1#![forbid(unsafe_code)]
20#![allow(clippy::needless_range_loop)]
22#![allow(clippy::cast_possible_truncation)]
24
25use super::scalar::ScalarFallback;
26use super::traits::{SimdOps, SimdOpsExt};
27use super::types::{I16x8, I32x4};
28
29pub struct DctOps<S: SimdOps> {
31 simd: S,
32}
33
34impl<S: SimdOps + Default> Default for DctOps<S> {
35 fn default() -> Self {
36 Self::new(S::default())
37 }
38}
39
40impl<S: SimdOps> DctOps<S> {
41 #[inline]
43 #[must_use]
44 pub const fn new(simd: S) -> Self {
45 Self { simd }
46 }
47
48 #[inline]
50 #[must_use]
51 pub const fn simd(&self) -> &S {
52 &self.simd
53 }
54}
55
56#[allow(dead_code)]
66pub const DCT4_COEFFS: [[i16; 4]; 4] = [
67 [64, 64, 64, 64], [83, 36, -36, -83], [64, -64, -64, 64], [36, -83, 83, -36], ];
72
73#[allow(dead_code)]
75pub const DCT8_COEFFS: [[i16; 8]; 8] = [
76 [64, 64, 64, 64, 64, 64, 64, 64],
77 [89, 75, 50, 18, -18, -50, -75, -89],
78 [83, 36, -36, -83, -83, -36, 36, 83],
79 [75, -18, -89, -50, 50, 89, 18, -75],
80 [64, -64, -64, 64, 64, -64, -64, 64],
81 [50, -89, 18, 75, -75, -18, 89, -50],
82 [36, -83, 83, -36, -36, 83, -83, 36],
83 [18, -50, 75, -89, 89, -75, 50, -18],
84];
85
86impl<S: SimdOps + SimdOpsExt> DctOps<S> {
87 #[allow(dead_code)]
95 pub fn forward_dct_4x4(&self, input: &[i16; 16], output: &mut [i16; 16]) {
96 let rows = [
98 I16x8::from_array([input[0], input[1], input[2], input[3], 0, 0, 0, 0]),
99 I16x8::from_array([input[4], input[5], input[6], input[7], 0, 0, 0, 0]),
100 I16x8::from_array([input[8], input[9], input[10], input[11], 0, 0, 0, 0]),
101 I16x8::from_array([input[12], input[13], input[14], input[15], 0, 0, 0, 0]),
102 ];
103
104 let mut temp = [[0i16; 4]; 4];
106 for i in 0..4 {
107 for j in 0..4 {
108 let mut sum = 0i32;
109 for k in 0..4 {
110 sum += i32::from(rows[i].0[k]) * i32::from(DCT4_COEFFS[j][k]);
111 }
112 temp[i][j] = ((sum + 32) >> 6) as i16;
114 }
115 }
116
117 for j in 0..4 {
119 for i in 0..4 {
120 let mut sum = 0i32;
121 for k in 0..4 {
122 sum += i32::from(temp[k][j]) * i32::from(DCT4_COEFFS[i][k]);
123 }
124 output[i * 4 + j] = ((sum + 32) >> 6) as i16;
126 }
127 }
128 }
129
130 #[allow(dead_code)]
138 pub fn inverse_dct_4x4(&self, input: &[i16; 16], output: &mut [i16; 16]) {
139 let mut temp = [[0i64; 4]; 4];
141 for j in 0..4 {
142 for i in 0..4 {
143 let mut sum = 0i64;
144 for k in 0..4 {
145 sum += i64::from(input[k * 4 + j]) * i64::from(DCT4_COEFFS[k][i]);
146 }
147 temp[i][j] = sum;
148 }
149 }
150
151 for i in 0..4 {
154 for j in 0..4 {
155 let mut sum = 0i64;
156 for k in 0..4 {
157 sum += temp[i][k] * i64::from(DCT4_COEFFS[k][j]);
158 }
159 output[i * 4 + j] = ((sum + 32768) >> 16) as i16;
161 }
162 }
163 }
164
165 #[allow(dead_code)]
167 pub fn forward_dct_8x8(&self, input: &[i16; 64], output: &mut [i16; 64]) {
168 let mut temp = [[0i32; 8]; 8];
170 for i in 0..8 {
171 for j in 0..8 {
172 let mut sum = 0i32;
173 for k in 0..8 {
174 sum += i32::from(input[i * 8 + k]) * i32::from(DCT8_COEFFS[j][k]);
175 }
176 temp[i][j] = (sum + 32) >> 6;
177 }
178 }
179
180 for j in 0..8 {
182 for i in 0..8 {
183 let mut sum = 0i32;
184 for k in 0..8 {
185 sum += temp[k][j] * i32::from(DCT8_COEFFS[i][k]);
186 }
187 output[i * 8 + j] = ((sum + 32) >> 6) as i16;
188 }
189 }
190 }
191
192 #[allow(dead_code)]
194 pub fn inverse_dct_8x8(&self, input: &[i16; 64], output: &mut [i16; 64]) {
195 let mut temp = [[0i64; 8]; 8];
197 for j in 0..8 {
198 for i in 0..8 {
199 let mut sum = 0i64;
200 for k in 0..8 {
201 sum += i64::from(input[k * 8 + j]) * i64::from(DCT8_COEFFS[k][i]);
202 }
203 temp[i][j] = sum;
204 }
205 }
206
207 for i in 0..8 {
210 for j in 0..8 {
211 let mut sum = 0i64;
212 for k in 0..8 {
213 sum += temp[i][k] * i64::from(DCT8_COEFFS[k][j]);
214 }
215 output[i * 8 + j] = ((sum + 131_072) >> 18) as i16;
217 }
218 }
219 }
220
221 #[allow(dead_code)]
225 pub fn forward_dct_16x16(&self, input: &[i16; 256], output: &mut [i16; 256]) {
226 self.forward_dct_nxn::<16>(input, output);
229 }
230
231 #[allow(dead_code)]
233 pub fn inverse_dct_16x16(&self, input: &[i16; 256], output: &mut [i16; 256]) {
234 self.inverse_dct_nxn::<16>(input, output);
235 }
236
237 #[allow(dead_code)]
239 pub fn forward_dct_32x32(&self, input: &[i16; 1024], output: &mut [i16; 1024]) {
240 self.forward_dct_nxn::<32>(input, output);
241 }
242
243 #[allow(dead_code)]
245 pub fn inverse_dct_32x32(&self, input: &[i16; 1024], output: &mut [i16; 1024]) {
246 self.inverse_dct_nxn::<32>(input, output);
247 }
248
249 #[allow(dead_code, clippy::unused_self)]
251 fn forward_dct_nxn<const N: usize>(&self, input: &[i16], output: &mut [i16]) {
252 let coeffs = generate_dct_coeffs::<N>();
253
254 let mut temp = vec![0i32; N * N];
256 for i in 0..N {
257 for j in 0..N {
258 let mut sum = 0i32;
259 for k in 0..N {
260 sum += i32::from(input[i * N + k]) * coeffs[j][k];
261 }
262 temp[i * N + j] = (sum + 32) >> 6;
263 }
264 }
265
266 for j in 0..N {
268 for i in 0..N {
269 let mut sum = 0i32;
270 for k in 0..N {
271 sum += temp[k * N + j] * coeffs[i][k];
272 }
273 output[i * N + j] = ((sum + 32) >> 6) as i16;
274 }
275 }
276 }
277
278 #[allow(dead_code, clippy::unused_self)]
280 fn inverse_dct_nxn<const N: usize>(&self, input: &[i16], output: &mut [i16]) {
281 let coeffs = generate_dct_coeffs::<N>();
282
283 let n_shift = (N as u32).trailing_zeros();
286 let total_shift = 12 + 2 * n_shift;
287 let round = 1i64 << (total_shift - 1);
288
289 let mut temp = vec![0i64; N * N];
291 for j in 0..N {
292 for i in 0..N {
293 let mut sum = 0i64;
294 for k in 0..N {
295 sum += i64::from(input[k * N + j]) * i64::from(coeffs[k][i]);
296 }
297 temp[i * N + j] = sum;
298 }
299 }
300
301 for i in 0..N {
303 for j in 0..N {
304 let mut sum = 0i64;
305 for k in 0..N {
306 sum += temp[i * N + k] * i64::from(coeffs[k][j]);
307 }
308 output[i * N + j] = ((sum + round) >> total_shift) as i16;
309 }
310 }
311 }
312
313 #[inline]
315 #[allow(dead_code)]
316 pub fn butterfly_add(&self, a: I16x8, b: I16x8) -> I16x8 {
317 self.simd.add_i16x8(a, b)
318 }
319
320 #[inline]
322 #[allow(dead_code)]
323 pub fn butterfly_sub(&self, a: I16x8, b: I16x8) -> I16x8 {
324 self.simd.sub_i16x8(a, b)
325 }
326
327 #[inline]
329 #[allow(dead_code)]
330 pub fn dct_madd(&self, a: I16x8, coeff: I16x8) -> I32x4 {
331 self.simd.pmaddwd(a, coeff)
332 }
333}
334
335#[allow(clippy::cast_precision_loss)]
339fn generate_dct_coeffs<const N: usize>() -> Vec<Vec<i32>> {
340 let mut coeffs = vec![vec![0i32; N]; N];
341 let pi = std::f64::consts::PI;
342 let n_f64 = N as f64;
343
344 for k in 0..N {
345 for n in 0..N {
346 let angle = pi * (k as f64) * (2.0 * (n as f64) + 1.0) / (2.0 * n_f64);
347 coeffs[k][n] = (angle.cos() * 64.0).round() as i32;
348 }
349 }
350
351 coeffs
352}
353
354#[inline]
356#[must_use]
357pub fn dct_ops() -> DctOps<ScalarFallback> {
358 DctOps::new(ScalarFallback::new())
359}
360
361#[allow(dead_code)]
368pub fn quantize_4x4(coeffs: &[i16; 16], qp: u8, output: &mut [i16; 16]) {
369 let scale: i32 = 1 << (15 - (qp / 6));
371
372 for (i, &c) in coeffs.iter().enumerate() {
373 let val = i32::from(c);
374 let sign = if val < 0 { -1i32 } else { 1i32 };
375 output[i] = (sign * ((val.abs() * scale + (1 << 14)) >> 15)) as i16;
376 }
377}
378
379#[allow(dead_code)]
381pub fn dequantize_4x4(coeffs: &[i16; 16], qp: u8, output: &mut [i16; 16]) {
382 let scale = 1 << (qp / 6);
383
384 for (i, &c) in coeffs.iter().enumerate() {
385 output[i] = (i32::from(c) * scale) as i16;
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn test_dct4_coeffs_orthogonality() {
395 for i in 0..4 {
397 for j in i + 1..4 {
398 let dot: i32 = (0..4)
399 .map(|k| i32::from(DCT4_COEFFS[i][k]) * i32::from(DCT4_COEFFS[j][k]))
400 .sum();
401 assert!(
403 dot.abs() < 100,
404 "Rows {} and {} not orthogonal: {}",
405 i,
406 j,
407 dot
408 );
409 }
410 }
411 }
412
413 #[test]
414 fn test_forward_inverse_4x4_identity() {
415 let ops = dct_ops();
416
417 let input = [
419 100, 102, 104, 106, 110, 112, 114, 116, 120, 122, 124, 126, 130, 132, 134, 136,
420 ];
421
422 let mut dct_output = [0i16; 16];
423 let mut reconstructed = [0i16; 16];
424
425 ops.forward_dct_4x4(&input, &mut dct_output);
426 ops.inverse_dct_4x4(&dct_output, &mut reconstructed);
427
428 for i in 0..16 {
430 let diff = (i32::from(input[i]) - i32::from(reconstructed[i])).abs();
431 assert!(
432 diff <= 2,
433 "Mismatch at {}: {} vs {}",
434 i,
435 input[i],
436 reconstructed[i]
437 );
438 }
439 }
440
441 #[test]
442 fn test_forward_inverse_8x8_identity() {
443 let ops = dct_ops();
444
445 let input = [128i16; 64];
447 let mut dct_output = [0i16; 64];
448 let mut reconstructed = [0i16; 64];
449
450 ops.forward_dct_8x8(&input, &mut dct_output);
451
452 assert!(dct_output[0].abs() > 100);
454 for i in 1..64 {
455 assert!(
456 dct_output[i].abs() < 10,
457 "Non-DC coeff {} too large: {}",
458 i,
459 dct_output[i]
460 );
461 }
462
463 ops.inverse_dct_8x8(&dct_output, &mut reconstructed);
464
465 for i in 0..64 {
467 let diff = (i32::from(input[i]) - i32::from(reconstructed[i])).abs();
468 assert!(
469 diff <= 2,
470 "Mismatch at {}: {} vs {}",
471 i,
472 input[i],
473 reconstructed[i]
474 );
475 }
476 }
477
478 #[test]
479 fn test_dct_zero_input() {
480 let ops = dct_ops();
481
482 let input = [0i16; 16];
483 let mut output = [1i16; 16]; ops.forward_dct_4x4(&input, &mut output);
486
487 for (i, &v) in output.iter().enumerate() {
489 assert_eq!(v, 0, "Non-zero output at {}: {}", i, v);
490 }
491 }
492
493 #[test]
494 fn test_dct_dc_only() {
495 let ops = dct_ops();
496
497 let input = [64i16; 16];
499 let mut output = [0i16; 16];
500
501 ops.forward_dct_4x4(&input, &mut output);
502
503 assert!(output[0] != 0);
505
506 for (i, &v) in output.iter().enumerate().skip(1) {
508 assert!(v.abs() < 5, "AC coeff {} too large: {}", i, v);
509 }
510 }
511
512 #[test]
513 fn test_quantize_dequantize() {
514 let coeffs = [100i16, -50, 25, -12, 6, -3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0];
515 let mut quantized = [0i16; 16];
516 let mut dequantized = [0i16; 16];
517
518 quantize_4x4(&coeffs, 20, &mut quantized);
520
521 assert!(quantized[0] != 0);
523
524 dequantize_4x4(&quantized, 20, &mut dequantized);
526
527 let dc_diff = (i32::from(coeffs[0]) - i32::from(dequantized[0])).abs();
529 assert!(
530 dc_diff < i32::from(coeffs[0]) / 2,
531 "DC diff too large: {}",
532 dc_diff
533 );
534 }
535
536 #[test]
537 fn test_generate_dct_coeffs() {
538 let coeffs = generate_dct_coeffs::<4>();
539
540 assert_eq!(coeffs.len(), 4);
541 assert_eq!(coeffs[0].len(), 4);
542
543 for &c in &coeffs[0] {
545 assert!(c > 0);
546 }
547 }
548
549 #[test]
550 fn test_dct8_coeffs() {
551 assert_eq!(DCT8_COEFFS[0], [64, 64, 64, 64, 64, 64, 64, 64]);
554
555 assert_eq!(DCT8_COEFFS[4], [64, -64, -64, 64, 64, -64, -64, 64]);
557 }
558}