#[cfg(not(target_arch = "x86_64"))]
fn main() {}
#[cfg(target_arch = "x86_64")]
mod x86_impl {
#![allow(clippy::excessive_precision)] #![allow(clippy::too_many_arguments)] #![allow(clippy::empty_line_after_outer_attr)]
use archmage::{SimdToken, X64V3Token, arcane};
use magetypes::simd::f32x8;
use std::time::Instant;
#[arcane]
fn dct1d_8(
token: X64V3Token,
v0: f32x8,
v1: f32x8,
v2: f32x8,
v3: f32x8,
v4: f32x8,
v5: f32x8,
v6: f32x8,
v7: f32x8,
) -> [f32x8; 8] {
let c0 = f32x8::splat(token, 0.353_553_4);
let c10 = f32x8::splat(token, 0.490_392_6);
let c11 = f32x8::splat(token, 0.415_734_8);
let c12 = f32x8::splat(token, 0.277_785_1);
let c13 = f32x8::splat(token, 0.097_545_16);
let c20 = f32x8::splat(token, 0.461_939_8);
let c21 = f32x8::splat(token, 0.191_341_7);
let nc10 = f32x8::splat(token, -0.490_392_6);
let nc11 = f32x8::splat(token, -0.415_734_8);
let nc12 = f32x8::splat(token, -0.277_785_1);
let nc13 = f32x8::splat(token, -0.097_545_16);
let nc20 = f32x8::splat(token, -0.461_939_8);
let nc21 = f32x8::splat(token, -0.191_341_7);
let out0 = (v0 + v1 + v2 + v3 + v4 + v5 + v6 + v7) * c0;
let out1 = v0.mul_add(
c10,
v1.mul_add(
c11,
v2.mul_add(
c12,
v3.mul_add(
c13,
v4.mul_add(nc13, v5.mul_add(nc12, v6.mul_add(nc11, v7 * nc10))),
),
),
),
);
let out2 = v0.mul_add(
c20,
v1.mul_add(
c21,
v2.mul_add(
nc21,
v3.mul_add(
nc20,
v4.mul_add(nc20, v5.mul_add(nc21, v6.mul_add(c21, v7 * c20))),
),
),
),
);
let out3 = v0.mul_add(
c11,
v1.mul_add(
nc13,
v2.mul_add(
nc10,
v3.mul_add(
nc12,
v4.mul_add(c12, v5.mul_add(c10, v6.mul_add(c13, v7 * nc11))),
),
),
),
);
let out4 = (v0 - v1 - v2 + v3 + v4 - v5 - v6 + v7) * c0;
let out5 = v0.mul_add(
c12,
v1.mul_add(
nc10,
v2.mul_add(
c13,
v3.mul_add(
c11,
v4.mul_add(nc11, v5.mul_add(nc13, v6.mul_add(c10, v7 * nc12))),
),
),
),
);
let out6 = v0.mul_add(
c21,
v1.mul_add(
nc20,
v2.mul_add(
c20,
v3.mul_add(
nc21,
v4.mul_add(nc21, v5.mul_add(c20, v6.mul_add(nc20, v7 * c21))),
),
),
),
);
let out7 = v0.mul_add(
c13,
v1.mul_add(
nc12,
v2.mul_add(
c11,
v3.mul_add(
nc10,
v4.mul_add(c10, v5.mul_add(nc11, v6.mul_add(c12, v7 * nc13))),
),
),
),
);
[out0, out1, out2, out3, out4, out5, out6, out7]
}
#[arcane]
fn transpose_8x8_vecs(token: X64V3Token, rows: &[f32x8; 8]) -> [f32x8; 8] {
use core::arch::x86_64::*;
let r0 = rows[0].raw();
let r1 = rows[1].raw();
let r2 = rows[2].raw();
let r3 = rows[3].raw();
let r4 = rows[4].raw();
let r5 = rows[5].raw();
let r6 = rows[6].raw();
let r7 = rows[7].raw();
let t0 = _mm256_unpacklo_ps(r0, r1);
let t1 = _mm256_unpackhi_ps(r0, r1);
let t2 = _mm256_unpacklo_ps(r2, r3);
let t3 = _mm256_unpackhi_ps(r2, r3);
let t4 = _mm256_unpacklo_ps(r4, r5);
let t5 = _mm256_unpackhi_ps(r4, r5);
let t6 = _mm256_unpacklo_ps(r6, r7);
let t7 = _mm256_unpackhi_ps(r6, r7);
let s0 = _mm256_shuffle_ps::<0x44>(t0, t2);
let s1 = _mm256_shuffle_ps::<0xEE>(t0, t2);
let s2 = _mm256_shuffle_ps::<0x44>(t1, t3);
let s3 = _mm256_shuffle_ps::<0xEE>(t1, t3);
let s4 = _mm256_shuffle_ps::<0x44>(t4, t6);
let s5 = _mm256_shuffle_ps::<0xEE>(t4, t6);
let s6 = _mm256_shuffle_ps::<0x44>(t5, t7);
let s7 = _mm256_shuffle_ps::<0xEE>(t5, t7);
let c0 = _mm256_permute2f128_ps::<0x20>(s0, s4);
let c1 = _mm256_permute2f128_ps::<0x20>(s1, s5);
let c2 = _mm256_permute2f128_ps::<0x20>(s2, s6);
let c3 = _mm256_permute2f128_ps::<0x20>(s3, s7);
let c4 = _mm256_permute2f128_ps::<0x31>(s0, s4);
let c5 = _mm256_permute2f128_ps::<0x31>(s1, s5);
let c6 = _mm256_permute2f128_ps::<0x31>(s2, s6);
let c7 = _mm256_permute2f128_ps::<0x31>(s3, s7);
[
f32x8::from_m256(token, c0),
f32x8::from_m256(token, c1),
f32x8::from_m256(token, c2),
f32x8::from_m256(token, c3),
f32x8::from_m256(token, c4),
f32x8::from_m256(token, c5),
f32x8::from_m256(token, c6),
f32x8::from_m256(token, c7),
]
}
#[arcane]
fn load_block(token: X64V3Token, block: &[f32; 64]) -> [f32x8; 8] {
[
f32x8::load(token, block[0..8].try_into().unwrap()),
f32x8::load(token, block[8..16].try_into().unwrap()),
f32x8::load(token, block[16..24].try_into().unwrap()),
f32x8::load(token, block[24..32].try_into().unwrap()),
f32x8::load(token, block[32..40].try_into().unwrap()),
f32x8::load(token, block[40..48].try_into().unwrap()),
f32x8::load(token, block[48..56].try_into().unwrap()),
f32x8::load(token, block[56..64].try_into().unwrap()),
]
}
#[arcane]
fn store_block(token: X64V3Token, vecs: &[f32x8; 8], block: &mut [f32; 64]) {
vecs[0].store((&mut block[0..8]).try_into().unwrap());
vecs[1].store((&mut block[8..16]).try_into().unwrap());
vecs[2].store((&mut block[16..24]).try_into().unwrap());
vecs[3].store((&mut block[24..32]).try_into().unwrap());
vecs[4].store((&mut block[32..40]).try_into().unwrap());
vecs[5].store((&mut block[40..48]).try_into().unwrap());
vecs[6].store((&mut block[48..56]).try_into().unwrap());
vecs[7].store((&mut block[56..64]).try_into().unwrap());
}
#[arcane]
pub fn fast_dct8x8(token: X64V3Token, block: &mut [f32; 64]) {
let rows = load_block(token, block);
let dct_rows = dct1d_8(
token, rows[0], rows[1], rows[2], rows[3], rows[4], rows[5], rows[6], rows[7],
);
let cols = transpose_8x8_vecs(token, &dct_rows);
let dct_cols = dct1d_8(
token, cols[0], cols[1], cols[2], cols[3], cols[4], cols[5], cols[6], cols[7],
);
let result = transpose_8x8_vecs(token, &dct_cols);
store_block(token, &result, block);
}
#[arcane]
pub fn fast_dct8x8_batch(token: X64V3Token, blocks: &mut [[f32; 64]]) {
for block in blocks {
fast_dct8x8(token, block);
}
}
const DCT_COEFF: [[f32; 8]; 8] = [
[
0.353553391,
0.353553391,
0.353553391,
0.353553391,
0.353553391,
0.353553391,
0.353553391,
0.353553391,
],
[
0.490392640,
0.415734806,
0.277785117,
0.097545161,
-0.097545161,
-0.277785117,
-0.415734806,
-0.490392640,
],
[
0.461939766,
0.191341716,
-0.191341716,
-0.461939766,
-0.461939766,
-0.191341716,
0.191341716,
0.461939766,
],
[
0.415734806,
-0.097545161,
-0.490392640,
-0.277785117,
0.277785117,
0.490392640,
0.097545161,
-0.415734806,
],
[
0.353553391,
-0.353553391,
-0.353553391,
0.353553391,
0.353553391,
-0.353553391,
-0.353553391,
0.353553391,
],
[
0.277785117,
-0.490392640,
0.097545161,
0.415734806,
-0.415734806,
-0.097545161,
0.490392640,
-0.277785117,
],
[
0.191341716,
-0.461939766,
0.461939766,
-0.191341716,
-0.191341716,
0.461939766,
-0.461939766,
0.191341716,
],
[
0.097545161,
-0.277785117,
0.415734806,
-0.490392640,
0.490392640,
-0.415734806,
0.277785117,
-0.097545161,
],
];
fn dct8_scalar(input: &[f32; 8], output: &mut [f32; 8]) {
for k in 0..8 {
let mut sum = 0.0f32;
for n in 0..8 {
sum += input[n] * DCT_COEFF[k][n];
}
output[k] = sum;
}
}
fn dct8x8_scalar(block: &mut [f32; 64]) {
let mut temp = [0.0f32; 64];
for row in 0..8 {
let input: [f32; 8] = block[row * 8..(row + 1) * 8].try_into().unwrap();
let mut output = [0.0f32; 8];
dct8_scalar(&input, &mut output);
temp[row * 8..(row + 1) * 8].copy_from_slice(&output);
}
for i in 0..8 {
for j in 0..8 {
block[i * 8 + j] = temp[j * 8 + i];
}
}
for row in 0..8 {
let input: [f32; 8] = block[row * 8..(row + 1) * 8].try_into().unwrap();
let mut output = [0.0f32; 8];
dct8_scalar(&input, &mut output);
temp[row * 8..(row + 1) * 8].copy_from_slice(&output);
}
for i in 0..8 {
for j in 0..8 {
block[i * 8 + j] = temp[j * 8 + i];
}
}
}
fn test_correctness() {
println!("=== Correctness Tests ===\n");
if let Some(token) = X64V3Token::summon() {
let mut scalar_block: [f32; 64] = core::array::from_fn(|i| i as f32);
let mut fast_block = scalar_block;
dct8x8_scalar(&mut scalar_block);
fast_dct8x8(token, &mut fast_block);
let max_error: f32 = scalar_block
.iter()
.zip(fast_block.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0, f32::max);
let avg_magnitude: f32 = scalar_block.iter().map(|x| x.abs()).sum::<f32>() / 64.0;
let relative_error = max_error / avg_magnitude;
println!(" Gradient block:");
println!(" Max absolute error: {:.2e}", max_error);
println!(" Relative error: {:.2e}", relative_error);
println!(" DC coefficient (scalar): {:.2}", scalar_block[0]);
println!(" DC coefficient (fast): {:.2}", fast_block[0]);
let mut const_block: [f32; 64] = [100.0; 64];
fast_dct8x8(token, &mut const_block);
println!("\n Constant block (100.0):");
println!(
" DC: {:.2} (expected ~800 for unnormalized)",
const_block[0]
);
println!(
" AC max: {:.2e}",
const_block[1..]
.iter()
.map(|x| x.abs())
.fold(0.0f32, f32::max)
);
} else {
println!(" AVX2 not available, skipping tests");
}
println!();
}
fn benchmark() {
const ITERATIONS: u32 = 50_000;
const BATCH_SIZE: usize = 1024;
const BATCH_ITERS: u32 = 100;
println!("=== Benchmarks ===\n");
let original: [f32; 64] = core::array::from_fn(|i| (i as f32) - 32.0);
let mut block = original;
let start = Instant::now();
for _ in 0..ITERATIONS {
dct8x8_scalar(&mut block);
std::hint::black_box(&block);
}
let scalar_time = start.elapsed();
let scalar_blocks_per_sec = ITERATIONS as f64 / scalar_time.as_secs_f64();
println!(
" Scalar DCT-8x8: {:>8.2} ms ({:.1}M blocks/sec)",
scalar_time.as_secs_f64() * 1000.0,
scalar_blocks_per_sec / 1_000_000.0
);
if let Some(token) = X64V3Token::summon() {
let start = Instant::now();
for _ in 0..ITERATIONS {
block = original;
fast_dct8x8(token, &mut block);
std::hint::black_box(&block);
}
let fast_time = start.elapsed();
let fast_blocks_per_sec = ITERATIONS as f64 / fast_time.as_secs_f64();
println!(
" Fast AVX2 DCT-8x8: {:>8.2} ms ({:.1}M blocks/sec, {:.1}x faster)",
fast_time.as_secs_f64() * 1000.0,
fast_blocks_per_sec / 1_000_000.0,
scalar_time.as_secs_f64() / fast_time.as_secs_f64()
);
println!(
"\n Batch processing ({} blocks x {} iterations):\n",
BATCH_SIZE, BATCH_ITERS
);
let original_batch: Vec<[f32; 64]> = (0..BATCH_SIZE)
.map(|b| core::array::from_fn(|i| ((b * 64 + i) % 256) as f32 - 128.0))
.collect();
let mut batch = original_batch.clone();
let start = Instant::now();
for _ in 0..BATCH_ITERS {
batch.copy_from_slice(&original_batch);
for block in &mut batch {
dct8x8_scalar(block);
}
std::hint::black_box(&batch);
}
let scalar_batch_time = start.elapsed();
let scalar_batch_rate =
(BATCH_SIZE as f64 * BATCH_ITERS as f64) / scalar_batch_time.as_secs_f64();
println!(
" Scalar: {:>8.2} ms ({:.1}M blocks/sec)",
scalar_batch_time.as_secs_f64() * 1000.0,
scalar_batch_rate / 1_000_000.0
);
let start = Instant::now();
for _ in 0..BATCH_ITERS {
batch.copy_from_slice(&original_batch);
fast_dct8x8_batch(token, &mut batch);
std::hint::black_box(&batch);
}
let fast_batch_time = start.elapsed();
let fast_batch_rate =
(BATCH_SIZE as f64 * BATCH_ITERS as f64) / fast_batch_time.as_secs_f64();
println!(
" Fast AVX2:{:>8.2} ms ({:.1}M blocks/sec, {:.1}x faster)",
fast_batch_time.as_secs_f64() * 1000.0,
fast_batch_rate / 1_000_000.0,
scalar_batch_time.as_secs_f64() / fast_batch_time.as_secs_f64()
);
} else {
println!(" AVX2 not available");
}
println!();
}
pub fn main() {
println!("\n╔══════════════════════════════════════════════════════════════╗");
println!("║ Fast DCT-8x8 using archmage SIMD vectors ║");
println!("╚══════════════════════════════════════════════════════════════╝\n");
println!("This implementation processes 8 rows in parallel using AVX2,");
println!("with FMA (fused multiply-add) for maximum throughput.\n");
test_correctness();
benchmark();
println!("=== Algorithm Summary ===\n");
println!(" Vectorized matrix multiplication:");
println!(" 1. Load 8 rows as 8 f32x8 vectors (column-major layout)");
println!(" 2. Compute DCT using FMA chains for each output coefficient");
println!(" 3. Transpose 8x8 using AVX2 shuffle/permute intrinsics");
println!(" 4. Repeat DCT for column transform");
println!();
println!(" Key optimizations:");
println!(" - Each f32x8 lane holds one row's value at same column");
println!(" - FMA chains: 7 mul_add ops per output row = 1 cycle each");
println!(" - In-register transpose: no memory round-trip");
println!(" - ~6-7x speedup over scalar (batch: 7-10x)");
println!();
}
}
#[cfg(target_arch = "x86_64")]
fn main() {
x86_impl::main()
}