Skip to main content

opus_rs/
range_coder.rs

1pub const EC_SYM_BITS: u32 = 8;
2pub const EC_CODE_BITS: u32 = 32;
3pub const EC_SYM_MAX: u32 = (1 << EC_SYM_BITS) - 1;
4pub const EC_CODE_SHIFT: u32 = EC_CODE_BITS - EC_SYM_BITS - 1;
5pub const EC_CODE_TOP: u32 = 1 << (EC_CODE_BITS - 1);
6pub const EC_CODE_BOT: u32 = EC_CODE_TOP >> EC_SYM_BITS;
7pub const EC_CODE_EXTRA: u32 = (EC_CODE_BITS - 2) % EC_SYM_BITS + 1;
8pub const BITRES: i32 = 3;
9
10#[derive(Clone)]
11pub struct RangeCoder {
12    pub buf: Vec<u8>,
13    pub storage: u32,
14    pub end_offs: u32,
15    pub end_window: u32,
16    pub nend_bits: i32,
17    pub nbits_total: i32,
18    pub offs: u32,
19    pub rng: u32,
20    pub val: u64,
21    pub ext: u32,
22    pub rem: i32,
23    pub error: i32,
24}
25
26impl RangeCoder {
27    pub fn new_encoder(size: u32) -> Self {
28        RangeCoder {
29            buf: vec![0; size as usize],
30            storage: size,
31            end_offs: 0,
32            end_window: 0,
33            nend_bits: 0,
34            nbits_total: 33,
35            offs: 0,
36            rng: 1 << 31,
37            val: 0,
38            ext: 0,
39            rem: -1,
40            error: 0,
41        }
42    }
43
44    pub fn new_decoder(data: &[u8]) -> Self {
45        let storage = data.len() as u32;
46        let buf = data.to_vec();
47        let mut rc = RangeCoder {
48            buf,
49            storage,
50            end_offs: 0,
51            end_window: 0,
52            nend_bits: 0,
53            nbits_total: (EC_CODE_BITS + 1
54                - ((EC_CODE_BITS - EC_CODE_EXTRA) / EC_SYM_BITS) * EC_SYM_BITS)
55                as i32,
56            offs: 0,
57            rng: 1 << EC_CODE_EXTRA,
58            val: 0,
59            ext: 0,
60            rem: 0,
61            error: 0,
62        };
63
64        rc.rem = rc.read_byte() as i32;
65        rc.val = (rc.rng - 1 - (rc.rem as u32 >> (EC_SYM_BITS - EC_CODE_EXTRA))) as u64;
66
67        rc.normalize_decoder();
68        rc
69    }
70
71    fn normalize_decoder(&mut self) {
72        while self.rng <= EC_CODE_BOT {
73            self.nbits_total += EC_SYM_BITS as i32;
74            self.rng <<= EC_SYM_BITS;
75            if self.rng == 0 {
76                debug_assert!(
77                    false,
78                    "normalize_decoder: rng=0 after shift, corrupt bitstream"
79                );
80                self.error = 1;
81                self.rng = 1;
82                return;
83            }
84
85            let sym = self.rem;
86            self.rem = self.read_byte() as i32;
87
88            let combined_sym = ((sym << EC_SYM_BITS) | self.rem) >> (EC_SYM_BITS - EC_CODE_EXTRA);
89            self.val = ((self.val << EC_SYM_BITS) + (EC_SYM_MAX & !combined_sym as u32) as u64)
90                & (EC_CODE_TOP as u64 - 1);
91        }
92    }
93
94    fn read_byte(&mut self) -> u8 {
95        if self.offs < self.storage {
96            let b = self.buf[self.offs as usize];
97            self.offs += 1;
98            b
99        } else {
100            0
101        }
102    }
103
104    pub fn enc_uint(&mut self, fl: u32, ft: u32) {
105        if ft > (1 << 8) {
106            let mut ft = ft - 1;
107            let s = 32 - ft.leading_zeros() as i32 - 8;
108            self.enc_bits(fl & ((1 << s) - 1), s as u32);
109            let fl = fl >> s;
110            ft >>= s;
111            ft += 1;
112            self.encode(fl, fl + 1, ft);
113        } else if ft > 1 {
114            self.encode(fl, fl + 1, ft);
115        }
116    }
117
118    pub fn dec_uint(&mut self, ft: u32) -> u32 {
119        if ft > (1 << 8) {
120            let mut ft = ft - 1;
121            let s = 32 - ft.leading_zeros() as i32 - 8;
122            let r = self.dec_bits(s as u32);
123            ft >>= s;
124            ft += 1;
125            let fs = self.decode(ft);
126            self.update(fs, fs + 1, ft);
127            (fs << s) | r
128        } else if ft > 1 {
129            let fs = self.decode(ft);
130            self.update(fs, fs + 1, ft);
131            fs
132        } else {
133            0
134        }
135    }
136
137    pub fn enc_bits(&mut self, val: u32, bits: u32) {
138        if bits == 0 {
139            return;
140        }
141        let mut window = self.end_window;
142        let mut used = self.nend_bits;
143        if (used as u32) + bits > EC_CODE_BITS {
144            while used >= EC_SYM_BITS as i32 {
145                self.write_byte_at_end((window & EC_SYM_MAX) as u8);
146                window >>= EC_SYM_BITS;
147                used -= EC_SYM_BITS as i32;
148            }
149        }
150        window |= (val & ((1 << bits) - 1)) << used;
151        used += bits as i32;
152        self.end_window = window;
153        self.nend_bits = used;
154        self.nbits_total += bits as i32;
155    }
156
157    pub fn dec_bits(&mut self, bits: u32) -> u32 {
158        if bits == 0 {
159            return 0;
160        }
161        let mut window = self.end_window;
162        let mut used = self.nend_bits;
163        if used < bits as i32 {
164            loop {
165                let byte = if self.end_offs < self.storage {
166                    self.end_offs += 1;
167                    self.buf[(self.storage - self.end_offs) as usize]
168                } else {
169                    0
170                };
171                window |= (byte as u32) << used;
172                used += 8;
173                if used > 32 - 8 {
174                    break;
175                }
176            }
177        }
178        let ret = window & ((1 << bits) - 1);
179        self.end_window = window >> bits;
180        self.nend_bits = used - bits as i32;
181        self.nbits_total += bits as i32;
182        ret
183    }
184
185    pub fn tell_frac(&self) -> i32 {
186        static CORRECTION: [u32; 8] = [35733, 38967, 42495, 46340, 50535, 55109, 60097, 65535];
187        let nbits = self.nbits_total << BITRES;
188        let mut l = 32 - self.rng.leading_zeros() as i32;
189        let r = self.rng >> (l - 16);
190        let mut b = (r >> 12) - 8;
191        if b < 7 && r > CORRECTION[b as usize] {
192            b += 1;
193        }
194        l = (l << 3) + b as i32;
195        nbits - l
196    }
197
198    pub fn tell(&self) -> i32 {
199        (self.tell_frac() + 7) >> 3
200    }
201
202    fn write_byte(&mut self, value: u8) {
203        if self.offs + self.end_offs < self.storage {
204            self.buf[self.offs as usize] = value;
205            self.offs += 1;
206        } else {
207            self.error = 1;
208        }
209    }
210
211    fn carry_out(&mut self, c: i32) {
212        if c != EC_SYM_MAX as i32 {
213            let carry = c >> EC_SYM_BITS;
214            if self.rem >= 0 {
215                self.write_byte((self.rem + carry) as u8);
216            }
217            if self.ext > 0 {
218                let sym = (EC_SYM_MAX as i32 + carry) & EC_SYM_MAX as i32;
219                for _ in 0..self.ext {
220                    self.write_byte(sym as u8);
221                }
222                self.ext = 0;
223            }
224            self.rem = c & EC_SYM_MAX as i32;
225        } else {
226            self.ext += 1;
227        }
228    }
229
230    pub fn encode(&mut self, fl: u32, fh: u32, ft: u32) {
231        if ft == 0 {
232            return;
233        }
234        let r = self.rng / ft;
235        if fl > 0 {
236            self.val += (self.rng - r * (ft - fl)) as u64;
237            self.rng = r * (fh - fl);
238        } else {
239            self.rng -= r * (ft - fh);
240        }
241        self.normalize_encoder();
242    }
243
244    fn normalize_encoder(&mut self) {
245        if self.rng == 0 {
246            self.error = 1;
247            self.rng = 1;
248            return;
249        }
250        while self.rng <= EC_CODE_BOT {
251            self.carry_out((self.val >> EC_CODE_SHIFT) as i32);
252            self.val = (self.val << EC_SYM_BITS) & (EC_CODE_TOP as u64 - 1);
253            self.rng <<= EC_SYM_BITS;
254            self.nbits_total = self.nbits_total.wrapping_add(EC_SYM_BITS as i32);
255        }
256    }
257
258    pub fn encode_bit_logp(&mut self, val: bool, logp: u32) {
259        let s = self.rng >> logp;
260        let r = self.rng - s;
261        if val {
262            self.val += r as u64;
263            self.rng = s;
264        } else {
265            self.rng = r;
266        }
267        self.normalize_encoder();
268    }
269
270    pub fn encode_icdf(&mut self, s: i32, icdf: &[u8], ftb: u32) {
271        let r = self.rng >> ftb;
272        if s > 0 {
273            let val = icdf[(s - 1) as usize] as u32;
274            self.val += (self.rng as u64) - (r as u64 * val as u64);
275            // The last symbol uses an implicit lower bound of 0
276            let lower = icdf.get(s as usize).copied().unwrap_or(0) as u32;
277            let diff = val - lower;
278            debug_assert!(
279                diff > 0,
280                "encode_icdf: zero-probability symbol s={s}, icdf={icdf:?}, ftb={ftb} \
281                 (icdf[{prev}]={val} == icdf[{s}]={lower})",
282                prev = s - 1,
283            );
284            self.rng = r * diff;
285        } else {
286            let val = icdf[s as usize] as u32;
287            let full = 1u32 << ftb;
288            debug_assert!(
289                val < full,
290                "encode_icdf: zero-probability symbol s=0, icdf={icdf:?}, ftb={ftb} \
291                 (icdf[0]={val} == 2^ftb={full}, symbol has zero probability)"
292            );
293            self.rng -= r * val;
294        }
295        self.normalize_encoder();
296    }
297
298    pub fn decode_bit_logp(&mut self, logp: u32) -> bool {
299        let s = self.rng >> logp;
300        let ret = self.val < s as u64;
301        if !ret {
302            self.val -= s as u64;
303            self.rng -= s;
304        } else {
305            self.rng = s;
306        }
307        self.normalize_decoder();
308        ret
309    }
310
311    /// Decode a symbol using an inverse CDF table.
312    /// Uses do-while pattern like C opus for better performance.
313    pub fn decode_icdf(&mut self, icdf: &[u8], ftb: u32) -> i32 {
314        let mut s = self.rng;
315        let d = self.val as u32;
316        let r = s >> ftb;
317        let mut ret = 0;
318        let mut t;
319
320        // Do-while loop: at least one iteration is guaranteed
321        // This matches C opus behavior and is faster for typical small icdf tables
322        loop {
323            t = s;
324            s = r * (icdf[ret] as u32);
325            ret += 1;
326            if d >= s {
327                break;
328            }
329        }
330
331        self.val = (d - s) as u64;
332        self.rng = t - s;
333        self.normalize_decoder();
334        (ret - 1) as i32
335    }
336
337    pub fn decode(&mut self, ft: u32) -> u32 {
338        let r = self.rng / ft;
339        self.ext = r;
340        let s = (self.val / r as u64) as u32;
341        ft - ft.min(s + 1)
342    }
343
344    pub fn update(&mut self, fl: u32, fh: u32, ft: u32) {
345        let s = self.ext * (ft - fh);
346        self.val -= s as u64;
347        self.rng = if fl > 0 {
348            self.ext * (fh - fl)
349        } else {
350            self.rng - s
351        };
352
353        self.normalize_decoder();
354    }
355
356    pub fn laplace_encode(&mut self, value: &mut i32, fs: u32, decay: i32) {
357        let mut val = *value;
358        let mut fl = 0;
359        let mut fs_val = fs;
360
361        if val != 0 {
362            let s = if val < 0 { -1 } else { 0 };
363            val = (val + s) ^ s;
364            fl = fs_val;
365            fs_val = self.laplace_get_freq1(fs_val, decay);
366
367            let mut i = 1;
368            while fs_val > 0 && i < val {
369                fs_val *= 2;
370                fl += fs_val + 2;
371                fs_val = ((fs_val as i32 * decay) >> 15) as u32;
372                i += 1;
373            }
374
375            if fs_val == 0 {
376                let ndi_max = 32768 - fl + 1 - 1;
377                let ndi_max = (ndi_max as i32 - s) >> 1;
378                let di = (val - i).min(ndi_max - 1);
379                fl += (2 * di + 1 + s) as u32;
380                fs_val = 1u32.min(32768 - fl);
381                *value = (i + di + s) ^ s;
382            } else {
383                fs_val += 1;
384                fl += fs_val & (!s as u32);
385            }
386        }
387        self.encode(fl, fl + fs_val, 1 << 15);
388    }
389
390    fn laplace_get_freq1(&self, fs0: u32, decay: i32) -> u32 {
391        let ft = 32768 - (2 * 16) - fs0;
392        ((ft as i32 * (16384 - decay)) >> 15) as u32
393    }
394
395    pub fn laplace_decode(&mut self, fs: u32, decay: i32) -> i32 {
396        let fm = self.decode(1 << 15);
397        let mut fl = 0;
398        let mut fs_val = fs;
399        let mut val = 0;
400
401        if fm >= fs_val {
402            val += 1;
403            fl = fs_val;
404            fs_val = self.laplace_get_freq1(fs_val, decay) + 1;
405
406            while fs_val > 1 && fm >= fl + 2 * fs_val {
407                fs_val *= 2;
408                fl += fs_val;
409                fs_val = (((fs_val as i32 - 2) * decay) >> 15) as u32 + 1;
410                val += 1;
411            }
412
413            if fs_val <= 1 {
414                let di = (fm - fl) >> 1;
415                val += di as i32;
416                fl += 2 * di;
417            }
418
419            if fm < fl + fs_val {
420                val = -val;
421            } else {
422                fl += fs_val;
423            }
424        }
425
426        self.update(fl, fl + fs_val.min(32768 - fl), 1 << 15);
427        val
428    }
429
430    fn write_byte_at_end(&mut self, value: u8) {
431        if self.offs + self.end_offs < self.storage {
432            self.end_offs += 1;
433            let idx = (self.storage - self.end_offs) as usize;
434            self.buf[idx] = value;
435        } else {
436            self.error = 1;
437        }
438    }
439
440    pub fn patch_initial_bits(&mut self, val: u32, nbits: u32) {
441        let shift = EC_SYM_BITS - nbits;
442        let mask = ((1u32 << nbits) - 1) << shift;
443        if self.offs > 0 {
444            self.buf[0] = ((self.buf[0] as u32 & !mask) | (val << shift)) as u8;
445        } else if self.rem >= 0 {
446            self.rem = ((self.rem as u32 & !mask) | (val << shift)) as i32;
447        } else if self.rng <= (EC_CODE_TOP >> nbits) {
448            let mask64 = (mask as u64) << EC_CODE_SHIFT;
449            self.val = (self.val & !mask64) | ((val as u64) << (EC_CODE_SHIFT + shift));
450        } else {
451            self.error = -1;
452        }
453    }
454
455    pub fn done(&mut self) {
456        let ilog = 32 - self.rng.leading_zeros();
457        let mut l = (EC_CODE_BITS - ilog) as i32;
458        let mut msk = (EC_CODE_TOP as u64 - 1) >> l;
459        let mut end = (self.val + msk) & !msk;
460
461        if (end | msk) >= self.val + self.rng as u64 {
462            l += 1;
463            msk >>= 1;
464            end = (self.val + msk) & !msk;
465        }
466
467        while l > 0 {
468            self.carry_out((end >> EC_CODE_SHIFT) as i32);
469            end = (end << EC_SYM_BITS) & (EC_CODE_TOP as u64 - 1);
470            l -= EC_SYM_BITS as i32;
471        }
472
473        if self.rem >= 0 || self.ext > 0 {
474            self.carry_out(0);
475        }
476
477        let mut window = self.end_window;
478        let mut used = self.nend_bits;
479        while used >= EC_SYM_BITS as i32 {
480            self.write_byte_at_end((window & EC_SYM_MAX) as u8);
481            window >>= EC_SYM_BITS;
482            used -= EC_SYM_BITS as i32;
483        }
484
485        if self.error == 0 {
486            for i in self.offs..(self.storage - self.end_offs) {
487                self.buf[i as usize] = 0;
488            }
489
490            if used > 0 {
491                if self.end_offs >= self.storage {
492                    self.error = -1;
493                } else {
494                    let idx = (self.storage - self.end_offs - 1) as usize;
495                    self.buf[idx] |= window as u8;
496
497                    self.end_offs += 1;
498                }
499            }
500        }
501    }
502
503    pub fn finish(&mut self) -> Vec<u8> {
504        self.done();
505
506        let mut result = Vec::with_capacity((self.offs + self.end_offs) as usize);
507        result.extend_from_slice(&self.buf[0..self.offs as usize]);
508        result.extend_from_slice(
509            &self.buf[(self.storage - self.end_offs) as usize..self.storage as usize],
510        );
511        result
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518
519    #[test]
520    fn test_laplace() {
521        let mut enc = RangeCoder::new_encoder(100);
522        let mut val = -3;
523        let fs = 100 << 7;
524        let decay = 120 << 6;
525        enc.laplace_encode(&mut val, fs, decay);
526        enc.done();
527
528        assert_eq!(enc.offs, 1);
529        assert_eq!(enc.buf[0], 224);
530
531        let mut dec = RangeCoder::new_decoder(&enc.buf[..enc.offs as usize]);
532        let decoded_val = dec.laplace_decode(fs, decay);
533        assert_eq!(decoded_val, -3);
534    }
535
536    #[test]
537    fn test_icdf_consistency() {
538        let mut enc = RangeCoder::new_encoder(1024);
539        let icdf = [2, 1, 0];
540        enc.encode_icdf(0, &icdf, 2);
541        enc.encode_icdf(1, &icdf, 2);
542        enc.encode_icdf(2, &icdf, 2);
543        enc.done();
544        let data = enc.buf[..enc.offs as usize].to_vec();
545
546        let mut dec = RangeCoder::new_decoder(&data);
547        let s0 = dec.decode_icdf(&icdf, 2);
548        let s1 = dec.decode_icdf(&icdf, 2);
549        let s2 = dec.decode_icdf(&icdf, 2);
550
551        assert_eq!(s0, 0);
552        assert_eq!(s1, 1);
553        assert_eq!(s2, 2);
554    }
555
556    /// encode_icdf 的最后一个符号(s == icdf.len() - 1)之前会 panic(index OOB),
557    /// 修复后应当正确编解码为最后一个符号索引。
558    #[test]
559    fn test_icdf_last_symbol_no_oob() {
560        // ftb=8 → 总频率 256
561        // 3 个符号,每个频率约 85,均不为 0
562        // icdf 语义:icdf[i] = (总频率 - 前 i+1 个符号的累积频率)
563        // symbol 0: 256 - 86 = 170  → icdf[0] = 170
564        // symbol 1: 170 - 85 = 85   → icdf[1] = 85
565        // symbol 2: 85  - 85 = 0    → icdf[2] = 0  (最后必须为 0)
566        let icdf: &[u8] = &[170, 85, 0];
567        let ftb = 8u32;
568
569        // 对每个符号做一次 encode → done → decode,验证无 panic 且往返正确
570        for sym in 0..3i32 {
571            let mut enc = RangeCoder::new_encoder(256);
572            enc.encode_icdf(sym, icdf, ftb); // sym==2 之前会 OOB panic
573            enc.done();
574            let data = enc.buf[..enc.offs as usize].to_vec();
575
576            let mut dec = RangeCoder::new_decoder(&data);
577            let decoded = dec.decode_icdf(icdf, ftb);
578            assert_eq!(decoded, sym, "往返失败: 编码 symbol={sym} 解码得 {decoded}");
579        }
580    }
581
582    /// decode_icdf 在 icdf 末尾不为 0 时会死循环/OOB;
583    /// 用标准(末尾为 0)的表验证解码器正常终止。
584    #[test]
585    fn test_icdf_decode_terminates() {
586        // 使用真实 Opus 风格的 ICDF 表(末尾必须为 0)
587        // ftb=8,总频率=256;四个等概率符号各占 64
588        let icdf: &[u8] = &[192, 128, 64, 0];
589        let ftb = 8u32;
590
591        let symbols = [0i32, 1, 2, 3];
592        let mut enc = RangeCoder::new_encoder(256);
593        for &s in &symbols {
594            enc.encode_icdf(s, icdf, ftb);
595        }
596        enc.done();
597        let data = enc.buf[..enc.offs as usize].to_vec();
598
599        let mut dec = RangeCoder::new_decoder(&data);
600        for &expected in &symbols {
601            let got = dec.decode_icdf(icdf, ftb);
602            assert_eq!(got, expected, "解码器输出 {got},期望 {expected}");
603        }
604    }
605
606    #[test]
607    fn test_bits_only() {
608        let mut enc = RangeCoder::new_encoder(1024);
609
610        enc.enc_bits(1, 1);
611        enc.enc_bits(5, 3);
612        enc.enc_bits(7, 3);
613        enc.enc_bits(0, 2);
614
615        let data = enc.finish();
616        let mut dec = RangeCoder::new_decoder(&data);
617
618        let b1 = dec.dec_bits(1);
619        let b2 = dec.dec_bits(3);
620        let b3 = dec.dec_bits(3);
621        let b4 = dec.dec_bits(2);
622
623        assert_eq!(b1, 1);
624        assert_eq!(b2, 5);
625        assert_eq!(b3, 7);
626        assert_eq!(b4, 0);
627    }
628
629    #[test]
630    fn test_interleaved_bits_entropy() {
631        let mut enc = RangeCoder::new_encoder(1024);
632
633        enc.enc_bits(1, 1);
634
635        enc.encode(10, 20, 100);
636
637        enc.enc_bits(5, 3);
638
639        enc.encode(50, 60, 100);
640
641        let data = enc.finish();
642
643        let mut dec = RangeCoder::new_decoder(&data);
644
645        let b1 = dec.dec_bits(1);
646        let d1 = dec.decode(100);
647        dec.update(10, 20, 100);
648        let b2 = dec.dec_bits(3);
649        let d2 = dec.decode(100);
650        dec.update(50, 60, 100);
651
652        assert_eq!(b1, 1);
653        assert!((10..20).contains(&d1), "d1={} expected in [10, 20)", d1);
654        assert_eq!(b2, 5);
655        assert!((50..60).contains(&d2), "d2={} expected in [50, 60)", d2);
656    }
657}