Skip to main content

opus_rs/
range_coder.rs

1pub const EC_SYM_BITS: u32 = 8;
2pub const EC_CODE_BITS: u32 = 32;
3pub const EC_SYM_MAX: u32 = (1 << EC_SYM_BITS) - 1;
4pub const EC_CODE_SHIFT: u32 = EC_CODE_BITS - EC_SYM_BITS - 1;
5pub const EC_CODE_TOP: u32 = 1 << (EC_CODE_BITS - 1);
6pub const EC_CODE_BOT: u32 = EC_CODE_TOP >> EC_SYM_BITS;
7pub const EC_CODE_EXTRA: u32 = (EC_CODE_BITS - 2) % EC_SYM_BITS + 1;
8pub const BITRES: i32 = 3;
9
10#[macro_export]
11macro_rules! tell_frac_inline {
12    ($rc:expr) => {{
13        static CORRECTION: [u32; 8] = [35733, 38967, 42495, 46340, 50535, 55109, 60097, 65535];
14        let nbits = $rc.nbits_total << BITRES;
15        let l = 32 - $rc.rng.leading_zeros() as i32;
16        let r = $rc.rng >> (l - 16);
17        let b = (r >> 12).wrapping_sub(8);
18
19        let correction = unsafe { *CORRECTION.get_unchecked(b as usize) };
20        let b = b + (r > correction) as u32;
21        nbits - (l << 3) - b as i32
22    }};
23}
24
25#[derive(Clone)]
26pub struct RangeCoder {
27    pub buf: Vec<u8>,
28    pub storage: u32,
29    pub end_offs: u32,
30    pub end_window: u32,
31    pub nend_bits: i32,
32    pub nbits_total: i32,
33    pub offs: u32,
34    pub rng: u32,
35    pub val: u32,
36    pub ext: u32,
37    pub rem: i32,
38    pub error: i32,
39}
40
41impl RangeCoder {
42    pub fn new_encoder(size: u32) -> Self {
43        let buf = vec![0u8; size as usize];
44        RangeCoder {
45            buf,
46            storage: size,
47            end_offs: 0,
48            end_window: 0,
49            nend_bits: 0,
50            nbits_total: 33,
51            offs: 0,
52            rng: 1 << 31,
53            val: 0,
54            ext: 0,
55            rem: -1,
56            error: 0,
57        }
58    }
59
60    #[inline]
61    pub fn reset_for_encode(&mut self, size: u32) {
62        if self.buf.len() < size as usize {
63            self.buf.resize(size as usize, 0);
64        }
65        unsafe {
66            self.buf.set_len(size as usize);
67        }
68        self.storage = size;
69        self.end_offs = 0;
70        self.end_window = 0;
71        self.nend_bits = 0;
72        self.nbits_total = 33;
73        self.offs = 0;
74        self.rng = 1 << 31;
75        self.val = 0;
76        self.ext = 0;
77        self.rem = -1;
78        self.error = 0;
79    }
80
81    pub fn new_decoder(data: &[u8]) -> Self {
82        let storage = data.len() as u32;
83        let buf = data.to_vec();
84        let mut rc = RangeCoder {
85            buf,
86            storage,
87            end_offs: 0,
88            end_window: 0,
89            nend_bits: 0,
90            nbits_total: (EC_CODE_BITS + 1
91                - ((EC_CODE_BITS - EC_CODE_EXTRA) / EC_SYM_BITS) * EC_SYM_BITS)
92                as i32,
93            offs: 0,
94            rng: 1 << EC_CODE_EXTRA,
95            val: 0,
96            ext: 0,
97            rem: 0,
98            error: 0,
99        };
100
101        rc.rem = rc.read_byte() as i32;
102        rc.val = rc
103            .rng
104            .wrapping_sub(1)
105            .wrapping_sub(rc.rem as u32 >> (EC_SYM_BITS - EC_CODE_EXTRA));
106
107        rc.normalize_decoder();
108        rc
109    }
110
111    #[inline(always)]
112    fn normalize_decoder(&mut self) {
113        let mut guard = 0u32;
114        while self.rng <= EC_CODE_BOT {
115            guard += 1;
116            if guard > 100 {
117                self.error = 1;
118                self.rng = EC_CODE_BOT + 1;
119                break;
120            }
121            self.nbits_total += EC_SYM_BITS as i32;
122            self.rng <<= EC_SYM_BITS;
123
124            let sym = self.rem;
125            self.rem = self.read_byte() as i32;
126
127            let combined_sym = ((sym << EC_SYM_BITS) | self.rem) >> (EC_SYM_BITS - EC_CODE_EXTRA);
128            self.val = (self.val << EC_SYM_BITS).wrapping_add(EC_SYM_MAX & !combined_sym as u32)
129                & (EC_CODE_TOP - 1);
130        }
131    }
132
133    fn read_byte(&mut self) -> u8 {
134        if self.offs < self.storage {
135            let b = self.buf[self.offs as usize];
136            self.offs += 1;
137            b
138        } else {
139            0
140        }
141    }
142
143    #[inline(always)]
144    pub fn enc_uint(&mut self, fl: u32, ft: u32) {
145        if ft > 1 {
146            let ft_minus_1 = ft - 1;
147            let ftb = 32 - ft_minus_1.leading_zeros() as i32;
148            if ftb > 8 {
149                let s = ftb - 8;
150                let fl_low = fl & ((1u32 << s) - 1);
151                let fl = fl >> s;
152                let ft = (ft_minus_1 >> s) + 1;
153                self.encode(fl, fl.wrapping_add(1), ft);
154                self.enc_bits(fl_low, s as u32);
155            } else {
156                self.encode(fl, fl.wrapping_add(1), ft);
157            }
158        }
159    }
160
161    #[inline(always)]
162    pub fn dec_uint(&mut self, ft: u32) -> u32 {
163        if ft > 1 {
164            let ft_minus_1 = ft - 1;
165            let ftb = 32 - ft_minus_1.leading_zeros() as i32;
166            if ftb > 8 {
167                let s = ftb - 8;
168                let ft = (ft_minus_1 >> s) + 1;
169                let fs = self.decode(ft);
170                self.update(fs, fs.wrapping_add(1), ft);
171                let r = self.dec_bits(s as u32);
172                (fs << s) | r
173            } else {
174                let fs = self.decode(ft);
175                self.update(fs, fs.wrapping_add(1), ft);
176                fs
177            }
178        } else {
179            0
180        }
181    }
182
183    #[inline(always)]
184    pub fn enc_bits(&mut self, val: u32, bits: u32) {
185        if bits == 0 {
186            return;
187        }
188        let mut window = self.end_window;
189        let mut used = self.nend_bits;
190        if (used as u32) + bits > EC_CODE_BITS {
191            while used >= EC_SYM_BITS as i32 {
192                self.write_byte_at_end((window & EC_SYM_MAX) as u8);
193                window >>= EC_SYM_BITS;
194                used -= EC_SYM_BITS as i32;
195            }
196        }
197        window |= (val & ((1 << bits) - 1)) << used;
198        used += bits as i32;
199        self.end_window = window;
200        self.nend_bits = used;
201        self.nbits_total += bits as i32;
202    }
203
204    pub fn pad_to_bits(&mut self, target_bits: i32) {
205        let remaining = target_bits - self.nbits_total;
206        if remaining <= 0 {
207            return;
208        }
209        let mut remaining = remaining as u32;
210
211        let partial =
212            (EC_SYM_BITS - (self.nend_bits as u32 & (EC_SYM_BITS - 1))) & (EC_SYM_BITS - 1);
213        if partial > 0 && remaining >= partial {
214            self.enc_bits(0, partial.min(remaining));
215            remaining -= partial.min(remaining);
216        }
217
218        let full_bytes = remaining / EC_SYM_BITS;
219        if full_bytes > 0 {
220            let available = self.storage - self.offs - self.end_offs;
221            let write_count = full_bytes.min(available);
222            if write_count > 0 {
223                let start = (self.storage - self.end_offs - write_count) as usize;
224                unsafe {
225                    std::ptr::write_bytes(
226                        self.buf.as_mut_ptr().add(start),
227                        0,
228                        write_count as usize,
229                    );
230                }
231                self.end_offs += write_count;
232            }
233            if write_count < full_bytes {
234                self.error = 1;
235            }
236            self.nbits_total += (full_bytes * EC_SYM_BITS) as i32;
237            remaining -= full_bytes * EC_SYM_BITS;
238        }
239
240        if remaining > 0 {
241            self.enc_bits(0, remaining);
242        }
243    }
244
245    pub fn dec_bits(&mut self, bits: u32) -> u32 {
246        if bits == 0 {
247            return 0;
248        }
249        let mut window = self.end_window;
250        let mut used = self.nend_bits;
251        if used < bits as i32 {
252            loop {
253                let byte = if self.end_offs < self.storage {
254                    self.end_offs += 1;
255                    self.buf[(self.storage - self.end_offs) as usize]
256                } else {
257                    0
258                };
259                window |= (byte as u32) << used;
260                used += 8;
261                if used > 32 - 8 {
262                    break;
263                }
264            }
265        }
266        let ret = window & ((1 << bits) - 1);
267        self.end_window = window >> bits;
268        self.nend_bits = used - bits as i32;
269        self.nbits_total += bits as i32;
270        ret
271    }
272
273    #[inline(always)]
274    pub fn tell_frac(&self) -> i32 {
275        static CORRECTION: [u32; 8] = [35733, 38967, 42495, 46340, 50535, 55109, 60097, 65535];
276        let nbits = self.nbits_total << BITRES;
277        let l = 32 - self.rng.leading_zeros() as i32;
278        let r = self.rng >> (l - 16);
279        let b = (r >> 12).wrapping_sub(8);
280        let b = b + (r > CORRECTION[b as usize]) as u32;
281        nbits - (l << 3) - b as i32
282    }
283
284    /// Integer bit count matching C's `ec_tell()`: `nbits_total - EC_ILOG(rng)`.
285    #[inline(always)]
286    pub fn tell(&self) -> i32 {
287        self.nbits_total - (32 - self.rng.leading_zeros() as i32)
288    }
289
290    #[inline(always)]
291    pub fn tell_fast(&self) -> i32 {
292        self.nbits_total
293    }
294
295    /// Shrink the range coder buffer, moving end-coded bytes to the new end.
296    /// Equivalent to C's `ec_enc_shrink`.
297    pub fn shrink(&mut self, new_size: u32) {
298        debug_assert!(self.offs + self.end_offs <= new_size);
299        if self.end_offs > 0 {
300            let old_end_start = (self.storage - self.end_offs) as usize;
301            let old_end_end = self.storage as usize;
302            let new_end_start = (new_size - self.end_offs) as usize;
303            self.buf
304                .copy_within(old_end_start..old_end_end, new_end_start);
305        }
306        self.storage = new_size;
307    }
308
309    #[inline(always)]
310    fn write_byte(&mut self, value: u8) {
311        if self.offs + self.end_offs < self.storage {
312            unsafe {
313                *self.buf.get_unchecked_mut(self.offs as usize) = value;
314            }
315            self.offs += 1;
316        } else {
317            self.error = 1;
318        }
319    }
320
321    #[inline(always)]
322    fn carry_out(&mut self, c: i32) {
323        if c != EC_SYM_MAX as i32 {
324            let carry = c >> EC_SYM_BITS;
325            if self.rem >= 0 {
326                self.write_byte((self.rem + carry) as u8);
327            }
328            if self.ext > 0 {
329                let sym = (EC_SYM_MAX as i32 + carry) & EC_SYM_MAX as i32;
330
331                let ext = self.ext as usize;
332                for _j in 0..ext {
333                    self.write_byte(sym as u8);
334                }
335                self.ext = 0;
336            }
337            self.rem = c & EC_SYM_MAX as i32;
338        } else {
339            self.ext += 1;
340        }
341    }
342
343    #[inline(always)]
344    fn celt_udiv(n: u32, d: u32) -> u32 {
345        debug_assert!(d > 0);
346        n / d
347    }
348
349    #[inline(always)]
350    pub fn encode(&mut self, fl: u32, fh: u32, ft: u32) {
351        debug_assert!(ft > 0, "encode: ft must be > 0");
352        let r = Self::celt_udiv(self.rng, ft);
353        if fl > 0 {
354            self.val = self
355                .val
356                .wrapping_add(self.rng.wrapping_sub(r.wrapping_mul(ft.wrapping_sub(fl))));
357            self.rng = r.wrapping_mul(fh.wrapping_sub(fl));
358        } else {
359            self.rng = self.rng.wrapping_sub(r.wrapping_mul(ft.wrapping_sub(fh)));
360        }
361        self.normalize_encoder();
362    }
363
364    #[inline(always)]
365    fn normalize_encoder(&mut self) {
366        while self.rng <= EC_CODE_BOT {
367            let c = (self.val >> EC_CODE_SHIFT) as i32;
368            if c != EC_SYM_MAX as i32 {
369                let carry = c >> EC_SYM_BITS;
370                if self.rem >= 0 {
371                    if self.offs + self.end_offs < self.storage {
372                        unsafe {
373                            *self.buf.get_unchecked_mut(self.offs as usize) =
374                                (self.rem + carry) as u8;
375                        }
376                        self.offs += 1;
377                    } else {
378                        self.error = 1;
379                    }
380                }
381                if self.ext > 0 {
382                    let sym = (EC_SYM_MAX as i32 + carry) & EC_SYM_MAX as i32;
383                    let ext = self.ext as usize;
384                    for _j in 0..ext {
385                        if self.offs + self.end_offs < self.storage {
386                            unsafe {
387                                *self.buf.get_unchecked_mut(self.offs as usize) = sym as u8;
388                            }
389                            self.offs += 1;
390                        } else {
391                            self.error = 1;
392                        }
393                    }
394                    self.ext = 0;
395                }
396                self.rem = c & EC_SYM_MAX as i32;
397            } else {
398                self.ext += 1;
399            }
400            self.val = (self.val << EC_SYM_BITS) & (EC_CODE_TOP - 1);
401            self.rng <<= EC_SYM_BITS;
402            self.nbits_total = self.nbits_total.wrapping_add(EC_SYM_BITS as i32);
403        }
404    }
405
406    #[inline(always)]
407    pub fn encode_bit_logp(&mut self, val: bool, logp: u32) {
408        let s = self.rng >> logp;
409        let r = self.rng.wrapping_sub(s);
410        if val {
411            self.val = self.val.wrapping_add(r);
412            self.rng = s;
413        } else {
414            self.rng = r;
415        }
416        self.normalize_encoder();
417    }
418
419    #[inline(always)]
420    pub fn encode_icdf(&mut self, s: i32, icdf: &[u8], ftb: u32) {
421        let r = self.rng >> ftb;
422        if s > 0 {
423            let val = unsafe { *icdf.get_unchecked((s - 1) as usize) as u32 };
424            self.val = self
425                .val
426                .wrapping_add(self.rng.wrapping_sub(r.wrapping_mul(val)));
427            let lower = unsafe { *icdf.get_unchecked(s as usize) };
428            self.rng = r.wrapping_mul(val.wrapping_sub(lower as u32));
429        } else {
430            let val = unsafe { *icdf.get_unchecked(s as usize) as u32 };
431            self.rng = self.rng.wrapping_sub(r.wrapping_mul(val));
432        }
433        self.normalize_encoder();
434    }
435
436    #[inline(always)]
437    pub fn decode_bit_logp(&mut self, logp: u32) -> bool {
438        let s = self.rng >> logp;
439        let ret = self.val < s;
440        if !ret {
441            self.val = self.val.wrapping_sub(s);
442            self.rng = self.rng.wrapping_sub(s);
443        } else {
444            self.rng = s;
445        }
446        self.normalize_decoder();
447        ret
448    }
449
450    #[inline(always)]
451    pub fn decode_icdf(&mut self, icdf: &[u8], ftb: u32) -> i32 {
452        let mut s = self.rng;
453        let d = self.val;
454        let r = s >> ftb;
455        let mut ret = 0;
456        let mut t;
457
458        loop {
459            t = s;
460            s = r.wrapping_mul(icdf[ret] as u32);
461            ret += 1;
462            if d >= s {
463                break;
464            }
465        }
466
467        self.val = d.wrapping_sub(s);
468        self.rng = t.wrapping_sub(s);
469        self.normalize_decoder();
470        (ret - 1) as i32
471    }
472
473    #[inline(always)]
474    pub fn decode(&mut self, ft: u32) -> u32 {
475        let r = Self::celt_udiv(self.rng, ft);
476        self.ext = r;
477        let s = self.val / r;
478        ft - ft.min(s.wrapping_add(1))
479    }
480
481    #[inline(always)]
482    pub fn update(&mut self, fl: u32, fh: u32, ft: u32) {
483        let s = self.ext.wrapping_mul(ft.wrapping_sub(fh));
484        self.val = self.val.wrapping_sub(s);
485        self.rng = if fl > 0 {
486            self.ext.wrapping_mul(fh.wrapping_sub(fl))
487        } else {
488            self.rng.wrapping_sub(s)
489        };
490        self.normalize_decoder();
491    }
492
493    pub fn laplace_encode(&mut self, value: &mut i32, fs: u32, decay: i32) {
494        let mut val = *value;
495        let mut fl = 0;
496        let mut fs_val = fs;
497
498        if val != 0 {
499            let s = if val < 0 { -1 } else { 0 };
500            val = (val + s) ^ s;
501            fl = fs_val;
502            fs_val = self.laplace_get_freq1(fs_val, decay);
503
504            let mut i = 1;
505            while fs_val > 0 && i < val {
506                fs_val *= 2;
507                fl += fs_val + 2;
508                fs_val = ((fs_val as i32 * decay) >> 15) as u32;
509                i += 1;
510            }
511
512            if fs_val == 0 {
513                let ndi_max = 32768 - fl + 1 - 1;
514                let ndi_max = (ndi_max as i32 - s) >> 1;
515                let di = (val - i).min(ndi_max - 1);
516                fl += (2 * di + 1 + s) as u32;
517                fs_val = 1u32.min(32768 - fl);
518                *value = (i + di + s) ^ s;
519            } else {
520                fs_val += 1;
521                fl += fs_val & (!s as u32);
522            }
523        }
524        self.encode(fl, fl.wrapping_add(fs_val), 1 << 15);
525    }
526
527    fn laplace_get_freq1(&self, fs0: u32, decay: i32) -> u32 {
528        let ft = 32768 - (2 * 16) - fs0;
529        ((ft as i32 * (16384 - decay)) >> 15) as u32
530    }
531
532    pub fn laplace_decode(&mut self, fs: u32, decay: i32) -> i32 {
533        let fm = self.decode(1 << 15);
534        let mut fl = 0;
535        let mut fs_val = fs;
536        let mut val = 0;
537
538        if fm >= fs_val {
539            val += 1;
540            fl = fs_val;
541            fs_val = self.laplace_get_freq1(fs_val, decay) + 1;
542
543            while fs_val > 1 && fm >= fl + 2 * fs_val {
544                fs_val *= 2;
545                fl += fs_val;
546                fs_val = (((fs_val as i32 - 2) * decay) >> 15) as u32 + 1;
547                val += 1;
548            }
549
550            if fs_val <= 1 {
551                let di = (fm - fl) >> 1;
552                val += di as i32;
553                fl += 2 * di;
554            }
555
556            if fm < fl + fs_val {
557                val = -val;
558            } else {
559                fl += fs_val;
560            }
561        }
562
563        self.update(fl, fl.wrapping_add(fs_val.min(32768 - fl)), 1 << 15);
564        val
565    }
566
567    #[inline(always)]
568    fn write_byte_at_end(&mut self, value: u8) {
569        if self.offs + self.end_offs < self.storage {
570            self.end_offs += 1;
571            let idx = (self.storage - self.end_offs) as usize;
572            unsafe {
573                *self.buf.get_unchecked_mut(idx) = value;
574            }
575        } else {
576            self.error = 1;
577        }
578    }
579
580    pub fn patch_initial_bits(&mut self, val: u32, nbits: u32) {
581        let shift = EC_SYM_BITS - nbits;
582        let mask = ((1u32 << nbits) - 1) << shift;
583        if self.offs > 0 {
584            self.buf[0] = ((self.buf[0] as u32 & !mask) | (val << shift)) as u8;
585        } else if self.rem >= 0 {
586            self.rem = ((self.rem as u32 & !mask) | (val << shift)) as i32;
587        } else if self.rng <= (EC_CODE_TOP >> nbits) {
588            let mask_shifted = mask << EC_CODE_SHIFT;
589            self.val = (self.val & !mask_shifted) | (val << (EC_CODE_SHIFT + shift));
590        } else {
591            self.error = -1;
592        }
593    }
594
595    pub fn done(&mut self) {
596        let ilog = 32 - self.rng.leading_zeros();
597        let mut l = (EC_CODE_BITS - ilog) as i32;
598        let mut msk = (EC_CODE_TOP - 1) >> l;
599        let mut end = (self.val.wrapping_add(msk)) & !msk;
600
601        if (end | msk) >= self.val.wrapping_add(self.rng) {
602            l += 1;
603            msk >>= 1;
604            end = (self.val.wrapping_add(msk)) & !msk;
605        }
606
607        while l > 0 {
608            self.carry_out((end >> EC_CODE_SHIFT) as i32);
609            end = (end << EC_SYM_BITS) & (EC_CODE_TOP - 1);
610            l -= EC_SYM_BITS as i32;
611        }
612
613        if self.rem >= 0 || self.ext > 0 {
614            self.carry_out(0);
615        }
616
617        let mut window = self.end_window;
618        let mut used = self.nend_bits;
619        while used >= EC_SYM_BITS as i32 {
620            self.write_byte_at_end((window & EC_SYM_MAX) as u8);
621            window >>= EC_SYM_BITS;
622            used -= EC_SYM_BITS as i32;
623        }
624
625        if self.error == 0 {
626            for i in self.offs..(self.storage - self.end_offs) {
627                self.buf[i as usize] = 0;
628            }
629
630            if used > 0 {
631                if self.end_offs >= self.storage {
632                    self.error = -1;
633                } else {
634                    // If we've busted, don't add too many extra bits to the
635                    // last byte; it would corrupt the range coder data.
636                    if self.offs + self.end_offs >= self.storage && -l < used {
637                        window &= (1u32 << (-l as u32)) - 1;
638                        self.error = -1;
639                    }
640                    let idx = (self.storage - self.end_offs - 1) as usize;
641                    self.buf[idx] |= window as u8;
642                }
643            }
644        }
645    }
646
647    pub fn finish(&mut self) -> Vec<u8> {
648        self.done();
649
650        let extra_end = if self.nend_bits > 0 && self.end_offs == 0 {
651            1
652        } else {
653            0
654        };
655        let mut result = Vec::with_capacity((self.offs + self.end_offs + extra_end) as usize);
656        result.extend_from_slice(&self.buf[0..self.offs as usize]);
657        if extra_end > 0 {
658            result.push(self.buf[(self.storage - 1) as usize]);
659        }
660        result.extend_from_slice(
661            &self.buf[(self.storage - self.end_offs) as usize..self.storage as usize],
662        );
663        result
664    }
665}
666
667#[cfg(test)]
668mod tests {
669    use super::*;
670
671    #[test]
672    fn test_laplace() {
673        let mut enc = RangeCoder::new_encoder(100);
674        let mut val = -3;
675        let fs = 100 << 7;
676        let decay = 120 << 6;
677        enc.laplace_encode(&mut val, fs, decay);
678        enc.done();
679
680        assert_eq!(enc.offs, 1);
681        assert_eq!(enc.buf[0], 224);
682
683        let mut dec = RangeCoder::new_decoder(&enc.buf[..enc.offs as usize]);
684        let decoded_val = dec.laplace_decode(fs, decay);
685        assert_eq!(decoded_val, -3);
686    }
687
688    #[test]
689    fn test_icdf_consistency() {
690        let mut enc = RangeCoder::new_encoder(1024);
691        let icdf = [2, 1, 0];
692        enc.encode_icdf(0, &icdf, 2);
693        enc.encode_icdf(1, &icdf, 2);
694        enc.encode_icdf(2, &icdf, 2);
695        enc.done();
696        let data = enc.buf[..enc.offs as usize].to_vec();
697
698        let mut dec = RangeCoder::new_decoder(&data);
699        let s0 = dec.decode_icdf(&icdf, 2);
700        let s1 = dec.decode_icdf(&icdf, 2);
701        let s2 = dec.decode_icdf(&icdf, 2);
702
703        assert_eq!(s0, 0);
704        assert_eq!(s1, 1);
705        assert_eq!(s2, 2);
706    }
707
708    #[test]
709    fn test_icdf_last_symbol_no_oob() {
710        let icdf: &[u8] = &[170, 85, 0];
711        let ftb = 8u32;
712
713        for sym in 0..3i32 {
714            let mut enc = RangeCoder::new_encoder(256);
715            enc.encode_icdf(sym, icdf, ftb);
716            enc.done();
717            let data = enc.buf[..enc.offs as usize].to_vec();
718
719            let mut dec = RangeCoder::new_decoder(&data);
720            let decoded = dec.decode_icdf(icdf, ftb);
721            assert_eq!(decoded, sym, "往返失败: 编码 symbol={sym} 解码得 {decoded}");
722        }
723    }
724
725    #[test]
726    fn test_icdf_decode_terminates() {
727        let icdf: &[u8] = &[192, 128, 64, 0];
728        let ftb = 8u32;
729
730        let symbols = [0i32, 1, 2, 3];
731        let mut enc = RangeCoder::new_encoder(256);
732        for &s in &symbols {
733            enc.encode_icdf(s, icdf, ftb);
734        }
735        enc.done();
736        let data = enc.buf[..enc.offs as usize].to_vec();
737
738        let mut dec = RangeCoder::new_decoder(&data);
739        for &expected in &symbols {
740            let got = dec.decode_icdf(icdf, ftb);
741            assert_eq!(got, expected, "解码器输出 {got},期望 {expected}");
742        }
743    }
744
745    #[test]
746    fn test_bits_only() {
747        let mut enc = RangeCoder::new_encoder(1024);
748
749        enc.enc_bits(1, 1);
750        enc.enc_bits(5, 3);
751        enc.enc_bits(7, 3);
752        enc.enc_bits(0, 2);
753
754        let data = enc.finish();
755        let mut dec = RangeCoder::new_decoder(&data);
756
757        let b1 = dec.dec_bits(1);
758        let b2 = dec.dec_bits(3);
759        let b3 = dec.dec_bits(3);
760        let b4 = dec.dec_bits(2);
761
762        assert_eq!(b1, 1);
763        assert_eq!(b2, 5);
764        assert_eq!(b3, 7);
765        assert_eq!(b4, 0);
766    }
767
768    #[test]
769    fn test_interleaved_bits_entropy() {
770        let mut enc = RangeCoder::new_encoder(1024);
771
772        enc.enc_bits(1, 1);
773
774        enc.encode(10, 20, 100);
775
776        enc.enc_bits(5, 3);
777
778        enc.encode(50, 60, 100);
779
780        let data = enc.finish();
781
782        let mut dec = RangeCoder::new_decoder(&data);
783
784        let b1 = dec.dec_bits(1);
785        let d1 = dec.decode(100);
786        dec.update(10, 20, 100);
787        let b2 = dec.dec_bits(3);
788        let d2 = dec.decode(100);
789        dec.update(50, 60, 100);
790
791        assert_eq!(b1, 1);
792        assert!((10..20).contains(&d1), "d1={} expected in [10, 20)", d1);
793        assert_eq!(b2, 5);
794        assert!((50..60).contains(&d2), "d2={} expected in [50, 60)", d2);
795    }
796}