Skip to main content

oximedia_codec/av1/
coefficients.rs

1//! AV1 coefficient parsing and dequantization.
2//!
3//! This module handles the parsing of transform coefficients from the
4//! entropy-coded bitstream, including:
5//!
6//! - End of block (EOB) position parsing
7//! - Coefficient level and sign decoding
8//! - Scan order for coefficient serialization
9//! - Dequantization helpers
10//!
11//! # Coefficient Coding Structure
12//!
13//! AV1 uses a sophisticated multi-level coding scheme:
14//!
15//! 1. **EOB (End of Block)** - Position of last non-zero coefficient
16//! 2. **Coefficient Base** - Base level (0-2) using multi-symbol coding
17//! 3. **Coefficient Base Range** - Extended range using Golomb-Rice codes
18//! 4. **DC Sign** - Sign of DC coefficient
19//! 5. **AC Signs** - Signs of AC coefficients
20//!
21//! # Scan Orders
22//!
23//! Coefficients are scanned in a specific order based on:
24//! - Transform class (2D, horizontal, vertical)
25//! - Transform size
26//!
27//! # Reference
28//!
29//! See AV1 Specification Section 5.11.39 for coefficient syntax and
30//! Section 7.12 for coefficient semantics.
31
32#![allow(dead_code)]
33#![allow(clippy::cast_possible_truncation)]
34#![allow(clippy::cast_sign_loss)]
35#![allow(clippy::bool_to_int_with_if)]
36#![allow(clippy::needless_bool_assign)]
37#![allow(clippy::if_not_else)]
38#![allow(clippy::cast_possible_wrap)]
39#![allow(clippy::match_same_arms)]
40#![allow(clippy::doc_markdown)]
41#![allow(clippy::explicit_iter_loop)]
42#![allow(clippy::cast_precision_loss)]
43#![allow(clippy::comparison_chain)]
44#![allow(clippy::cast_lossless)]
45
46use super::transform::{TxClass, TxSize, TxType};
47
48// =============================================================================
49// Constants
50// =============================================================================
51
52/// Maximum EOB position for any transform size.
53pub const MAX_EOB: usize = 4096;
54
55/// Number of EOB position contexts.
56pub const EOB_COEF_CONTEXTS: usize = 9;
57
58/// Number of coefficient base contexts.
59pub const COEFF_BASE_CONTEXTS: usize = 42;
60
61/// Number of coefficient base EOB contexts.
62pub const COEFF_BASE_EOB_CONTEXTS: usize = 3;
63
64/// Number of DC sign contexts.
65pub const DC_SIGN_CONTEXTS: usize = 3;
66
67/// Number of coefficient base range contexts.
68pub const COEFF_BR_CONTEXTS: usize = 21;
69
70/// Maximum coefficient base level.
71pub const COEFF_BASE_RANGE_MAX: u32 = 3;
72
73/// Golomb-Rice parameter for coefficient coding.
74pub const COEFF_BR_RICE_PARAM: u8 = 1;
75
76/// Base level cutoffs for coefficient coding.
77pub const BASE_LEVEL_CUTOFFS: [u32; 5] = [0, 1, 2, 3, 4];
78
79/// Number of TX classes.
80pub const TX_CLASSES: usize = 3;
81
82/// Coefficient context position limit.
83pub const COEFF_CONTEXT_MASK: usize = 63;
84
85/// Maximum neighbors for context computation.
86pub const MAX_NEIGHBORS: usize = 2;
87
88// =============================================================================
89// EOB Position Tables
90// =============================================================================
91
92/// EOB offset for each transform size.
93pub const EOB_OFFSET: [u16; 19] = [
94    0,    // TX_4X4
95    16,   // TX_8X8
96    80,   // TX_16X16
97    336,  // TX_32X32
98    1360, // TX_64X64
99    16,   // TX_4X8
100    16,   // TX_8X4
101    80,   // TX_8X16
102    80,   // TX_16X8
103    336,  // TX_16X32
104    336,  // TX_32X16
105    1360, // TX_32X64
106    1360, // TX_64X32
107    48,   // TX_4X16
108    48,   // TX_16X4
109    176,  // TX_8X32
110    176,  // TX_32X8
111    592,  // TX_16X64
112    592,  // TX_64X16
113];
114
115/// EOB extra bits for each transform size.
116pub const EOB_EXTRA_BITS: [u8; 19] = [
117    0, // TX_4X4
118    1, // TX_8X8
119    2, // TX_16X16
120    3, // TX_32X32
121    4, // TX_64X64
122    1, // TX_4X8
123    1, // TX_8X4
124    2, // TX_8X16
125    2, // TX_16X8
126    3, // TX_16X32
127    3, // TX_32X16
128    4, // TX_32X64
129    4, // TX_64X32
130    2, // TX_4X16
131    2, // TX_16X4
132    3, // TX_8X32
133    3, // TX_32X8
134    4, // TX_16X64
135    4, // TX_64X16
136];
137
138/// EOB group start positions.
139pub const EOB_GROUP_START: [u16; 12] = [0, 1, 2, 3, 5, 9, 17, 33, 65, 129, 257, 513];
140
141/// EOB symbol to position mapping.
142pub const EOB_TO_POS: [u16; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
143
144// =============================================================================
145// Level Context
146// =============================================================================
147
148/// Context for coefficient level coding.
149#[derive(Clone, Debug, Default)]
150pub struct LevelContext {
151    /// Accumulated magnitude of neighbors.
152    pub mag: u32,
153    /// Number of non-zero neighbors.
154    pub count: u8,
155    /// Position-based context.
156    pub pos_ctx: u8,
157}
158
159impl LevelContext {
160    /// Create a new level context.
161    #[must_use]
162    pub const fn new() -> Self {
163        Self {
164            mag: 0,
165            count: 0,
166            pos_ctx: 0,
167        }
168    }
169
170    /// Compute context from magnitude.
171    #[must_use]
172    pub fn mag_context(&self) -> u8 {
173        let mag = self.mag;
174        if mag > 512 {
175            4
176        } else if mag > 256 {
177            3
178        } else if mag > 128 {
179            2
180        } else if mag > 64 {
181            1
182        } else {
183            0
184        }
185    }
186
187    /// Compute the combined context.
188    #[must_use]
189    pub fn context(&self) -> u8 {
190        self.mag_context() * 3 + self.count.min(2)
191    }
192}
193
194// =============================================================================
195// Coefficient Context
196// =============================================================================
197
198/// Context for coefficient parsing state.
199#[derive(Clone, Debug)]
200pub struct CoeffContext {
201    /// Transform size.
202    pub tx_size: TxSize,
203    /// Transform type.
204    pub tx_type: TxType,
205    /// Plane index (0=Y, 1=U, 2=V).
206    pub plane: u8,
207    /// Current scan position.
208    pub scan_pos: u16,
209    /// End of block position.
210    pub eob: u16,
211    /// Coefficient levels (dequantized).
212    pub levels: Vec<i32>,
213    /// Sign bits.
214    pub signs: Vec<bool>,
215    /// Accumulated left context.
216    pub left_ctx: Vec<u8>,
217    /// Accumulated above context.
218    pub above_ctx: Vec<u8>,
219    /// Block width in 4x4 units.
220    pub block_width: u8,
221    /// Block height in 4x4 units.
222    pub block_height: u8,
223}
224
225impl CoeffContext {
226    /// Create a new coefficient context.
227    #[must_use]
228    pub fn new(tx_size: TxSize, tx_type: TxType, plane: u8) -> Self {
229        let area = tx_size.area() as usize;
230        let width = (tx_size.width() / 4) as u8;
231        let height = (tx_size.height() / 4) as u8;
232
233        Self {
234            tx_size,
235            tx_type,
236            plane,
237            scan_pos: 0,
238            eob: 0,
239            levels: vec![0; area],
240            signs: vec![false; area],
241            left_ctx: vec![0; height as usize * 4],
242            above_ctx: vec![0; width as usize * 4],
243            block_width: width,
244            block_height: height,
245        }
246    }
247
248    /// Reset the context for a new block.
249    pub fn reset(&mut self) {
250        self.scan_pos = 0;
251        self.eob = 0;
252        self.levels.fill(0);
253        self.signs.fill(false);
254        self.left_ctx.fill(0);
255        self.above_ctx.fill(0);
256    }
257
258    /// Get the transform class.
259    #[must_use]
260    pub fn tx_class(&self) -> TxClass {
261        self.tx_type.tx_class()
262    }
263
264    /// Get scan position for a coefficient index.
265    #[must_use]
266    pub fn get_scan_position(&self, idx: usize) -> (u32, u32) {
267        let width = self.tx_size.width();
268        let row = (idx as u32) / width;
269        let col = (idx as u32) % width;
270        (row, col)
271    }
272
273    /// Get coefficient index from row and column.
274    #[must_use]
275    pub fn get_coeff_index(&self, row: u32, col: u32) -> usize {
276        (row * self.tx_size.width() + col) as usize
277    }
278
279    /// Compute level context for a position.
280    #[must_use]
281    pub fn compute_level_context(&self, pos: usize) -> LevelContext {
282        let width = self.tx_size.width() as usize;
283        let _height = self.tx_size.height() as usize;
284        let row = pos / width;
285        let col = pos % width;
286
287        let mut ctx = LevelContext::new();
288
289        // Get neighbors (left and above)
290        if col > 0 {
291            let left = self.levels[row * width + col - 1].unsigned_abs();
292            ctx.mag += left;
293            if left > 0 {
294                ctx.count += 1;
295            }
296        }
297
298        if row > 0 {
299            let above = self.levels[(row - 1) * width + col].unsigned_abs();
300            ctx.mag += above;
301            if above > 0 {
302                ctx.count += 1;
303            }
304        }
305
306        // Diagonal neighbor
307        if row > 0 && col > 0 {
308            let diag = self.levels[(row - 1) * width + col - 1].unsigned_abs();
309            ctx.mag += diag;
310        }
311
312        // Position context
313        ctx.pos_ctx = if row + col == 0 {
314            0
315        } else if row + col < 2 {
316            1
317        } else if row + col < 4 {
318            2
319        } else {
320            3
321        };
322
323        ctx
324    }
325
326    /// Get DC sign context.
327    #[must_use]
328    pub fn dc_sign_context(&self) -> u8 {
329        let left_sign = if !self.left_ctx.is_empty() {
330            (self.left_ctx[0] as i8 - 1).signum()
331        } else {
332            0
333        };
334
335        let above_sign = if !self.above_ctx.is_empty() {
336            (self.above_ctx[0] as i8 - 1).signum()
337        } else {
338            0
339        };
340
341        let sign_sum = left_sign + above_sign;
342
343        if sign_sum < 0 {
344            0
345        } else if sign_sum > 0 {
346            2
347        } else {
348            1
349        }
350    }
351
352    /// Set coefficient value at position.
353    pub fn set_coeff(&mut self, pos: usize, level: i32, sign: bool) {
354        if pos < self.levels.len() {
355            self.levels[pos] = if sign { -level } else { level };
356            self.signs[pos] = sign;
357        }
358    }
359
360    /// Get coefficient value at position.
361    #[must_use]
362    pub fn get_coeff(&self, pos: usize) -> i32 {
363        self.levels.get(pos).copied().unwrap_or(0)
364    }
365
366    /// Check if block has any non-zero coefficients.
367    #[must_use]
368    pub fn has_nonzero(&self) -> bool {
369        self.eob > 0
370    }
371
372    /// Get the number of non-zero coefficients.
373    #[must_use]
374    pub fn count_nonzero(&self) -> u16 {
375        self.levels.iter().filter(|&&l| l != 0).count() as u16
376    }
377}
378
379impl Default for CoeffContext {
380    fn default() -> Self {
381        Self::new(TxSize::Tx4x4, TxType::DctDct, 0)
382    }
383}
384
385// =============================================================================
386// Scan Order Generation
387// =============================================================================
388
389/// Generate diagonal scan order for a given size.
390#[must_use]
391pub fn generate_diagonal_scan(width: usize, height: usize) -> Vec<u16> {
392    let mut scan = Vec::with_capacity(width * height);
393
394    // Traverse diagonals
395    for diag in 0..(width + height - 1) {
396        // Start position for this diagonal
397        let col_start = if diag < width { 0 } else { diag - width + 1 };
398        let col_end = diag.min(height - 1);
399
400        for offset in 0..=(col_end - col_start) {
401            let row = col_start + offset;
402            let col = diag - row;
403
404            if col < width && row < height {
405                scan.push((row * width + col) as u16);
406            }
407        }
408    }
409
410    scan
411}
412
413/// Generate horizontal scan order.
414#[must_use]
415pub fn generate_horizontal_scan(width: usize, height: usize) -> Vec<u16> {
416    let mut scan = Vec::with_capacity(width * height);
417
418    for row in 0..height {
419        for col in 0..width {
420            scan.push((row * width + col) as u16);
421        }
422    }
423
424    scan
425}
426
427/// Generate vertical scan order.
428#[must_use]
429pub fn generate_vertical_scan(width: usize, height: usize) -> Vec<u16> {
430    let mut scan = Vec::with_capacity(width * height);
431
432    for col in 0..width {
433        for row in 0..height {
434            scan.push((row * width + col) as u16);
435        }
436    }
437
438    scan
439}
440
441/// Get the scan order for a given transform.
442#[must_use]
443pub fn get_scan_order(tx_size: TxSize, tx_class: TxClass) -> Vec<u16> {
444    let width = tx_size.width() as usize;
445    let height = tx_size.height() as usize;
446
447    match tx_class {
448        TxClass::Class2D => generate_diagonal_scan(width, height),
449        TxClass::ClassHoriz => generate_horizontal_scan(width, height),
450        TxClass::ClassVert => generate_vertical_scan(width, height),
451    }
452}
453
454/// Scan order cache for common transform sizes.
455#[derive(Clone, Debug)]
456pub struct ScanOrderCache {
457    /// Cached scan orders indexed by [tx_size][tx_class].
458    cache: Vec<Vec<Vec<u16>>>,
459}
460
461impl ScanOrderCache {
462    /// Create a new scan order cache.
463    #[must_use]
464    pub fn new() -> Self {
465        let mut cache = Vec::with_capacity(19);
466
467        for tx_size_idx in 0..19 {
468            let tx_size = TxSize::from_u8(tx_size_idx as u8).unwrap_or_default();
469            let mut class_scans = Vec::with_capacity(3);
470
471            for tx_class_idx in 0..3 {
472                let tx_class = TxClass::from_u8(tx_class_idx as u8).unwrap_or_default();
473                class_scans.push(get_scan_order(tx_size, tx_class));
474            }
475
476            cache.push(class_scans);
477        }
478
479        Self { cache }
480    }
481
482    /// Get scan order from cache.
483    #[must_use]
484    pub fn get(&self, tx_size: TxSize, tx_class: TxClass) -> &[u16] {
485        let size_idx = tx_size as usize;
486        let class_idx = tx_class as usize;
487
488        if size_idx < self.cache.len() && class_idx < self.cache[size_idx].len() {
489            &self.cache[size_idx][class_idx]
490        } else {
491            &[]
492        }
493    }
494}
495
496impl Default for ScanOrderCache {
497    fn default() -> Self {
498        Self::new()
499    }
500}
501
502// =============================================================================
503// EOB Parsing Helpers
504// =============================================================================
505
506/// EOB (End of Block) position context.
507#[derive(Clone, Debug, Default)]
508pub struct EobContext {
509    /// EOB multi-context.
510    pub eob_multi: u8,
511    /// EOB extra bits.
512    pub eob_extra: u8,
513    /// Base context.
514    pub base_ctx: u8,
515}
516
517impl EobContext {
518    /// Create EOB context for a transform size.
519    #[must_use]
520    pub fn new(tx_size: TxSize) -> Self {
521        let size_idx = tx_size as usize;
522        let extra_bits = if size_idx < EOB_EXTRA_BITS.len() {
523            EOB_EXTRA_BITS[size_idx]
524        } else {
525            0
526        };
527
528        Self {
529            eob_multi: 0,
530            eob_extra: extra_bits,
531            base_ctx: 0,
532        }
533    }
534
535    /// Get the EOB context from position.
536    #[must_use]
537    pub fn get_eob_context(eob: u16) -> u8 {
538        if eob <= 1 {
539            0
540        } else if eob <= 2 {
541            1
542        } else if eob <= 4 {
543            2
544        } else if eob <= 8 {
545            3
546        } else if eob <= 16 {
547            4
548        } else if eob <= 32 {
549            5
550        } else if eob <= 64 {
551            6
552        } else if eob <= 128 {
553            7
554        } else {
555            8
556        }
557    }
558
559    /// Compute EOB from multi-symbol and extra bits.
560    #[must_use]
561    pub fn compute_eob(eob_multi: u8, eob_extra: u16) -> u16 {
562        let group_idx = eob_multi as usize;
563        if group_idx >= EOB_GROUP_START.len() {
564            return 0;
565        }
566
567        let base = EOB_GROUP_START[group_idx];
568        base + eob_extra
569    }
570}
571
572/// EOB point parsing state.
573#[derive(Clone, Copy, Debug, PartialEq, Eq)]
574pub enum EobPt {
575    /// No coefficients.
576    EobPt0 = 0,
577    /// 1 coefficient.
578    EobPt1 = 1,
579    /// 2 coefficients.
580    EobPt2 = 2,
581    /// 3-4 coefficients.
582    EobPt3To4 = 3,
583    /// 5-8 coefficients.
584    EobPt5To8 = 4,
585    /// 9-16 coefficients.
586    EobPt9To16 = 5,
587    /// 17-32 coefficients.
588    EobPt17To32 = 6,
589    /// 33-64 coefficients.
590    EobPt33To64 = 7,
591    /// 65-128 coefficients.
592    EobPt65To128 = 8,
593    /// 129-256 coefficients.
594    EobPt129To256 = 9,
595    /// 257-512 coefficients.
596    EobPt257To512 = 10,
597    /// 513-1024 coefficients.
598    EobPt513To1024 = 11,
599}
600
601impl EobPt {
602    /// Get the EOB point from an EOB value.
603    #[must_use]
604    pub fn from_eob(eob: u16) -> Self {
605        match eob {
606            0 => Self::EobPt0,
607            1 => Self::EobPt1,
608            2 => Self::EobPt2,
609            3..=4 => Self::EobPt3To4,
610            5..=8 => Self::EobPt5To8,
611            9..=16 => Self::EobPt9To16,
612            17..=32 => Self::EobPt17To32,
613            33..=64 => Self::EobPt33To64,
614            65..=128 => Self::EobPt65To128,
615            129..=256 => Self::EobPt129To256,
616            257..=512 => Self::EobPt257To512,
617            _ => Self::EobPt513To1024,
618        }
619    }
620
621    /// Get the base EOB for this point.
622    #[must_use]
623    pub const fn base_eob(self) -> u16 {
624        match self {
625            Self::EobPt0 => 0,
626            Self::EobPt1 => 1,
627            Self::EobPt2 => 2,
628            Self::EobPt3To4 => 3,
629            Self::EobPt5To8 => 5,
630            Self::EobPt9To16 => 9,
631            Self::EobPt17To32 => 17,
632            Self::EobPt33To64 => 33,
633            Self::EobPt65To128 => 65,
634            Self::EobPt129To256 => 129,
635            Self::EobPt257To512 => 257,
636            Self::EobPt513To1024 => 513,
637        }
638    }
639
640    /// Get the number of extra bits for this point.
641    #[must_use]
642    pub const fn extra_bits(self) -> u8 {
643        match self {
644            Self::EobPt0 | Self::EobPt1 | Self::EobPt2 => 0,
645            Self::EobPt3To4 => 1,
646            Self::EobPt5To8 => 2,
647            Self::EobPt9To16 => 3,
648            Self::EobPt17To32 => 4,
649            Self::EobPt33To64 => 5,
650            Self::EobPt65To128 => 6,
651            Self::EobPt129To256 => 7,
652            Self::EobPt257To512 => 8,
653            Self::EobPt513To1024 => 9,
654        }
655    }
656}
657
658// =============================================================================
659// Coefficient Base Range
660// =============================================================================
661
662/// Coefficient base range context.
663#[derive(Clone, Copy, Debug, Default)]
664pub struct CoeffBaseRange {
665    /// Base level (0-4).
666    pub base_level: u8,
667    /// Range context.
668    pub range_ctx: u8,
669}
670
671impl CoeffBaseRange {
672    /// Get context for coefficient base range coding.
673    #[must_use]
674    pub fn get_br_context(level_ctx: &LevelContext, pos: usize, width: usize) -> u8 {
675        let row = pos / width;
676        let col = pos % width;
677
678        // Base context from position
679        let pos_ctx = if row + col == 0 {
680            0
681        } else if row + col < 2 {
682            7
683        } else {
684            14
685        };
686
687        // Combine with magnitude context
688        pos_ctx + level_ctx.mag_context().min(6)
689    }
690
691    /// Compute level from base and range.
692    #[must_use]
693    pub fn compute_level(base: u8, range: u16) -> u32 {
694        u32::from(base) + u32::from(range)
695    }
696}
697
698// =============================================================================
699// Dequantization Helpers
700// =============================================================================
701
702/// Dequantize a single coefficient.
703#[must_use]
704pub fn dequantize_coeff(level: i32, dequant: i16, shift: u8) -> i32 {
705    let abs_level = level.abs();
706    let dq_level = (abs_level * i32::from(dequant)) >> shift;
707
708    if level < 0 {
709        -dq_level
710    } else {
711        dq_level
712    }
713}
714
715/// Dequantize all coefficients in a block.
716pub fn dequantize_block(coeffs: &mut [i32], dc_dequant: i16, ac_dequant: i16, shift: u8) {
717    if coeffs.is_empty() {
718        return;
719    }
720
721    // DC coefficient
722    coeffs[0] = dequantize_coeff(coeffs[0], dc_dequant, shift);
723
724    // AC coefficients
725    for coeff in coeffs.iter_mut().skip(1) {
726        *coeff = dequantize_coeff(*coeff, ac_dequant, shift);
727    }
728}
729
730/// Compute dequantization shift for a given bit depth.
731#[must_use]
732pub const fn get_dequant_shift(bit_depth: u8) -> u8 {
733    match bit_depth {
734        8 => 0,
735        10 => 2,
736        12 => 4,
737        _ => 0,
738    }
739}
740
741// =============================================================================
742// Coefficient Buffer
743// =============================================================================
744
745/// Buffer for storing and manipulating coefficient data.
746#[derive(Clone, Debug)]
747pub struct CoeffBuffer {
748    /// Coefficient storage.
749    coeffs: Vec<i32>,
750    /// Width of the buffer.
751    width: usize,
752    /// Height of the buffer.
753    height: usize,
754}
755
756impl CoeffBuffer {
757    /// Create a new coefficient buffer.
758    #[must_use]
759    pub fn new(width: usize, height: usize) -> Self {
760        Self {
761            coeffs: vec![0; width * height],
762            width,
763            height,
764        }
765    }
766
767    /// Create from transform size.
768    #[must_use]
769    pub fn from_tx_size(tx_size: TxSize) -> Self {
770        Self::new(tx_size.width() as usize, tx_size.height() as usize)
771    }
772
773    /// Get coefficient at position.
774    #[must_use]
775    pub fn get(&self, row: usize, col: usize) -> i32 {
776        if row < self.height && col < self.width {
777            self.coeffs[row * self.width + col]
778        } else {
779            0
780        }
781    }
782
783    /// Set coefficient at position.
784    pub fn set(&mut self, row: usize, col: usize, value: i32) {
785        if row < self.height && col < self.width {
786            self.coeffs[row * self.width + col] = value;
787        }
788    }
789
790    /// Clear all coefficients.
791    pub fn clear(&mut self) {
792        self.coeffs.fill(0);
793    }
794
795    /// Get mutable slice of coefficients.
796    pub fn as_mut_slice(&mut self) -> &mut [i32] {
797        &mut self.coeffs
798    }
799
800    /// Get immutable slice of coefficients.
801    #[must_use]
802    pub fn as_slice(&self) -> &[i32] {
803        &self.coeffs
804    }
805
806    /// Copy from scan order.
807    pub fn copy_from_scan(&mut self, src: &[i32], scan: &[u16]) {
808        for (i, &pos) in scan.iter().enumerate() {
809            if i < src.len() && (pos as usize) < self.coeffs.len() {
810                self.coeffs[pos as usize] = src[i];
811            }
812        }
813    }
814
815    /// Copy to scan order.
816    pub fn copy_to_scan(&self, dst: &mut [i32], scan: &[u16]) {
817        for (i, &pos) in scan.iter().enumerate() {
818            if i < dst.len() && (pos as usize) < self.coeffs.len() {
819                dst[i] = self.coeffs[pos as usize];
820            }
821        }
822    }
823}
824
825impl Default for CoeffBuffer {
826    fn default() -> Self {
827        Self::new(4, 4)
828    }
829}
830
831// =============================================================================
832// Neighbor Context Computation
833// =============================================================================
834
835/// Get neighbor positions for context computation.
836#[must_use]
837pub fn get_neighbor_positions(pos: usize, width: usize, _height: usize) -> [(usize, bool); 5] {
838    let row = pos / width;
839    let col = pos % width;
840
841    let mut neighbors = [(0usize, false); 5];
842
843    // Left neighbor
844    if col > 0 {
845        neighbors[0] = (row * width + col - 1, true);
846    }
847
848    // Above neighbor
849    if row > 0 {
850        neighbors[1] = ((row - 1) * width + col, true);
851    }
852
853    // Top-left diagonal
854    if row > 0 && col > 0 {
855        neighbors[2] = ((row - 1) * width + col - 1, true);
856    }
857
858    // Top-right diagonal
859    if row > 0 && col + 1 < width {
860        neighbors[3] = ((row - 1) * width + col + 1, true);
861    }
862
863    // Two positions left
864    if col > 1 {
865        neighbors[4] = (row * width + col - 2, true);
866    }
867
868    neighbors
869}
870
871/// Compute context from neighbor levels.
872#[must_use]
873pub fn compute_context_from_neighbors(levels: &[i32], neighbors: &[(usize, bool); 5]) -> u8 {
874    let mut mag = 0u32;
875    let mut count = 0u8;
876
877    for &(pos, valid) in neighbors.iter() {
878        if valid && pos < levels.len() {
879            let level = levels[pos].unsigned_abs();
880            mag += level;
881            if level > 0 {
882                count += 1;
883            }
884        }
885    }
886
887    // Context based on magnitude and count
888    let mag_ctx = if mag > 512 {
889        4
890    } else if mag > 256 {
891        3
892    } else if mag > 128 {
893        2
894    } else if mag > 64 {
895        1
896    } else {
897        0
898    };
899
900    mag_ctx * 3 + count.min(2)
901}
902
903// =============================================================================
904// Sign Coding Helpers
905// =============================================================================
906
907/// DC sign context computation.
908#[must_use]
909pub fn compute_dc_sign_context(left_dc: i32, above_dc: i32) -> u8 {
910    let left_sign = left_dc.signum();
911    let above_sign = above_dc.signum();
912
913    let sum = left_sign + above_sign;
914
915    if sum < 0 {
916        0
917    } else if sum > 0 {
918        2
919    } else {
920        1
921    }
922}
923
924/// Update context after coefficient is decoded.
925pub fn update_level_context(
926    left_ctx: &mut [u8],
927    above_ctx: &mut [u8],
928    level: i32,
929    row: usize,
930    col: usize,
931) {
932    let level_ctx = (level.unsigned_abs().min(63) as u8) + 1;
933
934    if row < left_ctx.len() {
935        left_ctx[row] = level_ctx;
936    }
937
938    if col < above_ctx.len() {
939        above_ctx[col] = level_ctx;
940    }
941}
942
943// =============================================================================
944// Coefficient Statistics
945// =============================================================================
946
947/// Statistics about coefficients in a block.
948#[derive(Clone, Debug, Default)]
949pub struct CoeffStats {
950    /// Number of zero coefficients.
951    pub zero_count: u32,
952    /// Number of coefficients with level 1.
953    pub level1_count: u32,
954    /// Number of coefficients with level 2.
955    pub level2_count: u32,
956    /// Number of coefficients with level > 2.
957    pub high_level_count: u32,
958    /// Sum of absolute levels.
959    pub level_sum: u64,
960    /// Maximum absolute level.
961    pub max_level: u32,
962}
963
964impl CoeffStats {
965    /// Compute statistics from coefficient buffer.
966    #[must_use]
967    pub fn from_coeffs(coeffs: &[i32]) -> Self {
968        let mut stats = Self::default();
969
970        for &coeff in coeffs {
971            let level = coeff.unsigned_abs();
972
973            match level {
974                0 => stats.zero_count += 1,
975                1 => stats.level1_count += 1,
976                2 => stats.level2_count += 1,
977                _ => stats.high_level_count += 1,
978            }
979
980            stats.level_sum += u64::from(level);
981            stats.max_level = stats.max_level.max(level);
982        }
983
984        stats
985    }
986
987    /// Get total non-zero count.
988    #[must_use]
989    pub fn nonzero_count(&self) -> u32 {
990        self.level1_count + self.level2_count + self.high_level_count
991    }
992
993    /// Get average level (for non-zero coefficients).
994    #[must_use]
995    pub fn average_level(&self) -> f64 {
996        let count = self.nonzero_count();
997        if count > 0 {
998            self.level_sum as f64 / count as f64
999        } else {
1000            0.0
1001        }
1002    }
1003}
1004
1005// =============================================================================
1006// Tests
1007// =============================================================================
1008
1009#[cfg(test)]
1010mod tests {
1011    use super::*;
1012
1013    #[test]
1014    fn test_level_context() {
1015        let mut ctx = LevelContext::new();
1016        assert_eq!(ctx.mag, 0);
1017        assert_eq!(ctx.count, 0);
1018
1019        ctx.mag = 100;
1020        ctx.count = 2;
1021        assert_eq!(ctx.mag_context(), 1);
1022        assert_eq!(ctx.context(), 1 * 3 + 2);
1023    }
1024
1025    #[test]
1026    fn test_coeff_context_new() {
1027        let ctx = CoeffContext::new(TxSize::Tx8x8, TxType::DctDct, 0);
1028        assert_eq!(ctx.levels.len(), 64);
1029        assert_eq!(ctx.tx_class(), TxClass::Class2D);
1030    }
1031
1032    #[test]
1033    fn test_coeff_context_set_get() {
1034        let mut ctx = CoeffContext::new(TxSize::Tx4x4, TxType::DctDct, 0);
1035        ctx.set_coeff(5, 100, false);
1036        assert_eq!(ctx.get_coeff(5), 100);
1037
1038        ctx.set_coeff(10, 50, true);
1039        assert_eq!(ctx.get_coeff(10), -50);
1040    }
1041
1042    #[test]
1043    fn test_diagonal_scan_4x4() {
1044        let scan = generate_diagonal_scan(4, 4);
1045        assert_eq!(scan.len(), 16);
1046        // First few elements should be diagonal
1047        assert_eq!(scan[0], 0); // (0,0)
1048        assert_eq!(scan[1], 1); // (0,1) - in row-major for 4x4
1049    }
1050
1051    #[test]
1052    fn test_horizontal_scan() {
1053        let scan = generate_horizontal_scan(4, 4);
1054        assert_eq!(scan.len(), 16);
1055        for i in 0..16 {
1056            assert_eq!(scan[i], i as u16);
1057        }
1058    }
1059
1060    #[test]
1061    fn test_vertical_scan() {
1062        let scan = generate_vertical_scan(4, 4);
1063        assert_eq!(scan.len(), 16);
1064        assert_eq!(scan[0], 0);
1065        assert_eq!(scan[1], 4);
1066        assert_eq!(scan[2], 8);
1067        assert_eq!(scan[3], 12);
1068    }
1069
1070    #[test]
1071    fn test_scan_order_cache() {
1072        let cache = ScanOrderCache::new();
1073        let scan = cache.get(TxSize::Tx4x4, TxClass::Class2D);
1074        assert_eq!(scan.len(), 16);
1075    }
1076
1077    #[test]
1078    fn test_eob_context() {
1079        let ctx = EobContext::new(TxSize::Tx8x8);
1080        assert!(ctx.eob_extra > 0);
1081    }
1082
1083    #[test]
1084    fn test_eob_pt() {
1085        assert_eq!(EobPt::from_eob(0), EobPt::EobPt0);
1086        assert_eq!(EobPt::from_eob(1), EobPt::EobPt1);
1087        assert_eq!(EobPt::from_eob(5), EobPt::EobPt5To8);
1088        assert_eq!(EobPt::from_eob(100), EobPt::EobPt65To128);
1089
1090        assert_eq!(EobPt::EobPt5To8.extra_bits(), 2);
1091        assert_eq!(EobPt::EobPt5To8.base_eob(), 5);
1092    }
1093
1094    #[test]
1095    fn test_dequantize_coeff() {
1096        let level = 10;
1097        let dequant = 16;
1098        let result = dequantize_coeff(level, dequant, 0);
1099        assert_eq!(result, 160);
1100
1101        let neg_result = dequantize_coeff(-level, dequant, 0);
1102        assert_eq!(neg_result, -160);
1103    }
1104
1105    #[test]
1106    fn test_dequantize_block() {
1107        let mut coeffs = vec![10, 5, 5, 5, 5, 5, 5, 5];
1108        dequantize_block(&mut coeffs, 20, 10, 0);
1109
1110        assert_eq!(coeffs[0], 200); // DC: 10 * 20
1111        assert_eq!(coeffs[1], 50); // AC: 5 * 10
1112    }
1113
1114    #[test]
1115    fn test_get_dequant_shift() {
1116        assert_eq!(get_dequant_shift(8), 0);
1117        assert_eq!(get_dequant_shift(10), 2);
1118        assert_eq!(get_dequant_shift(12), 4);
1119    }
1120
1121    #[test]
1122    fn test_coeff_buffer() {
1123        let mut buf = CoeffBuffer::new(4, 4);
1124        buf.set(1, 2, 100);
1125        assert_eq!(buf.get(1, 2), 100);
1126        assert_eq!(buf.get(0, 0), 0);
1127
1128        buf.clear();
1129        assert_eq!(buf.get(1, 2), 0);
1130    }
1131
1132    #[test]
1133    fn test_coeff_buffer_from_tx_size() {
1134        let buf = CoeffBuffer::from_tx_size(TxSize::Tx8x8);
1135        assert_eq!(buf.as_slice().len(), 64);
1136    }
1137
1138    #[test]
1139    fn test_neighbor_positions() {
1140        let neighbors = get_neighbor_positions(5, 4, 4);
1141
1142        // Position 5 is row=1, col=1 in 4x4
1143        // Left neighbor should be valid at position 4
1144        assert!(neighbors[0].1);
1145        assert_eq!(neighbors[0].0, 4);
1146
1147        // Above neighbor should be valid at position 1
1148        assert!(neighbors[1].1);
1149        assert_eq!(neighbors[1].0, 1);
1150    }
1151
1152    #[test]
1153    fn test_compute_dc_sign_context() {
1154        assert_eq!(compute_dc_sign_context(-5, -3), 0); // Both negative
1155        assert_eq!(compute_dc_sign_context(5, 3), 2); // Both positive
1156        assert_eq!(compute_dc_sign_context(-5, 3), 1); // Mixed
1157        assert_eq!(compute_dc_sign_context(0, 0), 1); // Zero
1158    }
1159
1160    #[test]
1161    fn test_coeff_stats() {
1162        let coeffs = vec![0, 1, 2, 3, 0, 1, 5, 0];
1163        let stats = CoeffStats::from_coeffs(&coeffs);
1164
1165        assert_eq!(stats.zero_count, 3);
1166        assert_eq!(stats.level1_count, 2);
1167        assert_eq!(stats.level2_count, 1);
1168        assert_eq!(stats.high_level_count, 2);
1169        assert_eq!(stats.max_level, 5);
1170        assert_eq!(stats.nonzero_count(), 5);
1171    }
1172
1173    #[test]
1174    fn test_coeff_context_dc_sign() {
1175        let ctx = CoeffContext::new(TxSize::Tx4x4, TxType::DctDct, 0);
1176        // Default context with empty context arrays
1177        // Results in neutral sign context (0 or 1 depending on implementation)
1178        let dc_ctx = ctx.dc_sign_context();
1179        // Context should be valid (0, 1, or 2)
1180        assert!(dc_ctx <= 2);
1181    }
1182
1183    #[test]
1184    fn test_coeff_context_level_context() {
1185        let mut ctx = CoeffContext::new(TxSize::Tx4x4, TxType::DctDct, 0);
1186        ctx.levels[0] = 5;
1187        ctx.levels[1] = 3;
1188        ctx.levels[4] = 2;
1189
1190        let level_ctx = ctx.compute_level_context(5);
1191        // Position 5 has neighbors at 4 (left) and 1 (above)
1192        assert!(level_ctx.mag > 0);
1193    }
1194
1195    #[test]
1196    fn test_eob_compute() {
1197        // Test EOB computation from multi-symbol and extra bits
1198        assert_eq!(EobContext::compute_eob(0, 0), 0);
1199        assert_eq!(EobContext::compute_eob(1, 0), 1);
1200        assert_eq!(EobContext::compute_eob(2, 0), 2);
1201    }
1202
1203    #[test]
1204    fn test_coeff_base_range() {
1205        let level_ctx = LevelContext {
1206            mag: 100,
1207            count: 2,
1208            pos_ctx: 1,
1209        };
1210
1211        let br_ctx = CoeffBaseRange::get_br_context(&level_ctx, 5, 4);
1212        assert!(br_ctx > 0);
1213
1214        let level = CoeffBaseRange::compute_level(2, 5);
1215        assert_eq!(level, 7);
1216    }
1217
1218    #[test]
1219    fn test_constants() {
1220        assert_eq!(MAX_EOB, 4096);
1221        assert_eq!(EOB_COEF_CONTEXTS, 9);
1222        assert_eq!(TX_CLASSES, 3);
1223    }
1224
1225    #[test]
1226    fn test_coeff_context_reset() {
1227        let mut ctx = CoeffContext::new(TxSize::Tx4x4, TxType::DctDct, 0);
1228        ctx.eob = 10;
1229        ctx.levels[5] = 100;
1230
1231        ctx.reset();
1232        assert_eq!(ctx.eob, 0);
1233        assert_eq!(ctx.levels[5], 0);
1234    }
1235
1236    #[test]
1237    fn test_coeff_context_count_nonzero() {
1238        let mut ctx = CoeffContext::new(TxSize::Tx4x4, TxType::DctDct, 0);
1239        ctx.levels[0] = 5;
1240        ctx.levels[5] = 3;
1241        ctx.levels[10] = -2;
1242
1243        assert_eq!(ctx.count_nonzero(), 3);
1244    }
1245
1246    #[test]
1247    fn test_scan_order_all_sizes() {
1248        // Test that scan order generation works for all sizes
1249        for size_idx in 0..19 {
1250            if let Some(tx_size) = TxSize::from_u8(size_idx) {
1251                for class_idx in 0..3 {
1252                    if let Some(tx_class) = TxClass::from_u8(class_idx) {
1253                        let scan = get_scan_order(tx_size, tx_class);
1254                        assert_eq!(scan.len(), tx_size.area() as usize);
1255                    }
1256                }
1257            }
1258        }
1259    }
1260}