use crate::foundation::consts::DCT_BLOCK_SIZE;
use crate::foundation::simd_types::Block8x8f;
#[cfg(target_arch = "x86_64")]
use archmage::SimdToken;
use archmage::autoversion;
use wide::f32x8;
#[allow(dead_code)]
const WC4: [f32; 2] = [0.541196100146197, 1.3065629648763764];
#[allow(dead_code)]
const WC8: [f32; 4] = [
0.5097955791041592,
0.6013448869350453,
0.8999762231364156,
2.5629154477415055,
];
#[allow(dead_code)]
const SQRT2: f32 = 1.41421356237;
mod aan {
pub const C4: f32 = 0.707106781;
pub const C6: f32 = 0.382683433;
pub const C2_M_C6: f32 = 0.541196100;
pub const C2_P_C6: f32 = 1.306562965;
#[allow(dead_code)]
pub const SCALES: [f32; 8] = [
1.0,
1.387039845,
1.306562965,
1.175875602,
1.0,
0.785694958,
0.541196100,
0.275899379,
];
pub const INV_SCALES: [f32; 8] = [
1.0,
1.0 / 1.387039845,
1.0 / 1.306562965,
1.0 / 1.175875602,
1.0,
1.0 / 0.785694958,
1.0 / 0.541196100,
1.0 / 0.275899379,
];
}
#[cfg(test)]
#[inline]
fn transpose_8x8(input: &[f32; 64], output: &mut [f32; 64]) {
for row in 0..8 {
for col in 0..8 {
output[col * 8 + row] = input[row * 8 + col];
}
}
}
#[allow(dead_code)]
#[inline]
fn add_reverse<const N: usize>(in1: &[f32], in2: &[f32], out: &mut [f32]) {
for i in 0..N {
out[i] = in1[i] + in2[N - 1 - i];
}
}
#[allow(dead_code)]
#[inline]
fn sub_reverse<const N: usize>(in1: &[f32], in2: &[f32], out: &mut [f32]) {
for i in 0..N {
out[i] = in1[i] - in2[N - 1 - i];
}
}
#[allow(dead_code)]
#[inline]
fn b_transform<const N: usize>(coeff: &mut [f32]) {
coeff[0] = coeff[0] * SQRT2 + coeff[1];
for i in 1..(N - 1) {
coeff[i] += coeff[i + 1];
}
}
#[allow(dead_code)]
#[inline]
fn multiply_wc8(coeff: &mut [f32]) {
for i in 0..4 {
coeff[4 + i] *= WC8[i];
}
}
#[allow(dead_code)]
#[inline]
fn multiply_wc4(coeff: &mut [f32]) {
for i in 0..2 {
coeff[2 + i] *= WC4[i];
}
}
#[allow(dead_code)]
#[inline]
fn inverse_even_odd<const N: usize>(input: &[f32], output: &mut [f32]) {
let half = N / 2;
for i in 0..half {
output[2 * i] = input[i];
output[2 * i + 1] = input[half + i];
}
}
#[allow(dead_code)]
#[inline]
fn dct1d_2(mem: &mut [f32]) {
let in1 = mem[0];
let in2 = mem[1];
mem[0] = in1 + in2;
mem[1] = in1 - in2;
}
#[allow(dead_code)]
#[inline]
fn dct1d_4(mem: &mut [f32]) {
let mut tmp = [0.0f32; 4];
add_reverse::<2>(&mem[0..2], &mem[2..4], &mut tmp[0..2]);
dct1d_2(&mut tmp[0..2]);
sub_reverse::<2>(&mem[0..2], &mem[2..4], &mut tmp[2..4]);
multiply_wc4(&mut tmp);
dct1d_2(&mut tmp[2..4]);
tmp[2] = tmp[2] * SQRT2 + tmp[3];
inverse_even_odd::<4>(&tmp, mem);
}
#[allow(dead_code)]
#[inline]
fn dct1d_8(mem: &mut [f32]) {
let mut tmp = [0.0f32; 8];
add_reverse::<4>(&mem[0..4], &mem[4..8], &mut tmp[0..4]);
dct1d_4(&mut tmp[0..4]);
sub_reverse::<4>(&mem[0..4], &mem[4..8], &mut tmp[4..8]);
multiply_wc8(&mut tmp);
dct1d_4(&mut tmp[4..8]);
b_transform::<4>(&mut tmp[4..8]);
inverse_even_odd::<8>(&tmp, mem);
}
pub(crate) mod simd {
use super::*;
#[cfg(target_arch = "x86_64")]
use archmage::{SimdToken, arcane};
#[cfg(target_arch = "x86_64")]
use magetypes::simd::f32x8 as mf32x8;
#[allow(dead_code)]
const WC4_0: f32x8 = f32x8::new([0.541196100146197; 8]);
#[allow(dead_code)]
const WC4_1: f32x8 = f32x8::new([1.3065629648763764; 8]);
#[allow(dead_code)]
const WC8_0: f32x8 = f32x8::new([0.5097955791041592; 8]);
#[allow(dead_code)]
const WC8_1: f32x8 = f32x8::new([0.6013448869350453; 8]);
#[allow(dead_code)]
const WC8_2: f32x8 = f32x8::new([0.8999762231364156; 8]);
#[allow(dead_code)]
const WC8_3: f32x8 = f32x8::new([2.5629154477415055; 8]);
#[allow(dead_code)]
const SQRT2_VEC: f32x8 = f32x8::new([1.41421356237; 8]);
#[allow(dead_code)]
#[inline]
fn dct1d_2_simd(m0: &mut f32x8, m1: &mut f32x8) {
let in0 = *m0;
let in1 = *m1;
*m0 = in0 + in1;
*m1 = in0 - in1;
}
#[allow(dead_code)]
#[inline]
fn dct1d_4_simd(m: &mut [f32x8; 4]) {
let t0 = m[0] + m[3]; let t1 = m[1] + m[2];
let t2 = m[0] - m[3];
let t3 = m[1] - m[2];
let r0 = t0 + t1;
let r1 = t0 - t1;
let t2_scaled = t2 * WC4_0;
let t3_scaled = t3 * WC4_1;
let r2 = t2_scaled + t3_scaled;
let r3 = t2_scaled - t3_scaled;
let r2_final = r2 * SQRT2_VEC + r3;
m[0] = r0; m[1] = r2_final; m[2] = r1; m[3] = r3; }
#[allow(dead_code)]
#[inline]
fn dct1d_8_simd(m: &mut [f32x8; 8]) {
let t0 = m[0] + m[7];
let t1 = m[1] + m[6];
let t2 = m[2] + m[5];
let t3 = m[3] + m[4];
let t4 = m[0] - m[7];
let t5 = m[1] - m[6];
let t6 = m[2] - m[5];
let t7 = m[3] - m[4];
let mut first = [t0, t1, t2, t3];
dct1d_4_simd(&mut first);
let t4_scaled = t4 * WC8_0;
let t5_scaled = t5 * WC8_1;
let t6_scaled = t6 * WC8_2;
let t7_scaled = t7 * WC8_3;
let mut second = [t4_scaled, t5_scaled, t6_scaled, t7_scaled];
dct1d_4_simd(&mut second);
second[0] = second[0] * SQRT2_VEC + second[1];
second[1] += second[2];
second[2] += second[3];
m[0] = first[0];
m[1] = second[0];
m[2] = first[1];
m[3] = second[1];
m[4] = first[2];
m[5] = second[2];
m[6] = first[3];
m[7] = second[3];
}
#[autoversion]
#[allow(dead_code)]
pub fn dct_8rows_parallel(input: &[f32; 64], output: &mut [f32; 64]) {
let mut rows: [f32x8; 8] = [f32x8::ZERO; 8];
for row in 0..8 {
let k = row * 8;
let row_slice: [f32; 8] = input[k..k + 8].try_into().unwrap();
rows[row] = f32x8::from(row_slice);
}
let mut mem = transpose_f32x8_array(&rows);
dct1d_8_simd(&mut mem);
let result = transpose_f32x8_array(&mem);
for (i, r) in result.iter().enumerate() {
let arr = r.to_array();
output[i * 8..i * 8 + 8].copy_from_slice(&arr);
}
}
#[allow(dead_code)]
#[inline]
fn transpose_f32x8_array(input: &[f32x8; 8]) -> [f32x8; 8] {
f32x8::transpose(*input)
}
pub fn transpose_8x8_simd(input: &[f32; 64], output: &mut [f32; 64]) {
#[cfg(target_arch = "x86_64")]
if let Some(token) = archmage::X64V3Token::summon() {
return mage_transpose_8x8(token, input, output);
}
transpose_8x8_wide(input, output);
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn mage_transpose_8x8(_token: archmage::X64V3Token, input: &[f32; 64], output: &mut [f32; 64]) {
let mut rows = mf32x8::load_8x8(input);
mf32x8::transpose_8x8(&mut rows);
mf32x8::store_8x8(&rows, output);
}
#[allow(dead_code)]
#[inline]
fn transpose_8x8_wide(input: &[f32; 64], output: &mut [f32; 64]) {
let rows = [
f32x8::from(<[f32; 8]>::try_from(&input[0..8]).unwrap()),
f32x8::from(<[f32; 8]>::try_from(&input[8..16]).unwrap()),
f32x8::from(<[f32; 8]>::try_from(&input[16..24]).unwrap()),
f32x8::from(<[f32; 8]>::try_from(&input[24..32]).unwrap()),
f32x8::from(<[f32; 8]>::try_from(&input[32..40]).unwrap()),
f32x8::from(<[f32; 8]>::try_from(&input[40..48]).unwrap()),
f32x8::from(<[f32; 8]>::try_from(&input[48..56]).unwrap()),
f32x8::from(<[f32; 8]>::try_from(&input[56..64]).unwrap()),
];
let transposed = f32x8::transpose(rows);
output[0..8].copy_from_slice(&transposed[0].to_array());
output[8..16].copy_from_slice(&transposed[1].to_array());
output[16..24].copy_from_slice(&transposed[2].to_array());
output[24..32].copy_from_slice(&transposed[3].to_array());
output[32..40].copy_from_slice(&transposed[4].to_array());
output[40..48].copy_from_slice(&transposed[5].to_array());
output[48..56].copy_from_slice(&transposed[6].to_array());
output[56..64].copy_from_slice(&transposed[7].to_array());
}
#[inline(always)]
fn transpose_vec(rows: [f32x8; 8]) -> [f32x8; 8] {
f32x8::transpose(rows)
}
const DCT_SCALE: f32x8 = f32x8::new([1.0 / 8.0; 8]);
#[inline(always)]
fn scale_vec(v: [f32x8; 8]) -> [f32x8; 8] {
[
v[0] * DCT_SCALE,
v[1] * DCT_SCALE,
v[2] * DCT_SCALE,
v[3] * DCT_SCALE,
v[4] * DCT_SCALE,
v[5] * DCT_SCALE,
v[6] * DCT_SCALE,
v[7] * DCT_SCALE,
]
}
#[inline(always)]
fn dct_1d_vec(m: [f32x8; 8]) -> [f32x8; 8] {
let t0 = m[0] + m[7];
let t1 = m[1] + m[6];
let t2 = m[2] + m[5];
let t3 = m[3] + m[4];
let t4 = m[0] - m[7];
let t5 = m[1] - m[6];
let t6 = m[2] - m[5];
let t7 = m[3] - m[4];
let first = dct_1d_4_vec([t0, t1, t2, t3]);
let t4_scaled = t4 * WC8_0;
let t5_scaled = t5 * WC8_1;
let t6_scaled = t6 * WC8_2;
let t7_scaled = t7 * WC8_3;
let mut second = dct_1d_4_vec([t4_scaled, t5_scaled, t6_scaled, t7_scaled]);
second[0] = second[0] * SQRT2_VEC + second[1];
second[1] += second[2];
second[2] += second[3];
[
first[0], second[0], first[1], second[1], first[2], second[2], first[3], second[3],
]
}
#[inline(always)]
fn dct_1d_4_vec(m: [f32x8; 4]) -> [f32x8; 4] {
let t0 = m[0] + m[3];
let t1 = m[1] + m[2];
let t2 = m[0] - m[3];
let t3 = m[1] - m[2];
let r0 = t0 + t1;
let r1 = t0 - t1;
let t2_scaled = t2 * WC4_0;
let t3_scaled = t3 * WC4_1;
let r2 = t2_scaled + t3_scaled;
let r3 = t2_scaled - t3_scaled;
let r2_final = r2 * SQRT2_VEC + r3;
[r0, r2_final, r1, r3]
}
pub fn forward_dct_8x8_simd_chained(input: &[f32; 64]) -> [f32; 64] {
#[cfg(target_arch = "x86_64")]
if let Some(token) = archmage::X64V3Token::summon() {
let mut output = [0.0f32; 64];
crate::encode::mage_simd::mage_forward_dct_8x8(token, input, &mut output);
return output;
}
forward_dct_8x8_simd_chained_fallback(input)
}
#[autoversion]
fn forward_dct_8x8_simd_chained_fallback(input: &[f32; 64]) -> [f32; 64] {
let rows = [
f32x8::from(<[f32; 8]>::try_from(&input[0..8]).unwrap()),
f32x8::from(<[f32; 8]>::try_from(&input[8..16]).unwrap()),
f32x8::from(<[f32; 8]>::try_from(&input[16..24]).unwrap()),
f32x8::from(<[f32; 8]>::try_from(&input[24..32]).unwrap()),
f32x8::from(<[f32; 8]>::try_from(&input[32..40]).unwrap()),
f32x8::from(<[f32; 8]>::try_from(&input[40..48]).unwrap()),
f32x8::from(<[f32; 8]>::try_from(&input[48..56]).unwrap()),
f32x8::from(<[f32; 8]>::try_from(&input[56..64]).unwrap()),
];
let cols = transpose_vec(rows);
let cols_after_row = scale_vec(dct_1d_vec(cols));
let rows_for_col = transpose_vec(cols_after_row);
let final_rows = scale_vec(dct_1d_vec(rows_for_col));
let mut output = [0.0f32; 64];
output[0..8].copy_from_slice(&final_rows[0].to_array());
output[8..16].copy_from_slice(&final_rows[1].to_array());
output[16..24].copy_from_slice(&final_rows[2].to_array());
output[24..32].copy_from_slice(&final_rows[3].to_array());
output[32..40].copy_from_slice(&final_rows[4].to_array());
output[40..48].copy_from_slice(&final_rows[5].to_array());
output[48..56].copy_from_slice(&final_rows[6].to_array());
output[56..64].copy_from_slice(&final_rows[7].to_array());
output
}
#[inline]
pub fn forward_dct_8x8_wide(input: &Block8x8f) -> Block8x8f {
#[cfg(target_arch = "x86_64")]
if let Some(token) = archmage::X64V3Token::summon() {
return crate::encode::mage_simd::mage_forward_dct_8x8_wide(token, input);
}
forward_dct_8x8_wide_fallback(input)
}
#[autoversion]
#[inline]
fn forward_dct_8x8_wide_fallback(input: &Block8x8f) -> Block8x8f {
let cols = transpose_vec(input.rows);
let cols_after_row = scale_vec(dct_1d_vec(cols));
let rows_for_col = transpose_vec(cols_after_row);
let final_rows = scale_vec(dct_1d_vec(rows_for_col));
Block8x8f { rows: final_rows }
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
fn dct1d_4_mage(m: &mut [mf32x8; 4], wc4_0: mf32x8, wc4_1: mf32x8, sqrt2: mf32x8) {
let t0 = m[0] + m[3];
let t1 = m[1] + m[2];
let t2 = m[0] - m[3];
let t3 = m[1] - m[2];
let r0 = t0 + t1;
let r1 = t0 - t1;
let t2_scaled = t2 * wc4_0;
let t3_scaled = t3 * wc4_1;
let r2 = t2_scaled + t3_scaled;
let r3 = t2_scaled - t3_scaled;
let r2_final = r2.mul_add(sqrt2, r3);
m[0] = r0;
m[1] = r2_final;
m[2] = r1;
m[3] = r3;
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
fn dct1d_8_mage(m: &mut [mf32x8; 8], token: archmage::X64V3Token) {
let wc4_0 = mf32x8::splat(token, 0.541196100146197);
let wc4_1 = mf32x8::splat(token, 1.3065629648763764);
let wc8_0 = mf32x8::splat(token, 0.5097955791041592);
let wc8_1 = mf32x8::splat(token, 0.6013448869350453);
let wc8_2 = mf32x8::splat(token, 0.8999762231364156);
let wc8_3 = mf32x8::splat(token, 2.5629154477415055);
let sqrt2 = mf32x8::splat(token, 1.41421356237);
let t0 = m[0] + m[7];
let t1 = m[1] + m[6];
let t2 = m[2] + m[5];
let t3 = m[3] + m[4];
let t4 = m[0] - m[7];
let t5 = m[1] - m[6];
let t6 = m[2] - m[5];
let t7 = m[3] - m[4];
let mut first = [t0, t1, t2, t3];
dct1d_4_mage(&mut first, wc4_0, wc4_1, sqrt2);
let t4_scaled = t4 * wc8_0;
let t5_scaled = t5 * wc8_1;
let t6_scaled = t6 * wc8_2;
let t7_scaled = t7 * wc8_3;
let mut second = [t4_scaled, t5_scaled, t6_scaled, t7_scaled];
dct1d_4_mage(&mut second, wc4_0, wc4_1, sqrt2);
second[0] = second[0].mul_add(sqrt2, second[1]);
second[1] += second[2];
second[2] += second[3];
m[0] = first[0];
m[1] = second[0];
m[2] = first[1];
m[3] = second[1];
m[4] = first[2];
m[5] = second[2];
m[6] = first[3];
m[7] = second[3];
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
pub(crate) fn forward_dct_8x8_mage(input: &[f32; 64], output: &mut [f32; 64]) {
let token = archmage::X64V3Token::summon().unwrap();
let scale = mf32x8::splat(token, 1.0 / 8.0);
let mut reg = mf32x8::load_8x8(input);
mf32x8::transpose_8x8(&mut reg);
dct1d_8_mage(&mut reg, token);
for r in &mut reg {
*r *= scale;
}
mf32x8::transpose_8x8(&mut reg);
dct1d_8_mage(&mut reg, token);
for r in &mut reg {
*r *= scale;
}
mf32x8::store_8x8(®, output);
}
}
#[allow(dead_code)]
#[inline]
fn dct_rows(input: &[f32; 64], output: &mut [f32; 64]) {
for row in 0..8 {
let mut tmp = [0.0f32; 8];
for i in 0..8 {
tmp[i] = input[row * 8 + i];
}
dct1d_8(&mut tmp);
for i in 0..8 {
output[row * 8 + i] = tmp[i];
}
}
}
#[must_use]
#[inline]
pub fn forward_dct_8x8(input: &[f32; DCT_BLOCK_SIZE]) -> [f32; DCT_BLOCK_SIZE] {
#[cfg(target_arch = "x86_64")]
{
if archmage::X64V3Token::summon().is_some() {
let mut output = [0.0f32; 64];
simd::forward_dct_8x8_mage(input, &mut output);
return output;
}
}
forward_dct_8x8_scalar(input)
}
#[inline]
pub fn forward_dct_8x8_scalar(input: &[f32; DCT_BLOCK_SIZE]) -> [f32; DCT_BLOCK_SIZE] {
simd::forward_dct_8x8_simd_chained(input)
}
#[must_use]
pub fn forward_dct_8x8_u8(input: &[u8; DCT_BLOCK_SIZE]) -> [f32; DCT_BLOCK_SIZE] {
let mut shifted = [0.0f32; DCT_BLOCK_SIZE];
let level_shift = f32x8::splat(128.0);
for chunk in 0..8 {
let k = chunk * 8;
let v = f32x8::from([
input[k] as f32,
input[k + 1] as f32,
input[k + 2] as f32,
input[k + 3] as f32,
input[k + 4] as f32,
input[k + 5] as f32,
input[k + 6] as f32,
input[k + 7] as f32,
]);
let result = v - level_shift;
let arr: [f32; 8] = result.into();
shifted[k..k + 8].copy_from_slice(&arr);
}
forward_dct_8x8(&shifted)
}
pub fn forward_dct_blocks(blocks: &[[f32; DCT_BLOCK_SIZE]]) -> Vec<[f32; DCT_BLOCK_SIZE]> {
blocks.iter().map(forward_dct_8x8).collect()
}
#[inline(always)]
fn aan_dct_1d(d: &mut [f32; 8]) {
let tmp0 = d[0] + d[7];
let tmp7 = d[0] - d[7];
let tmp1 = d[1] + d[6];
let tmp6 = d[1] - d[6];
let tmp2 = d[2] + d[5];
let tmp5 = d[2] - d[5];
let tmp3 = d[3] + d[4];
let tmp4 = d[3] - d[4];
let tmp10 = tmp0 + tmp3;
let tmp13 = tmp0 - tmp3;
let tmp11 = tmp1 + tmp2;
let tmp12 = tmp1 - tmp2;
d[0] = tmp10 + tmp11; d[4] = tmp10 - tmp11;
let z1 = (tmp12 + tmp13) * aan::C4;
d[2] = tmp13 + z1;
d[6] = tmp13 - z1;
let tmp10 = tmp4 + tmp5;
let tmp11 = tmp5 + tmp6;
let tmp12 = tmp6 + tmp7;
let z5 = (tmp10 - tmp12) * aan::C6;
let z2 = aan::C2_M_C6 * tmp10 + z5;
let z4 = aan::C2_P_C6 * tmp12 + z5;
let z3 = tmp11 * aan::C4;
let z11 = tmp7 + z3;
let z13 = tmp7 - z3;
d[5] = z13 + z2;
d[3] = z13 - z2;
d[1] = z11 + z4;
d[7] = z11 - z4;
d[0] *= aan::INV_SCALES[0];
d[1] *= aan::INV_SCALES[1];
d[2] *= aan::INV_SCALES[2];
d[3] *= aan::INV_SCALES[3];
d[4] *= aan::INV_SCALES[4];
d[5] *= aan::INV_SCALES[5];
d[6] *= aan::INV_SCALES[6];
d[7] *= aan::INV_SCALES[7];
}
#[inline]
#[must_use]
pub fn aan_forward_dct_8x8(input: &[f32; DCT_BLOCK_SIZE]) -> [f32; DCT_BLOCK_SIZE] {
let mut data = *input;
for row in 0..8 {
let offset = row * 8;
let mut row_data: [f32; 8] = [
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
data[offset + 4],
data[offset + 5],
data[offset + 6],
data[offset + 7],
];
aan_dct_1d(&mut row_data);
data[offset] = row_data[0];
data[offset + 1] = row_data[1];
data[offset + 2] = row_data[2];
data[offset + 3] = row_data[3];
data[offset + 4] = row_data[4];
data[offset + 5] = row_data[5];
data[offset + 6] = row_data[6];
data[offset + 7] = row_data[7];
}
for v in &mut data {
*v *= 0.125;
}
for col in 0..8 {
let mut col_data: [f32; 8] = [
data[col],
data[col + 8],
data[col + 16],
data[col + 24],
data[col + 32],
data[col + 40],
data[col + 48],
data[col + 56],
];
aan_dct_1d(&mut col_data);
data[col] = col_data[0];
data[col + 8] = col_data[1];
data[col + 16] = col_data[2];
data[col + 24] = col_data[3];
data[col + 32] = col_data[4];
data[col + 40] = col_data[5];
data[col + 48] = col_data[6];
data[col + 56] = col_data[7];
}
for v in &mut data {
*v *= 0.125;
}
data
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dct_dc_only() {
let input = [128.0f32; DCT_BLOCK_SIZE];
let output = forward_dct_8x8(&input);
assert!(output[0].abs() > 1.0);
for i in 1..DCT_BLOCK_SIZE {
assert!(
output[i].abs() < 0.001,
"AC[{}] = {} should be ~0",
i,
output[i]
);
}
}
#[test]
fn test_dct_zero_block() {
let input = [0.0f32; DCT_BLOCK_SIZE];
let output = forward_dct_8x8(&input);
for (i, &v) in output.iter().enumerate() {
assert!(v.abs() < 0.001, "coeff[{}] = {} should be 0", i, v);
}
}
#[test]
fn test_dct_u8_level_shift() {
let input = [128u8; DCT_BLOCK_SIZE];
let output = forward_dct_8x8_u8(&input);
for (i, &v) in output.iter().enumerate() {
assert!(v.abs() < 0.001, "coeff[{}] = {} should be 0", i, v);
}
}
#[test]
fn test_transpose() {
let mut input = [0.0f32; 64];
for i in 0..64 {
input[i] = i as f32;
}
let mut output = [0.0f32; 64];
transpose_8x8(&input, &mut output);
assert_eq!(output[0], 0.0); assert_eq!(output[1], 8.0); assert_eq!(output[8], 1.0); assert_eq!(output[9], 9.0); }
#[test]
fn test_simd_transpose_matches_scalar() {
let mut input = [0.0f32; 64];
for i in 0..64 {
input[i] = (i as f32 * 1.7).sin() * 100.0;
}
let mut scalar_output = [0.0f32; 64];
let mut simd_output = [0.0f32; 64];
transpose_8x8(&input, &mut scalar_output);
simd::transpose_8x8_simd(&input, &mut simd_output);
for i in 0..64 {
assert!(
(scalar_output[i] - simd_output[i]).abs() < 1e-6,
"Mismatch at {}: scalar={} simd={}",
i,
scalar_output[i],
simd_output[i]
);
}
}
#[test]
fn test_simd_dct_rows_matches_scalar() {
let patterns: Vec<[f32; 64]> = vec![
[64.0; 64],
{
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = i as f32;
}
arr
},
{
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = (i as f32 * 0.3).sin() * 100.0;
}
arr
},
{
let mut arr = [0.0f32; 64];
for row in 0..8 {
for col in 0..8 {
arr[row * 8 + col] = if (row + col) % 2 == 0 { 100.0 } else { -100.0 };
}
}
arr
},
];
for (pattern_idx, input) in patterns.iter().enumerate() {
let mut scalar_output = [0.0f32; 64];
for row in 0..8 {
let mut tmp = [0.0f32; 8];
for i in 0..8 {
tmp[i] = input[row * 8 + i];
}
dct1d_8(&mut tmp);
for i in 0..8 {
scalar_output[row * 8 + i] = tmp[i];
}
}
let mut simd_output = [0.0f32; 64];
simd::dct_8rows_parallel(input, &mut simd_output);
let mut max_error = 0.0f32;
for i in 0..64 {
let error = (scalar_output[i] - simd_output[i]).abs();
max_error = max_error.max(error);
}
assert!(
max_error < 1e-4,
"Pattern {}: SIMD DCT rows max error {} exceeds threshold",
pattern_idx,
max_error
);
}
}
#[test]
fn test_aan_dct_matches_recursive() {
let patterns: Vec<[f32; 64]> = vec![
[64.0; 64],
{
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = i as f32;
}
arr
},
{
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = (i as f32 * 0.3).sin() * 100.0;
}
arr
},
{
let mut arr = [0.0f32; 64];
for row in 0..8 {
for col in 0..8 {
arr[row * 8 + col] = if (row + col) % 2 == 0 { 100.0 } else { -100.0 };
}
}
arr
},
{
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = ((i * 17 + 31) % 256) as f32 - 128.0;
}
arr
},
];
for (pattern_idx, input) in patterns.iter().enumerate() {
let recursive_output = forward_dct_8x8(input);
let aan_output = aan_forward_dct_8x8(input);
let mut max_error = 0.0f32;
for i in 0..64 {
let error = (recursive_output[i] - aan_output[i]).abs();
max_error = max_error.max(error);
}
assert!(
max_error < 1e-4,
"Pattern {}: AAN vs recursive max error {} exceeds threshold.\n\
Recursive[0..8]: {:?}\n\
AAN[0..8]: {:?}",
pattern_idx,
max_error,
&recursive_output[0..8],
&aan_output[0..8]
);
}
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_mage_dct_matches_scalar() {
let _lock = archmage::testing::lock_token_testing();
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
return;
}
let patterns: Vec<[f32; 64]> = vec![
[64.0; 64],
{
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = i as f32;
}
arr
},
{
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = ((i as f32) * 0.5).sin() * 100.0 + 50.0;
}
arr
},
];
for (pattern_idx, input) in patterns.iter().enumerate() {
let scalar_output = forward_dct_8x8_scalar(input);
let mut mage_output = [0.0f32; 64];
simd::forward_dct_8x8_mage(input, &mut mage_output);
let mut max_error = 0.0f32;
for i in 0..64 {
let error = (scalar_output[i] - mage_output[i]).abs();
max_error = max_error.max(error);
}
assert!(
max_error < 1e-4,
"Pattern {}: magetypes vs scalar max error {} exceeds threshold.\n\
Scalar[0..8]: {:?}\n\
Mage[0..8]: {:?}",
pattern_idx,
max_error,
&scalar_output[0..8],
&mage_output[0..8]
);
}
}
#[cfg(target_arch = "x86_64")]
#[test]
#[ignore] fn bench_mage_dct_vs_scalar() {
use std::hint::black_box;
use std::time::Instant;
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
println!("AVX2+FMA not available, skipping benchmark");
return;
}
let input: [f32; 64] = [
52., 55., 61., 66., 70., 61., 64., 73., 63., 59., 55., 90., 109., 85., 69., 72., 62.,
59., 68., 113., 144., 104., 66., 73., 63., 58., 71., 122., 154., 106., 70., 69., 67.,
61., 68., 104., 126., 88., 68., 70., 79., 65., 60., 70., 77., 68., 58., 75., 85., 71.,
64., 59., 55., 61., 65., 83., 87., 79., 69., 68., 65., 76., 78., 94.,
];
let iterations = 10_000_000;
let mut output = [0.0f32; 64];
for _ in 0..10000 {
let _ = forward_dct_8x8_scalar(black_box(&input));
}
let start = Instant::now();
for _ in 0..iterations {
let result = forward_dct_8x8_scalar(black_box(&input));
black_box(&result);
}
let scalar_time = start.elapsed();
let scalar_ns = scalar_time.as_nanos() as f64 / iterations as f64;
for _ in 0..10000 {
simd::forward_dct_8x8_mage(black_box(&input), &mut output);
}
let start = Instant::now();
for _ in 0..iterations {
simd::forward_dct_8x8_mage(black_box(&input), black_box(&mut output));
}
let mage_time = start.elapsed();
let mage_ns = mage_time.as_nanos() as f64 / iterations as f64;
println!("DCT Performance ({} iterations):", iterations);
println!(" Scalar: {:.2} ns/DCT", scalar_ns);
println!(" Magetypes: {:.2} ns/DCT", mage_ns);
println!(
" Speedup: {:.2}x ({})",
scalar_ns / mage_ns,
if mage_ns < scalar_ns {
"magetypes faster"
} else {
"Scalar faster"
}
);
}
#[test]
fn test_wide_dct_matches_array_dct() {
use crate::foundation::simd_types::Block8x8f;
let patterns: [[f32; 64]; 4] = [
[128.0; 64], {
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = i as f32;
}
arr
},
{
let mut arr = [0.0f32; 64];
for i in 0..64 {
arr[i] = (i as f32 * 0.3).sin() * 100.0;
}
arr
},
{
let mut arr = [0.0f32; 64];
for row in 0..8 {
for col in 0..8 {
arr[row * 8 + col] = if (row + col) % 2 == 0 { 100.0 } else { -100.0 };
}
}
arr
},
];
for (idx, input) in patterns.iter().enumerate() {
let array_output = forward_dct_8x8(input);
let block_input = Block8x8f::from_array(input);
let block_output = simd::forward_dct_8x8_wide(&block_input);
let wide_output = block_output.to_array();
let mut max_error = 0.0f32;
for i in 0..64 {
let error = (array_output[i] - wide_output[i]).abs();
max_error = max_error.max(error);
}
assert!(
max_error < 1e-4,
"Pattern {}: wide DCT differs from array DCT by {}",
idx,
max_error
);
}
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_dct_dispatch_parity() {
use archmage::testing::{CompileTimePolicy, for_each_token_permutation};
let patterns: Vec<[f32; 64]> = vec![
[42.0; 64],
core::array::from_fn(|i| i as f32),
core::array::from_fn(|i| ((i % 8) as f32 - 3.5) * 20.0),
core::array::from_fn(|i| {
let r = i / 8;
let c = i % 8;
((r * 7 + c * 13) % 256) as f32 - 128.0
}),
];
for (idx, input) in patterns.iter().enumerate() {
let reference = forward_dct_8x8(input);
let report = for_each_token_permutation(CompileTimePolicy::Warn, |perm| {
let result = forward_dct_8x8(input);
for k in 0..64 {
let diff = (result[k] - reference[k]).abs();
assert!(
diff < 1e-4,
"DCT pattern {idx} coeff {k}: ref={} got={} diff={diff} at {perm}",
reference[k],
result[k]
);
}
});
if idx == 0 {
eprintln!("dct_dispatch: {report}");
assert!(
report.permutations_run >= 2,
"expected at least 2 permutations"
);
}
}
}
}