#![allow(dead_code)]
use crate::foundation::consts::DCT_BLOCK_SIZE;
#[cfg(target_arch = "x86_64")]
use archmage::SimdToken;
use archmage::prelude::*;
use magetypes::simd::generic::f32x8 as GenericF32x8;
const WC4: [f32; 2] = [0.541196100146197, 1.3065629648763764];
const WC8: [f32; 4] = [
0.5097955791041592,
0.6013448869350453,
0.8999762231364156,
2.5629154477415055,
];
const SQRT2: f32 = 1.41421356237;
#[cfg(test)]
#[inline]
fn transpose_8x8(input: &[f32; 64], output: &mut [f32; 64]) {
for row in 0..8 {
for col in 0..8 {
output[col * 8 + row] = input[row * 8 + col];
}
}
}
mod simd {
use super::*;
#[inline(always)]
fn idct_1d_vec_generic<T: magetypes::simd::backends::F32x8Backend>(
token: T,
cols: [GenericF32x8<T>; 8],
) -> [GenericF32x8<T>; 8] {
#[allow(non_camel_case_types)]
type f32x8<U> = GenericF32x8<U>;
let [c0, c1, c2, c3, c4, c5, c6, c7] = cols;
let t0 = c0;
let t1 = c2;
let t2 = c4;
let t3 = c6;
let t4 = c1;
let t5 = c3;
let t6 = c5;
let t7 = c7;
let e0 = t0;
let e1 = t2;
let o0 = t1;
let o1 = t3;
let e00 = e0 + e1;
let e01 = e0 - e1;
let o1b = o1 + o0;
let sqrt2 = f32x8::splat(token, SQRT2);
let o0b = o0 * sqrt2;
let o00 = o0b + o1b;
let o01 = o0b - o1b;
let wc4_0 = f32x8::splat(token, WC4[0]);
let wc4_1 = f32x8::splat(token, WC4[1]);
let prod0 = wc4_0 * o00;
let prod1 = wc4_1 * o01;
let r0 = e00 + prod0;
let r3 = e00 - prod0;
let r1 = e01 + prod1;
let r2 = e01 - prod1;
let t7b = t7 + t6;
let t6b = t6 + t5;
let t5b = t5 + t4;
let t4b = t4 * sqrt2;
let e0o = t4b;
let e1o = t6b;
let o0o = t5b;
let o1o = t7b;
let e00o = e0o + e1o;
let e01o = e0o - e1o;
let o1bo = o1o + o0o;
let o0bo = o0o * sqrt2;
let o00o = o0bo + o1bo;
let o01o = o0bo - o1bo;
let prod0o = wc4_0 * o00o;
let prod1o = wc4_1 * o01o;
let r0o = e00o + prod0o;
let r3o = e00o - prod0o;
let r1o = e01o + prod1o;
let r2o = e01o - prod1o;
let wc8 = [
f32x8::splat(token, WC8[0]),
f32x8::splat(token, WC8[1]),
f32x8::splat(token, WC8[2]),
f32x8::splat(token, WC8[3]),
];
let prod8_0 = wc8[0] * r0o;
let prod8_1 = wc8[1] * r1o;
let prod8_2 = wc8[2] * r2o;
let prod8_3 = wc8[3] * r3o;
[
r0 + prod8_0, r1 + prod8_1, r2 + prod8_2, r3 + prod8_3, r3 - prod8_3, r2 - prod8_2, r1 - prod8_1, r0 - prod8_0, ]
}
#[allow(dead_code)]
#[inline]
pub fn idct1d_fast(input: &[f32], output: &mut [f32]) {
const A1: f32 = 0.707_106_77;
const A2: f32 = 0.541_196_1;
const A3: f32 = 0.707_106_77;
const A4: f32 = 1.306_562_96;
const A5: f32 = 0.382_683_43;
let v0 = input[0];
let v1 = input[1];
let v2 = input[2];
let v3 = input[3];
let v4 = input[4];
let v5 = input[5];
let v6 = input[6];
let v7 = input[7];
let t0 = v0;
let t1 = v4;
let t2 = v2;
let t3 = v6;
let t4 = v1;
let t5 = v5;
let t6 = v3;
let t7 = v7;
let s0 = t0 + t1;
let s1 = t0 - t1;
let s2 = t2 * A1 - t3 * A1;
let s3 = t2 * A1 + t3 * A1;
let e0 = s0 + s3;
let e1 = s1 + s2;
let e2 = s1 - s2;
let e3 = s0 - s3;
let p4 = t4 + t7;
let p5 = t5 + t6;
let p6 = t4 + t6;
let p7 = t5 + t7;
let z5 = (p6 - p7) * A5;
let o4 = t4 * A2 + z5 + p4 * (-A2 - A5);
let o5 = t5 * A3 + p5 * (-A3);
let o6 = t6 * A4 + z5 + p6 * (-A4 + A5);
let o7 = t7 * A1;
let q4 = o4 + o7;
let q5 = o5 + o6;
let q6 = o5 - o6;
let q7 = -o4 + o7;
output[0] = e0 + q4;
output[1] = e1 + q5;
output[2] = e2 + q6;
output[3] = e3 + q7;
output[4] = e3 - q7;
output[5] = e2 - q6;
output[6] = e1 - q5;
output[7] = e0 - q4;
}
pub fn inverse_dct_8x8_simd(input: &[f32; 64]) -> [f32; 64] {
incant!(inverse_dct_8x8_simd_impl(input))
}
#[magetypes(v3, neon, wasm128, scalar)]
fn inverse_dct_8x8_simd_impl(token: Token, input: &[f32; 64]) -> [f32; 64] {
#[allow(non_camel_case_types)]
type f32x8 = GenericF32x8<Token>;
let mut rows = f32x8::load_8x8(input);
f32x8::transpose_8x8(&mut rows);
let cols_after_row = idct_1d_vec_generic(token, rows);
let mut rows_for_col = cols_after_row;
f32x8::transpose_8x8(&mut rows_for_col);
let final_rows = idct_1d_vec_generic(token, rows_for_col);
let scale = f32x8::splat(token, 1.0 / 8.0);
let scaled: [f32x8; 8] = core::array::from_fn(|i| final_rows[i] * scale);
let mut output = [0.0f32; 64];
f32x8::store_8x8(&scaled, &mut output);
output
}
pub fn transpose_8x8_simd(input: &[f32; 64], output: &mut [f32; 64]) {
incant!(transpose_8x8_simd_impl(input, output));
}
#[magetypes(v3, neon, wasm128, scalar)]
fn transpose_8x8_simd_impl(_token: Token, input: &[f32; 64], output: &mut [f32; 64]) {
#[allow(non_camel_case_types)]
type f32x8 = GenericF32x8<Token>;
let mut rows = f32x8::load_8x8(input);
f32x8::transpose_8x8(&mut rows);
f32x8::store_8x8(&rows, output);
}
}
#[cfg(target_arch = "x86_64")]
mod archmage_idct {
use archmage::{arcane, rite};
use safe_unaligned_simd::x86_64 as safe_simd;
use super::{SQRT2, WC4, WC8};
#[allow(unused_imports)]
use core::arch::x86_64::*;
#[rite]
fn mage_idct1d_4(_token: archmage::X64V3Token, m: &mut [__m256; 4]) {
let wc4_0 = _mm256_set1_ps(WC4[0]);
let wc4_1 = _mm256_set1_ps(WC4[1]);
let sqrt2 = _mm256_set1_ps(SQRT2);
let e0 = m[0];
let e1 = m[2];
let o0 = m[1];
let o1 = m[3];
let e00 = _mm256_add_ps(e0, e1);
let e01 = _mm256_sub_ps(e0, e1);
let o1b = _mm256_add_ps(o1, o0);
let o0b = _mm256_mul_ps(o0, sqrt2);
let o00 = _mm256_add_ps(o0b, o1b);
let o01 = _mm256_sub_ps(o0b, o1b);
let prod0 = _mm256_mul_ps(wc4_0, o00);
let prod1 = _mm256_mul_ps(wc4_1, o01);
m[0] = _mm256_add_ps(e00, prod0);
m[1] = _mm256_add_ps(e01, prod1);
m[2] = _mm256_sub_ps(e01, prod1);
m[3] = _mm256_sub_ps(e00, prod0);
}
#[rite]
fn mage_idct1d_8(token: archmage::X64V3Token, m: &mut [__m256; 8]) {
let sqrt2 = _mm256_set1_ps(SQRT2);
let t0 = m[0]; let t1 = m[2]; let t2 = m[4]; let t3 = m[6]; let t4 = m[1]; let t5 = m[3]; let t6 = m[5]; let t7 = m[7];
let mut even = [t0, t1, t2, t3];
mage_idct1d_4(token, &mut even);
let t7b = _mm256_add_ps(t7, t6);
let t6b = _mm256_add_ps(t6, t5);
let t5b = _mm256_add_ps(t5, t4);
let t4b = _mm256_mul_ps(t4, sqrt2);
let mut odd = [t4b, t5b, t6b, t7b];
mage_idct1d_4(token, &mut odd);
let wc8 = [
_mm256_set1_ps(WC8[0]),
_mm256_set1_ps(WC8[1]),
_mm256_set1_ps(WC8[2]),
_mm256_set1_ps(WC8[3]),
];
let prod0 = _mm256_mul_ps(wc8[0], odd[0]);
let prod1 = _mm256_mul_ps(wc8[1], odd[1]);
let prod2 = _mm256_mul_ps(wc8[2], odd[2]);
let prod3 = _mm256_mul_ps(wc8[3], odd[3]);
m[0] = _mm256_add_ps(even[0], prod0);
m[1] = _mm256_add_ps(even[1], prod1);
m[2] = _mm256_add_ps(even[2], prod2);
m[3] = _mm256_add_ps(even[3], prod3);
m[4] = _mm256_sub_ps(even[3], prod3);
m[5] = _mm256_sub_ps(even[2], prod2);
m[6] = _mm256_sub_ps(even[1], prod1);
m[7] = _mm256_sub_ps(even[0], prod0);
}
#[rite]
fn mage_transpose_8x8_inplace(_token: archmage::X64V3Token, r: &mut [__m256; 8]) {
let q0 = _mm256_unpacklo_ps(r[0], r[2]);
let q1 = _mm256_unpacklo_ps(r[1], r[3]);
let q2 = _mm256_unpackhi_ps(r[0], r[2]);
let q3 = _mm256_unpackhi_ps(r[1], r[3]);
let q4 = _mm256_unpacklo_ps(r[4], r[6]);
let q5 = _mm256_unpacklo_ps(r[5], r[7]);
let q6 = _mm256_unpackhi_ps(r[4], r[6]);
let q7 = _mm256_unpackhi_ps(r[5], r[7]);
let s0 = _mm256_unpacklo_ps(q0, q1);
let s1 = _mm256_unpackhi_ps(q0, q1);
let s2 = _mm256_unpacklo_ps(q2, q3);
let s3 = _mm256_unpackhi_ps(q2, q3);
let s4 = _mm256_unpacklo_ps(q4, q5);
let s5 = _mm256_unpackhi_ps(q4, q5);
let s6 = _mm256_unpacklo_ps(q6, q7);
let s7 = _mm256_unpackhi_ps(q6, q7);
r[0] = _mm256_permute2f128_ps::<0x20>(s0, s4);
r[1] = _mm256_permute2f128_ps::<0x20>(s1, s5);
r[2] = _mm256_permute2f128_ps::<0x20>(s2, s6);
r[3] = _mm256_permute2f128_ps::<0x20>(s3, s7);
r[4] = _mm256_permute2f128_ps::<0x31>(s0, s4);
r[5] = _mm256_permute2f128_ps::<0x31>(s1, s5);
r[6] = _mm256_permute2f128_ps::<0x31>(s2, s6);
r[7] = _mm256_permute2f128_ps::<0x31>(s3, s7);
}
#[arcane]
#[inline]
pub fn mage_inverse_dct_8x8(
token: archmage::X64V3Token,
input: &[f32; 64],
output: &mut [f32; 64],
) {
let scale = _mm256_set1_ps(1.0 / 8.0);
let mut reg = [
safe_simd::_mm256_loadu_ps(<&[f32; 8]>::try_from(&input[0..8]).unwrap()),
safe_simd::_mm256_loadu_ps(<&[f32; 8]>::try_from(&input[8..16]).unwrap()),
safe_simd::_mm256_loadu_ps(<&[f32; 8]>::try_from(&input[16..24]).unwrap()),
safe_simd::_mm256_loadu_ps(<&[f32; 8]>::try_from(&input[24..32]).unwrap()),
safe_simd::_mm256_loadu_ps(<&[f32; 8]>::try_from(&input[32..40]).unwrap()),
safe_simd::_mm256_loadu_ps(<&[f32; 8]>::try_from(&input[40..48]).unwrap()),
safe_simd::_mm256_loadu_ps(<&[f32; 8]>::try_from(&input[48..56]).unwrap()),
safe_simd::_mm256_loadu_ps(<&[f32; 8]>::try_from(&input[56..64]).unwrap()),
];
mage_transpose_8x8_inplace(token, &mut reg);
mage_idct1d_8(token, &mut reg);
mage_transpose_8x8_inplace(token, &mut reg);
mage_idct1d_8(token, &mut reg);
safe_simd::_mm256_storeu_ps(
<&mut [f32; 8]>::try_from(&mut output[0..8]).unwrap(),
_mm256_mul_ps(reg[0], scale),
);
safe_simd::_mm256_storeu_ps(
<&mut [f32; 8]>::try_from(&mut output[8..16]).unwrap(),
_mm256_mul_ps(reg[1], scale),
);
safe_simd::_mm256_storeu_ps(
<&mut [f32; 8]>::try_from(&mut output[16..24]).unwrap(),
_mm256_mul_ps(reg[2], scale),
);
safe_simd::_mm256_storeu_ps(
<&mut [f32; 8]>::try_from(&mut output[24..32]).unwrap(),
_mm256_mul_ps(reg[3], scale),
);
safe_simd::_mm256_storeu_ps(
<&mut [f32; 8]>::try_from(&mut output[32..40]).unwrap(),
_mm256_mul_ps(reg[4], scale),
);
safe_simd::_mm256_storeu_ps(
<&mut [f32; 8]>::try_from(&mut output[40..48]).unwrap(),
_mm256_mul_ps(reg[5], scale),
);
safe_simd::_mm256_storeu_ps(
<&mut [f32; 8]>::try_from(&mut output[48..56]).unwrap(),
_mm256_mul_ps(reg[6], scale),
);
safe_simd::_mm256_storeu_ps(
<&mut [f32; 8]>::try_from(&mut output[56..64]).unwrap(),
_mm256_mul_ps(reg[7], scale),
);
}
}
#[cfg(test)]
#[inline]
fn forward_even_odd<const N: usize>(input: &[f32], output: &mut [f32]) {
let half = N / 2;
for i in 0..half {
output[i] = input[2 * i];
output[half + i] = input[2 * i + 1];
}
}
#[cfg(test)]
#[inline]
fn b_transpose<const N: usize>(coeff: &mut [f32]) {
for i in (1..N).rev() {
coeff[i] += coeff[i - 1];
}
coeff[0] *= SQRT2;
}
#[cfg(test)]
#[inline]
fn multiply_and_add_8(input: &[f32], output: &mut [f32]) {
for i in 0..4 {
let even = input[i];
let odd = input[4 + i];
let prod = WC8[i] * odd;
output[i] = even + prod;
output[7 - i] = even - prod;
}
}
#[cfg(test)]
#[inline]
fn multiply_and_add_4(input: &[f32], output: &mut [f32]) {
for i in 0..2 {
let even = input[i];
let odd = input[2 + i];
let prod = WC4[i] * odd;
output[i] = even + prod;
output[3 - i] = even - prod;
}
}
#[cfg(test)]
#[inline]
fn idct1d_2(input: &[f32], output: &mut [f32]) {
let in1 = input[0];
let in2 = input[1];
output[0] = in1 + in2;
output[1] = in1 - in2;
}
#[cfg(test)]
fn idct1d_4(input: &[f32], output: &mut [f32]) {
let mut tmp = [0.0f32; 4];
forward_even_odd::<4>(input, &mut tmp);
let (first, second) = tmp.split_at_mut(2);
let mut first_out = [0.0f32; 2];
idct1d_2(first, &mut first_out);
first[0] = first_out[0];
first[1] = first_out[1];
second[1] += second[0];
second[0] *= SQRT2;
let mut second_out = [0.0f32; 2];
idct1d_2(second, &mut second_out);
second[0] = second_out[0];
second[1] = second_out[1];
let tmp_combined = [first[0], first[1], second[0], second[1]];
multiply_and_add_4(&tmp_combined, output);
}
#[cfg(test)]
fn idct1d_8(input: &[f32], output: &mut [f32]) {
let mut tmp = [0.0f32; 8];
forward_even_odd::<8>(input, &mut tmp);
let mut first_out = [0.0f32; 4];
idct1d_4(&tmp[0..4], &mut first_out);
tmp[0] = first_out[0];
tmp[1] = first_out[1];
tmp[2] = first_out[2];
tmp[3] = first_out[3];
b_transpose::<4>(&mut tmp[4..8]);
let mut second_out = [0.0f32; 4];
idct1d_4(&tmp[4..8], &mut second_out);
tmp[4] = second_out[0];
tmp[5] = second_out[1];
tmp[6] = second_out[2];
tmp[7] = second_out[3];
multiply_and_add_8(&tmp, output);
}
#[cfg(test)]
fn idct_rows(input: &[f32; 64], output: &mut [f32; 64]) {
for row in 0..8 {
let in_slice = &input[row * 8..(row + 1) * 8];
let mut out_row = [0.0f32; 8];
idct1d_8(in_slice, &mut out_row);
for i in 0..8 {
output[row * 8 + i] = out_row[i];
}
}
}
#[inline]
fn is_dc_only(input: &[f32; DCT_BLOCK_SIZE]) -> bool {
input[1..].iter().all(|&v| v.abs() < 1e-10)
}
#[must_use]
pub fn inverse_dct_8x8(input: &[f32; DCT_BLOCK_SIZE]) -> [f32; DCT_BLOCK_SIZE] {
if is_dc_only(input) {
let dc_value = input[0] / 8.0;
return [dc_value; DCT_BLOCK_SIZE];
}
#[cfg(target_arch = "x86_64")]
if let Some(token) = archmage::X64V3Token::summon() {
let mut output = [0.0f32; DCT_BLOCK_SIZE];
archmage_idct::mage_inverse_dct_8x8(token, input, &mut output);
return output;
}
simd::inverse_dct_8x8_simd(input)
}
#[must_use]
pub fn inverse_dct_8x8_u8(input: &[f32; DCT_BLOCK_SIZE]) -> [u8; DCT_BLOCK_SIZE] {
let output = inverse_dct_8x8(input);
let mut result = [0u8; DCT_BLOCK_SIZE];
for (i, &v) in output.iter().enumerate() {
let shifted = v + 128.0;
let rounded = shifted.round();
result[i] = rounded.clamp(0.0, 255.0) as u8;
}
result
}
pub fn inverse_dct_blocks(blocks: &[[f32; DCT_BLOCK_SIZE]]) -> Vec<[f32; DCT_BLOCK_SIZE]> {
blocks.iter().map(inverse_dct_8x8).collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::encode::dct::forward_dct_8x8;
#[test]
fn test_dct_idct_roundtrip() {
let mut input = [0.0f32; DCT_BLOCK_SIZE];
for i in 0..DCT_BLOCK_SIZE {
input[i] = (i as f32 * 3.7).sin() * 100.0;
}
let dct = forward_dct_8x8(&input);
let mut dct_scaled = dct;
for v in &mut dct_scaled {
*v *= 8.0;
}
let recovered = inverse_dct_8x8(&dct_scaled);
for i in 0..DCT_BLOCK_SIZE {
assert!(
(input[i] - recovered[i]).abs() < 0.1,
"Mismatch at {}: {} vs {}",
i,
input[i],
recovered[i]
);
}
}
#[test]
fn test_idct_dc_only() {
let mut input = [0.0f32; DCT_BLOCK_SIZE];
input[0] = 64.0;
let output = inverse_dct_8x8(&input);
let first = output[0];
for (i, &v) in output.iter().enumerate() {
assert!(
(v - first).abs() < 0.01,
"Value at {} differs: {} vs {}",
i,
v,
first
);
}
}
#[test]
fn test_idct_u8_clamping() {
let mut input = [0.0f32; DCT_BLOCK_SIZE];
input[0] = 1000.0;
let output = inverse_dct_8x8_u8(&input);
let max_val = output.iter().copied().max().unwrap();
assert!(
max_val > 200,
"Large DC should produce high values, got max={}",
max_val
);
input[0] = -1000.0;
let output_neg = inverse_dct_8x8_u8(&input);
let min_val_neg = output_neg.iter().copied().min().unwrap();
assert!(
min_val_neg < 10,
"Large negative DC should produce low values clamped near 0, got min={}",
min_val_neg
);
}
#[test]
fn test_roundtrip_random_patterns() {
for seed in 0..10 {
let mut input = [0.0f32; DCT_BLOCK_SIZE];
for i in 0..DCT_BLOCK_SIZE {
input[i] = ((i + seed * 7) as f32 * 1.23).sin() * 127.0;
}
let dct = forward_dct_8x8(&input);
let mut dct_scaled = dct;
for v in &mut dct_scaled {
*v *= 8.0;
}
let recovered = inverse_dct_8x8(&dct_scaled);
for i in 0..DCT_BLOCK_SIZE {
assert!(
(input[i] - recovered[i]).abs() < 0.1,
"Pattern {}: Mismatch at {}: {} vs {}",
seed,
i,
input[i],
recovered[i]
);
}
}
}
#[test]
fn test_dc_value_preservation() {
let value = 50.0f32;
let input = [value; DCT_BLOCK_SIZE];
let dct = forward_dct_8x8(&input);
let mut dct_scaled = dct;
for v in &mut dct_scaled {
*v *= 8.0;
}
let recovered = inverse_dct_8x8(&dct_scaled);
for (i, &v) in recovered.iter().enumerate() {
assert!(
(v - value).abs() < 0.01,
"DC preservation failed at {}: {} vs {}",
i,
v,
value
);
}
}
#[test]
fn test_inverse_dct_blocks() {
let mut block1 = [0.0f32; DCT_BLOCK_SIZE];
block1[0] = 64.0;
let mut block2 = [0.0f32; DCT_BLOCK_SIZE];
block2[0] = 32.0;
let blocks = vec![block1, block2];
let results = inverse_dct_blocks(&blocks);
assert_eq!(results.len(), 2);
let first_val = results[0][0];
for &v in &results[0] {
assert!((v - first_val).abs() < 0.01);
}
}
#[test]
fn test_inverse_dct_8x8_u8_negative() {
let mut input = [0.0f32; DCT_BLOCK_SIZE];
input[0] = -2000.0;
let output = inverse_dct_8x8_u8(&input);
for &v in &output {
assert_eq!(v, 0);
}
}
#[test]
fn test_is_dc_only() {
let dc_only = [
100.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
assert!(is_dc_only(&dc_only));
let not_dc_only = [
100.0f32, 50.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
assert!(!is_dc_only(¬_dc_only));
}
#[test]
fn test_transpose_8x8() {
let mut input = [0.0f32; 64];
for i in 0..64 {
input[i] = i as f32;
}
let mut output = [0.0f32; 64];
transpose_8x8(&input, &mut output);
for row in 0..8 {
for col in 0..8 {
assert_eq!(output[col * 8 + row], input[row * 8 + col]);
}
}
}
#[test]
fn test_idct1d_2() {
let input = [3.0f32, 1.0];
let mut output = [0.0f32; 2];
idct1d_2(&input, &mut output);
assert_eq!(output[0], 4.0); assert_eq!(output[1], 2.0); }
#[test]
fn test_forward_even_odd() {
let input = [0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
let mut output = [0.0f32; 8];
forward_even_odd::<8>(&input, &mut output);
assert_eq!(output[0], 0.0);
assert_eq!(output[1], 2.0);
assert_eq!(output[2], 4.0);
assert_eq!(output[3], 6.0);
assert_eq!(output[4], 1.0);
assert_eq!(output[5], 3.0);
assert_eq!(output[6], 5.0);
assert_eq!(output[7], 7.0);
}
#[test]
fn test_b_transpose() {
let mut coeff = [1.0f32, 2.0, 3.0, 4.0];
b_transpose::<4>(&mut coeff);
assert!((coeff[0] - SQRT2).abs() < 0.001);
assert_eq!(coeff[1], 3.0);
}
#[test]
fn test_multiply_and_add_4() {
let input = [10.0f32, 20.0, 5.0, 8.0]; let mut output = [0.0f32; 4];
multiply_and_add_4(&input, &mut output);
assert!((output[0] - (10.0 + WC4[0] * 5.0)).abs() < 0.001);
assert!((output[3] - (10.0 - WC4[0] * 5.0)).abs() < 0.001);
}
#[test]
fn test_multiply_and_add_8() {
let input = [1.0f32, 2.0, 3.0, 4.0, 0.5, 0.6, 0.7, 0.8];
let mut output = [0.0f32; 8];
multiply_and_add_8(&input, &mut output);
assert!((output[0] - (1.0 + WC8[0] * 0.5)).abs() < 0.001);
assert!((output[7] - (1.0 - WC8[0] * 0.5)).abs() < 0.001);
}
#[test]
fn test_idct1d_4() {
let input = [10.0f32, 5.0, 2.0, 1.0];
let mut output = [0.0f32; 4];
idct1d_4(&input, &mut output);
for v in &output {
assert!(v.is_finite());
}
}
#[test]
fn test_idct1d_8() {
let input = [64.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let mut output = [0.0f32; 8];
idct1d_8(&input, &mut output);
let first = output[0];
for v in &output {
assert!((v - first).abs() < 0.01);
}
}
#[test]
fn test_idct_rows() {
let mut input = [0.0f32; 64];
input[0] = 64.0; let mut output = [0.0f32; 64];
idct_rows(&input, &mut output);
let first = output[0];
for i in 0..8 {
assert!((output[i] - first).abs() < 0.01);
}
}
#[test]
fn test_simd_transpose_matches_scalar() {
let mut input = [0.0f32; 64];
for i in 0..64 {
input[i] = (i as f32 * 1.5).sin() * 100.0;
}
let mut scalar_output = [0.0f32; 64];
let mut simd_output = [0.0f32; 64];
transpose_8x8(&input, &mut scalar_output);
simd::transpose_8x8_simd(&input, &mut simd_output);
for i in 0..64 {
assert!(
(scalar_output[i] - simd_output[i]).abs() < 0.001,
"Mismatch at {}: {} vs {}",
i,
scalar_output[i],
simd_output[i]
);
}
}
#[test]
fn test_simd_idct_matches_scalar() {
let mut input = [0.0f32; 64];
for i in 0..64 {
input[i] = ((i * 17) as f32).sin() * 50.0;
}
let simd_result = simd::inverse_dct_8x8_simd(&input);
let mut block0 = input;
let mut block1 = [0.0f32; 64];
transpose_8x8(&block0, &mut block1);
idct_rows(&block1, &mut block0);
transpose_8x8(&block0, &mut block1);
let mut scalar_result = [0.0f32; 64];
idct_rows(&block1, &mut scalar_result);
for v in &mut scalar_result {
*v /= 8.0;
}
for i in 0..64 {
assert!(
(scalar_result[i] - simd_result[i]).abs() < 0.1,
"SIMD/scalar mismatch at {}: {} vs {}",
i,
scalar_result[i],
simd_result[i]
);
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_archmage_idct_matches_wide() {
use archmage::SimdToken;
let Some(token) = archmage::X64V3Token::summon() else {
return; };
for seed in 0..10 {
let mut input = [0.0f32; 64];
for i in 0..64 {
input[i] = ((i * 17 + seed * 31) as f32).sin() * 50.0;
}
let wide_result = simd::inverse_dct_8x8_simd(&input);
let mut mage_result = [0.0f32; 64];
archmage_idct::mage_inverse_dct_8x8(token, &input, &mut mage_result);
for i in 0..64 {
assert!(
(wide_result[i] - mage_result[i]).abs() < 1e-5,
"archmage/wide mismatch at [{}] seed {}: {} vs {}",
i,
seed,
wide_result[i],
mage_result[i]
);
}
}
}
}