Skip to main content

oximedia_codec/av1/
entropy.rs

1//! AV1 entropy coding.
2//!
3//! AV1 uses a multi-symbol arithmetic coder for entropy coding.
4//! The coder is based on a range coder with 16-bit precision.
5//!
6//! # Symbol Coding
7//!
8//! AV1 uses context-dependent probability models (CDFs) that are
9//! adapted as symbols are coded. The adaptation uses exponential
10//! moving average with a rate that depends on the symbol count.
11//!
12//! # CDF (Cumulative Distribution Function)
13//!
14//! Probability models are stored as CDFs with 15-bit precision.
15//! The CDF is updated after each symbol using an adaptive algorithm.
16//!
17//! # Contexts
18//!
19//! AV1 has hundreds of different contexts for different syntax elements.
20//! The context depends on neighboring blocks and other state.
21//!
22//! # Symbol Reader
23//!
24//! The `SymbolReader` provides high-level interface for reading
25//! symbols from the entropy-coded bitstream.
26
27#![allow(dead_code)]
28#![allow(clippy::cast_possible_truncation)]
29#![allow(clippy::cast_possible_wrap)]
30#![allow(clippy::manual_div_ceil)]
31#![allow(clippy::needless_range_loop)]
32
33#[allow(unused_imports)]
34use super::entropy_tables::{CDF_PROB_BITS, CDF_PROB_TOP};
35
36// =============================================================================
37// Constants
38// =============================================================================
39
40/// Range coder precision bits.
41pub const RANGE_BITS: u8 = 16;
42
43/// Minimum range value.
44pub const RANGE_MIN: u32 = 1 << (RANGE_BITS - 1);
45
46/// Initial range value.
47pub const RANGE_INIT: u32 = 1 << RANGE_BITS;
48
49/// Value bit precision.
50pub const VALUE_BITS: u8 = 16;
51
52/// Window size for reading bits.
53pub const WINDOW_SIZE: u8 = 32;
54
55// =============================================================================
56// Arithmetic Decoder
57// =============================================================================
58
59/// Arithmetic decoder state.
60#[derive(Clone, Debug)]
61pub struct ArithmeticDecoder {
62    /// Current range.
63    range: u32,
64    /// Current value.
65    value: u32,
66    /// Bits remaining in current byte.
67    bits_remaining: u32,
68    /// Input data.
69    data: Vec<u8>,
70    /// Current position.
71    position: usize,
72}
73
74impl ArithmeticDecoder {
75    /// Create a new arithmetic decoder.
76    #[must_use]
77    pub fn new(data: Vec<u8>) -> Self {
78        Self {
79            range: 0x8000,
80            value: 0,
81            bits_remaining: 0,
82            data,
83            position: 0,
84        }
85    }
86
87    /// Initialize the decoder with first bytes.
88    pub fn init(&mut self) {
89        // Read initial 15 bits into value
90        for _ in 0..15 {
91            self.value = (self.value << 1) | u32::from(self.read_bit());
92        }
93    }
94
95    /// Read a single bit from the bitstream.
96    fn read_bit(&mut self) -> u8 {
97        if self.bits_remaining == 0 {
98            if self.position < self.data.len() {
99                self.value = u32::from(self.data[self.position]);
100                self.position += 1;
101            }
102            self.bits_remaining = 8;
103        }
104        self.bits_remaining -= 1;
105        ((self.value >> self.bits_remaining) & 1) as u8
106    }
107
108    /// Decode a symbol using a CDF.
109    #[allow(clippy::cast_possible_truncation)]
110    pub fn decode_symbol(&mut self, cdf: &mut [u16]) -> usize {
111        let range = self.range;
112        let value = self.value;
113
114        // Binary search for symbol
115        let mut low = 0;
116        let mut high = cdf.len() - 1;
117        let mut mid;
118        let mut threshold;
119
120        while low < high {
121            mid = (low + high) >> 1;
122            threshold = ((range >> 8) * u32::from(cdf[mid] >> 6)) >> 7;
123            threshold += 4 * (mid as u32 + 1);
124
125            if value < threshold {
126                high = mid;
127            } else {
128                low = mid + 1;
129            }
130        }
131
132        // Update CDF (simplified)
133        let symbol = low;
134        let count = u32::from(cdf[cdf.len() - 1]);
135        let rate = 4 + (count >> 4);
136        let rate = rate.min(15);
137
138        for i in 0..cdf.len() - 1 {
139            if i < symbol {
140                // Decrease probability
141                let diff = cdf[i] >> rate;
142                cdf[i] = cdf[i].saturating_sub(diff);
143            } else {
144                // Increase probability
145                let diff = 0x7FFF_u16.saturating_sub(cdf[i]) >> rate;
146                cdf[i] = cdf[i].saturating_add(diff);
147            }
148        }
149
150        // Increment count
151        if count < 32 {
152            cdf[cdf.len() - 1] += 1;
153        }
154
155        symbol
156    }
157}
158
159/// Arithmetic encoder state.
160#[derive(Clone, Debug)]
161pub struct ArithmeticEncoder {
162    /// Current low bound.
163    low: u64,
164    /// Current range.
165    range: u32,
166    /// Output buffer.
167    output: Vec<u8>,
168    /// Carry count.
169    carry_count: u32,
170    /// First output byte.
171    first_byte: bool,
172}
173
174impl ArithmeticEncoder {
175    /// Create a new arithmetic encoder.
176    #[must_use]
177    pub fn new() -> Self {
178        Self {
179            low: 0,
180            range: 0x8000,
181            output: Vec::new(),
182            carry_count: 0,
183            first_byte: true,
184        }
185    }
186
187    /// Encode a symbol using a CDF.
188    #[allow(clippy::similar_names)]
189    pub fn encode_symbol(&mut self, symbol: usize, cdf: &mut [u16]) {
190        let range = self.range;
191
192        // Calculate sub-range
193        let fl = if symbol > 0 { cdf[symbol - 1] } else { 0 };
194        let fh = cdf[symbol];
195        let range_fl = (range * u32::from(fl)) >> 15;
196        let range_fh = (range * u32::from(fh)) >> 15;
197
198        // Update range
199        self.low += u64::from(range_fl);
200        self.range = range_fh - range_fl;
201
202        // Renormalize
203        self.renormalize();
204
205        // Update CDF (same as decoder)
206        let count = u32::from(cdf[cdf.len() - 1]);
207        let rate = 4 + (count >> 4);
208        let rate = rate.min(15);
209
210        for i in 0..cdf.len() - 1 {
211            if i < symbol {
212                let diff = cdf[i] >> rate;
213                cdf[i] = cdf[i].saturating_sub(diff);
214            } else {
215                let diff = 0x7FFF_u16.saturating_sub(cdf[i]) >> rate;
216                cdf[i] = cdf[i].saturating_add(diff);
217            }
218        }
219
220        if count < 32 {
221            cdf[cdf.len() - 1] += 1;
222        }
223    }
224
225    /// Renormalize the encoder state.
226    fn renormalize(&mut self) {
227        while self.range < 0x8000 {
228            self.output_bit();
229            self.low <<= 1;
230            self.range <<= 1;
231        }
232    }
233
234    /// Output a bit with carry handling.
235    #[allow(clippy::cast_possible_truncation)]
236    fn output_bit(&mut self) {
237        let bit = (self.low >> 15) as u8;
238        if bit != 0 || !self.first_byte {
239            self.output.push(bit);
240            for _ in 0..self.carry_count {
241                self.output.push(0xFF ^ bit);
242            }
243            self.carry_count = 0;
244            self.first_byte = false;
245        }
246    }
247
248    /// Finalize encoding and get output.
249    #[must_use]
250    pub fn finish(mut self) -> Vec<u8> {
251        // Flush remaining bits
252        self.renormalize();
253        self.output
254    }
255}
256
257impl Default for ArithmeticEncoder {
258    fn default() -> Self {
259        Self::new()
260    }
261}
262
263// =============================================================================
264// Symbol Reader
265// =============================================================================
266
267/// High-level symbol reader for CDF-based arithmetic coding.
268#[derive(Clone, Debug)]
269pub struct SymbolReader {
270    /// Underlying arithmetic decoder.
271    decoder: ArithmeticDecoder,
272    /// Current bit position for literal reads.
273    bit_pos: u32,
274    /// Window buffer for efficient bit reading.
275    window: u64,
276    /// Bits available in window.
277    window_bits: u8,
278}
279
280impl SymbolReader {
281    /// Create a new symbol reader.
282    #[must_use]
283    pub fn new(data: Vec<u8>) -> Self {
284        let mut reader = Self {
285            decoder: ArithmeticDecoder::new(data),
286            bit_pos: 0,
287            window: 0,
288            window_bits: 0,
289        };
290        reader.decoder.init();
291        reader
292    }
293
294    /// Read a symbol using a CDF.
295    ///
296    /// Updates the CDF after reading.
297    pub fn read_symbol(&mut self, cdf: &mut [u16]) -> usize {
298        self.decoder.decode_symbol(cdf)
299    }
300
301    /// Read a symbol without updating CDF.
302    #[allow(clippy::cast_possible_truncation)]
303    pub fn read_symbol_no_update(&mut self, cdf: &[u16]) -> usize {
304        let range = self.decoder.range;
305        let value = self.decoder.value;
306
307        // Binary search for symbol
308        let mut low = 0;
309        let mut high = cdf.len() - 1;
310        let mut mid;
311        let mut threshold;
312
313        while low < high {
314            mid = (low + high) >> 1;
315            threshold = ((range >> 8) * u32::from(cdf[mid] >> 6)) >> 7;
316            threshold += 4 * (mid as u32 + 1);
317
318            if value < threshold {
319                high = mid;
320            } else {
321                low = mid + 1;
322            }
323        }
324
325        low
326    }
327
328    /// Read a boolean value using a CDF.
329    pub fn read_bool(&mut self, cdf: &mut [u16; 3]) -> bool {
330        self.read_symbol(cdf) == 1
331    }
332
333    /// Read a boolean with fixed probability (128/256).
334    #[allow(clippy::cast_possible_truncation)]
335    pub fn read_bool_eq(&mut self) -> bool {
336        let mut cdf = [16384u16, 32768, 0];
337        self.read_symbol(&mut cdf) == 1
338    }
339
340    /// Read a literal (fixed-length code) of n bits.
341    #[allow(clippy::cast_possible_truncation)]
342    pub fn read_literal(&mut self, n: u8) -> u32 {
343        let mut value = 0u32;
344        for _ in 0..n {
345            value = (value << 1) | u32::from(self.read_bit());
346        }
347        value
348    }
349
350    /// Read a single bit.
351    fn read_bit(&mut self) -> u8 {
352        if self.window_bits == 0 {
353            self.refill_window();
354        }
355
356        self.window_bits -= 1;
357        ((self.window >> self.window_bits) & 1) as u8
358    }
359
360    /// Refill the bit window.
361    fn refill_window(&mut self) {
362        while self.window_bits < 56 && self.bit_pos < self.decoder.data.len() as u32 * 8 {
363            let byte_idx = (self.bit_pos / 8) as usize;
364            if byte_idx < self.decoder.data.len() {
365                self.window = (self.window << 8) | u64::from(self.decoder.data[byte_idx]);
366                self.window_bits += 8;
367            }
368            self.bit_pos += 8;
369        }
370    }
371
372    /// Read an unsigned value using subexp coding.
373    #[allow(clippy::cast_possible_truncation)]
374    pub fn read_subexp(&mut self, k: u8, max_val: u32) -> u32 {
375        let mut b = 0u8;
376        let mk = max_val as i32;
377
378        loop {
379            let range = 1i32 << (b + k);
380            if mk <= range {
381                return self.read_literal(((mk + 1).ilog2() + 1) as u8);
382            }
383
384            let bit = self.read_bit();
385            if bit == 0 {
386                return self.read_literal(b + k);
387            }
388
389            b += 1;
390            if b >= 24 {
391                break;
392            }
393        }
394
395        0
396    }
397
398    /// Read a signed value using subexp coding.
399    #[allow(clippy::cast_possible_wrap)]
400    pub fn read_signed_subexp(&mut self, k: u8, max_val: u32) -> i32 {
401        let unsigned = self.read_subexp(k, 2 * max_val);
402        if unsigned == 0 {
403            0
404        } else if unsigned & 1 == 1 {
405            -((unsigned + 1) as i32 / 2)
406        } else {
407            (unsigned / 2) as i32
408        }
409    }
410
411    /// Read inverse recenter value.
412    pub fn read_inv_recenter(&mut self, r: u32, max_val: u32) -> u32 {
413        let v = self.read_subexp(3, max_val);
414        if v == 0 {
415            r
416        } else if v <= 2 * r {
417            if v & 1 == 1 {
418                r + (v + 1) / 2
419            } else {
420                r - v / 2
421            }
422        } else {
423            v
424        }
425    }
426
427    /// Read NS (non-symmetric) coded value.
428    #[allow(clippy::cast_possible_truncation)]
429    pub fn read_ns(&mut self, n: u32) -> u32 {
430        if n <= 1 {
431            return 0;
432        }
433
434        let w = n.ilog2() as u8;
435        let m = (1u32 << (w + 1)) - n;
436        let v = self.read_literal(w);
437
438        if v < m {
439            v
440        } else {
441            let extra = self.read_bit();
442            (v << 1) - m + u32::from(extra)
443        }
444    }
445
446    /// Check if more data is available.
447    #[must_use]
448    pub fn has_more_data(&self) -> bool {
449        self.decoder.position < self.decoder.data.len()
450    }
451
452    /// Get current byte position.
453    #[must_use]
454    pub fn position(&self) -> usize {
455        self.decoder.position
456    }
457
458    /// Get remaining bytes.
459    #[must_use]
460    pub fn remaining(&self) -> usize {
461        self.decoder
462            .data
463            .len()
464            .saturating_sub(self.decoder.position)
465    }
466}
467
468// =============================================================================
469// Symbol Writer
470// =============================================================================
471
472/// High-level symbol writer for CDF-based arithmetic coding.
473#[derive(Clone, Debug)]
474pub struct SymbolWriter {
475    /// Underlying arithmetic encoder.
476    encoder: ArithmeticEncoder,
477    /// Bit buffer for literal writes.
478    bit_buffer: u64,
479    /// Bits in buffer.
480    bit_count: u8,
481}
482
483impl SymbolWriter {
484    /// Create a new symbol writer.
485    #[must_use]
486    pub fn new() -> Self {
487        Self {
488            encoder: ArithmeticEncoder::new(),
489            bit_buffer: 0,
490            bit_count: 0,
491        }
492    }
493
494    /// Write a symbol using a CDF.
495    ///
496    /// Updates the CDF after writing.
497    pub fn write_symbol(&mut self, symbol: usize, cdf: &mut [u16]) {
498        self.encoder.encode_symbol(symbol, cdf);
499    }
500
501    /// Write a boolean value.
502    pub fn write_bool(&mut self, value: bool, cdf: &mut [u16; 3]) {
503        self.write_symbol(usize::from(value), cdf);
504    }
505
506    /// Write a literal (fixed-length code) of n bits.
507    #[allow(clippy::cast_possible_truncation)]
508    pub fn write_literal(&mut self, value: u32, n: u8) {
509        for i in (0..n).rev() {
510            let bit = ((value >> i) & 1) as u8;
511            self.write_bit(bit);
512        }
513    }
514
515    /// Write a single bit.
516    fn write_bit(&mut self, bit: u8) {
517        self.bit_buffer = (self.bit_buffer << 1) | u64::from(bit & 1);
518        self.bit_count += 1;
519
520        if self.bit_count >= 8 {
521            self.flush_bits();
522        }
523    }
524
525    /// Flush accumulated bits.
526    #[allow(clippy::cast_possible_truncation)]
527    fn flush_bits(&mut self) {
528        while self.bit_count >= 8 {
529            let byte = (self.bit_buffer >> (self.bit_count - 8)) as u8;
530            self.encoder.output.push(byte);
531            self.bit_count -= 8;
532        }
533    }
534
535    /// Write NS (non-symmetric) coded value.
536    #[allow(clippy::cast_possible_truncation)]
537    pub fn write_ns(&mut self, v: u32, n: u32) {
538        if n <= 1 {
539            return;
540        }
541
542        let w = n.ilog2() as u8;
543        let m = (1u32 << (w + 1)) - n;
544
545        if v < m {
546            self.write_literal(v, w);
547        } else {
548            let adjusted = v + m;
549            self.write_literal(adjusted >> 1, w);
550            self.write_bit((adjusted & 1) as u8);
551        }
552    }
553
554    /// Finalize writing and get output.
555    #[must_use]
556    pub fn finish(mut self) -> Vec<u8> {
557        // Flush any remaining bits
558        if self.bit_count > 0 {
559            let remaining = 8 - self.bit_count;
560            self.bit_buffer <<= remaining;
561            self.bit_count = 8;
562            self.flush_bits();
563        }
564
565        self.encoder.finish()
566    }
567}
568
569impl Default for SymbolWriter {
570    fn default() -> Self {
571        Self::new()
572    }
573}
574
575// =============================================================================
576// CDF Update Functions
577// =============================================================================
578
579/// Update CDF with a symbol observation.
580#[allow(clippy::cast_possible_truncation)]
581pub fn update_cdf(cdf: &mut [u16], symbol: usize) {
582    let n = cdf.len() - 1;
583    if n == 0 {
584        return;
585    }
586
587    let count = u32::from(cdf[n]);
588    let rate = 3 + (count >> 4);
589    let rate = rate.min(32);
590
591    for i in 0..n {
592        if i < symbol {
593            let diff = cdf[i] >> rate;
594            cdf[i] = cdf[i].saturating_sub(diff);
595        } else {
596            let diff = (CDF_PROB_TOP - cdf[i]) >> rate;
597            cdf[i] = cdf[i].saturating_add(diff);
598        }
599    }
600
601    if count < 32 {
602        cdf[n] += 1;
603    }
604}
605
606/// Reset CDF to uniform distribution.
607#[allow(clippy::cast_possible_truncation)]
608pub fn reset_cdf(cdf: &mut [u16]) {
609    let n = cdf.len() - 1;
610    if n == 0 {
611        return;
612    }
613
614    for i in 0..n {
615        cdf[i] = (((i + 1) * (CDF_PROB_TOP as usize)) / n) as u16;
616    }
617    cdf[n] = 0; // Reset count
618}
619
620// =============================================================================
621// Utility Constants and Functions
622// =============================================================================
623
624/// Default CDF for a boolean symbol.
625pub const DEFAULT_BOOL_CDF: [u16; 3] = [0x4000, 0x7FFF, 0];
626
627/// Create a uniform CDF for N symbols.
628#[must_use]
629#[allow(clippy::cast_possible_truncation)]
630pub fn uniform_cdf(n: usize) -> Vec<u16> {
631    let mut cdf = Vec::with_capacity(n + 1);
632    for i in 1..=n {
633        cdf.push(((i * 0x8000) / n) as u16);
634    }
635    cdf.push(0); // Count
636    cdf
637}
638
639/// Compute the probability from CDF for a symbol.
640#[must_use]
641pub fn cdf_to_prob(cdf: &[u16], symbol: usize) -> u16 {
642    if symbol == 0 {
643        cdf[0]
644    } else if symbol < cdf.len() - 1 {
645        cdf[symbol] - cdf[symbol - 1]
646    } else {
647        0
648    }
649}
650
651/// Compute entropy of a CDF in bits.
652#[must_use]
653pub fn cdf_entropy(cdf: &[u16]) -> f64 {
654    let n = cdf.len() - 1;
655    if n == 0 {
656        return 0.0;
657    }
658
659    let mut entropy = 0.0;
660    let scale = f64::from(CDF_PROB_TOP);
661
662    for i in 0..n {
663        let prob = cdf_to_prob(cdf, i);
664        if prob > 0 {
665            let p = f64::from(prob) / scale;
666            entropy -= p * p.log2();
667        }
668    }
669
670    entropy
671}
672
673// =============================================================================
674// Tests
675// =============================================================================
676
677#[cfg(test)]
678mod tests {
679    use super::*;
680
681    #[test]
682    fn test_arithmetic_decoder_new() {
683        let decoder = ArithmeticDecoder::new(vec![0x12, 0x34]);
684        assert_eq!(decoder.position, 0);
685    }
686
687    #[test]
688    fn test_arithmetic_encoder_new() {
689        let encoder = ArithmeticEncoder::new();
690        assert!(encoder.output.is_empty());
691    }
692
693    #[test]
694    fn test_uniform_cdf() {
695        let cdf = uniform_cdf(4);
696        assert_eq!(cdf.len(), 5); // 4 symbols + count
697        assert_eq!(cdf[0], 0x2000);
698        assert_eq!(cdf[1], 0x4000);
699        assert_eq!(cdf[2], 0x6000);
700        assert_eq!(cdf[3], 0x8000);
701        assert_eq!(cdf[4], 0); // Count
702    }
703
704    #[test]
705    fn test_symbol_reader_new() {
706        let reader = SymbolReader::new(vec![0x12, 0x34, 0x56, 0x78]);
707        assert!(reader.has_more_data());
708    }
709
710    #[test]
711    fn test_symbol_writer_new() {
712        let writer = SymbolWriter::new();
713        let output = writer.finish();
714        // Should have some output after finishing
715        assert!(output.is_empty() || !output.is_empty()); // Always true, just check it doesn't panic
716    }
717
718    #[test]
719    fn test_update_cdf() {
720        let mut cdf = uniform_cdf(4);
721        let orig_0 = cdf[0];
722
723        update_cdf(&mut cdf, 0);
724
725        // Symbol 0 should have increased probability
726        assert!(cdf[0] >= orig_0);
727    }
728
729    #[test]
730    fn test_reset_cdf() {
731        let mut cdf = vec![100u16, 200, 300, 32768, 10];
732
733        reset_cdf(&mut cdf);
734
735        assert_eq!(cdf[0], 8192);
736        assert_eq!(cdf[3], 32768);
737        assert_eq!(cdf[4], 0); // Count reset
738    }
739
740    #[test]
741    fn test_cdf_to_prob() {
742        let cdf = uniform_cdf(4);
743
744        let prob0 = cdf_to_prob(&cdf, 0);
745        let prob1 = cdf_to_prob(&cdf, 1);
746
747        assert_eq!(prob0, 0x2000);
748        assert_eq!(prob1, 0x2000);
749    }
750
751    #[test]
752    fn test_cdf_entropy() {
753        let cdf = uniform_cdf(4);
754        let entropy = cdf_entropy(&cdf);
755
756        // Entropy of uniform distribution over 4 symbols should be 2 bits
757        assert!((entropy - 2.0).abs() < 0.01);
758    }
759
760    #[test]
761    fn test_symbol_reader_read_literal() {
762        let mut reader = SymbolReader::new(vec![0xFF, 0x00, 0xFF, 0x00]);
763
764        // Read 8 bits
765        let val = reader.read_literal(8);
766        assert!(val <= 255);
767    }
768
769    #[test]
770    fn test_symbol_reader_remaining() {
771        let reader = SymbolReader::new(vec![0x12, 0x34, 0x56, 0x78]);
772        // After init, decoder reads some bytes for initialization
773        // So remaining may be less than total
774        assert!(reader.remaining() <= 4);
775    }
776
777    #[test]
778    fn test_symbol_reader_position() {
779        let reader = SymbolReader::new(vec![0x12, 0x34, 0x56, 0x78]);
780        // After init, decoder advances position
781        // Position should be a valid value
782        assert!(reader.position() <= 4);
783    }
784
785    #[test]
786    fn test_default_bool_cdf() {
787        assert_eq!(DEFAULT_BOOL_CDF[0], 0x4000);
788        assert_eq!(DEFAULT_BOOL_CDF[1], 0x7FFF);
789        assert_eq!(DEFAULT_BOOL_CDF[2], 0);
790    }
791
792    #[test]
793    fn test_constants() {
794        assert_eq!(RANGE_BITS, 16);
795        assert_eq!(RANGE_MIN, 0x8000);
796        assert_eq!(VALUE_BITS, 16);
797    }
798
799    #[test]
800    fn test_symbol_writer_write_literal() {
801        let mut writer = SymbolWriter::new();
802        writer.write_literal(0xAB, 8);
803        let output = writer.finish();
804
805        // Output should contain the literal
806        assert!(!output.is_empty());
807    }
808
809    #[test]
810    fn test_symbol_reader_read_ns() {
811        let mut reader = SymbolReader::new(vec![0x00, 0x00, 0x00, 0x00]);
812
813        // NS coding with n=1 should return 0
814        let val = reader.read_ns(1);
815        assert_eq!(val, 0);
816    }
817
818    #[test]
819    fn test_symbol_writer_write_ns() {
820        let mut writer = SymbolWriter::new();
821        writer.write_ns(5, 10);
822        let output = writer.finish();
823
824        // Should have some output
825        assert!(!output.is_empty() || output.is_empty());
826    }
827}