Skip to main content

oximedia_codec/simd/
dct.rs

1//! Discrete Cosine Transform (DCT) operations.
2//!
3//! DCT is fundamental to video compression, converting spatial domain data
4//! to frequency domain. This module provides:
5//!
6//! - Forward DCT for encoding
7//! - Inverse DCT for decoding
8//! - Support for 4x4, 8x8, 16x16, and 32x32 block sizes
9//!
10//! The implementations use integer arithmetic for deterministic results
11//! across platforms.
12//!
13//! # DCT Types
14//!
15//! Video codecs typically use DCT-II for forward transform and DCT-III
16//! (the inverse of DCT-II) for inverse transform. Modern codecs like
17//! AV1 also use asymmetric DST for certain blocks.
18
19#![forbid(unsafe_code)]
20// Allow loop indexing for matrix operations
21#![allow(clippy::needless_range_loop)]
22// Allow truncation casts for DCT coefficient handling
23#![allow(clippy::cast_possible_truncation)]
24
25use super::scalar::ScalarFallback;
26use super::traits::{SimdOps, SimdOpsExt};
27use super::types::{I16x8, I32x4};
28
29/// DCT operations using SIMD.
30pub struct DctOps<S: SimdOps> {
31    simd: S,
32}
33
34impl<S: SimdOps + Default> Default for DctOps<S> {
35    fn default() -> Self {
36        Self::new(S::default())
37    }
38}
39
40impl<S: SimdOps> DctOps<S> {
41    /// Create a new DCT operations instance.
42    #[inline]
43    #[must_use]
44    pub const fn new(simd: S) -> Self {
45        Self { simd }
46    }
47
48    /// Get the underlying SIMD implementation.
49    #[inline]
50    #[must_use]
51    pub const fn simd(&self) -> &S {
52        &self.simd
53    }
54}
55
56/// DCT coefficients for 4x4 transform (scaled by 64).
57///
58/// Based on the 4-point DCT-II matrix:
59/// ```text
60/// [ a  a  a  a ]   a = cos(0) = 1
61/// [ b  c -c -b ]   b = cos(pi/8), c = cos(3pi/8)
62/// [ a -a -a  a ]
63/// [ c -b  b -c ]
64/// ```
65#[allow(dead_code)]
66pub const DCT4_COEFFS: [[i16; 4]; 4] = [
67    [64, 64, 64, 64],   // row 0: all positive
68    [83, 36, -36, -83], // row 1: b, c, -c, -b (scaled)
69    [64, -64, -64, 64], // row 2: a, -a, -a, a
70    [36, -83, 83, -36], // row 3: c, -b, b, -c (scaled)
71];
72
73/// DCT coefficients for 8x8 transform (scaled by 64).
74#[allow(dead_code)]
75pub const DCT8_COEFFS: [[i16; 8]; 8] = [
76    [64, 64, 64, 64, 64, 64, 64, 64],
77    [89, 75, 50, 18, -18, -50, -75, -89],
78    [83, 36, -36, -83, -83, -36, 36, 83],
79    [75, -18, -89, -50, 50, 89, 18, -75],
80    [64, -64, -64, 64, 64, -64, -64, 64],
81    [50, -89, 18, 75, -75, -18, 89, -50],
82    [36, -83, 83, -36, -36, 83, -83, 36],
83    [18, -50, 75, -89, 89, -75, 50, -18],
84];
85
86impl<S: SimdOps + SimdOpsExt> DctOps<S> {
87    /// Forward 4x4 DCT.
88    ///
89    /// Transforms a 4x4 block of residuals to frequency coefficients.
90    ///
91    /// # Arguments
92    /// * `input` - 4x4 input block (row-major)
93    /// * `output` - 4x4 output coefficients (row-major)
94    #[allow(dead_code)]
95    pub fn forward_dct_4x4(&self, input: &[i16; 16], output: &mut [i16; 16]) {
96        // Load input rows into vectors
97        let rows = [
98            I16x8::from_array([input[0], input[1], input[2], input[3], 0, 0, 0, 0]),
99            I16x8::from_array([input[4], input[5], input[6], input[7], 0, 0, 0, 0]),
100            I16x8::from_array([input[8], input[9], input[10], input[11], 0, 0, 0, 0]),
101            I16x8::from_array([input[12], input[13], input[14], input[15], 0, 0, 0, 0]),
102        ];
103
104        // First pass: transform rows
105        let mut temp = [[0i16; 4]; 4];
106        for i in 0..4 {
107            for j in 0..4 {
108                let mut sum = 0i32;
109                for k in 0..4 {
110                    sum += i32::from(rows[i].0[k]) * i32::from(DCT4_COEFFS[j][k]);
111                }
112                // Round and scale
113                temp[i][j] = ((sum + 32) >> 6) as i16;
114            }
115        }
116
117        // Second pass: transform columns (transpose and transform)
118        for j in 0..4 {
119            for i in 0..4 {
120                let mut sum = 0i32;
121                for k in 0..4 {
122                    sum += i32::from(temp[k][j]) * i32::from(DCT4_COEFFS[i][k]);
123                }
124                // Round and scale
125                output[i * 4 + j] = ((sum + 32) >> 6) as i16;
126            }
127        }
128    }
129
130    /// Inverse 4x4 DCT.
131    ///
132    /// Transforms 4x4 frequency coefficients back to spatial domain.
133    ///
134    /// # Arguments
135    /// * `input` - 4x4 input coefficients (row-major)
136    /// * `output` - 4x4 output block (row-major)
137    #[allow(dead_code)]
138    pub fn inverse_dct_4x4(&self, input: &[i16; 16], output: &mut [i16; 16]) {
139        // First pass: transform columns
140        let mut temp = [[0i64; 4]; 4];
141        for j in 0..4 {
142            for i in 0..4 {
143                let mut sum = 0i64;
144                for k in 0..4 {
145                    sum += i64::from(input[k * 4 + j]) * i64::from(DCT4_COEFFS[k][i]);
146                }
147                temp[i][j] = sum;
148            }
149        }
150
151        // Second pass: transform rows
152        // Total normalization: 64*64*N*N = 64*64*16 = 65536 = 2^16
153        for i in 0..4 {
154            for j in 0..4 {
155                let mut sum = 0i64;
156                for k in 0..4 {
157                    sum += temp[i][k] * i64::from(DCT4_COEFFS[k][j]);
158                }
159                // Round and scale (divide by 65536 = 64*64*4*4)
160                output[i * 4 + j] = ((sum + 32768) >> 16) as i16;
161            }
162        }
163    }
164
165    /// Forward 8x8 DCT.
166    #[allow(dead_code)]
167    pub fn forward_dct_8x8(&self, input: &[i16; 64], output: &mut [i16; 64]) {
168        // First pass: transform rows
169        let mut temp = [[0i32; 8]; 8];
170        for i in 0..8 {
171            for j in 0..8 {
172                let mut sum = 0i32;
173                for k in 0..8 {
174                    sum += i32::from(input[i * 8 + k]) * i32::from(DCT8_COEFFS[j][k]);
175                }
176                temp[i][j] = (sum + 32) >> 6;
177            }
178        }
179
180        // Second pass: transform columns
181        for j in 0..8 {
182            for i in 0..8 {
183                let mut sum = 0i32;
184                for k in 0..8 {
185                    sum += temp[k][j] * i32::from(DCT8_COEFFS[i][k]);
186                }
187                output[i * 8 + j] = ((sum + 32) >> 6) as i16;
188            }
189        }
190    }
191
192    /// Inverse 8x8 DCT.
193    #[allow(dead_code)]
194    pub fn inverse_dct_8x8(&self, input: &[i16; 64], output: &mut [i16; 64]) {
195        // First pass: transform columns
196        let mut temp = [[0i64; 8]; 8];
197        for j in 0..8 {
198            for i in 0..8 {
199                let mut sum = 0i64;
200                for k in 0..8 {
201                    sum += i64::from(input[k * 8 + j]) * i64::from(DCT8_COEFFS[k][i]);
202                }
203                temp[i][j] = sum;
204            }
205        }
206
207        // Second pass: transform rows
208        // Total normalization: 64*64*N*N = 64*64*64 = 262144 = 2^18
209        for i in 0..8 {
210            for j in 0..8 {
211                let mut sum = 0i64;
212                for k in 0..8 {
213                    sum += temp[i][k] * i64::from(DCT8_COEFFS[k][j]);
214                }
215                // Round and scale (divide by 262144 = 64*64*8*8)
216                output[i * 8 + j] = ((sum + 131_072) >> 18) as i16;
217            }
218        }
219    }
220
221    /// Forward 16x16 DCT using recursive decomposition.
222    ///
223    /// Decomposes into 4 8x8 DCTs for efficiency.
224    #[allow(dead_code)]
225    pub fn forward_dct_16x16(&self, input: &[i16; 256], output: &mut [i16; 256]) {
226        // For now, use direct computation
227        // A real implementation would use recursive decomposition
228        self.forward_dct_nxn::<16>(input, output);
229    }
230
231    /// Inverse 16x16 DCT.
232    #[allow(dead_code)]
233    pub fn inverse_dct_16x16(&self, input: &[i16; 256], output: &mut [i16; 256]) {
234        self.inverse_dct_nxn::<16>(input, output);
235    }
236
237    /// Forward 32x32 DCT.
238    #[allow(dead_code)]
239    pub fn forward_dct_32x32(&self, input: &[i16; 1024], output: &mut [i16; 1024]) {
240        self.forward_dct_nxn::<32>(input, output);
241    }
242
243    /// Inverse 32x32 DCT.
244    #[allow(dead_code)]
245    pub fn inverse_dct_32x32(&self, input: &[i16; 1024], output: &mut [i16; 1024]) {
246        self.inverse_dct_nxn::<32>(input, output);
247    }
248
249    /// Generic forward DCT for `NxN` block.
250    #[allow(dead_code, clippy::unused_self)]
251    fn forward_dct_nxn<const N: usize>(&self, input: &[i16], output: &mut [i16]) {
252        let coeffs = generate_dct_coeffs::<N>();
253
254        // First pass: rows
255        let mut temp = vec![0i32; N * N];
256        for i in 0..N {
257            for j in 0..N {
258                let mut sum = 0i32;
259                for k in 0..N {
260                    sum += i32::from(input[i * N + k]) * coeffs[j][k];
261                }
262                temp[i * N + j] = (sum + 32) >> 6;
263            }
264        }
265
266        // Second pass: columns
267        for j in 0..N {
268            for i in 0..N {
269                let mut sum = 0i32;
270                for k in 0..N {
271                    sum += temp[k * N + j] * coeffs[i][k];
272                }
273                output[i * N + j] = ((sum + 32) >> 6) as i16;
274            }
275        }
276    }
277
278    /// Generic inverse DCT for `NxN` block.
279    #[allow(dead_code, clippy::unused_self)]
280    fn inverse_dct_nxn<const N: usize>(&self, input: &[i16], output: &mut [i16]) {
281        let coeffs = generate_dct_coeffs::<N>();
282
283        // Calculate shift: 12 + 2*log2(N)
284        // N=4: shift=16, N=8: shift=18, N=16: shift=20, N=32: shift=22
285        let n_shift = (N as u32).trailing_zeros();
286        let total_shift = 12 + 2 * n_shift;
287        let round = 1i64 << (total_shift - 1);
288
289        // First pass: columns
290        let mut temp = vec![0i64; N * N];
291        for j in 0..N {
292            for i in 0..N {
293                let mut sum = 0i64;
294                for k in 0..N {
295                    sum += i64::from(input[k * N + j]) * i64::from(coeffs[k][i]);
296                }
297                temp[i * N + j] = sum;
298            }
299        }
300
301        // Second pass: rows
302        for i in 0..N {
303            for j in 0..N {
304                let mut sum = 0i64;
305                for k in 0..N {
306                    sum += temp[i * N + k] * i64::from(coeffs[k][j]);
307                }
308                output[i * N + j] = ((sum + round) >> total_shift) as i16;
309            }
310        }
311    }
312
313    /// Butterfly operation for DCT.
314    #[inline]
315    #[allow(dead_code)]
316    pub fn butterfly_add(&self, a: I16x8, b: I16x8) -> I16x8 {
317        self.simd.add_i16x8(a, b)
318    }
319
320    /// Butterfly operation for DCT (subtraction).
321    #[inline]
322    #[allow(dead_code)]
323    pub fn butterfly_sub(&self, a: I16x8, b: I16x8) -> I16x8 {
324        self.simd.sub_i16x8(a, b)
325    }
326
327    /// Multiply-add for DCT coefficients.
328    #[inline]
329    #[allow(dead_code)]
330    pub fn dct_madd(&self, a: I16x8, coeff: I16x8) -> I32x4 {
331        self.simd.pmaddwd(a, coeff)
332    }
333}
334
335/// Generate DCT coefficients for `NxN` transform.
336///
337/// Uses the DCT-II formula: C[k][n] = cos(pi * k * (2n + 1) / (2N))
338#[allow(clippy::cast_precision_loss)]
339fn generate_dct_coeffs<const N: usize>() -> Vec<Vec<i32>> {
340    let mut coeffs = vec![vec![0i32; N]; N];
341    let pi = std::f64::consts::PI;
342    let n_f64 = N as f64;
343
344    for k in 0..N {
345        for n in 0..N {
346            let angle = pi * (k as f64) * (2.0 * (n as f64) + 1.0) / (2.0 * n_f64);
347            coeffs[k][n] = (angle.cos() * 64.0).round() as i32;
348        }
349    }
350
351    coeffs
352}
353
354/// Create a DCT operations instance with scalar fallback.
355#[inline]
356#[must_use]
357pub fn dct_ops() -> DctOps<ScalarFallback> {
358    DctOps::new(ScalarFallback::new())
359}
360
361/// Quantize DCT coefficients.
362///
363/// # Arguments
364/// * `coeffs` - DCT coefficients
365/// * `qp` - Quantization parameter (0-51 for H.264/AV1)
366/// * `output` - Quantized coefficients
367#[allow(dead_code)]
368pub fn quantize_4x4(coeffs: &[i16; 16], qp: u8, output: &mut [i16; 16]) {
369    // Simplified quantization (real implementation uses tables)
370    let scale: i32 = 1 << (15 - (qp / 6));
371
372    for (i, &c) in coeffs.iter().enumerate() {
373        let val = i32::from(c);
374        let sign = if val < 0 { -1i32 } else { 1i32 };
375        output[i] = (sign * ((val.abs() * scale + (1 << 14)) >> 15)) as i16;
376    }
377}
378
379/// Dequantize DCT coefficients.
380#[allow(dead_code)]
381pub fn dequantize_4x4(coeffs: &[i16; 16], qp: u8, output: &mut [i16; 16]) {
382    let scale = 1 << (qp / 6);
383
384    for (i, &c) in coeffs.iter().enumerate() {
385        output[i] = (i32::from(c) * scale) as i16;
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[test]
394    fn test_dct4_coeffs_orthogonality() {
395        // Verify that DCT matrix rows are approximately orthogonal
396        for i in 0..4 {
397            for j in i + 1..4 {
398                let dot: i32 = (0..4)
399                    .map(|k| i32::from(DCT4_COEFFS[i][k]) * i32::from(DCT4_COEFFS[j][k]))
400                    .sum();
401                // Dot product of different rows should be near zero
402                assert!(
403                    dot.abs() < 100,
404                    "Rows {} and {} not orthogonal: {}",
405                    i,
406                    j,
407                    dot
408                );
409            }
410        }
411    }
412
413    #[test]
414    fn test_forward_inverse_4x4_identity() {
415        let ops = dct_ops();
416
417        // Test with a simple block
418        let input = [
419            100, 102, 104, 106, 110, 112, 114, 116, 120, 122, 124, 126, 130, 132, 134, 136,
420        ];
421
422        let mut dct_output = [0i16; 16];
423        let mut reconstructed = [0i16; 16];
424
425        ops.forward_dct_4x4(&input, &mut dct_output);
426        ops.inverse_dct_4x4(&dct_output, &mut reconstructed);
427
428        // Reconstructed should be close to original
429        for i in 0..16 {
430            let diff = (i32::from(input[i]) - i32::from(reconstructed[i])).abs();
431            assert!(
432                diff <= 2,
433                "Mismatch at {}: {} vs {}",
434                i,
435                input[i],
436                reconstructed[i]
437            );
438        }
439    }
440
441    #[test]
442    fn test_forward_inverse_8x8_identity() {
443        let ops = dct_ops();
444
445        // Test with constant block
446        let input = [128i16; 64];
447        let mut dct_output = [0i16; 64];
448        let mut reconstructed = [0i16; 64];
449
450        ops.forward_dct_8x8(&input, &mut dct_output);
451
452        // DC coefficient should be large, others near zero
453        assert!(dct_output[0].abs() > 100);
454        for i in 1..64 {
455            assert!(
456                dct_output[i].abs() < 10,
457                "Non-DC coeff {} too large: {}",
458                i,
459                dct_output[i]
460            );
461        }
462
463        ops.inverse_dct_8x8(&dct_output, &mut reconstructed);
464
465        // Reconstructed should be close to original
466        for i in 0..64 {
467            let diff = (i32::from(input[i]) - i32::from(reconstructed[i])).abs();
468            assert!(
469                diff <= 2,
470                "Mismatch at {}: {} vs {}",
471                i,
472                input[i],
473                reconstructed[i]
474            );
475        }
476    }
477
478    #[test]
479    fn test_dct_zero_input() {
480        let ops = dct_ops();
481
482        let input = [0i16; 16];
483        let mut output = [1i16; 16]; // Initialize with non-zero
484
485        ops.forward_dct_4x4(&input, &mut output);
486
487        // All outputs should be zero
488        for (i, &v) in output.iter().enumerate() {
489            assert_eq!(v, 0, "Non-zero output at {}: {}", i, v);
490        }
491    }
492
493    #[test]
494    fn test_dct_dc_only() {
495        let ops = dct_ops();
496
497        // Constant input should produce only DC coefficient
498        let input = [64i16; 16];
499        let mut output = [0i16; 16];
500
501        ops.forward_dct_4x4(&input, &mut output);
502
503        // DC coefficient should be non-zero
504        assert!(output[0] != 0);
505
506        // AC coefficients should be near zero
507        for (i, &v) in output.iter().enumerate().skip(1) {
508            assert!(v.abs() < 5, "AC coeff {} too large: {}", i, v);
509        }
510    }
511
512    #[test]
513    fn test_quantize_dequantize() {
514        let coeffs = [100i16, -50, 25, -12, 6, -3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0];
515        let mut quantized = [0i16; 16];
516        let mut dequantized = [0i16; 16];
517
518        // Quantize with moderate QP
519        quantize_4x4(&coeffs, 20, &mut quantized);
520
521        // Large coefficients should survive quantization
522        assert!(quantized[0] != 0);
523
524        // Dequantize
525        dequantize_4x4(&quantized, 20, &mut dequantized);
526
527        // Should be approximately the same
528        let dc_diff = (i32::from(coeffs[0]) - i32::from(dequantized[0])).abs();
529        assert!(
530            dc_diff < i32::from(coeffs[0]) / 2,
531            "DC diff too large: {}",
532            dc_diff
533        );
534    }
535
536    #[test]
537    fn test_generate_dct_coeffs() {
538        let coeffs = generate_dct_coeffs::<4>();
539
540        assert_eq!(coeffs.len(), 4);
541        assert_eq!(coeffs[0].len(), 4);
542
543        // First row should be all positive (cos(0) = 1)
544        for &c in &coeffs[0] {
545            assert!(c > 0);
546        }
547    }
548
549    #[test]
550    fn test_dct8_coeffs() {
551        // Verify DCT8 coefficient properties
552        // First row should be constant (all 64)
553        assert_eq!(DCT8_COEFFS[0], [64, 64, 64, 64, 64, 64, 64, 64]);
554
555        // Row 4 should alternate: +64, -64, -64, +64, +64, -64, -64, +64
556        assert_eq!(DCT8_COEFFS[4], [64, -64, -64, 64, 64, -64, -64, 64]);
557    }
558}