Skip to main content

oximedia_codec/av1/
symbols.rs

1//! AV1 symbol decoding from entropy-coded bitstream.
2//!
3//! This module provides high-level symbol decoding for AV1 syntax elements
4//! using the entropy coding engine. It handles:
5//!
6//! - Partition decoding
7//! - Mode decoding (intra/inter)
8//! - Motion vector decoding
9//! - Transform type and size selection
10//! - Skip and skip_mode flags
11//! - Segmentation and quantization selection
12//!
13//! # Symbol Decoding Flow
14//!
15//! 1. **Block-level symbols** - Partition, skip, segment_id
16//! 2. **Mode symbols** - Intra/inter mode selection
17//! 3. **Reference frame symbols** - For inter blocks
18//! 4. **Motion vector symbols** - MV decoding
19//! 5. **Transform symbols** - TX size and type
20//! 6. **Coefficient symbols** - Via coeff_decode module
21//!
22//! # Context Modeling
23//!
24//! AV1 uses adaptive context-dependent probability models. Contexts
25//! are computed from neighboring blocks and frame-level state.
26
27#![forbid(unsafe_code)]
28#![allow(dead_code)]
29#![allow(clippy::doc_markdown)]
30#![allow(clippy::unused_self)]
31#![allow(clippy::missing_errors_doc)]
32#![allow(clippy::similar_names)]
33#![allow(clippy::cast_possible_truncation)]
34#![allow(clippy::cast_sign_loss)]
35#![allow(clippy::too_many_arguments)]
36#![allow(clippy::struct_excessive_bools)]
37#![allow(clippy::module_name_repetitions)]
38
39use super::block::{BlockModeInfo, BlockSize, InterMode, IntraMode, PartitionType};
40use super::entropy::{uniform_cdf, SymbolReader};
41use super::entropy_tables::CdfContext;
42use super::transform::{TxSize, TxType};
43use crate::error::{CodecError, CodecResult};
44
45// =============================================================================
46// Constants
47// =============================================================================
48
49/// Maximum partition contexts.
50pub const PARTITION_CONTEXTS: usize = 4;
51
52/// Maximum skip contexts.
53pub const SKIP_CONTEXTS: usize = 3;
54
55/// Maximum intra mode contexts.
56pub const INTRA_MODE_CONTEXTS: usize = 5;
57
58/// Maximum inter mode contexts.
59pub const INTER_MODE_CONTEXTS: usize = 7;
60
61/// Maximum reference frame contexts.
62pub const REF_CONTEXTS: usize = 3;
63
64/// Maximum MV contexts.
65pub const MV_CONTEXTS: usize = 2;
66
67/// Maximum TX size contexts.
68pub const TX_SIZE_CONTEXTS: usize = 4;
69
70/// Maximum TX type contexts.
71pub const TX_TYPE_CONTEXTS: usize = 4;
72
73/// Number of reference frames.
74pub const NUM_REF_FRAMES: usize = 7;
75
76/// Maximum motion vector component.
77pub const MAX_MV_COMPONENT: i16 = 1023;
78
79// =============================================================================
80// Symbol Decoder State
81// =============================================================================
82
83/// Symbol decoder for AV1 syntax elements.
84#[derive(Debug)]
85pub struct SymbolDecoder {
86    /// Underlying symbol reader.
87    reader: SymbolReader,
88    /// CDF context for probability models.
89    cdf_context: CdfContext,
90    /// Current frame is intra-only.
91    frame_is_intra: bool,
92    /// Allow intraBC.
93    allow_intrabc: bool,
94    /// Current segment ID.
95    segment_id: u8,
96}
97
98impl SymbolDecoder {
99    /// Create a new symbol decoder.
100    pub fn new(data: Vec<u8>, frame_is_intra: bool) -> Self {
101        Self {
102            reader: SymbolReader::new(data),
103            cdf_context: CdfContext::new(),
104            frame_is_intra,
105            allow_intrabc: false,
106            segment_id: 0,
107        }
108    }
109
110    /// Read partition symbol.
111    pub fn read_partition(&mut self, bsize: BlockSize, _ctx: u8) -> CodecResult<PartitionType> {
112        if bsize == BlockSize::Block4x4 {
113            // 4x4 blocks cannot be partitioned
114            return Ok(PartitionType::None);
115        }
116
117        let mut cdf = uniform_cdf(10);
118        let symbol = self.reader.read_symbol(&mut cdf);
119        PartitionType::from_u8(symbol as u8)
120            .ok_or_else(|| CodecError::InvalidBitstream("Invalid partition type".to_string()))
121    }
122
123    /// Read skip flag.
124    pub fn read_skip(&mut self, _ctx: u8) -> bool {
125        // Use uniform CDF for now
126        let mut cdf = [16384u16, 32768, 0];
127        self.reader.read_bool(&mut cdf)
128    }
129
130    /// Read skip_mode flag.
131    pub fn read_skip_mode(&mut self, _ctx: u8) -> bool {
132        if self.frame_is_intra {
133            return false;
134        }
135
136        // Use local CDF array
137        let mut cdf = [16384u16, 32768, 0];
138        self.reader.read_bool(&mut cdf)
139    }
140
141    /// Read segment ID.
142    pub fn read_segment_id(&mut self, _ctx: u8, max_segments: u8) -> u8 {
143        if max_segments == 1 {
144            return 0;
145        }
146
147        let mut cdf = uniform_cdf(max_segments as usize);
148        let segment_id = self.reader.read_symbol(&mut cdf) as u8;
149        self.segment_id = segment_id;
150        segment_id
151    }
152
153    /// Read is_inter flag.
154    pub fn read_is_inter(&mut self, _ctx: u8) -> bool {
155        if self.frame_is_intra {
156            return false;
157        }
158
159        let mut cdf = [16384u16, 32768, 0];
160        self.reader.read_bool(&mut cdf)
161    }
162
163    /// Read intra mode for luma.
164    pub fn read_intra_mode(&mut self, _ctx: u8, _bsize: BlockSize) -> CodecResult<IntraMode> {
165        let mut cdf = uniform_cdf(13);
166        let symbol = self.reader.read_symbol(&mut cdf);
167        IntraMode::from_u8(symbol as u8)
168            .ok_or_else(|| CodecError::InvalidBitstream("Invalid intra mode".to_string()))
169    }
170
171    /// Read intra mode for chroma (UV).
172    pub fn read_uv_mode(&mut self, _y_mode: IntraMode, _ctx: u8) -> CodecResult<IntraMode> {
173        let mut cdf = uniform_cdf(13);
174        let symbol = self.reader.read_symbol(&mut cdf);
175        IntraMode::from_u8(symbol as u8)
176            .ok_or_else(|| CodecError::InvalidBitstream("Invalid UV mode".to_string()))
177    }
178
179    /// Read angle delta for directional intra modes.
180    pub fn read_angle_delta(&mut self, mode: IntraMode) -> i8 {
181        if !mode.is_directional() {
182            return 0;
183        }
184
185        let mut cdf = uniform_cdf(7);
186        let symbol = self.reader.read_symbol(&mut cdf);
187
188        // Map symbol to delta: 0->-3, 1->-2, 2->-1, 3->0, 4->1, 5->2, 6->3
189        (symbol as i8) - 3
190    }
191
192    /// Read palette mode flag.
193    pub fn read_use_palette(&mut self, bsize: BlockSize, _ctx: u8) -> bool {
194        if bsize == BlockSize::Block4x4
195            || bsize == BlockSize::Block4x8
196            || bsize == BlockSize::Block8x4
197        {
198            return false;
199        }
200
201        let mut cdf = [16384u16, 32768, 0];
202        self.reader.read_bool(&mut cdf)
203    }
204
205    /// Read filter intra mode.
206    pub fn read_filter_intra_mode(&mut self) -> u8 {
207        let mut cdf = uniform_cdf(5);
208        self.reader.read_symbol(&mut cdf) as u8
209    }
210
211    /// Read inter mode.
212    pub fn read_inter_mode(&mut self, _ctx: u8) -> CodecResult<InterMode> {
213        let mut cdf = uniform_cdf(4);
214        let symbol = self.reader.read_symbol(&mut cdf);
215        InterMode::from_u8(symbol as u8)
216            .ok_or_else(|| CodecError::InvalidBitstream("Invalid inter mode".to_string()))
217    }
218
219    /// Read reference frame indices.
220    pub fn read_ref_frames(&mut self, _ctx: u8) -> [i8; 2] {
221        if self.frame_is_intra {
222            return [-1, -1];
223        }
224
225        // Read compound mode flag
226        let mut compound_cdf = [16384u16, 32768, 0];
227        let is_compound = self.reader.read_bool(&mut compound_cdf);
228
229        if is_compound {
230            // Read two reference frames
231            let ref0 = self.read_single_ref_frame(0);
232            let ref1 = self.read_single_ref_frame(1);
233            [ref0, ref1]
234        } else {
235            // Single reference
236            let ref0 = self.read_single_ref_frame(0);
237            [ref0, -1]
238        }
239    }
240
241    /// Read a single reference frame index.
242    fn read_single_ref_frame(&mut self, _idx: usize) -> i8 {
243        let mut cdf = uniform_cdf(7);
244        self.reader.read_symbol(&mut cdf) as i8
245    }
246
247    /// Read motion vector.
248    pub fn read_mv(&mut self, ctx: u8) -> [i16; 2] {
249        let row = self.read_mv_component(ctx, true);
250        let col = self.read_mv_component(ctx, false);
251        [row, col]
252    }
253
254    /// Read a single MV component.
255    fn read_mv_component(&mut self, _ctx: u8, _is_row: bool) -> i16 {
256        // Read sign
257        let mut sign_cdf = [16384u16, 32768, 0];
258        let sign = self.reader.read_bool(&mut sign_cdf);
259
260        // Read class (magnitude range) - simplified to uniform
261        let mut class_cdf = uniform_cdf(11);
262        let class = self.reader.read_symbol(&mut class_cdf) as u8;
263
264        // Read bits based on class
265        let mag = self.read_mv_magnitude(class);
266
267        if sign {
268            -(mag as i16)
269        } else {
270            mag as i16
271        }
272    }
273
274    /// Read MV magnitude bits.
275    fn read_mv_magnitude(&mut self, class: u8) -> u16 {
276        match class {
277            0 => 0, // Class 0: magnitude 0
278            1 => 1, // Class 1: magnitude 1
279            _ => {
280                // Classes 2-10: read additional bits
281                let offset_bits = class - 2;
282                let mut mag = 1u16 << (offset_bits + 1);
283
284                for _ in 0..offset_bits {
285                    let mut bit_cdf = [16384u16, 32768, 0];
286                    let bit = self.reader.read_bool(&mut bit_cdf);
287                    mag |= u16::from(bit);
288                    mag <<= 1;
289                }
290
291                mag >> 1
292            }
293        }
294    }
295
296    /// Read transform size.
297    pub fn read_tx_size(&mut self, bsize: BlockSize, _ctx: u8) -> TxSize {
298        let max_tx_size = bsize.max_tx_size();
299
300        // Use uniform CDF
301        let mut cdf = uniform_cdf(5);
302        let symbol = self.reader.read_symbol(&mut cdf);
303
304        // Map symbol to TX size
305        self.map_tx_size_symbol(symbol, max_tx_size)
306    }
307
308    /// Map TX size symbol to actual TX size.
309    fn map_tx_size_symbol(&self, symbol: usize, max_tx_size: TxSize) -> TxSize {
310        match symbol {
311            0 => TxSize::Tx4x4,
312            1 => TxSize::Tx8x8.min(max_tx_size),
313            2 => TxSize::Tx16x16.min(max_tx_size),
314            3 => TxSize::Tx32x32.min(max_tx_size),
315            _ => max_tx_size,
316        }
317    }
318
319    /// Read transform type.
320    pub fn read_tx_type(&mut self, _tx_size: TxSize, _is_inter: bool, _ctx: u8) -> TxType {
321        // Use uniform CDF
322        let mut cdf = uniform_cdf(16);
323        let symbol = self.reader.read_symbol(&mut cdf);
324        TxType::from_u8(symbol as u8).unwrap_or(TxType::DctDct)
325    }
326
327    /// Read compound type for compound prediction.
328    pub fn read_compound_type(&mut self, _ctx: u8) -> u8 {
329        let mut cdf = uniform_cdf(3);
330        self.reader.read_symbol(&mut cdf) as u8
331    }
332
333    /// Read interpolation filter.
334    pub fn read_interp_filter(&mut self, _ctx: u8) -> u8 {
335        let mut cdf = uniform_cdf(4);
336        self.reader.read_symbol(&mut cdf) as u8
337    }
338
339    /// Read motion mode (simple, OBMC, warped).
340    pub fn read_motion_mode(&mut self, bsize: BlockSize, _ctx: u8) -> u8 {
341        if bsize == BlockSize::Block4x4
342            || bsize == BlockSize::Block4x8
343            || bsize == BlockSize::Block8x4
344        {
345            return 0; // Simple motion only for small blocks
346        }
347
348        let mut cdf = uniform_cdf(3);
349        self.reader.read_symbol(&mut cdf) as u8
350    }
351
352    /// Decode complete block mode info.
353    pub fn decode_block_mode(
354        &mut self,
355        bsize: BlockSize,
356        ctx_skip: u8,
357        ctx_mode: u8,
358    ) -> CodecResult<BlockModeInfo> {
359        let mut mode_info = BlockModeInfo::new();
360        mode_info.block_size = bsize;
361
362        // Read skip flag
363        mode_info.skip = self.read_skip(ctx_skip);
364
365        // Read segment ID (if segmentation is enabled)
366        mode_info.segment_id = self.segment_id;
367
368        // Read is_inter
369        mode_info.is_inter = self.read_is_inter(ctx_mode);
370
371        if mode_info.is_inter {
372            // Inter block
373            mode_info.inter_mode = self.read_inter_mode(ctx_mode)?;
374            mode_info.ref_frames = self.read_ref_frames(ctx_mode);
375
376            // Read motion vectors if needed
377            if mode_info.inter_mode.has_newmv() {
378                let mv = self.read_mv(ctx_mode);
379                mode_info.mv[0] = mv;
380            }
381
382            // Read interpolation filter
383            mode_info.interp_filter = [
384                self.read_interp_filter(ctx_mode),
385                self.read_interp_filter(ctx_mode),
386            ];
387
388            // Read motion mode
389            mode_info.motion_mode = self.read_motion_mode(bsize, ctx_mode);
390
391            // Compound prediction
392            if mode_info.is_compound() {
393                mode_info.compound_type = self.read_compound_type(ctx_mode);
394            }
395        } else {
396            // Intra block
397            mode_info.intra_mode = self.read_intra_mode(ctx_mode, bsize)?;
398            mode_info.uv_mode = self.read_uv_mode(mode_info.intra_mode, ctx_mode)?;
399            mode_info.angle_delta = [
400                self.read_angle_delta(mode_info.intra_mode),
401                self.read_angle_delta(mode_info.uv_mode),
402            ];
403
404            // Palette mode
405            mode_info.use_palette = self.read_use_palette(bsize, ctx_mode);
406
407            // Filter intra
408            if bsize.width() <= 32 && bsize.height() <= 32 {
409                mode_info.filter_intra_mode = self.read_filter_intra_mode();
410            }
411        }
412
413        // Read transform size
414        mode_info.tx_size = self.read_tx_size(bsize, ctx_mode);
415        // TX type would be read separately based on TX size
416
417        Ok(mode_info)
418    }
419
420    /// Check if more data is available.
421    pub fn has_more_data(&self) -> bool {
422        self.reader.has_more_data()
423    }
424
425    /// Get current position in bytes.
426    pub fn position(&self) -> usize {
427        self.reader.position()
428    }
429
430    /// Get remaining bytes.
431    pub fn remaining(&self) -> usize {
432        self.reader.remaining()
433    }
434}
435
436// =============================================================================
437// TxSize min/max helpers
438// =============================================================================
439
440impl TxSize {
441    /// Get the minimum of two TX sizes.
442    #[must_use]
443    pub const fn min(self, other: Self) -> Self {
444        let self_area = self.area();
445        let other_area = other.area();
446        if self_area <= other_area {
447            self
448        } else {
449            other
450        }
451    }
452}
453
454// =============================================================================
455// Context Computation Helpers
456// =============================================================================
457
458/// Compute partition context from neighbors.
459#[must_use]
460pub fn compute_partition_context(above: u8, left: u8, bsize: BlockSize) -> u8 {
461    let bs = bsize.width_log2();
462    let above_split = above < bs;
463    let left_split = left < bs;
464
465    match (above_split, left_split) {
466        (false, false) => 0,
467        (true, false) => 1,
468        (false, true) => 2,
469        (true, true) => 3,
470    }
471}
472
473/// Compute skip context from neighbors.
474#[must_use]
475pub fn compute_skip_context(above_skip: bool, left_skip: bool) -> u8 {
476    match (above_skip, left_skip) {
477        (false, false) => 0,
478        (false, true) | (true, false) => 1,
479        (true, true) => 2,
480    }
481}
482
483/// Compute is_inter context from neighbors.
484#[must_use]
485pub fn compute_is_inter_context(above_inter: bool, left_inter: bool) -> u8 {
486    match (above_inter, left_inter) {
487        (false, false) => 0,
488        (false, true) | (true, false) => 1,
489        (true, true) => 2,
490    }
491}
492
493/// Compute TX size context from neighbors.
494#[must_use]
495pub fn compute_tx_size_context(above_tx: TxSize, left_tx: TxSize, max_tx: TxSize) -> u8 {
496    let above_cat = tx_size_category(above_tx, max_tx);
497    let left_cat = tx_size_category(left_tx, max_tx);
498
499    (above_cat + left_cat).min(3)
500}
501
502/// Categorize TX size relative to max.
503fn tx_size_category(tx: TxSize, max_tx: TxSize) -> u8 {
504    if tx == max_tx {
505        0
506    } else if tx.width() * 2 >= max_tx.width() && tx.height() * 2 >= max_tx.height() {
507        1
508    } else {
509        2
510    }
511}
512
513// =============================================================================
514// MV Prediction and Context
515// =============================================================================
516
517/// Motion vector predictor.
518#[derive(Clone, Copy, Debug, Default)]
519pub struct MvPredictor {
520    /// Candidate motion vectors.
521    pub candidates: [[i16; 2]; 3],
522    /// Number of valid candidates.
523    pub count: usize,
524}
525
526impl MvPredictor {
527    /// Create a new MV predictor.
528    #[must_use]
529    pub const fn new() -> Self {
530        Self {
531            candidates: [[0, 0]; 3],
532            count: 0,
533        }
534    }
535
536    /// Add a candidate MV.
537    pub fn add_candidate(&mut self, mv: [i16; 2]) {
538        if self.count < 3 {
539            self.candidates[self.count] = mv;
540            self.count += 1;
541        }
542    }
543
544    /// Get the nearest MV.
545    #[must_use]
546    pub fn nearest(&self) -> [i16; 2] {
547        if self.count > 0 {
548            self.candidates[0]
549        } else {
550            [0, 0]
551        }
552    }
553
554    /// Get the near MV.
555    #[must_use]
556    pub fn near(&self) -> [i16; 2] {
557        if self.count > 1 {
558            self.candidates[1]
559        } else {
560            self.nearest()
561        }
562    }
563
564    /// Compute MV context.
565    #[must_use]
566    pub fn compute_context(&self) -> u8 {
567        if self.count == 0 {
568            return 0;
569        }
570
571        let mv0_mag = self.mv_magnitude(self.candidates[0]);
572        let mv1_mag = if self.count > 1 {
573            self.mv_magnitude(self.candidates[1])
574        } else {
575            0
576        };
577
578        if mv0_mag < 16 && mv1_mag < 16 {
579            0
580        } else {
581            1
582        }
583    }
584
585    /// Compute MV magnitude.
586    fn mv_magnitude(&self, mv: [i16; 2]) -> u16 {
587        (mv[0].abs() + mv[1].abs()) as u16
588    }
589}
590
591// =============================================================================
592// Symbol Encoding (for completeness)
593// =============================================================================
594
595/// Symbol encoder for AV1 syntax elements.
596#[derive(Debug)]
597pub struct SymbolEncoder {
598    /// Underlying symbol writer.
599    writer: super::entropy::SymbolWriter,
600    /// CDF context.
601    cdf_context: CdfContext,
602}
603
604impl SymbolEncoder {
605    /// Create a new symbol encoder.
606    #[must_use]
607    pub fn new() -> Self {
608        Self {
609            writer: super::entropy::SymbolWriter::new(),
610            cdf_context: CdfContext::new(),
611        }
612    }
613
614    /// Write partition symbol.
615    pub fn write_partition(&mut self, partition: PartitionType, _ctx: u8) {
616        let mut cdf = uniform_cdf(10);
617        self.writer.write_symbol(partition as usize, &mut cdf);
618    }
619
620    /// Write skip flag.
621    pub fn write_skip(&mut self, skip: bool, _ctx: u8) {
622        // Use local CDF array
623        let mut cdf = [16384u16, 32768, 0];
624        self.writer.write_bool(skip, &mut cdf);
625    }
626
627    /// Finalize and get output.
628    #[must_use]
629    pub fn finish(self) -> Vec<u8> {
630        self.writer.finish()
631    }
632}
633
634impl Default for SymbolEncoder {
635    fn default() -> Self {
636        Self::new()
637    }
638}
639
640// =============================================================================
641// Tests
642// =============================================================================
643
644#[cfg(test)]
645mod tests {
646    use super::*;
647
648    #[test]
649    fn test_symbol_decoder_creation() {
650        let data = vec![0u8; 128];
651        let decoder = SymbolDecoder::new(data, false);
652        assert!(decoder.has_more_data());
653    }
654
655    #[test]
656    fn test_partition_context() {
657        assert_eq!(compute_partition_context(0, 0, BlockSize::Block16x16), 3);
658        assert_eq!(compute_partition_context(4, 4, BlockSize::Block16x16), 0);
659        assert_eq!(compute_partition_context(4, 3, BlockSize::Block16x16), 2);
660    }
661
662    #[test]
663    fn test_skip_context() {
664        assert_eq!(compute_skip_context(false, false), 0);
665        assert_eq!(compute_skip_context(true, false), 1);
666        assert_eq!(compute_skip_context(false, true), 1);
667        assert_eq!(compute_skip_context(true, true), 2);
668    }
669
670    #[test]
671    fn test_is_inter_context() {
672        assert_eq!(compute_is_inter_context(false, false), 0);
673        assert_eq!(compute_is_inter_context(true, false), 1);
674        assert_eq!(compute_is_inter_context(false, true), 1);
675        assert_eq!(compute_is_inter_context(true, true), 2);
676    }
677
678    #[test]
679    fn test_tx_size_context() {
680        let max_tx = TxSize::Tx16x16;
681        let ctx = compute_tx_size_context(TxSize::Tx8x8, TxSize::Tx8x8, max_tx);
682        assert!(ctx <= 3);
683    }
684
685    #[test]
686    fn test_mv_predictor() {
687        let mut pred = MvPredictor::new();
688        assert_eq!(pred.count, 0);
689
690        pred.add_candidate([10, 20]);
691        assert_eq!(pred.count, 1);
692        assert_eq!(pred.nearest(), [10, 20]);
693
694        pred.add_candidate([5, 15]);
695        assert_eq!(pred.count, 2);
696        assert_eq!(pred.near(), [5, 15]);
697    }
698
699    #[test]
700    fn test_mv_predictor_context() {
701        let mut pred = MvPredictor::new();
702        pred.add_candidate([5, 10]);
703        pred.add_candidate([8, 12]);
704
705        let ctx = pred.compute_context();
706        assert!(ctx <= 1);
707    }
708
709    #[test]
710    fn test_tx_size_min() {
711        assert_eq!(TxSize::Tx8x8.min(TxSize::Tx16x16), TxSize::Tx8x8);
712        assert_eq!(TxSize::Tx16x16.min(TxSize::Tx8x8), TxSize::Tx8x8);
713        assert_eq!(TxSize::Tx4x4.min(TxSize::Tx4x4), TxSize::Tx4x4);
714    }
715
716    #[test]
717    fn test_symbol_encoder() {
718        let mut encoder = SymbolEncoder::new();
719        encoder.write_skip(true, 0);
720        encoder.write_partition(PartitionType::None, 0);
721        let output = encoder.finish();
722        assert!(!output.is_empty());
723    }
724
725    #[test]
726    fn test_mv_magnitude() {
727        let pred = MvPredictor::new();
728        assert_eq!(pred.mv_magnitude([0, 0]), 0);
729        assert_eq!(pred.mv_magnitude([10, 20]), 30);
730        assert_eq!(pred.mv_magnitude([-10, 20]), 30);
731    }
732
733    #[test]
734    fn test_constants() {
735        assert_eq!(PARTITION_CONTEXTS, 4);
736        assert_eq!(SKIP_CONTEXTS, 3);
737        assert_eq!(INTRA_MODE_CONTEXTS, 5);
738        assert_eq!(INTER_MODE_CONTEXTS, 7);
739        assert_eq!(NUM_REF_FRAMES, 7);
740    }
741
742    #[test]
743    fn test_symbol_decoder_position() {
744        let data = vec![0u8; 128];
745        let decoder = SymbolDecoder::new(data, false);
746        // Decoder init reads 15 bits (2 bytes) for arithmetic decoder state
747        assert_eq!(decoder.remaining(), 126);
748    }
749
750    #[test]
751    fn test_tx_size_category() {
752        let max_tx = TxSize::Tx32x32;
753        assert_eq!(tx_size_category(TxSize::Tx32x32, max_tx), 0);
754        assert_eq!(tx_size_category(TxSize::Tx16x16, max_tx), 1);
755        assert_eq!(tx_size_category(TxSize::Tx4x4, max_tx), 2);
756    }
757}