Skip to main content

fpzip_rs/codec/
range_encoder.rs

1use super::rc_qs_model::RCQsModel;
2
3extern crate alloc;
4use alloc::vec::Vec;
5
6/// Range coder (arithmetic encoder) for entropy coding.
7pub struct RangeEncoder {
8    output: Vec<u8>,
9    low: u32,
10    range: u32,
11}
12
13impl Default for RangeEncoder {
14    fn default() -> Self {
15        Self::new()
16    }
17}
18
19impl RangeEncoder {
20    /// Creates a new range encoder.
21    pub fn new() -> Self {
22        Self {
23            output: Vec::new(),
24            low: 0,
25            range: 0xFFFFFFFF,
26        }
27    }
28
29    /// Creates a new range encoder with pre-allocated capacity.
30    pub fn with_capacity(capacity: usize) -> Self {
31        Self {
32            output: Vec::with_capacity(capacity),
33            low: 0,
34            range: 0xFFFFFFFF,
35        }
36    }
37
38    /// Returns the number of bytes written so far.
39    pub fn bytes_written(&self) -> usize {
40        self.output.len()
41    }
42
43    /// Finishes encoding and returns the compressed data.
44    pub fn finish(mut self) -> Vec<u8> {
45        self.put(4);
46        self.output
47    }
48
49    /// Returns reference to the output buffer.
50    pub fn output(&self) -> &[u8] {
51        &self.output
52    }
53
54    /// Consumes and returns the output buffer without finishing.
55    pub fn into_output(self) -> Vec<u8> {
56        self.output
57    }
58
59    /// Encodes a single bit.
60    #[inline]
61    pub fn encode_bit(&mut self, bit: bool) {
62        self.range >>= 1;
63        if bit {
64            self.low = self.low.wrapping_add(self.range);
65        }
66        self.normalize();
67    }
68
69    /// Encodes a symbol using a probability model.
70    #[inline]
71    pub fn encode_with_model(&mut self, symbol: u32, model: &mut RCQsModel) {
72        let (l, r) = model.encode(symbol);
73        model.normalize(&mut self.range);
74        self.low = self.low.wrapping_add(self.range.wrapping_mul(l));
75        self.range = self.range.wrapping_mul(r);
76        self.normalize();
77    }
78
79    /// Encodes an n-bit unsigned integer (0 <= s < 2^n).
80    #[inline]
81    pub fn encode_uint(&mut self, s: u32, n: i32) {
82        if n <= 0 {
83            return;
84        }
85        let mut s = s;
86        let mut n = n;
87        if n > 16 {
88            self.encode_shift(s & 0xFFFF, 16);
89            s >>= 16;
90            n -= 16;
91        }
92        self.encode_shift(s, n);
93    }
94
95    /// Encodes a 64-bit unsigned integer with n bits.
96    pub fn encode_ulong(&mut self, s: u64, n: i32) {
97        if n <= 0 {
98            return;
99        }
100        let mut s = s;
101        let mut n = n;
102        while n > 16 {
103            self.encode_shift((s & 0xFFFF) as u32, 16);
104            s >>= 16;
105            n -= 16;
106        }
107        self.encode_shift(s as u32, n);
108    }
109
110    /// Encodes an integer using shift (for power-of-2 ranges).
111    #[inline]
112    fn encode_shift(&mut self, s: u32, n: i32) {
113        self.range >>= n as u32;
114        self.low = self.low.wrapping_add(self.range.wrapping_mul(s));
115        self.normalize();
116    }
117
118    /// Normalizes the range and outputs fixed bits.
119    #[inline]
120    fn normalize(&mut self) {
121        while ((self.low ^ self.low.wrapping_add(self.range)) >> 24) == 0 {
122            self.put(1);
123            self.range <<= 8;
124        }
125        if (self.range >> 16) == 0 {
126            self.put(2);
127            self.range = 0u32.wrapping_sub(self.low);
128        }
129    }
130
131    /// Outputs n bytes from the high bits of low.
132    #[inline]
133    fn put(&mut self, n: i32) {
134        for _ in 0..n {
135            self.output.push((self.low >> 24) as u8);
136            self.low <<= 8;
137        }
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use crate::codec::range_decoder::RangeDecoder;
145
146    #[test]
147    fn encode_decode_bits() {
148        let mut enc = RangeEncoder::new();
149        enc.encode_bit(true);
150        enc.encode_bit(false);
151        enc.encode_bit(true);
152        enc.encode_bit(true);
153        enc.encode_bit(false);
154        let data = enc.finish();
155
156        let mut dec = RangeDecoder::new(&data);
157        dec.init();
158        assert!(dec.decode_bit());
159        assert!(!dec.decode_bit());
160        assert!(dec.decode_bit());
161        assert!(dec.decode_bit());
162        assert!(!dec.decode_bit());
163    }
164
165    #[test]
166    fn encode_decode_uint8() {
167        let mut enc = RangeEncoder::new();
168        enc.encode_uint(0xAB, 8);
169        let data = enc.finish();
170
171        let mut dec = RangeDecoder::new(&data);
172        dec.init();
173        assert_eq!(dec.decode_uint(8), 0xAB);
174    }
175
176    #[test]
177    fn encode_decode_uint16() {
178        let mut enc = RangeEncoder::new();
179        enc.encode_uint(0xABCD, 16);
180        let data = enc.finish();
181
182        let mut dec = RangeDecoder::new(&data);
183        dec.init();
184        assert_eq!(dec.decode_uint(16), 0xABCD);
185    }
186
187    #[test]
188    fn encode_decode_uint32() {
189        let mut enc = RangeEncoder::new();
190        enc.encode_uint(0xDEADBEEF, 32);
191        let data = enc.finish();
192
193        let mut dec = RangeDecoder::new(&data);
194        dec.init();
195        assert_eq!(dec.decode_uint(32), 0xDEADBEEF);
196    }
197
198    #[test]
199    fn encode_decode_ulong64() {
200        let mut enc = RangeEncoder::new();
201        enc.encode_ulong(0xDEADBEEFCAFEBABE, 64);
202        let data = enc.finish();
203
204        let mut dec = RangeDecoder::new(&data);
205        dec.init();
206        assert_eq!(dec.decode_ulong(64), 0xDEADBEEFCAFEBABE);
207    }
208
209    #[test]
210    fn encode_decode_with_model() {
211        let mut enc = RangeEncoder::new();
212        let mut model = RCQsModel::with_defaults(true, 65);
213        enc.encode_with_model(32, &mut model); // bias symbol
214        enc.encode_with_model(0, &mut model);
215        enc.encode_with_model(64, &mut model);
216        let data = enc.finish();
217
218        let mut dec = RangeDecoder::new(&data);
219        dec.init();
220        let mut dmodel = RCQsModel::with_defaults(false, 65);
221        assert_eq!(dec.decode_with_model(&mut dmodel), 32);
222        assert_eq!(dec.decode_with_model(&mut dmodel), 0);
223        assert_eq!(dec.decode_with_model(&mut dmodel), 64);
224    }
225
226    #[test]
227    fn bytes_written_tracking() {
228        let mut enc = RangeEncoder::new();
229        assert_eq!(enc.bytes_written(), 0);
230        enc.encode_bit(true);
231        // After encoding, some bytes may have been output
232        let data = enc.finish();
233        assert!(data.len() > 0);
234    }
235
236    #[test]
237    fn mixed_encoding_modes() {
238        let mut enc = RangeEncoder::new();
239        let mut model = RCQsModel::with_defaults(true, 10);
240
241        enc.encode_bit(true);
242        enc.encode_uint(42, 8);
243        enc.encode_with_model(5, &mut model);
244        enc.encode_uint(0x1234, 16);
245        enc.encode_bit(false);
246
247        let data = enc.finish();
248
249        let mut dec = RangeDecoder::new(&data);
250        dec.init();
251        let mut dmodel = RCQsModel::with_defaults(false, 10);
252
253        assert!(dec.decode_bit());
254        assert_eq!(dec.decode_uint(8), 42);
255        assert_eq!(dec.decode_with_model(&mut dmodel), 5);
256        assert_eq!(dec.decode_uint(16), 0x1234);
257        assert!(!dec.decode_bit());
258    }
259
260    #[test]
261    fn large_symbol_sequence() {
262        let mut enc = RangeEncoder::new();
263        let mut model = RCQsModel::with_defaults(true, 65);
264
265        let symbols: Vec<u32> = (0..1000).map(|i| (i % 65) as u32).collect();
266        for &s in &symbols {
267            enc.encode_with_model(s, &mut model);
268        }
269        let data = enc.finish();
270
271        let mut dec = RangeDecoder::new(&data);
272        dec.init();
273        let mut dmodel = RCQsModel::with_defaults(false, 65);
274
275        for &expected in &symbols {
276            assert_eq!(dec.decode_with_model(&mut dmodel), expected);
277        }
278    }
279}