Skip to main content

fpzip_rs/codec/
range_decoder.rs

1use super::rc_qs_model::RCQsModel;
2
3/// Range coder (arithmetic decoder) for entropy decoding.
4pub struct RangeDecoder<'a> {
5    input: &'a [u8],
6    pos: usize,
7    low: u32,
8    range: u32,
9    code: u32,
10    error: bool,
11}
12
13impl<'a> RangeDecoder<'a> {
14    /// Creates a new range decoder reading from the given byte slice.
15    pub fn new(input: &'a [u8]) -> Self {
16        Self {
17            input,
18            pos: 0,
19            low: 0,
20            range: 0xFFFFFFFF,
21            code: 0,
22            error: false,
23        }
24    }
25
26    /// Whether an EOF error was encountered.
27    pub fn error(&self) -> bool {
28        self.error
29    }
30
31    /// Number of bytes read so far.
32    pub fn bytes_read(&self) -> usize {
33        self.pos
34    }
35
36    /// Initializes the decoder by reading the first 4 bytes.
37    pub fn init(&mut self) {
38        self.error = false;
39        self.get(4);
40    }
41
42    /// Decodes a single bit.
43    #[inline]
44    pub fn decode_bit(&mut self) -> bool {
45        self.range >>= 1;
46        let s = self.code >= self.low.wrapping_add(self.range);
47        if s {
48            self.low = self.low.wrapping_add(self.range);
49        }
50        self.normalize();
51        s
52    }
53
54    /// Decodes a symbol using a probability model.
55    #[inline]
56    pub fn decode_with_model(&mut self, model: &mut RCQsModel) -> u32 {
57        model.normalize(&mut self.range);
58        let mut l = self.code.wrapping_sub(self.low) / self.range;
59        let mut r = 0u32;
60        let s = model.decode(&mut l, &mut r);
61        self.low = self.low.wrapping_add(self.range.wrapping_mul(l));
62        self.range = self.range.wrapping_mul(r);
63        self.normalize();
64        s
65    }
66
67    /// Decodes an n-bit unsigned integer (0 <= result < 2^n).
68    #[inline]
69    pub fn decode_uint(&mut self, n: i32) -> u32 {
70        if n <= 0 {
71            return 0;
72        }
73        let mut s = 0u32;
74        let mut m = 0;
75        let mut n = n;
76
77        while n > 16 {
78            s += self.decode_shift(16) << m;
79            m += 16;
80            n -= 16;
81        }
82
83        (self.decode_shift(n) << m) + s
84    }
85
86    /// Decodes a 64-bit unsigned integer with n bits.
87    pub fn decode_ulong(&mut self, n: i32) -> u64 {
88        if n <= 0 {
89            return 0;
90        }
91        let mut s = 0u64;
92        let mut m = 0;
93        let mut n = n;
94
95        while n > 16 {
96            s += (self.decode_shift(16) as u64) << m;
97            m += 16;
98            n -= 16;
99        }
100
101        ((self.decode_shift(n) as u64) << m) + s
102    }
103
104    /// Decodes using shift (for power-of-2 ranges).
105    #[inline]
106    fn decode_shift(&mut self, n: i32) -> u32 {
107        self.range >>= n as u32;
108        let s = self.code.wrapping_sub(self.low) / self.range;
109        self.low = self.low.wrapping_add(self.range.wrapping_mul(s));
110        self.normalize();
111        s
112    }
113
114    /// Normalizes the range and inputs new data.
115    #[inline]
116    fn normalize(&mut self) {
117        while ((self.low ^ self.low.wrapping_add(self.range)) >> 24) == 0 {
118            self.get(1);
119            self.range <<= 8;
120        }
121        if (self.range >> 16) == 0 {
122            self.get(2);
123            self.range = 0u32.wrapping_sub(self.low);
124        }
125    }
126
127    /// Inputs n bytes from the stream.
128    #[inline]
129    fn get(&mut self, n: i32) {
130        for _ in 0..n {
131            self.code <<= 8;
132            self.code |= self.get_byte() as u32;
133            self.low <<= 8;
134        }
135    }
136
137    /// Reads a single byte.
138    #[inline]
139    fn get_byte(&mut self) -> u8 {
140        if self.pos >= self.input.len() {
141            self.error = true;
142            return 0;
143        }
144        let b = self.input[self.pos];
145        self.pos += 1;
146        b
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn eof_sets_error() {
156        let data = [0u8; 4]; // Just enough for init
157        let mut dec = RangeDecoder::new(&data);
158        dec.init();
159        assert!(!dec.error());
160        // Decoding beyond available data should eventually set error
161        for _ in 0..100 {
162            dec.decode_bit();
163        }
164        assert!(dec.error());
165    }
166}