Skip to main content

opus_rs/
range_coder.rs

1pub const EC_SYM_BITS: u32 = 8;
2pub const EC_CODE_BITS: u32 = 32;
3pub const EC_SYM_MAX: u32 = (1 << EC_SYM_BITS) - 1;
4pub const EC_CODE_SHIFT: u32 = EC_CODE_BITS - EC_SYM_BITS - 1;
5pub const EC_CODE_TOP: u32 = 1 << (EC_CODE_BITS - 1);
6pub const EC_CODE_BOT: u32 = EC_CODE_TOP >> EC_SYM_BITS;
7pub const EC_CODE_EXTRA: u32 = (EC_CODE_BITS - 2) % EC_SYM_BITS + 1;
8pub const BITRES: i32 = 3;
9
10/// Small divisor table for fast unsigned division (matches C opus SMALL_DIV_TABLE).
11/// Table entry i = floor(2^32 / (2*i+1)) for i >= 1, used for fast division by odd numbers.
12const SMALL_DIV_TABLE: [u32; 128] = [
13    0xFFFFFFFF, 0x55555555, 0x33333333, 0x24924924,
14    0x1C71C71C, 0x1745D174, 0x13B13B13, 0x11111111,
15    0x0F0F0F0F, 0x0D79435E, 0x0C30C30C, 0x0B21642C,
16    0x0A3D70A3, 0x097B425E, 0x08D3DCB0, 0x08421084,
17    0x07C1F07C, 0x07507507, 0x06EB3E45, 0x06906906,
18    0x063E7063, 0x05F417D0, 0x05B05B05, 0x0572620A,
19    0x05397829, 0x05050505, 0x04D4873E, 0x04A7904A,
20    0x047DC11F, 0x0456C797, 0x04325C53, 0x04104104,
21    0x03F03F03, 0x03D22635, 0x03B5CC0E, 0x039B0AD1,
22    0x0381C0E0, 0x0369D036, 0x03531DEC, 0x033D91D2,
23    0x0329161F, 0x03159721, 0x03030303, 0x02F14990,
24    0x02E05C0B, 0x02D02D02, 0x02C0B02C, 0x02B1DA46,
25    0x02A3A0FD, 0x0295FAD4, 0x0288DF0C, 0x027C4597,
26    0x02702702, 0x02647C69, 0x02593F69, 0x024E6A17,
27    0x0243F6F0, 0x0239E0D5, 0x02302302, 0x0226B902,
28    0x021D9EAD, 0x0214D021, 0x020C49BA, 0x02040810,
29    0x01FC07F0, 0x01F44659, 0x01ECC07B, 0x01E573AC,
30    0x01DE5D6E, 0x01D77B65, 0x01D0CB58, 0x01CA4B30,
31    0x01C3F8F0, 0x01BDD2B8, 0x01B7D6C3, 0x01B20364,
32    0x01AC5701, 0x01A6D01A, 0x01A16D3F, 0x019C2D14,
33    0x01970E4F, 0x01920FB4, 0x018D3018, 0x01886E5F,
34    0x0183C977, 0x017F405F, 0x017AD220, 0x01767DCE,
35    0x01724287, 0x016E1F76, 0x016A13CD, 0x01661EC6,
36    0x01623FA7, 0x015E75BB, 0x015AC056, 0x01571ED3,
37    0x01539094, 0x01501501, 0x014CAB88, 0x0149539E,
38    0x01460CBC, 0x0142D662, 0x013FB013, 0x013C995A,
39    0x013991C2, 0x013698DF, 0x0133AE45, 0x0130D190,
40    0x012E025C, 0x012B404A, 0x01288B01, 0x0125E227,
41    0x01234567, 0x0120B470, 0x011E2EF3, 0x011BB4A4,
42    0x01194538, 0x0116E068, 0x011485F0, 0x0112358E,
43    0x010FEF01, 0x010DB20A, 0x010B7E6E, 0x010953F3,
44    0x01073260, 0x0105197F, 0x0103091B, 0x01010101,
45];
46
47/// Inline macro for tell_frac computation to avoid function call overhead.
48/// This should be used in hot loops where tell_frac is called frequently.
49#[macro_export]
50macro_rules! tell_frac_inline {
51    ($rc:expr) => {{
52        static CORRECTION: [u32; 8] = [35733, 38967, 42495, 46340, 50535, 55109, 60097, 65535];
53        let nbits = $rc.nbits_total << BITRES;
54        let l = 32 - $rc.rng.leading_zeros() as i32;
55        let r = $rc.rng >> (l - 16);
56        let b = (r >> 12).wrapping_sub(8);
57        // SAFETY: b is always in 0..8 because r is rng normalized to 16 bits,
58        // so r >> 12 is in 8..15, and b = (r>>12) - 8 is in 0..7.
59        let correction = unsafe { *CORRECTION.get_unchecked(b as usize) };
60        let b = b + (r > correction) as u32;
61        nbits - (l << 3) - b as i32
62    }};
63}
64
65#[derive(Clone)]
66pub struct RangeCoder {
67    pub buf: Vec<u8>,
68    pub storage: u32,
69    pub end_offs: u32,
70    pub end_window: u32,
71    pub nend_bits: i32,
72    pub nbits_total: i32,
73    pub offs: u32,
74    pub rng: u32,
75    pub val: u32,
76    pub ext: u32,
77    pub rem: i32,
78    pub error: i32,
79}
80
81impl RangeCoder {
82    pub fn new_encoder(size: u32) -> Self {
83        // Avoid zero-initialization: allocate without initializing.
84        // The encoder writes bytes sequentially and tracks position via offs/end_offs.
85        let mut buf = Vec::with_capacity(size as usize);
86        unsafe {
87            buf.set_len(size as usize);
88        }
89        RangeCoder {
90            buf,
91            storage: size,
92            end_offs: 0,
93            end_window: 0,
94            nend_bits: 0,
95            nbits_total: 33,
96            offs: 0,
97            rng: 1 << 31,
98            val: 0,
99            ext: 0,
100            rem: -1,
101            error: 0,
102        }
103    }
104
105    /// Reset encoder state for reuse with a (possibly different) buffer size.
106    /// Reuses the existing allocation if it's large enough, avoiding per-frame heap allocation.
107    #[inline]
108    pub fn reset_for_encode(&mut self, size: u32) {
109        if self.buf.len() < size as usize {
110            self.buf.resize(size as usize, 0);
111        }
112        unsafe {
113            self.buf.set_len(size as usize);
114        }
115        self.storage = size;
116        self.end_offs = 0;
117        self.end_window = 0;
118        self.nend_bits = 0;
119        self.nbits_total = 33;
120        self.offs = 0;
121        self.rng = 1 << 31;
122        self.val = 0;
123        self.ext = 0;
124        self.rem = -1;
125        self.error = 0;
126    }
127
128    pub fn new_decoder(data: &[u8]) -> Self {
129        let storage = data.len() as u32;
130        let buf = data.to_vec();
131        let mut rc = RangeCoder {
132            buf,
133            storage,
134            end_offs: 0,
135            end_window: 0,
136            nend_bits: 0,
137            nbits_total: (EC_CODE_BITS + 1
138                - ((EC_CODE_BITS - EC_CODE_EXTRA) / EC_SYM_BITS) * EC_SYM_BITS)
139                as i32,
140            offs: 0,
141            rng: 1 << EC_CODE_EXTRA,
142            val: 0,
143            ext: 0,
144            rem: 0,
145            error: 0,
146        };
147
148        rc.rem = rc.read_byte() as i32;
149        rc.val = rc.rng.wrapping_sub(1).wrapping_sub(rc.rem as u32 >> (EC_SYM_BITS - EC_CODE_EXTRA));
150
151        rc.normalize_decoder();
152        rc
153    }
154
155    #[inline(always)]
156	fn normalize_decoder(&mut self) {
157        let mut guard = 0u32;
158        while self.rng <= EC_CODE_BOT {
159            guard += 1;
160            if guard > 100 {
161                self.error = 1;
162                self.rng = EC_CODE_BOT + 1;
163                break;
164            }
165            self.nbits_total += EC_SYM_BITS as i32;
166            self.rng <<= EC_SYM_BITS;
167
168            let sym = self.rem;
169            self.rem = self.read_byte() as i32;
170
171            let combined_sym = ((sym << EC_SYM_BITS) | self.rem) >> (EC_SYM_BITS - EC_CODE_EXTRA);
172            self.val = (self.val << EC_SYM_BITS)
173                .wrapping_add(EC_SYM_MAX & !combined_sym as u32)
174                & (EC_CODE_TOP - 1);
175        }
176    }
177
178    fn read_byte(&mut self) -> u8 {
179        if self.offs < self.storage {
180            let b = self.buf[self.offs as usize];
181            self.offs += 1;
182            b
183        } else {
184            0
185        }
186    }
187
188    #[inline(always)]
189    pub fn enc_uint(&mut self, fl: u32, ft: u32) {
190        if ft > (1 << 8) {
191            let mut ft = ft - 1;
192            let s = 32 - ft.leading_zeros() as i32 - 8;
193            self.enc_bits(fl & ((1 << s) - 1), s as u32);
194            let fl = fl >> s;
195            ft >>= s;
196            ft += 1;
197            self.encode(fl, fl.wrapping_add(1), ft);
198        } else if ft > 1 {
199            self.encode(fl, fl.wrapping_add(1), ft);
200        }
201    }
202
203    #[inline(always)]
204    pub fn dec_uint(&mut self, ft: u32) -> u32 {
205        if ft > (1 << 8) {
206            let mut ft = ft - 1;
207            let s = 32 - ft.leading_zeros() as i32 - 8;
208            let r = self.dec_bits(s as u32);
209            ft >>= s;
210            ft += 1;
211            let fs = self.decode(ft);
212            self.update(fs, fs.wrapping_add(1), ft);
213            (fs << s) | r
214        } else if ft > 1 {
215            let fs = self.decode(ft);
216            self.update(fs, fs.wrapping_add(1), ft);
217            fs
218        } else {
219            0
220        }
221    }
222
223    #[inline(always)]
224    pub fn enc_bits(&mut self, val: u32, bits: u32) {
225        if bits == 0 {
226            return;
227        }
228        let mut window = self.end_window;
229        let mut used = self.nend_bits;
230        if (used as u32) + bits > EC_CODE_BITS {
231            while used >= EC_SYM_BITS as i32 {
232                self.write_byte_at_end((window & EC_SYM_MAX) as u8);
233                window >>= EC_SYM_BITS;
234                used -= EC_SYM_BITS as i32;
235            }
236        }
237        window |= (val & ((1 << bits) - 1)) << used;
238        used += bits as i32;
239        self.end_window = window;
240        self.nend_bits = used;
241        self.nbits_total += bits as i32;
242    }
243
244    /// Efficiently pad with zero bits up to a target bit count.
245    /// This is much faster than calling enc_bits(0, 1) in a loop because it
246    /// processes whole bytes at once.
247    pub fn pad_to_bits(&mut self, target_bits: i32) {
248        let remaining = target_bits - self.nbits_total;
249        if remaining <= 0 {
250            return;
251        }
252        let mut remaining = remaining as u32;
253
254        // First, fill the current partial byte
255        let partial = (EC_SYM_BITS as u32 - (self.nend_bits as u32 & (EC_SYM_BITS - 1))) & (EC_SYM_BITS - 1);
256        if partial > 0 && remaining >= partial {
257            self.enc_bits(0, partial.min(remaining));
258            remaining -= partial.min(remaining);
259        }
260
261        // Write full zero bytes using bulk memset
262        let full_bytes = remaining / EC_SYM_BITS;
263        if full_bytes > 0 {
264            let available = self.storage - self.offs - self.end_offs;
265            let write_count = full_bytes.min(available);
266            if write_count > 0 {
267                // Fill from the end of the buffer (backward), which is where end bits go
268                let start = (self.storage - self.end_offs - write_count) as usize;
269                unsafe {
270                    std::ptr::write_bytes(self.buf.as_mut_ptr().add(start), 0, write_count as usize);
271                }
272                self.end_offs += write_count;
273            }
274            if write_count < full_bytes {
275                self.error = 1;
276            }
277            self.nbits_total += (full_bytes * EC_SYM_BITS) as i32;
278            remaining -= full_bytes * EC_SYM_BITS;
279        }
280
281        // Write remaining bits (< 8)
282        if remaining > 0 {
283            self.enc_bits(0, remaining);
284        }
285    }
286
287    pub fn dec_bits(&mut self, bits: u32) -> u32 {
288        if bits == 0 {
289            return 0;
290        }
291        let mut window = self.end_window;
292        let mut used = self.nend_bits;
293        if used < bits as i32 {
294            loop {
295                let byte = if self.end_offs < self.storage {
296                    self.end_offs += 1;
297                    self.buf[(self.storage - self.end_offs) as usize]
298                } else {
299                    0
300                };
301                window |= (byte as u32) << used;
302                used += 8;
303                if used > 32 - 8 {
304                    break;
305                }
306            }
307        }
308        let ret = window & ((1 << bits) - 1);
309        self.end_window = window >> bits;
310        self.nend_bits = used - bits as i32;
311        self.nbits_total += bits as i32;
312        ret
313    }
314
315    /// Compute the fractional number of bits used.
316    /// This is a hot path function - force inlining for performance.
317    #[inline(always)]
318    pub fn tell_frac(&self) -> i32 {
319        // CORRECTION table moved to static to reduce stack pressure
320        static CORRECTION: [u32; 8] = [35733, 38967, 42495, 46340, 50535, 55109, 60097, 65535];
321        let nbits = self.nbits_total << BITRES;
322        let l = 32 - self.rng.leading_zeros() as i32;
323        let r = self.rng >> (l - 16);
324        let b = (r >> 12).wrapping_sub(8);
325        let b = b + (r > CORRECTION[b as usize]) as u32;
326        nbits - (l << 3) - b as i32
327    }
328
329    #[inline(always)]
330    pub fn tell(&self) -> i32 {
331        // Inline tell_frac computation to avoid function call
332        const CORRECTION: [u32; 8] = [35733, 38967, 42495, 46340, 50535, 55109, 60097, 65535];
333        let nbits = self.nbits_total << BITRES;
334        let l = 32 - self.rng.leading_zeros() as i32;
335        let r = self.rng >> (l - 16);
336        let b = (r >> 12).wrapping_sub(8);
337        let b = b + (r > CORRECTION[b as usize]) as u32;
338        let tell_frac = nbits - (l << 3) - b as i32;
339        (tell_frac + 7) >> 3
340    }
341
342    /// Fast tell that only returns integer bits used (no fraction).
343    /// This is much faster than tell_frac() but less precise.
344    /// Use when only approximate bit position is needed.
345    #[inline(always)]
346    pub fn tell_fast(&self) -> i32 {
347        self.nbits_total
348    }
349
350    #[inline(always)]
351    fn write_byte(&mut self, value: u8) {
352        if self.offs + self.end_offs < self.storage {
353            unsafe {
354                *self.buf.get_unchecked_mut(self.offs as usize) = value;
355            }
356            self.offs += 1;
357        } else {
358            self.error = 1;
359        }
360    }
361
362    #[inline(always)]
363    fn carry_out(&mut self, c: i32) {
364        if c != EC_SYM_MAX as i32 {
365            let carry = c >> EC_SYM_BITS;
366            if self.rem >= 0 {
367                self.write_byte((self.rem + carry) as u8);
368            }
369            if self.ext > 0 {
370                let sym = (EC_SYM_MAX as i32 + carry) & EC_SYM_MAX as i32;
371                // SAFETY: ext counts deferred bytes that must be written
372                let ext = self.ext as usize;
373                for _j in 0..ext {
374                    self.write_byte(sym as u8);
375                }
376                self.ext = 0;
377            }
378            self.rem = c & EC_SYM_MAX as i32;
379        } else {
380            self.ext += 1;
381        }
382    }
383
384    /// Fast unsigned division using small divisor table (matches C opus celt_udiv).
385    /// For divisors <= 256, uses a multiply-high approach that's faster than hardware division.
386    /// Algorithm: factor d = 2^t * v where v is odd, then compute
387    /// q ≈ (SMALL_DIV_TABLE[(v-1)/2] * (n >> t)) >> 32 with correction.
388    #[inline(always)]
389    fn celt_udiv(n: u32, d: u32) -> u32 {
390        if d <= 256 {
391            let t = d.trailing_zeros();
392            let v = d >> t;
393            let idx = (v - 1) >> 1;
394            // SAFETY: d <= 256, v = d >> t (odd), idx = (v-1)/2.
395            // Max odd v for d<=256 is 255, idx = 127, which is < SMALL_DIV_TABLE.len()=128.
396            let table_val = unsafe { *SMALL_DIV_TABLE.get_unchecked(idx as usize) };
397            let q = ((table_val as u64 * (n >> t) as u64) >> 32) as u32;
398            return q + (n.wrapping_sub(q.wrapping_mul(d)) >= d) as u32;
399        }
400        n / d
401    }
402
403    /// All arithmetic on `val` uses wrapping_* to match C opus unsigned 32-bit behavior.
404    #[inline(always)]
405    pub fn encode(&mut self, fl: u32, fh: u32, ft: u32) {
406        debug_assert!(ft > 0, "encode: ft must be > 0");
407        let r = Self::celt_udiv(self.rng, ft);
408        if fl > 0 {
409            self.val = self.val.wrapping_add(self.rng.wrapping_sub(r.wrapping_mul(ft.wrapping_sub(fl))));
410            self.rng = r.wrapping_mul(fh.wrapping_sub(fl));
411        } else {
412            self.rng = self.rng.wrapping_sub(r.wrapping_mul(ft.wrapping_sub(fh)));
413        }
414        self.normalize_encoder();
415    }
416
417    #[inline(always)]
418    fn normalize_encoder(&mut self) {
419        while self.rng <= EC_CODE_BOT {
420            // Inline carry_out + write_byte to avoid function call overhead in hot loop
421            let c = (self.val >> EC_CODE_SHIFT) as i32;
422            if c != EC_SYM_MAX as i32 {
423                let carry = c >> EC_SYM_BITS;
424                if self.rem >= 0 {
425                    // Inline write_byte
426                    if self.offs + self.end_offs < self.storage {
427                        unsafe {
428                            *self.buf.get_unchecked_mut(self.offs as usize) =
429                                (self.rem + carry) as u8;
430                        }
431                        self.offs += 1;
432                    } else {
433                        self.error = 1;
434                    }
435                }
436                if self.ext > 0 {
437                    let sym = (EC_SYM_MAX as i32 + carry) & EC_SYM_MAX as i32;
438                    let ext = self.ext as usize;
439                    for _j in 0..ext {
440                        if self.offs + self.end_offs < self.storage {
441                            unsafe {
442                                *self.buf.get_unchecked_mut(self.offs as usize) = sym as u8;
443                            }
444                            self.offs += 1;
445                        } else {
446                            self.error = 1;
447                        }
448                    }
449                    self.ext = 0;
450                }
451                self.rem = c & EC_SYM_MAX as i32;
452            } else {
453                self.ext += 1;
454            }
455            self.val = (self.val << EC_SYM_BITS) & (EC_CODE_TOP - 1);
456            self.rng <<= EC_SYM_BITS;
457            self.nbits_total = self.nbits_total.wrapping_add(EC_SYM_BITS as i32);
458        }
459    }
460
461    #[inline(always)]
462    pub fn encode_bit_logp(&mut self, val: bool, logp: u32) {
463        let s = self.rng >> logp;
464        let r = self.rng.wrapping_sub(s);
465        if val {
466            self.val = self.val.wrapping_add(r);
467            self.rng = s;
468        } else {
469            self.rng = r;
470        }
471        self.normalize_encoder();
472    }
473
474    #[inline(always)]
475    pub fn encode_icdf(&mut self, s: i32, icdf: &[u8], ftb: u32) {
476        let r = self.rng >> ftb;
477        if s > 0 {
478            let val = unsafe { *icdf.get_unchecked((s - 1) as usize) as u32 };
479            self.val = self.val.wrapping_add(self.rng.wrapping_sub(r.wrapping_mul(val)));
480            let lower = unsafe { *icdf.get_unchecked(s as usize) };
481            self.rng = r.wrapping_mul(val.wrapping_sub(lower as u32));
482        } else {
483            let val = unsafe { *icdf.get_unchecked(s as usize) as u32 };
484            self.rng = self.rng.wrapping_sub(r.wrapping_mul(val));
485        }
486        self.normalize_encoder();
487    }
488
489    #[inline(always)]
490    pub fn decode_bit_logp(&mut self, logp: u32) -> bool {
491        let s = self.rng >> logp;
492        let ret = self.val < s;
493        if !ret {
494            self.val = self.val.wrapping_sub(s);
495            self.rng = self.rng.wrapping_sub(s);
496        } else {
497            self.rng = s;
498        }
499        self.normalize_decoder();
500        ret
501    }
502
503    /// Decode a symbol using an inverse CDF table.
504    /// Uses do-while pattern like C opus for better performance.
505    #[inline(always)]
506    pub fn decode_icdf(&mut self, icdf: &[u8], ftb: u32) -> i32 {
507        let mut s = self.rng;
508        let d = self.val;
509        let r = s >> ftb;
510        let mut ret = 0;
511        let mut t;
512
513        // Do-while loop: at least one iteration is guaranteed
514        // This matches C opus behavior and is faster for typical small icdf tables
515        loop {
516            t = s;
517            s = r.wrapping_mul(icdf[ret] as u32);
518            ret += 1;
519            if d >= s {
520                break;
521            }
522        }
523
524        self.val = d.wrapping_sub(s);
525        self.rng = t.wrapping_sub(s);
526        self.normalize_decoder();
527        (ret - 1) as i32
528    }
529
530    #[inline(always)]
531    pub fn decode(&mut self, ft: u32) -> u32 {
532        let r = self.rng / ft;
533        self.ext = r;
534        let s = self.val / r;
535        ft - ft.min(s.wrapping_add(1))
536    }
537
538    #[inline(always)]
539    pub fn update(&mut self, fl: u32, fh: u32, ft: u32) {
540        let s = self.ext.wrapping_mul(ft.wrapping_sub(fh));
541        self.val = self.val.wrapping_sub(s);
542        self.rng = if fl > 0 {
543            self.ext.wrapping_mul(fh.wrapping_sub(fl))
544        } else {
545            self.rng.wrapping_sub(s)
546        };
547        self.normalize_decoder();
548    }
549
550    pub fn laplace_encode(&mut self, value: &mut i32, fs: u32, decay: i32) {
551        let mut val = *value;
552        let mut fl = 0;
553        let mut fs_val = fs;
554
555        if val != 0 {
556            let s = if val < 0 { -1 } else { 0 };
557            val = (val + s) ^ s;
558            fl = fs_val;
559            fs_val = self.laplace_get_freq1(fs_val, decay);
560
561            let mut i = 1;
562            while fs_val > 0 && i < val {
563                fs_val *= 2;
564                fl += fs_val + 2;
565                fs_val = ((fs_val as i32 * decay) >> 15) as u32;
566                i += 1;
567            }
568
569            if fs_val == 0 {
570                let ndi_max = 32768 - fl + 1 - 1;
571                let ndi_max = (ndi_max as i32 - s) >> 1;
572                let di = (val - i).min(ndi_max - 1);
573                fl += (2 * di + 1 + s) as u32;
574                fs_val = 1u32.min(32768 - fl);
575                *value = (i + di + s) ^ s;
576            } else {
577                fs_val += 1;
578                fl += fs_val & (!s as u32);
579            }
580        }
581        self.encode(fl, fl.wrapping_add(fs_val), 1 << 15);
582    }
583
584    fn laplace_get_freq1(&self, fs0: u32, decay: i32) -> u32 {
585        let ft = 32768 - (2 * 16) - fs0;
586        ((ft as i32 * (16384 - decay)) >> 15) as u32
587    }
588
589    pub fn laplace_decode(&mut self, fs: u32, decay: i32) -> i32 {
590        let fm = self.decode(1 << 15);
591        let mut fl = 0;
592        let mut fs_val = fs;
593        let mut val = 0;
594
595        if fm >= fs_val {
596            val += 1;
597            fl = fs_val;
598            fs_val = self.laplace_get_freq1(fs_val, decay) + 1;
599
600            while fs_val > 1 && fm >= fl + 2 * fs_val {
601                fs_val *= 2;
602                fl += fs_val;
603                fs_val = (((fs_val as i32 - 2) * decay) >> 15) as u32 + 1;
604                val += 1;
605            }
606
607            if fs_val <= 1 {
608                let di = (fm - fl) >> 1;
609                val += di as i32;
610                fl += 2 * di;
611            }
612
613            if fm < fl + fs_val {
614                val = -val;
615            } else {
616                fl += fs_val;
617            }
618        }
619
620        self.update(fl, fl.wrapping_add(fs_val.min(32768 - fl)), 1 << 15);
621        val
622    }
623
624    #[inline(always)]
625    fn write_byte_at_end(&mut self, value: u8) {
626        if self.offs + self.end_offs < self.storage {
627            self.end_offs += 1;
628            let idx = (self.storage - self.end_offs) as usize;
629            unsafe {
630                *self.buf.get_unchecked_mut(idx) = value;
631            }
632        } else {
633            self.error = 1;
634        }
635    }
636
637    pub fn patch_initial_bits(&mut self, val: u32, nbits: u32) {
638        let shift = EC_SYM_BITS - nbits;
639        let mask = ((1u32 << nbits) - 1) << shift;
640        if self.offs > 0 {
641            self.buf[0] = ((self.buf[0] as u32 & !mask) | (val << shift)) as u8;
642        } else if self.rem >= 0 {
643            self.rem = ((self.rem as u32 & !mask) | (val << shift)) as i32;
644        } else if self.rng <= (EC_CODE_TOP >> nbits) {
645            let mask_shifted = mask << EC_CODE_SHIFT;
646            self.val = (self.val & !mask_shifted) | (val << (EC_CODE_SHIFT + shift));
647        } else {
648            self.error = -1;
649        }
650    }
651
652    pub fn done(&mut self) {
653        let ilog = 32 - self.rng.leading_zeros();
654        let mut l = (EC_CODE_BITS - ilog) as i32;
655        let mut msk = (EC_CODE_TOP - 1) >> l;
656        let mut end = (self.val.wrapping_add(msk)) & !msk;
657
658        if (end | msk) >= self.val.wrapping_add(self.rng) {
659            l += 1;
660            msk >>= 1;
661            end = (self.val.wrapping_add(msk)) & !msk;
662        }
663
664        while l > 0 {
665            self.carry_out((end >> EC_CODE_SHIFT) as i32);
666            end = (end << EC_SYM_BITS) & (EC_CODE_TOP - 1);
667            l -= EC_SYM_BITS as i32;
668        }
669
670        if self.rem >= 0 || self.ext > 0 {
671            self.carry_out(0);
672        }
673
674        let mut window = self.end_window;
675        let mut used = self.nend_bits;
676        while used >= EC_SYM_BITS as i32 {
677            self.write_byte_at_end((window & EC_SYM_MAX) as u8);
678            window >>= EC_SYM_BITS;
679            used -= EC_SYM_BITS as i32;
680        }
681
682        if self.error == 0 {
683            for i in self.offs..(self.storage - self.end_offs) {
684                self.buf[i as usize] = 0;
685            }
686
687            if used > 0 {
688                if self.end_offs >= self.storage {
689                    self.error = -1;
690                } else {
691                    let idx = (self.storage - self.end_offs - 1) as usize;
692                    self.buf[idx] |= window as u8;
693
694                    self.end_offs += 1;
695                }
696            }
697        }
698    }
699
700    pub fn finish(&mut self) -> Vec<u8> {
701        self.done();
702
703        let mut result = Vec::with_capacity((self.offs + self.end_offs) as usize);
704        result.extend_from_slice(&self.buf[0..self.offs as usize]);
705        result.extend_from_slice(
706            &self.buf[(self.storage - self.end_offs) as usize..self.storage as usize],
707        );
708        result
709    }
710}
711
712#[cfg(test)]
713mod tests {
714    use super::*;
715
716    #[test]
717    fn test_laplace() {
718        let mut enc = RangeCoder::new_encoder(100);
719        let mut val = -3;
720        let fs = 100 << 7;
721        let decay = 120 << 6;
722        enc.laplace_encode(&mut val, fs, decay);
723        enc.done();
724
725        assert_eq!(enc.offs, 1);
726        assert_eq!(enc.buf[0], 224);
727
728        let mut dec = RangeCoder::new_decoder(&enc.buf[..enc.offs as usize]);
729        let decoded_val = dec.laplace_decode(fs, decay);
730        assert_eq!(decoded_val, -3);
731    }
732
733    #[test]
734    fn test_icdf_consistency() {
735        let mut enc = RangeCoder::new_encoder(1024);
736        let icdf = [2, 1, 0];
737        enc.encode_icdf(0, &icdf, 2);
738        enc.encode_icdf(1, &icdf, 2);
739        enc.encode_icdf(2, &icdf, 2);
740        enc.done();
741        let data = enc.buf[..enc.offs as usize].to_vec();
742
743        let mut dec = RangeCoder::new_decoder(&data);
744        let s0 = dec.decode_icdf(&icdf, 2);
745        let s1 = dec.decode_icdf(&icdf, 2);
746        let s2 = dec.decode_icdf(&icdf, 2);
747
748        assert_eq!(s0, 0);
749        assert_eq!(s1, 1);
750        assert_eq!(s2, 2);
751    }
752
753    /// encode_icdf 的最后一个符号(s == icdf.len() - 1)之前会 panic(index OOB),
754    /// 修复后应当正确编解码为最后一个符号索引。
755    #[test]
756    fn test_icdf_last_symbol_no_oob() {
757        // ftb=8 → 总频率 256
758        // 3 个符号,每个频率约 85,均不为 0
759        // icdf 语义:icdf[i] = (总频率 - 前 i+1 个符号的累积频率)
760        // symbol 0: 256 - 86 = 170  → icdf[0] = 170
761        // symbol 1: 170 - 85 = 85   → icdf[1] = 85
762        // symbol 2: 85  - 85 = 0    → icdf[2] = 0  (最后必须为 0)
763        let icdf: &[u8] = &[170, 85, 0];
764        let ftb = 8u32;
765
766        // 对每个符号做一次 encode → done → decode,验证无 panic 且往返正确
767        for sym in 0..3i32 {
768            let mut enc = RangeCoder::new_encoder(256);
769            enc.encode_icdf(sym, icdf, ftb); // sym==2 之前会 OOB panic
770            enc.done();
771            let data = enc.buf[..enc.offs as usize].to_vec();
772
773            let mut dec = RangeCoder::new_decoder(&data);
774            let decoded = dec.decode_icdf(icdf, ftb);
775            assert_eq!(decoded, sym, "往返失败: 编码 symbol={sym} 解码得 {decoded}");
776        }
777    }
778
779    /// decode_icdf 在 icdf 末尾不为 0 时会死循环/OOB;
780    /// 用标准(末尾为 0)的表验证解码器正常终止。
781    #[test]
782    fn test_icdf_decode_terminates() {
783        // 使用真实 Opus 风格的 ICDF 表(末尾必须为 0)
784        // ftb=8,总频率=256;四个等概率符号各占 64
785        let icdf: &[u8] = &[192, 128, 64, 0];
786        let ftb = 8u32;
787
788        let symbols = [0i32, 1, 2, 3];
789        let mut enc = RangeCoder::new_encoder(256);
790        for &s in &symbols {
791            enc.encode_icdf(s, icdf, ftb);
792        }
793        enc.done();
794        let data = enc.buf[..enc.offs as usize].to_vec();
795
796        let mut dec = RangeCoder::new_decoder(&data);
797        for &expected in &symbols {
798            let got = dec.decode_icdf(icdf, ftb);
799            assert_eq!(got, expected, "解码器输出 {got},期望 {expected}");
800        }
801    }
802
803    #[test]
804    fn test_bits_only() {
805        let mut enc = RangeCoder::new_encoder(1024);
806
807        enc.enc_bits(1, 1);
808        enc.enc_bits(5, 3);
809        enc.enc_bits(7, 3);
810        enc.enc_bits(0, 2);
811
812        let data = enc.finish();
813        let mut dec = RangeCoder::new_decoder(&data);
814
815        let b1 = dec.dec_bits(1);
816        let b2 = dec.dec_bits(3);
817        let b3 = dec.dec_bits(3);
818        let b4 = dec.dec_bits(2);
819
820        assert_eq!(b1, 1);
821        assert_eq!(b2, 5);
822        assert_eq!(b3, 7);
823        assert_eq!(b4, 0);
824    }
825
826    #[test]
827    fn test_interleaved_bits_entropy() {
828        let mut enc = RangeCoder::new_encoder(1024);
829
830        enc.enc_bits(1, 1);
831
832        enc.encode(10, 20, 100);
833
834        enc.enc_bits(5, 3);
835
836        enc.encode(50, 60, 100);
837
838        let data = enc.finish();
839
840        let mut dec = RangeCoder::new_decoder(&data);
841
842        let b1 = dec.dec_bits(1);
843        let d1 = dec.decode(100);
844        dec.update(10, 20, 100);
845        let b2 = dec.dec_bits(3);
846        let d2 = dec.decode(100);
847        dec.update(50, 60, 100);
848
849        assert_eq!(b1, 1);
850        assert!((10..20).contains(&d1), "d1={} expected in [10, 20)", d1);
851        assert_eq!(b2, 5);
852        assert!((50..60).contains(&d2), "d2={} expected in [50, 60)", d2);
853    }
854}