Skip to main content

fpzip_rs/codec/
rc_qs_model.rs

1extern crate alloc;
2use alloc::vec;
3use alloc::vec::Vec;
4
5const TABLE_SHIFT: i32 = 7;
6
7/// Quasi-Static adaptive probability model for the range coder.
8pub struct RCQsModel {
9    symbols: usize,
10    bits: i32,
11    left: i32,
12    more: i32,
13    incr: u32,
14    rescale: i32,
15    target_rescale: i32,
16    symf: Vec<u32>,
17    cumf: Vec<u32>,
18    search_shift: i32,
19    search: Option<Vec<u32>>,
20}
21
22impl RCQsModel {
23    /// Creates a new quasi-static probability model.
24    ///
25    /// `compress`: true for compression, false for decompression.
26    /// `symbols`: number of symbols.
27    /// `bits`: log2 of total frequency count (must be <= 16).
28    /// `period`: max symbols between normalizations.
29    pub fn new(compress: bool, symbols: usize, bits: i32, period: i32) -> Self {
30        assert!(bits <= 16, "bits must be <= 16");
31        assert!(period < (1 << (bits + 1)), "period too large");
32
33        let n = symbols;
34        let symf = vec![0u32; n + 1];
35        let mut cumf = vec![0u32; n + 1];
36        cumf[0] = 0;
37        cumf[n] = 1u32 << bits;
38
39        let (search, search_shift) = if compress {
40            (None, 0)
41        } else {
42            let ss = bits - TABLE_SHIFT;
43            let s = vec![0u32; (1 << TABLE_SHIFT) + 1];
44            (Some(s), ss)
45        };
46
47        let mut model = Self {
48            symbols,
49            bits,
50            left: 0,
51            more: 0,
52            incr: 0,
53            rescale: 0,
54            target_rescale: period,
55            symf,
56            cumf,
57            search_shift,
58            search,
59        };
60        model.reset();
61        model
62    }
63
64    /// Creates a new model with default bits=16 and period=0x400.
65    pub fn with_defaults(compress: bool, symbols: usize) -> Self {
66        Self::new(compress, symbols, 16, 0x400)
67    }
68
69    pub fn symbols(&self) -> usize {
70        self.symbols
71    }
72
73    /// Reinitializes the model to uniform distribution.
74    pub fn reset(&mut self) {
75        let n = self.symbols;
76        self.rescale = (n as i32 >> 4) | 2;
77        self.more = 0;
78
79        let total_freq = self.cumf[n];
80        let f = total_freq / n as u32;
81        let m = total_freq % n as u32;
82
83        for i in 0..m as usize {
84            self.symf[i] = f + 1;
85        }
86        for i in m as usize..n {
87            self.symf[i] = f;
88        }
89
90        self.update();
91    }
92
93    /// Gets the cumulative and individual frequencies for encoding symbol s.
94    #[inline]
95    pub fn encode(&mut self, s: u32) -> (u32, u32) {
96        let cum_freq = self.cumf[s as usize];
97        let freq = self.cumf[s as usize + 1] - cum_freq;
98        self.update_symbol(s);
99        (cum_freq, freq)
100    }
101
102    /// Returns the symbol corresponding to cumulative frequency l.
103    /// Updates l to the cumulative frequency and returns (symbol, freq).
104    pub fn decode(&mut self, l: &mut u32, r: &mut u32) -> u32 {
105        let search = self.search.as_ref().unwrap();
106        let i = (*l >> self.search_shift) as usize;
107        let mut s = search[i];
108        let mut h = search[i + 1] + 1;
109
110        // Binary search
111        while s + 1 < h {
112            let m = (s + h) >> 1;
113            if *l < self.cumf[m as usize] {
114                h = m;
115            } else {
116                s = m;
117            }
118        }
119
120        *l = self.cumf[s as usize];
121        *r = self.cumf[s as usize + 1] - *l;
122        self.update_symbol(s);
123
124        s
125    }
126
127    /// Normalizes the range by shifting right by bits.
128    #[inline]
129    pub fn normalize(&self, r: &mut u32) {
130        *r >>= self.bits;
131    }
132
133    /// Main update routine - rescales frequencies and rebuilds tables.
134    fn update(&mut self) {
135        if self.more > 0 {
136            self.left = self.more;
137            self.more = 0;
138            self.incr += 1;
139            return;
140        }
141
142        if self.rescale != self.target_rescale {
143            self.rescale *= 2;
144            if self.rescale > self.target_rescale {
145                self.rescale = self.target_rescale;
146            }
147        }
148
149        let n = self.symbols;
150        let mut cf = self.cumf[n];
151        let mut count = cf;
152
153        for i in (0..n).rev() {
154            let mut sf = self.symf[i];
155            cf -= sf;
156            self.cumf[i] = cf;
157            sf = (sf >> 1) | 1; // halve with odd bit set
158            count -= sf;
159            self.symf[i] = sf;
160        }
161
162        self.incr = count / self.rescale as u32;
163        self.more = (count % self.rescale as u32) as i32;
164        self.left = self.rescale - self.more;
165
166        // Build lookup table
167        if let Some(ref mut search) = self.search {
168            let mut h = 1i32 << TABLE_SHIFT;
169            for i in (0..n).rev() {
170                let new_h = (self.cumf[i] >> self.search_shift) as i32;
171                for l in new_h..=h {
172                    search[l as usize] = i as u32;
173                }
174                h = new_h;
175            }
176        }
177    }
178
179    /// Updates frequency for a single symbol.
180    #[inline]
181    fn update_symbol(&mut self, s: u32) {
182        if self.left == 0 {
183            self.update();
184        }
185        self.left -= 1;
186        self.symf[s as usize] += self.incr;
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[test]
195    fn new_default_params() {
196        let m = RCQsModel::with_defaults(true, 65);
197        assert_eq!(m.symbols(), 65);
198    }
199
200    #[test]
201    #[should_panic(expected = "bits must be <= 16")]
202    fn bits_too_large() {
203        RCQsModel::new(true, 10, 17, 0x400);
204    }
205
206    #[test]
207    #[should_panic(expected = "period too large")]
208    fn period_too_large() {
209        RCQsModel::new(true, 10, 16, 1 << 17);
210    }
211
212    #[test]
213    fn compress_mode_no_search_table() {
214        let m = RCQsModel::new(true, 10, 16, 0x400);
215        assert!(m.search.is_none());
216    }
217
218    #[test]
219    fn decompress_mode_has_search_table() {
220        let m = RCQsModel::new(false, 10, 16, 0x400);
221        assert!(m.search.is_some());
222    }
223
224    #[test]
225    fn encode_returns_valid_frequencies() {
226        let mut m = RCQsModel::with_defaults(true, 65);
227        for s in 0..65u32 {
228            let (cum, freq) = m.encode(s);
229            assert!(freq > 0, "freq must be > 0 for symbol {s}");
230            assert!(
231                cum + freq <= (1 << 16),
232                "cumulative overflow for symbol {s}"
233            );
234        }
235    }
236
237    #[test]
238    fn reset_restores_uniform() {
239        let mut m = RCQsModel::with_defaults(true, 10);
240        // Encode some symbols to change distribution
241        for _ in 0..100 {
242            m.encode(0);
243        }
244        m.reset();
245        // After reset, frequencies should be roughly uniform
246        let (_, f0) = m.encode(0);
247        let (_, f5) = m.encode(5);
248        // They should be close after reset
249        let diff = (f0 as i64 - f5 as i64).unsigned_abs();
250        assert!(diff <= 2, "frequencies should be roughly equal after reset");
251    }
252}