Skip to main content

oximedia_codec/av1/
coefficients.rs

1//! AV1 coefficient parsing and dequantization.
2//!
3//! This module handles the parsing of transform coefficients from the
4//! entropy-coded bitstream, including:
5//!
6//! - End of block (EOB) position parsing
7//! - Coefficient level and sign decoding
8//! - Scan order for coefficient serialization
9//! - Dequantization helpers
10//!
11//! # Coefficient Coding Structure
12//!
13//! AV1 uses a sophisticated multi-level coding scheme:
14//!
15//! 1. **EOB (End of Block)** - Position of last non-zero coefficient
16//! 2. **Coefficient Base** - Base level (0-2) using multi-symbol coding
17//! 3. **Coefficient Base Range** - Extended range using Golomb-Rice codes
18//! 4. **DC Sign** - Sign of DC coefficient
19//! 5. **AC Signs** - Signs of AC coefficients
20//!
21//! # Scan Orders
22//!
23//! Coefficients are scanned in a specific order based on:
24//! - Transform class (2D, horizontal, vertical)
25//! - Transform size
26//!
27//! # Reference
28//!
29//! See AV1 Specification Section 5.11.39 for coefficient syntax and
30//! Section 7.12 for coefficient semantics.
31
32#![allow(dead_code)]
33#![allow(clippy::cast_possible_truncation)]
34#![allow(clippy::cast_sign_loss)]
35#![allow(clippy::bool_to_int_with_if)]
36#![allow(clippy::needless_bool_assign)]
37#![allow(clippy::if_not_else)]
38#![allow(clippy::cast_possible_wrap)]
39#![allow(clippy::match_same_arms)]
40#![allow(clippy::doc_markdown)]
41#![allow(clippy::explicit_iter_loop)]
42#![allow(clippy::cast_precision_loss)]
43#![allow(clippy::comparison_chain)]
44#![allow(clippy::cast_lossless)]
45
46use super::transform::{TxClass, TxSize, TxType};
47use crate::error::{CodecError, CodecResult};
48
49// =============================================================================
50// Constants
51// =============================================================================
52
53/// Maximum EOB position for any transform size.
54pub const MAX_EOB: usize = 4096;
55
56/// Number of EOB position contexts.
57pub const EOB_COEF_CONTEXTS: usize = 9;
58
59/// Number of coefficient base contexts.
60pub const COEFF_BASE_CONTEXTS: usize = 42;
61
62/// Number of coefficient base EOB contexts.
63pub const COEFF_BASE_EOB_CONTEXTS: usize = 3;
64
65/// Number of DC sign contexts.
66pub const DC_SIGN_CONTEXTS: usize = 3;
67
68/// Number of coefficient base range contexts.
69pub const COEFF_BR_CONTEXTS: usize = 21;
70
71/// Maximum coefficient base level.
72pub const COEFF_BASE_RANGE_MAX: u32 = 3;
73
74/// Golomb-Rice parameter for coefficient coding.
75pub const COEFF_BR_RICE_PARAM: u8 = 1;
76
77/// Base level cutoffs for coefficient coding.
78pub const BASE_LEVEL_CUTOFFS: [u32; 5] = [0, 1, 2, 3, 4];
79
80/// Number of TX classes.
81pub const TX_CLASSES: usize = 3;
82
83/// Coefficient context position limit.
84pub const COEFF_CONTEXT_MASK: usize = 63;
85
86/// Maximum neighbors for context computation.
87pub const MAX_NEIGHBORS: usize = 2;
88
89// =============================================================================
90// EOB Position Tables
91// =============================================================================
92
93/// EOB offset for each transform size.
94pub const EOB_OFFSET: [u16; 19] = [
95    0,    // TX_4X4
96    16,   // TX_8X8
97    80,   // TX_16X16
98    336,  // TX_32X32
99    1360, // TX_64X64
100    16,   // TX_4X8
101    16,   // TX_8X4
102    80,   // TX_8X16
103    80,   // TX_16X8
104    336,  // TX_16X32
105    336,  // TX_32X16
106    1360, // TX_32X64
107    1360, // TX_64X32
108    48,   // TX_4X16
109    48,   // TX_16X4
110    176,  // TX_8X32
111    176,  // TX_32X8
112    592,  // TX_16X64
113    592,  // TX_64X16
114];
115
116/// EOB extra bits for each transform size.
117pub const EOB_EXTRA_BITS: [u8; 19] = [
118    0, // TX_4X4
119    1, // TX_8X8
120    2, // TX_16X16
121    3, // TX_32X32
122    4, // TX_64X64
123    1, // TX_4X8
124    1, // TX_8X4
125    2, // TX_8X16
126    2, // TX_16X8
127    3, // TX_16X32
128    3, // TX_32X16
129    4, // TX_32X64
130    4, // TX_64X32
131    2, // TX_4X16
132    2, // TX_16X4
133    3, // TX_8X32
134    3, // TX_32X8
135    4, // TX_16X64
136    4, // TX_64X16
137];
138
139/// EOB group start positions.
140pub const EOB_GROUP_START: [u16; 12] = [0, 1, 2, 3, 5, 9, 17, 33, 65, 129, 257, 513];
141
142/// EOB symbol to position mapping.
143pub const EOB_TO_POS: [u16; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
144
145// =============================================================================
146// Level Context
147// =============================================================================
148
149/// Context for coefficient level coding.
150#[derive(Clone, Debug, Default)]
151pub struct LevelContext {
152    /// Accumulated magnitude of neighbors.
153    pub mag: u32,
154    /// Number of non-zero neighbors.
155    pub count: u8,
156    /// Position-based context.
157    pub pos_ctx: u8,
158}
159
160impl LevelContext {
161    /// Create a new level context.
162    #[must_use]
163    pub const fn new() -> Self {
164        Self {
165            mag: 0,
166            count: 0,
167            pos_ctx: 0,
168        }
169    }
170
171    /// Compute context from magnitude.
172    #[must_use]
173    pub fn mag_context(&self) -> u8 {
174        let mag = self.mag;
175        if mag > 512 {
176            4
177        } else if mag > 256 {
178            3
179        } else if mag > 128 {
180            2
181        } else if mag > 64 {
182            1
183        } else {
184            0
185        }
186    }
187
188    /// Compute the combined context.
189    #[must_use]
190    pub fn context(&self) -> u8 {
191        self.mag_context() * 3 + self.count.min(2)
192    }
193}
194
195// =============================================================================
196// Coefficient Context
197// =============================================================================
198
199/// Context for coefficient parsing state.
200#[derive(Clone, Debug)]
201pub struct CoeffContext {
202    /// Transform size.
203    pub tx_size: TxSize,
204    /// Transform type.
205    pub tx_type: TxType,
206    /// Plane index (0=Y, 1=U, 2=V).
207    pub plane: u8,
208    /// Current scan position.
209    pub scan_pos: u16,
210    /// End of block position.
211    pub eob: u16,
212    /// Coefficient levels (dequantized).
213    pub levels: Vec<i32>,
214    /// Sign bits.
215    pub signs: Vec<bool>,
216    /// Accumulated left context.
217    pub left_ctx: Vec<u8>,
218    /// Accumulated above context.
219    pub above_ctx: Vec<u8>,
220    /// Block width in 4x4 units.
221    pub block_width: u8,
222    /// Block height in 4x4 units.
223    pub block_height: u8,
224}
225
226impl CoeffContext {
227    /// Create a new coefficient context.
228    #[must_use]
229    pub fn new(tx_size: TxSize, tx_type: TxType, plane: u8) -> Self {
230        let area = tx_size.area() as usize;
231        let width = (tx_size.width() / 4) as u8;
232        let height = (tx_size.height() / 4) as u8;
233
234        Self {
235            tx_size,
236            tx_type,
237            plane,
238            scan_pos: 0,
239            eob: 0,
240            levels: vec![0; area],
241            signs: vec![false; area],
242            left_ctx: vec![0; height as usize * 4],
243            above_ctx: vec![0; width as usize * 4],
244            block_width: width,
245            block_height: height,
246        }
247    }
248
249    /// Reset the context for a new block.
250    pub fn reset(&mut self) {
251        self.scan_pos = 0;
252        self.eob = 0;
253        self.levels.fill(0);
254        self.signs.fill(false);
255        self.left_ctx.fill(0);
256        self.above_ctx.fill(0);
257    }
258
259    /// Get the transform class.
260    #[must_use]
261    pub fn tx_class(&self) -> TxClass {
262        self.tx_type.tx_class()
263    }
264
265    /// Get scan position for a coefficient index.
266    #[must_use]
267    pub fn get_scan_position(&self, idx: usize) -> (u32, u32) {
268        let width = self.tx_size.width();
269        let row = (idx as u32) / width;
270        let col = (idx as u32) % width;
271        (row, col)
272    }
273
274    /// Get coefficient index from row and column.
275    #[must_use]
276    pub fn get_coeff_index(&self, row: u32, col: u32) -> usize {
277        (row * self.tx_size.width() + col) as usize
278    }
279
280    /// Compute level context for a position.
281    #[must_use]
282    pub fn compute_level_context(&self, pos: usize) -> LevelContext {
283        let width = self.tx_size.width() as usize;
284        let _height = self.tx_size.height() as usize;
285        let row = pos / width;
286        let col = pos % width;
287
288        let mut ctx = LevelContext::new();
289
290        // Get neighbors (left and above)
291        if col > 0 {
292            let left = self.levels[row * width + col - 1].unsigned_abs();
293            ctx.mag += left;
294            if left > 0 {
295                ctx.count += 1;
296            }
297        }
298
299        if row > 0 {
300            let above = self.levels[(row - 1) * width + col].unsigned_abs();
301            ctx.mag += above;
302            if above > 0 {
303                ctx.count += 1;
304            }
305        }
306
307        // Diagonal neighbor
308        if row > 0 && col > 0 {
309            let diag = self.levels[(row - 1) * width + col - 1].unsigned_abs();
310            ctx.mag += diag;
311        }
312
313        // Position context
314        ctx.pos_ctx = if row + col == 0 {
315            0
316        } else if row + col < 2 {
317            1
318        } else if row + col < 4 {
319            2
320        } else {
321            3
322        };
323
324        ctx
325    }
326
327    /// Get DC sign context.
328    #[must_use]
329    pub fn dc_sign_context(&self) -> u8 {
330        let left_sign = if !self.left_ctx.is_empty() {
331            (self.left_ctx[0] as i8 - 1).signum()
332        } else {
333            0
334        };
335
336        let above_sign = if !self.above_ctx.is_empty() {
337            (self.above_ctx[0] as i8 - 1).signum()
338        } else {
339            0
340        };
341
342        let sign_sum = left_sign + above_sign;
343
344        if sign_sum < 0 {
345            0
346        } else if sign_sum > 0 {
347            2
348        } else {
349            1
350        }
351    }
352
353    /// Set coefficient value at position.
354    pub fn set_coeff(&mut self, pos: usize, level: i32, sign: bool) {
355        if pos < self.levels.len() {
356            self.levels[pos] = if sign { -level } else { level };
357            self.signs[pos] = sign;
358        }
359    }
360
361    /// Get coefficient value at position.
362    #[must_use]
363    pub fn get_coeff(&self, pos: usize) -> i32 {
364        self.levels.get(pos).copied().unwrap_or(0)
365    }
366
367    /// Check if block has any non-zero coefficients.
368    #[must_use]
369    pub fn has_nonzero(&self) -> bool {
370        self.eob > 0
371    }
372
373    /// Get the number of non-zero coefficients.
374    #[must_use]
375    pub fn count_nonzero(&self) -> u16 {
376        self.levels.iter().filter(|&&l| l != 0).count() as u16
377    }
378}
379
380impl Default for CoeffContext {
381    fn default() -> Self {
382        Self::new(TxSize::Tx4x4, TxType::DctDct, 0)
383    }
384}
385
386// =============================================================================
387// Scan Order Generation
388// =============================================================================
389
390/// Generate diagonal scan order for a given size.
391#[must_use]
392pub fn generate_diagonal_scan(width: usize, height: usize) -> Vec<u16> {
393    let mut scan = Vec::with_capacity(width * height);
394
395    // Traverse diagonals
396    for diag in 0..(width + height - 1) {
397        // Start position for this diagonal
398        let col_start = if diag < width { 0 } else { diag - width + 1 };
399        let col_end = diag.min(height - 1);
400
401        for offset in 0..=(col_end - col_start) {
402            let row = col_start + offset;
403            let col = diag - row;
404
405            if col < width && row < height {
406                scan.push((row * width + col) as u16);
407            }
408        }
409    }
410
411    scan
412}
413
414/// Generate horizontal scan order.
415#[must_use]
416pub fn generate_horizontal_scan(width: usize, height: usize) -> Vec<u16> {
417    let mut scan = Vec::with_capacity(width * height);
418
419    for row in 0..height {
420        for col in 0..width {
421            scan.push((row * width + col) as u16);
422        }
423    }
424
425    scan
426}
427
428/// Generate vertical scan order.
429#[must_use]
430pub fn generate_vertical_scan(width: usize, height: usize) -> Vec<u16> {
431    let mut scan = Vec::with_capacity(width * height);
432
433    for col in 0..width {
434        for row in 0..height {
435            scan.push((row * width + col) as u16);
436        }
437    }
438
439    scan
440}
441
442/// Get the scan order for a given transform.
443#[must_use]
444pub fn get_scan_order(tx_size: TxSize, tx_class: TxClass) -> Vec<u16> {
445    let width = tx_size.width() as usize;
446    let height = tx_size.height() as usize;
447
448    match tx_class {
449        TxClass::Class2D => generate_diagonal_scan(width, height),
450        TxClass::ClassHoriz => generate_horizontal_scan(width, height),
451        TxClass::ClassVert => generate_vertical_scan(width, height),
452    }
453}
454
455/// Scan order cache for common transform sizes.
456#[derive(Clone, Debug)]
457pub struct ScanOrderCache {
458    /// Cached scan orders indexed by [tx_size][tx_class].
459    cache: Vec<Vec<Vec<u16>>>,
460}
461
462impl ScanOrderCache {
463    /// Create a new scan order cache.
464    #[must_use]
465    pub fn new() -> Self {
466        let mut cache = Vec::with_capacity(19);
467
468        for tx_size_idx in 0..19 {
469            let tx_size = TxSize::from_u8(tx_size_idx as u8).unwrap_or_default();
470            let mut class_scans = Vec::with_capacity(3);
471
472            for tx_class_idx in 0..3 {
473                let tx_class = TxClass::from_u8(tx_class_idx as u8).unwrap_or_default();
474                class_scans.push(get_scan_order(tx_size, tx_class));
475            }
476
477            cache.push(class_scans);
478        }
479
480        Self { cache }
481    }
482
483    /// Get scan order from cache.
484    #[must_use]
485    pub fn get(&self, tx_size: TxSize, tx_class: TxClass) -> &[u16] {
486        let size_idx = tx_size as usize;
487        let class_idx = tx_class as usize;
488
489        if size_idx < self.cache.len() && class_idx < self.cache[size_idx].len() {
490            &self.cache[size_idx][class_idx]
491        } else {
492            &[]
493        }
494    }
495}
496
497impl Default for ScanOrderCache {
498    fn default() -> Self {
499        Self::new()
500    }
501}
502
503// =============================================================================
504// EOB Parsing Helpers
505// =============================================================================
506
507/// EOB (End of Block) position context.
508#[derive(Clone, Debug, Default)]
509pub struct EobContext {
510    /// EOB multi-context.
511    pub eob_multi: u8,
512    /// EOB extra bits.
513    pub eob_extra: u8,
514    /// Base context.
515    pub base_ctx: u8,
516}
517
518impl EobContext {
519    /// Create EOB context for a transform size.
520    #[must_use]
521    pub fn new(tx_size: TxSize) -> Self {
522        let size_idx = tx_size as usize;
523        let extra_bits = if size_idx < EOB_EXTRA_BITS.len() {
524            EOB_EXTRA_BITS[size_idx]
525        } else {
526            0
527        };
528
529        Self {
530            eob_multi: 0,
531            eob_extra: extra_bits,
532            base_ctx: 0,
533        }
534    }
535
536    /// Get the EOB context from position.
537    #[must_use]
538    pub fn get_eob_context(eob: u16) -> u8 {
539        if eob <= 1 {
540            0
541        } else if eob <= 2 {
542            1
543        } else if eob <= 4 {
544            2
545        } else if eob <= 8 {
546            3
547        } else if eob <= 16 {
548            4
549        } else if eob <= 32 {
550            5
551        } else if eob <= 64 {
552            6
553        } else if eob <= 128 {
554            7
555        } else {
556            8
557        }
558    }
559
560    /// Compute EOB from multi-symbol and extra bits.
561    #[must_use]
562    pub fn compute_eob(eob_multi: u8, eob_extra: u16) -> u16 {
563        let group_idx = eob_multi as usize;
564        if group_idx >= EOB_GROUP_START.len() {
565            return 0;
566        }
567
568        let base = EOB_GROUP_START[group_idx];
569        base + eob_extra
570    }
571}
572
573/// EOB point parsing state.
574#[derive(Clone, Copy, Debug, PartialEq, Eq)]
575pub enum EobPt {
576    /// No coefficients.
577    EobPt0 = 0,
578    /// 1 coefficient.
579    EobPt1 = 1,
580    /// 2 coefficients.
581    EobPt2 = 2,
582    /// 3-4 coefficients.
583    EobPt3To4 = 3,
584    /// 5-8 coefficients.
585    EobPt5To8 = 4,
586    /// 9-16 coefficients.
587    EobPt9To16 = 5,
588    /// 17-32 coefficients.
589    EobPt17To32 = 6,
590    /// 33-64 coefficients.
591    EobPt33To64 = 7,
592    /// 65-128 coefficients.
593    EobPt65To128 = 8,
594    /// 129-256 coefficients.
595    EobPt129To256 = 9,
596    /// 257-512 coefficients.
597    EobPt257To512 = 10,
598    /// 513-1024 coefficients.
599    EobPt513To1024 = 11,
600}
601
602impl EobPt {
603    /// Get the EOB point from an EOB *position* (the actual index of the last
604    /// non-zero coefficient in scan order, i.e. the count of decoded
605    /// coefficients).
606    ///
607    /// Use this when you already have the final EOB position — for example in
608    /// the encoder, after locating the last non-zero entry in a transform
609    /// block.
610    ///
611    /// For the decoder path, where the bitstream carries the EOB-multi *symbol*
612    /// (the index into the EOB-multi CDF, range 0..11), use
613    /// [`EobPt::from_symbol`] instead. The two values are not interchangeable:
614    /// symbol `s` corresponds to base position
615    /// [`EOB_GROUP_START`]`[s]`, which is `(1 << (s - 1)) + 1` for `s ≥ 2`.
616    #[must_use]
617    pub fn from_eob(eob: u16) -> Self {
618        match eob {
619            0 => Self::EobPt0,
620            1 => Self::EobPt1,
621            2 => Self::EobPt2,
622            3..=4 => Self::EobPt3To4,
623            5..=8 => Self::EobPt5To8,
624            9..=16 => Self::EobPt9To16,
625            17..=32 => Self::EobPt17To32,
626            33..=64 => Self::EobPt33To64,
627            65..=128 => Self::EobPt65To128,
628            129..=256 => Self::EobPt129To256,
629            257..=512 => Self::EobPt257To512,
630            _ => Self::EobPt513To1024,
631        }
632    }
633
634    /// Build an [`EobPt`] from the EOB-multi *symbol index* read from the
635    /// bitstream (range `0..=11`).
636    ///
637    /// The EOB-multi CDF in AV1 emits a symbol that directly identifies one of
638    /// the 12 EOB point classes. Each enum variant's discriminant equals the
639    /// symbol value it represents, so this is essentially the inverse of
640    /// [`EobPt::base_eob`] / [`EobPt::extra_bits`].
641    ///
642    /// Returns [`CodecError::InvalidBitstream`] if `symbol > 11`. The largest
643    /// EOB-multi CDF in the AV1 spec carries 16 symbols, so a malformed
644    /// bitstream can in principle place symbols 12..15 here — those values are
645    /// rejected rather than silently saturated, because they desynchronize the
646    /// EOB extra-bits read that follows.
647    ///
648    /// # Errors
649    ///
650    /// Returns [`CodecError::InvalidBitstream`] if `symbol >= 12`.
651    pub fn from_symbol(symbol: usize) -> CodecResult<Self> {
652        match symbol {
653            0 => Ok(Self::EobPt0),
654            1 => Ok(Self::EobPt1),
655            2 => Ok(Self::EobPt2),
656            3 => Ok(Self::EobPt3To4),
657            4 => Ok(Self::EobPt5To8),
658            5 => Ok(Self::EobPt9To16),
659            6 => Ok(Self::EobPt17To32),
660            7 => Ok(Self::EobPt33To64),
661            8 => Ok(Self::EobPt65To128),
662            9 => Ok(Self::EobPt129To256),
663            10 => Ok(Self::EobPt257To512),
664            11 => Ok(Self::EobPt513To1024),
665            other => Err(CodecError::InvalidBitstream(format!(
666                "EOB-multi symbol {other} out of range 0..=11"
667            ))),
668        }
669    }
670
671    /// Get the base EOB for this point.
672    #[must_use]
673    pub const fn base_eob(self) -> u16 {
674        match self {
675            Self::EobPt0 => 0,
676            Self::EobPt1 => 1,
677            Self::EobPt2 => 2,
678            Self::EobPt3To4 => 3,
679            Self::EobPt5To8 => 5,
680            Self::EobPt9To16 => 9,
681            Self::EobPt17To32 => 17,
682            Self::EobPt33To64 => 33,
683            Self::EobPt65To128 => 65,
684            Self::EobPt129To256 => 129,
685            Self::EobPt257To512 => 257,
686            Self::EobPt513To1024 => 513,
687        }
688    }
689
690    /// Get the number of extra bits for this point.
691    #[must_use]
692    pub const fn extra_bits(self) -> u8 {
693        match self {
694            Self::EobPt0 | Self::EobPt1 | Self::EobPt2 => 0,
695            Self::EobPt3To4 => 1,
696            Self::EobPt5To8 => 2,
697            Self::EobPt9To16 => 3,
698            Self::EobPt17To32 => 4,
699            Self::EobPt33To64 => 5,
700            Self::EobPt65To128 => 6,
701            Self::EobPt129To256 => 7,
702            Self::EobPt257To512 => 8,
703            Self::EobPt513To1024 => 9,
704        }
705    }
706}
707
708// =============================================================================
709// Coefficient Base Range
710// =============================================================================
711
712/// Coefficient base range context.
713#[derive(Clone, Copy, Debug, Default)]
714pub struct CoeffBaseRange {
715    /// Base level (0-4).
716    pub base_level: u8,
717    /// Range context.
718    pub range_ctx: u8,
719}
720
721impl CoeffBaseRange {
722    /// Get context for coefficient base range coding.
723    #[must_use]
724    pub fn get_br_context(level_ctx: &LevelContext, pos: usize, width: usize) -> u8 {
725        let row = pos / width;
726        let col = pos % width;
727
728        // Base context from position
729        let pos_ctx = if row + col == 0 {
730            0
731        } else if row + col < 2 {
732            7
733        } else {
734            14
735        };
736
737        // Combine with magnitude context
738        pos_ctx + level_ctx.mag_context().min(6)
739    }
740
741    /// Compute level from base and range.
742    #[must_use]
743    pub fn compute_level(base: u8, range: u16) -> u32 {
744        u32::from(base) + u32::from(range)
745    }
746}
747
748// =============================================================================
749// Dequantization Helpers
750// =============================================================================
751
752/// Dequantize a single coefficient.
753#[must_use]
754pub fn dequantize_coeff(level: i32, dequant: i16, shift: u8) -> i32 {
755    let abs_level = level.abs();
756    let dq_level = (abs_level * i32::from(dequant)) >> shift;
757
758    if level < 0 {
759        -dq_level
760    } else {
761        dq_level
762    }
763}
764
765/// Dequantize all coefficients in a block.
766pub fn dequantize_block(coeffs: &mut [i32], dc_dequant: i16, ac_dequant: i16, shift: u8) {
767    if coeffs.is_empty() {
768        return;
769    }
770
771    // DC coefficient
772    coeffs[0] = dequantize_coeff(coeffs[0], dc_dequant, shift);
773
774    // AC coefficients
775    for coeff in coeffs.iter_mut().skip(1) {
776        *coeff = dequantize_coeff(*coeff, ac_dequant, shift);
777    }
778}
779
780/// Compute dequantization shift for a given bit depth.
781#[must_use]
782pub const fn get_dequant_shift(bit_depth: u8) -> u8 {
783    match bit_depth {
784        8 => 0,
785        10 => 2,
786        12 => 4,
787        _ => 0,
788    }
789}
790
791// =============================================================================
792// Coefficient Buffer
793// =============================================================================
794
795/// Buffer for storing and manipulating coefficient data.
796#[derive(Clone, Debug)]
797pub struct CoeffBuffer {
798    /// Coefficient storage.
799    coeffs: Vec<i32>,
800    /// Width of the buffer.
801    width: usize,
802    /// Height of the buffer.
803    height: usize,
804}
805
806impl CoeffBuffer {
807    /// Create a new coefficient buffer.
808    #[must_use]
809    pub fn new(width: usize, height: usize) -> Self {
810        Self {
811            coeffs: vec![0; width * height],
812            width,
813            height,
814        }
815    }
816
817    /// Create from transform size.
818    #[must_use]
819    pub fn from_tx_size(tx_size: TxSize) -> Self {
820        Self::new(tx_size.width() as usize, tx_size.height() as usize)
821    }
822
823    /// Get coefficient at position.
824    #[must_use]
825    pub fn get(&self, row: usize, col: usize) -> i32 {
826        if row < self.height && col < self.width {
827            self.coeffs[row * self.width + col]
828        } else {
829            0
830        }
831    }
832
833    /// Set coefficient at position.
834    pub fn set(&mut self, row: usize, col: usize, value: i32) {
835        if row < self.height && col < self.width {
836            self.coeffs[row * self.width + col] = value;
837        }
838    }
839
840    /// Clear all coefficients.
841    pub fn clear(&mut self) {
842        self.coeffs.fill(0);
843    }
844
845    /// Get the width (number of columns) of the buffer.
846    #[must_use]
847    pub const fn width(&self) -> usize {
848        self.width
849    }
850
851    /// Get the height (number of rows) of the buffer.
852    #[must_use]
853    pub const fn height(&self) -> usize {
854        self.height
855    }
856
857    /// Get mutable slice of coefficients.
858    pub fn as_mut_slice(&mut self) -> &mut [i32] {
859        &mut self.coeffs
860    }
861
862    /// Get immutable slice of coefficients.
863    #[must_use]
864    pub fn as_slice(&self) -> &[i32] {
865        &self.coeffs
866    }
867
868    /// Copy from scan order.
869    pub fn copy_from_scan(&mut self, src: &[i32], scan: &[u16]) {
870        for (i, &pos) in scan.iter().enumerate() {
871            if i < src.len() && (pos as usize) < self.coeffs.len() {
872                self.coeffs[pos as usize] = src[i];
873            }
874        }
875    }
876
877    /// Copy to scan order.
878    pub fn copy_to_scan(&self, dst: &mut [i32], scan: &[u16]) {
879        for (i, &pos) in scan.iter().enumerate() {
880            if i < dst.len() && (pos as usize) < self.coeffs.len() {
881                dst[i] = self.coeffs[pos as usize];
882            }
883        }
884    }
885}
886
887impl Default for CoeffBuffer {
888    fn default() -> Self {
889        Self::new(4, 4)
890    }
891}
892
893// =============================================================================
894// Neighbor Context Computation
895// =============================================================================
896
897/// Get neighbor positions for context computation.
898#[must_use]
899pub fn get_neighbor_positions(pos: usize, width: usize, _height: usize) -> [(usize, bool); 5] {
900    let row = pos / width;
901    let col = pos % width;
902
903    let mut neighbors = [(0usize, false); 5];
904
905    // Left neighbor
906    if col > 0 {
907        neighbors[0] = (row * width + col - 1, true);
908    }
909
910    // Above neighbor
911    if row > 0 {
912        neighbors[1] = ((row - 1) * width + col, true);
913    }
914
915    // Top-left diagonal
916    if row > 0 && col > 0 {
917        neighbors[2] = ((row - 1) * width + col - 1, true);
918    }
919
920    // Top-right diagonal
921    if row > 0 && col + 1 < width {
922        neighbors[3] = ((row - 1) * width + col + 1, true);
923    }
924
925    // Two positions left
926    if col > 1 {
927        neighbors[4] = (row * width + col - 2, true);
928    }
929
930    neighbors
931}
932
933/// Compute context from neighbor levels.
934#[must_use]
935pub fn compute_context_from_neighbors(levels: &[i32], neighbors: &[(usize, bool); 5]) -> u8 {
936    let mut mag = 0u32;
937    let mut count = 0u8;
938
939    for &(pos, valid) in neighbors.iter() {
940        if valid && pos < levels.len() {
941            let level = levels[pos].unsigned_abs();
942            mag += level;
943            if level > 0 {
944                count += 1;
945            }
946        }
947    }
948
949    // Context based on magnitude and count
950    let mag_ctx = if mag > 512 {
951        4
952    } else if mag > 256 {
953        3
954    } else if mag > 128 {
955        2
956    } else if mag > 64 {
957        1
958    } else {
959        0
960    };
961
962    mag_ctx * 3 + count.min(2)
963}
964
965// =============================================================================
966// Sign Coding Helpers
967// =============================================================================
968
969/// DC sign context computation.
970#[must_use]
971pub fn compute_dc_sign_context(left_dc: i32, above_dc: i32) -> u8 {
972    let left_sign = left_dc.signum();
973    let above_sign = above_dc.signum();
974
975    let sum = left_sign + above_sign;
976
977    if sum < 0 {
978        0
979    } else if sum > 0 {
980        2
981    } else {
982        1
983    }
984}
985
986/// Update context after coefficient is decoded.
987pub fn update_level_context(
988    left_ctx: &mut [u8],
989    above_ctx: &mut [u8],
990    level: i32,
991    row: usize,
992    col: usize,
993) {
994    let level_ctx = (level.unsigned_abs().min(63) as u8) + 1;
995
996    if row < left_ctx.len() {
997        left_ctx[row] = level_ctx;
998    }
999
1000    if col < above_ctx.len() {
1001        above_ctx[col] = level_ctx;
1002    }
1003}
1004
1005// =============================================================================
1006// Coefficient Statistics
1007// =============================================================================
1008
1009/// Statistics about coefficients in a block.
1010#[derive(Clone, Debug, Default)]
1011pub struct CoeffStats {
1012    /// Number of zero coefficients.
1013    pub zero_count: u32,
1014    /// Number of coefficients with level 1.
1015    pub level1_count: u32,
1016    /// Number of coefficients with level 2.
1017    pub level2_count: u32,
1018    /// Number of coefficients with level > 2.
1019    pub high_level_count: u32,
1020    /// Sum of absolute levels.
1021    pub level_sum: u64,
1022    /// Maximum absolute level.
1023    pub max_level: u32,
1024}
1025
1026impl CoeffStats {
1027    /// Compute statistics from coefficient buffer.
1028    #[must_use]
1029    pub fn from_coeffs(coeffs: &[i32]) -> Self {
1030        let mut stats = Self::default();
1031
1032        for &coeff in coeffs {
1033            let level = coeff.unsigned_abs();
1034
1035            match level {
1036                0 => stats.zero_count += 1,
1037                1 => stats.level1_count += 1,
1038                2 => stats.level2_count += 1,
1039                _ => stats.high_level_count += 1,
1040            }
1041
1042            stats.level_sum += u64::from(level);
1043            stats.max_level = stats.max_level.max(level);
1044        }
1045
1046        stats
1047    }
1048
1049    /// Get total non-zero count.
1050    #[must_use]
1051    pub fn nonzero_count(&self) -> u32 {
1052        self.level1_count + self.level2_count + self.high_level_count
1053    }
1054
1055    /// Get average level (for non-zero coefficients).
1056    #[must_use]
1057    pub fn average_level(&self) -> f64 {
1058        let count = self.nonzero_count();
1059        if count > 0 {
1060            self.level_sum as f64 / count as f64
1061        } else {
1062            0.0
1063        }
1064    }
1065}
1066
1067// =============================================================================
1068// Tests
1069// =============================================================================
1070
1071#[cfg(test)]
1072mod tests {
1073    use super::*;
1074
1075    #[test]
1076    fn test_level_context() {
1077        let mut ctx = LevelContext::new();
1078        assert_eq!(ctx.mag, 0);
1079        assert_eq!(ctx.count, 0);
1080
1081        ctx.mag = 100;
1082        ctx.count = 2;
1083        assert_eq!(ctx.mag_context(), 1);
1084        assert_eq!(ctx.context(), 1 * 3 + 2);
1085    }
1086
1087    #[test]
1088    fn test_coeff_context_new() {
1089        let ctx = CoeffContext::new(TxSize::Tx8x8, TxType::DctDct, 0);
1090        assert_eq!(ctx.levels.len(), 64);
1091        assert_eq!(ctx.tx_class(), TxClass::Class2D);
1092    }
1093
1094    #[test]
1095    fn test_coeff_context_set_get() {
1096        let mut ctx = CoeffContext::new(TxSize::Tx4x4, TxType::DctDct, 0);
1097        ctx.set_coeff(5, 100, false);
1098        assert_eq!(ctx.get_coeff(5), 100);
1099
1100        ctx.set_coeff(10, 50, true);
1101        assert_eq!(ctx.get_coeff(10), -50);
1102    }
1103
1104    #[test]
1105    fn test_diagonal_scan_4x4() {
1106        let scan = generate_diagonal_scan(4, 4);
1107        assert_eq!(scan.len(), 16);
1108        // First few elements should be diagonal
1109        assert_eq!(scan[0], 0); // (0,0)
1110        assert_eq!(scan[1], 1); // (0,1) - in row-major for 4x4
1111    }
1112
1113    #[test]
1114    fn test_horizontal_scan() {
1115        let scan = generate_horizontal_scan(4, 4);
1116        assert_eq!(scan.len(), 16);
1117        for i in 0..16 {
1118            assert_eq!(scan[i], i as u16);
1119        }
1120    }
1121
1122    #[test]
1123    fn test_vertical_scan() {
1124        let scan = generate_vertical_scan(4, 4);
1125        assert_eq!(scan.len(), 16);
1126        assert_eq!(scan[0], 0);
1127        assert_eq!(scan[1], 4);
1128        assert_eq!(scan[2], 8);
1129        assert_eq!(scan[3], 12);
1130    }
1131
1132    #[test]
1133    fn test_scan_order_cache() {
1134        let cache = ScanOrderCache::new();
1135        let scan = cache.get(TxSize::Tx4x4, TxClass::Class2D);
1136        assert_eq!(scan.len(), 16);
1137    }
1138
1139    #[test]
1140    fn test_eob_context() {
1141        let ctx = EobContext::new(TxSize::Tx8x8);
1142        assert!(ctx.eob_extra > 0);
1143    }
1144
1145    #[test]
1146    fn test_eob_pt() {
1147        assert_eq!(EobPt::from_eob(0), EobPt::EobPt0);
1148        assert_eq!(EobPt::from_eob(1), EobPt::EobPt1);
1149        assert_eq!(EobPt::from_eob(5), EobPt::EobPt5To8);
1150        assert_eq!(EobPt::from_eob(100), EobPt::EobPt65To128);
1151
1152        assert_eq!(EobPt::EobPt5To8.extra_bits(), 2);
1153        assert_eq!(EobPt::EobPt5To8.base_eob(), 5);
1154    }
1155
1156    #[test]
1157    fn test_eob_pt_from_symbol_basic() {
1158        // Regression: the EOB-multi CDF emits a *symbol* (0..=11) that
1159        // identifies the EobPt class directly. Each variant's discriminant
1160        // equals its symbol value.
1161        let expected = [
1162            EobPt::EobPt0,
1163            EobPt::EobPt1,
1164            EobPt::EobPt2,
1165            EobPt::EobPt3To4,
1166            EobPt::EobPt5To8,
1167            EobPt::EobPt9To16,
1168            EobPt::EobPt17To32,
1169            EobPt::EobPt33To64,
1170            EobPt::EobPt65To128,
1171            EobPt::EobPt129To256,
1172            EobPt::EobPt257To512,
1173            EobPt::EobPt513To1024,
1174        ];
1175        for (symbol, &want) in expected.iter().enumerate() {
1176            let got = EobPt::from_symbol(symbol).expect("symbol in 0..=11 must succeed");
1177            assert_eq!(got, want, "symbol {symbol} should map to {want:?}");
1178            assert_eq!(
1179                got as usize, symbol,
1180                "discriminant of {got:?} should equal symbol {symbol}"
1181            );
1182        }
1183    }
1184
1185    #[test]
1186    fn test_eob_pt_from_symbol_rejects_out_of_range() {
1187        // The 16-symbol EOB-multi CDF can in principle yield symbols 12..15
1188        // from a malformed bitstream. Those must produce a decoding error
1189        // rather than silently saturate.
1190        for bad in 12..=15 {
1191            let err = EobPt::from_symbol(bad).expect_err("symbols 12..=15 must be rejected");
1192            let msg = format!("{err}");
1193            assert!(
1194                msg.contains("EOB-multi"),
1195                "error message should mention EOB-multi: {msg}"
1196            );
1197        }
1198        assert!(EobPt::from_symbol(usize::MAX).is_err());
1199    }
1200
1201    #[test]
1202    fn test_eob_pt_from_symbol_differs_from_from_eob() {
1203        // This is the heart of the bug: the old decoder confused these two.
1204        // For symbols >= 4 the two functions return different EobPt classes,
1205        // and crucially different `extra_bits` counts, so a decoder that
1206        // misuses `from_eob` on a symbol desynchronizes the bitstream.
1207        //
1208        // Symbols 0..=3 happen to coincide (EobPt0, EobPt1, EobPt2,
1209        // EobPt3To4) because for those values the position table and the
1210        // symbol enumeration agree.
1211        for symbol in 0..=3usize {
1212            let from_sym = EobPt::from_symbol(symbol).expect("low symbol always succeeds");
1213            let from_pos = EobPt::from_eob(symbol as u16);
1214            assert_eq!(from_sym, from_pos, "symbols 0..=3 coincidentally match");
1215        }
1216        for symbol in 4..=11usize {
1217            let from_sym = EobPt::from_symbol(symbol).expect("symbols 4..=11 must succeed");
1218            let from_pos = EobPt::from_eob(symbol as u16);
1219            assert_ne!(
1220                from_sym, from_pos,
1221                "symbol {symbol} must NOT coincide with from_eob (the bug)",
1222            );
1223            // And the extra_bits read from the bitstream would have been
1224            // wrong, which is the actual harm.
1225            assert_ne!(
1226                from_sym.extra_bits(),
1227                from_pos.extra_bits(),
1228                "symbol {symbol} extra_bits mismatch is the desync source",
1229            );
1230        }
1231    }
1232
1233    #[test]
1234    fn test_eob_pt_symbol_chain_covers_group_start() {
1235        // For every legal symbol, the decoder consumes `extra_bits` literal
1236        // bits and combines them with EOB_GROUP_START[symbol] to obtain the
1237        // final EOB position. Verify that this combination spans precisely
1238        // the range advertised by the EobPt variant.
1239        for symbol in 0..=11usize {
1240            let pt = EobPt::from_symbol(symbol).expect("legal symbol");
1241            assert_eq!(pt.base_eob(), EOB_GROUP_START[symbol]);
1242
1243            let extra = pt.extra_bits();
1244            // Range size = 2^extra; symbol 0..=2 are degenerate (size 1).
1245            let range_size: u16 = 1u16 << extra;
1246            if symbol <= 2 {
1247                assert_eq!(range_size, 1);
1248            } else {
1249                // Final position fits in [base, base + range_size - 1].
1250                let max_offset = range_size - 1;
1251                let max_pos = EobPt::from_eob(pt.base_eob() + max_offset);
1252                assert_eq!(
1253                    max_pos, pt,
1254                    "EOB position at top of group must classify as same EobPt",
1255                );
1256            }
1257        }
1258    }
1259
1260    #[test]
1261    fn test_dequantize_coeff() {
1262        let level = 10;
1263        let dequant = 16;
1264        let result = dequantize_coeff(level, dequant, 0);
1265        assert_eq!(result, 160);
1266
1267        let neg_result = dequantize_coeff(-level, dequant, 0);
1268        assert_eq!(neg_result, -160);
1269    }
1270
1271    #[test]
1272    fn test_dequantize_block() {
1273        let mut coeffs = vec![10, 5, 5, 5, 5, 5, 5, 5];
1274        dequantize_block(&mut coeffs, 20, 10, 0);
1275
1276        assert_eq!(coeffs[0], 200); // DC: 10 * 20
1277        assert_eq!(coeffs[1], 50); // AC: 5 * 10
1278    }
1279
1280    #[test]
1281    fn test_get_dequant_shift() {
1282        assert_eq!(get_dequant_shift(8), 0);
1283        assert_eq!(get_dequant_shift(10), 2);
1284        assert_eq!(get_dequant_shift(12), 4);
1285    }
1286
1287    #[test]
1288    fn test_coeff_buffer() {
1289        let mut buf = CoeffBuffer::new(4, 4);
1290        buf.set(1, 2, 100);
1291        assert_eq!(buf.get(1, 2), 100);
1292        assert_eq!(buf.get(0, 0), 0);
1293
1294        buf.clear();
1295        assert_eq!(buf.get(1, 2), 0);
1296    }
1297
1298    #[test]
1299    fn test_coeff_buffer_from_tx_size() {
1300        let buf = CoeffBuffer::from_tx_size(TxSize::Tx8x8);
1301        assert_eq!(buf.as_slice().len(), 64);
1302    }
1303
1304    #[test]
1305    fn test_neighbor_positions() {
1306        let neighbors = get_neighbor_positions(5, 4, 4);
1307
1308        // Position 5 is row=1, col=1 in 4x4
1309        // Left neighbor should be valid at position 4
1310        assert!(neighbors[0].1);
1311        assert_eq!(neighbors[0].0, 4);
1312
1313        // Above neighbor should be valid at position 1
1314        assert!(neighbors[1].1);
1315        assert_eq!(neighbors[1].0, 1);
1316    }
1317
1318    #[test]
1319    fn test_compute_dc_sign_context() {
1320        assert_eq!(compute_dc_sign_context(-5, -3), 0); // Both negative
1321        assert_eq!(compute_dc_sign_context(5, 3), 2); // Both positive
1322        assert_eq!(compute_dc_sign_context(-5, 3), 1); // Mixed
1323        assert_eq!(compute_dc_sign_context(0, 0), 1); // Zero
1324    }
1325
1326    #[test]
1327    fn test_coeff_stats() {
1328        let coeffs = vec![0, 1, 2, 3, 0, 1, 5, 0];
1329        let stats = CoeffStats::from_coeffs(&coeffs);
1330
1331        assert_eq!(stats.zero_count, 3);
1332        assert_eq!(stats.level1_count, 2);
1333        assert_eq!(stats.level2_count, 1);
1334        assert_eq!(stats.high_level_count, 2);
1335        assert_eq!(stats.max_level, 5);
1336        assert_eq!(stats.nonzero_count(), 5);
1337    }
1338
1339    #[test]
1340    fn test_coeff_context_dc_sign() {
1341        let ctx = CoeffContext::new(TxSize::Tx4x4, TxType::DctDct, 0);
1342        // Default context with empty context arrays
1343        // Results in neutral sign context (0 or 1 depending on implementation)
1344        let dc_ctx = ctx.dc_sign_context();
1345        // Context should be valid (0, 1, or 2)
1346        assert!(dc_ctx <= 2);
1347    }
1348
1349    #[test]
1350    fn test_coeff_context_level_context() {
1351        let mut ctx = CoeffContext::new(TxSize::Tx4x4, TxType::DctDct, 0);
1352        ctx.levels[0] = 5;
1353        ctx.levels[1] = 3;
1354        ctx.levels[4] = 2;
1355
1356        let level_ctx = ctx.compute_level_context(5);
1357        // Position 5 has neighbors at 4 (left) and 1 (above)
1358        assert!(level_ctx.mag > 0);
1359    }
1360
1361    #[test]
1362    fn test_eob_compute() {
1363        // Test EOB computation from multi-symbol and extra bits
1364        assert_eq!(EobContext::compute_eob(0, 0), 0);
1365        assert_eq!(EobContext::compute_eob(1, 0), 1);
1366        assert_eq!(EobContext::compute_eob(2, 0), 2);
1367    }
1368
1369    #[test]
1370    fn test_coeff_base_range() {
1371        let level_ctx = LevelContext {
1372            mag: 100,
1373            count: 2,
1374            pos_ctx: 1,
1375        };
1376
1377        let br_ctx = CoeffBaseRange::get_br_context(&level_ctx, 5, 4);
1378        assert!(br_ctx > 0);
1379
1380        let level = CoeffBaseRange::compute_level(2, 5);
1381        assert_eq!(level, 7);
1382    }
1383
1384    #[test]
1385    fn test_constants() {
1386        assert_eq!(MAX_EOB, 4096);
1387        assert_eq!(EOB_COEF_CONTEXTS, 9);
1388        assert_eq!(TX_CLASSES, 3);
1389    }
1390
1391    #[test]
1392    fn test_coeff_context_reset() {
1393        let mut ctx = CoeffContext::new(TxSize::Tx4x4, TxType::DctDct, 0);
1394        ctx.eob = 10;
1395        ctx.levels[5] = 100;
1396
1397        ctx.reset();
1398        assert_eq!(ctx.eob, 0);
1399        assert_eq!(ctx.levels[5], 0);
1400    }
1401
1402    #[test]
1403    fn test_coeff_context_count_nonzero() {
1404        let mut ctx = CoeffContext::new(TxSize::Tx4x4, TxType::DctDct, 0);
1405        ctx.levels[0] = 5;
1406        ctx.levels[5] = 3;
1407        ctx.levels[10] = -2;
1408
1409        assert_eq!(ctx.count_nonzero(), 3);
1410    }
1411
1412    #[test]
1413    fn test_scan_order_all_sizes() {
1414        // Test that scan order generation works for all sizes
1415        for size_idx in 0..19 {
1416            if let Some(tx_size) = TxSize::from_u8(size_idx) {
1417                for class_idx in 0..3 {
1418                    if let Some(tx_class) = TxClass::from_u8(class_idx) {
1419                        let scan = get_scan_order(tx_size, tx_class);
1420                        assert_eq!(scan.len(), tx_size.area() as usize);
1421                    }
1422                }
1423            }
1424        }
1425    }
1426}