Skip to main content

oximedia_codec/av1/
coeff_decode.rs

1//! AV1 transform coefficient decoding.
2//!
3//! This module handles the complete decoding of transform coefficients from
4//! the entropy-coded bitstream, including:
5//!
6//! - EOB (End of Block) position parsing
7//! - Coefficient level decoding using multi-level scheme
8//! - Coefficient sign decoding
9//! - Dequantization
10//! - Scan order application
11//!
12//! # Coefficient Decoding Process
13//!
14//! 1. **EOB Parsing** - Determine position of last non-zero coefficient
15//! 2. **Coefficient Levels** - Decode base levels and ranges
16//! 3. **DC Sign** - Decode sign of DC coefficient
17//! 4. **AC Signs** - Decode signs of AC coefficients
18//! 5. **Dequantization** - Apply quantization parameters
19//! 6. **Scan Order** - Convert from scan order to raster order
20//!
21//! # Context Modeling
22//!
23//! Coefficient decoding uses adaptive context models based on:
24//! - Position within the block
25//! - Magnitude of neighboring coefficients
26//! - Transform size and type
27
28#![forbid(unsafe_code)]
29#![allow(dead_code)]
30#![allow(clippy::doc_markdown)]
31#![allow(clippy::too_many_arguments)]
32#![allow(clippy::cast_possible_truncation)]
33#![allow(clippy::cast_sign_loss)]
34#![allow(clippy::cast_possible_wrap)]
35#![allow(clippy::similar_names)]
36#![allow(clippy::module_name_repetitions)]
37
38use super::coefficients::{
39    dequantize_block, get_dequant_shift, CoeffBuffer, CoeffContext, CoeffStats, EobContext, EobPt,
40    LevelContext, ScanOrderCache,
41};
42use super::entropy::SymbolReader;
43use super::entropy_tables::CdfContext;
44use super::quantization::QuantizationParams;
45use super::transform::{TxSize, TxType};
46use crate::error::CodecResult;
47
48// =============================================================================
49// Constants
50// =============================================================================
51
52/// Maximum coefficient level for base coding.
53pub const COEFF_BASE_MAX: u32 = 3;
54
55/// Coefficient base range (used for higher levels).
56pub const BR_CDF_SIZE: usize = 4;
57
58/// Maximum Golomb-Rice parameter.
59pub const MAX_BR_PARAM: u8 = 5;
60
61/// Number of EOB offset bits.
62pub const EOB_OFFSET_BITS: [u8; 12] = [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
63
64/// Scan order coeff skip threshold.
65pub const COEFF_SKIP_THRESHOLD: u16 = 256;
66
67// =============================================================================
68// Coefficient Decoder
69// =============================================================================
70
71/// Decoder for transform coefficients.
72#[derive(Debug)]
73pub struct CoeffDecoder {
74    /// Symbol reader.
75    reader: SymbolReader,
76    /// CDF context for probability models.
77    cdf_context: CdfContext,
78    /// Scan order cache.
79    scan_cache: ScanOrderCache,
80    /// Quantization parameters.
81    quant_params: QuantizationParams,
82    /// Current bit depth.
83    bit_depth: u8,
84}
85
86impl CoeffDecoder {
87    /// Create a new coefficient decoder.
88    pub fn new(data: Vec<u8>, quant_params: QuantizationParams, bit_depth: u8) -> Self {
89        Self {
90            reader: SymbolReader::new(data),
91            cdf_context: CdfContext::new(),
92            scan_cache: ScanOrderCache::new(),
93            quant_params,
94            bit_depth,
95        }
96    }
97
98    /// Decode coefficients for a transform block.
99    pub fn decode_coefficients(
100        &mut self,
101        tx_size: TxSize,
102        tx_type: TxType,
103        plane: u8,
104        skip: bool,
105    ) -> CodecResult<CoeffBuffer> {
106        let mut ctx = CoeffContext::new(tx_size, tx_type, plane);
107
108        if skip {
109            // Skip blocks have all-zero coefficients
110            return Ok(CoeffBuffer::from_tx_size(tx_size));
111        }
112
113        // Decode EOB position
114        ctx.eob = self.decode_eob(tx_size, plane)?;
115
116        if ctx.eob == 0 {
117            // No coefficients
118            return Ok(CoeffBuffer::from_tx_size(tx_size));
119        }
120
121        // Get scan order (clone to avoid borrow issues)
122        let scan = self.scan_cache.get(tx_size, ctx.tx_class()).to_vec();
123
124        // Decode coefficient levels
125        self.decode_coeff_levels(&mut ctx, &scan)?;
126
127        // Decode signs
128        self.decode_signs(&mut ctx, &scan)?;
129
130        // Dequantize coefficients
131        self.dequantize_coefficients(&mut ctx, plane)?;
132
133        // Convert from scan order to raster order
134        let mut buffer = CoeffBuffer::from_tx_size(tx_size);
135        self.reorder_coefficients(&ctx, &scan, &mut buffer)?;
136
137        Ok(buffer)
138    }
139
140    /// Decode EOB (End of Block) position.
141    ///
142    /// The EOB-multi CDF emits a *symbol index* (the index of the EOB point
143    /// class, 0..=11), not the EOB position itself. The symbol then dictates
144    /// how many extra bits the decoder must consume to recover the precise
145    /// position within the class. Confusing these two values desynchronizes
146    /// the bitstream — see [`EobPt::from_symbol`] for the relationship.
147    fn decode_eob(&mut self, tx_size: TxSize, plane: u8) -> CodecResult<u16> {
148        let _eob_ctx = EobContext::new(tx_size);
149
150        // Read EOB multi-symbol
151        let ctx = (tx_size as usize * 3) + (plane as usize);
152        let eob_multi_cdf = self.cdf_context.get_eob_multi_cdf_mut(ctx);
153        let eob_multi_symbol = self.reader.read_symbol(eob_multi_cdf);
154
155        self.finalize_eob_from_symbol(eob_multi_symbol)
156    }
157
158    /// Compute the final EOB position from a previously-read EOB-multi
159    /// *symbol* (CDF index in `0..=11`).
160    ///
161    /// This is split out from [`Self::decode_eob`] so that regression tests
162    /// can inject a known symbol and observe the precise number of extra
163    /// bits consumed plus the final EOB position — the two outputs that
164    /// diverge between the correct (`from_symbol`) and the historically
165    /// buggy (`from_eob`) implementations of this step.
166    ///
167    /// # Errors
168    ///
169    /// Returns [`crate::error::CodecError::InvalidBitstream`] if `symbol`
170    /// is outside the legal range 0..=11.
171    fn finalize_eob_from_symbol(&mut self, symbol: usize) -> CodecResult<u16> {
172        if symbol == 0 {
173            return Ok(0); // No coefficients
174        }
175
176        // Convert the EOB-multi *symbol* (CDF index) into its EobPt class.
177        // NOTE: do not call `EobPt::from_eob` here — that helper takes an
178        // actual EOB position, not a symbol index. See the unit tests for
179        // a demonstration of how the two formulations diverge.
180        let eob_pt = EobPt::from_symbol(symbol)?;
181        let extra_bits = eob_pt.extra_bits();
182
183        let eob_extra = if extra_bits > 0 {
184            self.reader.read_literal(extra_bits)
185        } else {
186            0
187        };
188
189        let eob_multi = symbol as u8;
190        let eob = EobContext::compute_eob(eob_multi, eob_extra as u16);
191        Ok(eob)
192    }
193
194    /// Decode coefficient levels.
195    fn decode_coeff_levels(&mut self, ctx: &mut CoeffContext, scan: &[u16]) -> CodecResult<()> {
196        let eob = ctx.eob as usize;
197
198        // Decode in reverse scan order (from EOB to DC)
199        for scan_idx in (0..eob).rev() {
200            let pos = scan[scan_idx] as usize;
201            let level = self.decode_coeff_level(ctx, pos, scan_idx == eob - 1)?;
202            ctx.levels[pos] = level as i32;
203        }
204
205        Ok(())
206    }
207
208    /// Decode a single coefficient level.
209    fn decode_coeff_level(
210        &mut self,
211        ctx: &CoeffContext,
212        pos: usize,
213        is_eob: bool,
214    ) -> CodecResult<u32> {
215        let level_ctx = ctx.compute_level_context(pos);
216
217        // Decode base level (0-3)
218        let base_level = if is_eob {
219            // At EOB, coefficient is at least 1
220            1 + self.decode_coeff_base_eob(&level_ctx, ctx.plane)?
221        } else {
222            self.decode_coeff_base(&level_ctx, ctx.plane)?
223        };
224
225        if base_level >= COEFF_BASE_MAX {
226            // Decode additional range
227            let range = self.decode_coeff_base_range(&level_ctx, ctx.plane)?;
228            Ok(base_level + range)
229        } else {
230            Ok(base_level)
231        }
232    }
233
234    /// Decode coefficient base level.
235    fn decode_coeff_base(&mut self, level_ctx: &LevelContext, _plane: u8) -> CodecResult<u32> {
236        let context = level_ctx.context() as usize;
237        let cdf = self.cdf_context.get_coeff_base_cdf_mut(context);
238        Ok(self.reader.read_symbol(cdf) as u32)
239    }
240
241    /// Decode coefficient base at EOB.
242    fn decode_coeff_base_eob(&mut self, level_ctx: &LevelContext, _plane: u8) -> CodecResult<u32> {
243        let context = level_ctx.context() as usize;
244        let cdf = self.cdf_context.get_coeff_base_eob_cdf_mut(context);
245        Ok(self.reader.read_symbol(cdf) as u32)
246    }
247
248    /// Decode coefficient base range for high magnitude coefficients.
249    fn decode_coeff_base_range(
250        &mut self,
251        level_ctx: &LevelContext,
252        _plane: u8,
253    ) -> CodecResult<u32> {
254        let context = level_ctx.mag_context() as usize;
255        let mut total_range = 0u32;
256
257        // Multi-level Golomb-Rice coding
258        for _level in 0..5 {
259            let br_cdf = self.cdf_context.get_coeff_br_cdf_mut(context);
260            let br_symbol = self.reader.read_symbol(br_cdf);
261
262            if br_symbol < BR_CDF_SIZE - 1 {
263                total_range += br_symbol as u32;
264                break;
265            } else {
266                total_range += (BR_CDF_SIZE - 1) as u32;
267            }
268        }
269
270        Ok(total_range)
271    }
272
273    /// Decode signs for all non-zero coefficients.
274    fn decode_signs(&mut self, ctx: &mut CoeffContext, scan: &[u16]) -> CodecResult<()> {
275        let eob = ctx.eob as usize;
276
277        // Decode DC sign first (if DC is non-zero)
278        if ctx.levels[0] != 0 {
279            let dc_sign_ctx = ctx.dc_sign_context();
280            let cdf_slice = self.cdf_context.get_dc_sign_cdf_mut(dc_sign_ctx as usize);
281            // read_bool needs &mut [u16; 3], but we have &mut [u16]
282            // Copy to a fixed-size array
283            if cdf_slice.len() >= 3 {
284                let mut cdf_array = [cdf_slice[0], cdf_slice[1], cdf_slice[2]];
285                let sign = self.reader.read_bool(&mut cdf_array);
286                // Copy back updated CDF
287                cdf_slice[0] = cdf_array[0];
288                cdf_slice[1] = cdf_array[1];
289                cdf_slice[2] = cdf_array[2];
290                ctx.signs[0] = sign;
291                if sign {
292                    ctx.levels[0] = -ctx.levels[0];
293                }
294            }
295        }
296
297        // Decode AC signs
298        for scan_idx in 1..eob {
299            let pos = scan[scan_idx] as usize;
300            if ctx.levels[pos] != 0 {
301                // AC signs use equiprobable model
302                let sign = self.reader.read_bool_eq();
303                ctx.signs[pos] = sign;
304                if sign {
305                    ctx.levels[pos] = -ctx.levels[pos];
306                }
307            }
308        }
309
310        Ok(())
311    }
312
313    /// Dequantize coefficients.
314    fn dequantize_coefficients(&mut self, ctx: &mut CoeffContext, plane: u8) -> CodecResult<()> {
315        // Get dequant values
316        let dc_dequant = self
317            .quant_params
318            .get_dc_quant(plane as usize, self.bit_depth) as i16;
319        let ac_dequant = self
320            .quant_params
321            .get_ac_quant(plane as usize, self.bit_depth) as i16;
322        let shift = get_dequant_shift(self.bit_depth);
323
324        dequantize_block(&mut ctx.levels, dc_dequant, ac_dequant, shift);
325
326        Ok(())
327    }
328
329    /// Reorder coefficients from scan order to raster order.
330    fn reorder_coefficients(
331        &self,
332        ctx: &CoeffContext,
333        scan: &[u16],
334        buffer: &mut CoeffBuffer,
335    ) -> CodecResult<()> {
336        let eob = ctx.eob as usize;
337
338        for scan_idx in 0..eob {
339            let pos = scan[scan_idx] as usize;
340            if pos < ctx.levels.len() {
341                let level = ctx.levels[pos];
342                let (row, col) = ctx.get_scan_position(pos);
343                buffer.set(row as usize, col as usize, level);
344            }
345        }
346
347        Ok(())
348    }
349
350    /// Check if more data is available.
351    pub fn has_more_data(&self) -> bool {
352        self.reader.has_more_data()
353    }
354
355    /// Get current position.
356    pub fn position(&self) -> usize {
357        self.reader.position()
358    }
359}
360
361// =============================================================================
362// Coefficient Encoding
363// =============================================================================
364
365/// Encoder for transform coefficients.
366#[derive(Debug)]
367pub struct CoeffEncoder {
368    /// Symbol writer.
369    writer: super::entropy::SymbolWriter,
370    /// CDF context.
371    cdf_context: CdfContext,
372    /// Scan order cache.
373    scan_cache: ScanOrderCache,
374}
375
376impl CoeffEncoder {
377    /// Create a new coefficient encoder.
378    #[must_use]
379    pub fn new() -> Self {
380        Self {
381            writer: super::entropy::SymbolWriter::new(),
382            cdf_context: CdfContext::new(),
383            scan_cache: ScanOrderCache::new(),
384        }
385    }
386
387    /// Encode coefficients for a transform block.
388    pub fn encode_coefficients(
389        &mut self,
390        buffer: &CoeffBuffer,
391        tx_size: TxSize,
392        tx_type: TxType,
393        plane: u8,
394    ) -> CodecResult<()> {
395        let tx_class = tx_type.tx_class();
396        let scan = self.scan_cache.get(tx_size, tx_class).to_vec();
397
398        // Find EOB
399        let eob = self.find_eob(buffer, &scan);
400
401        // Encode EOB
402        self.encode_eob(eob, tx_size, plane)?;
403
404        if eob == 0 {
405            return Ok(());
406        }
407
408        // Convert to scan order
409        let mut levels = vec![0i32; eob as usize];
410        buffer.copy_to_scan(&mut levels, &scan[..eob as usize]);
411
412        // Encode levels
413        self.encode_levels(&levels, &scan, plane)?;
414
415        // Encode signs
416        self.encode_signs(&levels, &scan)?;
417
418        Ok(())
419    }
420
421    /// Find EOB position.
422    fn find_eob(&self, buffer: &CoeffBuffer, scan: &[u16]) -> u16 {
423        for (i, &pos) in scan.iter().enumerate().rev() {
424            let (row, col) = self.pos_to_rowcol(pos as usize, buffer);
425            if buffer.get(row, col) != 0 {
426                return (i + 1) as u16;
427            }
428        }
429        0
430    }
431
432    /// Convert position to row/col.
433    fn pos_to_rowcol(&self, pos: usize, buffer: &CoeffBuffer) -> (usize, usize) {
434        let width = buffer.width();
435        (pos / width, pos % width)
436    }
437
438    /// Encode EOB.
439    fn encode_eob(&mut self, eob: u16, tx_size: TxSize, plane: u8) -> CodecResult<()> {
440        let ctx = (tx_size as usize * 3) + (plane as usize);
441
442        if eob == 0 {
443            let cdf = self.cdf_context.get_eob_multi_cdf_mut(ctx);
444            self.writer.write_symbol(0, cdf);
445            return Ok(());
446        }
447
448        let eob_pt = EobPt::from_eob(eob);
449        let cdf = self.cdf_context.get_eob_multi_cdf_mut(ctx);
450        self.writer.write_symbol(eob_pt as usize, cdf);
451
452        // Write extra bits
453        let extra_bits = eob_pt.extra_bits();
454        if extra_bits > 0 {
455            let offset = eob - eob_pt.base_eob();
456            self.writer.write_literal(offset as u32, extra_bits);
457        }
458
459        Ok(())
460    }
461
462    /// Encode coefficient levels.
463    fn encode_levels(&mut self, levels: &[i32], scan: &[u16], plane: u8) -> CodecResult<()> {
464        let mut ctx = LevelContext::new();
465
466        for (scan_idx, &_pos) in scan.iter().enumerate().rev() {
467            let level = levels[scan_idx].unsigned_abs();
468
469            let base_level = level.min(COEFF_BASE_MAX);
470            let is_eob = scan_idx == levels.len() - 1;
471
472            if is_eob {
473                let cdf = self
474                    .cdf_context
475                    .get_coeff_base_eob_cdf_mut(ctx.context() as usize);
476                self.writer.write_symbol((base_level - 1) as usize, cdf);
477            } else {
478                let cdf = self
479                    .cdf_context
480                    .get_coeff_base_cdf_mut(ctx.context() as usize);
481                self.writer.write_symbol(base_level as usize, cdf);
482            }
483
484            if level >= COEFF_BASE_MAX {
485                self.encode_base_range(level - COEFF_BASE_MAX, &ctx, plane)?;
486            }
487
488            // Update context
489            ctx.mag += level;
490            if level > 0 {
491                ctx.count += 1;
492            }
493        }
494
495        Ok(())
496    }
497
498    /// Encode coefficient base range.
499    fn encode_base_range(&mut self, range: u32, ctx: &LevelContext, _plane: u8) -> CodecResult<()> {
500        let mut remaining = range;
501        let mag_ctx = ctx.mag_context() as usize;
502
503        for _level in 0..5 {
504            if remaining == 0 {
505                break;
506            }
507
508            let symbol = remaining.min((BR_CDF_SIZE - 1) as u32) as usize;
509            let cdf = self.cdf_context.get_coeff_br_cdf_mut(mag_ctx);
510            self.writer.write_symbol(symbol, cdf);
511
512            if symbol < BR_CDF_SIZE - 1 {
513                break;
514            }
515
516            remaining -= (BR_CDF_SIZE - 1) as u32;
517        }
518
519        Ok(())
520    }
521
522    /// Encode signs.
523    fn encode_signs(&mut self, levels: &[i32], scan: &[u16]) -> CodecResult<()> {
524        // DC sign
525        if !levels.is_empty() && levels[0] != 0 {
526            let dc_ctx = 1; // Simplified context
527            let cdf_slice = self.cdf_context.get_dc_sign_cdf_mut(dc_ctx);
528            if cdf_slice.len() >= 3 {
529                let mut cdf = [cdf_slice[0], cdf_slice[1], cdf_slice[2]];
530                self.writer.write_bool(levels[0] < 0, &mut cdf);
531                // Copy back updated CDF
532                cdf_slice[0] = cdf[0];
533                cdf_slice[1] = cdf[1];
534                cdf_slice[2] = cdf[2];
535            }
536        }
537
538        // AC signs
539        for (idx, &_pos) in scan.iter().enumerate().skip(1) {
540            if idx < levels.len() && levels[idx] != 0 {
541                // Equiprobable
542                let mut cdf = [16384u16, 32768, 0];
543                self.writer.write_bool(levels[idx] < 0, &mut cdf);
544            }
545        }
546
547        Ok(())
548    }
549
550    /// Finalize and get output.
551    #[must_use]
552    pub fn finish(self) -> Vec<u8> {
553        self.writer.finish()
554    }
555}
556
557impl Default for CoeffEncoder {
558    fn default() -> Self {
559        Self::new()
560    }
561}
562
563// =============================================================================
564// Batched Coefficient Decoder
565// =============================================================================
566
567/// Batched decoder for multiple coefficient blocks.
568pub struct BatchedCoeffDecoder {
569    /// Base decoder.
570    decoder: CoeffDecoder,
571    /// Decoded blocks cache.
572    blocks: Vec<CoeffBuffer>,
573}
574
575impl BatchedCoeffDecoder {
576    /// Create a new batched decoder.
577    pub fn new(data: Vec<u8>, quant_params: QuantizationParams, bit_depth: u8) -> Self {
578        Self {
579            decoder: CoeffDecoder::new(data, quant_params, bit_depth),
580            blocks: Vec::new(),
581        }
582    }
583
584    /// Decode multiple blocks.
585    pub fn decode_blocks(
586        &mut self,
587        specs: &[(TxSize, TxType, u8, bool)],
588    ) -> CodecResult<Vec<CoeffBuffer>> {
589        self.blocks.clear();
590
591        for &(tx_size, tx_type, plane, skip) in specs {
592            let buffer = self
593                .decoder
594                .decode_coefficients(tx_size, tx_type, plane, skip)?;
595            self.blocks.push(buffer);
596        }
597
598        Ok(std::mem::take(&mut self.blocks))
599    }
600
601    /// Get statistics for decoded blocks.
602    #[must_use]
603    pub fn get_statistics(&self) -> Vec<CoeffStats> {
604        self.blocks
605            .iter()
606            .map(|b| CoeffStats::from_coeffs(b.as_slice()))
607            .collect()
608    }
609}
610
611// =============================================================================
612// Coefficient Analysis
613// =============================================================================
614
615/// Analyze coefficient distribution.
616#[derive(Clone, Debug, Default)]
617pub struct CoeffAnalysis {
618    /// Total coefficients analyzed.
619    pub total_coeffs: u64,
620    /// Zero coefficients.
621    pub zero_count: u64,
622    /// Non-zero coefficients.
623    pub nonzero_count: u64,
624    /// DC coefficient sum.
625    pub dc_sum: i64,
626    /// AC coefficient sum.
627    pub ac_sum: i64,
628    /// Maximum absolute value.
629    pub max_abs: u32,
630}
631
632impl CoeffAnalysis {
633    /// Create new analysis.
634    #[must_use]
635    pub const fn new() -> Self {
636        Self {
637            total_coeffs: 0,
638            zero_count: 0,
639            nonzero_count: 0,
640            dc_sum: 0,
641            ac_sum: 0,
642            max_abs: 0,
643        }
644    }
645
646    /// Analyze a coefficient buffer.
647    pub fn analyze(&mut self, buffer: &CoeffBuffer) {
648        let coeffs = buffer.as_slice();
649        self.total_coeffs += coeffs.len() as u64;
650
651        if !coeffs.is_empty() {
652            self.dc_sum += i64::from(coeffs[0]);
653        }
654
655        for (i, &coeff) in coeffs.iter().enumerate() {
656            let abs_val = coeff.unsigned_abs();
657
658            if coeff == 0 {
659                self.zero_count += 1;
660            } else {
661                self.nonzero_count += 1;
662                self.max_abs = self.max_abs.max(abs_val);
663
664                if i > 0 {
665                    self.ac_sum += i64::from(coeff);
666                }
667            }
668        }
669    }
670
671    /// Get sparsity ratio (percentage of zeros).
672    #[must_use]
673    pub fn sparsity(&self) -> f64 {
674        if self.total_coeffs > 0 {
675            (self.zero_count as f64 / self.total_coeffs as f64) * 100.0
676        } else {
677            0.0
678        }
679    }
680
681    /// Get average DC value.
682    #[must_use]
683    pub fn avg_dc(&self) -> f64 {
684        if self.total_coeffs > 0 {
685            self.dc_sum as f64 / self.total_coeffs as f64
686        } else {
687            0.0
688        }
689    }
690}
691
692// =============================================================================
693// Tests
694// =============================================================================
695
696#[cfg(test)]
697mod tests {
698    use super::*;
699
700    fn create_test_quant_params() -> QuantizationParams {
701        QuantizationParams {
702            base_q_idx: 100,
703            delta_q_y_dc: 0,
704            delta_q_u_dc: 0,
705            delta_q_v_dc: 0,
706            delta_q_u_ac: 0,
707            delta_q_v_ac: 0,
708            using_qmatrix: false,
709            qm_y: 15,
710            qm_u: 15,
711            qm_v: 15,
712            delta_q_present: false,
713            delta_q_res: 0,
714        }
715    }
716
717    #[test]
718    fn test_coeff_decoder_creation() {
719        let data = vec![0u8; 128];
720        let quant = create_test_quant_params();
721        let decoder = CoeffDecoder::new(data, quant, 8);
722        assert!(decoder.has_more_data());
723    }
724
725    #[test]
726    fn test_coeff_encoder_creation() {
727        let encoder = CoeffEncoder::new();
728        let output = encoder.finish();
729        assert!(!output.is_empty() || output.is_empty());
730    }
731
732    #[test]
733    fn test_batched_decoder() {
734        let data = vec![0u8; 256];
735        let quant = create_test_quant_params();
736        let mut decoder = BatchedCoeffDecoder::new(data, quant, 8);
737
738        let specs = vec![
739            (TxSize::Tx4x4, TxType::DctDct, 0, false),
740            (TxSize::Tx8x8, TxType::DctDct, 0, false),
741        ];
742
743        // Decoding may fail with test data, but should not crash
744        let _ = decoder.decode_blocks(&specs);
745    }
746
747    #[test]
748    fn test_coeff_analysis() {
749        let mut analysis = CoeffAnalysis::new();
750        let mut buffer = CoeffBuffer::new(4, 4);
751
752        buffer.set(0, 0, 100); // DC
753        buffer.set(1, 1, 50); // AC
754        buffer.set(2, 2, -30); // AC
755
756        analysis.analyze(&buffer);
757
758        assert_eq!(analysis.total_coeffs, 16);
759        assert_eq!(analysis.nonzero_count, 3);
760        assert_eq!(analysis.zero_count, 13);
761        assert_eq!(analysis.max_abs, 100);
762    }
763
764    #[test]
765    fn test_coeff_analysis_sparsity() {
766        let mut analysis = CoeffAnalysis::new();
767        let buffer = CoeffBuffer::new(8, 8); // All zeros
768
769        analysis.analyze(&buffer);
770
771        assert_eq!(analysis.sparsity(), 100.0);
772    }
773
774    #[test]
775    fn test_constants() {
776        assert_eq!(COEFF_BASE_MAX, 3);
777        assert_eq!(BR_CDF_SIZE, 4);
778        assert_eq!(MAX_BR_PARAM, 5);
779    }
780
781    #[test]
782    fn test_eob_offset_bits() {
783        assert_eq!(EOB_OFFSET_BITS[0], 0);
784        assert_eq!(EOB_OFFSET_BITS[3], 1);
785        assert_eq!(EOB_OFFSET_BITS[11], 9);
786    }
787
788    #[test]
789    fn test_coeff_analysis_avg_dc() {
790        let mut analysis = CoeffAnalysis::new();
791        let mut buffer = CoeffBuffer::new(4, 4);
792        buffer.set(0, 0, 200);
793
794        analysis.analyze(&buffer);
795
796        // DC sum is 200, total coeffs is 16
797        assert!((analysis.avg_dc() - 12.5).abs() < 0.1);
798    }
799
800    #[test]
801    fn test_coeff_decoder_position() {
802        let data = vec![0u8; 128];
803        let quant = create_test_quant_params();
804        let decoder = CoeffDecoder::new(data, quant, 8);
805        assert!(decoder.position() <= 128);
806    }
807
808    /// Regression for the symbol-vs-position bug at the original line 154 of
809    /// this file.
810    ///
811    /// `finalize_eob_from_symbol` performs the exact computation that
812    /// `decode_eob` carries out once it has obtained the EOB-multi symbol.
813    /// By calling it directly with a known symbol we can observe the precise
814    /// EOB position produced and confirm that:
815    ///
816    ///   * for symbols 0..=3 the result is fully determined by
817    ///     `EOB_GROUP_START` (the buggy and fixed paths coincide here);
818    ///   * for symbols 4..=11 the result corresponds to the *correct* extra-
819    ///     bits count (`from_symbol(s).extra_bits()`), proving the fix; the
820    ///     buggy `from_eob(s).extra_bits()` would have read fewer bits and
821    ///     yielded a smaller EOB.
822    ///
823    /// The reader is fed all-ones bytes so `read_literal(k)` deterministically
824    /// returns `(1 << k) - 1`. This makes the fixed-vs-buggy divergence
825    /// observable in the EOB *value* itself, not just in subsequent bit
826    /// consumption.
827    #[test]
828    fn test_finalize_eob_from_symbol_reads_correct_extra_bits() {
829        use super::super::coefficients::{EobPt, EOB_GROUP_START};
830
831        // For each legal symbol s ∈ 1..=11, the fixed decoder should produce:
832        //     eob = EOB_GROUP_START[s] + ((1 << extra_bits) - 1)
833        // where extra_bits = EobPt::from_symbol(s).extra_bits().
834        //
835        // The historically buggy decoder used EobPt::from_eob(s as u16)
836        // instead, which yields a *smaller* extra_bits count for s >= 4
837        // (e.g. s=4 → buggy extra=1 vs fixed extra=2), so on all-ones input
838        // the EOB value diverges. We assert the FIXED value and additionally
839        // assert it is strictly larger than what the bug would have produced.
840        for symbol in 1..=11usize {
841            // 64 bytes of 0xFF is enough for the largest extra_bits (9).
842            let data = vec![0xFFu8; 64];
843            let quant = create_test_quant_params();
844            let mut decoder = CoeffDecoder::new(data, quant, 8);
845
846            let eob = decoder
847                .finalize_eob_from_symbol(symbol)
848                .expect("legal symbol must succeed");
849
850            let fixed_extra_bits = EobPt::from_symbol(symbol)
851                .expect("legal symbol")
852                .extra_bits();
853            let expected_extra: u16 = if fixed_extra_bits == 0 {
854                0
855            } else {
856                (1u16 << fixed_extra_bits) - 1
857            };
858            let expected_eob = EOB_GROUP_START[symbol] + expected_extra;
859            assert_eq!(
860                eob, expected_eob,
861                "symbol {symbol}: fixed decoder must read {fixed_extra_bits} \
862                 extra bits (all 1s → {expected_extra}) and produce \
863                 EOB_GROUP_START[{symbol}]+{expected_extra}={expected_eob}",
864            );
865
866            // Divergence check vs. the original bug.
867            let buggy_extra_bits = EobPt::from_eob(symbol as u16).extra_bits();
868            if symbol >= 4 {
869                // For s>=4 the buggy decoder reads fewer extra bits, so the
870                // EOB it would compute is strictly smaller.
871                let buggy_extra: u16 = if buggy_extra_bits == 0 {
872                    0
873                } else {
874                    (1u16 << buggy_extra_bits) - 1
875                };
876                let buggy_eob = EOB_GROUP_START[symbol] + buggy_extra;
877                assert!(
878                    eob > buggy_eob,
879                    "symbol {symbol}: fixed eob {eob} must exceed buggy eob \
880                     {buggy_eob} (proves the regression test discriminates)",
881                );
882            } else {
883                // For s<=3 the two paths produce identical EobPt classes.
884                assert_eq!(
885                    fixed_extra_bits, buggy_extra_bits,
886                    "symbol {symbol} (<=3): fixed and buggy extra_bits must \
887                     coincide",
888                );
889            }
890        }
891    }
892
893    /// Regression: symbol 0 means "no coefficients" and must return EOB = 0
894    /// without consuming any extra bits from the bit window.
895    #[test]
896    fn test_finalize_eob_from_symbol_zero_returns_zero() {
897        let data = vec![0xFFu8; 16];
898        let quant = create_test_quant_params();
899        let mut decoder = CoeffDecoder::new(data, quant, 8);
900
901        let eob = decoder
902            .finalize_eob_from_symbol(0)
903            .expect("symbol 0 always succeeds");
904        assert_eq!(eob, 0, "symbol 0 must short-circuit to EOB = 0");
905    }
906
907    /// Regression: the 16-symbol EOB-multi CDF can in principle emit symbols
908    /// 12..15 on a malformed bitstream. Those must surface as decoder errors,
909    /// not silent saturation that would corrupt subsequent reads.
910    #[test]
911    fn test_finalize_eob_from_symbol_rejects_out_of_range() {
912        let data = vec![0xFFu8; 16];
913        let quant = create_test_quant_params();
914        let mut decoder = CoeffDecoder::new(data, quant, 8);
915
916        for bad_symbol in 12..=15usize {
917            let err = decoder
918                .finalize_eob_from_symbol(bad_symbol)
919                .expect_err("symbols >=12 must be rejected");
920            let msg = format!("{err}");
921            assert!(
922                msg.contains("EOB-multi"),
923                "error message should mention EOB-multi: {msg}",
924            );
925        }
926    }
927
928    // -------------------------------------------------------------------------
929    // Regression tests: CoeffEncoder::pos_to_rowcol non-square TX blocks
930    // -------------------------------------------------------------------------
931    //
932    // Before the fix, width was derived as `(slice.len() as f64).sqrt() as
933    // usize`, which is only correct for square TX sizes.  For any non-square
934    // TX size the sqrt of the total coefficient count does not equal the block
935    // width, producing wrong (row, col) pairs.  These tests verify that the
936    // fix (using `CoeffBuffer::width()`) gives correct results for several
937    // representative non-square sizes.
938
939    #[test]
940    fn test_pos_to_rowcol_tx4x8() {
941        // Tx4x8: width=4, height=8, 32 coefficients total.
942        // sqrt(32) ≈ 5 (wrong); correct width = 4.
943        let encoder = CoeffEncoder::new();
944        let buf = CoeffBuffer::from_tx_size(TxSize::Tx4x8);
945        assert_eq!(buf.width(), 4);
946        assert_eq!(buf.height(), 8);
947
948        assert_eq!(encoder.pos_to_rowcol(0, &buf), (0, 0), "pos 0 -> (0,0)");
949        assert_eq!(encoder.pos_to_rowcol(3, &buf), (0, 3), "pos 3 -> (0,3)");
950        assert_eq!(encoder.pos_to_rowcol(4, &buf), (1, 0), "pos 4 -> (1,0)");
951        assert_eq!(encoder.pos_to_rowcol(7, &buf), (1, 3), "pos 7 -> (1,3)");
952    }
953
954    #[test]
955    fn test_pos_to_rowcol_tx8x4() {
956        // Tx8x4: width=8, height=4, 32 coefficients total.
957        // sqrt(32) ≈ 5 (wrong); correct width = 8.
958        let encoder = CoeffEncoder::new();
959        let buf = CoeffBuffer::from_tx_size(TxSize::Tx8x4);
960        assert_eq!(buf.width(), 8);
961        assert_eq!(buf.height(), 4);
962
963        assert_eq!(encoder.pos_to_rowcol(9, &buf), (1, 1), "pos 9 -> (1,1)");
964        assert_eq!(encoder.pos_to_rowcol(0, &buf), (0, 0), "pos 0 -> (0,0)");
965        assert_eq!(encoder.pos_to_rowcol(7, &buf), (0, 7), "pos 7 -> (0,7)");
966        assert_eq!(encoder.pos_to_rowcol(8, &buf), (1, 0), "pos 8 -> (1,0)");
967    }
968
969    #[test]
970    fn test_pos_to_rowcol_tx4x16() {
971        // Tx4x16: width=4, height=16, 64 coefficients total.
972        // sqrt(64) = 8 (wrong); correct width = 4.
973        let encoder = CoeffEncoder::new();
974        let buf = CoeffBuffer::from_tx_size(TxSize::Tx4x16);
975        assert_eq!(buf.width(), 4);
976        assert_eq!(buf.height(), 16);
977
978        assert_eq!(encoder.pos_to_rowcol(17, &buf), (4, 1), "pos 17 -> (4,1)");
979        assert_eq!(encoder.pos_to_rowcol(0, &buf), (0, 0), "pos 0 -> (0,0)");
980        assert_eq!(encoder.pos_to_rowcol(3, &buf), (0, 3), "pos 3 -> (0,3)");
981        assert_eq!(encoder.pos_to_rowcol(4, &buf), (1, 0), "pos 4 -> (1,0)");
982    }
983
984    #[test]
985    fn test_pos_to_rowcol_square_unchanged() {
986        // Tx4x4: width=4, height=4. Both old and new logic agree, confirming no
987        // regression for square sizes.
988        let encoder = CoeffEncoder::new();
989        let buf = CoeffBuffer::from_tx_size(TxSize::Tx4x4);
990        assert_eq!(buf.width(), 4);
991        assert_eq!(buf.height(), 4);
992
993        for pos in 0..16usize {
994            let (row, col) = encoder.pos_to_rowcol(pos, &buf);
995            assert_eq!(row, pos / 4, "Tx4x4 pos {pos}: row mismatch");
996            assert_eq!(col, pos % 4, "Tx4x4 pos {pos}: col mismatch");
997        }
998    }
999}