#[allow(clippy::excessive_precision)]
const SQRT2: f32 = core::f32::consts::SQRT_2;
#[allow(clippy::excessive_precision)]
const WC_MULTIPLIERS_4: [f32; 2] = [0.541196100146197, 1.3065629648763764];
#[allow(clippy::excessive_precision)]
const WC_MULTIPLIERS_8: [f32; 4] = [
0.5097955791041592,
0.6013448869350453,
0.8999762231364156,
2.5629154477415055,
];
#[allow(clippy::excessive_precision)]
const WC_MULTIPLIERS_16: [f32; 8] = [
0.5024192861881557,
0.5224986149396889,
0.5669440348163577,
0.6468217833599901,
0.7881546234512502,
1.060677685990347,
1.7224470982383342,
5.101148618689155,
];
#[allow(clippy::excessive_precision)]
const WC_MULTIPLIERS_32: [f32; 16] = [
0.5006029982351963,
0.5054709598975436,
0.5154473099226246,
0.5310425910897841,
0.5531038960344445,
0.5829349682061339,
0.6225041230356648,
0.6748083414550057,
0.7445362710022986,
0.8393496454155268,
0.9725682378619608,
1.1694399334328847,
1.4841646163141662,
2.057781009953411,
3.407608418468719,
10.190008123548033,
];
#[inline]
fn dct1d_2_scalar(mem: &mut [f32]) {
let x = mem[0] + mem[1];
let y = mem[0] - mem[1];
mem[0] = x;
mem[1] = y;
}
fn dct1d_4_scalar(mem: &mut [f32]) {
let mut tmp = crate::scratch_buf::<4>();
tmp[0] = mem[0] + mem[3];
tmp[1] = mem[1] + mem[2];
tmp[2] = mem[0] - mem[3];
tmp[3] = mem[1] - mem[2];
dct1d_2_scalar(&mut tmp[0..2]);
tmp[2] *= WC_MULTIPLIERS_4[0];
tmp[3] *= WC_MULTIPLIERS_4[1];
dct1d_2_scalar(&mut tmp[2..4]);
tmp[2] = SQRT2 * tmp[2] + tmp[3];
mem[0] = tmp[0];
mem[1] = tmp[2];
mem[2] = tmp[1];
mem[3] = tmp[3];
}
fn dct1d_8_scalar(mem: &mut [f32]) {
let mut tmp = crate::scratch_buf::<8>();
for i in 0..4 {
tmp[i] = mem[i] + mem[7 - i];
tmp[4 + i] = mem[i] - mem[7 - i];
}
dct1d_4_scalar(&mut tmp[0..4]);
for i in 0..4 {
tmp[4 + i] *= WC_MULTIPLIERS_8[i];
}
dct1d_4_scalar(&mut tmp[4..8]);
tmp[4] = SQRT2 * tmp[4] + tmp[5];
for i in 1..3 {
tmp[4 + i] += tmp[4 + i + 1];
}
for i in 0..4 {
mem[2 * i] = tmp[i];
mem[2 * i + 1] = tmp[4 + i];
}
}
fn dct1d_16_scalar(mem: &mut [f32]) {
let mut tmp = crate::scratch_buf::<16>();
for i in 0..8 {
tmp[i] = mem[i] + mem[15 - i];
tmp[8 + i] = mem[i] - mem[15 - i];
}
dct1d_8_scalar(&mut tmp[0..8]);
for i in 0..8 {
tmp[8 + i] *= WC_MULTIPLIERS_16[i];
}
dct1d_8_scalar(&mut tmp[8..16]);
tmp[8] = SQRT2 * tmp[8] + tmp[9];
for i in 1..7 {
tmp[8 + i] += tmp[8 + i + 1];
}
for i in 0..8 {
mem[2 * i] = tmp[i];
mem[2 * i + 1] = tmp[8 + i];
}
}
fn dct1d_32_scalar(mem: &mut [f32]) {
let mut tmp = crate::scratch_buf::<32>();
for i in 0..16 {
tmp[i] = mem[i] + mem[31 - i];
tmp[16 + i] = mem[i] - mem[31 - i];
}
dct1d_16_scalar(&mut tmp[0..16]);
for i in 0..16 {
tmp[16 + i] *= WC_MULTIPLIERS_32[i];
}
dct1d_16_scalar(&mut tmp[16..32]);
tmp[16] = SQRT2 * tmp[16] + tmp[17];
for i in 1..15 {
tmp[16 + i] += tmp[16 + i + 1];
}
for i in 0..16 {
mem[2 * i] = tmp[i];
mem[2 * i + 1] = tmp[16 + i];
}
}
#[inline]
pub fn dct_32x32_scalar(input: &[f32; 1024], output: &mut [f32; 1024]) {
let mut tmp = crate::scratch_buf::<1024>();
for row in 0..32 {
let s = row * 32;
tmp[s..s + 32].copy_from_slice(&input[s..s + 32]);
dct1d_32_scalar(&mut tmp[s..s + 32]);
for v in tmp[s..s + 32].iter_mut() {
*v *= 1.0 / 32.0;
}
}
let mut transposed = crate::scratch_buf::<1024>();
for r in 0..32 {
for c in 0..32 {
transposed[c * 32 + r] = tmp[r * 32 + c];
}
}
for row in 0..32 {
let s = row * 32;
dct1d_32_scalar(&mut transposed[s..s + 32]);
for v in transposed[s..s + 32].iter_mut() {
*v *= 1.0 / 32.0;
}
}
output.copy_from_slice(&transposed);
}
#[inline]
pub fn dct_32x16_scalar(input: &[f32; 512], output: &mut [f32; 512]) {
let mut tmp = crate::scratch_buf::<512>();
for row in 0..32 {
let s = row * 16;
tmp[s..s + 16].copy_from_slice(&input[s..s + 16]);
dct1d_16_scalar(&mut tmp[s..s + 16]);
for v in tmp[s..s + 16].iter_mut() {
*v *= 1.0 / 16.0;
}
}
let mut transposed = crate::scratch_buf::<512>();
for r in 0..32 {
for c in 0..16 {
transposed[c * 32 + r] = tmp[r * 16 + c];
}
}
for row in 0..16 {
let s = row * 32;
dct1d_32_scalar(&mut transposed[s..s + 32]);
for v in transposed[s..s + 32].iter_mut() {
*v *= 1.0 / 32.0;
}
}
output.copy_from_slice(&transposed);
}
#[inline]
pub fn dct_16x32_scalar(input: &[f32; 512], output: &mut [f32; 512]) {
let mut tmp = crate::scratch_buf::<512>();
for row in 0..16 {
let s = row * 32;
tmp[s..s + 32].copy_from_slice(&input[s..s + 32]);
dct1d_32_scalar(&mut tmp[s..s + 32]);
for v in tmp[s..s + 32].iter_mut() {
*v *= 1.0 / 32.0;
}
}
let mut transposed = crate::scratch_buf::<512>();
for r in 0..16 {
for c in 0..32 {
transposed[c * 16 + r] = tmp[r * 32 + c];
}
}
for row in 0..32 {
let s = row * 16;
dct1d_16_scalar(&mut transposed[s..s + 16]);
for v in transposed[s..s + 16].iter_mut() {
*v *= 1.0 / 16.0;
}
}
for r in 0..32 {
for c in 0..16 {
output[c * 32 + r] = transposed[r * 16 + c];
}
}
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
fn gather_col(
token: archmage::X64V3Token,
data: &[f32],
base_row: usize,
j: usize,
stride: usize,
) -> magetypes::simd::f32x8 {
crate::gather_col_strided(token, data, base_row, j, stride)
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
fn scatter_col(
v: magetypes::simd::f32x8,
data: &mut [f32],
base_row: usize,
j: usize,
stride: usize,
) {
crate::scatter_col_strided(v, data, base_row, j, stride)
}
#[cfg(target_arch = "x86_64")]
#[archmage::arcane]
#[inline(always)]
pub(crate) fn dct1d_32_batch(token: archmage::X64V3Token, v: &mut [magetypes::simd::f32x8; 32]) {
use magetypes::simd::f32x8;
let sqrt2 = f32x8::splat(token, SQRT2);
let mut a = [f32x8::zero(token); 16];
let mut s = [f32x8::zero(token); 16];
for i in 0..16 {
a[i] = v[i] + v[31 - i];
s[i] = v[i] - v[31 - i];
}
crate::dct16::dct1d_16_batch(token, &mut a);
for i in 0..16 {
s[i] *= f32x8::splat(token, WC_MULTIPLIERS_32[i]);
}
crate::dct16::dct1d_16_batch(token, &mut s);
s[0] = sqrt2 * s[0] + s[1];
for i in 1..15 {
s[i] += s[i + 1];
}
for i in 0..16 {
v[2 * i] = a[i];
v[2 * i + 1] = s[i];
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[archmage::arcane]
#[allow(clippy::needless_range_loop)]
pub fn dct_32x32_avx2(token: archmage::X64V3Token, input: &[f32; 1024], output: &mut [f32; 1024]) {
use magetypes::simd::f32x8;
let inv32 = f32x8::splat(token, 1.0 / 32.0);
let mut tmp = crate::scratch_buf::<1024>();
for batch in 0..4 {
let base = batch * 8;
let mut v = [f32x8::zero(token); 32];
for j in 0..32 {
v[j] = gather_col(token, input, base, j, 32);
}
dct1d_32_batch(token, &mut v);
for j in 0..32 {
v[j] *= inv32;
}
for j in 0..32 {
scatter_col(v[j], &mut tmp, base, j, 32);
}
}
let mut transposed = crate::scratch_buf::<1024>();
for r in 0..32 {
for c in 0..32 {
transposed[c * 32 + r] = tmp[r * 32 + c];
}
}
for batch in 0..4 {
let base = batch * 8;
let mut v = [f32x8::zero(token); 32];
for j in 0..32 {
v[j] = gather_col(token, &transposed, base, j, 32);
}
dct1d_32_batch(token, &mut v);
for j in 0..32 {
v[j] *= inv32;
}
for j in 0..32 {
scatter_col(v[j], output, base, j, 32);
}
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[archmage::arcane]
#[allow(clippy::needless_range_loop)]
pub fn dct_32x16_avx2(token: archmage::X64V3Token, input: &[f32; 512], output: &mut [f32; 512]) {
use magetypes::simd::f32x8;
let inv16 = f32x8::splat(token, 1.0 / 16.0);
let inv32 = f32x8::splat(token, 1.0 / 32.0);
let mut tmp = crate::scratch_buf::<512>();
for batch in 0..4 {
let base = batch * 8;
let mut v = [f32x8::zero(token); 16];
for j in 0..16 {
v[j] = gather_col(token, input, base, j, 16);
}
crate::dct16::dct1d_16_batch(token, &mut v);
for j in 0..16 {
v[j] *= inv16;
}
for j in 0..16 {
scatter_col(v[j], &mut tmp, base, j, 16);
}
}
let mut transposed = crate::scratch_buf::<512>();
for r in 0..32 {
for c in 0..16 {
transposed[c * 32 + r] = tmp[r * 16 + c];
}
}
for batch in 0..2 {
let base = batch * 8;
let mut v = [f32x8::zero(token); 32];
for j in 0..32 {
v[j] = gather_col(token, &transposed, base, j, 32);
}
dct1d_32_batch(token, &mut v);
for j in 0..32 {
v[j] *= inv32;
}
for j in 0..32 {
scatter_col(v[j], output, base, j, 32);
}
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[archmage::arcane]
#[allow(clippy::needless_range_loop)]
pub fn dct_16x32_avx2(token: archmage::X64V3Token, input: &[f32; 512], output: &mut [f32; 512]) {
use magetypes::simd::f32x8;
let inv16 = f32x8::splat(token, 1.0 / 16.0);
let inv32 = f32x8::splat(token, 1.0 / 32.0);
let mut tmp = crate::scratch_buf::<512>();
for batch in 0..2 {
let base = batch * 8;
let mut v = [f32x8::zero(token); 32];
for j in 0..32 {
v[j] = gather_col(token, input, base, j, 32);
}
dct1d_32_batch(token, &mut v);
for j in 0..32 {
v[j] *= inv32;
}
for j in 0..32 {
scatter_col(v[j], &mut tmp, base, j, 32);
}
}
let mut transposed = crate::scratch_buf::<512>();
for r in 0..16 {
for c in 0..32 {
transposed[c * 16 + r] = tmp[r * 32 + c];
}
}
for batch in 0..4 {
let base = batch * 8;
let mut v = [f32x8::zero(token); 16];
for j in 0..16 {
v[j] = gather_col(token, &transposed, base, j, 16);
}
crate::dct16::dct1d_16_batch(token, &mut v);
for j in 0..16 {
v[j] *= inv16;
}
for j in 0..16 {
scatter_col(v[j], &mut transposed, base, j, 16);
}
}
for r in 0..32 {
for c in 0..16 {
output[c * 32 + r] = transposed[r * 16 + c];
}
}
}
#[inline]
pub fn dct_32x32(input: &[f32; 1024], output: &mut [f32; 1024]) {
#[cfg(target_arch = "x86_64")]
{
use archmage::SimdToken;
if let Some(token) = archmage::X64V3Token::summon() {
dct_32x32_avx2(token, input, output);
return;
}
}
dct_32x32_scalar(input, output);
}
#[inline]
pub fn dct_32x16(input: &[f32; 512], output: &mut [f32; 512]) {
#[cfg(target_arch = "x86_64")]
{
use archmage::SimdToken;
if let Some(token) = archmage::X64V3Token::summon() {
dct_32x16_avx2(token, input, output);
return;
}
}
dct_32x16_scalar(input, output);
}
#[inline]
pub fn dct_16x32(input: &[f32; 512], output: &mut [f32; 512]) {
#[cfg(target_arch = "x86_64")]
{
use archmage::SimdToken;
if let Some(token) = archmage::X64V3Token::summon() {
dct_16x32_avx2(token, input, output);
return;
}
}
dct_16x32_scalar(input, output);
}
#[cfg(test)]
mod tests {
extern crate std;
use super::*;
fn assert_simd_matches_scalar_1024(
scalar_fn: fn(&[f32; 1024], &mut [f32; 1024]),
dispatch_fn: fn(&[f32; 1024], &mut [f32; 1024]),
input: &[f32; 1024],
label: &str,
) {
let mut scalar_out = [0.0f32; 1024];
scalar_fn(input, &mut scalar_out);
let report = archmage::testing::for_each_token_permutation(
archmage::testing::CompileTimePolicy::Warn,
|perm| {
let mut simd_out = [0.0f32; 1024];
dispatch_fn(input, &mut simd_out);
let mut max_diff = 0.0f32;
let mut max_idx = 0;
for i in 0..1024 {
let diff = (scalar_out[i] - simd_out[i]).abs();
if diff > max_diff {
max_diff = diff;
max_idx = i;
}
}
assert!(
max_diff < 1e-2,
"{label} max diff = {max_diff} at {max_idx} (scalar={}, simd={}) [{perm}]",
scalar_out[max_idx],
simd_out[max_idx],
);
},
);
std::eprintln!("{label}: {report}");
}
fn assert_simd_matches_scalar_512(
scalar_fn: fn(&[f32; 512], &mut [f32; 512]),
dispatch_fn: fn(&[f32; 512], &mut [f32; 512]),
input: &[f32; 512],
label: &str,
) {
let mut scalar_out = [0.0f32; 512];
scalar_fn(input, &mut scalar_out);
let report = archmage::testing::for_each_token_permutation(
archmage::testing::CompileTimePolicy::Warn,
|perm| {
let mut simd_out = [0.0f32; 512];
dispatch_fn(input, &mut simd_out);
let mut max_diff = 0.0f32;
let mut max_idx = 0;
for i in 0..512 {
let diff = (scalar_out[i] - simd_out[i]).abs();
if diff > max_diff {
max_diff = diff;
max_idx = i;
}
}
assert!(
max_diff < 1e-2,
"{label} max diff = {max_diff} at {max_idx} (scalar={}, simd={}) [{perm}]",
scalar_out[max_idx],
simd_out[max_idx],
);
},
);
std::eprintln!("{label}: {report}");
}
#[test]
fn test_dct_32x32_simd_matches_scalar() {
let mut input = [0.0f32; 1024];
for (i, val) in input.iter_mut().enumerate() {
*val = ((i as f32) * 0.31 + 1.7).cos() * 100.0;
}
assert_simd_matches_scalar_1024(dct_32x32_scalar, dct_32x32, &input, "DCT32x32 cos");
}
#[test]
fn test_dct_32x32_dc_only() {
let mut input = [0.0f32; 1024];
input[0] = 128.0;
assert_simd_matches_scalar_1024(dct_32x32_scalar, dct_32x32, &input, "DCT32x32 DC");
}
#[test]
fn test_dct_32x32_sequential() {
let mut input = [0.0f32; 1024];
for (i, val) in input.iter_mut().enumerate() {
*val = i as f32;
}
assert_simd_matches_scalar_1024(dct_32x32_scalar, dct_32x32, &input, "DCT32x32 seq");
}
#[test]
fn test_dct_32x16_simd_matches_scalar() {
let mut input = [0.0f32; 512];
for (i, val) in input.iter_mut().enumerate() {
*val = ((i as f32) * 0.43 + 2.1).cos() * 80.0;
}
assert_simd_matches_scalar_512(dct_32x16_scalar, dct_32x16, &input, "DCT32x16");
}
#[test]
fn test_dct_16x32_simd_matches_scalar() {
let mut input = [0.0f32; 512];
for (i, val) in input.iter_mut().enumerate() {
*val = ((i as f32) * 0.29 + 0.7).sin() * 120.0;
}
assert_simd_matches_scalar_512(dct_16x32_scalar, dct_16x32, &input, "DCT16x32");
}
#[test]
fn test_dct1d_32_energy() {
let mut input = [0.0f32; 32];
for (i, val) in input.iter_mut().enumerate() {
*val = ((i as f32) * 0.7 + 0.3).sin() * 50.0;
}
let input_energy: f64 = input.iter().map(|x| (*x as f64) * (*x as f64)).sum();
let mut output = input;
dct1d_32_scalar(&mut output);
for v in output.iter_mut() {
*v *= 1.0 / 32.0;
}
let output_energy: f64 = output.iter().map(|x| (*x as f64) * (*x as f64)).sum();
let ratio = output_energy / input_energy;
assert!(
(ratio - 1.0 / 32.0).abs() < 0.001,
"Energy ratio {ratio:.6} far from 1/32 = 0.03125"
);
}
}