Skip to main content

oximedia_codec/av1/
entropy_encoder.rs

1//! AV1 entropy encoding with arithmetic coder and CDF context.
2//!
3//! This module implements the entropy encoding system for AV1, including:
4//!
5//! - Arithmetic encoding with range coder
6//! - CDF (Cumulative Distribution Function) context management
7//! - Symbol encoding with adaptive probabilities
8//! - Bitstream output with OBU framing
9//!
10//! # AV1 Entropy Coding
11//!
12//! AV1 uses an arithmetic coder with symbol probabilities stored as CDFs.
13//! The probabilities are adapted based on previously encoded symbols to
14//! improve compression efficiency.
15//!
16//! # References
17//!
18//! - AV1 Specification Section 8.2: Arithmetic Coding Engine
19//! - AV1 Specification Section 8.3: Symbol Decoding Process
20
21#![forbid(unsafe_code)]
22#![allow(dead_code)]
23#![allow(clippy::cast_possible_truncation)]
24#![allow(clippy::cast_precision_loss)]
25#![allow(clippy::cast_sign_loss)]
26#![allow(clippy::similar_names)]
27#![allow(clippy::too_many_arguments)]
28
29use super::entropy_tables::{CDF_PROB_BITS, CDF_PROB_TOP};
30
31// =============================================================================
32// Constants
33// =============================================================================
34
35/// Bits of precision for arithmetic coder.
36const EC_PROB_SHIFT: u32 = 6;
37
38/// Window size for arithmetic coder (2^16).
39const EC_WINDOW_SIZE: u32 = 1 << 16;
40
41/// Minimum range before renormalization.
42const EC_MIN_RANGE: u32 = 128;
43
44/// Maximum symbol alphabet size.
45const MAX_SYMBOL_VALUE: u16 = 15;
46
47/// CDF update rate (higher = faster adaptation).
48const CDF_UPDATE_RATE: u16 = 5;
49
50// =============================================================================
51// Arithmetic Encoder
52// =============================================================================
53
54/// Arithmetic encoder state.
55#[derive(Clone, Debug)]
56pub struct ArithmeticEncoder {
57    /// Current range.
58    range: u32,
59    /// Low value (accumulated bits).
60    low: u32,
61    /// Number of outstanding bits.
62    cnt: i32,
63    /// Output buffer.
64    buffer: Vec<u8>,
65}
66
67impl Default for ArithmeticEncoder {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl ArithmeticEncoder {
74    /// Create a new arithmetic encoder.
75    #[must_use]
76    pub fn new() -> Self {
77        Self {
78            range: EC_WINDOW_SIZE,
79            low: 0,
80            cnt: -9,
81            buffer: Vec::with_capacity(4096),
82        }
83    }
84
85    /// Encode a symbol using CDF.
86    ///
87    /// # Arguments
88    ///
89    /// * `symbol` - Symbol value to encode
90    /// * `cdf` - Cumulative distribution function
91    pub fn encode_symbol(&mut self, symbol: u16, cdf: &[u16]) {
92        assert!(symbol < cdf.len() as u16 - 1, "Symbol out of range");
93
94        let fl = u32::from(if symbol == 0 {
95            0
96        } else {
97            cdf[symbol as usize - 1]
98        });
99        let fh = u32::from(cdf[symbol as usize]);
100        let _ft = u32::from(cdf[cdf.len() - 1]);
101
102        // Compute new range
103        let u = self.range;
104        let v = ((u >> 8) * (fh - fl)) >> (CDF_PROB_BITS - 8);
105        let r_new = if v < EC_MIN_RANGE { EC_MIN_RANGE } else { v };
106
107        // Update low value
108        self.low += ((u >> 8) * fl) >> (CDF_PROB_BITS - 8);
109        self.range = r_new;
110
111        // Renormalize if needed
112        self.renormalize();
113    }
114
115    /// Encode a binary symbol (0 or 1).
116    pub fn encode_bool(&mut self, symbol: bool, prob: u16) {
117        // For boolean: CDF needs 3 values for 2 symbols: [0], [prob(0)], [total]
118        // But we store cumulative, so: [prob_0, prob_0 + prob_1] where prob_0 + prob_1 = total
119        let cdf = [CDF_PROB_TOP - prob, CDF_PROB_TOP, CDF_PROB_TOP];
120        let symbol_val = u16::from(symbol);
121        self.encode_symbol(symbol_val, &cdf);
122    }
123
124    /// Encode a literal value with uniform distribution.
125    pub fn encode_literal(&mut self, value: u32, num_bits: u8) {
126        for i in (0..num_bits).rev() {
127            let bit = (value >> i) & 1;
128            self.encode_bool(bit != 0, CDF_PROB_TOP / 2);
129        }
130    }
131
132    /// Renormalize the encoder state.
133    fn renormalize(&mut self) {
134        while self.range < EC_MIN_RANGE {
135            let c = (self.low >> 23) as u8;
136            self.buffer.push(c);
137
138            self.low = (self.low << 8) & 0x7F_FF_FF;
139            self.range <<= 8;
140            self.cnt += 8;
141        }
142    }
143
144    /// Flush encoder and get output bytes.
145    pub fn flush(&mut self) -> Vec<u8> {
146        // Final renormalization - output accumulated bits
147        while self.cnt >= 0 {
148            let c = (self.low >> 23) as u8;
149            self.buffer.push(c);
150            self.low = (self.low << 8) & 0x7F_FF_FF;
151            self.cnt -= 8;
152        }
153
154        // Output the remaining partial byte from the encoder state
155        let c = (self.low >> 23) as u8;
156        self.buffer.push(c);
157
158        // Ensure byte alignment (pad to multiple of 4)
159        while self.buffer.len() % 4 != 0 {
160            self.buffer.push(0);
161        }
162
163        std::mem::take(&mut self.buffer)
164    }
165
166    /// Get current buffer without flushing.
167    #[must_use]
168    pub fn buffer(&self) -> &[u8] {
169        &self.buffer
170    }
171
172    /// Reset encoder state.
173    pub fn reset(&mut self) {
174        self.range = EC_WINDOW_SIZE;
175        self.low = 0;
176        self.cnt = -9;
177        self.buffer.clear();
178    }
179}
180
181// =============================================================================
182// CDF Context Management
183// =============================================================================
184
185/// CDF (Cumulative Distribution Function) for symbol probabilities.
186#[derive(Clone, Debug)]
187pub struct CdfContext {
188    /// CDF values (cumulative probabilities).
189    cdf: Vec<u16>,
190    /// Number of symbols in alphabet.
191    nsymb: usize,
192}
193
194impl CdfContext {
195    /// Create a new CDF context with uniform distribution.
196    #[must_use]
197    pub fn new(nsymb: usize) -> Self {
198        let mut cdf = Vec::with_capacity(nsymb + 1);
199        let step = CDF_PROB_TOP / nsymb as u16;
200
201        for i in 0..nsymb {
202            cdf.push(step * (i as u16 + 1));
203        }
204        cdf[nsymb - 1] = CDF_PROB_TOP;
205
206        Self { cdf, nsymb }
207    }
208
209    /// Get CDF slice.
210    #[must_use]
211    pub fn cdf(&self) -> &[u16] {
212        &self.cdf
213    }
214
215    /// Update CDF based on observed symbol.
216    pub fn update(&mut self, symbol: u16) {
217        if symbol >= self.nsymb as u16 {
218            return;
219        }
220
221        // Adaptive CDF update using moving average
222        for i in symbol as usize..self.nsymb {
223            let delta = CDF_PROB_TOP.saturating_sub(self.cdf[i]) >> CDF_UPDATE_RATE;
224            self.cdf[i] = self.cdf[i].saturating_add(delta);
225        }
226
227        // Ensure last value is always CDF_PROB_TOP
228        self.cdf[self.nsymb - 1] = CDF_PROB_TOP;
229    }
230
231    /// Reset CDF to uniform distribution.
232    pub fn reset(&mut self) {
233        let step = CDF_PROB_TOP / self.nsymb as u16;
234        for i in 0..self.nsymb {
235            self.cdf[i] = step * (i as u16 + 1);
236        }
237        self.cdf[self.nsymb - 1] = CDF_PROB_TOP;
238    }
239}
240
241// =============================================================================
242// Symbol Encoder
243// =============================================================================
244
245/// High-level symbol encoder with CDF management.
246#[derive(Clone, Debug)]
247pub struct SymbolEncoder {
248    /// Arithmetic encoder.
249    encoder: ArithmeticEncoder,
250    /// CDF contexts for different symbol types.
251    contexts: Vec<CdfContext>,
252}
253
254impl Default for SymbolEncoder {
255    fn default() -> Self {
256        Self::new()
257    }
258}
259
260impl SymbolEncoder {
261    /// Create a new symbol encoder.
262    #[must_use]
263    pub fn new() -> Self {
264        Self {
265            encoder: ArithmeticEncoder::new(),
266            contexts: Vec::new(),
267        }
268    }
269
270    /// Initialize contexts for encoding.
271    pub fn init_contexts(&mut self, num_contexts: usize, nsymb: usize) {
272        self.contexts.clear();
273        for _ in 0..num_contexts {
274            self.contexts.push(CdfContext::new(nsymb));
275        }
276    }
277
278    /// Encode a symbol with given context.
279    pub fn encode(&mut self, symbol: u16, context_id: usize) {
280        if context_id >= self.contexts.len() {
281            // Use default uniform CDF
282            let cdf = CdfContext::new(MAX_SYMBOL_VALUE as usize + 1);
283            self.encoder.encode_symbol(symbol, cdf.cdf());
284            return;
285        }
286
287        let cdf = self.contexts[context_id].cdf().to_vec();
288        self.encoder.encode_symbol(symbol, &cdf);
289
290        // Update CDF after encoding
291        self.contexts[context_id].update(symbol);
292    }
293
294    /// Encode a boolean value.
295    pub fn encode_bool(&mut self, value: bool) {
296        self.encoder.encode_bool(value, CDF_PROB_TOP / 2);
297    }
298
299    /// Encode a literal value.
300    pub fn encode_literal(&mut self, value: u32, num_bits: u8) {
301        self.encoder.encode_literal(value, num_bits);
302    }
303
304    /// Finish encoding and get output.
305    pub fn finish(&mut self) -> Vec<u8> {
306        self.encoder.flush()
307    }
308
309    /// Get current output without finishing.
310    #[must_use]
311    pub fn buffer(&self) -> &[u8] {
312        self.encoder.buffer()
313    }
314
315    /// Reset encoder and contexts.
316    pub fn reset(&mut self) {
317        self.encoder.reset();
318        for ctx in &mut self.contexts {
319            ctx.reset();
320        }
321    }
322}
323
324// =============================================================================
325// Bitstream Writer
326// =============================================================================
327
328/// Bitstream writer for byte-aligned output.
329#[derive(Clone, Debug, Default)]
330pub struct BitstreamWriter {
331    /// Output buffer.
332    buffer: Vec<u8>,
333    /// Current byte being written.
334    current_byte: u8,
335    /// Number of bits written in current byte.
336    bit_pos: u8,
337}
338
339impl BitstreamWriter {
340    /// Create a new bitstream writer.
341    #[must_use]
342    pub fn new() -> Self {
343        Self {
344            buffer: Vec::new(),
345            current_byte: 0,
346            bit_pos: 0,
347        }
348    }
349
350    /// Write a single bit.
351    pub fn write_bit(&mut self, bit: bool) {
352        if bit {
353            self.current_byte |= 1 << (7 - self.bit_pos);
354        }
355
356        self.bit_pos += 1;
357        if self.bit_pos == 8 {
358            self.buffer.push(self.current_byte);
359            self.current_byte = 0;
360            self.bit_pos = 0;
361        }
362    }
363
364    /// Write multiple bits from a value.
365    pub fn write_bits(&mut self, value: u32, num_bits: u8) {
366        for i in (0..num_bits).rev() {
367            let bit = (value >> i) & 1;
368            self.write_bit(bit != 0);
369        }
370    }
371
372    /// Write a byte-aligned value.
373    pub fn write_byte(&mut self, byte: u8) {
374        self.align();
375        self.buffer.push(byte);
376    }
377
378    /// Align to byte boundary.
379    pub fn align(&mut self) {
380        if self.bit_pos != 0 {
381            self.buffer.push(self.current_byte);
382            self.current_byte = 0;
383            self.bit_pos = 0;
384        }
385    }
386
387    /// Write a slice of bytes.
388    pub fn write_bytes(&mut self, bytes: &[u8]) {
389        self.align();
390        self.buffer.extend_from_slice(bytes);
391    }
392
393    /// Get output buffer.
394    #[must_use]
395    pub fn buffer(&self) -> &[u8] {
396        &self.buffer
397    }
398
399    /// Consume writer and get output.
400    #[must_use]
401    pub fn finish(mut self) -> Vec<u8> {
402        self.align();
403        self.buffer
404    }
405
406    /// Get buffer length in bytes.
407    #[must_use]
408    pub fn len(&self) -> usize {
409        self.buffer.len() + usize::from(self.bit_pos > 0)
410    }
411
412    /// Check if writer is empty.
413    #[must_use]
414    pub fn is_empty(&self) -> bool {
415        self.buffer.is_empty() && self.bit_pos == 0
416    }
417
418    /// Reset writer.
419    pub fn reset(&mut self) {
420        self.buffer.clear();
421        self.current_byte = 0;
422        self.bit_pos = 0;
423    }
424}
425
426// =============================================================================
427// OBU Writer
428// =============================================================================
429
430/// OBU (Open Bitstream Unit) writer.
431#[derive(Clone, Debug)]
432pub struct ObuWriter {
433    /// Bitstream writer.
434    writer: BitstreamWriter,
435}
436
437impl Default for ObuWriter {
438    fn default() -> Self {
439        Self::new()
440    }
441}
442
443impl ObuWriter {
444    /// Create a new OBU writer.
445    #[must_use]
446    pub fn new() -> Self {
447        Self {
448            writer: BitstreamWriter::new(),
449        }
450    }
451
452    /// Write OBU header.
453    pub fn write_obu_header(&mut self, obu_type: u8, has_size: bool) {
454        // Forbidden bit
455        self.writer.write_bit(false);
456
457        // OBU type (4 bits)
458        self.writer.write_bits(u32::from(obu_type), 4);
459
460        // Extension flag
461        self.writer.write_bit(false);
462
463        // Has size flag
464        self.writer.write_bit(has_size);
465
466        // Reserved bit
467        self.writer.write_bit(false);
468    }
469
470    /// Write LEB128 encoded size.
471    pub fn write_leb128(&mut self, mut value: u64) {
472        loop {
473            let mut byte = (value & 0x7F) as u8;
474            value >>= 7;
475
476            if value != 0 {
477                byte |= 0x80;
478            }
479
480            self.writer.write_byte(byte);
481
482            if value == 0 {
483                break;
484            }
485        }
486    }
487
488    /// Write OBU with size field.
489    pub fn write_obu(&mut self, obu_type: u8, payload: &[u8]) {
490        self.write_obu_header(obu_type, true);
491        self.write_leb128(payload.len() as u64);
492        self.writer.write_bytes(payload);
493    }
494
495    /// Get output buffer.
496    #[must_use]
497    pub fn buffer(&self) -> &[u8] {
498        self.writer.buffer()
499    }
500
501    /// Finish and get output.
502    #[must_use]
503    pub fn finish(self) -> Vec<u8> {
504        self.writer.finish()
505    }
506}
507
508// =============================================================================
509// Utility Functions
510// =============================================================================
511
512/// Compute CDF from probability mass function (PMF).
513#[must_use]
514pub fn pmf_to_cdf(pmf: &[u16]) -> Vec<u16> {
515    let mut cdf = Vec::with_capacity(pmf.len());
516    let mut cumsum = 0u16;
517
518    for &p in pmf {
519        cumsum = cumsum.saturating_add(p);
520        cdf.push(cumsum);
521    }
522
523    // Normalize to CDF_PROB_TOP
524    if let Some(&last) = cdf.last() {
525        if last > 0 && last != CDF_PROB_TOP {
526            for val in &mut cdf {
527                *val = (u32::from(*val) * u32::from(CDF_PROB_TOP) / u32::from(last)) as u16;
528            }
529        }
530    }
531
532    cdf
533}
534
535/// Estimate rate for symbol given CDF.
536#[must_use]
537pub fn estimate_symbol_rate(symbol: u16, cdf: &[u16]) -> f32 {
538    if symbol >= cdf.len() as u16 {
539        return 8.0; // Default rate for unknown symbol
540    }
541
542    let fl = if symbol == 0 {
543        0
544    } else {
545        cdf[symbol as usize - 1]
546    };
547    let fh = cdf[symbol as usize];
548    let prob = fh.saturating_sub(fl);
549
550    if prob == 0 {
551        16.0 // Very unlikely symbol
552    } else {
553        -(f32::from(prob) / f32::from(CDF_PROB_TOP)).log2()
554    }
555}
556
557// =============================================================================
558// Tests
559// =============================================================================
560
561#[cfg(test)]
562mod tests {
563    use super::*;
564
565    #[test]
566    fn test_arithmetic_encoder_creation() {
567        let encoder = ArithmeticEncoder::new();
568        assert_eq!(encoder.range, EC_WINDOW_SIZE);
569        assert_eq!(encoder.low, 0);
570        assert!(encoder.buffer.is_empty());
571    }
572
573    #[test]
574    fn test_arithmetic_encoder_bool() {
575        let mut encoder = ArithmeticEncoder::new();
576        encoder.encode_bool(true, CDF_PROB_TOP / 2);
577        encoder.encode_bool(false, CDF_PROB_TOP / 2);
578
579        let output = encoder.flush();
580        assert!(!output.is_empty());
581    }
582
583    #[test]
584    fn test_arithmetic_encoder_literal() {
585        let mut encoder = ArithmeticEncoder::new();
586        encoder.encode_literal(0xFF, 8);
587
588        let output = encoder.flush();
589        assert!(!output.is_empty());
590    }
591
592    #[test]
593    fn test_cdf_context_creation() {
594        let cdf = CdfContext::new(4);
595        assert_eq!(cdf.nsymb, 4);
596        assert_eq!(cdf.cdf().len(), 4);
597        assert_eq!(
598            *cdf.cdf().last().expect("should have last element"),
599            CDF_PROB_TOP
600        );
601    }
602
603    #[test]
604    fn test_cdf_context_update() {
605        let mut cdf = CdfContext::new(4);
606        let initial_cdf = cdf.cdf().to_vec();
607
608        cdf.update(1);
609        let updated_cdf = cdf.cdf();
610
611        // CDF should change after update
612        assert_ne!(initial_cdf, updated_cdf);
613        assert_eq!(
614            *updated_cdf.last().expect("should have last element"),
615            CDF_PROB_TOP
616        );
617    }
618
619    #[test]
620    fn test_cdf_context_reset() {
621        let mut cdf = CdfContext::new(4);
622        let initial_cdf = cdf.cdf().to_vec();
623
624        cdf.update(1);
625        cdf.update(2);
626        cdf.reset();
627
628        assert_eq!(cdf.cdf(), &initial_cdf[..]);
629    }
630
631    #[test]
632    fn test_symbol_encoder() {
633        let mut encoder = SymbolEncoder::new();
634        encoder.init_contexts(4, 8);
635
636        encoder.encode(0, 0);
637        encoder.encode(1, 0);
638        encoder.encode(2, 1);
639
640        let output = encoder.finish();
641        assert!(!output.is_empty());
642    }
643
644    #[test]
645    fn test_symbol_encoder_bool() {
646        let mut encoder = SymbolEncoder::new();
647        encoder.encode_bool(true);
648        encoder.encode_bool(false);
649        encoder.encode_bool(true);
650
651        let output = encoder.finish();
652        assert!(!output.is_empty());
653    }
654
655    #[test]
656    fn test_bitstream_writer_bit() {
657        let mut writer = BitstreamWriter::new();
658        writer.write_bit(true);
659        writer.write_bit(false);
660        writer.write_bit(true);
661        writer.write_bit(true);
662        writer.write_bit(false);
663        writer.write_bit(false);
664        writer.write_bit(false);
665        writer.write_bit(true);
666
667        let output = writer.finish();
668        assert_eq!(output.len(), 1);
669        assert_eq!(output[0], 0b1011_0001);
670    }
671
672    #[test]
673    fn test_bitstream_writer_bits() {
674        let mut writer = BitstreamWriter::new();
675        writer.write_bits(0xFF, 8);
676
677        let output = writer.finish();
678        assert_eq!(output.len(), 1);
679        assert_eq!(output[0], 0xFF);
680    }
681
682    #[test]
683    fn test_bitstream_writer_align() {
684        let mut writer = BitstreamWriter::new();
685        writer.write_bit(true);
686        writer.write_bit(false);
687        writer.align();
688
689        let output = writer.finish();
690        assert_eq!(output.len(), 1);
691    }
692
693    #[test]
694    fn test_bitstream_writer_bytes() {
695        let mut writer = BitstreamWriter::new();
696        writer.write_bytes(&[0xAB, 0xCD, 0xEF]);
697
698        let output = writer.finish();
699        assert_eq!(output, &[0xAB, 0xCD, 0xEF]);
700    }
701
702    #[test]
703    fn test_obu_writer_header() {
704        let mut writer = ObuWriter::new();
705        writer.write_obu_header(1, true);
706
707        let output = writer.buffer();
708        assert!(!output.is_empty());
709    }
710
711    #[test]
712    fn test_obu_writer_leb128() {
713        let mut writer = ObuWriter::new();
714        writer.write_leb128(127);
715
716        let output = writer.buffer();
717        assert_eq!(output.len(), 1);
718        assert_eq!(output[0], 127);
719
720        let mut writer2 = ObuWriter::new();
721        writer2.write_leb128(128);
722
723        let output2 = writer2.buffer();
724        assert_eq!(output2.len(), 2);
725    }
726
727    #[test]
728    fn test_obu_writer_complete() {
729        let mut writer = ObuWriter::new();
730        let payload = vec![1, 2, 3, 4];
731        writer.write_obu(1, &payload);
732
733        let output = writer.finish();
734        assert!(output.len() > payload.len());
735    }
736
737    #[test]
738    fn test_pmf_to_cdf() {
739        let pmf = vec![100, 200, 300, 400];
740        let cdf = pmf_to_cdf(&pmf);
741
742        assert_eq!(cdf.len(), 4);
743        assert!(*cdf.last().expect("should have last element") > 0);
744        // Check monotonic increasing
745        for i in 1..cdf.len() {
746            assert!(cdf[i] >= cdf[i - 1]);
747        }
748    }
749
750    #[test]
751    fn test_estimate_symbol_rate() {
752        let cdf = vec![100, 300, 600, CDF_PROB_TOP];
753
754        let rate0 = estimate_symbol_rate(0, &cdf);
755        let rate1 = estimate_symbol_rate(1, &cdf);
756
757        assert!(rate0 > 0.0);
758        assert!(rate1 > 0.0);
759        // More probable symbol should have lower rate
760        assert!(rate0 < rate1 * 2.0);
761    }
762
763    #[test]
764    fn test_bitstream_writer_len() {
765        let mut writer = BitstreamWriter::new();
766        assert_eq!(writer.len(), 0);
767        assert!(writer.is_empty());
768
769        writer.write_byte(0xFF);
770        assert_eq!(writer.len(), 1);
771        assert!(!writer.is_empty());
772    }
773
774    #[test]
775    fn test_symbol_encoder_reset() {
776        let mut encoder = SymbolEncoder::new();
777        encoder.init_contexts(2, 4);
778        encoder.encode(1, 0);
779
780        encoder.reset();
781        assert!(encoder.buffer().is_empty());
782    }
783
784    #[test]
785    fn test_arithmetic_encoder_reset() {
786        let mut encoder = ArithmeticEncoder::new();
787        encoder.encode_bool(true, CDF_PROB_TOP / 2);
788
789        encoder.reset();
790        assert_eq!(encoder.range, EC_WINDOW_SIZE);
791        assert!(encoder.buffer.is_empty());
792    }
793}