Skip to main content

oximedia_codec/av1/
coeff_encode.rs

1//! AV1 transform coefficient encoding.
2//!
3//! This module handles encoding of quantized transform coefficients with:
4//!
5//! - Scan order selection (diagonal, horizontal, vertical)
6//! - EOB (End of Block) encoding
7//! - Coefficient level encoding with context
8//! - Sign encoding
9//! - Quantization integration
10//!
11//! # Coefficient Encoding Process
12//!
13//! 1. Forward transform (DCT/ADST)
14//! 2. Quantization
15//! 3. Find EOB (last non-zero coefficient)
16//! 4. Scan in appropriate order
17//! 5. Encode levels and signs using arithmetic coder
18//!
19//! # References
20//!
21//! - AV1 Specification Section 5.11: Transform Coefficient Syntax
22
23#![forbid(unsafe_code)]
24#![allow(dead_code)]
25#![allow(clippy::cast_possible_truncation)]
26#![allow(clippy::cast_precision_loss)]
27#![allow(clippy::cast_sign_loss)]
28#![allow(clippy::similar_names)]
29#![allow(clippy::too_many_arguments)]
30
31use super::entropy_encoder::SymbolEncoder;
32use super::quantization::QuantizationParams;
33use super::transform::{TxClass, TxSize, TxType};
34
35// =============================================================================
36// Constants
37// =============================================================================
38
39/// Maximum coefficient value after quantization.
40const MAX_COEFF_LEVEL: i32 = 255;
41
42/// Number of coefficient context types.
43const COEFF_CONTEXTS: usize = 4;
44
45/// Number of EOB context types.
46const EOB_CONTEXTS: usize = 7;
47
48/// Coefficient level map contexts.
49const LEVEL_CONTEXTS: usize = 21;
50
51// =============================================================================
52// Scan Order
53// =============================================================================
54
55/// Scan order type for coefficient encoding.
56#[derive(Clone, Copy, Debug, PartialEq, Eq)]
57pub enum ScanOrder {
58    /// Default zig-zag diagonal scan.
59    Default = 0,
60    /// Horizontal scan (for vertical transforms).
61    Horizontal = 1,
62    /// Vertical scan (for horizontal transforms).
63    Vertical = 2,
64}
65
66impl ScanOrder {
67    /// Get scan order for given transform type.
68    #[must_use]
69    pub const fn from_tx_type(tx_type: TxType) -> Self {
70        match tx_type.tx_class() {
71            TxClass::Class2D => Self::Default,
72            TxClass::ClassHoriz => Self::Horizontal,
73            TxClass::ClassVert => Self::Vertical,
74        }
75    }
76}
77
78/// Generate scan order indices for a transform block.
79#[must_use]
80pub fn generate_scan_order(tx_size: TxSize, scan_order: ScanOrder) -> Vec<(usize, usize)> {
81    let w = tx_size.width() as usize;
82    let h = tx_size.height() as usize;
83    let mut indices = Vec::with_capacity(w * h);
84
85    match scan_order {
86        ScanOrder::Default => {
87            // Diagonal zig-zag scan
88            for diag in 0..(w + h - 1) {
89                if diag % 2 == 0 {
90                    // Even diagonal: go up-right
91                    let start_col = diag.min(w - 1);
92                    let start_row = diag.saturating_sub(w - 1);
93
94                    let mut col = start_col;
95                    let mut row = start_row;
96
97                    while col < w && row < h {
98                        if col <= diag && row <= diag {
99                            indices.push((row, col));
100                        }
101                        if col == 0 {
102                            break;
103                        }
104                        col -= 1;
105                        row += 1;
106                    }
107                } else {
108                    // Odd diagonal: go down-left
109                    let start_row = diag.min(h - 1);
110                    let start_col = diag.saturating_sub(h - 1);
111
112                    let mut row = start_row;
113                    let mut col = start_col;
114
115                    while row < h && col < w {
116                        if row <= diag && col <= diag {
117                            indices.push((row, col));
118                        }
119                        if row == 0 {
120                            break;
121                        }
122                        row -= 1;
123                        col += 1;
124                    }
125                }
126            }
127        }
128        ScanOrder::Horizontal => {
129            // Row-major scan
130            for y in 0..h {
131                for x in 0..w {
132                    indices.push((y, x));
133                }
134            }
135        }
136        ScanOrder::Vertical => {
137            // Column-major scan
138            for x in 0..w {
139                for y in 0..h {
140                    indices.push((y, x));
141                }
142            }
143        }
144    }
145
146    indices
147}
148
149// =============================================================================
150// Coefficient Encoder
151// =============================================================================
152
153/// Transform coefficient encoder.
154#[derive(Clone, Debug)]
155pub struct CoeffEncoder {
156    /// Symbol encoder.
157    encoder: SymbolEncoder,
158    /// Quantization parameters.
159    qparams: QuantizationParams,
160    /// Logical bits encoded (for tracking when buffer hasn't flushed).
161    bits_encoded: usize,
162}
163
164impl Default for CoeffEncoder {
165    fn default() -> Self {
166        Self::new()
167    }
168}
169
170impl CoeffEncoder {
171    /// Create a new coefficient encoder.
172    #[must_use]
173    pub fn new() -> Self {
174        Self {
175            encoder: SymbolEncoder::new(),
176            qparams: QuantizationParams::default(),
177            bits_encoded: 0,
178        }
179    }
180
181    /// Set quantization parameters.
182    pub fn set_qparams(&mut self, qparams: QuantizationParams) {
183        self.qparams = qparams;
184    }
185
186    /// Encode transform coefficients.
187    ///
188    /// # Arguments
189    ///
190    /// * `coeffs` - Quantized coefficients in raster order
191    /// * `tx_size` - Transform size
192    /// * `tx_type` - Transform type
193    /// * `plane` - Plane index (0=Y, 1=U, 2=V)
194    ///
195    /// # Returns
196    ///
197    /// Number of bits used
198    pub fn encode_coeffs(
199        &mut self,
200        coeffs: &[i32],
201        tx_size: TxSize,
202        tx_type: TxType,
203        plane: u8,
204    ) -> usize {
205        let start_len = self.encoder.buffer().len();
206        self.bits_encoded = 0;
207
208        // Find EOB (last non-zero coefficient)
209        let eob = self.find_eob(coeffs);
210
211        if eob == 0 {
212            // All zeros - encode skip
213            self.encoder.encode_bool(true);
214            self.bits_encoded += 1;
215            let buffer_bits = 8 * (self.encoder.buffer().len() - start_len);
216            return buffer_bits.max(self.bits_encoded);
217        }
218
219        // Not skip
220        self.encoder.encode_bool(false);
221        self.bits_encoded += 1;
222
223        // Encode EOB
224        self.encode_eob(eob, tx_size);
225
226        // Get scan order
227        let scan_order = ScanOrder::from_tx_type(tx_type);
228        let scan = generate_scan_order(tx_size, scan_order);
229
230        // Encode coefficients in scan order
231        self.encode_coeffs_scan(coeffs, &scan[..eob], tx_size, plane);
232
233        let buffer_bits = 8 * (self.encoder.buffer().len() - start_len);
234        buffer_bits.max(self.bits_encoded)
235    }
236
237    /// Find end of block (last non-zero coefficient position).
238    fn find_eob(&self, coeffs: &[i32]) -> usize {
239        for (i, &c) in coeffs.iter().enumerate().rev() {
240            if c != 0 {
241                return i + 1;
242            }
243        }
244        0
245    }
246
247    /// Encode EOB position.
248    fn encode_eob(&mut self, eob: usize, tx_size: TxSize) {
249        let max_eob = tx_size.max_eob() as usize;
250
251        // Simple EOB encoding (could be improved with better context)
252        let eob_bits = (max_eob.next_power_of_two().trailing_zeros()) as u8;
253        self.encoder.encode_literal(eob as u32, eob_bits);
254    }
255
256    /// Encode coefficients in scan order.
257    fn encode_coeffs_scan(
258        &mut self,
259        coeffs: &[i32],
260        scan: &[(usize, usize)],
261        tx_size: TxSize,
262        _plane: u8,
263    ) {
264        let stride = tx_size.width() as usize;
265
266        for &(row, col) in scan {
267            let idx = row * stride + col;
268            if idx >= coeffs.len() {
269                break;
270            }
271
272            let coeff = coeffs[idx];
273            self.encode_coeff(coeff);
274        }
275    }
276
277    /// Encode a single coefficient.
278    fn encode_coeff(&mut self, coeff: i32) {
279        let level = coeff.abs();
280
281        if level == 0 {
282            // Zero coefficient
283            self.encoder.encode_literal(0, 8);
284            return;
285        }
286
287        // Encode level (simplified - no context modeling)
288        let level_clamped = level.min(MAX_COEFF_LEVEL) as u32;
289        self.encoder.encode_literal(level_clamped, 8);
290
291        // Encode sign
292        self.encoder.encode_bool(coeff < 0);
293    }
294
295    /// Get encoded output.
296    #[must_use]
297    pub fn finish(&mut self) -> Vec<u8> {
298        self.encoder.finish()
299    }
300
301    /// Reset encoder state.
302    pub fn reset(&mut self) {
303        self.encoder.reset();
304    }
305}
306
307// =============================================================================
308// Quantization
309// =============================================================================
310
311/// Quantize transform coefficients.
312#[must_use]
313pub fn quantize_coeffs(coeffs: &[i32], qp: u8, tx_size: TxSize) -> Vec<i32> {
314    let q_step = compute_q_step(qp);
315    let area = tx_size.area() as usize;
316    let mut quantized = vec![0i32; area.min(coeffs.len())];
317
318    for (i, &c) in coeffs.iter().take(area).enumerate() {
319        quantized[i] = quantize_coeff(c, q_step);
320    }
321
322    quantized
323}
324
325/// Dequantize transform coefficients.
326#[must_use]
327pub fn dequantize_coeffs(coeffs: &[i32], qp: u8, tx_size: TxSize) -> Vec<i32> {
328    let q_step = compute_q_step(qp);
329    let area = tx_size.area() as usize;
330    let mut dequantized = vec![0i32; area.min(coeffs.len())];
331
332    for (i, &c) in coeffs.iter().take(area).enumerate() {
333        dequantized[i] = dequantize_coeff(c, q_step);
334    }
335
336    dequantized
337}
338
339/// Compute quantization step from QP.
340#[must_use]
341fn compute_q_step(qp: u8) -> i32 {
342    // Simplified: q_step = 2^(qp/6)
343    let qp_f = f32::from(qp);
344    (2.0_f32.powf(qp_f / 6.0)) as i32
345}
346
347/// Quantize a single coefficient.
348#[must_use]
349fn quantize_coeff(coeff: i32, q_step: i32) -> i32 {
350    if q_step == 0 {
351        return coeff;
352    }
353
354    let sign = coeff.signum();
355    let abs_coeff = coeff.abs();
356    let quantized = (abs_coeff + q_step / 2) / q_step;
357
358    sign * quantized.min(MAX_COEFF_LEVEL)
359}
360
361/// Dequantize a single coefficient.
362#[must_use]
363fn dequantize_coeff(coeff: i32, q_step: i32) -> i32 {
364    coeff * q_step
365}
366
367// =============================================================================
368// Coefficient Statistics
369// =============================================================================
370
371/// Statistics for coefficient encoding.
372#[derive(Clone, Debug, Default)]
373pub struct CoeffStats {
374    /// Total number of coefficients.
375    pub total_coeffs: usize,
376    /// Number of zero coefficients.
377    pub zero_coeffs: usize,
378    /// Number of blocks skipped.
379    pub skip_blocks: usize,
380    /// Total bits used.
381    pub total_bits: usize,
382}
383
384impl CoeffStats {
385    /// Create new statistics.
386    #[must_use]
387    pub const fn new() -> Self {
388        Self {
389            total_coeffs: 0,
390            zero_coeffs: 0,
391            skip_blocks: 0,
392            total_bits: 0,
393        }
394    }
395
396    /// Update statistics from coefficient block.
397    pub fn update(&mut self, coeffs: &[i32], bits_used: usize) {
398        self.total_coeffs += coeffs.len();
399        self.zero_coeffs += coeffs.iter().filter(|&&c| c == 0).count();
400        self.total_bits += bits_used;
401
402        if coeffs.iter().all(|&c| c == 0) {
403            self.skip_blocks += 1;
404        }
405    }
406
407    /// Get average bits per coefficient.
408    #[must_use]
409    pub fn avg_bits_per_coeff(&self) -> f32 {
410        if self.total_coeffs == 0 {
411            0.0
412        } else {
413            self.total_bits as f32 / self.total_coeffs as f32
414        }
415    }
416
417    /// Get zero coefficient ratio.
418    #[must_use]
419    pub fn zero_ratio(&self) -> f32 {
420        if self.total_coeffs == 0 {
421            0.0
422        } else {
423            self.zero_coeffs as f32 / self.total_coeffs as f32
424        }
425    }
426}
427
428// =============================================================================
429// Context Modeling
430// =============================================================================
431
432/// Coefficient level context.
433#[derive(Clone, Copy, Debug)]
434pub struct CoeffContext {
435    /// Number of non-zero neighbors.
436    pub nz_neighbors: u8,
437    /// Position in block (DC or AC).
438    pub is_dc: bool,
439    /// Previous coefficient level.
440    pub prev_level: u8,
441}
442
443impl Default for CoeffContext {
444    fn default() -> Self {
445        Self {
446            nz_neighbors: 0,
447            is_dc: false,
448            prev_level: 0,
449        }
450    }
451}
452
453impl CoeffContext {
454    /// Get context index for level encoding.
455    #[must_use]
456    pub const fn level_ctx(&self) -> usize {
457        let base = if self.is_dc { 0 } else { 7 };
458        let nz = self.nz_neighbors as usize;
459        let nz_clamped = if nz > 3 { 3 } else { nz };
460        let prev = self.prev_level as usize;
461        let prev_clamped = if prev > 1 { 1 } else { prev };
462        let offset = nz_clamped * 2 + prev_clamped;
463        let result = base + offset;
464        if result > LEVEL_CONTEXTS - 1 {
465            LEVEL_CONTEXTS - 1
466        } else {
467            result
468        }
469    }
470
471    /// Get context index for EOB encoding.
472    #[must_use]
473    pub const fn eob_ctx(&self) -> usize {
474        self.nz_neighbors as usize % EOB_CONTEXTS
475    }
476}
477
478// =============================================================================
479// Tests
480// =============================================================================
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn test_scan_order_from_tx_type() {
488        assert_eq!(ScanOrder::from_tx_type(TxType::DctDct), ScanOrder::Default);
489        assert_eq!(
490            ScanOrder::from_tx_type(TxType::DctIdtx),
491            ScanOrder::Horizontal
492        );
493        assert_eq!(
494            ScanOrder::from_tx_type(TxType::IdtxDct),
495            ScanOrder::Vertical
496        );
497    }
498
499    #[test]
500    fn test_generate_scan_order_4x4() {
501        let scan = generate_scan_order(TxSize::Tx4x4, ScanOrder::Default);
502        assert_eq!(scan.len(), 16);
503        assert_eq!(scan[0], (0, 0)); // DC coefficient first
504    }
505
506    #[test]
507    fn test_generate_scan_order_horizontal() {
508        let scan = generate_scan_order(TxSize::Tx4x4, ScanOrder::Horizontal);
509        assert_eq!(scan.len(), 16);
510        // Row-major order
511        assert_eq!(scan[0], (0, 0));
512        assert_eq!(scan[1], (0, 1));
513        assert_eq!(scan[4], (1, 0));
514    }
515
516    #[test]
517    fn test_generate_scan_order_vertical() {
518        let scan = generate_scan_order(TxSize::Tx4x4, ScanOrder::Vertical);
519        assert_eq!(scan.len(), 16);
520        // Column-major order
521        assert_eq!(scan[0], (0, 0));
522        assert_eq!(scan[1], (1, 0));
523        assert_eq!(scan[4], (0, 1));
524    }
525
526    #[test]
527    fn test_coeff_encoder_creation() {
528        let encoder = CoeffEncoder::new();
529        assert!(!encoder.encoder.buffer().is_empty() || encoder.encoder.buffer().is_empty());
530    }
531
532    #[test]
533    fn test_find_eob() {
534        let encoder = CoeffEncoder::new();
535
536        let coeffs = vec![1, 2, 0, 3, 0, 0, 0];
537        let eob = encoder.find_eob(&coeffs);
538        assert_eq!(eob, 4);
539
540        let all_zero = vec![0; 16];
541        let eob_zero = encoder.find_eob(&all_zero);
542        assert_eq!(eob_zero, 0);
543    }
544
545    #[test]
546    fn test_encode_all_zero_block() {
547        let mut encoder = CoeffEncoder::new();
548        let coeffs = vec![0; 16];
549
550        let bits = encoder.encode_coeffs(&coeffs, TxSize::Tx4x4, TxType::DctDct, 0);
551        assert!(bits > 0);
552    }
553
554    #[test]
555    fn test_encode_non_zero_block() {
556        let mut encoder = CoeffEncoder::new();
557        let coeffs = vec![10, 5, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
558
559        let bits = encoder.encode_coeffs(&coeffs, TxSize::Tx4x4, TxType::DctDct, 0);
560        assert!(bits > 0);
561    }
562
563    #[test]
564    fn test_quantize_coeff() {
565        let q_step = 4;
566
567        let q1 = quantize_coeff(10, q_step);
568        assert_eq!(q1, 3); // (10 + 2) / 4 = 3
569
570        let q2 = quantize_coeff(-10, q_step);
571        assert_eq!(q2, -3);
572
573        let q3 = quantize_coeff(0, q_step);
574        assert_eq!(q3, 0);
575    }
576
577    #[test]
578    fn test_dequantize_coeff() {
579        let q_step = 4;
580
581        let dq1 = dequantize_coeff(3, q_step);
582        assert_eq!(dq1, 12);
583
584        let dq2 = dequantize_coeff(-3, q_step);
585        assert_eq!(dq2, -12);
586    }
587
588    #[test]
589    fn test_compute_q_step() {
590        let q_step_0 = compute_q_step(0);
591        assert!(q_step_0 > 0);
592
593        let q_step_30 = compute_q_step(30);
594        assert!(q_step_30 > q_step_0);
595    }
596
597    #[test]
598    fn test_quantize_coeffs_array() {
599        let coeffs = vec![10, 20, 30, 40];
600        let quantized = quantize_coeffs(&coeffs, 6, TxSize::Tx4x4);
601
602        assert_eq!(quantized.len(), 4);
603        assert!(quantized[0] <= coeffs[0]);
604        assert!(quantized[3] <= coeffs[3]);
605    }
606
607    #[test]
608    fn test_dequantize_coeffs_array() {
609        let coeffs = vec![2, 4, 6, 8];
610        let dequantized = dequantize_coeffs(&coeffs, 6, TxSize::Tx4x4);
611
612        assert_eq!(dequantized.len(), 4);
613        assert!(dequantized[0] >= coeffs[0]);
614    }
615
616    #[test]
617    fn test_coeff_stats() {
618        let mut stats = CoeffStats::new();
619        assert_eq!(stats.total_coeffs, 0);
620
621        let coeffs = vec![1, 0, 2, 0, 0, 3];
622        stats.update(&coeffs, 100);
623
624        assert_eq!(stats.total_coeffs, 6);
625        assert_eq!(stats.zero_coeffs, 3);
626        assert_eq!(stats.zero_ratio(), 0.5);
627    }
628
629    #[test]
630    fn test_coeff_stats_skip() {
631        let mut stats = CoeffStats::new();
632        let all_zero = vec![0; 16];
633        stats.update(&all_zero, 8);
634
635        assert_eq!(stats.skip_blocks, 1);
636    }
637
638    #[test]
639    fn test_coeff_context_dc() {
640        let ctx = CoeffContext {
641            nz_neighbors: 2,
642            is_dc: true,
643            prev_level: 1,
644        };
645
646        let level_ctx = ctx.level_ctx();
647        assert!(level_ctx < LEVEL_CONTEXTS);
648
649        let eob_ctx = ctx.eob_ctx();
650        assert!(eob_ctx < EOB_CONTEXTS);
651    }
652
653    #[test]
654    fn test_coeff_context_ac() {
655        let ctx = CoeffContext {
656            nz_neighbors: 1,
657            is_dc: false,
658            prev_level: 0,
659        };
660
661        let level_ctx = ctx.level_ctx();
662        assert!(level_ctx < LEVEL_CONTEXTS);
663        assert!(level_ctx >= 7); // AC contexts start at 7
664    }
665
666    #[test]
667    fn test_scan_order_coverage() {
668        // Ensure all positions are covered
669        let scan_default = generate_scan_order(TxSize::Tx4x4, ScanOrder::Default);
670        let scan_horiz = generate_scan_order(TxSize::Tx4x4, ScanOrder::Horizontal);
671        let scan_vert = generate_scan_order(TxSize::Tx4x4, ScanOrder::Vertical);
672
673        assert_eq!(scan_default.len(), 16);
674        assert_eq!(scan_horiz.len(), 16);
675        assert_eq!(scan_vert.len(), 16);
676
677        // Check uniqueness
678        let mut positions = std::collections::HashSet::new();
679        for pos in &scan_default {
680            positions.insert(*pos);
681        }
682        assert_eq!(positions.len(), 16);
683    }
684}