Skip to main content

opus_rs/
range_coder.rs

1pub const EC_SYM_BITS: u32 = 8;
2pub const EC_CODE_BITS: u32 = 32;
3pub const EC_SYM_MAX: u32 = (1 << EC_SYM_BITS) - 1;
4pub const EC_CODE_SHIFT: u32 = EC_CODE_BITS - EC_SYM_BITS - 1;
5pub const EC_CODE_TOP: u32 = 1 << (EC_CODE_BITS - 1);
6pub const EC_CODE_BOT: u32 = EC_CODE_TOP >> EC_SYM_BITS;
7pub const EC_CODE_EXTRA: u32 = (EC_CODE_BITS - 2) % EC_SYM_BITS + 1;
8pub const BITRES: i32 = 3;
9
10const SMALL_DIV_TABLE: [u32; 128] = [
11    0xFFFFFFFF, 0x55555555, 0x33333333, 0x24924924, 0x1C71C71C, 0x1745D174, 0x13B13B13, 0x11111111,
12    0x0F0F0F0F, 0x0D79435E, 0x0C30C30C, 0x0B21642C, 0x0A3D70A3, 0x097B425E, 0x08D3DCB0, 0x08421084,
13    0x07C1F07C, 0x07507507, 0x06EB3E45, 0x06906906, 0x063E7063, 0x05F417D0, 0x05B05B05, 0x0572620A,
14    0x05397829, 0x05050505, 0x04D4873E, 0x04A7904A, 0x047DC11F, 0x0456C797, 0x04325C53, 0x04104104,
15    0x03F03F03, 0x03D22635, 0x03B5CC0E, 0x039B0AD1, 0x0381C0E0, 0x0369D036, 0x03531DEC, 0x033D91D2,
16    0x0329161F, 0x03159721, 0x03030303, 0x02F14990, 0x02E05C0B, 0x02D02D02, 0x02C0B02C, 0x02B1DA46,
17    0x02A3A0FD, 0x0295FAD4, 0x0288DF0C, 0x027C4597, 0x02702702, 0x02647C69, 0x02593F69, 0x024E6A17,
18    0x0243F6F0, 0x0239E0D5, 0x02302302, 0x0226B902, 0x021D9EAD, 0x0214D021, 0x020C49BA, 0x02040810,
19    0x01FC07F0, 0x01F44659, 0x01ECC07B, 0x01E573AC, 0x01DE5D6E, 0x01D77B65, 0x01D0CB58, 0x01CA4B30,
20    0x01C3F8F0, 0x01BDD2B8, 0x01B7D6C3, 0x01B20364, 0x01AC5701, 0x01A6D01A, 0x01A16D3F, 0x019C2D14,
21    0x01970E4F, 0x01920FB4, 0x018D3018, 0x01886E5F, 0x0183C977, 0x017F405F, 0x017AD220, 0x01767DCE,
22    0x01724287, 0x016E1F76, 0x016A13CD, 0x01661EC6, 0x01623FA7, 0x015E75BB, 0x015AC056, 0x01571ED3,
23    0x01539094, 0x01501501, 0x014CAB88, 0x0149539E, 0x01460CBC, 0x0142D662, 0x013FB013, 0x013C995A,
24    0x013991C2, 0x013698DF, 0x0133AE45, 0x0130D190, 0x012E025C, 0x012B404A, 0x01288B01, 0x0125E227,
25    0x01234567, 0x0120B470, 0x011E2EF3, 0x011BB4A4, 0x01194538, 0x0116E068, 0x011485F0, 0x0112358E,
26    0x010FEF01, 0x010DB20A, 0x010B7E6E, 0x010953F3, 0x01073260, 0x0105197F, 0x0103091B, 0x01010101,
27];
28
29#[macro_export]
30macro_rules! tell_frac_inline {
31    ($rc:expr) => {{
32        static CORRECTION: [u32; 8] = [35733, 38967, 42495, 46340, 50535, 55109, 60097, 65535];
33        let nbits = $rc.nbits_total << BITRES;
34        let l = 32 - $rc.rng.leading_zeros() as i32;
35        let r = $rc.rng >> (l - 16);
36        let b = (r >> 12).wrapping_sub(8);
37
38        let correction = unsafe { *CORRECTION.get_unchecked(b as usize) };
39        let b = b + (r > correction) as u32;
40        nbits - (l << 3) - b as i32
41    }};
42}
43
44#[derive(Clone)]
45pub struct RangeCoder {
46    pub buf: Vec<u8>,
47    pub storage: u32,
48    pub end_offs: u32,
49    pub end_window: u32,
50    pub nend_bits: i32,
51    pub nbits_total: i32,
52    pub offs: u32,
53    pub rng: u32,
54    pub val: u32,
55    pub ext: u32,
56    pub rem: i32,
57    pub error: i32,
58}
59
60impl RangeCoder {
61    pub fn new_encoder(size: u32) -> Self {
62        let buf = vec![0u8; size as usize];
63        RangeCoder {
64            buf,
65            storage: size,
66            end_offs: 0,
67            end_window: 0,
68            nend_bits: 0,
69            nbits_total: 33,
70            offs: 0,
71            rng: 1 << 31,
72            val: 0,
73            ext: 0,
74            rem: -1,
75            error: 0,
76        }
77    }
78
79    #[inline]
80    pub fn reset_for_encode(&mut self, size: u32) {
81        if self.buf.len() < size as usize {
82            self.buf.resize(size as usize, 0);
83        }
84        unsafe {
85            self.buf.set_len(size as usize);
86        }
87        self.storage = size;
88        self.end_offs = 0;
89        self.end_window = 0;
90        self.nend_bits = 0;
91        self.nbits_total = 33;
92        self.offs = 0;
93        self.rng = 1 << 31;
94        self.val = 0;
95        self.ext = 0;
96        self.rem = -1;
97        self.error = 0;
98    }
99
100    pub fn new_decoder(data: &[u8]) -> Self {
101        let storage = data.len() as u32;
102        let buf = data.to_vec();
103        let mut rc = RangeCoder {
104            buf,
105            storage,
106            end_offs: 0,
107            end_window: 0,
108            nend_bits: 0,
109            nbits_total: (EC_CODE_BITS + 1
110                - ((EC_CODE_BITS - EC_CODE_EXTRA) / EC_SYM_BITS) * EC_SYM_BITS)
111                as i32,
112            offs: 0,
113            rng: 1 << EC_CODE_EXTRA,
114            val: 0,
115            ext: 0,
116            rem: 0,
117            error: 0,
118        };
119
120        rc.rem = rc.read_byte() as i32;
121        rc.val = rc
122            .rng
123            .wrapping_sub(1)
124            .wrapping_sub(rc.rem as u32 >> (EC_SYM_BITS - EC_CODE_EXTRA));
125
126        rc.normalize_decoder();
127        rc
128    }
129
130    #[inline(always)]
131    fn normalize_decoder(&mut self) {
132        let mut guard = 0u32;
133        while self.rng <= EC_CODE_BOT {
134            guard += 1;
135            if guard > 100 {
136                self.error = 1;
137                self.rng = EC_CODE_BOT + 1;
138                break;
139            }
140            self.nbits_total += EC_SYM_BITS as i32;
141            self.rng <<= EC_SYM_BITS;
142
143            let sym = self.rem;
144            self.rem = self.read_byte() as i32;
145
146            let combined_sym = ((sym << EC_SYM_BITS) | self.rem) >> (EC_SYM_BITS - EC_CODE_EXTRA);
147            self.val = (self.val << EC_SYM_BITS).wrapping_add(EC_SYM_MAX & !combined_sym as u32)
148                & (EC_CODE_TOP - 1);
149        }
150    }
151
152    fn read_byte(&mut self) -> u8 {
153        if self.offs < self.storage {
154            let b = self.buf[self.offs as usize];
155            self.offs += 1;
156            b
157        } else {
158            0
159        }
160    }
161
162    #[inline(always)]
163    pub fn enc_uint(&mut self, fl: u32, ft: u32) {
164        if ft > 1 {
165            let ft_minus_1 = ft - 1;
166            let ftb = 32 - ft_minus_1.leading_zeros() as i32;
167            if ftb > 8 {
168                let s = ftb - 8;
169                self.enc_bits(fl & ((1 << s) - 1), s as u32);
170                let fl = fl >> s;
171                let ft = (ft_minus_1 >> s) + 1;
172                self.encode(fl, fl.wrapping_add(1), ft);
173            } else {
174                self.encode(fl, fl.wrapping_add(1), ft);
175            }
176        }
177    }
178
179    #[inline(always)]
180    pub fn dec_uint(&mut self, ft: u32) -> u32 {
181        if ft > 1 {
182            let ft_minus_1 = ft - 1;
183            let ftb = 32 - ft_minus_1.leading_zeros() as i32;
184            if ftb > 8 {
185                let s = ftb - 8;
186                let r = self.dec_bits(s as u32);
187                let ft = (ft_minus_1 >> s) + 1;
188                let fs = self.decode(ft);
189                self.update(fs, fs.wrapping_add(1), ft);
190                (fs << s) | r
191            } else {
192                let fs = self.decode(ft);
193                self.update(fs, fs.wrapping_add(1), ft);
194                fs
195            }
196        } else {
197            0
198        }
199    }
200
201    #[inline(always)]
202    pub fn enc_bits(&mut self, val: u32, bits: u32) {
203        if bits == 0 {
204            return;
205        }
206        let mut window = self.end_window;
207        let mut used = self.nend_bits;
208        if (used as u32) + bits > EC_CODE_BITS {
209            while used >= EC_SYM_BITS as i32 {
210                self.write_byte_at_end((window & EC_SYM_MAX) as u8);
211                window >>= EC_SYM_BITS;
212                used -= EC_SYM_BITS as i32;
213            }
214        }
215        window |= (val & ((1 << bits) - 1)) << used;
216        used += bits as i32;
217        self.end_window = window;
218        self.nend_bits = used;
219        self.nbits_total += bits as i32;
220    }
221
222    pub fn pad_to_bits(&mut self, target_bits: i32) {
223        let remaining = target_bits - self.nbits_total;
224        if remaining <= 0 {
225            return;
226        }
227        let mut remaining = remaining as u32;
228
229        let partial =
230            (EC_SYM_BITS - (self.nend_bits as u32 & (EC_SYM_BITS - 1))) & (EC_SYM_BITS - 1);
231        if partial > 0 && remaining >= partial {
232            self.enc_bits(0, partial.min(remaining));
233            remaining -= partial.min(remaining);
234        }
235
236        let full_bytes = remaining / EC_SYM_BITS;
237        if full_bytes > 0 {
238            let available = self.storage - self.offs - self.end_offs;
239            let write_count = full_bytes.min(available);
240            if write_count > 0 {
241                let start = (self.storage - self.end_offs - write_count) as usize;
242                unsafe {
243                    std::ptr::write_bytes(
244                        self.buf.as_mut_ptr().add(start),
245                        0,
246                        write_count as usize,
247                    );
248                }
249                self.end_offs += write_count;
250            }
251            if write_count < full_bytes {
252                self.error = 1;
253            }
254            self.nbits_total += (full_bytes * EC_SYM_BITS) as i32;
255            remaining -= full_bytes * EC_SYM_BITS;
256        }
257
258        if remaining > 0 {
259            self.enc_bits(0, remaining);
260        }
261    }
262
263    pub fn dec_bits(&mut self, bits: u32) -> u32 {
264        if bits == 0 {
265            return 0;
266        }
267        let mut window = self.end_window;
268        let mut used = self.nend_bits;
269        if used < bits as i32 {
270            loop {
271                let byte = if self.end_offs < self.storage {
272                    self.end_offs += 1;
273                    self.buf[(self.storage - self.end_offs) as usize]
274                } else {
275                    0
276                };
277                window |= (byte as u32) << used;
278                used += 8;
279                if used > 32 - 8 {
280                    break;
281                }
282            }
283        }
284        let ret = window & ((1 << bits) - 1);
285        self.end_window = window >> bits;
286        self.nend_bits = used - bits as i32;
287        self.nbits_total += bits as i32;
288        ret
289    }
290
291    #[inline(always)]
292    pub fn tell_frac(&self) -> i32 {
293        static CORRECTION: [u32; 8] = [35733, 38967, 42495, 46340, 50535, 55109, 60097, 65535];
294        let nbits = self.nbits_total << BITRES;
295        let l = 32 - self.rng.leading_zeros() as i32;
296        let r = self.rng >> (l - 16);
297        let b = (r >> 12).wrapping_sub(8);
298        let b = b + (r > CORRECTION[b as usize]) as u32;
299        nbits - (l << 3) - b as i32
300    }
301
302    /// Integer bit count matching C's `ec_tell()`: `nbits_total - EC_ILOG(rng)`.
303    #[inline(always)]
304    pub fn tell(&self) -> i32 {
305        self.nbits_total - (32 - self.rng.leading_zeros() as i32)
306    }
307
308    #[inline(always)]
309    pub fn tell_fast(&self) -> i32 {
310        self.nbits_total
311    }
312
313    /// Shrink the range coder buffer, moving end-coded bytes to the new end.
314    /// Equivalent to C's `ec_enc_shrink`.
315    pub fn shrink(&mut self, new_size: u32) {
316        debug_assert!(self.offs + self.end_offs <= new_size);
317        if self.end_offs > 0 {
318            let old_end_start = (self.storage - self.end_offs) as usize;
319            let old_end_end = self.storage as usize;
320            let new_end_start = (new_size - self.end_offs) as usize;
321            self.buf
322                .copy_within(old_end_start..old_end_end, new_end_start);
323        }
324        self.storage = new_size;
325    }
326
327    #[inline(always)]
328    fn write_byte(&mut self, value: u8) {
329        if self.offs + self.end_offs < self.storage {
330            unsafe {
331                *self.buf.get_unchecked_mut(self.offs as usize) = value;
332            }
333            self.offs += 1;
334        } else {
335            self.error = 1;
336        }
337    }
338
339    #[inline(always)]
340    fn carry_out(&mut self, c: i32) {
341        if c != EC_SYM_MAX as i32 {
342            let carry = c >> EC_SYM_BITS;
343            if self.rem >= 0 {
344                self.write_byte((self.rem + carry) as u8);
345            }
346            if self.ext > 0 {
347                let sym = (EC_SYM_MAX as i32 + carry) & EC_SYM_MAX as i32;
348
349                let ext = self.ext as usize;
350                for _j in 0..ext {
351                    self.write_byte(sym as u8);
352                }
353                self.ext = 0;
354            }
355            self.rem = c & EC_SYM_MAX as i32;
356        } else {
357            self.ext += 1;
358        }
359    }
360
361    #[inline(always)]
362    fn celt_udiv(n: u32, d: u32) -> u32 {
363        if d <= 256 {
364            let t = d.trailing_zeros();
365            let v = d >> t;
366            let idx = (v - 1) >> 1;
367
368            let table_val = unsafe { *SMALL_DIV_TABLE.get_unchecked(idx as usize) };
369            let q = ((table_val as u64 * (n >> t) as u64) >> 32) as u32;
370            return q + (n.wrapping_sub(q.wrapping_mul(d)) >= d) as u32;
371        }
372        n / d
373    }
374
375    #[inline(always)]
376    pub fn encode(&mut self, fl: u32, fh: u32, ft: u32) {
377        debug_assert!(ft > 0, "encode: ft must be > 0");
378        let r = Self::celt_udiv(self.rng, ft);
379        if fl > 0 {
380            self.val = self
381                .val
382                .wrapping_add(self.rng.wrapping_sub(r.wrapping_mul(ft.wrapping_sub(fl))));
383            self.rng = r.wrapping_mul(fh.wrapping_sub(fl));
384        } else {
385            self.rng = self.rng.wrapping_sub(r.wrapping_mul(ft.wrapping_sub(fh)));
386        }
387        self.normalize_encoder();
388    }
389
390    #[inline(always)]
391    fn normalize_encoder(&mut self) {
392        while self.rng <= EC_CODE_BOT {
393            let c = (self.val >> EC_CODE_SHIFT) as i32;
394            if c != EC_SYM_MAX as i32 {
395                let carry = c >> EC_SYM_BITS;
396                if self.rem >= 0 {
397                    if self.offs + self.end_offs < self.storage {
398                        unsafe {
399                            *self.buf.get_unchecked_mut(self.offs as usize) =
400                                (self.rem + carry) as u8;
401                        }
402                        self.offs += 1;
403                    } else {
404                        self.error = 1;
405                    }
406                }
407                if self.ext > 0 {
408                    let sym = (EC_SYM_MAX as i32 + carry) & EC_SYM_MAX as i32;
409                    let ext = self.ext as usize;
410                    for _j in 0..ext {
411                        if self.offs + self.end_offs < self.storage {
412                            unsafe {
413                                *self.buf.get_unchecked_mut(self.offs as usize) = sym as u8;
414                            }
415                            self.offs += 1;
416                        } else {
417                            self.error = 1;
418                        }
419                    }
420                    self.ext = 0;
421                }
422                self.rem = c & EC_SYM_MAX as i32;
423            } else {
424                self.ext += 1;
425            }
426            self.val = (self.val << EC_SYM_BITS) & (EC_CODE_TOP - 1);
427            self.rng <<= EC_SYM_BITS;
428            self.nbits_total = self.nbits_total.wrapping_add(EC_SYM_BITS as i32);
429        }
430    }
431
432    #[inline(always)]
433    pub fn encode_bit_logp(&mut self, val: bool, logp: u32) {
434        let s = self.rng >> logp;
435        let r = self.rng.wrapping_sub(s);
436        if val {
437            self.val = self.val.wrapping_add(r);
438            self.rng = s;
439        } else {
440            self.rng = r;
441        }
442        self.normalize_encoder();
443    }
444
445    #[inline(always)]
446    pub fn encode_icdf(&mut self, s: i32, icdf: &[u8], ftb: u32) {
447        let r = self.rng >> ftb;
448        if s > 0 {
449            let val = unsafe { *icdf.get_unchecked((s - 1) as usize) as u32 };
450            self.val = self
451                .val
452                .wrapping_add(self.rng.wrapping_sub(r.wrapping_mul(val)));
453            let lower = unsafe { *icdf.get_unchecked(s as usize) };
454            self.rng = r.wrapping_mul(val.wrapping_sub(lower as u32));
455        } else {
456            let val = unsafe { *icdf.get_unchecked(s as usize) as u32 };
457            self.rng = self.rng.wrapping_sub(r.wrapping_mul(val));
458        }
459        self.normalize_encoder();
460    }
461
462    #[inline(always)]
463    pub fn decode_bit_logp(&mut self, logp: u32) -> bool {
464        let s = self.rng >> logp;
465        let ret = self.val < s;
466        if !ret {
467            self.val = self.val.wrapping_sub(s);
468            self.rng = self.rng.wrapping_sub(s);
469        } else {
470            self.rng = s;
471        }
472        self.normalize_decoder();
473        ret
474    }
475
476    #[inline(always)]
477    pub fn decode_icdf(&mut self, icdf: &[u8], ftb: u32) -> i32 {
478        let mut s = self.rng;
479        let d = self.val;
480        let r = s >> ftb;
481        let mut ret = 0;
482        let mut t;
483
484        loop {
485            t = s;
486            s = r.wrapping_mul(icdf[ret] as u32);
487            ret += 1;
488            if d >= s {
489                break;
490            }
491        }
492
493        self.val = d.wrapping_sub(s);
494        self.rng = t.wrapping_sub(s);
495        self.normalize_decoder();
496        (ret - 1) as i32
497    }
498
499    #[inline(always)]
500    pub fn decode(&mut self, ft: u32) -> u32 {
501        let r = self.rng / ft;
502        self.ext = r;
503        let s = self.val / r;
504        ft - ft.min(s.wrapping_add(1))
505    }
506
507    #[inline(always)]
508    pub fn update(&mut self, fl: u32, fh: u32, ft: u32) {
509        let s = self.ext.wrapping_mul(ft.wrapping_sub(fh));
510        self.val = self.val.wrapping_sub(s);
511        self.rng = if fl > 0 {
512            self.ext.wrapping_mul(fh.wrapping_sub(fl))
513        } else {
514            self.rng.wrapping_sub(s)
515        };
516        self.normalize_decoder();
517    }
518
519    pub fn laplace_encode(&mut self, value: &mut i32, fs: u32, decay: i32) {
520        let mut val = *value;
521        let mut fl = 0;
522        let mut fs_val = fs;
523
524        if val != 0 {
525            let s = if val < 0 { -1 } else { 0 };
526            val = (val + s) ^ s;
527            fl = fs_val;
528            fs_val = self.laplace_get_freq1(fs_val, decay);
529
530            let mut i = 1;
531            while fs_val > 0 && i < val {
532                fs_val *= 2;
533                fl += fs_val + 2;
534                fs_val = ((fs_val as i32 * decay) >> 15) as u32;
535                i += 1;
536            }
537
538            if fs_val == 0 {
539                let ndi_max = 32768 - fl + 1 - 1;
540                let ndi_max = (ndi_max as i32 - s) >> 1;
541                let di = (val - i).min(ndi_max - 1);
542                fl += (2 * di + 1 + s) as u32;
543                fs_val = 1u32.min(32768 - fl);
544                *value = (i + di + s) ^ s;
545            } else {
546                fs_val += 1;
547                fl += fs_val & (!s as u32);
548            }
549        }
550        self.encode(fl, fl.wrapping_add(fs_val), 1 << 15);
551    }
552
553    fn laplace_get_freq1(&self, fs0: u32, decay: i32) -> u32 {
554        let ft = 32768 - (2 * 16) - fs0;
555        ((ft as i32 * (16384 - decay)) >> 15) as u32
556    }
557
558    pub fn laplace_decode(&mut self, fs: u32, decay: i32) -> i32 {
559        let fm = self.decode(1 << 15);
560        let mut fl = 0;
561        let mut fs_val = fs;
562        let mut val = 0;
563
564        if fm >= fs_val {
565            val += 1;
566            fl = fs_val;
567            fs_val = self.laplace_get_freq1(fs_val, decay) + 1;
568
569            while fs_val > 1 && fm >= fl + 2 * fs_val {
570                fs_val *= 2;
571                fl += fs_val;
572                fs_val = (((fs_val as i32 - 2) * decay) >> 15) as u32 + 1;
573                val += 1;
574            }
575
576            if fs_val <= 1 {
577                let di = (fm - fl) >> 1;
578                val += di as i32;
579                fl += 2 * di;
580            }
581
582            if fm < fl + fs_val {
583                val = -val;
584            } else {
585                fl += fs_val;
586            }
587        }
588
589        self.update(fl, fl.wrapping_add(fs_val.min(32768 - fl)), 1 << 15);
590        val
591    }
592
593    #[inline(always)]
594    fn write_byte_at_end(&mut self, value: u8) {
595        if self.offs + self.end_offs < self.storage {
596            self.end_offs += 1;
597            let idx = (self.storage - self.end_offs) as usize;
598            unsafe {
599                *self.buf.get_unchecked_mut(idx) = value;
600            }
601        } else {
602            self.error = 1;
603        }
604    }
605
606    pub fn patch_initial_bits(&mut self, val: u32, nbits: u32) {
607        let shift = EC_SYM_BITS - nbits;
608        let mask = ((1u32 << nbits) - 1) << shift;
609        if self.offs > 0 {
610            self.buf[0] = ((self.buf[0] as u32 & !mask) | (val << shift)) as u8;
611        } else if self.rem >= 0 {
612            self.rem = ((self.rem as u32 & !mask) | (val << shift)) as i32;
613        } else if self.rng <= (EC_CODE_TOP >> nbits) {
614            let mask_shifted = mask << EC_CODE_SHIFT;
615            self.val = (self.val & !mask_shifted) | (val << (EC_CODE_SHIFT + shift));
616        } else {
617            self.error = -1;
618        }
619    }
620
621    pub fn done(&mut self) {
622        let ilog = 32 - self.rng.leading_zeros();
623        let mut l = (EC_CODE_BITS - ilog) as i32;
624        let mut msk = (EC_CODE_TOP - 1) >> l;
625        let mut end = (self.val.wrapping_add(msk)) & !msk;
626
627        if (end | msk) >= self.val.wrapping_add(self.rng) {
628            l += 1;
629            msk >>= 1;
630            end = (self.val.wrapping_add(msk)) & !msk;
631        }
632
633        while l > 0 {
634            self.carry_out((end >> EC_CODE_SHIFT) as i32);
635            end = (end << EC_SYM_BITS) & (EC_CODE_TOP - 1);
636            l -= EC_SYM_BITS as i32;
637        }
638
639        if self.rem >= 0 || self.ext > 0 {
640            self.carry_out(0);
641        }
642
643        let mut window = self.end_window;
644        let mut used = self.nend_bits;
645        while used >= EC_SYM_BITS as i32 {
646            self.write_byte_at_end((window & EC_SYM_MAX) as u8);
647            window >>= EC_SYM_BITS;
648            used -= EC_SYM_BITS as i32;
649        }
650
651        if self.error == 0 {
652            for i in self.offs..(self.storage - self.end_offs) {
653                self.buf[i as usize] = 0;
654            }
655
656            if used > 0 {
657                if self.end_offs >= self.storage {
658                    self.error = -1;
659                } else {
660                    // If we've busted, don't add too many extra bits to the
661                    // last byte; it would corrupt the range coder data.
662                    if self.offs + self.end_offs >= self.storage && (-l as i32) < used {
663                        window &= (1u32 << (-l as u32)) - 1;
664                        self.error = -1;
665                    }
666                    let idx = (self.storage - self.end_offs - 1) as usize;
667                    self.buf[idx] |= window as u8;
668                }
669            }
670        }
671    }
672
673    pub fn finish(&mut self) -> Vec<u8> {
674        self.done();
675
676        let extra_end = if self.nend_bits > 0 && self.end_offs == 0 {
677            1
678        } else {
679            0
680        };
681        let mut result = Vec::with_capacity((self.offs + self.end_offs + extra_end) as usize);
682        result.extend_from_slice(&self.buf[0..self.offs as usize]);
683        if extra_end > 0 {
684            result.push(self.buf[(self.storage - 1) as usize]);
685        }
686        result.extend_from_slice(
687            &self.buf[(self.storage - self.end_offs) as usize..self.storage as usize],
688        );
689        result
690    }
691}
692
693#[cfg(test)]
694mod tests {
695    use super::*;
696
697    #[test]
698    fn test_laplace() {
699        let mut enc = RangeCoder::new_encoder(100);
700        let mut val = -3;
701        let fs = 100 << 7;
702        let decay = 120 << 6;
703        enc.laplace_encode(&mut val, fs, decay);
704        enc.done();
705
706        assert_eq!(enc.offs, 1);
707        assert_eq!(enc.buf[0], 224);
708
709        let mut dec = RangeCoder::new_decoder(&enc.buf[..enc.offs as usize]);
710        let decoded_val = dec.laplace_decode(fs, decay);
711        assert_eq!(decoded_val, -3);
712    }
713
714    #[test]
715    fn test_icdf_consistency() {
716        let mut enc = RangeCoder::new_encoder(1024);
717        let icdf = [2, 1, 0];
718        enc.encode_icdf(0, &icdf, 2);
719        enc.encode_icdf(1, &icdf, 2);
720        enc.encode_icdf(2, &icdf, 2);
721        enc.done();
722        let data = enc.buf[..enc.offs as usize].to_vec();
723
724        let mut dec = RangeCoder::new_decoder(&data);
725        let s0 = dec.decode_icdf(&icdf, 2);
726        let s1 = dec.decode_icdf(&icdf, 2);
727        let s2 = dec.decode_icdf(&icdf, 2);
728
729        assert_eq!(s0, 0);
730        assert_eq!(s1, 1);
731        assert_eq!(s2, 2);
732    }
733
734    #[test]
735    fn test_icdf_last_symbol_no_oob() {
736        let icdf: &[u8] = &[170, 85, 0];
737        let ftb = 8u32;
738
739        for sym in 0..3i32 {
740            let mut enc = RangeCoder::new_encoder(256);
741            enc.encode_icdf(sym, icdf, ftb);
742            enc.done();
743            let data = enc.buf[..enc.offs as usize].to_vec();
744
745            let mut dec = RangeCoder::new_decoder(&data);
746            let decoded = dec.decode_icdf(icdf, ftb);
747            assert_eq!(decoded, sym, "往返失败: 编码 symbol={sym} 解码得 {decoded}");
748        }
749    }
750
751    #[test]
752    fn test_icdf_decode_terminates() {
753        let icdf: &[u8] = &[192, 128, 64, 0];
754        let ftb = 8u32;
755
756        let symbols = [0i32, 1, 2, 3];
757        let mut enc = RangeCoder::new_encoder(256);
758        for &s in &symbols {
759            enc.encode_icdf(s, icdf, ftb);
760        }
761        enc.done();
762        let data = enc.buf[..enc.offs as usize].to_vec();
763
764        let mut dec = RangeCoder::new_decoder(&data);
765        for &expected in &symbols {
766            let got = dec.decode_icdf(icdf, ftb);
767            assert_eq!(got, expected, "解码器输出 {got},期望 {expected}");
768        }
769    }
770
771    #[test]
772    fn test_bits_only() {
773        let mut enc = RangeCoder::new_encoder(1024);
774
775        enc.enc_bits(1, 1);
776        enc.enc_bits(5, 3);
777        enc.enc_bits(7, 3);
778        enc.enc_bits(0, 2);
779
780        let data = enc.finish();
781        let mut dec = RangeCoder::new_decoder(&data);
782
783        let b1 = dec.dec_bits(1);
784        let b2 = dec.dec_bits(3);
785        let b3 = dec.dec_bits(3);
786        let b4 = dec.dec_bits(2);
787
788        assert_eq!(b1, 1);
789        assert_eq!(b2, 5);
790        assert_eq!(b3, 7);
791        assert_eq!(b4, 0);
792    }
793
794    #[test]
795    fn test_interleaved_bits_entropy() {
796        let mut enc = RangeCoder::new_encoder(1024);
797
798        enc.enc_bits(1, 1);
799
800        enc.encode(10, 20, 100);
801
802        enc.enc_bits(5, 3);
803
804        enc.encode(50, 60, 100);
805
806        let data = enc.finish();
807
808        let mut dec = RangeCoder::new_decoder(&data);
809
810        let b1 = dec.dec_bits(1);
811        let d1 = dec.decode(100);
812        dec.update(10, 20, 100);
813        let b2 = dec.dec_bits(3);
814        let d2 = dec.decode(100);
815        dec.update(50, 60, 100);
816
817        assert_eq!(b1, 1);
818        assert!((10..20).contains(&d1), "d1={} expected in [10, 20)", d1);
819        assert_eq!(b2, 5);
820        assert!((50..60).contains(&d2), "d2={} expected in [50, 60)", d2);
821    }
822}