Skip to main content

opus_rs/
range_coder.rs

1pub const EC_SYM_BITS: u32 = 8;
2pub const EC_CODE_BITS: u32 = 32;
3pub const EC_SYM_MAX: u32 = (1 << EC_SYM_BITS) - 1;
4pub const EC_CODE_SHIFT: u32 = EC_CODE_BITS - EC_SYM_BITS - 1;
5pub const EC_CODE_TOP: u32 = 1 << (EC_CODE_BITS - 1);
6pub const EC_CODE_BOT: u32 = EC_CODE_TOP >> EC_SYM_BITS;
7pub const EC_CODE_EXTRA: u32 = (EC_CODE_BITS - 2) % EC_SYM_BITS + 1;
8pub const BITRES: i32 = 3;
9
10#[derive(Clone)]
11pub struct RangeCoder {
12    pub buf: Vec<u8>,
13    pub storage: u32,
14    pub end_offs: u32,
15    pub end_window: u32,
16    pub nend_bits: i32,
17    pub nbits_total: i32,
18    pub offs: u32,
19    pub rng: u32,
20    pub val: u64,
21    pub ext: u32,
22    pub rem: i32,
23    pub error: i32,
24}
25
26impl RangeCoder {
27    pub fn new_encoder(size: u32) -> Self {
28        RangeCoder {
29            buf: vec![0; size as usize],
30            storage: size,
31            end_offs: 0,
32            end_window: 0,
33            nend_bits: 0,
34            nbits_total: 33, // EC_CODE_BITS + 1
35            offs: 0,
36            rng: 1 << 31, // EC_CODE_TOP
37            val: 0,
38            ext: 0,
39            rem: -1,
40            error: 0,
41        }
42    }
43
44    pub fn new_decoder(buf: Vec<u8>) -> Self {
45        let storage = buf.len() as u32;
46        let mut rc = RangeCoder {
47            buf,
48            storage,
49            end_offs: 0,
50            end_window: 0,
51            nend_bits: 0,
52            nbits_total: (EC_CODE_BITS + 1
53                - ((EC_CODE_BITS - EC_CODE_EXTRA) / EC_SYM_BITS) * EC_SYM_BITS)
54                as i32,
55            offs: 0,
56            rng: 1 << EC_CODE_EXTRA,
57            val: 0,
58            ext: 0,
59            rem: 0,
60            error: 0,
61        };
62
63        rc.rem = rc.read_byte() as i32;
64        rc.val = (rc.rng - 1 - (rc.rem as u32 >> (EC_SYM_BITS - EC_CODE_EXTRA))) as u64;
65
66        rc.normalize_decoder();
67        rc
68    }
69
70    fn normalize_decoder(&mut self) {
71        while self.rng <= EC_CODE_BOT {
72            self.nbits_total += EC_SYM_BITS as i32;
73            self.rng <<= EC_SYM_BITS;
74
75            let sym = self.rem;
76            self.rem = self.read_byte() as i32;
77
78            let combined_sym = ((sym << EC_SYM_BITS) | self.rem) >> (EC_SYM_BITS - EC_CODE_EXTRA);
79            self.val = ((self.val << EC_SYM_BITS)
80                + (EC_SYM_MAX as u32 & !combined_sym as u32) as u64)
81                & (EC_CODE_TOP as u64 - 1);
82        }
83    }
84
85    fn read_byte(&mut self) -> u8 {
86        if self.offs < self.storage {
87            let b = self.buf[self.offs as usize];
88            self.offs += 1;
89            b
90        } else {
91            0
92        }
93    }
94
95    pub fn enc_uint(&mut self, fl: u32, ft: u32) {
96        if ft > (1 << 8) {
97            let mut ft = ft - 1;
98            let s = 32 - ft.leading_zeros() as i32 - 8;
99            self.enc_bits(fl & ((1 << s) - 1), s as u32);
100            let fl = fl >> s;
101            ft >>= s;
102            ft += 1;
103            self.encode(fl, fl + 1, ft);
104        } else if ft > 1 {
105            self.encode(fl, fl + 1, ft);
106        }
107    }
108
109    pub fn dec_uint(&mut self, ft: u32) -> u32 {
110        if ft > (1 << 8) {
111            let mut ft = ft - 1;
112            let s = 32 - ft.leading_zeros() as i32 - 8;
113            let r = self.dec_bits(s as u32);
114            ft >>= s;
115            ft += 1;
116            let fs = self.decode(ft);
117            self.update(fs, fs + 1, ft);
118            (fs << s) | r
119        } else if ft > 1 {
120            let fs = self.decode(ft);
121            self.update(fs, fs + 1, ft);
122            fs
123        } else {
124            0
125        }
126    }
127
128    pub fn enc_bits(&mut self, val: u32, bits: u32) {
129        if bits == 0 {
130            return;
131        }
132        let mut window = self.end_window;
133        let mut used = self.nend_bits;
134        if (used as u32) + bits > EC_CODE_BITS {
135            while used >= EC_SYM_BITS as i32 {
136                self.write_byte_at_end((window & EC_SYM_MAX) as u8);
137                window >>= EC_SYM_BITS;
138                used -= EC_SYM_BITS as i32;
139            }
140        }
141        window |= (val & ((1 << bits) - 1)) << used;
142        used += bits as i32;
143        self.end_window = window;
144        self.nend_bits = used;
145        self.nbits_total += bits as i32;
146    }
147
148    pub fn dec_bits(&mut self, bits: u32) -> u32 {
149        if bits == 0 {
150            return 0;
151        }
152        let mut window = self.end_window;
153        let mut used = self.nend_bits;
154        if used < bits as i32 {
155            loop {
156                // Read byte from end: buf[storage - ++end_offs]
157                let byte = if self.end_offs < self.storage {
158                    self.end_offs += 1;
159                    self.buf[(self.storage - self.end_offs) as usize]
160                } else {
161                    0
162                };
163                window |= (byte as u32) << used;
164                used += 8;
165                if used > 32 - 8 {
166                    break;
167                }
168            }
169        }
170        let ret = window & ((1 << bits) - 1);
171        self.end_window = window >> bits;
172        self.nend_bits = used - bits as i32;
173        self.nbits_total += bits as i32;
174        ret
175    }
176
177    pub fn tell_frac(&self) -> i32 {
178        static CORRECTION: [u32; 8] = [35733, 38967, 42495, 46340, 50535, 55109, 60097, 65535];
179        let nbits = self.nbits_total << BITRES;
180        let mut l = 32 - self.rng.leading_zeros() as i32;
181        let r = self.rng >> (l - 16);
182        let mut b = (r >> 12) - 8;
183        if b < 7 && r > CORRECTION[b as usize] {
184            b += 1;
185        }
186        l = (l << 3) + b as i32;
187        nbits - l
188    }
189
190    pub fn tell(&self) -> i32 {
191        (self.tell_frac() + 7) >> 3
192    }
193
194    fn write_byte(&mut self, value: u8) {
195        if self.offs + self.end_offs < self.storage {
196            self.buf[self.offs as usize] = value;
197            self.offs += 1;
198        } else {
199            self.error = 1;
200        }
201    }
202
203    fn carry_out(&mut self, c: i32) {
204        if c != EC_SYM_MAX as i32 {
205            let carry = c >> EC_SYM_BITS;
206            if self.rem >= 0 {
207                self.write_byte((self.rem + carry) as u8);
208            }
209            if self.ext > 0 {
210                let sym = (EC_SYM_MAX as i32 + carry) & EC_SYM_MAX as i32;
211                for _ in 0..self.ext {
212                    self.write_byte(sym as u8);
213                }
214                self.ext = 0;
215            }
216            self.rem = c & EC_SYM_MAX as i32;
217        } else {
218            self.ext += 1;
219        }
220    }
221
222    pub fn encode(&mut self, fl: u32, fh: u32, ft: u32) {
223        if ft == 0 {
224            return;
225        }
226        let r = self.rng / ft;
227        if fl > 0 {
228            self.val += (self.rng - r * (ft - fl)) as u64;
229            self.rng = r * (fh - fl);
230        } else {
231            self.rng -= r * (ft - fh);
232        }
233        self.normalize_encoder();
234    }
235
236    fn normalize_encoder(&mut self) {
237        while self.rng <= EC_CODE_BOT {
238            self.carry_out((self.val >> EC_CODE_SHIFT) as i32);
239            self.val = (self.val << EC_SYM_BITS) & (EC_CODE_TOP as u64 - 1);
240            self.rng <<= EC_SYM_BITS;
241            self.nbits_total = self.nbits_total.wrapping_add(EC_SYM_BITS as i32);
242        }
243    }
244
245    pub fn encode_bit_logp(&mut self, val: bool, logp: u32) {
246        let s = self.rng >> logp;
247        let r = self.rng - s;
248        if val {
249            self.val += r as u64;
250            self.rng = s;
251        } else {
252            self.rng = r;
253        }
254        self.normalize_encoder();
255    }
256
257    pub fn encode_icdf(&mut self, s: i32, icdf: &[u8], ftb: u32) {
258        let r = self.rng >> ftb;
259        if s > 0 {
260            let val = icdf[(s - 1) as usize] as u32;
261            self.val += (self.rng as u64) - (r as u64 * val as u64);
262            let diff = val - icdf[s as usize] as u32;
263            self.rng = r * diff;
264        } else {
265            let val = icdf[s as usize] as u32;
266            self.rng -= r * val;
267        }
268        self.normalize_encoder();
269    }
270
271    pub fn decode_bit_logp(&mut self, logp: u32) -> bool {
272        let s = self.rng >> logp;
273        let ret = self.val < s as u64;
274        if !ret {
275            self.val -= s as u64;
276            self.rng -= s;
277        } else {
278            self.rng = s;
279        }
280        self.normalize_decoder();
281        ret
282    }
283
284    pub fn decode_icdf(&mut self, icdf: &[u8], ftb: u32) -> i32 {
285        let mut s = self.rng;
286        let d = self.val as u32;
287        let r = s >> ftb;
288        let mut ret = -1;
289        let mut t;
290        loop {
291            ret += 1;
292            t = s;
293            s = r * (icdf[ret as usize] as u32);
294            if d >= s {
295                break;
296            }
297        }
298        self.val = (d - s) as u64;
299        self.rng = t - s;
300        self.normalize_decoder();
301        ret
302    }
303
304    pub fn decode(&mut self, ft: u32) -> u32 {
305        let r = self.rng / ft;
306        self.ext = r;
307        let s = (self.val / r as u64) as u32;
308        ft - ft.min(s + 1)
309    }
310
311    pub fn update(&mut self, fl: u32, fh: u32, ft: u32) {
312        let s = self.ext * (ft - fh);
313        self.val -= s as u64;
314        self.rng = if fl > 0 {
315            self.ext * (fh - fl)
316        } else {
317            self.rng - s
318        };
319
320        self.normalize_decoder();
321    }
322
323    pub fn laplace_encode(&mut self, value: &mut i32, fs: u32, decay: i32) {
324        let mut val = *value;
325        let mut fl = 0;
326        let mut fs_val = fs;
327
328        if val != 0 {
329            let s = if val < 0 { -1 } else { 0 };
330            val = (val + s) ^ s;
331            fl = fs_val;
332            fs_val = self.laplace_get_freq1(fs_val, decay);
333
334            let mut i = 1;
335            while fs_val > 0 && i < val {
336                fs_val *= 2;
337                fl += fs_val + 2; // 2 * LAPLACE_MINP
338                fs_val = (fs_val as i32 * decay >> 15) as u32;
339                i += 1;
340            }
341
342            if fs_val == 0 {
343                let ndi_max = (32768 - fl + 1 - 1) >> 0;
344                let ndi_max = (ndi_max as i32 - s) >> 1;
345                let di = (val - i).min(ndi_max - 1);
346                fl += (2 * di + 1 + s) as u32;
347                fs_val = 1u32.min(32768 - fl);
348                *value = (i + di + s) ^ s;
349            } else {
350                fs_val += 1;
351                fl += fs_val & (!s as u32);
352            }
353        }
354        self.encode(fl, fl + fs_val, 1 << 15);
355    }
356
357    fn laplace_get_freq1(&self, fs0: u32, decay: i32) -> u32 {
358        let ft = 32768 - 1 * (2 * 16) - fs0; // LAPLACE_MINP=1, LAPLACE_NMIN=16
359        (ft as i32 * (16384 - decay) >> 15) as u32
360    }
361
362    pub fn laplace_decode(&mut self, fs: u32, decay: i32) -> i32 {
363        let fm = self.decode(1 << 15);
364        let mut fl = 0;
365        let mut fs_val = fs;
366        let mut val = 0;
367
368        if fm >= fs_val {
369            val += 1;
370            fl = fs_val;
371            fs_val = self.laplace_get_freq1(fs_val, decay) + 1;
372
373            while fs_val > 1 && fm >= fl + 2 * fs_val {
374                fs_val *= 2;
375                fl += fs_val;
376                fs_val = ((fs_val as i32 - 2) * decay >> 15) as u32 + 1;
377                val += 1;
378            }
379
380            if fs_val <= 1 {
381                let di = (fm - fl) >> 1;
382                val += di as i32;
383                fl += 2 * di;
384            }
385
386            if fm < fl + fs_val {
387                val = -val;
388            } else {
389                fl += fs_val;
390            }
391        }
392
393        self.update(fl, fl + fs_val.min(32768 - fl), 1 << 15);
394        val
395    }
396
397    fn write_byte_at_end(&mut self, value: u8) {
398        if self.offs + self.end_offs < self.storage {
399            self.end_offs += 1;
400            let idx = (self.storage - self.end_offs) as usize;
401            self.buf[idx] = value;
402        } else {
403            self.error = 1;
404        }
405    }
406
407    /// Patch the initial bits of the range-coded stream.
408    /// Used by SILK to retroactively insert VAD/LBRR flags at the start.
409    /// Equivalent to C `ec_enc_patch_initial_bits`.
410    pub fn patch_initial_bits(&mut self, val: u32, nbits: u32) {
411        let shift = EC_SYM_BITS - nbits;
412        let mask = ((1u32 << nbits) - 1) << shift;
413        if self.offs > 0 {
414            // The first byte has been finalized
415            self.buf[0] = ((self.buf[0] as u32 & !mask) | (val << shift)) as u8;
416        } else if self.rem >= 0 {
417            // The first byte is still awaiting carry propagation
418            self.rem = ((self.rem as u32 & !mask) | (val << shift)) as i32;
419        } else if self.rng <= (EC_CODE_TOP >> nbits) {
420            // The renormalization loop has never been run
421            let mask64 = (mask as u64) << EC_CODE_SHIFT;
422            self.val = (self.val & !mask64) | ((val as u64) << (EC_CODE_SHIFT + shift));
423        } else {
424            // The encoder hasn't even encoded nbits of data yet
425            self.error = -1;
426        }
427    }
428
429    pub fn done(&mut self) {
430        let ilog = 32 - self.rng.leading_zeros(); // Matches C EC_ILOG(rng)
431        let mut l = (EC_CODE_BITS - ilog) as i32;
432        let mut msk = (EC_CODE_TOP as u64 - 1) >> l;
433        let mut end = (self.val + msk) & !msk;
434
435        if (end | msk) >= self.val + self.rng as u64 {
436            l += 1;
437            msk >>= 1;
438            end = (self.val + msk) & !msk;
439        }
440
441        while l > 0 {
442            self.carry_out((end >> EC_CODE_SHIFT) as i32);
443            end = (end << EC_SYM_BITS) & (EC_CODE_TOP as u64 - 1);
444            l -= EC_SYM_BITS as i32;
445        }
446
447        if self.rem >= 0 || self.ext > 0 {
448            self.carry_out(0);
449        }
450
451        let mut window = self.end_window;
452        let mut used = self.nend_bits;
453        while used >= EC_SYM_BITS as i32 {
454            self.write_byte_at_end((window & EC_SYM_MAX) as u8);
455            window >>= EC_SYM_BITS;
456            used -= EC_SYM_BITS as i32;
457        }
458
459        if self.error == 0 {
460            // Clear excess space
461            for i in self.offs..(self.storage - self.end_offs) {
462                self.buf[i as usize] = 0;
463            }
464
465            if used > 0 {
466                if self.end_offs >= self.storage {
467                    self.error = -1;
468                } else {
469                    let idx = (self.storage - self.end_offs - 1) as usize;
470                    self.buf[idx] |= window as u8;
471                    // Count the byte containing remaining bits as part of end_offs
472                    self.end_offs += 1;
473                }
474            }
475        }
476    }
477
478    pub fn finish(&mut self) -> Vec<u8> {
479        self.done();
480        // The buffer is written from both ends:
481        // - buf[0..offs]: entropy coded data (from start)
482        // - buf[storage-end_offs..storage]: bit-packed data (from end)
483        // After done(), end_offs includes any byte containing remaining bits
484        let mut result = Vec::with_capacity((self.offs + self.end_offs) as usize);
485        result.extend_from_slice(&self.buf[0..self.offs as usize]);
486        result.extend_from_slice(
487            &self.buf[(self.storage - self.end_offs) as usize..self.storage as usize],
488        );
489        result
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496
497    #[test]
498    fn test_laplace() {
499        let mut enc = RangeCoder::new_encoder(100);
500        let mut val = -3;
501        let fs = 100 << 7;
502        let decay = 120 << 6;
503        enc.laplace_encode(&mut val, fs, decay);
504        enc.done();
505
506        assert_eq!(enc.offs, 1);
507        assert_eq!(enc.buf[0], 224);
508
509        let mut dec = RangeCoder::new_decoder(enc.buf[..enc.offs as usize].to_vec());
510        let decoded_val = dec.laplace_decode(fs, decay);
511        assert_eq!(decoded_val, -3);
512    }
513
514    #[test]
515    fn test_icdf_consistency() {
516        let mut enc = RangeCoder::new_encoder(1024);
517        let icdf = [2, 1, 0]; // ftb=2, ft=4
518        enc.encode_icdf(0, &icdf, 2);
519        enc.encode_icdf(1, &icdf, 2);
520        enc.encode_icdf(2, &icdf, 2);
521        enc.done();
522        let data = enc.buf[..enc.offs as usize].to_vec();
523
524        let mut dec = RangeCoder::new_decoder(data);
525        let s0 = dec.decode_icdf(&icdf, 2);
526        let s1 = dec.decode_icdf(&icdf, 2);
527        let s2 = dec.decode_icdf(&icdf, 2);
528
529        assert_eq!(s0, 0);
530        assert_eq!(s1, 1);
531        assert_eq!(s2, 2);
532    }
533
534    #[test]
535    fn test_bits_only() {
536        let mut enc = RangeCoder::new_encoder(1024);
537
538        enc.enc_bits(1, 1); // 1 bit: value 1
539        enc.enc_bits(5, 3); // 3 bits: value 5 (101)
540        enc.enc_bits(7, 3); // 3 bits: value 7 (111)
541        enc.enc_bits(0, 2); // 2 bits: value 0 (00)
542
543        let data = enc.finish();
544        let mut dec = RangeCoder::new_decoder(data);
545
546        let b1 = dec.dec_bits(1);
547        let b2 = dec.dec_bits(3);
548        let b3 = dec.dec_bits(3);
549        let b4 = dec.dec_bits(2);
550
551        assert_eq!(b1, 1);
552        assert_eq!(b2, 5);
553        assert_eq!(b3, 7);
554        assert_eq!(b4, 0);
555    }
556
557    #[test]
558    fn test_interleaved_bits_entropy() {
559        let mut enc = RangeCoder::new_encoder(1024);
560
561        // 1. Bit
562        enc.enc_bits(1, 1);
563        // 2. Entropy
564        enc.encode(10, 20, 100);
565        // 3. Bit
566        enc.enc_bits(5, 3);
567        // 4. Entropy
568        enc.encode(50, 60, 100);
569
570        let data = enc.finish();
571
572        let mut dec = RangeCoder::new_decoder(data);
573
574        let b1 = dec.dec_bits(1);
575        let d1 = dec.decode(100);
576        dec.update(10, 20, 100);
577        let b2 = dec.dec_bits(3);
578        let d2 = dec.decode(100);
579        dec.update(50, 60, 100);
580
581        assert_eq!(b1, 1);
582        assert!(d1 >= 10 && d1 < 20, "d1={} expected in [10, 20)", d1);
583        assert_eq!(b2, 5);
584        assert!(d2 >= 50 && d2 < 60, "d2={} expected in [50, 60)", d2);
585    }
586}