#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
const PASS1_BITS: i32 = 2;
const CONST_BITS: i32 = 13;
const FIX_0_298: i32 = 2446;
const FIX_0_390: i32 = 3196;
const FIX_0_541: i32 = 4433;
const FIX_0_765: i32 = 6270;
const FIX_0_899: i32 = 7373;
const FIX_1_175: i32 = 9633;
const FIX_1_501: i32 = 12299;
const FIX_1_847: i32 = 15137;
const FIX_1_961: i32 = 16069;
const FIX_2_053: i32 = 16819;
const FIX_2_562: i32 = 20995;
const FIX_3_072: i32 = 25172;
macro_rules! transpose4x4 {
($r0:expr, $r1:expr, $r2:expr, $r3:expr) => {{
let t0 = _mm_unpacklo_epi32($r0, $r1);
let t1 = _mm_unpackhi_epi32($r0, $r1);
let t2 = _mm_unpacklo_epi32($r2, $r3);
let t3 = _mm_unpackhi_epi32($r2, $r3);
$r0 = _mm_unpacklo_epi64(t0, t2);
$r1 = _mm_unpackhi_epi64(t0, t2);
$r2 = _mm_unpacklo_epi64(t1, t3);
$r3 = _mm_unpackhi_epi64(t1, t3);
}};
}
#[inline(always)]
unsafe fn even_part(
s0: __m128i,
s2: __m128i,
s4: __m128i,
s6: __m128i,
bias: __m128i,
) -> (__m128i, __m128i, __m128i, __m128i) {
let tmp0 = _mm_slli_epi32::<CONST_BITS>(s0);
let tmp2 = _mm_slli_epi32::<CONST_BITS>(s4);
let tmp10 = _mm_add_epi32(tmp0, tmp2);
let tmp11 = _mm_sub_epi32(tmp0, tmp2);
let z1 = _mm_add_epi32(s2, s6);
let fix_0541 = _mm_set1_epi32(FIX_0_541);
let fix_0765 = _mm_set1_epi32(FIX_0_765);
let fix_1847 = _mm_set1_epi32(FIX_1_847);
let tmp13 = _mm_add_epi32(_mm_mullo_epi32(z1, fix_0541), _mm_mullo_epi32(s2, fix_0765));
let tmp12 = _mm_sub_epi32(_mm_mullo_epi32(z1, fix_0541), _mm_mullo_epi32(s6, fix_1847));
let t0 = _mm_add_epi32(_mm_add_epi32(tmp10, tmp13), bias);
let t3 = _mm_add_epi32(_mm_sub_epi32(tmp10, tmp13), bias);
let t1 = _mm_add_epi32(_mm_add_epi32(tmp11, tmp12), bias);
let t2 = _mm_add_epi32(_mm_sub_epi32(tmp11, tmp12), bias);
(t0, t1, t2, t3)
}
#[inline(always)]
unsafe fn odd_part(
s1: __m128i,
s3: __m128i,
s5: __m128i,
s7: __m128i,
) -> (__m128i, __m128i, __m128i, __m128i) {
let z1 = _mm_add_epi32(s7, s1);
let z2 = _mm_add_epi32(s5, s3);
let z3 = _mm_add_epi32(s7, s3);
let z4 = _mm_add_epi32(s5, s1);
let z5 = _mm_mullo_epi32(_mm_add_epi32(z3, z4), _mm_set1_epi32(FIX_1_175));
let p7 = _mm_mullo_epi32(s7, _mm_set1_epi32(FIX_0_298));
let p5 = _mm_mullo_epi32(s5, _mm_set1_epi32(FIX_2_053));
let p3 = _mm_mullo_epi32(s3, _mm_set1_epi32(FIX_3_072));
let p1 = _mm_mullo_epi32(s1, _mm_set1_epi32(FIX_1_501));
let z1 = _mm_mullo_epi32(z1, _mm_set1_epi32(-FIX_0_899));
let z2 = _mm_mullo_epi32(z2, _mm_set1_epi32(-FIX_2_562));
let z3 = _mm_add_epi32(_mm_mullo_epi32(z3, _mm_set1_epi32(-FIX_1_961)), z5);
let z4 = _mm_add_epi32(_mm_mullo_epi32(z4, _mm_set1_epi32(-FIX_0_390)), z5);
(
_mm_add_epi32(_mm_add_epi32(p7, z1), z3),
_mm_add_epi32(_mm_add_epi32(p5, z2), z4),
_mm_add_epi32(_mm_add_epi32(p3, z2), z3),
_mm_add_epi32(_mm_add_epi32(p1, z1), z4),
)
}
pub fn idct_8x8_sse41(coeffs: &[i32; 64], output: &mut [u8], stride: usize) {
unsafe { idct_8x8_sse41_inner(coeffs, output, stride) }
}
#[target_feature(enable = "sse4.1")]
unsafe fn idct_8x8_sse41_inner(coeffs: &[i32; 64], output: &mut [u8], stride: usize) {
let mut workspace = [0i32; 64];
let half = _mm_set1_epi32(1 << (CONST_BITS - 1));
for col_group in 0..2 {
let base = col_group * 4;
let s0 = _mm_loadu_si128(coeffs.as_ptr().add(base) as *const __m128i);
let s1 = _mm_loadu_si128(coeffs.as_ptr().add(8 + base) as *const __m128i);
let s2 = _mm_loadu_si128(coeffs.as_ptr().add(16 + base) as *const __m128i);
let s3 = _mm_loadu_si128(coeffs.as_ptr().add(24 + base) as *const __m128i);
let s4 = _mm_loadu_si128(coeffs.as_ptr().add(32 + base) as *const __m128i);
let s5 = _mm_loadu_si128(coeffs.as_ptr().add(40 + base) as *const __m128i);
let s6 = _mm_loadu_si128(coeffs.as_ptr().add(48 + base) as *const __m128i);
let s7 = _mm_loadu_si128(coeffs.as_ptr().add(56 + base) as *const __m128i);
let (t0, t1, t2, t3) = even_part(s0, s2, s4, s6, half);
let (o0, o1, o2, o3) = odd_part(s1, s3, s5, s7);
let shift = 13 - 2; let r0 = _mm_srai_epi32::<{ 13 - 2 }>(_mm_add_epi32(t0, o3));
let r7 = _mm_srai_epi32::<{ 13 - 2 }>(_mm_sub_epi32(t0, o3));
let r1 = _mm_srai_epi32::<{ 13 - 2 }>(_mm_add_epi32(t1, o2));
let r6 = _mm_srai_epi32::<{ 13 - 2 }>(_mm_sub_epi32(t1, o2));
let r2 = _mm_srai_epi32::<{ 13 - 2 }>(_mm_add_epi32(t2, o1));
let r5 = _mm_srai_epi32::<{ 13 - 2 }>(_mm_sub_epi32(t2, o1));
let r3 = _mm_srai_epi32::<{ 13 - 2 }>(_mm_add_epi32(t3, o0));
let r4 = _mm_srai_epi32::<{ 13 - 2 }>(_mm_sub_epi32(t3, o0));
_mm_storeu_si128(workspace.as_mut_ptr().add(base) as *mut __m128i, r0);
_mm_storeu_si128(workspace.as_mut_ptr().add(8 + base) as *mut __m128i, r1);
_mm_storeu_si128(workspace.as_mut_ptr().add(16 + base) as *mut __m128i, r2);
_mm_storeu_si128(workspace.as_mut_ptr().add(24 + base) as *mut __m128i, r3);
_mm_storeu_si128(workspace.as_mut_ptr().add(32 + base) as *mut __m128i, r4);
_mm_storeu_si128(workspace.as_mut_ptr().add(40 + base) as *mut __m128i, r5);
_mm_storeu_si128(workspace.as_mut_ptr().add(48 + base) as *mut __m128i, r6);
_mm_storeu_si128(workspace.as_mut_ptr().add(56 + base) as *mut __m128i, r7);
let _ = shift; }
let range_shift = CONST_BITS + PASS1_BITS + 3;
let bias_val = (1 << (range_shift - 1)) + (128 << range_shift);
let bias = _mm_set1_epi32(bias_val);
let zero = _mm_setzero_si128();
let max255 = _mm_set1_epi32(255);
for row_group in 0..2 {
let row_base = row_group * 4 * 8;
let mut r0l = _mm_loadu_si128(workspace.as_ptr().add(row_base) as *const __m128i);
let mut r1l = _mm_loadu_si128(workspace.as_ptr().add(row_base + 8) as *const __m128i);
let mut r2l = _mm_loadu_si128(workspace.as_ptr().add(row_base + 16) as *const __m128i);
let mut r3l = _mm_loadu_si128(workspace.as_ptr().add(row_base + 24) as *const __m128i);
transpose4x4!(r0l, r1l, r2l, r3l);
let mut r0r = _mm_loadu_si128(workspace.as_ptr().add(row_base + 4) as *const __m128i);
let mut r1r = _mm_loadu_si128(workspace.as_ptr().add(row_base + 8 + 4) as *const __m128i);
let mut r2r = _mm_loadu_si128(workspace.as_ptr().add(row_base + 16 + 4) as *const __m128i);
let mut r3r = _mm_loadu_si128(workspace.as_ptr().add(row_base + 24 + 4) as *const __m128i);
transpose4x4!(r0r, r1r, r2r, r3r);
let (t0, t1, t2, t3) = even_part(r0l, r2l, r0r, r2r, bias);
let (o0, o1, o2, o3) = odd_part(r1l, r3l, r1r, r3r);
let clamp = |v: __m128i| -> __m128i {
let shifted = _mm_srai_epi32::<{ 13 + 2 + 3 }>(v);
_mm_min_epi32(_mm_max_epi32(shifted, zero), max255)
};
let out0 = clamp(_mm_add_epi32(t0, o3));
let out7 = clamp(_mm_sub_epi32(t0, o3));
let out1 = clamp(_mm_add_epi32(t1, o2));
let out6 = clamp(_mm_sub_epi32(t1, o2));
let out2 = clamp(_mm_add_epi32(t2, o1));
let out5 = clamp(_mm_sub_epi32(t2, o1));
let out3 = clamp(_mm_add_epi32(t3, o0));
let out4 = clamp(_mm_sub_epi32(t3, o0));
let mut tmp = [0i32; 32];
_mm_storeu_si128(tmp.as_mut_ptr() as *mut __m128i, out0);
_mm_storeu_si128(tmp.as_mut_ptr().add(4) as *mut __m128i, out1);
_mm_storeu_si128(tmp.as_mut_ptr().add(8) as *mut __m128i, out2);
_mm_storeu_si128(tmp.as_mut_ptr().add(12) as *mut __m128i, out3);
_mm_storeu_si128(tmp.as_mut_ptr().add(16) as *mut __m128i, out4);
_mm_storeu_si128(tmp.as_mut_ptr().add(20) as *mut __m128i, out5);
_mm_storeu_si128(tmp.as_mut_ptr().add(24) as *mut __m128i, out6);
_mm_storeu_si128(tmp.as_mut_ptr().add(28) as *mut __m128i, out7);
for local_row in 0..4 {
let row_idx = row_group * 4 + local_row;
let out_ptr = output.as_mut_ptr().add(row_idx * stride);
for col in 0..8 {
*out_ptr.add(col) = tmp[col * 4 + local_row] as u8;
}
}
}
}
pub fn idct_dc_only_sse41(dc_value: i32, output: &mut [u8], stride: usize) {
let range_shift = CONST_BITS + PASS1_BITS + 3;
let round = 1 << (range_shift - 1);
let bias = round + (128 << range_shift);
let scaled = dc_value << (CONST_BITS + PASS1_BITS);
let val = ((scaled + bias) >> range_shift).clamp(0, 255) as u8;
unsafe {
let fill = _mm_set1_epi8(val as i8);
for row in 0..8 {
_mm_storel_epi64(output.as_mut_ptr().add(row * stride) as *mut __m128i, fill);
}
}
}
#[cfg(test)]
#[cfg(target_arch = "x86_64")]
mod tests {
use super::super::scalar::{idct_8x8_scalar, idct_dc_only_scalar};
use super::{idct_8x8_sse41, idct_dc_only_sse41};
fn make_test_coeffs() -> [i32; 64] {
let mut c = [0i32; 64];
c[0] = 256;
c[1] = 32;
c[8] = 16;
c[2] = -24;
c[16] = 20;
c[9] = -12;
c[3] = 8;
c[24] = -8;
c
}
#[test]
fn idct_8x8_parity() {
if !is_x86_feature_detected!("sse4.1") {
eprintln!("SIMD feature not available, skipping");
return;
}
let coeffs = make_test_coeffs();
let mut scalar_out = [0u8; 64];
let mut simd_out = [0u8; 64];
idct_8x8_scalar(&coeffs, &mut scalar_out, 8);
idct_8x8_sse41(&coeffs, &mut simd_out, 8);
for i in 0..64 {
let diff = (scalar_out[i] as i32 - simd_out[i] as i32).abs();
assert!(
diff <= 1,
"parity mismatch at index {i}: scalar={}, simd={}, diff={}",
scalar_out[i],
simd_out[i],
diff
);
}
}
#[test]
fn idct_8x8_parity_zero_block() {
if !is_x86_feature_detected!("sse4.1") {
eprintln!("SIMD feature not available, skipping");
return;
}
let coeffs = [0i32; 64];
let mut scalar_out = [0u8; 64];
let mut simd_out = [0u8; 64];
idct_8x8_scalar(&coeffs, &mut scalar_out, 8);
idct_8x8_sse41(&coeffs, &mut simd_out, 8);
for i in 0..64 {
let diff = (scalar_out[i] as i32 - simd_out[i] as i32).abs();
assert!(
diff <= 1,
"zero-block parity mismatch at index {i}: scalar={}, simd={}, diff={}",
scalar_out[i],
simd_out[i],
diff
);
}
}
#[test]
fn idct_8x8_parity_strided() {
if !is_x86_feature_detected!("sse4.1") {
eprintln!("SIMD feature not available, skipping");
return;
}
let coeffs = make_test_coeffs();
let stride = 16usize;
let mut scalar_out = vec![0u8; stride * 8];
let mut simd_out = vec![0u8; stride * 8];
idct_8x8_scalar(&coeffs, &mut scalar_out, stride);
idct_8x8_sse41(&coeffs, &mut simd_out, stride);
for row in 0..8 {
for col in 0..8 {
let i = row * stride + col;
let diff = (scalar_out[i] as i32 - simd_out[i] as i32).abs();
assert!(
diff <= 1,
"strided parity mismatch at row={row} col={col}: scalar={}, simd={}, diff={}",
scalar_out[i],
simd_out[i],
diff
);
}
}
}
#[test]
fn idct_dc_only_parity() {
if !is_x86_feature_detected!("sse4.1") {
eprintln!("SIMD feature not available, skipping");
return;
}
for dc in [0i32, 8, 64, 128, -64, 255, -255] {
let mut scalar_out = [0u8; 64];
let mut simd_out = [0u8; 64];
idct_dc_only_scalar(dc, &mut scalar_out, 8);
idct_dc_only_sse41(dc, &mut simd_out, 8);
for i in 0..64 {
let diff = (scalar_out[i] as i32 - simd_out[i] as i32).abs();
assert!(
diff <= 1,
"dc_only parity mismatch dc={dc} at index {i}: scalar={}, simd={}, diff={}",
scalar_out[i],
simd_out[i],
diff
);
}
}
}
}