Skip to main content

oxideav_opus/
range_decoder.rs

1//! Range decoder primitives for the Opus codec.
2//!
3//! This module implements the bit-exact range decoder described in
4//! RFC 6716 §4.1 (`docs/audio/opus/rfc6716-opus.txt`). The implementation
5//! is clean-room: every routine is transcribed from the prose and
6//! pseudocode equations in the RFC; no external library source was
7//! consulted.
8//!
9//! The range decoder is the SHARED entropy primitive that both the SILK
10//! and CELT layers of Opus invoke for every coded symbol. The
11//! [`oxideav-celt`] crate carries its own copy of the same primitive;
12//! each crate owns its copy until a shared low-level primitive crate
13//! exists in the workspace. The two copies are independent clean-room
14//! transcriptions of the same RFC sections and are expected to be
15//! behaviourally identical.
16//!
17//! The following routines are wired up:
18//!
19//! * Initialization (§4.1.1).
20//! * Symbol-update internal helper (§4.1.2).
21//! * Renormalization (§4.1.2.1).
22//! * [`RangeDecoder::decode_bin`] for power-of-two `ft` symbols (§4.1.3.1).
23//! * [`RangeDecoder::dec_bit_logp`] (§4.1.3.2).
24//! * [`RangeDecoder::dec_icdf`] for inverse-CDF table decoding (§4.1.3.3).
25//! * [`RangeDecoder::dec_bits`] for raw bits (§4.1.4).
26//! * [`RangeDecoder::dec_uint`] for uniformly-distributed integers
27//!   (§4.1.5).
28//! * [`RangeDecoder::tell`] for whole-bit accounting (§4.1.6.1).
29//! * [`RangeDecoder::tell_frac`] for 1/8th-bit-precision accounting
30//!   (§4.1.6.2).
31//! * [`RangeDecoder::ec_decode`] / [`RangeDecoder::ec_dec_update`] for
32//!   the generic two-step symbol path (§4.1.2). These are the building
33//!   blocks for custom symbol decoders that an inverse-CDF table cannot
34//!   express directly — notably the CELT §4.3.2.1 coarse-energy
35//!   Laplace decoder and the §4.3.3 allocation interpolation search,
36//!   both of which decode against a frequency model computed at
37//!   run time rather than a fixed `icdf[]` table.
38
39use crate::Error;
40
41/// Bit-exact CELT/SILK range decoder state per RFC 6716 §4.1.
42///
43/// The decoder splits the input buffer into two halves. The range
44/// coder consumes bytes from the front (MSB-first into the range
45/// state) and the raw-bit reader consumes bytes from the back
46/// (LSB-first). RFC 6716 §4.1.4 explicitly permits the two readers
47/// to overlap; the decoder MUST allow it.
48#[derive(Debug)]
49pub struct RangeDecoder<'a> {
50    /// Input bitstream backing this decoder.
51    buf: &'a [u8],
52    /// Offset of the next byte the range coder will consume (advances
53    /// forward through `buf`).
54    fwd: usize,
55    /// Number of bytes consumed by the raw-bit reader, measured from
56    /// the END of `buf`. A value of `0` means no raw bit has yet been
57    /// read; the next raw byte fetched comes from `buf[buf.len() - 1]`.
58    back: usize,
59    /// Number of unconsumed bits currently sitting in `back_window`
60    /// (0..=8 at rest, may exceed during refill).
61    back_bits_avail: u32,
62    /// Buffer of unconsumed raw bits, packed with the next bit to
63    /// emit in bit 0.
64    back_window: u32,
65    /// One-bit buffer holding the LSB of the previously-consumed
66    /// forward byte (used in the next renormalization step, §4.1.2.1).
67    rem: u32,
68    /// Range size; the renormalization invariant is `rng > 2**23`.
69    rng: u32,
70    /// Top of range minus current code value, minus one.
71    val: u32,
72    /// Running tally of whole bits the range coder has consumed
73    /// (RFC 6716 §4.1.6 `nbits_total`).
74    nbits_total: u32,
75    /// Number of raw bits the decoder has read so far. RFC 6716 §4.1.6
76    /// adds these into the bit-usage accounting on top of `nbits_total`.
77    nbits_raw: u32,
78    /// Sticky error flag: any decode that detects a corrupt frame
79    /// latches an error. Once set, subsequent decodes return zeroes
80    /// rather than corrupting the caller's state. RFC 6716 §4.1.5
81    /// recommends this behaviour for malformed integer decodes.
82    error: bool,
83}
84
85impl<'a> RangeDecoder<'a> {
86    /// Renormalization invariant from §4.1.2.1: `rng > 2**23`.
87    const RNG_MIN: u32 = 1 << 23;
88
89    /// Initialize the range decoder over `buf` per RFC 6716 §4.1.1.
90    ///
91    /// The spec defines `b0` as "the first input byte (or zero if
92    /// there are no bytes in this Opus frame)". The decoder sets
93    /// `rng = 128`, `val = 127 - (b0 >> 1)`, buffers the leftover bit
94    /// `(b0 & 1)`, then immediately invokes renormalization so the
95    /// invariant `rng > 2**23` holds before any symbol is decoded.
96    pub fn new(buf: &'a [u8]) -> Self {
97        let b0 = buf.first().copied().unwrap_or(0) as u32;
98        let mut dec = Self {
99            buf,
100            // §4.1.1: the first byte is consumed by initialization,
101            // so the next forward fetch starts at index 1.
102            fwd: if buf.is_empty() { 0 } else { 1 },
103            back: 0,
104            back_bits_avail: 0,
105            back_window: 0,
106            rem: b0 & 1,
107            rng: 128,
108            val: 127 - (b0 >> 1),
109            // §4.1.6: "nbits_total is initialized to 9 just before the
110            // initial range renormalization process completes."
111            nbits_total: 9,
112            nbits_raw: 0,
113            error: false,
114        };
115        dec.normalize();
116        dec
117    }
118
119    /// Whether this decoder has latched a `frame corrupt` error
120    /// somewhere in its history. Higher-level decoders use this to
121    /// abort the current frame and apply packet-loss concealment.
122    pub fn has_error(&self) -> bool {
123        self.error
124    }
125
126    /// Current whole-bit budget consumed by the range coder plus the
127    /// raw-bit reader, per RFC 6716 §4.1.6.1.
128    ///
129    /// `ec_tell` is defined as `nbits_total - ilog(rng)`. Raw bits are
130    /// added separately because §4.1.6 specifies that raw bits also
131    /// count against the total.
132    pub fn tell(&self) -> u32 {
133        // `ilog(rng)` is the position of the most-significant set bit
134        // of `rng`, counting from 1. The renormalization invariant
135        // keeps `rng >= 2**23`, so `lg` is always at least 24.
136        let lg = 32 - self.rng.leading_zeros();
137        self.nbits_total
138            .saturating_sub(lg)
139            .saturating_add(self.nbits_raw)
140    }
141
142    /// Current 1/8th-bit-precision budget consumed by the range coder
143    /// plus the raw-bit reader, per RFC 6716 §4.1.6.2.
144    ///
145    /// Follows §4.1.6.2 directly: from `lg = ilog(rng)`, extract
146    /// `r_Q15 = rng >> (lg - 16)` as a Q15 value in `[2^15, 2^16)`.
147    /// Three iterations of
148    /// `r_Q15 = (r_Q15*r_Q15) >> 15; lg = 2*lg + (r_Q15 >> 16)` extend
149    /// `lg` to 1/8th-bit precision. Raw bits add `8*nbits_raw` (whole
150    /// bits scaled into eighths). By construction,
151    /// `ec_tell() == ceil(ec_tell_frac() / 8.0)`.
152    pub fn tell_frac(&self) -> u32 {
153        let lg0 = 32 - self.rng.leading_zeros();
154        // §4.1.6.2: lg >= 24 after renormalization, so the shift below
155        // is well-defined.  r_Q15 in [2^15, 2^16).
156        let mut r_q15 = self.rng >> (lg0 - 16);
157        // Build the 1/8th-bit-precision lg one bit at a time. The
158        // spec doubles `lg` on each of the three refinement passes;
159        // the accumulator starts at the whole-bit value `lg0`.
160        let mut lg_frac = lg0;
161        // Three passes yield three extra bits = 1/8th-bit precision.
162        for _ in 0..3 {
163            r_q15 = (r_q15 * r_q15) >> 15;
164            let bit = r_q15 >> 16;
165            lg_frac = 2 * lg_frac + bit;
166            // If `bit == 1`, halve r_Q15 so it falls back into
167            // [2^15, 2^16).
168            if bit == 1 {
169                r_q15 >>= 1;
170            }
171        }
172        // Final value = nbits_total*8 - lg_frac + nbits_raw*8.
173        self.nbits_total
174            .saturating_mul(8)
175            .saturating_sub(lg_frac)
176            .saturating_add(self.nbits_raw.saturating_mul(8))
177    }
178
179    /// Decode a single binary symbol with probability `2^-logp` of
180    /// being a "1", per RFC 6716 §4.1.3.2.
181    ///
182    /// Mathematically equivalent to `ec_decode(ft = 1<<logp)` followed
183    /// by `ec_dec_update(0, ft-1, ft)` (for a "0") or
184    /// `ec_dec_update(ft-1, ft, ft)` (for a "1"). The implementation
185    /// is multiply-and-divide-free: `r >> logp` replaces `rng/ft`, and
186    /// the discriminator collapses to a comparison.
187    pub fn dec_bit_logp(&mut self, logp: u32) -> u32 {
188        let r = self.rng;
189        let d = self.val;
190        // `s = r >> logp` corresponds to `rng/ft` with `ft = 1<<logp`
191        // (an exact shift when ft is a power of two).
192        let s = r >> logp;
193        // The "1" half corresponds to `fl = ft-1, fh = ft`, leading to
194        //   val unchanged, rng = s.
195        // The "0" half is `fl = 0, fh = ft-1`, leading to
196        //   val -= s, rng = r - s.
197        let bit = if d < s { 1 } else { 0 };
198        if bit == 1 {
199            self.rng = s;
200        } else {
201            self.val = d - s;
202            self.rng = r - s;
203        }
204        self.normalize();
205        bit
206    }
207
208    /// Decode `bits` raw bits per RFC 6716 §4.1.4.
209    ///
210    /// Raw bits are packed at the END of the frame: the least
211    /// significant bit of the first value is the LSB of the last
212    /// byte; reads proceed toward the front. The function returns the
213    /// raw bits in the order written — the LSB of the result holds
214    /// the bit the encoder emitted first.
215    ///
216    /// Returns `0` on errors (`bits > 32`); also returns zero-extended
217    /// bits past the end of the frame, matching §4.1.4's "the decoder
218    /// MUST continue to use zero for any further input bytes required".
219    pub fn dec_bits(&mut self, bits: u32) -> u32 {
220        if bits == 0 {
221            return 0;
222        }
223        if bits > 32 {
224            self.error = true;
225            return 0;
226        }
227        let mut window = self.back_window;
228        let mut avail = self.back_bits_avail;
229        // Refill the window until it holds enough bits to service the
230        // requested read.
231        while avail < bits {
232            let byte = if self.back < self.buf.len() {
233                self.buf[self.buf.len() - 1 - self.back]
234            } else {
235                // §4.1.4: zero-extend past the end of the frame.
236                0
237            };
238            self.back = self.back.saturating_add(1);
239            // Concatenate the new byte ABOVE the existing window so the
240            // intra-byte LSB-first packing is preserved.
241            window |= (byte as u32) << avail;
242            avail += 8;
243        }
244        let mask: u32 = if bits == 32 { !0 } else { (1u32 << bits) - 1 };
245        let result = window & mask;
246        // Consume the served bits.
247        self.back_window = window >> bits;
248        self.back_bits_avail = avail - bits;
249        self.nbits_raw += bits;
250        result
251    }
252
253    /// Decode one of `ft` equiprobable values in `0..ft`, per
254    /// RFC 6716 §4.1.5.
255    ///
256    /// Values of `ft <= 1` degenerate to the constant `0`. `ft` may
257    /// be as large as `2^32 - 1`. The §4.1.5 procedure splits the
258    /// value: the top 8 bits go through the range coder, the
259    /// remainder through raw bits. If the reconstructed value is
260    /// `>= ft`, the frame is corrupt — the decoder latches the error
261    /// flag and saturates to `ft - 1` per §4.1.5's concealment
262    /// recommendation.
263    pub fn dec_uint(&mut self, ft: u32) -> Result<u32, Error> {
264        if ft <= 1 {
265            return Ok(0);
266        }
267        // `ftb = ilog(ft - 1)`: number of bits needed for `ft - 1`.
268        let ftb = 32 - (ft - 1).leading_zeros();
269        if ftb <= 8 {
270            // Small case: a single range-coded symbol covers the whole
271            // value.
272            let t = self.decode(ft);
273            self.dec_update(t, t + 1, ft);
274            Ok(t)
275        } else {
276            // Large case: top 8 bits range-coded, remainder raw.
277            let split_bits = ftb - 8;
278            let top_ft = ((ft - 1) >> split_bits) + 1;
279            let t_hi = self.decode(top_ft);
280            self.dec_update(t_hi, t_hi + 1, top_ft);
281            let t_lo = self.dec_bits(split_bits);
282            let t = (t_hi << split_bits) | t_lo;
283            if t >= ft {
284                self.error = true;
285                Ok(ft - 1)
286            } else {
287                Ok(t)
288            }
289        }
290    }
291
292    /// Decode `fs` for a power-of-two `ft = 1<<ftb` per RFC 6716
293    /// §4.1.3.1 (`ec_decode_bin`).
294    ///
295    /// Mathematically equivalent to [`Self::decode`] with `ft = 1<<ftb`
296    /// but avoids the division: `rng / ft == rng >> ftb`. The caller is
297    /// expected to follow with [`Self::dec_update`] (or use
298    /// [`Self::dec_icdf`] which fuses the two steps).
299    ///
300    /// Returns `fs` in the range `[0, 1<<ftb)`.
301    pub fn decode_bin(&mut self, ftb: u32) -> u32 {
302        let s = self.rng >> ftb;
303        if s == 0 {
304            // Would only happen for ftb > ilog(rng). The
305            // renormalization invariant keeps ilog(rng) >= 24, so any
306            // practical ftb (icdf uses up to 8) is safe. Defensively
307            // saturate to 0.
308            return 0;
309        }
310        let ft = 1u32 << ftb;
311        let approx = (self.val / s).saturating_add(1);
312        ft - approx.min(ft)
313    }
314
315    /// Decode a symbol via an inverse-CDF table, per RFC 6716 §4.1.3.3
316    /// (`ec_dec_icdf`).
317    ///
318    /// `icdf[k]` stores `(1<<ftb) - fh[k]`, terminated by a `0` entry
319    /// (the implicit `fh[K_last] == ft`). `fl[0]` is implicitly 0; the
320    /// table values are strictly monotonically decreasing.
321    ///
322    /// Fuses the search step (find the smallest `k` such that
323    /// `fs < ft - icdf[k]`) with the range/value update, eliminating
324    /// the division. The renormalization loop runs before returning.
325    ///
326    /// Returns the decoded symbol index `k` in `0..icdf.len()-1`. On a
327    /// malformed table (no terminating zero), the decoder latches its
328    /// sticky error flag and returns 0.
329    pub fn dec_icdf(&mut self, icdf: &[u8], ftb: u32) -> u32 {
330        // `s` corresponds to `rng / ft` for `ft = 1<<ftb`.
331        let s = self.rng >> ftb;
332        // Forward walk: for each candidate k, compute
333        //   next = s * icdf[k]
334        // which is the "remaining range above this symbol". The first
335        // k where `val >= next` is the decoded symbol. `t` tracks the
336        // previous step's `next` so that `rng' = t - next` matches the
337        // §4.1.2 update for `fl[k] == prev_next/s` and `fh[k] ==
338        // next/s` (with k=0 falling out to `rng - s*icdf[0]` since
339        // `t` starts at `rng`).
340        let mut t = self.rng;
341        for (k, &cell) in icdf.iter().enumerate() {
342            let next = s.saturating_mul(cell as u32);
343            if self.val >= next {
344                self.val -= next;
345                self.rng = t - next;
346                self.normalize();
347                return k as u32;
348            }
349            t = next;
350        }
351        // Malformed table: no terminator reached. §4.1.5 advises
352        // latching the corrupt-frame error and returning a saturated
353        // value.
354        self.error = true;
355        0
356    }
357
358    /// `ec_decode(ft)` per RFC 6716 §4.1.2 — the first of the two
359    /// symbol-decode steps.
360    ///
361    /// Computes the 16-bit symbol proxy
362    /// `fs = ft - min(val / (rng / ft) + 1, ft)`, which "lies within
363    /// the range of some symbol in the current context". The caller
364    /// then identifies the symbol `k` whose three-tuple
365    /// `(fl[k], fh[k], ft)` satisfies `fl[k] <= fs < fh[k]` and feeds
366    /// that tuple to [`Self::ec_dec_update`].
367    ///
368    /// This split form is needed when the frequency model is computed
369    /// at run time and cannot be pre-baked into a static inverse-CDF
370    /// table (RFC 6716 §4.3.2.1 coarse-energy Laplace decode, §4.3.3
371    /// allocation search). For fixed PDFs prefer [`Self::dec_icdf`],
372    /// which fuses both steps and avoids a division.
373    ///
374    /// `ft` must be in `1..=2**16` for the §4.1.2 derivation to hold
375    /// (the renormalization invariant keeps `rng > 2**23`, so
376    /// `rng / ft >= 1`). `ft == 0` would divide by zero; the decoder
377    /// latches its sticky error flag and returns `0` instead. The
378    /// returned `fs` lies in `[0, ft)`.
379    pub fn ec_decode(&mut self, ft: u32) -> u32 {
380        if ft == 0 {
381            self.error = true;
382            return 0;
383        }
384        self.decode(ft)
385    }
386
387    /// `ec_dec_update(fl, fh, ft)` per RFC 6716 §4.1.2 — the second of
388    /// the two symbol-decode steps.
389    ///
390    /// Narrows the range to the chosen symbol's `[fl, fh)` sub-interval
391    /// of `[0, ft)` per the §4.1.2 update equations, then renormalizes
392    /// to restore `rng > 2**23`. Pair this with the index returned by
393    /// the caller's search over the value from [`Self::ec_decode`].
394    ///
395    /// The three-tuple MUST satisfy `0 <= fl < fh <= ft` and
396    /// `1 <= ft <= 2**16`. A malformed tuple (`ft == 0`, `fh > ft`, or
397    /// `fl >= fh`) cannot come from a well-formed search; the decoder
398    /// latches its sticky error flag and leaves its state unchanged
399    /// rather than underflowing `val` or zeroing `rng`.
400    pub fn ec_dec_update(&mut self, fl: u32, fh: u32, ft: u32) {
401        if ft == 0 || fh > ft || fl >= fh {
402            self.error = true;
403            return;
404        }
405        self.dec_update(fl, fh, ft);
406    }
407
408    // ----- internal helpers -----
409
410    /// `ec_decode(ft)` per RFC 6716 §4.1.2: compute the symbol-proxy
411    /// `fs = ft - min(val / (rng / ft) + 1, ft)`.
412    fn decode(&mut self, ft: u32) -> u32 {
413        // The spec uses integer division. `rng/ft` is computed first;
414        // the divisor is then `val / (rng/ft)`. The renormalization
415        // invariant ensures `rng/ft >= 1` in all practical cases
416        // (rng > 2**23 and ft <= 2**16 on the symbol-decode path).
417        let s = self.rng / ft;
418        let approx = self.val / s + 1;
419        ft - approx.min(ft)
420    }
421
422    /// `ec_dec_update(fl, fh, ft)` per RFC 6716 §4.1.2.
423    ///
424    /// Narrows the range to the chosen symbol's interval, then runs
425    /// renormalization to restore `rng > 2**23`.
426    fn dec_update(&mut self, fl: u32, fh: u32, ft: u32) {
427        let s = self.rng / ft;
428        self.val -= s * (ft - fh);
429        if fl > 0 {
430            self.rng = s * (fh - fl);
431        } else {
432            self.rng -= s * (ft - fh);
433        }
434        self.normalize();
435    }
436
437    /// `ec_dec_normalize` per RFC 6716 §4.1.2.1.
438    ///
439    /// Until `rng > 2**23`, shift `rng` left by 8 and pull a fresh
440    /// `sym` byte. `sym` combines the previously-buffered low bit
441    /// (`rem`, as MSB) with the top 7 bits of the new byte; the LSB of
442    /// the new byte is buffered for next time. When the frame is
443    /// exhausted, zero bytes are substituted.
444    fn normalize(&mut self) {
445        while self.rng <= Self::RNG_MIN {
446            let byte = if self.fwd < self.buf.len() {
447                let b = self.buf[self.fwd];
448                self.fwd += 1;
449                b as u32
450            } else {
451                0
452            };
453            let sym = (self.rem << 7) | (byte >> 1);
454            self.rem = byte & 1;
455            self.rng <<= 8;
456            self.val = ((self.val << 8) + (255 - sym)) & 0x7FFF_FFFF;
457            // §4.1.6: each iteration adds 8 to nbits_total.
458            self.nbits_total = self.nbits_total.saturating_add(8);
459        }
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    /// §4.1.1 initialization over an empty buffer must still satisfy
468    /// the §4.1.2.1 invariant and report `ec_tell() == 1`
469    /// (§4.1.6.1: "In a newly initialized decoder, before any symbols
470    /// have been read, this reports that 1 bit has been used").
471    #[test]
472    fn init_empty_buffer_satisfies_invariant() {
473        let dec = RangeDecoder::new(&[]);
474        assert!(dec.rng > RangeDecoder::RNG_MIN);
475        assert!(!dec.has_error());
476        assert_eq!(dec.tell(), 1);
477    }
478
479    /// Non-empty initialization also satisfies the invariant and
480    /// reports a sensible tell.
481    #[test]
482    fn init_nonempty_buffer_holds_invariant() {
483        let dec = RangeDecoder::new(&[0xAB, 0xCD, 0xEF, 0x12]);
484        assert!(dec.rng > RangeDecoder::RNG_MIN);
485        assert!(!dec.has_error());
486        assert!(dec.tell() >= 1);
487    }
488
489    /// `dec_bit_logp` should be statistically biased by the surrounding
490    /// bytes: an all-zero stream pushes `val` high, biasing toward "0",
491    /// and an all-ones stream pushes it low, biasing toward "1".
492    #[test]
493    fn dec_bit_logp_bias_with_extreme_inputs() {
494        // All-zero stream: bias toward "0".
495        let mut dec0 = RangeDecoder::new(&[0u8; 16]);
496        let mut zero_count = 0;
497        for _ in 0..32 {
498            if dec0.dec_bit_logp(1) == 0 {
499                zero_count += 1;
500            }
501        }
502        assert!(!dec0.has_error());
503        assert!(
504            zero_count > 16,
505            "all-zero stream should be biased toward 0: zero_count={}",
506            zero_count
507        );
508
509        // All-ones stream: bias toward "1".
510        let mut dec1 = RangeDecoder::new(&[0xFFu8; 16]);
511        let mut one_count = 0;
512        for _ in 0..32 {
513            if dec1.dec_bit_logp(1) == 1 {
514                one_count += 1;
515            }
516        }
517        assert!(!dec1.has_error());
518        assert!(
519            one_count > 16,
520            "all-ones stream should be biased toward 1: one_count={}",
521            one_count
522        );
523    }
524
525    /// `dec_bits` reads raw bits LSB-first from the END of the buffer.
526    /// With the last byte = 0b1010_0110, the first 4 raw bits returned
527    /// are 0b0110 = 6, and the next 4 are 0b1010 = 0xA.
528    #[test]
529    fn dec_bits_lsb_first_from_end() {
530        let mut dec = RangeDecoder::new(&[0x00, 0x00, 0xA6]);
531        let lo = dec.dec_bits(4);
532        let hi = dec.dec_bits(4);
533        assert_eq!(lo, 0x6);
534        assert_eq!(hi, 0xA);
535        assert!(!dec.has_error());
536    }
537
538    /// `dec_bits` past the end of the frame must zero-extend, per
539    /// §4.1.4 ("the decoder MUST continue to use zero for any further
540    /// input bytes required"). The function must not panic or set the
541    /// error flag in that case.
542    #[test]
543    fn dec_bits_zero_past_end_of_frame() {
544        let mut dec = RangeDecoder::new(&[0xFF, 0xFF]);
545        for _ in 0..4 {
546            let v = dec.dec_bits(4);
547            assert_eq!(v, 0xF);
548        }
549        // The next 8 bits should come back as zero (the range coder
550        // may or may not have shared bytes with the raw reader — but
551        // the *raw* side reads past-EOF as 0).
552        let pad = dec.dec_bits(8);
553        let _ = pad;
554        assert!(!dec.has_error());
555    }
556
557    /// `dec_uint(1)` is degenerate — the only value in `0..1` is 0 —
558    /// and consumes no bits.
559    #[test]
560    fn dec_uint_ft_one_is_zero_no_consumption() {
561        let mut dec = RangeDecoder::new(&[0x12, 0x34, 0x56]);
562        let before = dec.tell();
563        let v = dec.dec_uint(1).expect("ft=1 must succeed");
564        let after = dec.tell();
565        assert_eq!(v, 0);
566        assert_eq!(after, before);
567    }
568
569    /// `dec_uint` with `ft` in the small (`ftb <= 8`) regime: returned
570    /// values must lie in `[0, ft)` and never trip the error flag for
571    /// well-formed inputs.
572    #[test]
573    fn dec_uint_small_ft_in_range() {
574        let mut dec = RangeDecoder::new(&[0x42, 0x18, 0xC3, 0x7F]);
575        for _ in 0..8 {
576            let v = dec.dec_uint(200).expect("ft=200 must succeed");
577            assert!(v < 200, "v={} out of range", v);
578        }
579        assert!(!dec.has_error());
580    }
581
582    /// `dec_uint` with `ft` in the large (`ftb > 8`) regime: returned
583    /// values must lie in `[0, ft)`. The saturation path is allowed
584    /// to set the error flag, but the returned value remains bounded.
585    #[test]
586    fn dec_uint_large_ft_in_range() {
587        let buf: Vec<u8> = (0..64).collect();
588        let mut dec = RangeDecoder::new(&buf);
589        for _ in 0..8 {
590            let v = dec.dec_uint(1_000_000).expect("ft=1_000_000 must succeed");
591            assert!(v < 1_000_000, "v={} out of range", v);
592        }
593    }
594
595    /// `dec_uint` with `ft = 0` is degenerate and returns 0 without
596    /// consumption.
597    #[test]
598    fn dec_uint_ft_zero_returns_zero() {
599        let mut dec = RangeDecoder::new(&[0xAA, 0xBB, 0xCC, 0xDD]);
600        let before = dec.tell();
601        let v = dec.dec_uint(0).expect("ft=0 must succeed");
602        assert_eq!(v, 0);
603        assert_eq!(dec.tell(), before);
604    }
605
606    /// `tell()` must monotonically non-decrease across operations.
607    #[test]
608    fn tell_is_monotonic_across_decodes() {
609        let mut dec = RangeDecoder::new(&[0x55; 8]);
610        let mut prev = dec.tell();
611        for _ in 0..16 {
612            let _ = dec.dec_bit_logp(2);
613            let now = dec.tell();
614            assert!(now >= prev, "tell() went backwards: {} -> {}", prev, now);
615            prev = now;
616        }
617    }
618
619    /// `decode_bin(ftb)` must agree with the generic `decode(1<<ftb)`
620    /// path bit-for-bit (RFC 6716 §4.1.3.1: the two are mathematically
621    /// equivalent). Drive both with the same input bytes and compare.
622    #[test]
623    fn decode_bin_matches_generic_decode() {
624        for &ftb in &[1u32, 4, 8, 12, 15] {
625            let buf = [0x37u8, 0x91, 0xC4, 0x18, 0xA2, 0x5D, 0x6E, 0xFF];
626            let mut a = RangeDecoder::new(&buf);
627            let mut b = RangeDecoder::new(&buf);
628            let from_bin = a.decode_bin(ftb);
629            let from_generic = b.decode(1u32 << ftb);
630            assert_eq!(
631                from_bin, from_generic,
632                "decode_bin({ftb}) != decode(1<<{ftb})"
633            );
634            assert!(from_bin < (1u32 << ftb), "fs={from_bin} out of range");
635        }
636    }
637
638    /// RFC 6716 §4.1.6.1 specifies the identity
639    /// `ec_tell() == ceil(ec_tell_frac() / 8.0)`. Walk a decoder
640    /// forward through mixed symbol and raw-bit reads and assert this
641    /// at every step.
642    #[test]
643    fn tell_frac_consistent_with_tell() {
644        let mut dec = RangeDecoder::new(&[0xA3, 0x7F, 0x10, 0x5C, 0xE8, 0x91, 0x42, 0xB7]);
645        // §4.1.6.1: a fresh decoder reports tell() == 1.
646        assert_eq!(dec.tell(), 1);
647        for _ in 0..12 {
648            let whole = dec.tell();
649            let frac = dec.tell_frac();
650            let ceil_eighths = frac.div_ceil(8);
651            assert_eq!(
652                ceil_eighths, whole,
653                "tell()={whole} != ceil(tell_frac()={frac} / 8)={ceil_eighths}"
654            );
655            let _ = dec.dec_bit_logp(1);
656            let _ = dec.dec_bits(2);
657        }
658    }
659
660    /// `tell_frac()` of a fresh decoder sits in `[1, 8]` (since
661    /// `tell()` is `1` and the §4.1.6.1 ceiling identity holds).
662    #[test]
663    fn tell_frac_initial_within_one_bit() {
664        let dec = RangeDecoder::new(&[0xCC, 0xDD, 0xEE, 0xFF]);
665        let frac = dec.tell_frac();
666        assert!(
667            (1..=8).contains(&frac),
668            "tell_frac initial out of [1,8]: {frac}"
669        );
670        assert!(frac.div_ceil(8) == dec.tell());
671    }
672
673    /// `dec_icdf` over a binary `{ft - 1, 1}/ft` distribution must
674    /// agree with `dec_bit_logp(logp)` step-for-step — both are
675    /// special cases of `ec_decode` with `ft = 1<<ftb` (RFC 6716
676    /// §4.1.3.2 + §4.1.3.3).
677    #[test]
678    fn dec_icdf_matches_dec_bit_logp_for_binary() {
679        let buf = [0xDE, 0xAD, 0xBE, 0xEF, 0x10, 0x32, 0x54, 0x76];
680        // logp = 3 → ft = 8, P("1") = 1/8. icdf {ft-fh[0], ft-fh[1]} =
681        // {1, 0}: symbol 0 is the high-probability outcome (the "0"
682        // bit).
683        let logp = 3u32;
684        let icdf = [1u8, 0];
685        let mut a = RangeDecoder::new(&buf);
686        let mut b = RangeDecoder::new(&buf);
687        for _ in 0..16 {
688            let via_logp = a.dec_bit_logp(logp);
689            let via_icdf = b.dec_icdf(&icdf, logp);
690            assert_eq!(
691                via_logp, via_icdf,
692                "dec_bit_logp({logp}) != dec_icdf({icdf:?}, {logp})"
693            );
694        }
695        assert!(!a.has_error() && !b.has_error());
696    }
697
698    /// `dec_icdf` over a uniform `{1,1,1,1,1,1,1,1}/8` PDF must return
699    /// a symbol in `[0, 8)` every time without error.
700    #[test]
701    fn dec_icdf_uniform_returns_in_range() {
702        // Uniform 8-way PDF: fh = {1,2,3,4,5,6,7,8} → icdf =
703        // {7,6,5,4,3,2,1,0}.
704        let icdf = [7u8, 6, 5, 4, 3, 2, 1, 0];
705        let mut dec = RangeDecoder::new(&[0x42, 0x18, 0xC3, 0x7F, 0x55, 0xAA, 0x33, 0xCC]);
706        for _ in 0..16 {
707            let k = dec.dec_icdf(&icdf, 3);
708            assert!(k < 8, "icdf uniform returned {k} out of [0, 8)");
709        }
710        assert!(!dec.has_error());
711    }
712
713    /// `dec_icdf` over the degenerate single-symbol table `{0}` (only
714    /// the terminator) means symbol 0 covers the whole interval, so it
715    /// is always returned. No range mass is consumed and the error
716    /// flag stays clear.
717    #[test]
718    fn dec_icdf_single_symbol_always_zero() {
719        let icdf = [0u8];
720        let mut dec = RangeDecoder::new(&[0x77, 0x33, 0x11, 0xAA]);
721        let before_tell = dec.tell();
722        for _ in 0..4 {
723            let k = dec.dec_icdf(&icdf, 3);
724            assert_eq!(k, 0);
725        }
726        assert!(dec.tell() >= before_tell);
727        assert!(!dec.has_error());
728    }
729
730    /// `tell_frac()` is monotonically non-decreasing across mixed ops
731    /// (§4.1.6.2 inherits the monotonicity of `ec_tell` since the
732    /// procedure only adds bits).
733    #[test]
734    fn tell_frac_is_monotonic() {
735        let mut dec = RangeDecoder::new(&[0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88]);
736        // Uniform 8-way icdf so each call burns ~3 bits.
737        let icdf = [7u8, 6, 5, 4, 3, 2, 1, 0];
738        let mut prev = dec.tell_frac();
739        for i in 0..24 {
740            match i % 3 {
741                0 => {
742                    let _ = dec.dec_bit_logp(2);
743                }
744                1 => {
745                    let _ = dec.dec_icdf(&icdf, 3);
746                }
747                _ => {
748                    let _ = dec.dec_bits(2);
749                }
750            }
751            let now = dec.tell_frac();
752            assert!(
753                now >= prev,
754                "tell_frac() went backwards: {} -> {}",
755                prev,
756                now
757            );
758            prev = now;
759        }
760    }
761
762    /// `dec_bits(0)` returns 0 and consumes nothing.
763    #[test]
764    fn dec_bits_zero_width_is_noop() {
765        let mut dec = RangeDecoder::new(&[0x12, 0x34, 0x56]);
766        let before = dec.tell();
767        let v = dec.dec_bits(0);
768        assert_eq!(v, 0);
769        assert_eq!(dec.tell(), before);
770        assert!(!dec.has_error());
771    }
772
773    /// `dec_bits` with an over-large width sets the error flag and
774    /// returns 0 (guard against caller misuse).
775    #[test]
776    fn dec_bits_oversize_latches_error() {
777        let mut dec = RangeDecoder::new(&[0xAA, 0xBB, 0xCC, 0xDD]);
778        let v = dec.dec_bits(33);
779        assert_eq!(v, 0);
780        assert!(dec.has_error());
781    }
782
783    /// The public two-step `ec_decode` / `ec_dec_update` path must
784    /// reproduce, symbol for symbol, what the fused `dec_icdf` produces
785    /// for the same fixed PDF. This is the RFC 6716 §4.1.2 ↔ §4.1.3.3
786    /// equivalence: `dec_icdf` is exactly `ec_decode(1<<ftb)` followed
787    /// by a search and `ec_dec_update`. We drive both decoders over
788    /// identical input bytes and assert they stay in lockstep.
789    #[test]
790    fn ec_decode_update_matches_dec_icdf_for_fixed_pdf() {
791        // Uniform 8-way PDF: fh = {1,2,..,8}, ft = 8, icdf = {7,..,0}.
792        let icdf = [7u8, 6, 5, 4, 3, 2, 1, 0];
793        let ftb = 3u32;
794        let ft = 1u32 << ftb;
795        let buf = [0x42u8, 0x18, 0xC3, 0x7F, 0x55, 0xAA, 0x33, 0xCC];
796        let mut fused = RangeDecoder::new(&buf);
797        let mut split = RangeDecoder::new(&buf);
798        for _ in 0..16 {
799            let k_fused = fused.dec_icdf(&icdf, ftb);
800
801            // Reconstruct the same decode via the public split steps.
802            let fs = split.ec_decode(ft);
803            // icdf[k] == ft - fh[k]; fl[k] == fh[k-1] (fl[0] == 0).
804            // Find the symbol whose [fl, fh) contains fs.
805            let mut k_split = 0u32;
806            let mut fl = 0u32;
807            let mut fh = ft - icdf[0] as u32;
808            for (idx, w) in icdf.windows(2).enumerate() {
809                if fs < fh {
810                    break;
811                }
812                fl = ft - w[0] as u32;
813                fh = ft - w[1] as u32;
814                k_split = (idx + 1) as u32;
815            }
816            split.ec_dec_update(fl, fh, ft);
817
818            assert_eq!(
819                k_fused, k_split,
820                "fused dec_icdf and split ec_decode/ec_dec_update diverged"
821            );
822        }
823        assert!(!fused.has_error() && !split.has_error());
824    }
825
826    /// `ec_decode(ft)` returns a value in `[0, ft)` for every well-formed
827    /// `ft`, matching the §4.1.2 derivation `fs = ft - min(.., ft)`.
828    #[test]
829    fn ec_decode_returns_in_range() {
830        let buf = [0x37u8, 0x91, 0xC4, 0x18, 0xA2, 0x5D, 0x6E, 0xFF];
831        for &ft in &[2u32, 7, 100, 1000, 1 << 16] {
832            let mut dec = RangeDecoder::new(&buf);
833            let fs = dec.ec_decode(ft);
834            assert!(fs < ft, "ec_decode({ft}) = {fs} out of [0, {ft})");
835            assert!(!dec.has_error());
836        }
837    }
838
839    /// `ec_decode(0)` is malformed (division by zero in the §4.1.2
840    /// formula); it must latch the error flag and return 0 rather than
841    /// panic.
842    #[test]
843    fn ec_decode_ft_zero_latches_error() {
844        let mut dec = RangeDecoder::new(&[0x11, 0x22, 0x33, 0x44]);
845        let fs = dec.ec_decode(0);
846        assert_eq!(fs, 0);
847        assert!(dec.has_error());
848    }
849
850    /// `ec_dec_update` with a malformed tuple (`fl >= fh`, `fh > ft`,
851    /// or `ft == 0`) latches the error flag and leaves state untouched,
852    /// guarding against an underflow of `val` or a zeroing of `rng`.
853    #[test]
854    fn ec_dec_update_rejects_malformed_tuple() {
855        // fl >= fh
856        let mut a = RangeDecoder::new(&[0xAA, 0xBB, 0xCC, 0xDD]);
857        a.ec_dec_update(4, 4, 8);
858        assert!(a.has_error());
859
860        // fh > ft
861        let mut b = RangeDecoder::new(&[0xAA, 0xBB, 0xCC, 0xDD]);
862        b.ec_dec_update(0, 9, 8);
863        assert!(b.has_error());
864
865        // ft == 0
866        let mut c = RangeDecoder::new(&[0xAA, 0xBB, 0xCC, 0xDD]);
867        c.ec_dec_update(0, 1, 0);
868        assert!(c.has_error());
869    }
870
871    /// The public split path must also reproduce the `dec_uint` small
872    /// regime (`ftb <= 8`), which is internally `decode(ft)` followed by
873    /// `dec_update(t, t+1, ft)`. Drive `dec_uint` and the public
874    /// `ec_decode` / `ec_dec_update` in lockstep for a small `ft`.
875    #[test]
876    fn ec_decode_update_matches_dec_uint_small_regime() {
877        let buf = [0x42u8, 0x18, 0xC3, 0x7F, 0x55, 0xAA, 0x33, 0xCC];
878        let ft = 200u32; // ftb <= 8 → small dec_uint regime
879        let mut via_uint = RangeDecoder::new(&buf);
880        let mut via_split = RangeDecoder::new(&buf);
881        for _ in 0..8 {
882            let u = via_uint.dec_uint(ft).expect("ft=200 small regime");
883            let t = via_split.ec_decode(ft);
884            via_split.ec_dec_update(t, t + 1, ft);
885            assert_eq!(u, t, "dec_uint and split ec_decode/update diverged");
886        }
887        assert!(!via_uint.has_error() && !via_split.has_error());
888    }
889}