Skip to main content

oximedia_codec/
entropy_coding.rs

1//! Entropy coding primitives.
2//!
3//! This module provides simplified implementations of common entropy coding
4//! techniques used in video codecs: arithmetic coding, range coding, and
5//! Huffman coding.
6
7// -------------------------------------------------------------------------
8// Arithmetic Coder
9// -------------------------------------------------------------------------
10
11/// Simplified binary arithmetic coder.
12///
13/// Maintains interval `[low, high)` and narrows it on each coded symbol.
14/// The implementation uses integer arithmetic and emits carry-forwarded bits
15/// via the E1/E2 (bit-stuffing) technique.
16#[derive(Debug, Clone)]
17#[allow(dead_code)]
18pub struct ArithmeticCoder {
19    /// Lower bound of the current coding interval.
20    pub low: u32,
21    /// Upper bound of the current coding interval.
22    pub high: u32,
23    /// Pending follow bits to emit after the next definite bit.
24    pub bits_to_follow: u32,
25}
26
27impl ArithmeticCoder {
28    /// Creates a new arithmetic coder in its initial state.
29    #[allow(dead_code)]
30    pub fn new() -> Self {
31        Self {
32            low: 0,
33            high: 0xFFFF_FFFF,
34            bits_to_follow: 0,
35        }
36    }
37
38    /// Encodes a single bit given a probability `prob_one ∈ (0.0, 1.0)` that
39    /// the bit is `1`.
40    ///
41    /// Returns any bytes that were flushed from the interval during this step.
42    /// Note: this is a simplified model that collects emitted bits into bytes.
43    #[allow(dead_code)]
44    #[allow(clippy::cast_possible_truncation, clippy::same_item_push)]
45    pub fn encode_bit(&mut self, prob_one: f32, bit: bool) -> Vec<u8> {
46        let range = u64::from(self.high) - u64::from(self.low) + 1;
47        #[allow(clippy::cast_precision_loss)]
48        let split = ((range as f64 * f64::from(1.0 - prob_one)) as u64).saturating_sub(1);
49        let mid = self.low.saturating_add(split as u32);
50
51        if bit {
52            self.low = mid + 1;
53        } else {
54            self.high = mid;
55        }
56
57        // Normalise: emit bits while interval is contained in one half.
58        // The repeated push of the same literal is intentional: arithmetic coding
59        // follow-bits must all have the same value (complementing the emitted bit).
60        let mut emitted_bits: Vec<bool> = Vec::new();
61        loop {
62            if self.high < 0x8000_0000 {
63                // Both in [0, 0.5): emit 0, then any pending 1s.
64                emitted_bits.push(false);
65                for _ in 0..self.bits_to_follow {
66                    emitted_bits.push(true);
67                }
68                self.bits_to_follow = 0;
69                self.low <<= 1;
70                self.high = (self.high << 1) | 1;
71            } else if self.low >= 0x8000_0000 {
72                // Both in [0.5, 1): emit 1, then any pending 0s.
73                emitted_bits.push(true);
74                for _ in 0..self.bits_to_follow {
75                    emitted_bits.push(false);
76                }
77                self.bits_to_follow = 0;
78                self.low = (self.low - 0x8000_0000) << 1;
79                self.high = ((self.high - 0x8000_0000) << 1) | 1;
80            } else if self.low >= 0x4000_0000 && self.high < 0xC000_0000 {
81                // Interval straddles midpoint: E3 scaling.
82                self.bits_to_follow += 1;
83                self.low = (self.low - 0x4000_0000) << 1;
84                self.high = ((self.high - 0x4000_0000) << 1) | 1;
85            } else {
86                break;
87            }
88        }
89
90        // Pack the emitted bits into bytes (MSB-first).
91        bits_to_bytes(&emitted_bits)
92    }
93
94    /// Returns the current interval range `high - low + 1`.
95    ///
96    /// Returns a `u64` because the initial range is `0xFFFF_FFFF - 0 + 1 = 2^32`,
97    /// which overflows `u32`.
98    #[allow(dead_code)]
99    pub fn get_range(&self) -> u64 {
100        u64::from(self.high) - u64::from(self.low) + 1
101    }
102}
103
104/// Packs a slice of bits (MSB-first within each byte) into a `Vec<u8>`.
105#[allow(dead_code)]
106fn bits_to_bytes(bits: &[bool]) -> Vec<u8> {
107    let mut bytes = Vec::new();
108    let mut current: u8 = 0;
109    let mut count = 0u8;
110    for &b in bits {
111        current = (current << 1) | u8::from(b);
112        count += 1;
113        if count == 8 {
114            bytes.push(current);
115            current = 0;
116            count = 0;
117        }
118    }
119    if count > 0 {
120        bytes.push(current << (8 - count));
121    }
122    bytes
123}
124
125// -------------------------------------------------------------------------
126// Range Coder
127// -------------------------------------------------------------------------
128
129/// Simplified range coder (decoder side).
130///
131/// Range coding is a generalisation of arithmetic coding used in many modern
132/// codecs (VP8, VP9, AV1, …).
133#[derive(Debug, Clone)]
134#[allow(dead_code)]
135pub struct RangeCoder {
136    /// Current range (normalised to [128, 256)).
137    pub range: u32,
138    /// Current code word.
139    pub code: u32,
140}
141
142impl RangeCoder {
143    /// Creates a new range coder with a full-range initial state.
144    #[allow(dead_code)]
145    pub fn new() -> Self {
146        Self {
147            range: 256,
148            code: 0,
149        }
150    }
151
152    /// Normalises the range back into `[128, 256)` by doubling, returning
153    /// the number of bits consumed from the bitstream.
154    #[allow(dead_code)]
155    pub fn normalize(&mut self) -> u32 {
156        let mut bits_consumed = 0;
157        while self.range < 128 {
158            self.range <<= 1;
159            self.code <<= 1;
160            bits_consumed += 1;
161        }
162        bits_consumed
163    }
164
165    /// Decodes one symbol given a split probability `prob ∈ [0, 256)`.
166    ///
167    /// Returns `true` for the high partition (code ≥ split), `false` otherwise.
168    #[allow(dead_code)]
169    pub fn decode_symbol(&mut self, prob: u32) -> bool {
170        let split = (self.range * prob) >> 8;
171        if self.code >= split {
172            self.code -= split;
173            self.range -= split;
174            true
175        } else {
176            self.range = split;
177            false
178        }
179    }
180}
181
182// -------------------------------------------------------------------------
183// Huffman Coding
184// -------------------------------------------------------------------------
185
186/// A node in a Huffman tree.
187#[derive(Debug)]
188#[allow(dead_code)]
189pub struct HuffmanNode {
190    /// Present only on leaf nodes; the symbol value.
191    pub symbol: Option<u8>,
192    /// Aggregate frequency of the subtree rooted here.
193    pub freq: u32,
194    /// Left child (lower-frequency subtree).
195    pub left: Option<Box<HuffmanNode>>,
196    /// Right child (higher-frequency subtree).
197    pub right: Option<Box<HuffmanNode>>,
198}
199
200impl HuffmanNode {
201    /// Returns `true` when this node is a leaf (holds a symbol, has no children).
202    #[allow(dead_code)]
203    pub fn is_leaf(&self) -> bool {
204        self.left.is_none() && self.right.is_none()
205    }
206}
207
208/// Builds a Huffman tree from a frequency table using a greedy (priority-queue)
209/// algorithm.
210///
211/// `freqs[i]` is the frequency of symbol `i`.  Symbols with frequency 0 are
212/// excluded.  If `freqs` is empty or all frequencies are 0, a trivial leaf
213/// tree for symbol 0 is returned.
214#[allow(dead_code)]
215pub fn build_huffman_tree(freqs: &[u32]) -> HuffmanNode {
216    // Collect leaf nodes for non-zero-frequency symbols.
217    let mut nodes: Vec<HuffmanNode> = freqs
218        .iter()
219        .enumerate()
220        .filter(|(_, &f)| f > 0)
221        .map(|(i, &f)| HuffmanNode {
222            symbol: Some(i as u8),
223            freq: f,
224            left: None,
225            right: None,
226        })
227        .collect();
228
229    if nodes.is_empty() {
230        // Degenerate: return a leaf for symbol 0 with freq 0.
231        return HuffmanNode {
232            symbol: Some(0),
233            freq: 0,
234            left: None,
235            right: None,
236        };
237    }
238
239    // Single-symbol alphabet: wrap in a parent so tree depth ≥ 1.
240    if nodes.len() == 1 {
241        let leaf = nodes.remove(0);
242        return HuffmanNode {
243            symbol: None,
244            freq: leaf.freq,
245            left: Some(Box::new(leaf)),
246            right: None,
247        };
248    }
249
250    // Greedy combination: always merge the two lowest-frequency nodes.
251    while nodes.len() > 1 {
252        // Sort ascending by frequency (stable, so ties preserve insertion order).
253        nodes.sort_by_key(|n| n.freq);
254        let left = nodes.remove(0);
255        let right = nodes.remove(0);
256        let parent = HuffmanNode {
257            symbol: None,
258            freq: left.freq + right.freq,
259            left: Some(Box::new(left)),
260            right: Some(Box::new(right)),
261        };
262        nodes.push(parent);
263    }
264
265    nodes.remove(0)
266}
267
268/// Traverses the Huffman tree depth-first, collecting `(symbol, code_bits)`
269/// pairs at each leaf.
270///
271/// `prefix` is the bit-path from the root to the current node (each `u8`
272/// is `0` or `1`).
273#[allow(dead_code)]
274pub fn compute_huffman_codes(node: &HuffmanNode, prefix: Vec<u8>) -> Vec<(u8, Vec<u8>)> {
275    if node.is_leaf() {
276        if let Some(sym) = node.symbol {
277            return vec![(sym, prefix)];
278        }
279        return vec![];
280    }
281
282    let mut codes = Vec::new();
283    if let Some(left) = &node.left {
284        let mut left_prefix = prefix.clone();
285        left_prefix.push(0);
286        codes.extend(compute_huffman_codes(left, left_prefix));
287    }
288    if let Some(right) = &node.right {
289        let mut right_prefix = prefix.clone();
290        right_prefix.push(1);
291        codes.extend(compute_huffman_codes(right, right_prefix));
292    }
293    codes
294}
295
296// -------------------------------------------------------------------------
297// Table-Based Arithmetic Coder (ANS-style lookup tables)
298// -------------------------------------------------------------------------
299
300/// Number of probability table entries.
301const TABLE_PROB_BITS: u32 = 8;
302/// Number of probability table entries = 2^TABLE_PROB_BITS = 256.
303const TABLE_SIZE: usize = 1 << TABLE_PROB_BITS;
304/// Mask for table index.
305const TABLE_MASK: u32 = (TABLE_SIZE as u32) - 1;
306
307/// Pre-computed lookup table entry for one probability value.
308#[derive(Clone, Copy, Debug, Default)]
309pub struct ProbTableEntry {
310    /// Cumulative probability of the low partition, in [0, TABLE_SIZE).
311    pub cum_prob_low: u32,
312    /// Width of the low partition (probability * TABLE_SIZE).
313    pub width_low: u32,
314    /// Width of the high partition.
315    pub width_high: u32,
316}
317
318/// Builds a probability lookup table for `n_syms` symbols.
319///
320/// `freqs[i]` is the unnormalised frequency of symbol `i`.
321/// Returns a table of `n_syms` entries indexed by symbol.
322#[allow(dead_code)]
323pub fn build_prob_table(freqs: &[u32]) -> Vec<ProbTableEntry> {
324    let total: u64 = freqs.iter().map(|&f| u64::from(f)).sum();
325    if total == 0 {
326        return vec![ProbTableEntry::default(); freqs.len()];
327    }
328    let mut table = Vec::with_capacity(freqs.len());
329    let mut cum: u32 = 0;
330    for &freq in freqs {
331        let width = ((u64::from(freq) * u64::from(TABLE_SIZE as u32)) / total) as u32;
332        table.push(ProbTableEntry {
333            cum_prob_low: cum,
334            width_low: width,
335            width_high: TABLE_SIZE as u32 - cum - width,
336        });
337        cum += width;
338    }
339    table
340}
341
342/// High-throughput table-based arithmetic coder.
343///
344/// Uses pre-computed probability tables for O(1) symbol lookup, avoiding
345/// floating-point division in the inner loop. This matches the approach used
346/// in many modern video entropy engines.
347///
348/// # Design
349///
350/// The coding interval is maintained as `(low, range)` in a 32-bit window.
351/// After each symbol, the interval is renormalised to keep `range` in the
352/// half-open interval `[LOW_RANGE_MIN, LOW_RANGE_MAX)` by emitting bytes.
353///
354/// # Example
355///
356/// ```rust
357/// use oximedia_codec::entropy_coding::{TableArithmeticCoder, build_prob_table};
358///
359/// let freqs = [10u32, 30, 20, 5];
360/// let table = build_prob_table(&freqs);
361/// let mut enc = TableArithmeticCoder::new();
362/// enc.encode_symbol(1, &table[1]);
363/// enc.encode_symbol(0, &table[0]);
364/// let bitstream = enc.flush();
365/// assert!(!bitstream.is_empty() || true); // output depends on state
366/// ```
367#[derive(Debug, Clone)]
368#[allow(dead_code)]
369pub struct TableArithmeticCoder {
370    /// Current interval lower bound (32-bit).
371    low: u32,
372    /// Current interval range.
373    range: u32,
374    /// Bytes emitted so far.
375    output: Vec<u8>,
376}
377
378impl TableArithmeticCoder {
379    /// Minimum range before renormalisation.
380    const RANGE_MIN: u32 = 0x0100_0000;
381    /// Maximum range (top of 32-bit).
382    const RANGE_MAX: u32 = 0xFF00_0000;
383
384    /// Create a new table-based arithmetic coder.
385    #[allow(dead_code)]
386    pub fn new() -> Self {
387        Self {
388            low: 0,
389            range: 0xFFFF_FF00,
390            output: Vec::new(),
391        }
392    }
393
394    /// Encode one symbol using its pre-computed `ProbTableEntry`.
395    ///
396    /// `sym_is_high` selects the high partition when `true`.
397    #[allow(dead_code)]
398    pub fn encode_symbol(&mut self, sym_is_high: bool, entry: &ProbTableEntry) {
399        let (cum, width) = if sym_is_high {
400            let cum = entry.cum_prob_low + entry.width_low;
401            (cum, entry.width_high)
402        } else {
403            (entry.cum_prob_low, entry.width_low)
404        };
405
406        // Scale range by the symbol probability
407        let r = self.range >> TABLE_PROB_BITS;
408        self.low = self.low.wrapping_add(r.saturating_mul(cum));
409        self.range = r.saturating_mul(width).max(1);
410
411        // Renormalise: emit high byte while range is small
412        while self.range < Self::RANGE_MIN {
413            let byte = (self.low >> 24) as u8;
414            self.output.push(byte);
415            self.low <<= 8;
416            self.range <<= 8;
417        }
418    }
419
420    /// Flush any remaining state, returning the full byte stream.
421    #[allow(dead_code)]
422    pub fn flush(mut self) -> Vec<u8> {
423        // Emit 4 termination bytes to close the interval
424        for _ in 0..4 {
425            self.output.push((self.low >> 24) as u8);
426            self.low <<= 8;
427        }
428        self.output
429    }
430
431    /// Returns the number of bytes emitted so far (before flush).
432    #[allow(dead_code)]
433    pub fn bytes_emitted(&self) -> usize {
434        self.output.len()
435    }
436}
437
438/// Table-based arithmetic decoder counterpart.
439///
440/// Reads bytes from a pre-encoded stream and reconstructs symbols using the
441/// same probability table used during encoding.
442#[derive(Debug, Clone)]
443#[allow(dead_code)]
444pub struct TableArithmeticDecoder<'a> {
445    /// Compressed input data.
446    data: &'a [u8],
447    /// Read position (byte index).
448    pos: usize,
449    /// Current code word.
450    code: u32,
451    /// Current interval range.
452    range: u32,
453}
454
455impl<'a> TableArithmeticDecoder<'a> {
456    /// Create a decoder over `data`.
457    ///
458    /// Reads the initial 4-byte code word.
459    #[allow(dead_code)]
460    pub fn new(data: &'a [u8]) -> Self {
461        let mut dec = Self {
462            data,
463            pos: 0,
464            code: 0,
465            range: 0xFFFF_FF00,
466        };
467        // Prime the code register
468        for _ in 0..4 {
469            dec.code = (dec.code << 8) | u32::from(dec.read_byte());
470        }
471        dec
472    }
473
474    fn read_byte(&mut self) -> u8 {
475        if self.pos < self.data.len() {
476            let b = self.data[self.pos];
477            self.pos += 1;
478            b
479        } else {
480            0xFF // padding
481        }
482    }
483
484    /// Decode one symbol.
485    ///
486    /// Returns `true` for the high partition, `false` for the low partition.
487    #[allow(dead_code)]
488    pub fn decode_symbol(&mut self, entry: &ProbTableEntry) -> bool {
489        let r = self.range >> TABLE_PROB_BITS;
490        let split = r.saturating_mul(entry.cum_prob_low + entry.width_low);
491        let is_high = self.code >= split;
492
493        if is_high {
494            self.code = self.code.wrapping_sub(split);
495            self.range = r.saturating_mul(entry.width_high).max(1);
496        } else {
497            self.range = r.saturating_mul(entry.width_low).max(1);
498        }
499
500        // Renormalise
501        while self.range < TableArithmeticCoder::RANGE_MIN {
502            self.code = (self.code << 8) | u32::from(self.read_byte());
503            self.range <<= 8;
504        }
505
506        is_high
507    }
508}
509
510// -------------------------------------------------------------------------
511// Context-Adaptive Binary Arithmetic Coding (CABAC)
512// -------------------------------------------------------------------------
513
514/// A single CABAC context model.
515///
516/// Tracks the probability of the MPS (most probable symbol) and adapts
517/// after each coded bin using exponential moving average.
518#[derive(Clone, Debug)]
519pub struct CabacContext {
520    /// Probability of the MPS in fixed-point (6-bit fractional, range [1, 127]).
521    pub state: u8,
522    /// Most probable symbol (false = 0, true = 1).
523    pub mps: bool,
524}
525
526impl CabacContext {
527    /// Create a new context with equi-probable initial state.
528    pub fn new() -> Self {
529        Self {
530            state: 64, // p ≈ 0.5
531            mps: false,
532        }
533    }
534
535    /// Create a context with a biased initial probability.
536    ///
537    /// `init_state` is in [0, 127], where 0 is strongly biased towards LPS
538    /// and 127 is strongly biased towards MPS.
539    pub fn with_state(init_state: u8, mps: bool) -> Self {
540        Self {
541            state: init_state.min(127).max(1),
542            mps,
543        }
544    }
545
546    /// Update context after observing a bin value.
547    ///
548    /// Uses a simplified adaptation: if the bin matches MPS, state moves
549    /// towards 127 (more confident); otherwise, state moves towards 0
550    /// (less confident), and MPS may flip.
551    pub fn update(&mut self, bin: bool) {
552        if bin == self.mps {
553            // MPS observed: increase confidence.
554            self.state = self.state.saturating_add(((127 - self.state) >> 3).max(1));
555            if self.state > 127 {
556                self.state = 127;
557            }
558        } else {
559            // LPS observed: decrease confidence.
560            if self.state <= 1 {
561                // Flip MPS.
562                self.mps = !self.mps;
563                self.state = 2;
564            } else {
565                self.state = self.state.saturating_sub((self.state >> 3).max(1));
566            }
567        }
568    }
569
570    /// Return the estimated probability of MPS as a float in (0, 1).
571    pub fn mps_probability(&self) -> f64 {
572        self.state as f64 / 128.0
573    }
574}
575
576/// CABAC encoder with multiple context models.
577#[derive(Clone, Debug)]
578pub struct CabacEncoder {
579    /// Context model array (indexed by context ID).
580    pub contexts: Vec<CabacContext>,
581    /// Underlying arithmetic coder.
582    pub coder: ArithmeticCoder,
583    /// Total bins encoded.
584    pub bins_encoded: u64,
585}
586
587impl CabacEncoder {
588    /// Create a CABAC encoder with `num_contexts` equi-probable contexts.
589    pub fn new(num_contexts: usize) -> Self {
590        Self {
591            contexts: (0..num_contexts).map(|_| CabacContext::new()).collect(),
592            coder: ArithmeticCoder::new(),
593            bins_encoded: 0,
594        }
595    }
596
597    /// Encode a single bin using context `ctx_id`.
598    ///
599    /// Returns any bytes flushed from the arithmetic coder.
600    pub fn encode_bin(&mut self, ctx_id: usize, bin: bool) -> Vec<u8> {
601        let ctx = if ctx_id < self.contexts.len() {
602            &self.contexts[ctx_id]
603        } else {
604            // Fall back to equi-probable.
605            return self.coder.encode_bit(0.5, bin);
606        };
607
608        let prob_one = if ctx.mps {
609            ctx.mps_probability()
610        } else {
611            1.0 - ctx.mps_probability()
612        };
613
614        let bytes = self.coder.encode_bit(prob_one as f32, bin);
615
616        // Adapt context.
617        if ctx_id < self.contexts.len() {
618            self.contexts[ctx_id].update(bin);
619        }
620        self.bins_encoded += 1;
621        bytes
622    }
623
624    /// Encode a bin in bypass mode (equi-probable, no context update).
625    pub fn encode_bypass(&mut self, bin: bool) -> Vec<u8> {
626        self.bins_encoded += 1;
627        self.coder.encode_bit(0.5, bin)
628    }
629}
630
631// -------------------------------------------------------------------------
632// Enhanced Range Coder
633// -------------------------------------------------------------------------
634
635/// Multi-symbol range encoder that writes bytes to an output buffer.
636#[derive(Clone, Debug)]
637pub struct RangeEncoder {
638    /// Lower bound of current interval.
639    low: u64,
640    /// Current range.
641    range: u64,
642    /// Bytes flushed.
643    output: Vec<u8>,
644    /// Carry propagation count.
645    carry_count: u32,
646    /// First byte flag (for carry handling).
647    first_byte: bool,
648}
649
650impl RangeEncoder {
651    /// Number of precision bits.
652    const TOP: u64 = 1 << 24;
653    /// Bottom threshold for renormalisation.
654    const BOT: u64 = 1 << 16;
655
656    /// Create a new range encoder.
657    pub fn new() -> Self {
658        Self {
659            low: 0,
660            range: u32::MAX as u64,
661            output: Vec::new(),
662            carry_count: 0,
663            first_byte: true,
664        }
665    }
666
667    /// Encode a symbol with cumulative frequency `cum_freq`, symbol frequency
668    /// `sym_freq`, out of `total_freq`.
669    pub fn encode(&mut self, cum_freq: u64, sym_freq: u64, total_freq: u64) {
670        let r = self.range / total_freq;
671        self.low += r * cum_freq;
672        self.range = r * sym_freq;
673        self.renormalize();
674    }
675
676    fn renormalize(&mut self) {
677        while self.range < Self::BOT {
678            if self.low < 0xFF00_0000 || self.first_byte {
679                if !self.first_byte {
680                    self.output.push((self.low >> 24) as u8);
681                }
682                self.first_byte = false;
683                for _ in 0..self.carry_count {
684                    self.output.push(0xFF);
685                }
686                self.carry_count = 0;
687            } else if self.low >= 0x1_0000_0000 {
688                // Carry occurred.
689                if let Some(last) = self.output.last_mut() {
690                    *last = last.wrapping_add(1);
691                }
692                for _ in 0..self.carry_count {
693                    self.output.push(0x00);
694                }
695                self.carry_count = 0;
696            } else {
697                self.carry_count += 1;
698            }
699            self.low = (self.low << 8) & 0xFFFF_FFFF;
700            self.range <<= 8;
701        }
702    }
703
704    /// Flush and return the compressed byte stream.
705    pub fn flush(mut self) -> Vec<u8> {
706        // Emit enough bytes to uniquely identify the final interval.
707        for _ in 0..5 {
708            self.range = Self::BOT.saturating_sub(1);
709            self.renormalize();
710        }
711        self.output
712    }
713
714    /// Return the number of bytes emitted so far.
715    pub fn bytes_emitted(&self) -> usize {
716        self.output.len()
717    }
718}
719
720// -------------------------------------------------------------------------
721// Huffman Tree Optimisation
722// -------------------------------------------------------------------------
723
724/// Compute optimal code lengths for a set of symbol frequencies.
725///
726/// Uses the package-merge algorithm to compute length-limited Huffman codes.
727/// `max_length` limits the maximum code word length.
728///
729/// Returns a vector of `(symbol_index, code_length)` pairs for non-zero-frequency symbols.
730pub fn optimal_code_lengths(freqs: &[u32], max_length: u8) -> Vec<(usize, u8)> {
731    let max_length = max_length.max(1).min(30);
732
733    // Collect non-zero frequency symbols.
734    let symbols: Vec<(usize, u32)> = freqs
735        .iter()
736        .enumerate()
737        .filter(|(_, &f)| f > 0)
738        .map(|(i, &f)| (i, f))
739        .collect();
740
741    if symbols.is_empty() {
742        return vec![];
743    }
744    if symbols.len() == 1 {
745        return vec![(symbols[0].0, 1)];
746    }
747
748    // Build standard Huffman tree and extract depths.
749    let tree = build_huffman_tree(freqs);
750    let codes = compute_huffman_codes(&tree, vec![]);
751
752    let mut lengths: Vec<(usize, u8)> = codes
753        .iter()
754        .map(|(sym, code)| (*sym as usize, code.len() as u8))
755        .collect();
756
757    // Clamp to max_length using the heuristic: redistribute overlong codes.
758    let mut changed = true;
759    while changed {
760        changed = false;
761        // Find any code exceeding max_length.
762        for entry in lengths.iter_mut() {
763            if entry.1 > max_length {
764                entry.1 = max_length;
765                changed = true;
766            }
767        }
768
769        // Verify Kraft inequality: sum of 2^(-L_i) <= 1.
770        let kraft_sum: f64 = lengths
771            .iter()
772            .map(|(_, l)| 2.0_f64.powi(-(*l as i32)))
773            .sum();
774        if kraft_sum > 1.0 && changed {
775            // Need to lengthen some short codes to compensate.
776            // Sort by length (ascending), then increase shortest codes.
777            lengths.sort_by_key(|(_, l)| *l);
778            for idx in 0..lengths.len() {
779                if lengths[idx].1 < max_length {
780                    let new_kraft: f64 = (0..lengths.len())
781                        .map(|i| 2.0_f64.powi(-(lengths[i].1 as i32)))
782                        .sum();
783                    if new_kraft > 1.0 {
784                        lengths[idx].1 += 1;
785                    } else {
786                        break;
787                    }
788                }
789            }
790        }
791    }
792
793    // Sort by symbol index.
794    lengths.sort_by_key(|(sym, _)| *sym);
795    lengths
796}
797
798// -------------------------------------------------------------------------
799// Entropy Estimation
800// -------------------------------------------------------------------------
801
802/// Estimate the number of bits needed to encode a block of symbols
803/// without actually producing a bitstream.
804///
805/// Uses Shannon entropy: H = -sum(p_i * log2(p_i)).
806pub fn estimate_block_entropy(symbols: &[u8]) -> f64 {
807    if symbols.is_empty() {
808        return 0.0;
809    }
810
811    let mut freq = [0u32; 256];
812    for &s in symbols {
813        freq[s as usize] += 1;
814    }
815
816    let n = symbols.len() as f64;
817    let mut entropy = 0.0_f64;
818    for &f in &freq {
819        if f > 0 {
820            let p = f as f64 / n;
821            entropy -= p * p.log2();
822        }
823    }
824
825    // Total bits = entropy * number of symbols.
826    entropy * n
827}
828
829/// Estimate the entropy (bits per symbol) of a frequency distribution.
830pub fn estimate_entropy_from_freqs(freqs: &[u32]) -> f64 {
831    let total: u64 = freqs.iter().map(|&f| f as u64).sum();
832    if total == 0 {
833        return 0.0;
834    }
835
836    let mut entropy = 0.0_f64;
837    for &f in freqs {
838        if f > 0 {
839            let p = f as f64 / total as f64;
840            entropy -= p * p.log2();
841        }
842    }
843    entropy
844}
845
846/// Compare two coding strategies and return which uses fewer estimated bits.
847///
848/// Returns `true` if strategy A (freqs_a) is more efficient than strategy B.
849pub fn compare_coding_strategies(freqs_a: &[u32], freqs_b: &[u32], symbol_count: u64) -> bool {
850    let entropy_a = estimate_entropy_from_freqs(freqs_a);
851    let entropy_b = estimate_entropy_from_freqs(freqs_b);
852    let bits_a = entropy_a * symbol_count as f64;
853    let bits_b = entropy_b * symbol_count as f64;
854    bits_a <= bits_b
855}
856
857// -------------------------------------------------------------------------
858// Symbol Frequency Adaptation (Sliding Window)
859// -------------------------------------------------------------------------
860
861/// Adaptive frequency tracker using a sliding window.
862///
863/// Maintains a ring buffer of recent symbols and computes up-to-date
864/// frequency counts for probability estimation.
865#[derive(Clone, Debug)]
866pub struct AdaptiveFrequencyTracker {
867    /// Ring buffer of recent symbols.
868    window: Vec<u8>,
869    /// Current write position in the ring buffer.
870    pos: usize,
871    /// Number of valid entries (may be less than capacity during warm-up).
872    count: usize,
873    /// Window capacity.
874    capacity: usize,
875    /// Running frequency counts for each symbol [0..255].
876    freq: [u32; 256],
877}
878
879impl AdaptiveFrequencyTracker {
880    /// Create a tracker with the given window size.
881    pub fn new(window_size: usize) -> Self {
882        let cap = window_size.max(1);
883        Self {
884            window: vec![0; cap],
885            pos: 0,
886            count: 0,
887            capacity: cap,
888            freq: [0u32; 256],
889        }
890    }
891
892    /// Add a new symbol observation.
893    pub fn observe(&mut self, symbol: u8) {
894        if self.count >= self.capacity {
895            // Evict the oldest symbol.
896            let oldest = self.window[self.pos];
897            self.freq[oldest as usize] = self.freq[oldest as usize].saturating_sub(1);
898        } else {
899            self.count += 1;
900        }
901        self.window[self.pos] = symbol;
902        self.freq[symbol as usize] += 1;
903        self.pos = (self.pos + 1) % self.capacity;
904    }
905
906    /// Get the current frequency of `symbol`.
907    pub fn frequency(&self, symbol: u8) -> u32 {
908        self.freq[symbol as usize]
909    }
910
911    /// Get the total number of observations in the window.
912    pub fn total(&self) -> usize {
913        self.count
914    }
915
916    /// Get the estimated probability of `symbol` (0.0 if no observations).
917    pub fn probability(&self, symbol: u8) -> f64 {
918        if self.count == 0 {
919            return 0.0;
920        }
921        self.freq[symbol as usize] as f64 / self.count as f64
922    }
923
924    /// Return the full 256-entry frequency table (snapshot).
925    pub fn frequency_table(&self) -> [u32; 256] {
926        self.freq
927    }
928
929    /// Reset the tracker to its initial state.
930    pub fn reset(&mut self) {
931        self.pos = 0;
932        self.count = 0;
933        self.freq = [0u32; 256];
934        for b in self.window.iter_mut() {
935            *b = 0;
936        }
937    }
938}
939
940#[cfg(test)]
941mod tests {
942    use super::*;
943
944    // --- ArithmeticCoder tests ---
945
946    #[test]
947    fn arithmetic_coder_new_initial_range() {
948        let coder = ArithmeticCoder::new();
949        assert_eq!(coder.low, 0);
950        assert_eq!(coder.high, 0xFFFF_FFFF);
951        // Initial range spans the full 32-bit space: 2^32.
952        assert_eq!(coder.get_range(), 0x1_0000_0000u64);
953    }
954
955    #[test]
956    fn arithmetic_coder_get_range() {
957        let c = ArithmeticCoder::new();
958        let initial_range = c.get_range();
959        // Initial range must be positive.
960        assert!(initial_range > 0);
961        // Initial range is the full 32-bit span.
962        assert_eq!(initial_range, 0x1_0000_0000u64);
963        // After encoding with a strongly-biased probability, the coder should
964        // still maintain a valid (positive) range.
965        let mut c2 = ArithmeticCoder::new();
966        c2.encode_bit(0.9, true);
967        assert!(c2.get_range() > 0);
968        assert!(c2.low <= c2.high);
969    }
970
971    #[test]
972    fn arithmetic_coder_encode_bit_does_not_panic() {
973        let mut c = ArithmeticCoder::new();
974        let _bytes = c.encode_bit(0.5, true);
975        let _bytes = c.encode_bit(0.5, false);
976        let _bytes = c.encode_bit(0.9, true);
977        // No panic is sufficient for this test.
978    }
979
980    #[test]
981    fn arithmetic_coder_bits_to_follow_increments() {
982        let mut c = ArithmeticCoder::new();
983        // Repeated near-50% bits tend to trigger E3 scaling.
984        for _ in 0..16 {
985            c.encode_bit(0.5, true);
986        }
987        // State should remain coherent (low ≤ high).
988        assert!(c.low <= c.high);
989    }
990
991    #[test]
992    fn arithmetic_coder_encode_sequence_returns_bytes() {
993        let mut c = ArithmeticCoder::new();
994        let mut all_bytes = Vec::new();
995        // Encode 32 bits with strong probability – should flush many bytes.
996        for _ in 0..32 {
997            all_bytes.extend(c.encode_bit(0.95, true));
998        }
999        // We don't verify bit-exact values, just that the coder is usable.
1000        assert!(all_bytes.len() <= 32 * 2); // sanity upper bound
1001    }
1002
1003    // --- bits_to_bytes helper ---
1004
1005    #[test]
1006    fn bits_to_bytes_empty() {
1007        let b = bits_to_bytes(&[]);
1008        assert!(b.is_empty());
1009    }
1010
1011    #[test]
1012    fn bits_to_bytes_full_byte() {
1013        // 0b1010_1010 = 0xAA
1014        let bits = [true, false, true, false, true, false, true, false];
1015        let b = bits_to_bytes(&bits);
1016        assert_eq!(b, vec![0xAA]);
1017    }
1018
1019    // --- RangeCoder tests ---
1020
1021    #[test]
1022    fn range_coder_new() {
1023        let rc = RangeCoder::new();
1024        assert_eq!(rc.range, 256);
1025        assert_eq!(rc.code, 0);
1026    }
1027
1028    #[test]
1029    fn range_coder_normalize_already_normalised() {
1030        let mut rc = RangeCoder::new();
1031        let bits = rc.normalize();
1032        assert_eq!(bits, 0); // already in [128, 256)
1033    }
1034
1035    #[test]
1036    fn range_coder_normalize_below_128() {
1037        let mut rc = RangeCoder { range: 32, code: 0 };
1038        let bits = rc.normalize();
1039        assert!(rc.range >= 128);
1040        assert_eq!(bits, 2); // 32 → 64 → 128, two doublings
1041    }
1042
1043    #[test]
1044    fn range_coder_decode_symbol_high_partition() {
1045        let mut rc = RangeCoder {
1046            range: 256,
1047            code: 200,
1048        };
1049        // split = (256 * 128) >> 8 = 128; code(200) >= split(128) → true
1050        let sym = rc.decode_symbol(128);
1051        assert!(sym);
1052        assert_eq!(rc.range, 256 - 128);
1053        assert_eq!(rc.code, 200 - 128);
1054    }
1055
1056    #[test]
1057    fn range_coder_decode_symbol_low_partition() {
1058        let mut rc = RangeCoder {
1059            range: 256,
1060            code: 50,
1061        };
1062        // split = 128; code(50) < split(128) → false
1063        let sym = rc.decode_symbol(128);
1064        assert!(!sym);
1065        assert_eq!(rc.range, 128);
1066        assert_eq!(rc.code, 50);
1067    }
1068
1069    // --- HuffmanNode tests ---
1070
1071    #[test]
1072    fn huffman_node_is_leaf_true() {
1073        let leaf = HuffmanNode {
1074            symbol: Some(42),
1075            freq: 10,
1076            left: None,
1077            right: None,
1078        };
1079        assert!(leaf.is_leaf());
1080    }
1081
1082    #[test]
1083    fn huffman_node_is_leaf_false() {
1084        let inner = HuffmanNode {
1085            symbol: None,
1086            freq: 20,
1087            left: Some(Box::new(HuffmanNode {
1088                symbol: Some(0),
1089                freq: 10,
1090                left: None,
1091                right: None,
1092            })),
1093            right: None,
1094        };
1095        assert!(!inner.is_leaf());
1096    }
1097
1098    #[test]
1099    fn build_huffman_tree_two_symbols() {
1100        let freqs = [10u32, 20];
1101        let tree = build_huffman_tree(&freqs);
1102        assert!(!tree.is_leaf());
1103        assert_eq!(tree.freq, 30);
1104        let codes = compute_huffman_codes(&tree, vec![]);
1105        // Two leaves → two codes.
1106        assert_eq!(codes.len(), 2);
1107    }
1108
1109    #[test]
1110    fn build_huffman_tree_multiple_symbols() {
1111        // Typical small alphabet.
1112        let freqs = [5u32, 9, 12, 13, 16, 45];
1113        let tree = build_huffman_tree(&freqs);
1114        let codes = compute_huffman_codes(&tree, vec![]);
1115        assert_eq!(codes.len(), 6);
1116        // Higher-frequency symbols should have shorter codes.
1117        let mut code_map = std::collections::HashMap::new();
1118        for (sym, code) in &codes {
1119            code_map.insert(*sym, code.len());
1120        }
1121        // Symbol 5 (freq=45) should have the shortest code.
1122        assert!(code_map[&5] <= code_map[&0]);
1123    }
1124
1125    #[test]
1126    fn build_huffman_tree_empty_freqs() {
1127        let tree = build_huffman_tree(&[]);
1128        // Degenerate: single leaf for symbol 0.
1129        assert!(tree.is_leaf());
1130        assert_eq!(tree.symbol, Some(0));
1131    }
1132
1133    #[test]
1134    fn build_huffman_tree_single_symbol() {
1135        let freqs = [0u32, 7, 0];
1136        let tree = build_huffman_tree(&freqs);
1137        // Wrapped in a parent.
1138        assert!(!tree.is_leaf());
1139        let codes = compute_huffman_codes(&tree, vec![]);
1140        assert_eq!(codes.len(), 1);
1141        assert_eq!(codes[0].0, 1); // symbol index 1
1142    }
1143
1144    #[test]
1145    fn compute_huffman_codes_all_unique() {
1146        let freqs = [1u32, 2, 4, 8];
1147        let tree = build_huffman_tree(&freqs);
1148        let codes = compute_huffman_codes(&tree, vec![]);
1149        let symbols: Vec<u8> = codes.iter().map(|(s, _)| *s).collect();
1150        // All symbols should appear exactly once.
1151        let mut sorted = symbols.clone();
1152        sorted.sort_unstable();
1153        sorted.dedup();
1154        assert_eq!(sorted.len(), symbols.len());
1155    }
1156
1157    // --- TableArithmeticCoder tests ---
1158
1159    #[test]
1160    fn table_coder_build_prob_table_basic() {
1161        let freqs = [10u32, 30, 20, 5];
1162        let table = build_prob_table(&freqs);
1163        assert_eq!(table.len(), 4);
1164        // All widths should be non-negative
1165        for entry in &table {
1166            // cum_prob_low stays within TABLE_SIZE
1167            assert!(entry.cum_prob_low <= TABLE_SIZE as u32);
1168        }
1169    }
1170
1171    #[test]
1172    fn table_coder_build_prob_table_empty() {
1173        let table = build_prob_table(&[]);
1174        assert!(table.is_empty());
1175    }
1176
1177    #[test]
1178    fn table_coder_encode_produces_bytes() {
1179        let freqs = [128u32, 128u32]; // equal prob
1180        let table = build_prob_table(&freqs);
1181        let mut enc = TableArithmeticCoder::new();
1182        for _ in 0..32 {
1183            enc.encode_symbol(false, &table[0]);
1184        }
1185        let data = enc.flush();
1186        assert!(!data.is_empty());
1187    }
1188
1189    #[test]
1190    fn table_coder_encode_decode_roundtrip() {
1191        // Use a uniform binary alphabet: two equal-frequency symbols.
1192        // With equal widths the encoder/decoder split at the midpoint.
1193        let freqs = [128u32, 128u32]; // equal prob → split always at TABLE_SIZE/2
1194        let table = build_prob_table(&freqs);
1195        let symbols: Vec<bool> = vec![false, false, true, false, true, true, false];
1196
1197        // Encode: false → low partition (index 0), true → high partition (index 1)
1198        let mut enc = TableArithmeticCoder::new();
1199        for &s in &symbols {
1200            enc.encode_symbol(s, &table[0]); // same entry for both; high/low is the flag
1201        }
1202        let data = enc.flush();
1203
1204        // Decode using the same entry
1205        let mut dec = TableArithmeticDecoder::new(&data);
1206        let mut decoded = Vec::new();
1207        for _ in 0..symbols.len() {
1208            decoded.push(dec.decode_symbol(&table[0]));
1209        }
1210
1211        assert_eq!(
1212            decoded, symbols,
1213            "Round-trip must reproduce the original symbols"
1214        );
1215    }
1216
1217    #[test]
1218    fn table_coder_bytes_emitted_before_flush() {
1219        let freqs = [1u32, 255u32];
1220        let table = build_prob_table(&freqs);
1221        let mut enc = TableArithmeticCoder::new();
1222        for _ in 0..100 {
1223            enc.encode_symbol(true, &table[1]);
1224        }
1225        // bytes_emitted() should reflect renormalisation output
1226        let mid_count = enc.bytes_emitted();
1227        let data = enc.flush();
1228        assert!(data.len() >= mid_count);
1229    }
1230
1231    #[test]
1232    fn table_coder_all_high_partition() {
1233        let freqs = [50u32, 206u32];
1234        let table = build_prob_table(&freqs);
1235        let symbols = vec![true; 20];
1236
1237        let mut enc = TableArithmeticCoder::new();
1238        for &s in &symbols {
1239            enc.encode_symbol(s, &table[1]);
1240        }
1241        let data = enc.flush();
1242
1243        let mut dec = TableArithmeticDecoder::new(&data);
1244        for _ in 0..symbols.len() {
1245            let sym = dec.decode_symbol(&table[0]);
1246            assert!(sym, "should decode as high partition");
1247        }
1248    }
1249
1250    #[test]
1251    fn table_coder_all_low_partition() {
1252        let freqs = [200u32, 56u32];
1253        let table = build_prob_table(&freqs);
1254        let symbols = vec![false; 20];
1255
1256        let mut enc = TableArithmeticCoder::new();
1257        for &s in &symbols {
1258            enc.encode_symbol(s, &table[0]);
1259        }
1260        let data = enc.flush();
1261
1262        let mut dec = TableArithmeticDecoder::new(&data);
1263        for _ in 0..symbols.len() {
1264            let sym = dec.decode_symbol(&table[0]);
1265            assert!(!sym, "should decode as low partition");
1266        }
1267    }
1268
1269    // --- CABAC Context tests ---
1270
1271    #[test]
1272    fn cabac_context_initial_equi_probable() {
1273        let ctx = CabacContext::new();
1274        let p = ctx.mps_probability();
1275        assert!((p - 0.5).abs() < 0.01);
1276    }
1277
1278    #[test]
1279    fn cabac_context_adapts_towards_mps() {
1280        let mut ctx = CabacContext::new();
1281        for _ in 0..20 {
1282            ctx.update(ctx.mps);
1283        }
1284        assert!(
1285            ctx.mps_probability() > 0.7,
1286            "should converge towards high confidence"
1287        );
1288    }
1289
1290    #[test]
1291    fn cabac_context_adapts_towards_lps() {
1292        let mut ctx = CabacContext::new();
1293        let lps = !ctx.mps;
1294        for _ in 0..30 {
1295            ctx.update(lps);
1296        }
1297        // After many LPS updates, MPS should have flipped.
1298        assert!(ctx.mps == lps || ctx.state <= 10);
1299    }
1300
1301    #[test]
1302    fn cabac_context_with_biased_state() {
1303        let ctx = CabacContext::with_state(120, true);
1304        assert!(ctx.mps_probability() > 0.9);
1305        assert!(ctx.mps);
1306    }
1307
1308    #[test]
1309    fn cabac_encoder_basic() {
1310        let mut enc = CabacEncoder::new(4);
1311        let mut bytes = Vec::new();
1312        for i in 0..16 {
1313            bytes.extend(enc.encode_bin(i % 4, i % 2 == 0));
1314        }
1315        assert_eq!(enc.bins_encoded, 16);
1316    }
1317
1318    #[test]
1319    fn cabac_encoder_bypass_mode() {
1320        let mut enc = CabacEncoder::new(1);
1321        let bytes = enc.encode_bypass(true);
1322        assert_eq!(enc.bins_encoded, 1);
1323        // Bypass should not affect any context.
1324        let p = enc.contexts[0].mps_probability();
1325        assert!((p - 0.5).abs() < 0.01);
1326    }
1327
1328    // --- Enhanced Range Coder tests ---
1329
1330    #[test]
1331    fn range_encoder_encode_flush() {
1332        let mut enc = RangeEncoder::new();
1333        enc.encode(0, 50, 100);
1334        enc.encode(50, 50, 100);
1335        let data = enc.flush();
1336        assert!(!data.is_empty());
1337    }
1338
1339    #[test]
1340    fn range_encoder_bytes_emitted() {
1341        let mut enc = RangeEncoder::new();
1342        for _ in 0..100 {
1343            enc.encode(0, 128, 256);
1344        }
1345        let mid = enc.bytes_emitted();
1346        let data = enc.flush();
1347        assert!(data.len() >= mid);
1348    }
1349
1350    // --- Huffman Optimisation tests ---
1351
1352    #[test]
1353    fn optimal_code_lengths_basic() {
1354        let freqs = [10u32, 20, 40, 80];
1355        let lengths = optimal_code_lengths(&freqs, 15);
1356        assert_eq!(lengths.len(), 4);
1357        // Higher frequency → shorter code.
1358        let len_map: std::collections::HashMap<usize, u8> = lengths.iter().cloned().collect();
1359        assert!(len_map[&3] <= len_map[&0]);
1360    }
1361
1362    #[test]
1363    fn optimal_code_lengths_max_length_respected() {
1364        let freqs = [1u32, 1, 1, 1, 1, 1, 1, 1, 100];
1365        let lengths = optimal_code_lengths(&freqs, 4);
1366        for (_, l) in &lengths {
1367            assert!(*l <= 4, "code length {} exceeds max 4", l);
1368        }
1369    }
1370
1371    #[test]
1372    fn optimal_code_lengths_single_symbol() {
1373        let freqs = [0u32, 0, 42];
1374        let lengths = optimal_code_lengths(&freqs, 10);
1375        assert_eq!(lengths.len(), 1);
1376        assert_eq!(lengths[0], (2, 1));
1377    }
1378
1379    #[test]
1380    fn optimal_code_lengths_empty() {
1381        let lengths = optimal_code_lengths(&[], 10);
1382        assert!(lengths.is_empty());
1383    }
1384
1385    // --- Entropy Estimation tests ---
1386
1387    #[test]
1388    fn estimate_block_entropy_uniform() {
1389        // All same symbol: entropy = 0.
1390        let block = vec![42u8; 100];
1391        let bits = estimate_block_entropy(&block);
1392        assert!(
1393            bits < 1.0,
1394            "uniform block entropy should be ~0, got {}",
1395            bits
1396        );
1397    }
1398
1399    #[test]
1400    fn estimate_block_entropy_binary() {
1401        // Two symbols, equal frequency: entropy = 1 bit/symbol.
1402        let mut block = vec![0u8; 100];
1403        for b in block.iter_mut().step_by(2) {
1404            *b = 1;
1405        }
1406        let bits = estimate_block_entropy(&block);
1407        let bits_per_sym = bits / 100.0;
1408        assert!((bits_per_sym - 1.0).abs() < 0.1);
1409    }
1410
1411    #[test]
1412    fn estimate_entropy_from_freqs_uniform() {
1413        // 256 symbols each with freq 1: max entropy = 8 bits/sym.
1414        let freqs = vec![1u32; 256];
1415        let entropy = estimate_entropy_from_freqs(&freqs);
1416        assert!((entropy - 8.0).abs() < 0.01);
1417    }
1418
1419    #[test]
1420    fn compare_coding_strategies_picks_better() {
1421        // Strategy A: more concentrated → lower entropy.
1422        let a = [100u32, 1, 1, 1];
1423        let b = [25u32, 25, 25, 25];
1424        assert!(compare_coding_strategies(&a, &b, 1000));
1425        assert!(!compare_coding_strategies(&b, &a, 1000));
1426    }
1427
1428    // --- Adaptive Frequency Tracker tests ---
1429
1430    #[test]
1431    fn adaptive_tracker_basic() {
1432        let mut tracker = AdaptiveFrequencyTracker::new(10);
1433        tracker.observe(5);
1434        tracker.observe(5);
1435        tracker.observe(3);
1436        assert_eq!(tracker.frequency(5), 2);
1437        assert_eq!(tracker.frequency(3), 1);
1438        assert_eq!(tracker.total(), 3);
1439    }
1440
1441    #[test]
1442    fn adaptive_tracker_window_eviction() {
1443        let mut tracker = AdaptiveFrequencyTracker::new(3);
1444        tracker.observe(1);
1445        tracker.observe(2);
1446        tracker.observe(3);
1447        assert_eq!(tracker.frequency(1), 1);
1448
1449        // Adding a 4th symbol evicts symbol 1.
1450        tracker.observe(4);
1451        assert_eq!(tracker.frequency(1), 0);
1452        assert_eq!(tracker.frequency(4), 1);
1453        assert_eq!(tracker.total(), 3);
1454    }
1455
1456    #[test]
1457    fn adaptive_tracker_probability() {
1458        let mut tracker = AdaptiveFrequencyTracker::new(100);
1459        for _ in 0..75 {
1460            tracker.observe(0);
1461        }
1462        for _ in 0..25 {
1463            tracker.observe(1);
1464        }
1465        let p0 = tracker.probability(0);
1466        assert!((p0 - 0.75).abs() < 0.01);
1467    }
1468
1469    #[test]
1470    fn adaptive_tracker_reset() {
1471        let mut tracker = AdaptiveFrequencyTracker::new(10);
1472        tracker.observe(42);
1473        tracker.reset();
1474        assert_eq!(tracker.frequency(42), 0);
1475        assert_eq!(tracker.total(), 0);
1476    }
1477
1478    #[test]
1479    fn adaptive_tracker_frequency_table() {
1480        let mut tracker = AdaptiveFrequencyTracker::new(100);
1481        tracker.observe(10);
1482        tracker.observe(10);
1483        tracker.observe(20);
1484        let table = tracker.frequency_table();
1485        assert_eq!(table[10], 2);
1486        assert_eq!(table[20], 1);
1487        assert_eq!(table[0], 0);
1488    }
1489}