const SQRT2: f32 = core::f32::consts::SQRT_2;
const WC_MULTIPLIERS_4: [f32; 2] = [0.541_196_1, 1.306_563];
const WC_MULTIPLIERS_8: [f32; 4] = [0.509_795_6, 0.601_344_9, 0.899_976_2, 2.562_915_5];
const WC_MULTIPLIERS_16: [f32; 8] = [
0.502_419_3,
0.522_498_6,
0.566_944_06,
0.646_821_8,
0.788_154_65,
1.060_677_7,
1.722_447_1,
5.101_148_6,
];
#[inline]
pub fn dct_16x16(input: &[f32; 256], output: &mut [f32; 256]) {
#[cfg(target_arch = "x86_64")]
{
use archmage::SimdToken;
if let Some(token) = archmage::X64V3Token::summon() {
dct_16x16_avx2(token, input, output);
return;
}
}
#[cfg(target_arch = "aarch64")]
{
use archmage::SimdToken;
if let Some(token) = archmage::NeonToken::summon() {
dct_16x16_neon(token, input, output);
return;
}
}
#[cfg(target_arch = "wasm32")]
{
use archmage::SimdToken;
if let Some(token) = archmage::Wasm128Token::summon() {
dct_16x16_wasm128(token, input, output);
return;
}
}
dct_16x16_scalar(input, output);
}
#[inline]
pub fn dct_16x16_scalar(input: &[f32; 256], output: &mut [f32; 256]) {
let mut tmp = crate::scratch_buf::<256>();
for row in 0..16 {
let s = row * 16;
tmp[s..s + 16].copy_from_slice(&input[s..s + 16]);
dct1d_16_scalar(&mut tmp[s..s + 16]);
for i in 0..16 {
tmp[s + i] *= 1.0 / 16.0;
}
}
let mut transposed = crate::scratch_buf::<256>();
for r in 0..16 {
for c in 0..16 {
transposed[c * 16 + r] = tmp[r * 16 + c];
}
}
for row in 0..16 {
let s = row * 16;
dct1d_16_scalar(&mut transposed[s..s + 16]);
for i in 0..16 {
transposed[s + i] *= 1.0 / 16.0;
}
}
output.copy_from_slice(&transposed);
}
#[inline]
fn dct1d_2_scalar(mem: &mut [f32]) {
let a = mem[0];
let b = mem[1];
mem[0] = a + b;
mem[1] = a - b;
}
fn dct1d_4_scalar(mem: &mut [f32]) {
let mut tmp = [0.0f32; 4];
tmp[0] = mem[0] + mem[3];
tmp[1] = mem[1] + mem[2];
tmp[2] = mem[0] - mem[3];
tmp[3] = mem[1] - mem[2];
dct1d_2_scalar(&mut tmp[0..2]);
tmp[2] *= WC_MULTIPLIERS_4[0];
tmp[3] *= WC_MULTIPLIERS_4[1];
dct1d_2_scalar(&mut tmp[2..4]);
tmp[2] = SQRT2 * tmp[2] + tmp[3];
mem[0] = tmp[0];
mem[2] = tmp[1];
mem[1] = tmp[2];
mem[3] = tmp[3];
}
fn dct1d_8_scalar(mem: &mut [f32]) {
let mut tmp = [0.0f32; 8];
for i in 0..4 {
tmp[i] = mem[i] + mem[7 - i];
}
for i in 0..4 {
tmp[4 + i] = mem[i] - mem[7 - i];
}
dct1d_4_scalar(&mut tmp[0..4]);
for i in 0..4 {
tmp[4 + i] *= WC_MULTIPLIERS_8[i];
}
dct1d_4_scalar(&mut tmp[4..8]);
tmp[4] = SQRT2 * tmp[4] + tmp[5];
tmp[5] += tmp[6];
tmp[6] += tmp[7];
for i in 0..4 {
mem[2 * i] = tmp[i];
mem[2 * i + 1] = tmp[4 + i];
}
}
fn dct1d_16_scalar(mem: &mut [f32]) {
let mut tmp = [0.0f32; 16];
for i in 0..8 {
tmp[i] = mem[i] + mem[15 - i];
}
for i in 0..8 {
tmp[8 + i] = mem[i] - mem[15 - i];
}
dct1d_8_scalar(&mut tmp[0..8]);
for i in 0..8 {
tmp[8 + i] *= WC_MULTIPLIERS_16[i];
}
dct1d_8_scalar(&mut tmp[8..16]);
tmp[8] = SQRT2 * tmp[8] + tmp[9];
for i in 1..7 {
tmp[8 + i] += tmp[8 + i + 1];
}
for i in 0..8 {
mem[2 * i] = tmp[i];
mem[2 * i + 1] = tmp[8 + i];
}
}
#[cfg(target_arch = "x86_64")]
#[archmage::arcane]
#[inline(always)]
fn gather_col(
token: archmage::X64V3Token,
data: &[f32],
base_row: usize,
j: usize,
) -> magetypes::simd::f32x8 {
magetypes::simd::f32x8::from_array(
token,
[
data[base_row * 16 + j],
data[(base_row + 1) * 16 + j],
data[(base_row + 2) * 16 + j],
data[(base_row + 3) * 16 + j],
data[(base_row + 4) * 16 + j],
data[(base_row + 5) * 16 + j],
data[(base_row + 6) * 16 + j],
data[(base_row + 7) * 16 + j],
],
)
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
fn scatter_col(v: magetypes::simd::f32x8, data: &mut [f32], base_row: usize, j: usize) {
let mut lane = [0.0f32; 8];
v.store(&mut lane);
for r in 0..8 {
data[(base_row + r) * 16 + j] = lane[r];
}
}
#[cfg(target_arch = "x86_64")]
#[archmage::arcane]
#[inline(always)]
fn dct1d_4_batch(token: archmage::X64V3Token, v: &mut [magetypes::simd::f32x8; 4]) {
use magetypes::simd::f32x8;
let sqrt2 = f32x8::splat(token, SQRT2);
let wc4_0 = f32x8::splat(token, WC_MULTIPLIERS_4[0]);
let wc4_1 = f32x8::splat(token, WC_MULTIPLIERS_4[1]);
let a0 = v[0] + v[3];
let a1 = v[1] + v[2];
let s0 = v[0] - v[3];
let s1 = v[1] - v[2];
let fh0 = a0 + a1;
let fh1 = a0 - a1;
let s0 = s0 * wc4_0;
let s1 = s1 * wc4_1;
let sh0 = s0 + s1;
let sh1 = s0 - s1;
let sh0 = sqrt2.mul_add(sh0, sh1);
v[0] = fh0;
v[1] = sh0;
v[2] = fh1;
v[3] = sh1;
}
#[cfg(target_arch = "x86_64")]
#[archmage::arcane]
#[inline(always)]
fn dct1d_8_batch(token: archmage::X64V3Token, v: &mut [magetypes::simd::f32x8; 8]) {
use magetypes::simd::f32x8;
let sqrt2 = f32x8::splat(token, SQRT2);
let a0 = v[0] + v[7];
let a1 = v[1] + v[6];
let a2 = v[2] + v[5];
let a3 = v[3] + v[4];
let s0 = v[0] - v[7];
let s1 = v[1] - v[6];
let s2 = v[2] - v[5];
let s3 = v[3] - v[4];
let mut first_half = [a0, a1, a2, a3];
dct1d_4_batch(token, &mut first_half);
let s0 = s0 * f32x8::splat(token, WC_MULTIPLIERS_8[0]);
let s1 = s1 * f32x8::splat(token, WC_MULTIPLIERS_8[1]);
let s2 = s2 * f32x8::splat(token, WC_MULTIPLIERS_8[2]);
let s3 = s3 * f32x8::splat(token, WC_MULTIPLIERS_8[3]);
let mut second_half = [s0, s1, s2, s3];
dct1d_4_batch(token, &mut second_half);
second_half[0] = sqrt2.mul_add(second_half[0], second_half[1]);
second_half[1] += second_half[2];
second_half[2] += second_half[3];
v[0] = first_half[0];
v[1] = second_half[0];
v[2] = first_half[1];
v[3] = second_half[1];
v[4] = first_half[2];
v[5] = second_half[2];
v[6] = first_half[3];
v[7] = second_half[3];
}
#[cfg(target_arch = "x86_64")]
#[archmage::arcane]
#[inline(always)]
#[allow(clippy::too_many_arguments)]
pub(crate) fn dct1d_16_batch(token: archmage::X64V3Token, v: &mut [magetypes::simd::f32x8; 16]) {
use magetypes::simd::f32x8;
let sqrt2 = f32x8::splat(token, SQRT2);
let a0 = v[0] + v[15];
let a1 = v[1] + v[14];
let a2 = v[2] + v[13];
let a3 = v[3] + v[12];
let a4 = v[4] + v[11];
let a5 = v[5] + v[10];
let a6 = v[6] + v[9];
let a7 = v[7] + v[8];
let s0 = v[0] - v[15];
let s1 = v[1] - v[14];
let s2 = v[2] - v[13];
let s3 = v[3] - v[12];
let s4 = v[4] - v[11];
let s5 = v[5] - v[10];
let s6 = v[6] - v[9];
let s7 = v[7] - v[8];
let mut first_half = [a0, a1, a2, a3, a4, a5, a6, a7];
dct1d_8_batch(token, &mut first_half);
let s0 = s0 * f32x8::splat(token, WC_MULTIPLIERS_16[0]);
let s1 = s1 * f32x8::splat(token, WC_MULTIPLIERS_16[1]);
let s2 = s2 * f32x8::splat(token, WC_MULTIPLIERS_16[2]);
let s3 = s3 * f32x8::splat(token, WC_MULTIPLIERS_16[3]);
let s4 = s4 * f32x8::splat(token, WC_MULTIPLIERS_16[4]);
let s5 = s5 * f32x8::splat(token, WC_MULTIPLIERS_16[5]);
let s6 = s6 * f32x8::splat(token, WC_MULTIPLIERS_16[6]);
let s7 = s7 * f32x8::splat(token, WC_MULTIPLIERS_16[7]);
let mut second_half = [s0, s1, s2, s3, s4, s5, s6, s7];
dct1d_8_batch(token, &mut second_half);
second_half[0] = sqrt2.mul_add(second_half[0], second_half[1]);
second_half[1] += second_half[2];
second_half[2] += second_half[3];
second_half[3] += second_half[4];
second_half[4] += second_half[5];
second_half[5] += second_half[6];
second_half[6] += second_half[7];
v[0] = first_half[0];
v[1] = second_half[0];
v[2] = first_half[1];
v[3] = second_half[1];
v[4] = first_half[2];
v[5] = second_half[2];
v[6] = first_half[3];
v[7] = second_half[3];
v[8] = first_half[4];
v[9] = second_half[4];
v[10] = first_half[5];
v[11] = second_half[5];
v[12] = first_half[6];
v[13] = second_half[6];
v[14] = first_half[7];
v[15] = second_half[7];
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[archmage::arcane]
#[allow(clippy::needless_range_loop)]
pub fn dct_16x16_avx2(token: archmage::X64V3Token, input: &[f32; 256], output: &mut [f32; 256]) {
use magetypes::simd::f32x8;
let scale = f32x8::splat(token, 1.0 / 16.0);
let mut tmp = crate::scratch_buf::<256>();
{
let mut v = [f32x8::zero(token); 16];
for j in 0..16 {
v[j] = gather_col(token, input, 0, j);
}
dct1d_16_batch(token, &mut v);
for j in 0..16 {
v[j] *= scale;
}
for j in 0..16 {
scatter_col(v[j], &mut tmp, 0, j);
}
}
{
let mut v = [f32x8::zero(token); 16];
for j in 0..16 {
v[j] = gather_col(token, input, 8, j);
}
dct1d_16_batch(token, &mut v);
for j in 0..16 {
v[j] *= scale;
}
for j in 0..16 {
scatter_col(v[j], &mut tmp, 8, j);
}
}
let mut transposed = crate::scratch_buf::<256>();
for r in 0..16 {
for c in 0..16 {
transposed[c * 16 + r] = tmp[r * 16 + c];
}
}
{
let mut v = [f32x8::zero(token); 16];
for j in 0..16 {
v[j] = gather_col(token, &transposed, 0, j);
}
dct1d_16_batch(token, &mut v);
for j in 0..16 {
v[j] *= scale;
}
for j in 0..16 {
scatter_col(v[j], output, 0, j);
}
}
{
let mut v = [f32x8::zero(token); 16];
for j in 0..16 {
v[j] = gather_col(token, &transposed, 8, j);
}
dct1d_16_batch(token, &mut v);
for j in 0..16 {
v[j] *= scale;
}
for j in 0..16 {
scatter_col(v[j], output, 8, j);
}
}
}
#[inline]
pub fn dct_16x8(input: &[f32; 128], output: &mut [f32; 128]) {
#[cfg(target_arch = "x86_64")]
{
use archmage::SimdToken;
if let Some(token) = archmage::X64V3Token::summon() {
dct_16x8_avx2(token, input, output);
return;
}
}
#[cfg(target_arch = "aarch64")]
{
use archmage::SimdToken;
if let Some(token) = archmage::NeonToken::summon() {
dct_16x8_neon(token, input, output);
return;
}
}
#[cfg(target_arch = "wasm32")]
{
use archmage::SimdToken;
if let Some(token) = archmage::Wasm128Token::summon() {
dct_16x8_wasm128(token, input, output);
return;
}
}
dct_16x8_scalar(input, output);
}
#[inline]
pub fn dct_16x8_scalar(input: &[f32; 128], output: &mut [f32; 128]) {
let mut tmp = crate::scratch_buf::<128>();
for row in 0..16 {
let s = row * 8;
tmp[s..s + 8].copy_from_slice(&input[s..s + 8]);
dct1d_8_scalar(&mut tmp[s..s + 8]);
for i in 0..8 {
tmp[s + i] *= 1.0 / 8.0;
}
}
let mut transposed = crate::scratch_buf::<128>();
for row in 0..16 {
for col in 0..8 {
transposed[col * 16 + row] = tmp[row * 8 + col];
}
}
for row in 0..8 {
let s = row * 16;
dct1d_16_scalar(&mut transposed[s..s + 16]);
for i in 0..16 {
transposed[s + i] *= 1.0 / 16.0;
}
}
output.copy_from_slice(&transposed);
}
#[cfg(target_arch = "x86_64")]
#[archmage::arcane]
#[inline(always)]
fn gather_col_s8(
token: archmage::X64V3Token,
data: &[f32],
base_row: usize,
j: usize,
) -> magetypes::simd::f32x8 {
magetypes::simd::f32x8::from_array(
token,
[
data[base_row * 8 + j],
data[(base_row + 1) * 8 + j],
data[(base_row + 2) * 8 + j],
data[(base_row + 3) * 8 + j],
data[(base_row + 4) * 8 + j],
data[(base_row + 5) * 8 + j],
data[(base_row + 6) * 8 + j],
data[(base_row + 7) * 8 + j],
],
)
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
fn scatter_col_s8(v: magetypes::simd::f32x8, data: &mut [f32], base_row: usize, j: usize) {
let mut lane = [0.0f32; 8];
v.store(&mut lane);
for r in 0..8 {
data[(base_row + r) * 8 + j] = lane[r];
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[archmage::arcane]
#[allow(clippy::needless_range_loop)]
pub fn dct_16x8_avx2(token: archmage::X64V3Token, input: &[f32; 128], output: &mut [f32; 128]) {
use magetypes::simd::f32x8;
let scale8 = f32x8::splat(token, 1.0 / 8.0);
let scale16 = f32x8::splat(token, 1.0 / 16.0);
let mut tmp = crate::scratch_buf::<128>();
{
let mut v = [f32x8::zero(token); 8];
for j in 0..8 {
v[j] = gather_col_s8(token, input, 0, j);
}
dct1d_8_batch(token, &mut v);
for j in 0..8 {
v[j] *= scale8;
}
for j in 0..8 {
scatter_col_s8(v[j], &mut tmp, 0, j);
}
}
{
let mut v = [f32x8::zero(token); 8];
for j in 0..8 {
v[j] = gather_col_s8(token, input, 8, j);
}
dct1d_8_batch(token, &mut v);
for j in 0..8 {
v[j] *= scale8;
}
for j in 0..8 {
scatter_col_s8(v[j], &mut tmp, 8, j);
}
}
let mut transposed = crate::scratch_buf::<128>();
for row in 0..16 {
for col in 0..8 {
transposed[col * 16 + row] = tmp[row * 8 + col];
}
}
{
let mut v = [f32x8::zero(token); 16];
for j in 0..16 {
v[j] = gather_col(token, &transposed, 0, j);
}
dct1d_16_batch(token, &mut v);
for j in 0..16 {
v[j] *= scale16;
}
for j in 0..16 {
scatter_col(v[j], output, 0, j);
}
}
}
#[inline]
pub fn dct_8x16(input: &[f32; 128], output: &mut [f32; 128]) {
#[cfg(target_arch = "x86_64")]
{
use archmage::SimdToken;
if let Some(token) = archmage::X64V3Token::summon() {
dct_8x16_avx2(token, input, output);
return;
}
}
#[cfg(target_arch = "aarch64")]
{
use archmage::SimdToken;
if let Some(token) = archmage::NeonToken::summon() {
dct_8x16_neon(token, input, output);
return;
}
}
#[cfg(target_arch = "wasm32")]
{
use archmage::SimdToken;
if let Some(token) = archmage::Wasm128Token::summon() {
dct_8x16_wasm128(token, input, output);
return;
}
}
dct_8x16_scalar(input, output);
}
#[inline]
pub fn dct_8x16_scalar(input: &[f32; 128], output: &mut [f32; 128]) {
let mut tmp = crate::scratch_buf::<128>();
for row in 0..8 {
let s = row * 16;
tmp[s..s + 16].copy_from_slice(&input[s..s + 16]);
dct1d_16_scalar(&mut tmp[s..s + 16]);
for i in 0..16 {
tmp[s + i] *= 1.0 / 16.0;
}
}
let mut transposed = crate::scratch_buf::<128>();
for row in 0..8 {
for col in 0..16 {
transposed[col * 8 + row] = tmp[row * 16 + col];
}
}
for row in 0..16 {
let s = row * 8;
dct1d_8_scalar(&mut transposed[s..s + 8]);
for i in 0..8 {
transposed[s + i] *= 1.0 / 8.0;
}
}
for row in 0..16 {
for col in 0..8 {
output[col * 16 + row] = transposed[row * 8 + col];
}
}
}
#[cfg(target_arch = "x86_64")]
#[inline]
#[archmage::arcane]
#[allow(clippy::needless_range_loop)]
pub fn dct_8x16_avx2(token: archmage::X64V3Token, input: &[f32; 128], output: &mut [f32; 128]) {
use magetypes::simd::f32x8;
let scale8 = f32x8::splat(token, 1.0 / 8.0);
let scale16 = f32x8::splat(token, 1.0 / 16.0);
let mut tmp = crate::scratch_buf::<128>();
{
let mut v = [f32x8::zero(token); 16];
for j in 0..16 {
v[j] = gather_col(token, input, 0, j);
}
dct1d_16_batch(token, &mut v);
for j in 0..16 {
v[j] *= scale16;
}
for j in 0..16 {
scatter_col(v[j], &mut tmp, 0, j);
}
}
let mut transposed = crate::scratch_buf::<128>();
for row in 0..8 {
for col in 0..16 {
transposed[col * 8 + row] = tmp[row * 16 + col];
}
}
{
let mut v = [f32x8::zero(token); 8];
for j in 0..8 {
v[j] = gather_col_s8(token, &transposed, 0, j);
}
dct1d_8_batch(token, &mut v);
for j in 0..8 {
v[j] *= scale8;
}
for j in 0..8 {
scatter_col_s8(v[j], &mut transposed, 0, j);
}
}
{
let mut v = [f32x8::zero(token); 8];
for j in 0..8 {
v[j] = gather_col_s8(token, &transposed, 8, j);
}
dct1d_8_batch(token, &mut v);
for j in 0..8 {
v[j] *= scale8;
}
for j in 0..8 {
scatter_col_s8(v[j], &mut transposed, 8, j);
}
}
for row in 0..16 {
for col in 0..8 {
output[col * 16 + row] = transposed[row * 8 + col];
}
}
}
#[cfg(target_arch = "aarch64")]
#[archmage::rite]
fn gather_col_neon(
token: archmage::NeonToken,
data: &[f32],
base_row: usize,
j: usize,
s: usize,
) -> magetypes::simd::f32x4 {
magetypes::simd::f32x4::from_array(
token,
[
data[base_row * s + j],
data[(base_row + 1) * s + j],
data[(base_row + 2) * s + j],
data[(base_row + 3) * s + j],
],
)
}
#[cfg(target_arch = "aarch64")]
#[archmage::rite]
fn scatter_col_neon(
_token: archmage::NeonToken,
v: magetypes::simd::f32x4,
data: &mut [f32],
base_row: usize,
j: usize,
s: usize,
) {
let mut lane = [0.0f32; 4];
v.store(&mut lane);
for r in 0..4 {
data[(base_row + r) * s + j] = lane[r];
}
}
#[cfg(target_arch = "aarch64")]
#[archmage::rite]
fn dct1d_4_batch_neon(token: archmage::NeonToken, v: &mut [magetypes::simd::f32x4; 4]) {
use magetypes::simd::f32x4;
let sqrt2 = f32x4::splat(token, SQRT2);
let wc4_0 = f32x4::splat(token, WC_MULTIPLIERS_4[0]);
let wc4_1 = f32x4::splat(token, WC_MULTIPLIERS_4[1]);
let a0 = v[0] + v[3];
let a1 = v[1] + v[2];
let s0 = v[0] - v[3];
let s1 = v[1] - v[2];
let fh0 = a0 + a1;
let fh1 = a0 - a1;
let s0 = s0 * wc4_0;
let s1 = s1 * wc4_1;
let sh0 = s0 + s1;
let sh1 = s0 - s1;
let sh0 = sqrt2.mul_add(sh0, sh1);
v[0] = fh0;
v[1] = sh0;
v[2] = fh1;
v[3] = sh1;
}
#[cfg(target_arch = "aarch64")]
#[archmage::rite]
fn dct1d_8_batch_neon(token: archmage::NeonToken, v: &mut [magetypes::simd::f32x4; 8]) {
use magetypes::simd::f32x4;
let sqrt2 = f32x4::splat(token, SQRT2);
let a0 = v[0] + v[7];
let a1 = v[1] + v[6];
let a2 = v[2] + v[5];
let a3 = v[3] + v[4];
let s0 = v[0] - v[7];
let s1 = v[1] - v[6];
let s2 = v[2] - v[5];
let s3 = v[3] - v[4];
let mut first_half = [a0, a1, a2, a3];
dct1d_4_batch_neon(token, &mut first_half);
let s0 = s0 * f32x4::splat(token, WC_MULTIPLIERS_8[0]);
let s1 = s1 * f32x4::splat(token, WC_MULTIPLIERS_8[1]);
let s2 = s2 * f32x4::splat(token, WC_MULTIPLIERS_8[2]);
let s3 = s3 * f32x4::splat(token, WC_MULTIPLIERS_8[3]);
let mut second_half = [s0, s1, s2, s3];
dct1d_4_batch_neon(token, &mut second_half);
second_half[0] = sqrt2.mul_add(second_half[0], second_half[1]);
second_half[1] += second_half[2];
second_half[2] += second_half[3];
v[0] = first_half[0];
v[1] = second_half[0];
v[2] = first_half[1];
v[3] = second_half[1];
v[4] = first_half[2];
v[5] = second_half[2];
v[6] = first_half[3];
v[7] = second_half[3];
}
#[cfg(target_arch = "aarch64")]
#[archmage::rite]
pub(crate) fn dct1d_16_batch_neon(
token: archmage::NeonToken,
v: &mut [magetypes::simd::f32x4; 16],
) {
use magetypes::simd::f32x4;
let sqrt2 = f32x4::splat(token, SQRT2);
let a0 = v[0] + v[15];
let a1 = v[1] + v[14];
let a2 = v[2] + v[13];
let a3 = v[3] + v[12];
let a4 = v[4] + v[11];
let a5 = v[5] + v[10];
let a6 = v[6] + v[9];
let a7 = v[7] + v[8];
let s0 = v[0] - v[15];
let s1 = v[1] - v[14];
let s2 = v[2] - v[13];
let s3 = v[3] - v[12];
let s4 = v[4] - v[11];
let s5 = v[5] - v[10];
let s6 = v[6] - v[9];
let s7 = v[7] - v[8];
let mut first_half = [a0, a1, a2, a3, a4, a5, a6, a7];
dct1d_8_batch_neon(token, &mut first_half);
let s0 = s0 * f32x4::splat(token, WC_MULTIPLIERS_16[0]);
let s1 = s1 * f32x4::splat(token, WC_MULTIPLIERS_16[1]);
let s2 = s2 * f32x4::splat(token, WC_MULTIPLIERS_16[2]);
let s3 = s3 * f32x4::splat(token, WC_MULTIPLIERS_16[3]);
let s4 = s4 * f32x4::splat(token, WC_MULTIPLIERS_16[4]);
let s5 = s5 * f32x4::splat(token, WC_MULTIPLIERS_16[5]);
let s6 = s6 * f32x4::splat(token, WC_MULTIPLIERS_16[6]);
let s7 = s7 * f32x4::splat(token, WC_MULTIPLIERS_16[7]);
let mut second_half = [s0, s1, s2, s3, s4, s5, s6, s7];
dct1d_8_batch_neon(token, &mut second_half);
second_half[0] = sqrt2.mul_add(second_half[0], second_half[1]);
second_half[1] += second_half[2];
second_half[2] += second_half[3];
second_half[3] += second_half[4];
second_half[4] += second_half[5];
second_half[5] += second_half[6];
second_half[6] += second_half[7];
v[0] = first_half[0];
v[1] = second_half[0];
v[2] = first_half[1];
v[3] = second_half[1];
v[4] = first_half[2];
v[5] = second_half[2];
v[6] = first_half[3];
v[7] = second_half[3];
v[8] = first_half[4];
v[9] = second_half[4];
v[10] = first_half[5];
v[11] = second_half[5];
v[12] = first_half[6];
v[13] = second_half[6];
v[14] = first_half[7];
v[15] = second_half[7];
}
#[cfg(target_arch = "aarch64")]
#[archmage::rite]
#[allow(clippy::needless_range_loop)]
fn neon_dct8_batch(
token: archmage::NeonToken,
data_in: &[f32],
data_out: &mut [f32],
base_row: usize,
stride: usize,
scale: magetypes::simd::f32x4,
) {
let mut v = [magetypes::simd::f32x4::zero(token); 8];
for j in 0..8 {
v[j] = gather_col_neon(token, data_in, base_row, j, stride);
}
dct1d_8_batch_neon(token, &mut v);
for j in 0..8 {
v[j] *= scale;
}
for j in 0..8 {
scatter_col_neon(token, v[j], data_out, base_row, j, stride);
}
}
#[cfg(target_arch = "aarch64")]
#[archmage::rite]
#[allow(clippy::needless_range_loop)]
fn neon_dct16_batch(
token: archmage::NeonToken,
data_in: &[f32],
data_out: &mut [f32],
base_row: usize,
stride: usize,
scale: magetypes::simd::f32x4,
) {
let mut v = [magetypes::simd::f32x4::zero(token); 16];
for j in 0..16 {
v[j] = gather_col_neon(token, data_in, base_row, j, stride);
}
dct1d_16_batch_neon(token, &mut v);
for j in 0..16 {
v[j] *= scale;
}
for j in 0..16 {
scatter_col_neon(token, v[j], data_out, base_row, j, stride);
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
#[archmage::arcane]
#[allow(clippy::needless_range_loop)]
pub fn dct_16x16_neon(token: archmage::NeonToken, input: &[f32; 256], output: &mut [f32; 256]) {
use magetypes::simd::f32x4;
let scale = f32x4::splat(token, 1.0 / 16.0);
let mut tmp = crate::scratch_buf::<256>();
for batch in 0..4 {
neon_dct16_batch(token, input, &mut tmp, batch * 4, 16, scale);
}
let mut transposed = crate::scratch_buf::<256>();
for r in 0..16 {
for c in 0..16 {
transposed[c * 16 + r] = tmp[r * 16 + c];
}
}
for batch in 0..4 {
neon_dct16_batch(token, &transposed, output, batch * 4, 16, scale);
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
#[archmage::arcane]
#[allow(clippy::needless_range_loop)]
pub fn dct_16x8_neon(token: archmage::NeonToken, input: &[f32; 128], output: &mut [f32; 128]) {
use magetypes::simd::f32x4;
let scale8 = f32x4::splat(token, 1.0 / 8.0);
let scale16 = f32x4::splat(token, 1.0 / 16.0);
let mut tmp = crate::scratch_buf::<128>();
for batch in 0..4 {
neon_dct8_batch(token, input, &mut tmp, batch * 4, 8, scale8);
}
let mut transposed = crate::scratch_buf::<128>();
for row in 0..16 {
for col in 0..8 {
transposed[col * 16 + row] = tmp[row * 8 + col];
}
}
for batch in 0..2 {
neon_dct16_batch(token, &transposed, output, batch * 4, 16, scale16);
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
#[archmage::arcane]
#[allow(clippy::needless_range_loop)]
pub fn dct_8x16_neon(token: archmage::NeonToken, input: &[f32; 128], output: &mut [f32; 128]) {
use magetypes::simd::f32x4;
let scale8 = f32x4::splat(token, 1.0 / 8.0);
let scale16 = f32x4::splat(token, 1.0 / 16.0);
let mut tmp = crate::scratch_buf::<128>();
for batch in 0..2 {
neon_dct16_batch(token, input, &mut tmp, batch * 4, 16, scale16);
}
let mut transposed = crate::scratch_buf::<128>();
for row in 0..8 {
for col in 0..16 {
transposed[col * 8 + row] = tmp[row * 16 + col];
}
}
let mut pass2_out = crate::scratch_buf::<128>();
for batch in 0..4 {
neon_dct8_batch(token, &transposed, &mut pass2_out, batch * 4, 8, scale8);
}
for row in 0..16 {
for col in 0..8 {
output[col * 16 + row] = pass2_out[row * 8 + col];
}
}
}
#[cfg(target_arch = "wasm32")]
#[archmage::rite]
fn gather_col_wasm128(
token: archmage::Wasm128Token,
data: &[f32],
base_row: usize,
j: usize,
s: usize,
) -> magetypes::simd::f32x4 {
magetypes::simd::f32x4::from_array(
token,
[
data[base_row * s + j],
data[(base_row + 1) * s + j],
data[(base_row + 2) * s + j],
data[(base_row + 3) * s + j],
],
)
}
#[cfg(target_arch = "wasm32")]
#[archmage::rite]
fn scatter_col_wasm128(
_token: archmage::Wasm128Token,
v: magetypes::simd::f32x4,
data: &mut [f32],
base_row: usize,
j: usize,
s: usize,
) {
let mut lane = [0.0f32; 4];
v.store(&mut lane);
for r in 0..4 {
data[(base_row + r) * s + j] = lane[r];
}
}
#[cfg(target_arch = "wasm32")]
#[archmage::rite]
fn dct1d_4_batch_wasm128(token: archmage::Wasm128Token, v: &mut [magetypes::simd::f32x4; 4]) {
use magetypes::simd::f32x4;
let sqrt2 = f32x4::splat(token, SQRT2);
let wc4_0 = f32x4::splat(token, WC_MULTIPLIERS_4[0]);
let wc4_1 = f32x4::splat(token, WC_MULTIPLIERS_4[1]);
let a0 = v[0] + v[3];
let a1 = v[1] + v[2];
let s0 = v[0] - v[3];
let s1 = v[1] - v[2];
let fh0 = a0 + a1;
let fh1 = a0 - a1;
let s0 = s0 * wc4_0;
let s1 = s1 * wc4_1;
let sh0 = s0 + s1;
let sh1 = s0 - s1;
let sh0 = sqrt2.mul_add(sh0, sh1);
v[0] = fh0;
v[1] = sh0;
v[2] = fh1;
v[3] = sh1;
}
#[cfg(target_arch = "wasm32")]
#[archmage::rite]
fn dct1d_8_batch_wasm128(token: archmage::Wasm128Token, v: &mut [magetypes::simd::f32x4; 8]) {
use magetypes::simd::f32x4;
let sqrt2 = f32x4::splat(token, SQRT2);
let a0 = v[0] + v[7];
let a1 = v[1] + v[6];
let a2 = v[2] + v[5];
let a3 = v[3] + v[4];
let s0 = v[0] - v[7];
let s1 = v[1] - v[6];
let s2 = v[2] - v[5];
let s3 = v[3] - v[4];
let mut first_half = [a0, a1, a2, a3];
dct1d_4_batch_wasm128(token, &mut first_half);
let s0 = s0 * f32x4::splat(token, WC_MULTIPLIERS_8[0]);
let s1 = s1 * f32x4::splat(token, WC_MULTIPLIERS_8[1]);
let s2 = s2 * f32x4::splat(token, WC_MULTIPLIERS_8[2]);
let s3 = s3 * f32x4::splat(token, WC_MULTIPLIERS_8[3]);
let mut second_half = [s0, s1, s2, s3];
dct1d_4_batch_wasm128(token, &mut second_half);
second_half[0] = sqrt2.mul_add(second_half[0], second_half[1]);
second_half[1] += second_half[2];
second_half[2] += second_half[3];
v[0] = first_half[0];
v[1] = second_half[0];
v[2] = first_half[1];
v[3] = second_half[1];
v[4] = first_half[2];
v[5] = second_half[2];
v[6] = first_half[3];
v[7] = second_half[3];
}
#[cfg(target_arch = "wasm32")]
#[archmage::rite]
pub(crate) fn dct1d_16_batch_wasm128(
token: archmage::Wasm128Token,
v: &mut [magetypes::simd::f32x4; 16],
) {
use magetypes::simd::f32x4;
let sqrt2 = f32x4::splat(token, SQRT2);
let a0 = v[0] + v[15];
let a1 = v[1] + v[14];
let a2 = v[2] + v[13];
let a3 = v[3] + v[12];
let a4 = v[4] + v[11];
let a5 = v[5] + v[10];
let a6 = v[6] + v[9];
let a7 = v[7] + v[8];
let s0 = v[0] - v[15];
let s1 = v[1] - v[14];
let s2 = v[2] - v[13];
let s3 = v[3] - v[12];
let s4 = v[4] - v[11];
let s5 = v[5] - v[10];
let s6 = v[6] - v[9];
let s7 = v[7] - v[8];
let mut first_half = [a0, a1, a2, a3, a4, a5, a6, a7];
dct1d_8_batch_wasm128(token, &mut first_half);
let s0 = s0 * f32x4::splat(token, WC_MULTIPLIERS_16[0]);
let s1 = s1 * f32x4::splat(token, WC_MULTIPLIERS_16[1]);
let s2 = s2 * f32x4::splat(token, WC_MULTIPLIERS_16[2]);
let s3 = s3 * f32x4::splat(token, WC_MULTIPLIERS_16[3]);
let s4 = s4 * f32x4::splat(token, WC_MULTIPLIERS_16[4]);
let s5 = s5 * f32x4::splat(token, WC_MULTIPLIERS_16[5]);
let s6 = s6 * f32x4::splat(token, WC_MULTIPLIERS_16[6]);
let s7 = s7 * f32x4::splat(token, WC_MULTIPLIERS_16[7]);
let mut second_half = [s0, s1, s2, s3, s4, s5, s6, s7];
dct1d_8_batch_wasm128(token, &mut second_half);
second_half[0] = sqrt2.mul_add(second_half[0], second_half[1]);
second_half[1] += second_half[2];
second_half[2] += second_half[3];
second_half[3] += second_half[4];
second_half[4] += second_half[5];
second_half[5] += second_half[6];
second_half[6] += second_half[7];
v[0] = first_half[0];
v[1] = second_half[0];
v[2] = first_half[1];
v[3] = second_half[1];
v[4] = first_half[2];
v[5] = second_half[2];
v[6] = first_half[3];
v[7] = second_half[3];
v[8] = first_half[4];
v[9] = second_half[4];
v[10] = first_half[5];
v[11] = second_half[5];
v[12] = first_half[6];
v[13] = second_half[6];
v[14] = first_half[7];
v[15] = second_half[7];
}
#[cfg(target_arch = "wasm32")]
#[archmage::rite]
#[allow(clippy::needless_range_loop)]
fn wasm128_dct8_batch(
token: archmage::Wasm128Token,
data_in: &[f32],
data_out: &mut [f32],
base_row: usize,
stride: usize,
scale: magetypes::simd::f32x4,
) {
let mut v = [magetypes::simd::f32x4::zero(token); 8];
for j in 0..8 {
v[j] = gather_col_wasm128(token, data_in, base_row, j, stride);
}
dct1d_8_batch_wasm128(token, &mut v);
for j in 0..8 {
v[j] *= scale;
}
for j in 0..8 {
scatter_col_wasm128(token, v[j], data_out, base_row, j, stride);
}
}
#[cfg(target_arch = "wasm32")]
#[archmage::rite]
#[allow(clippy::needless_range_loop)]
fn wasm128_dct16_batch(
token: archmage::Wasm128Token,
data_in: &[f32],
data_out: &mut [f32],
base_row: usize,
stride: usize,
scale: magetypes::simd::f32x4,
) {
let mut v = [magetypes::simd::f32x4::zero(token); 16];
for j in 0..16 {
v[j] = gather_col_wasm128(token, data_in, base_row, j, stride);
}
dct1d_16_batch_wasm128(token, &mut v);
for j in 0..16 {
v[j] *= scale;
}
for j in 0..16 {
scatter_col_wasm128(token, v[j], data_out, base_row, j, stride);
}
}
#[cfg(target_arch = "wasm32")]
#[inline]
#[archmage::arcane]
#[allow(clippy::needless_range_loop)]
pub fn dct_16x16_wasm128(
token: archmage::Wasm128Token,
input: &[f32; 256],
output: &mut [f32; 256],
) {
use magetypes::simd::f32x4;
let scale = f32x4::splat(token, 1.0 / 16.0);
let mut tmp = crate::scratch_buf::<256>();
for batch in 0..4 {
wasm128_dct16_batch(token, input, &mut tmp, batch * 4, 16, scale);
}
let mut transposed = crate::scratch_buf::<256>();
for r in 0..16 {
for c in 0..16 {
transposed[c * 16 + r] = tmp[r * 16 + c];
}
}
for batch in 0..4 {
wasm128_dct16_batch(token, &transposed, output, batch * 4, 16, scale);
}
}
#[cfg(target_arch = "wasm32")]
#[inline]
#[archmage::arcane]
#[allow(clippy::needless_range_loop)]
pub fn dct_16x8_wasm128(
token: archmage::Wasm128Token,
input: &[f32; 128],
output: &mut [f32; 128],
) {
use magetypes::simd::f32x4;
let scale8 = f32x4::splat(token, 1.0 / 8.0);
let scale16 = f32x4::splat(token, 1.0 / 16.0);
let mut tmp = crate::scratch_buf::<128>();
for batch in 0..4 {
wasm128_dct8_batch(token, input, &mut tmp, batch * 4, 8, scale8);
}
let mut transposed = crate::scratch_buf::<128>();
for row in 0..16 {
for col in 0..8 {
transposed[col * 16 + row] = tmp[row * 8 + col];
}
}
for batch in 0..2 {
wasm128_dct16_batch(token, &transposed, output, batch * 4, 16, scale16);
}
}
#[cfg(target_arch = "wasm32")]
#[inline]
#[archmage::arcane]
#[allow(clippy::needless_range_loop)]
pub fn dct_8x16_wasm128(
token: archmage::Wasm128Token,
input: &[f32; 128],
output: &mut [f32; 128],
) {
use magetypes::simd::f32x4;
let scale8 = f32x4::splat(token, 1.0 / 8.0);
let scale16 = f32x4::splat(token, 1.0 / 16.0);
let mut tmp = crate::scratch_buf::<128>();
for batch in 0..2 {
wasm128_dct16_batch(token, input, &mut tmp, batch * 4, 16, scale16);
}
let mut transposed = crate::scratch_buf::<128>();
for row in 0..8 {
for col in 0..16 {
transposed[col * 8 + row] = tmp[row * 16 + col];
}
}
let mut pass2_out = crate::scratch_buf::<128>();
for batch in 0..4 {
wasm128_dct8_batch(token, &transposed, &mut pass2_out, batch * 4, 8, scale8);
}
for row in 0..16 {
for col in 0..8 {
output[col * 16 + row] = pass2_out[row * 8 + col];
}
}
}
#[cfg(test)]
mod tests {
extern crate std;
use super::*;
#[test]
fn test_dct_16x16_simd_matches_scalar() {
let mut input = [0.0f32; 256];
for (i, val) in input.iter_mut().enumerate() {
*val = i as f32;
}
let mut scalar_out = [0.0f32; 256];
dct_16x16_scalar(&input, &mut scalar_out);
let report = archmage::testing::for_each_token_permutation(
archmage::testing::CompileTimePolicy::Warn,
|perm| {
let mut simd_out = [0.0f32; 256];
dct_16x16(&input, &mut simd_out);
let mut max_diff = 0.0f32;
let mut max_idx = 0;
for i in 0..256 {
let diff = (scalar_out[i] - simd_out[i]).abs();
if diff > max_diff {
max_diff = diff;
max_idx = i;
}
}
assert!(
max_diff < 1e-2,
"DCT16x16 max diff = {max_diff} at {max_idx} (scalar={}, simd={}) [{perm}]",
scalar_out[max_idx],
simd_out[max_idx],
);
},
);
std::eprintln!("{report}");
}
#[test]
fn test_dct_16x16_dc_only() {
let input = [42.0f32; 256];
let mut output = [0.0f32; 256];
dct_16x16(&input, &mut output);
assert!(
output[0].abs() > 1.0,
"DC coefficient should be nonzero, got {}",
output[0],
);
let mut max_ac = 0.0f32;
let mut max_ac_idx = 0;
for (i, &coeff) in output.iter().enumerate().skip(1) {
let val = coeff.abs();
if val > max_ac {
max_ac = val;
max_ac_idx = i;
}
}
assert!(
max_ac < 1e-3,
"AC coefficients should be near zero, max = {} at index {}",
max_ac,
max_ac_idx,
);
}
#[test]
fn test_dct_16x16_roundtrip() {
let mut input = [0.0f32; 256];
for (i, val) in input.iter_mut().enumerate() {
*val = ((i as f32) * 0.37 + 1.5).cos() * 100.0;
}
let mut dct_out = [0.0f32; 256];
let mut roundtrip = [0.0f32; 256];
dct_16x16(&input, &mut dct_out);
super::super::idct16::idct_16x16(&dct_out, &mut roundtrip);
let mut max_diff = 0.0f32;
let mut max_idx = 0;
for i in 0..256 {
let diff = (input[i] - roundtrip[i]).abs();
if diff > max_diff {
max_diff = diff;
max_idx = i;
}
}
assert!(
max_diff < 1e-2,
"DCT16x16 roundtrip max diff = {} at index {} (input={}, roundtrip={})",
max_diff,
max_idx,
input[max_idx],
roundtrip[max_idx],
);
}
#[test]
fn test_dct_16x8_simd_matches_scalar() {
let mut input = [0.0f32; 128];
for (i, val) in input.iter_mut().enumerate() {
*val = ((i as f32) * 0.43 + 2.1).cos() * 80.0;
}
let mut scalar_out = [0.0f32; 128];
dct_16x8_scalar(&input, &mut scalar_out);
let report = archmage::testing::for_each_token_permutation(
archmage::testing::CompileTimePolicy::Warn,
|perm| {
let mut simd_out = [0.0f32; 128];
dct_16x8(&input, &mut simd_out);
let mut max_diff = 0.0f32;
let mut max_idx = 0;
for i in 0..128 {
let diff = (scalar_out[i] - simd_out[i]).abs();
if diff > max_diff {
max_diff = diff;
max_idx = i;
}
}
assert!(
max_diff < 1e-2,
"DCT16x8 max diff = {max_diff} at {max_idx} (scalar={}, simd={}) [{perm}]",
scalar_out[max_idx],
simd_out[max_idx],
);
},
);
std::eprintln!("{report}");
}
#[test]
fn test_dct_8x16_simd_matches_scalar() {
let mut input = [0.0f32; 128];
for (i, val) in input.iter_mut().enumerate() {
*val = ((i as f32) * 0.29 + 0.7).sin() * 120.0;
}
let mut scalar_out = [0.0f32; 128];
dct_8x16_scalar(&input, &mut scalar_out);
let report = archmage::testing::for_each_token_permutation(
archmage::testing::CompileTimePolicy::Warn,
|perm| {
let mut simd_out = [0.0f32; 128];
dct_8x16(&input, &mut simd_out);
let mut max_diff = 0.0f32;
let mut max_idx = 0;
for i in 0..128 {
let diff = (scalar_out[i] - simd_out[i]).abs();
if diff > max_diff {
max_diff = diff;
max_idx = i;
}
}
assert!(
max_diff < 1e-2,
"DCT8x16 max diff = {max_diff} at {max_idx} (scalar={}, simd={}) [{perm}]",
scalar_out[max_idx],
simd_out[max_idx],
);
},
);
std::eprintln!("{report}");
}
}