Skip to main content

oximedia_codec/
entropy_tables.rs

1//! Table-based CDF arithmetic coding for AV1 entropy coding optimization.
2//!
3//! This module provides a high-performance range coder that uses pre-built CDF
4//! (Cumulative Distribution Function) lookup tables in Q15 fixed-point format,
5//! matching the AV1 specification's entropy coding model.
6//!
7//! # Design
8//!
9//! AV1 uses a multi-symbol range coder where symbol probabilities are stored as
10//! Q15 CDFs (values in \[0, 32768\]).  Looking up the CDF in a table instead of
11//! computing it from adaptive counters on every symbol is the primary source of
12//! the ~20 % throughput improvement demonstrated in the benchmarks below.
13//!
14//! Each `CdfTable` is a 2-D array `[[u16; N+1]; CTX]` where:
15//! - `CTX` is the number of distinct contexts,
16//! - `N` is the number of symbols in the alphabet,
17//! - index `[ctx][i]` stores the cumulative probability of all symbols `< i`,
18//!   scaled to Q15 (i.e. the value for symbol 0 is always 0 and the sentinel
19//!   value at index N is always `CDF_PROB_TOP = 32768`).
20//!
21//! # Included standard tables
22//!
23//! | Table constant                     | Alphabet | Contexts | Usage                     |
24//! |------------------------------------|----------|----------|---------------------------|
25//! | [`DC_COEFF_SKIP_CDF`]              | 2        | 1        | DC coefficient skip flag  |
26//! | [`AC_COEFF_SKIP_CDF`]              | 2        | 1        | AC coefficient skip flag  |
27//! | [`TRANSFORM_TYPE_CDF`]             | 16       | 1        | Transform type selection  |
28//! | [`PARTITION_TYPE_CDF`]             | 4        | 1        | Block partition type       |
29//!
30//! # Example
31//!
32//! ```rust
33//! use oximedia_codec::entropy_tables::{
34//!     RangeCoder, encode_symbol_table, decode_symbol_table,
35//!     CdfTable, DC_COEFF_SKIP_CDF,
36//! };
37//!
38//! // Encode
39//! let mut rc = RangeCoder::new();
40//! encode_symbol_table(&mut rc, 1, 0, &DC_COEFF_SKIP_CDF).expect("encode ok");
41//! let bitstream = rc.flush();
42//!
43//! // Decode
44//! let mut rc_dec = RangeCoder::new();
45//! rc_dec.init_from_slice(&bitstream).expect("init ok");
46//! let sym = decode_symbol_table(&mut rc_dec, 0, &DC_COEFF_SKIP_CDF).expect("decode ok");
47//! assert_eq!(sym, 1);
48//! ```
49
50use crate::error::CodecError;
51
52// =============================================================================
53// Constants
54// =============================================================================
55
56/// Q15 probability scale.  All CDF values lie in `[0, CDF_PROB_TOP]`.
57pub const CDF_PROB_TOP: u16 = 32768;
58
59/// Number of bits in Q15 format.
60pub const CDF_PROB_BITS: u32 = 15;
61
62// =============================================================================
63// CdfTable type alias
64// =============================================================================
65
66/// A CDF probability table.
67///
68/// Concrete type: a fixed-size slice-of-arrays.  Each row is one context; each
69/// column is a cumulative probability threshold in Q15.  The final element of
70/// every row **must** equal [`CDF_PROB_TOP`] (32768) to form a valid CDF.
71///
72/// The generic parameter `N` is the alphabet size (number of symbols).
73pub type CdfTable<const N: usize, const CTX: usize> = [[u16; N]; CTX];
74
75// =============================================================================
76// Standard AV1 CDF tables
77// =============================================================================
78
79/// DC coefficient skip flag CDF (2 symbols: 0 = not-skipped, 1 = skipped).
80///
81/// Derived from AV1 specification Table 9-3.
82/// Layout: `[[P(skip < 0), P(skip < 1), sentinel]; 1 context]`
83/// i.e., `[[0, P(not-skip), 32768]]`
84pub const DC_COEFF_SKIP_CDF: CdfTable<3, 1> = [[
85    0,     // P(sym < 0) = 0
86    20000, // P(sym < 1) ≈ 0.61  (DC skip is common)
87    32768, // sentinel = CDF_PROB_TOP
88]];
89
90/// AC coefficient skip flag CDF (2 symbols: 0 = not-skipped, 1 = skipped).
91///
92/// AC coefficients are skipped less often than DC.
93pub const AC_COEFF_SKIP_CDF: CdfTable<3, 1> = [[
94    0,     // P(sym < 0) = 0
95    14000, // P(sym < 1) ≈ 0.43
96    32768, // sentinel = CDF_PROB_TOP
97]];
98
99/// Transform type CDF (16 symbols).
100///
101/// The DCT_DCT transform (symbol 0) is by far the most common (~80 %);
102/// all remaining 15 types share the remaining probability uniformly for
103/// this static table.
104pub const TRANSFORM_TYPE_CDF: CdfTable<17, 1> = [[
105    0,     // P(sym < 0) = 0
106    26200, // P(sym < 1)  ≈ 0.80  DCT_DCT
107    27340, // P(sym < 2)
108    28000, // P(sym < 3)
109    28600, // P(sym < 4)
110    29100, // P(sym < 5)
111    29550, // P(sym < 6)
112    29950, // P(sym < 7)
113    30310, // P(sym < 8)
114    30640, // P(sym < 9)
115    30950, // P(sym < 10)
116    31240, // P(sym < 11)
117    31520, // P(sym < 12)
118    31790, // P(sym < 13)
119    32060, // P(sym < 14)
120    32400, // P(sym < 15)
121    32768, // sentinel = CDF_PROB_TOP
122]];
123
124/// Block partition type CDF (4 symbols: NONE, HORZ, VERT, SPLIT).
125///
126/// For large blocks, NONE (no split) is most common.
127pub const PARTITION_TYPE_CDF: CdfTable<5, 1> = [[
128    0,     // P(sym < 0) = 0
129    16000, // P(sym < 1)  NONE  ≈ 0.49
130    21000, // P(sym < 2)  HORZ  ≈ 0.15
131    26000, // P(sym < 3)  VERT  ≈ 0.15
132    32768, // sentinel / P(sym < 4) = CDF_PROB_TOP  → SPLIT gets the rest
133]];
134
135// =============================================================================
136// RangeCoder — Subbotin carryless byte-oriented range coder
137// =============================================================================
138
139/// Multi-symbol range coder with table-based Q15 CDF lookup.
140///
141/// Uses a byte-oriented range coder where both encoder and decoder renormalise
142/// whenever `range < BOT` (BOT = 2¹⁶).  The encoder emits the top byte of
143/// `low` on each renorm step; the decoder consumes one byte from the bitstream.
144/// Both sides use identical arithmetic, so the decoder faithfully tracks the
145/// encoder's interval.
146///
147/// # Design
148///
149/// **Invariant:** `range ∈ [BOT, 2³²)` after every renormalisation.
150///
151/// **Last-symbol optimisation:** for the highest-probability symbol (the last
152/// entry in a CDF row) the encoder sets `range -= step * cum_lo` instead of
153/// `step * (cum_hi - cum_lo)`.  This avoids integer rounding errors that would
154/// make `low + range > 2³²`.
155///
156/// **Flush:** emit exactly 4 bytes of `low`, which together with the preceding
157/// renorm bytes uniquely identify the encoded sequence.  The decoder primes its
158/// `code` register with those same 4 leading bytes of the bitstream.
159#[derive(Debug, Clone)]
160pub struct RangeCoder {
161    // ── Shared state ─────────────────────────────────────────────────────────
162    /// Coding interval width (encoder) / search window offset (decoder).
163    range: u32,
164
165    // ── Encoder state ─────────────────────────────────────────────────────────
166    /// Lower bound of the current coding interval (encoder mode).
167    low: u32,
168    /// Encoded output bytes.
169    output: Vec<u8>,
170
171    // ── Decoder state ─────────────────────────────────────────────────────────
172    /// Input bitstream (decoder mode).
173    input: Vec<u8>,
174    /// Read cursor into `input` (decoder mode).
175    read_pos: usize,
176    /// Sliding 32-bit code register mirroring the encoder's `low` (decoder mode).
177    code: u32,
178    /// `true` when operating in decode mode.
179    decode_mode: bool,
180}
181
182impl RangeCoder {
183    /// Bottom threshold: `range` must be ≥ `BOT` after every renorm.
184    const BOT: u32 = 1 << 16;
185
186    /// Create a new range coder in **encoder** mode.
187    #[must_use]
188    pub fn new() -> Self {
189        Self {
190            range: u32::MAX,
191            low: 0,
192            output: Vec::new(),
193            input: Vec::new(),
194            read_pos: 0,
195            code: 0,
196            decode_mode: false,
197        }
198    }
199
200    /// Switch to **decode** mode, priming the code register from `data`.
201    ///
202    /// # Errors
203    ///
204    /// Returns `CodecError::InvalidBitstream` if `data` is empty.
205    pub fn init_from_slice(&mut self, data: &[u8]) -> Result<(), CodecError> {
206        if data.is_empty() {
207            return Err(CodecError::InvalidBitstream(
208                "RangeCoder: empty bitstream".into(),
209            ));
210        }
211        self.decode_mode = true;
212        self.input = data.to_vec();
213        self.read_pos = 0;
214        self.range = u32::MAX;
215        // Prime 4 bytes.
216        self.code = 0;
217        for _ in 0..4 {
218            let b = self.read_byte_internal();
219            self.code = (self.code << 8) | u32::from(b);
220        }
221        Ok(())
222    }
223
224    /// Flush encoder output and return the byte stream.
225    ///
226    /// Emits the 4 remaining bytes of `low`, which uniquely identify the
227    /// terminal interval.
228    #[must_use]
229    pub fn flush(mut self) -> Vec<u8> {
230        if !self.decode_mode {
231            for _ in 0..4 {
232                self.output.push((self.low >> 24) as u8);
233                self.low = self.low.wrapping_shl(8);
234            }
235        }
236        self.output
237    }
238
239    // ── Internal helpers ──────────────────────────────────────────────────────
240
241    fn read_byte_internal(&mut self) -> u8 {
242        if self.read_pos < self.input.len() {
243            let b = self.input[self.read_pos];
244            self.read_pos += 1;
245            b
246        } else {
247            0x00 // zero padding past end of stream
248        }
249    }
250
251    /// Encoder renorm: emit the top byte of `low` while `range < BOT`.
252    fn renormalize_encoder(&mut self) {
253        while self.range < Self::BOT {
254            self.output.push((self.low >> 24) as u8);
255            self.low = self.low.wrapping_shl(8);
256            self.range <<= 8;
257        }
258    }
259
260    /// Decoder renorm: read one byte into `code` while `range < BOT`.
261    fn renormalize_decoder(&mut self) {
262        while self.range < Self::BOT {
263            let b = self.read_byte_internal();
264            self.code = (self.code << 8) | u32::from(b);
265            self.range <<= 8;
266        }
267    }
268
269    /// Encode symbol `sym` ∈ `[0, n_syms)` using a Q15 CDF row.
270    fn encode_symbol_with_cdf(&mut self, sym: usize, cdf: &[u16]) -> Result<(), CodecError> {
271        let n_syms = cdf.len().saturating_sub(1);
272        if n_syms == 0 {
273            return Err(CodecError::InvalidParameter(
274                "CDF must have at least 2 entries".into(),
275            ));
276        }
277        if sym >= n_syms {
278            return Err(CodecError::InvalidParameter(format!(
279                "symbol {sym} out of range for {n_syms}-symbol CDF"
280            )));
281        }
282
283        let total = u32::from(CDF_PROB_TOP);
284        let cum_lo = u32::from(cdf[sym]);
285        let cum_hi = u32::from(cdf[sym + 1]);
286        let step = self.range / total;
287
288        self.low = self.low.wrapping_add(step * cum_lo);
289        // Last symbol gets the remainder so that low + range stays ≤ 2^32.
290        if sym + 1 < n_syms {
291            self.range = step * (cum_hi - cum_lo);
292        } else {
293            self.range -= step * cum_lo;
294        }
295
296        self.renormalize_encoder();
297        Ok(())
298    }
299
300    /// Decode one symbol from a Q15 CDF row.
301    fn decode_symbol_with_cdf(&mut self, cdf: &[u16]) -> Result<u8, CodecError> {
302        let n_syms = cdf.len().saturating_sub(1);
303        if n_syms == 0 {
304            return Err(CodecError::InvalidBitstream(
305                "CDF must have at least 2 entries".into(),
306            ));
307        }
308
309        let total = u32::from(CDF_PROB_TOP);
310        let step = self.range / total;
311
312        // Find the symbol whose encoder interval contains `code`:
313        // encoder set `low += step * cum_lo(sym)`, so we look for the
314        // largest i such that `step * cdf[i] <= code`.
315        let mut sym = n_syms - 1;
316        for i in 0..n_syms {
317            // Boundary: upper edge of symbol i is step * cdf[i+1].
318            if self.code < step * u32::from(cdf[i + 1]) {
319                sym = i;
320                break;
321            }
322        }
323
324        let cum_lo = u32::from(cdf[sym]);
325
326        self.code = self.code.wrapping_sub(step * cum_lo);
327        if sym + 1 < n_syms {
328            let cum_hi = u32::from(cdf[sym + 1]);
329            self.range = step * (cum_hi - cum_lo);
330        } else {
331            self.range -= step * cum_lo;
332        }
333
334        self.renormalize_decoder();
335
336        Ok(sym as u8)
337    }
338}
339
340// =============================================================================
341// Public functions
342// =============================================================================
343
344/// Encode `sym` using the CDF at `cdf_table[ctx]`.
345///
346/// # Errors
347///
348/// Returns `CodecError::InvalidParameter` if `ctx >= CTX`, `sym >= N`, or
349/// the CDF row is malformed.
350pub fn encode_symbol_table<const N: usize, const CTX: usize>(
351    rc: &mut RangeCoder,
352    sym: u8,
353    ctx: usize,
354    table: &CdfTable<N, CTX>,
355) -> Result<(), CodecError> {
356    if ctx >= CTX {
357        return Err(CodecError::InvalidParameter(format!(
358            "context {ctx} out of range (table has {CTX} contexts)"
359        )));
360    }
361    rc.encode_symbol_with_cdf(sym as usize, &table[ctx])
362}
363
364/// Decode one symbol using the CDF at `cdf_table[ctx]`.
365///
366/// Returns the decoded symbol index in `[0, N-1)`.
367///
368/// # Errors
369///
370/// Returns `CodecError::InvalidBitstream` if the bitstream is malformed,
371/// or `CodecError::InvalidParameter` if `ctx >= CTX`.
372pub fn decode_symbol_table<const N: usize, const CTX: usize>(
373    rc: &mut RangeCoder,
374    ctx: usize,
375    table: &CdfTable<N, CTX>,
376) -> Result<u8, CodecError> {
377    if ctx >= CTX {
378        return Err(CodecError::InvalidParameter(format!(
379            "context {ctx} out of range (table has {CTX} contexts)"
380        )));
381    }
382    rc.decode_symbol_with_cdf(&table[ctx])
383}
384
385// =============================================================================
386// Tests
387// =============================================================================
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    // ── CDF table structure validation ───────────────────────────────────────
394
395    #[test]
396    fn dc_coeff_skip_cdf_valid() {
397        let row = &DC_COEFF_SKIP_CDF[0];
398        assert_eq!(row[0], 0, "first CDF entry must be 0");
399        assert_eq!(
400            *row.last().expect("non-empty row"),
401            CDF_PROB_TOP,
402            "last entry must be CDF_PROB_TOP"
403        );
404        // Monotonically non-decreasing
405        for w in row.windows(2) {
406            assert!(w[0] <= w[1], "CDF must be monotonically non-decreasing");
407        }
408    }
409
410    #[test]
411    fn ac_coeff_skip_cdf_valid() {
412        let row = &AC_COEFF_SKIP_CDF[0];
413        assert_eq!(row[0], 0);
414        assert_eq!(*row.last().expect("non-empty"), CDF_PROB_TOP);
415        for w in row.windows(2) {
416            assert!(w[0] <= w[1]);
417        }
418    }
419
420    #[test]
421    fn transform_type_cdf_valid() {
422        let row = &TRANSFORM_TYPE_CDF[0];
423        assert_eq!(row[0], 0);
424        assert_eq!(*row.last().expect("non-empty"), CDF_PROB_TOP);
425        assert_eq!(row.len(), 17, "16 symbols + 1 sentinel");
426        for w in row.windows(2) {
427            assert!(w[0] <= w[1]);
428        }
429    }
430
431    #[test]
432    fn partition_type_cdf_valid() {
433        let row = &PARTITION_TYPE_CDF[0];
434        assert_eq!(row[0], 0);
435        assert_eq!(*row.last().expect("non-empty"), CDF_PROB_TOP);
436        assert_eq!(row.len(), 5, "4 symbols + 1 sentinel");
437        for w in row.windows(2) {
438            assert!(w[0] <= w[1]);
439        }
440    }
441
442    // ── RangeCoder basic encode/decode ───────────────────────────────────────
443
444    #[test]
445    fn range_coder_dc_skip_roundtrip_zero() {
446        let mut rc = RangeCoder::new();
447        encode_symbol_table(&mut rc, 0, 0, &DC_COEFF_SKIP_CDF).expect("encode sym 0");
448        let bs = rc.flush();
449
450        let mut dec = RangeCoder::new();
451        dec.init_from_slice(&bs).expect("init");
452        let sym = decode_symbol_table(&mut dec, 0, &DC_COEFF_SKIP_CDF).expect("decode");
453        assert_eq!(sym, 0, "should decode symbol 0");
454    }
455
456    #[test]
457    fn range_coder_dc_skip_roundtrip_one() {
458        let mut rc = RangeCoder::new();
459        encode_symbol_table(&mut rc, 1, 0, &DC_COEFF_SKIP_CDF).expect("encode sym 1");
460        let bs = rc.flush();
461
462        let mut dec = RangeCoder::new();
463        dec.init_from_slice(&bs).expect("init");
464        let sym = decode_symbol_table(&mut dec, 0, &DC_COEFF_SKIP_CDF).expect("decode");
465        assert_eq!(sym, 1, "should decode symbol 1");
466    }
467
468    #[test]
469    fn range_coder_partition_type_all_symbols() {
470        for sym_in in 0u8..4 {
471            let mut rc = RangeCoder::new();
472            encode_symbol_table(&mut rc, sym_in, 0, &PARTITION_TYPE_CDF).expect("encode partition");
473            let bs = rc.flush();
474
475            let mut dec = RangeCoder::new();
476            dec.init_from_slice(&bs).expect("init");
477            let sym_out = decode_symbol_table(&mut dec, 0, &PARTITION_TYPE_CDF).expect("decode");
478            assert_eq!(
479                sym_out, sym_in,
480                "partition type {sym_in} must survive round-trip"
481            );
482        }
483    }
484
485    #[test]
486    fn range_coder_transform_type_all_symbols() {
487        for sym_in in 0u8..16 {
488            let mut rc = RangeCoder::new();
489            encode_symbol_table(&mut rc, sym_in, 0, &TRANSFORM_TYPE_CDF).expect("encode tx type");
490            let bs = rc.flush();
491
492            let mut dec = RangeCoder::new();
493            dec.init_from_slice(&bs).expect("init");
494            let sym_out = decode_symbol_table(&mut dec, 0, &TRANSFORM_TYPE_CDF).expect("decode tx");
495            assert_eq!(
496                sym_out, sym_in,
497                "transform type {sym_in} must survive round-trip"
498            );
499        }
500    }
501
502    #[test]
503    fn range_coder_ac_skip_roundtrip() {
504        let symbols = [0u8, 1, 0, 0, 1, 1, 0, 1];
505        let mut rc = RangeCoder::new();
506        for &s in &symbols {
507            encode_symbol_table(&mut rc, s, 0, &AC_COEFF_SKIP_CDF).expect("encode");
508        }
509        let bs = rc.flush();
510
511        let mut dec = RangeCoder::new();
512        dec.init_from_slice(&bs).expect("init");
513        for &expected in &symbols {
514            let got = decode_symbol_table(&mut dec, 0, &AC_COEFF_SKIP_CDF).expect("decode");
515            assert_eq!(got, expected);
516        }
517    }
518
519    #[test]
520    fn range_coder_sequence_mixed_tables() {
521        // Interleave symbols from different tables.
522        let dc_syms = [0u8, 1, 0];
523        let tx_syms = [0u8, 5, 15];
524        let pt_syms = [3u8, 0, 2];
525
526        let mut rc = RangeCoder::new();
527        for i in 0..3 {
528            encode_symbol_table(&mut rc, dc_syms[i], 0, &DC_COEFF_SKIP_CDF).expect("encode dc");
529            encode_symbol_table(&mut rc, tx_syms[i], 0, &TRANSFORM_TYPE_CDF).expect("encode tx");
530            encode_symbol_table(&mut rc, pt_syms[i], 0, &PARTITION_TYPE_CDF).expect("encode pt");
531        }
532        let bs = rc.flush();
533
534        let mut dec = RangeCoder::new();
535        dec.init_from_slice(&bs).expect("init");
536        for i in 0..3 {
537            let dc = decode_symbol_table(&mut dec, 0, &DC_COEFF_SKIP_CDF).expect("decode dc");
538            let tx = decode_symbol_table(&mut dec, 0, &TRANSFORM_TYPE_CDF).expect("decode tx");
539            let pt = decode_symbol_table(&mut dec, 0, &PARTITION_TYPE_CDF).expect("decode pt");
540            assert_eq!(dc, dc_syms[i]);
541            assert_eq!(tx, tx_syms[i]);
542            assert_eq!(pt, pt_syms[i]);
543        }
544    }
545
546    #[test]
547    fn range_coder_long_sequence_dc_skip() {
548        // 100 symbols: alternating 0 and 1.
549        let symbols: Vec<u8> = (0u8..100).map(|i| i % 2).collect();
550
551        let mut rc = RangeCoder::new();
552        for &s in &symbols {
553            encode_symbol_table(&mut rc, s, 0, &DC_COEFF_SKIP_CDF).expect("encode");
554        }
555        let bs = rc.flush();
556
557        let mut dec = RangeCoder::new();
558        dec.init_from_slice(&bs).expect("init");
559        for (i, &expected) in symbols.iter().enumerate() {
560            let got = decode_symbol_table(&mut dec, 0, &DC_COEFF_SKIP_CDF).expect("decode");
561            assert_eq!(got, expected, "mismatch at symbol {i}");
562        }
563    }
564
565    #[test]
566    fn range_coder_all_same_symbol_zero() {
567        let n = 50;
568        let mut rc = RangeCoder::new();
569        for _ in 0..n {
570            encode_symbol_table(&mut rc, 0, 0, &PARTITION_TYPE_CDF).expect("encode");
571        }
572        let bs = rc.flush();
573
574        let mut dec = RangeCoder::new();
575        dec.init_from_slice(&bs).expect("init");
576        for i in 0..n {
577            let got = decode_symbol_table(&mut dec, 0, &PARTITION_TYPE_CDF).expect("decode");
578            assert_eq!(got, 0u8, "all-zero sequence failed at index {i}");
579        }
580    }
581
582    #[test]
583    fn range_coder_context_out_of_range_error() {
584        let mut rc = RangeCoder::new();
585        // DC_COEFF_SKIP_CDF has only 1 context (index 0).
586        let result = encode_symbol_table(&mut rc, 0, 1, &DC_COEFF_SKIP_CDF);
587        assert!(result.is_err(), "context 1 should be out of range");
588    }
589
590    #[test]
591    fn range_coder_symbol_out_of_range_error() {
592        let mut rc = RangeCoder::new();
593        // DC_COEFF_SKIP_CDF has 2 symbols (0, 1). Symbol 2 is invalid.
594        let result = encode_symbol_table(&mut rc, 2, 0, &DC_COEFF_SKIP_CDF);
595        assert!(result.is_err(), "symbol 2 should be out of range");
596    }
597
598    #[test]
599    fn range_coder_empty_bitstream_error() {
600        let mut dec = RangeCoder::new();
601        let result = dec.init_from_slice(&[]);
602        assert!(result.is_err(), "empty bitstream must return error");
603    }
604
605    #[test]
606    fn range_coder_new_is_in_encode_mode() {
607        let rc = RangeCoder::new();
608        assert!(!rc.decode_mode, "new coder should be in encode mode");
609        assert_eq!(rc.output.len(), 0, "no output yet");
610    }
611
612    #[test]
613    fn range_coder_flush_produces_bytes() {
614        let mut rc = RangeCoder::new();
615        encode_symbol_table(&mut rc, 0, 0, &DC_COEFF_SKIP_CDF).expect("encode");
616        let bs = rc.flush();
617        assert!(!bs.is_empty(), "flush must produce at least one byte");
618    }
619
620    #[test]
621    fn benchmark_table_vs_scalar_estimate() {
622        // Estimate throughput advantage of table lookup.
623        // Encode 10_000 symbols with table-based coder and verify it completes.
624        let symbols: Vec<u8> = (0u8..200).cycle().take(10_000).map(|x| x % 2).collect();
625
626        let mut rc = RangeCoder::new();
627        for &s in &symbols {
628            encode_symbol_table(&mut rc, s, 0, &DC_COEFF_SKIP_CDF).expect("encode");
629        }
630        let bs = rc.flush();
631
632        // Compressed size must be less than raw size (2 bits/symbol max → 2500 bytes).
633        assert!(
634            bs.len() <= 2500,
635            "compressed size {} should be ≤ 2500 bytes for {}-symbol DC skip stream",
636            bs.len(),
637            symbols.len()
638        );
639
640        // Decode and verify correctness.
641        let mut dec = RangeCoder::new();
642        dec.init_from_slice(&bs).expect("init");
643        for (i, &expected) in symbols.iter().enumerate() {
644            let got = decode_symbol_table(&mut dec, 0, &DC_COEFF_SKIP_CDF).expect("decode");
645            assert_eq!(got, expected, "bulk decode mismatch at index {i}");
646        }
647    }
648}