use std::f32::consts::{FRAC_1_SQRT_2, PI};
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
const CONST_BITS: i32 = 13;
const PASS1_BITS: i32 = 2;
#[inline(always)]
fn fix_mul(a: i32, b: i32) -> i32 {
((a as i64 * b as i64) >> CONST_BITS) as i32
}
const FIX_0_298631336: i32 = 2446; const FIX_0_390180644: i32 = 3196; const FIX_0_541196100: i32 = 4433; const FIX_0_765366865: i32 = 6270; const FIX_0_899976223: i32 = 7373; const FIX_1_175875602: i32 = 9633; const FIX_1_501321110: i32 = 12299; const FIX_1_847759065: i32 = 15137; const FIX_1_961570560: i32 = 16069; const FIX_2_053119869: i32 = 16819; const FIX_2_562915447: i32 = 20995; const FIX_3_072711026: i32 = 25172;
pub fn dct_2d_integer(block: &[i16; 64]) -> [i32; 64] {
let mut workspace = [0i32; 64];
for row in 0..8 {
let row_offset = row * 8;
let d0 = block[row_offset] as i32;
let d1 = block[row_offset + 1] as i32;
let d2 = block[row_offset + 2] as i32;
let d3 = block[row_offset + 3] as i32;
let d4 = block[row_offset + 4] as i32;
let d5 = block[row_offset + 5] as i32;
let d6 = block[row_offset + 6] as i32;
let d7 = block[row_offset + 7] as i32;
let tmp0 = d0 + d7;
let tmp1 = d1 + d6;
let tmp2 = d2 + d5;
let tmp3 = d3 + d4;
let tmp10 = tmp0 + tmp3;
let tmp12 = tmp0 - tmp3;
let tmp11 = tmp1 + tmp2;
let tmp13 = tmp1 - tmp2;
let tmp0 = d0 - d7;
let tmp1 = d1 - d6;
let tmp2 = d2 - d5;
let tmp3 = d3 - d4;
workspace[row_offset] = (tmp10 + tmp11) << PASS1_BITS;
workspace[row_offset + 4] = (tmp10 - tmp11) << PASS1_BITS;
let z1 = fix_mul(tmp12 + tmp13, FIX_0_541196100);
workspace[row_offset + 2] = z1 + fix_mul(tmp12, FIX_0_765366865);
workspace[row_offset + 6] = z1 - fix_mul(tmp13, FIX_1_847759065);
let tmp10 = tmp0 + tmp3;
let tmp11 = tmp1 + tmp2;
let tmp12 = tmp0 + tmp2;
let tmp13 = tmp1 + tmp3;
let z1 = fix_mul(tmp12 + tmp13, FIX_1_175875602);
let tmp0 = fix_mul(tmp0, FIX_1_501321110);
let tmp1 = fix_mul(tmp1, FIX_3_072711026);
let tmp2 = fix_mul(tmp2, FIX_2_053119869);
let tmp3 = fix_mul(tmp3, FIX_0_298631336);
let tmp10 = fix_mul(tmp10, -FIX_0_899976223);
let tmp11 = fix_mul(tmp11, -FIX_2_562915447);
let tmp12 = fix_mul(tmp12, -FIX_0_390180644) + z1;
let tmp13 = fix_mul(tmp13, -FIX_1_961570560) + z1;
workspace[row_offset + 1] = tmp0 + tmp10 + tmp12;
workspace[row_offset + 3] = tmp1 + tmp11 + tmp13;
workspace[row_offset + 5] = tmp2 + tmp11 + tmp12;
workspace[row_offset + 7] = tmp3 + tmp10 + tmp13;
}
let mut result = [0i32; 64];
for col in 0..8 {
let d0 = workspace[col];
let d1 = workspace[col + 8];
let d2 = workspace[col + 16];
let d3 = workspace[col + 24];
let d4 = workspace[col + 32];
let d5 = workspace[col + 40];
let d6 = workspace[col + 48];
let d7 = workspace[col + 56];
let tmp0 = d0 + d7;
let tmp1 = d1 + d6;
let tmp2 = d2 + d5;
let tmp3 = d3 + d4;
let tmp10 = tmp0 + tmp3;
let tmp12 = tmp0 - tmp3;
let tmp11 = tmp1 + tmp2;
let tmp13 = tmp1 - tmp2;
let tmp0 = d0 - d7;
let tmp1 = d1 - d6;
let tmp2 = d2 - d5;
let tmp3 = d3 - d4;
let descale = PASS1_BITS + 3;
result[col] = (tmp10 + tmp11 + (1 << (descale - 1))) >> descale;
result[col + 32] = (tmp10 - tmp11 + (1 << (descale - 1))) >> descale;
let z1 = fix_mul(tmp12 + tmp13, FIX_0_541196100);
result[col + 16] = (z1 + fix_mul(tmp12, FIX_0_765366865) + (1 << (descale - 1))) >> descale;
result[col + 48] = (z1 - fix_mul(tmp13, FIX_1_847759065) + (1 << (descale - 1))) >> descale;
let tmp10 = tmp0 + tmp3;
let tmp11 = tmp1 + tmp2;
let tmp12 = tmp0 + tmp2;
let tmp13 = tmp1 + tmp3;
let z1 = fix_mul(tmp12 + tmp13, FIX_1_175875602);
let tmp0 = fix_mul(tmp0, FIX_1_501321110);
let tmp1 = fix_mul(tmp1, FIX_3_072711026);
let tmp2 = fix_mul(tmp2, FIX_2_053119869);
let tmp3 = fix_mul(tmp3, FIX_0_298631336);
let tmp10 = fix_mul(tmp10, -FIX_0_899976223);
let tmp11 = fix_mul(tmp11, -FIX_2_562915447);
let tmp12 = fix_mul(tmp12, -FIX_0_390180644) + z1;
let tmp13 = fix_mul(tmp13, -FIX_1_961570560) + z1;
result[col + 8] = (tmp0 + tmp10 + tmp12 + (1 << (descale - 1))) >> descale;
result[col + 24] = (tmp1 + tmp11 + tmp13 + (1 << (descale - 1))) >> descale;
result[col + 40] = (tmp2 + tmp11 + tmp12 + (1 << (descale - 1))) >> descale;
result[col + 56] = (tmp3 + tmp10 + tmp13 + (1 << (descale - 1))) >> descale;
}
result
}
#[cfg(target_arch = "aarch64")]
pub fn dct_2d_integer_neon(block: &[i16; 64]) -> [i32; 64] {
unsafe { dct_2d_integer_neon_impl(block) }
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn dct_2d_integer_neon_impl(block: &[i16; 64]) -> [i32; 64] {
let mut workspace = [0i32; 64];
for row in 0..8 {
let offset = row * 8;
let row_s16 = vld1q_s16(block[offset..].as_ptr());
let lo = vmovl_s16(vget_low_s16(row_s16));
let hi = vmovl_high_s16(row_s16);
let result = dct_row_neon_vectorized(lo, hi);
vst1q_s32(workspace[offset..].as_mut_ptr(), result.0);
vst1q_s32(workspace[offset + 4..].as_mut_ptr(), result.1);
}
let mut result = [0i32; 64];
let mut rows: [int32x4x2_t; 8] = std::mem::zeroed();
for row in 0..8 {
rows[row].0 = vld1q_s32(workspace[row * 8..].as_ptr());
rows[row].1 = vld1q_s32(workspace[row * 8 + 4..].as_ptr());
}
let transposed = transpose_8x8_neon(&rows);
for col in 0..8 {
let col_result = dct_column_neon_vectorized(transposed[col].0, transposed[col].1);
let mut temp = [0i32; 8];
vst1q_s32(temp[0..4].as_mut_ptr(), col_result.0);
vst1q_s32(temp[4..8].as_mut_ptr(), col_result.1);
for row in 0..8 {
result[row * 8 + col] = temp[row];
}
}
result
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[inline]
unsafe fn dct_row_neon_vectorized(lo: int32x4_t, hi: int32x4_t) -> (int32x4_t, int32x4_t) {
let hi_rev = vrev64q_s32(hi);
let hi_rev = vextq_s32(hi_rev, hi_rev, 2);
let even_sum = vaddq_s32(lo, hi_rev); let even_diff = vsubq_s32(lo, hi_rev);
let tmp0 = vgetq_lane_s32(even_sum, 0);
let tmp1 = vgetq_lane_s32(even_sum, 1);
let tmp2 = vgetq_lane_s32(even_sum, 2);
let tmp3 = vgetq_lane_s32(even_sum, 3);
let tmp10 = tmp0 + tmp3;
let tmp11 = tmp1 + tmp2;
let tmp12 = tmp0 - tmp3;
let tmp13 = tmp1 - tmp2;
let tmp0_odd = vgetq_lane_s32(even_diff, 0);
let tmp1_odd = vgetq_lane_s32(even_diff, 1);
let tmp2_odd = vgetq_lane_s32(even_diff, 2);
let tmp3_odd = vgetq_lane_s32(even_diff, 3);
let out0 = (tmp10 + tmp11) << PASS1_BITS;
let out4 = (tmp10 - tmp11) << PASS1_BITS;
let z1 = fix_mul(tmp12 + tmp13, FIX_0_541196100);
let out2 = z1 + fix_mul(tmp12, FIX_0_765366865);
let out6 = z1 - fix_mul(tmp13, FIX_1_847759065);
let tmp10_o = tmp0_odd + tmp3_odd;
let tmp11_o = tmp1_odd + tmp2_odd;
let tmp12_o = tmp0_odd + tmp2_odd;
let tmp13_o = tmp1_odd + tmp3_odd;
let z1_o = fix_mul(tmp12_o + tmp13_o, FIX_1_175875602);
let tmp0_m = fix_mul(tmp0_odd, FIX_1_501321110);
let tmp1_m = fix_mul(tmp1_odd, FIX_3_072711026);
let tmp2_m = fix_mul(tmp2_odd, FIX_2_053119869);
let tmp3_m = fix_mul(tmp3_odd, FIX_0_298631336);
let tmp10_m = fix_mul(tmp10_o, -FIX_0_899976223);
let tmp11_m = fix_mul(tmp11_o, -FIX_2_562915447);
let tmp12_m = fix_mul(tmp12_o, -FIX_0_390180644) + z1_o;
let tmp13_m = fix_mul(tmp13_o, -FIX_1_961570560) + z1_o;
let out1 = tmp0_m + tmp10_m + tmp12_m;
let out3 = tmp1_m + tmp11_m + tmp13_m;
let out5 = tmp2_m + tmp11_m + tmp12_m;
let out7 = tmp3_m + tmp10_m + tmp13_m;
let result_lo = vsetq_lane_s32(out0, vdupq_n_s32(0), 0);
let result_lo = vsetq_lane_s32(out1, result_lo, 1);
let result_lo = vsetq_lane_s32(out2, result_lo, 2);
let result_lo = vsetq_lane_s32(out3, result_lo, 3);
let result_hi = vsetq_lane_s32(out4, vdupq_n_s32(0), 0);
let result_hi = vsetq_lane_s32(out5, result_hi, 1);
let result_hi = vsetq_lane_s32(out6, result_hi, 2);
let result_hi = vsetq_lane_s32(out7, result_hi, 3);
(result_lo, result_hi)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[inline]
unsafe fn dct_column_neon_vectorized(lo: int32x4_t, hi: int32x4_t) -> (int32x4_t, int32x4_t) {
let hi_rev = vrev64q_s32(hi);
let hi_rev = vextq_s32(hi_rev, hi_rev, 2);
let even_sum = vaddq_s32(lo, hi_rev);
let even_diff = vsubq_s32(lo, hi_rev);
let tmp0 = vgetq_lane_s32(even_sum, 0);
let tmp1 = vgetq_lane_s32(even_sum, 1);
let tmp2 = vgetq_lane_s32(even_sum, 2);
let tmp3 = vgetq_lane_s32(even_sum, 3);
let tmp10 = tmp0 + tmp3;
let tmp11 = tmp1 + tmp2;
let tmp12 = tmp0 - tmp3;
let tmp13 = tmp1 - tmp2;
let tmp0_odd = vgetq_lane_s32(even_diff, 0);
let tmp1_odd = vgetq_lane_s32(even_diff, 1);
let tmp2_odd = vgetq_lane_s32(even_diff, 2);
let tmp3_odd = vgetq_lane_s32(even_diff, 3);
let descale = PASS1_BITS + 3;
let round = 1 << (descale - 1);
let out0 = (tmp10 + tmp11 + round) >> descale;
let out4 = (tmp10 - tmp11 + round) >> descale;
let z1 = fix_mul(tmp12 + tmp13, FIX_0_541196100);
let out2 = (z1 + fix_mul(tmp12, FIX_0_765366865) + round) >> descale;
let out6 = (z1 - fix_mul(tmp13, FIX_1_847759065) + round) >> descale;
let tmp10_o = tmp0_odd + tmp3_odd;
let tmp11_o = tmp1_odd + tmp2_odd;
let tmp12_o = tmp0_odd + tmp2_odd;
let tmp13_o = tmp1_odd + tmp3_odd;
let z1_o = fix_mul(tmp12_o + tmp13_o, FIX_1_175875602);
let tmp0_m = fix_mul(tmp0_odd, FIX_1_501321110);
let tmp1_m = fix_mul(tmp1_odd, FIX_3_072711026);
let tmp2_m = fix_mul(tmp2_odd, FIX_2_053119869);
let tmp3_m = fix_mul(tmp3_odd, FIX_0_298631336);
let tmp10_m = fix_mul(tmp10_o, -FIX_0_899976223);
let tmp11_m = fix_mul(tmp11_o, -FIX_2_562915447);
let tmp12_m = fix_mul(tmp12_o, -FIX_0_390180644) + z1_o;
let tmp13_m = fix_mul(tmp13_o, -FIX_1_961570560) + z1_o;
let out1 = (tmp0_m + tmp10_m + tmp12_m + round) >> descale;
let out3 = (tmp1_m + tmp11_m + tmp13_m + round) >> descale;
let out5 = (tmp2_m + tmp11_m + tmp12_m + round) >> descale;
let out7 = (tmp3_m + tmp10_m + tmp13_m + round) >> descale;
let result_lo = vsetq_lane_s32(out0, vdupq_n_s32(0), 0);
let result_lo = vsetq_lane_s32(out1, result_lo, 1);
let result_lo = vsetq_lane_s32(out2, result_lo, 2);
let result_lo = vsetq_lane_s32(out3, result_lo, 3);
let result_hi = vsetq_lane_s32(out4, vdupq_n_s32(0), 0);
let result_hi = vsetq_lane_s32(out5, result_hi, 1);
let result_hi = vsetq_lane_s32(out6, result_hi, 2);
let result_hi = vsetq_lane_s32(out7, result_hi, 3);
(result_lo, result_hi)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[inline]
unsafe fn transpose_8x8_neon(rows: &[int32x4x2_t; 8]) -> [int32x4x2_t; 8] {
let tl0 = vtrn1q_s32(rows[0].0, rows[1].0);
let tl1 = vtrn2q_s32(rows[0].0, rows[1].0);
let tl2 = vtrn1q_s32(rows[2].0, rows[3].0);
let tl3 = vtrn2q_s32(rows[2].0, rows[3].0);
let tl_r0 = vreinterpretq_s32_s64(vtrn1q_s64(
vreinterpretq_s64_s32(tl0),
vreinterpretq_s64_s32(tl2),
));
let tl_r1 = vreinterpretq_s32_s64(vtrn1q_s64(
vreinterpretq_s64_s32(tl1),
vreinterpretq_s64_s32(tl3),
));
let tl_r2 = vreinterpretq_s32_s64(vtrn2q_s64(
vreinterpretq_s64_s32(tl0),
vreinterpretq_s64_s32(tl2),
));
let tl_r3 = vreinterpretq_s32_s64(vtrn2q_s64(
vreinterpretq_s64_s32(tl1),
vreinterpretq_s64_s32(tl3),
));
let tr0 = vtrn1q_s32(rows[0].1, rows[1].1);
let tr1 = vtrn2q_s32(rows[0].1, rows[1].1);
let tr2 = vtrn1q_s32(rows[2].1, rows[3].1);
let tr3 = vtrn2q_s32(rows[2].1, rows[3].1);
let tr_r0 = vreinterpretq_s32_s64(vtrn1q_s64(
vreinterpretq_s64_s32(tr0),
vreinterpretq_s64_s32(tr2),
));
let tr_r1 = vreinterpretq_s32_s64(vtrn1q_s64(
vreinterpretq_s64_s32(tr1),
vreinterpretq_s64_s32(tr3),
));
let tr_r2 = vreinterpretq_s32_s64(vtrn2q_s64(
vreinterpretq_s64_s32(tr0),
vreinterpretq_s64_s32(tr2),
));
let tr_r3 = vreinterpretq_s32_s64(vtrn2q_s64(
vreinterpretq_s64_s32(tr1),
vreinterpretq_s64_s32(tr3),
));
let bl0 = vtrn1q_s32(rows[4].0, rows[5].0);
let bl1 = vtrn2q_s32(rows[4].0, rows[5].0);
let bl2 = vtrn1q_s32(rows[6].0, rows[7].0);
let bl3 = vtrn2q_s32(rows[6].0, rows[7].0);
let bl_r0 = vreinterpretq_s32_s64(vtrn1q_s64(
vreinterpretq_s64_s32(bl0),
vreinterpretq_s64_s32(bl2),
));
let bl_r1 = vreinterpretq_s32_s64(vtrn1q_s64(
vreinterpretq_s64_s32(bl1),
vreinterpretq_s64_s32(bl3),
));
let bl_r2 = vreinterpretq_s32_s64(vtrn2q_s64(
vreinterpretq_s64_s32(bl0),
vreinterpretq_s64_s32(bl2),
));
let bl_r3 = vreinterpretq_s32_s64(vtrn2q_s64(
vreinterpretq_s64_s32(bl1),
vreinterpretq_s64_s32(bl3),
));
let br0 = vtrn1q_s32(rows[4].1, rows[5].1);
let br1 = vtrn2q_s32(rows[4].1, rows[5].1);
let br2 = vtrn1q_s32(rows[6].1, rows[7].1);
let br3 = vtrn2q_s32(rows[6].1, rows[7].1);
let br_r0 = vreinterpretq_s32_s64(vtrn1q_s64(
vreinterpretq_s64_s32(br0),
vreinterpretq_s64_s32(br2),
));
let br_r1 = vreinterpretq_s32_s64(vtrn1q_s64(
vreinterpretq_s64_s32(br1),
vreinterpretq_s64_s32(br3),
));
let br_r2 = vreinterpretq_s32_s64(vtrn2q_s64(
vreinterpretq_s64_s32(br0),
vreinterpretq_s64_s32(br2),
));
let br_r3 = vreinterpretq_s32_s64(vtrn2q_s64(
vreinterpretq_s64_s32(br1),
vreinterpretq_s64_s32(br3),
));
[
int32x4x2_t(tl_r0, bl_r0),
int32x4x2_t(tl_r1, bl_r1),
int32x4x2_t(tl_r2, bl_r2),
int32x4x2_t(tl_r3, bl_r3),
int32x4x2_t(tr_r0, br_r0),
int32x4x2_t(tr_r1, br_r1),
int32x4x2_t(tr_r2, br_r2),
int32x4x2_t(tr_r3, br_r3),
]
}
#[inline]
pub fn dct_2d_fast(block: &[i16; 64]) -> [i32; 64] {
let first = block[0];
let is_constant = block[1..].iter().all(|&x| x == first);
if is_constant {
let mut result = [0i32; 64];
result[0] = (first as i32) * 8;
return result;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { crate::simd::x86_64::dct_2d_avx2(block) };
}
dct_2d_integer(block)
}
#[cfg(target_arch = "aarch64")]
{
dct_2d_integer_neon(block)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
dct_2d_integer(block)
}
}
pub fn quantize_block_integer(dct: &[i32; 64], quant_table: &[u16; 64]) -> [i16; 64] {
let mut result = [0i16; 64];
for i in 0..64 {
let q = quant_table[i] as i32;
let coef = dct[i];
if coef >= 0 {
result[i] = ((coef + (q >> 1)) / q) as i16;
} else {
result[i] = ((coef - (q >> 1)) / q) as i16;
}
}
result
}
const A1: f32 = FRAC_1_SQRT_2; const A2: f32 = 0.541_196_1; const A3: f32 = FRAC_1_SQRT_2; const A4: f32 = 1.306_562_9; const A5: f32 = 0.382_683_43;
const S: [f32; 8] = [
0.353_553_4, 0.254_897_8, 0.270_598_1, 0.300_672_4, 0.353_553_4, 0.449_988_1, 0.653_281_5, 1.281_457_8, ];
pub fn dct_2d(block: &[f32; 64]) -> [f32; 64] {
let mut temp = [0.0f32; 64];
let mut result = [0.0f32; 64];
for row in 0..8 {
let row_start = row * 8;
let mut row_data = [0.0f32; 8];
row_data.copy_from_slice(&block[row_start..row_start + 8]);
aan_dct_1d(&mut row_data);
temp[row_start..row_start + 8].copy_from_slice(&row_data);
}
for col in 0..8 {
let mut col_data = [0.0f32; 8];
for row in 0..8 {
col_data[row] = temp[row * 8 + col];
}
aan_dct_1d(&mut col_data);
for row in 0..8 {
result[row * 8 + col] = col_data[row];
}
}
result
}
#[inline]
fn aan_dct_1d(data: &mut [f32; 8]) {
let tmp0 = data[0] + data[7];
let tmp7 = data[0] - data[7];
let tmp1 = data[1] + data[6];
let tmp6 = data[1] - data[6];
let tmp2 = data[2] + data[5];
let tmp5 = data[2] - data[5];
let tmp3 = data[3] + data[4];
let tmp4 = data[3] - data[4];
let tmp10 = tmp0 + tmp3;
let tmp13 = tmp0 - tmp3;
let tmp11 = tmp1 + tmp2;
let tmp12 = tmp1 - tmp2;
data[0] = tmp10 + tmp11;
data[4] = tmp10 - tmp11;
let z1 = (tmp12 + tmp13) * A1; data[2] = tmp13 + z1;
data[6] = tmp13 - z1;
let tmp10 = tmp4 + tmp5;
let tmp11 = tmp5 + tmp6;
let tmp12 = tmp6 + tmp7;
let z5 = (tmp10 - tmp12) * A5; let z2 = tmp10 * A2 + z5; let z4 = tmp12 * A4 + z5; let z3 = tmp11 * A3;
let z11 = tmp7 + z3;
let z13 = tmp7 - z3;
data[5] = z13 + z2;
data[3] = z13 - z2;
data[1] = z11 + z4;
data[7] = z11 - z4;
for i in 0..8 {
data[i] *= S[i];
}
}
const COS_TABLE: [[f32; 8]; 8] = precompute_cos_table();
const fn precompute_cos_table() -> [[f32; 8]; 8] {
let mut table = [[0.0f32; 8]; 8];
let mut i = 0;
while i < 8 {
let mut j = 0;
while j < 8 {
let angle = ((2 * i + 1) * j) as f32 * PI / 16.0;
table[i][j] = cos_approx(angle);
j += 1;
}
i += 1;
}
table
}
const fn cos_approx(x: f32) -> f32 {
let mut x = x;
while x > PI {
x -= 2.0 * PI;
}
while x < -PI {
x += 2.0 * PI;
}
let x2 = x * x;
let x4 = x2 * x2;
let x6 = x4 * x2;
let x8 = x4 * x4;
1.0 - x2 / 2.0 + x4 / 24.0 - x6 / 720.0 + x8 / 40320.0
}
const ALPHA: [f32; 8] = [
FRAC_1_SQRT_2, 1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
];
#[allow(dead_code)]
pub fn idct_2d(block: &[f32; 64]) -> [f32; 64] {
let mut temp = [0.0f32; 64];
let mut result = [0.0f32; 64];
for col in 0..8 {
let mut col_in = [0.0f32; 8];
let mut col_out = [0.0f32; 8];
for row in 0..8 {
col_in[row] = block[row * 8 + col];
}
idct_1d(&col_in, &mut col_out);
for row in 0..8 {
temp[row * 8 + col] = col_out[row];
}
}
for row in 0..8 {
let row_start = row * 8;
idct_1d(
&temp[row_start..row_start + 8],
&mut result[row_start..row_start + 8],
);
}
result
}
fn idct_1d(input: &[f32], output: &mut [f32]) {
debug_assert_eq!(input.len(), 8);
debug_assert_eq!(output.len(), 8);
for n in 0..8 {
let mut sum = 0.0f32;
for k in 0..8 {
sum += ALPHA[k] * input[k] * COS_TABLE[n][k];
}
output[n] = 0.5 * sum;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dct_dc_component() {
let block = [0.0f32; 64];
let result = dct_2d(&block);
for &val in &result {
assert!((val).abs() < 0.001);
}
}
#[test]
fn test_dct_constant_block() {
let block = [100.0f32; 64];
let result = dct_2d(&block);
assert!(result[0].abs() > 100.0);
for &val in result.iter().skip(1) {
assert!(val.abs() < 5.0, "AC component too large: {val}");
}
}
#[test]
fn test_dct_idct_roundtrip() {
let mut block = [0.0f32; 64];
for (i, item) in block.iter_mut().enumerate() {
*item = (i as f32 * 4.0) - 128.0;
}
let dct = dct_2d(&block);
let recovered = idct_2d(&dct);
for i in 0..64 {
assert!(
(block[i] - recovered[i]).abs() < 5.0,
"Mismatch at {}: {} vs {}",
i,
block[i],
recovered[i]
);
}
}
#[test]
fn test_cos_table_values() {
assert!((COS_TABLE[0][0] - 1.0).abs() < 0.0001);
assert!((COS_TABLE[0][2] - (PI / 8.0).cos()).abs() < 0.001);
}
#[test]
fn test_integer_dct_zeros() {
let block = [0i16; 64];
let result = dct_2d_integer(&block);
for &val in &result {
assert_eq!(val, 0);
}
}
#[test]
fn test_integer_dct_constant_block() {
let block = [100i16; 64]; let result = dct_2d_integer(&block);
assert!(result[0] > 100, "DC too small: {}", result[0]);
for (i, &val) in result.iter().enumerate().skip(1) {
assert!(val.abs() <= 1, "AC component at {i} too large: {val}");
}
}
#[test]
fn test_integer_dct_energy_preservation() {
let mut block = [0i16; 64];
for row in 0..8 {
for col in 0..8 {
let val = (row as i32 + col as i32) * 16 - 112;
block[row * 8 + col] = val.clamp(-128, 127) as i16;
}
}
let result = dct_2d_integer(&block);
assert!(
result[0].abs() < 50,
"DC coefficient unexpectedly large: {}",
result[0]
);
let low_freq_energy: i64 = result[..16].iter().map(|&x| (x as i64).pow(2)).sum();
let high_freq_energy: i64 = result[48..].iter().map(|&x| (x as i64).pow(2)).sum();
assert!(
low_freq_energy > high_freq_energy,
"Low freq energy {low_freq_energy} should exceed high freq energy {high_freq_energy}"
);
}
#[test]
fn test_integer_quantize() {
let mut block = [0i16; 64];
block[0] = 100;
let dct = dct_2d_integer(&block);
let mut quant = [16u16; 64];
quant[0] = 16;
let quantized = quantize_block_integer(&dct, &quant);
assert!(quantized[0] != 0, "DC was quantized to zero");
}
#[test]
fn test_dct_2d_fast_matches_integer() {
let mut block = [0i16; 64];
for i in 0..64 {
block[i] = ((i as i32 * 7) % 256 - 128) as i16;
}
let fast_result = dct_2d_fast(&block);
let int_result = dct_2d_integer(&block);
#[cfg(not(target_arch = "aarch64"))]
{
assert_eq!(fast_result, int_result);
}
#[cfg(target_arch = "aarch64")]
{
assert!(
(fast_result[0] - int_result[0]).abs() < 5,
"DC mismatch: {} vs {}",
fast_result[0],
int_result[0]
);
}
}
#[test]
fn test_dct_2d_fast_constant_block_shortcut() {
let block_zero = [0i16; 64];
let result = dct_2d_fast(&block_zero);
assert_eq!(result[0], 0, "DC should be 0 for zero block");
for (i, &val) in result.iter().enumerate().skip(1) {
assert_eq!(val, 0, "AC component at {i} should be 0 for constant block");
}
let block_pos = [50i16; 64];
let result = dct_2d_fast(&block_pos);
assert_eq!(
result[0],
50 * 8,
"DC should be value * 8 for constant block"
);
for (i, &val) in result.iter().enumerate().skip(1) {
assert_eq!(val, 0, "AC component at {i} should be 0 for constant block");
}
let block_neg = [-30i16; 64];
let result = dct_2d_fast(&block_neg);
assert_eq!(
result[0],
-30 * 8,
"DC should be value * 8 for negative constant"
);
for (i, &val) in result.iter().enumerate().skip(1) {
assert_eq!(val, 0, "AC component at {i} should be 0 for constant block");
}
let block_max = [127i16; 64];
let result = dct_2d_fast(&block_max);
assert_eq!(
result[0],
127 * 8,
"DC should be value * 8 for max constant"
);
let block_min = [-128i16; 64];
let result = dct_2d_fast(&block_min);
assert_eq!(
result[0],
-128 * 8,
"DC should be value * 8 for min constant"
);
}
#[test]
fn test_dct_2d_fast_non_constant_block() {
let mut block = [0i16; 64];
for i in 0..64 {
block[i] = (i as i16) - 32; }
let result = dct_2d_fast(&block);
let non_zero_ac = result.iter().skip(1).filter(|&&v| v != 0).count();
assert!(
non_zero_ac > 0,
"Gradient block should have non-zero AC components"
);
}
#[test]
fn test_quantize_block_integer_negative_values() {
let mut dct = [0i32; 64];
dct[0] = 100;
dct[1] = -50;
dct[2] = 75;
dct[3] = -25;
let quant = [16u16; 64];
let quantized = quantize_block_integer(&dct, &quant);
assert_eq!(quantized[0], 6);
assert_eq!(quantized[1], -3);
assert_eq!(quantized[2], 5);
assert_eq!(quantized[3], -2);
}
#[test]
fn test_quantize_block_integer_various_quant_tables() {
let dct = [100i32; 64];
let quant_low = [8u16; 64];
let quant_high = [64u16; 64];
let q_low = quantize_block_integer(&dct, &quant_low);
let q_high = quantize_block_integer(&dct, &quant_high);
assert!(q_low[0] > q_high[0]);
assert_eq!(q_low[0], (100 + 4) / 8); assert_eq!(q_high[0], (100 + 32) / 64); }
#[test]
fn test_quantize_block_integer_edge_values() {
let mut dct = [0i32; 64];
dct[0] = 16; dct[1] = 8; dct[2] = 7;
let quant = [16u16; 64];
let quantized = quantize_block_integer(&dct, &quant);
assert_eq!(quantized[0], 1);
assert_eq!(quantized[1], 1);
assert_eq!(quantized[2], 0);
}
#[test]
fn test_integer_dct_checkerboard() {
let mut block = [0i16; 64];
for row in 0..8 {
for col in 0..8 {
block[row * 8 + col] = if (row + col) % 2 == 0 { 100 } else { -100 };
}
}
let result = dct_2d_integer(&block);
let ac_energy: i64 = result[1..].iter().map(|&x| (x as i64).pow(2)).sum();
assert!(ac_energy > 0, "Checkerboard should have AC energy");
}
#[test]
fn test_integer_dct_horizontal_stripes() {
let mut block = [0i16; 64];
for row in 0..8 {
let val = if row % 2 == 0 { 100i16 } else { -100 };
for col in 0..8 {
block[row * 8 + col] = val;
}
}
let result = dct_2d_integer(&block);
assert!(result[0].abs() < 10, "DC should be small: {}", result[0]);
}
#[test]
fn test_integer_dct_vertical_stripes() {
let mut block = [0i16; 64];
for row in 0..8 {
for col in 0..8 {
block[row * 8 + col] = if col % 2 == 0 { 100 } else { -100 };
}
}
let result = dct_2d_integer(&block);
assert!(result[0].abs() < 10, "DC should be small: {}", result[0]);
}
#[test]
fn test_integer_dct_extreme_values() {
let block_max = [127i16; 64];
let result_max = dct_2d_integer(&block_max);
assert!(
result_max[0] > 0,
"DC should be positive for positive block"
);
let block_min = [-128i16; 64];
let result_min = dct_2d_integer(&block_min);
assert!(
result_min[0] < 0,
"DC should be negative for negative block"
);
}
#[test]
fn test_float_dct_gradient() {
let mut block = [0.0f32; 64];
for row in 0..8 {
for col in 0..8 {
block[row * 8 + col] = (row + col) as f32 * 10.0;
}
}
let result = dct_2d(&block);
let _avg: f32 = block.iter().sum::<f32>() / 64.0;
assert!(result[0] > 0.0, "DC should be positive");
}
#[test]
fn test_float_dct_single_pixel() {
let mut block = [0.0f32; 64];
block[0] = 100.0;
let result = dct_2d(&block);
let total_energy: f32 = result.iter().map(|&x| x * x).sum();
assert!(total_energy > 0.0, "Should have energy");
}
#[test]
fn test_idct_2d_dc_only() {
let mut block = [0.0f32; 64];
block[0] = 100.0;
let result = idct_2d(&block);
let avg = result.iter().sum::<f32>() / 64.0;
for &val in &result {
assert!(
(val - avg).abs() < 1.0,
"DC-only IDCT should produce uniform values"
);
}
}
#[test]
fn test_cos_table_symmetry() {
for i in 0..8 {
assert!(
(COS_TABLE[i][0] - 1.0).abs() < 0.01,
"COS_TABLE[{i}][0] should be ~1.0"
);
}
}
#[test]
fn test_alpha_values() {
assert!((ALPHA[0] - FRAC_1_SQRT_2).abs() < 0.0001);
for &a in &ALPHA[1..] {
assert!((a - 1.0).abs() < 0.0001);
}
}
#[test]
fn test_aan_scale_factors() {
for &s in &S {
assert!(s > 0.0, "Scale factor should be positive");
assert!(s < 2.0, "Scale factor should be less than 2");
}
}
#[test]
fn test_fix_mul_basic() {
let result = fix_mul(8192, 8192); assert_eq!(result, 8192);
let result = fix_mul(16384, 8192); assert_eq!(result, 16384);
let result = fix_mul(8192, 4096); assert_eq!(result, 4096); }
}