use crate::foundation::consts::DCT_BLOCK_SIZE;
use crate::foundation::simd_types::Block8x8f;
#[cfg(target_arch = "x86_64")]
use archmage::SimdToken;
use archmage::prelude::*;
use magetypes::simd::generic::f32x8 as GenericF32x8;
#[allow(dead_code)]
const WC4: [f32; 2] = [0.541196100146197, 1.3065629648763764];
#[allow(dead_code)]
const WC8: [f32; 4] = [
0.5097955791041592,
0.6013448869350453,
0.8999762231364156,
2.5629154477415055,
];
#[allow(dead_code)]
const SQRT2: f32 = 1.41421356237;
mod aan {
pub const C4: f32 = 0.707106781;
pub const C6: f32 = 0.382683433;
pub const C2_M_C6: f32 = 0.541196100;
pub const C2_P_C6: f32 = 1.306562965;
#[allow(dead_code)]
pub const SCALES: [f32; 8] = [
1.0,
1.387039845,
1.306562965,
1.175875602,
1.0,
0.785694958,
0.541196100,
0.275899379,
];
pub const INV_SCALES: [f32; 8] = [
1.0,
1.0 / 1.387039845,
1.0 / 1.306562965,
1.0 / 1.175875602,
1.0,
1.0 / 0.785694958,
1.0 / 0.541196100,
1.0 / 0.275899379,
];
}
#[cfg(test)]
#[inline]
fn transpose_8x8(input: &[f32; 64], output: &mut [f32; 64]) {
for row in 0..8 {
for col in 0..8 {
output[col * 8 + row] = input[row * 8 + col];
}
}
}
#[allow(dead_code)]
#[inline]
fn add_reverse<const N: usize>(in1: &[f32], in2: &[f32], out: &mut [f32]) {
for i in 0..N {
out[i] = in1[i] + in2[N - 1 - i];
}
}
#[allow(dead_code)]
#[inline]
fn sub_reverse<const N: usize>(in1: &[f32], in2: &[f32], out: &mut [f32]) {
for i in 0..N {
out[i] = in1[i] - in2[N - 1 - i];
}
}
#[allow(dead_code)]
#[inline]
fn b_transform<const N: usize>(coeff: &mut [f32]) {
coeff[0] = coeff[0] * SQRT2 + coeff[1];
for i in 1..(N - 1) {
coeff[i] += coeff[i + 1];
}
}
#[allow(dead_code)]
#[inline]
fn multiply_wc8(coeff: &mut [f32]) {
for i in 0..4 {
coeff[4 + i] *= WC8[i];
}
}
#[allow(dead_code)]
#[inline]
fn multiply_wc4(coeff: &mut [f32]) {
for i in 0..2 {
coeff[2 + i] *= WC4[i];
}
}
#[allow(dead_code)]
#[inline]
fn inverse_even_odd<const N: usize>(input: &[f32], output: &mut [f32]) {
let half = N / 2;
for i in 0..half {
output[2 * i] = input[i];
output[2 * i + 1] = input[half + i];
}
}
#[allow(dead_code)]
#[inline]
fn dct1d_2(mem: &mut [f32]) {
let in1 = mem[0];
let in2 = mem[1];
mem[0] = in1 + in2;
mem[1] = in1 - in2;
}
#[allow(dead_code)]
#[inline]
fn dct1d_4(mem: &mut [f32]) {
let mut tmp = [0.0f32; 4];
add_reverse::<2>(&mem[0..2], &mem[2..4], &mut tmp[0..2]);
dct1d_2(&mut tmp[0..2]);
sub_reverse::<2>(&mem[0..2], &mem[2..4], &mut tmp[2..4]);
multiply_wc4(&mut tmp);
dct1d_2(&mut tmp[2..4]);
tmp[2] = tmp[2] * SQRT2 + tmp[3];
inverse_even_odd::<4>(&tmp, mem);
}
#[allow(dead_code)]
#[inline]
fn dct1d_8(mem: &mut [f32]) {
let mut tmp = [0.0f32; 8];
add_reverse::<4>(&mem[0..4], &mem[4..8], &mut tmp[0..4]);
dct1d_4(&mut tmp[0..4]);
sub_reverse::<4>(&mem[0..4], &mem[4..8], &mut tmp[4..8]);
multiply_wc8(&mut tmp);
dct1d_4(&mut tmp[4..8]);
b_transform::<4>(&mut tmp[4..8]);
inverse_even_odd::<8>(&tmp, mem);
}
pub(crate) mod simd {
use super::*;
#[cfg(target_arch = "x86_64")]
use archmage::{SimdToken, arcane};
#[cfg(target_arch = "x86_64")]
use magetypes::simd::f32x8 as mf32x8;
pub fn transpose_8x8_simd(input: &[f32; 64], output: &mut [f32; 64]) {
#[cfg(target_arch = "x86_64")]
if let Some(token) = archmage::X64V3Token::summon() {
return mage_transpose_8x8(token, input, output);
}
incant!(transpose_8x8_generic(input, output));
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn mage_transpose_8x8(_token: archmage::X64V3Token, input: &[f32; 64], output: &mut [f32; 64]) {
let mut rows = mf32x8::load_8x8(input);
mf32x8::transpose_8x8(&mut rows);
mf32x8::store_8x8(&rows, output);
}
#[magetypes(v3, neon, wasm128, scalar)]
#[inline(always)]
fn transpose_8x8_generic(_token: Token, input: &[f32; 64], output: &mut [f32; 64]) {
#[allow(non_camel_case_types)]
type f32x8 = GenericF32x8<Token>;
let mut rows = f32x8::load_8x8(input);
f32x8::transpose_8x8(&mut rows);
f32x8::store_8x8(&rows, output);
}
#[inline(always)]
fn dct_1d_4_vec_generic<T: magetypes::simd::backends::F32x8Backend>(
token: T,
m: [GenericF32x8<T>; 4],
) -> [GenericF32x8<T>; 4] {
#[allow(non_camel_case_types)]
type f32x8<U> = GenericF32x8<U>;
let wc4_0 = f32x8::splat(token, 0.541196100146197);
let wc4_1 = f32x8::splat(token, 1.3065629648763764);
let sqrt2 = f32x8::splat(token, 1.41421356237);
let t0 = m[0] + m[3];
let t1 = m[1] + m[2];
let t2 = m[0] - m[3];
let t3 = m[1] - m[2];
let r0 = t0 + t1;
let r1 = t0 - t1;
let t2_scaled = t2 * wc4_0;
let t3_scaled = t3 * wc4_1;
let r2 = t2_scaled + t3_scaled;
let r3 = t2_scaled - t3_scaled;
let r2_final = r2 * sqrt2 + r3;
[r0, r2_final, r1, r3]
}
#[inline(always)]
fn dct_1d_vec_generic<T: magetypes::simd::backends::F32x8Backend>(
token: T,
m: [GenericF32x8<T>; 8],
) -> [GenericF32x8<T>; 8] {
#[allow(non_camel_case_types)]
type f32x8<U> = GenericF32x8<U>;
let wc8_0 = f32x8::splat(token, 0.5097955791041592);
let wc8_1 = f32x8::splat(token, 0.6013448869350453);
let wc8_2 = f32x8::splat(token, 0.8999762231364156);
let wc8_3 = f32x8::splat(token, 2.5629154477415055);
let sqrt2 = f32x8::splat(token, 1.41421356237);
let t0 = m[0] + m[7];
let t1 = m[1] + m[6];
let t2 = m[2] + m[5];
let t3 = m[3] + m[4];
let t4 = m[0] - m[7];
let t5 = m[1] - m[6];
let t6 = m[2] - m[5];
let t7 = m[3] - m[4];
let first = dct_1d_4_vec_generic(token, [t0, t1, t2, t3]);
let t4_scaled = t4 * wc8_0;
let t5_scaled = t5 * wc8_1;
let t6_scaled = t6 * wc8_2;
let t7_scaled = t7 * wc8_3;
let mut second = dct_1d_4_vec_generic(token, [t4_scaled, t5_scaled, t6_scaled, t7_scaled]);
second[0] = second[0] * sqrt2 + second[1];
second[1] += second[2];
second[2] += second[3];
[
first[0], second[0], first[1], second[1], first[2], second[2], first[3], second[3],
]
}
#[inline(always)]
fn scale_vec_generic<T: magetypes::simd::backends::F32x8Backend>(
token: T,
v: [GenericF32x8<T>; 8],
) -> [GenericF32x8<T>; 8] {
let scale = GenericF32x8::<T>::splat(token, 1.0 / 8.0);
[
v[0] * scale,
v[1] * scale,
v[2] * scale,
v[3] * scale,
v[4] * scale,
v[5] * scale,
v[6] * scale,
v[7] * scale,
]
}
#[allow(dead_code)]
pub fn dct_8rows_parallel(input: &[f32; 64], output: &mut [f32; 64]) {
incant!(dct_8rows_parallel_impl(input, output));
}
#[magetypes(v3, neon, wasm128, scalar)]
#[allow(dead_code)]
fn dct_8rows_parallel_impl(token: Token, input: &[f32; 64], output: &mut [f32; 64]) {
#[allow(non_camel_case_types)]
type f32x8 = GenericF32x8<Token>;
let mut rows = f32x8::load_8x8(input);
f32x8::transpose_8x8(&mut rows);
rows = dct_1d_vec_generic(token, rows);
f32x8::transpose_8x8(&mut rows);
f32x8::store_8x8(&rows, output);
}
pub fn forward_dct_8x8_simd_chained(input: &[f32; 64]) -> [f32; 64] {
#[cfg(target_arch = "x86_64")]
if let Some(token) = archmage::X64V3Token::summon() {
let mut output = [0.0f32; 64];
crate::encode::mage_simd::mage_forward_dct_8x8(token, input, &mut output);
return output;
}
#[cfg(target_arch = "wasm32")]
if let Some(token) = archmage::Wasm128Token::summon() {
return forward_dct_8x8_simd_chained_wasm(token, input);
}
incant!(forward_dct_8x8_simd_chained_fallback(input))
}
#[cfg(target_arch = "wasm32")]
#[inline]
pub(crate) fn forward_dct_8x8_simd_chained_wasm(
token: archmage::Wasm128Token,
input: &[f32; 64],
) -> [f32; 64] {
type F32x8 = GenericF32x8<archmage::Wasm128Token>;
let mut rows = F32x8::load_8x8(input);
F32x8::transpose_8x8(&mut rows);
let cols_after_row = scale_vec_generic(token, dct_1d_vec_generic(token, rows));
let mut rows_for_col = cols_after_row;
F32x8::transpose_8x8(&mut rows_for_col);
let final_rows = scale_vec_generic(token, dct_1d_vec_generic(token, rows_for_col));
let mut output = [0.0f32; 64];
F32x8::store_8x8(&final_rows, &mut output);
output
}
#[magetypes(v3, neon, wasm128, scalar)]
fn forward_dct_8x8_simd_chained_fallback(token: Token, input: &[f32; 64]) -> [f32; 64] {
#[allow(non_camel_case_types)]
type f32x8 = GenericF32x8<Token>;
let mut rows = f32x8::load_8x8(input);
f32x8::transpose_8x8(&mut rows);
let cols_after_row = scale_vec_generic(token, dct_1d_vec_generic(token, rows));
let mut rows_for_col = cols_after_row;
f32x8::transpose_8x8(&mut rows_for_col);
let final_rows = scale_vec_generic(token, dct_1d_vec_generic(token, rows_for_col));
let mut output = [0.0f32; 64];
f32x8::store_8x8(&final_rows, &mut output);
output
}
#[inline]
pub fn forward_dct_8x8_wide(input: &Block8x8f) -> Block8x8f {
#[cfg(target_arch = "x86_64")]
if let Some(token) = archmage::X64V3Token::summon() {
return crate::encode::mage_simd::mage_forward_dct_8x8_wide(token, input);
}
#[cfg(target_arch = "wasm32")]
if let Some(token) = archmage::Wasm128Token::summon() {
return forward_dct_8x8_wide_wasm(token, input);
}
incant!(forward_dct_8x8_wide_fallback(input))
}
#[cfg(target_arch = "wasm32")]
#[inline]
pub fn forward_dct_8x8_wide_wasm(
token: archmage::Wasm128Token,
input: &Block8x8f,
) -> Block8x8f {
type F32x8 = GenericF32x8<archmage::Wasm128Token>;
let mut rows: [F32x8; 8] =
core::array::from_fn(|i| F32x8::from_array(token, input.rows[i]));
F32x8::transpose_8x8(&mut rows);
let cols_after_row = scale_vec_generic(token, dct_1d_vec_generic(token, rows));
let mut rows_for_col = cols_after_row;
F32x8::transpose_8x8(&mut rows_for_col);
let final_rows = scale_vec_generic(token, dct_1d_vec_generic(token, rows_for_col));
Block8x8f {
rows: core::array::from_fn(|i| final_rows[i].to_array()),
}
}
#[magetypes(v3, neon, wasm128, scalar)]
#[inline]
fn forward_dct_8x8_wide_fallback(token: Token, input: &Block8x8f) -> Block8x8f {
#[allow(non_camel_case_types)]
type f32x8 = GenericF32x8<Token>;
let mut rows: [f32x8; 8] =
core::array::from_fn(|i| f32x8::from_array(token, input.rows[i]));
f32x8::transpose_8x8(&mut rows);
let cols_after_row = scale_vec_generic(token, dct_1d_vec_generic(token, rows));
let mut rows_for_col = cols_after_row;
f32x8::transpose_8x8(&mut rows_for_col);
let final_rows = scale_vec_generic(token, dct_1d_vec_generic(token, rows_for_col));
Block8x8f {
rows: core::array::from_fn(|i| final_rows[i].to_array()),
}
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
fn dct1d_4_mage(m: &mut [mf32x8; 4], wc4_0: mf32x8, wc4_1: mf32x8, sqrt2: mf32x8) {
let t0 = m[0] + m[3];
let t1 = m[1] + m[2];
let t2 = m[0] - m[3];
let t3 = m[1] - m[2];
let r0 = t0 + t1;
let r1 = t0 - t1;
let t2_scaled = t2 * wc4_0;
let t3_scaled = t3 * wc4_1;
let r2 = t2_scaled + t3_scaled;
let r3 = t2_scaled - t3_scaled;
let r2_final = r2.mul_add(sqrt2, r3);
m[0] = r0;
m[1] = r2_final;
m[2] = r1;
m[3] = r3;
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
fn dct1d_8_mage(m: &mut [mf32x8; 8], token: archmage::X64V3Token) {
let wc4_0 = mf32x8::splat(token, 0.541196100146197);
let wc4_1 = mf32x8::splat(token, 1.3065629648763764);
let wc8_0 = mf32x8::splat(token, 0.5097955791041592);
let wc8_1 = mf32x8::splat(token, 0.6013448869350453);
let wc8_2 = mf32x8::splat(token, 0.8999762231364156);
let wc8_3 = mf32x8::splat(token, 2.5629154477415055);
let sqrt2 = mf32x8::splat(token, 1.41421356237);
let t0 = m[0] + m[7];
let t1 = m[1] + m[6];
let t2 = m[2] + m[5];
let t3 = m[3] + m[4];
let t4 = m[0] - m[7];
let t5 = m[1] - m[6];
let t6 = m[2] - m[5];
let t7 = m[3] - m[4];
let mut first = [t0, t1, t2, t3];
dct1d_4_mage(&mut first, wc4_0, wc4_1, sqrt2);
let t4_scaled = t4 * wc8_0;
let t5_scaled = t5 * wc8_1;
let t6_scaled = t6 * wc8_2;
let t7_scaled = t7 * wc8_3;
let mut second = [t4_scaled, t5_scaled, t6_scaled, t7_scaled];
dct1d_4_mage(&mut second, wc4_0, wc4_1, sqrt2);
second[0] = second[0].mul_add(sqrt2, second[1]);
second[1] += second[2];
second[2] += second[3];
m[0] = first[0];
m[1] = second[0];
m[2] = first[1];
m[3] = second[1];
m[4] = first[2];
m[5] = second[2];
m[6] = first[3];
m[7] = second[3];
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
pub(crate) fn forward_dct_8x8_mage(input: &[f32; 64], output: &mut [f32; 64]) {
let token = archmage::X64V3Token::summon().unwrap();
let scale = mf32x8::splat(token, 1.0 / 8.0);
let mut reg = mf32x8::load_8x8(input);
mf32x8::transpose_8x8(&mut reg);
dct1d_8_mage(&mut reg, token);
for r in &mut reg {
*r *= scale;
}
mf32x8::transpose_8x8(&mut reg);
dct1d_8_mage(&mut reg, token);
for r in &mut reg {
*r *= scale;
}
mf32x8::store_8x8(®, output);
}
}
#[allow(dead_code)]
#[inline]
fn dct_rows(input: &[f32; 64], output: &mut [f32; 64]) {
for row in 0..8 {
let mut tmp = [0.0f32; 8];
for i in 0..8 {
tmp[i] = input[row * 8 + i];
}
dct1d_8(&mut tmp);
for i in 0..8 {
output[row * 8 + i] = tmp[i];
}
}
}
#[must_use]
#[inline]
pub fn forward_dct_8x8(input: &[f32; DCT_BLOCK_SIZE]) -> [f32; DCT_BLOCK_SIZE] {
#[cfg(target_arch = "x86_64")]
{
if archmage::X64V3Token::summon().is_some() {
let mut output = [0.0f32; 64];
simd::forward_dct_8x8_mage(input, &mut output);
return output;
}
}
#[cfg(target_arch = "wasm32")]
{
if let Some(token) = archmage::Wasm128Token::summon() {
return simd::forward_dct_8x8_simd_chained_wasm(token, input);
}
}
forward_dct_8x8_scalar(input)
}
#[inline]
pub fn forward_dct_8x8_scalar(input: &[f32; DCT_BLOCK_SIZE]) -> [f32; DCT_BLOCK_SIZE] {
simd::forward_dct_8x8_simd_chained(input)
}
#[must_use]
pub fn forward_dct_8x8_u8(input: &[u8; DCT_BLOCK_SIZE]) -> [f32; DCT_BLOCK_SIZE] {
let mut shifted = [0.0f32; DCT_BLOCK_SIZE];
for (i, &v) in input.iter().enumerate() {
shifted[i] = v as f32 - 128.0;
}
forward_dct_8x8(&shifted)
}
pub fn forward_dct_blocks(blocks: &[[f32; DCT_BLOCK_SIZE]]) -> Vec<[f32; DCT_BLOCK_SIZE]> {
blocks.iter().map(forward_dct_8x8).collect()
}
#[inline(always)]
fn aan_dct_1d(d: &mut [f32; 8]) {
let tmp0 = d[0] + d[7];
let tmp7 = d[0] - d[7];
let tmp1 = d[1] + d[6];
let tmp6 = d[1] - d[6];
let tmp2 = d[2] + d[5];
let tmp5 = d[2] - d[5];
let tmp3 = d[3] + d[4];
let tmp4 = d[3] - d[4];
let tmp10 = tmp0 + tmp3;
let tmp13 = tmp0 - tmp3;
let tmp11 = tmp1 + tmp2;
let tmp12 = tmp1 - tmp2;
d[0] = tmp10 + tmp11; d[4] = tmp10 - tmp11;
let z1 = (tmp12 + tmp13) * aan::C4;
d[2] = tmp13 + z1;
d[6] = tmp13 - z1;
let tmp10 = tmp4 + tmp5;
let tmp11 = tmp5 + tmp6;
let tmp12 = tmp6 + tmp7;
let z5 = (tmp10 - tmp12) * aan::C6;
let z2 = aan::C2_M_C6 * tmp10 + z5;
let z4 = aan::C2_P_C6 * tmp12 + z5;
let z3 = tmp11 * aan::C4;
let z11 = tmp7 + z3;
let z13 = tmp7 - z3;
d[5] = z13 + z2;
d[3] = z13 - z2;
d[1] = z11 + z4;
d[7] = z11 - z4;
d[0] *= aan::INV_SCALES[0];
d[1] *= aan::INV_SCALES[1];
d[2] *= aan::INV_SCALES[2];
d[3] *= aan::INV_SCALES[3];
d[4] *= aan::INV_SCALES[4];
d[5] *= aan::INV_SCALES[5];
d[6] *= aan::INV_SCALES[6];
d[7] *= aan::INV_SCALES[7];
}
#[inline]
#[must_use]
pub fn aan_forward_dct_8x8(input: &[f32; DCT_BLOCK_SIZE]) -> [f32; DCT_BLOCK_SIZE] {
let mut data = *input;
for row in 0..8 {
let offset = row * 8;
let mut row_data: [f32; 8] = [
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
data[offset + 4],
data[offset + 5],
data[offset + 6],
data[offset + 7],
];
aan_dct_1d(&mut row_data);
data[offset] = row_data[0];
data[offset + 1] = row_data[1];
data[offset + 2] = row_data[2];
data[offset + 3] = row_data[3];
data[offset + 4] = row_data[4];
data[offset + 5] = row_data[5];
data[offset + 6] = row_data[6];
data[offset + 7] = row_data[7];
}
for v in &mut data {
*v *= 0.125;
}
for col in 0..8 {
let mut col_data: [f32; 8] = [
data[col],
data[col + 8],
data[col + 16],
data[col + 24],
data[col + 32],
data[col + 40],
data[col + 48],
data[col + 56],
];
aan_dct_1d(&mut col_data);
data[col] = col_data[0];
data[col + 8] = col_data[1];
data[col + 16] = col_data[2];
data[col + 24] = col_data[3];
data[col + 32] = col_data[4];
data[col + 40] = col_data[5];
data[col + 48] = col_data[6];
data[col + 56] = col_data[7];
}
for v in &mut data {
*v *= 0.125;
}
data
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dct_dc_only() {
let input = [128.0f32; DCT_BLOCK_SIZE];
let output = forward_dct_8x8(&input);
assert!(output[0].abs() > 1.0);
for i in 1..DCT_BLOCK_SIZE {
assert!(
output[i].abs() < 0.001,
"AC[{}] = {} should be ~0",
i,
output[i]
);
}
}
#[test]
fn test_dct_zero_block() {
let input = [0.0f32; DCT_BLOCK_SIZE];
let output = forward_dct_8x8(&input);
for (i, &v) in output.iter().enumerate() {
assert!(v.abs() < 0.001, "coeff[{}] = {} should be 0", i, v);
}
}
#[test]
fn test_dct_u8_level_shift() {
let input = [128u8; DCT_BLOCK_SIZE];
let output = forward_dct_8x8_u8(&input);
for (i, &v) in output.iter().enumerate() {
assert!(v.abs() < 0.001, "coeff[{}] = {} should be 0", i, v);
}
}
#[test]
fn test_transpose() {
let mut input = [0.0f32; 64];
for i in 0..64 {
input[i] = i as f32;
}
let mut output = [0.0f32; 64];
transpose_8x8(&input, &mut output);
assert_eq!(output[0], 0.0); assert_eq!(output[1], 8.0); assert_eq!(output[8], 1.0); assert_eq!(output[9], 9.0); }
#[test]
fn test_simd_transpose_matches_scalar() {
let mut input = [0.0f32; 64];
for i in 0..64 {
input[i] = (i as f32 * 1.7).sin() * 100.0;
}
let mut scalar_output = [0.0f32; 64];
let mut simd_output = [0.0f32; 64];
transpose_8x8(&input, &mut scalar_output);
simd::transpose_8x8_simd(&input, &mut simd_output);
for i in 0..64 {
assert!(
(scalar_output[i] - simd_output[i]).abs() < 1e-6,
"Mismatch at {}: scalar={} simd={}",
i,
scalar_output[i],
simd_output[i]
);
}
}
#[test]
fn test_simd_dct_rows_matches_scalar() {
let patterns: Vec<[f32; 64]> = vec![
[64.0; 64],
{
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = i as f32;
}
arr
},
{
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = (i as f32 * 0.3).sin() * 100.0;
}
arr
},
{
let mut arr = [0.0f32; 64];
for row in 0..8 {
for col in 0..8 {
arr[row * 8 + col] = if (row + col) % 2 == 0 { 100.0 } else { -100.0 };
}
}
arr
},
];
for (pattern_idx, input) in patterns.iter().enumerate() {
let mut scalar_output = [0.0f32; 64];
for row in 0..8 {
let mut tmp = [0.0f32; 8];
for i in 0..8 {
tmp[i] = input[row * 8 + i];
}
dct1d_8(&mut tmp);
for i in 0..8 {
scalar_output[row * 8 + i] = tmp[i];
}
}
let mut simd_output = [0.0f32; 64];
simd::dct_8rows_parallel(input, &mut simd_output);
let mut max_error = 0.0f32;
for i in 0..64 {
let error = (scalar_output[i] - simd_output[i]).abs();
max_error = max_error.max(error);
}
assert!(
max_error < 1e-4,
"Pattern {}: SIMD DCT rows max error {} exceeds threshold",
pattern_idx,
max_error
);
}
}
#[test]
fn test_aan_dct_matches_recursive() {
let patterns: Vec<[f32; 64]> = vec![
[64.0; 64],
{
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = i as f32;
}
arr
},
{
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = (i as f32 * 0.3).sin() * 100.0;
}
arr
},
{
let mut arr = [0.0f32; 64];
for row in 0..8 {
for col in 0..8 {
arr[row * 8 + col] = if (row + col) % 2 == 0 { 100.0 } else { -100.0 };
}
}
arr
},
{
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = ((i * 17 + 31) % 256) as f32 - 128.0;
}
arr
},
];
for (pattern_idx, input) in patterns.iter().enumerate() {
let recursive_output = forward_dct_8x8(input);
let aan_output = aan_forward_dct_8x8(input);
let mut max_error = 0.0f32;
for i in 0..64 {
let error = (recursive_output[i] - aan_output[i]).abs();
max_error = max_error.max(error);
}
assert!(
max_error < 1e-4,
"Pattern {}: AAN vs recursive max error {} exceeds threshold.\n\
Recursive[0..8]: {:?}\n\
AAN[0..8]: {:?}",
pattern_idx,
max_error,
&recursive_output[0..8],
&aan_output[0..8]
);
}
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_mage_dct_matches_scalar() {
let _lock = archmage::testing::lock_token_testing();
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
return;
}
let patterns: Vec<[f32; 64]> = vec![
[64.0; 64],
{
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = i as f32;
}
arr
},
{
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = ((i as f32) * 0.5).sin() * 100.0 + 50.0;
}
arr
},
];
for (pattern_idx, input) in patterns.iter().enumerate() {
let scalar_output = forward_dct_8x8_scalar(input);
let mut mage_output = [0.0f32; 64];
simd::forward_dct_8x8_mage(input, &mut mage_output);
let mut max_error = 0.0f32;
for i in 0..64 {
let error = (scalar_output[i] - mage_output[i]).abs();
max_error = max_error.max(error);
}
assert!(
max_error < 1e-4,
"Pattern {}: magetypes vs scalar max error {} exceeds threshold.\n\
Scalar[0..8]: {:?}\n\
Mage[0..8]: {:?}",
pattern_idx,
max_error,
&scalar_output[0..8],
&mage_output[0..8]
);
}
}
#[cfg(target_arch = "x86_64")]
#[test]
#[ignore] fn bench_mage_dct_vs_scalar() {
use std::hint::black_box;
use std::time::Instant;
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
println!("AVX2+FMA not available, skipping benchmark");
return;
}
let input: [f32; 64] = [
52., 55., 61., 66., 70., 61., 64., 73., 63., 59., 55., 90., 109., 85., 69., 72., 62.,
59., 68., 113., 144., 104., 66., 73., 63., 58., 71., 122., 154., 106., 70., 69., 67.,
61., 68., 104., 126., 88., 68., 70., 79., 65., 60., 70., 77., 68., 58., 75., 85., 71.,
64., 59., 55., 61., 65., 83., 87., 79., 69., 68., 65., 76., 78., 94.,
];
let iterations = 10_000_000;
let mut output = [0.0f32; 64];
for _ in 0..10000 {
let _ = forward_dct_8x8_scalar(black_box(&input));
}
let start = Instant::now();
for _ in 0..iterations {
let result = forward_dct_8x8_scalar(black_box(&input));
black_box(&result);
}
let scalar_time = start.elapsed();
let scalar_ns = scalar_time.as_nanos() as f64 / iterations as f64;
for _ in 0..10000 {
simd::forward_dct_8x8_mage(black_box(&input), &mut output);
}
let start = Instant::now();
for _ in 0..iterations {
simd::forward_dct_8x8_mage(black_box(&input), black_box(&mut output));
}
let mage_time = start.elapsed();
let mage_ns = mage_time.as_nanos() as f64 / iterations as f64;
println!("DCT Performance ({} iterations):", iterations);
println!(" Scalar: {:.2} ns/DCT", scalar_ns);
println!(" Magetypes: {:.2} ns/DCT", mage_ns);
println!(
" Speedup: {:.2}x ({})",
scalar_ns / mage_ns,
if mage_ns < scalar_ns {
"magetypes faster"
} else {
"Scalar faster"
}
);
}
#[test]
fn test_wide_dct_matches_array_dct() {
use crate::foundation::simd_types::Block8x8f;
let patterns: [[f32; 64]; 4] = [
[128.0; 64], {
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = i as f32;
}
arr
},
{
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = (i as f32 * 0.3).sin() * 100.0;
}
arr
},
{
let mut arr = [0.0f32; 64];
for row in 0..8 {
for col in 0..8 {
arr[row * 8 + col] = if (row + col) % 2 == 0 { 100.0 } else { -100.0 };
}
}
arr
},
];
for (idx, input) in patterns.iter().enumerate() {
let array_output = forward_dct_8x8(input);
let block_input = Block8x8f::from_array(input);
let block_output = simd::forward_dct_8x8_wide(&block_input);
let wide_output = block_output.to_array();
let mut max_error = 0.0f32;
for i in 0..64 {
let error = (array_output[i] - wide_output[i]).abs();
max_error = max_error.max(error);
}
assert!(
max_error < 1e-4,
"Pattern {}: wide DCT differs from array DCT by {}",
idx,
max_error
);
}
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_dct_dispatch_parity() {
use archmage::testing::{CompileTimePolicy, for_each_token_permutation};
let patterns: Vec<[f32; 64]> = vec![
[42.0; 64],
core::array::from_fn(|i| i as f32),
core::array::from_fn(|i| ((i % 8) as f32 - 3.5) * 20.0),
core::array::from_fn(|i| {
let r = i / 8;
let c = i % 8;
((r * 7 + c * 13) % 256) as f32 - 128.0
}),
];
for (idx, input) in patterns.iter().enumerate() {
let reference = forward_dct_8x8(input);
let report = for_each_token_permutation(CompileTimePolicy::Warn, |perm| {
let result = forward_dct_8x8(input);
for k in 0..64 {
let diff = (result[k] - reference[k]).abs();
assert!(
diff < 1e-4,
"DCT pattern {idx} coeff {k}: ref={} got={} diff={diff} at {perm}",
reference[k],
result[k]
);
}
});
if idx == 0 {
eprintln!("dct_dispatch: {report}");
assert!(
report.permutations_run >= 2,
"expected at least 2 permutations"
);
}
}
}
}