#![allow(clippy::needless_range_loop)]
use archmage::prelude::*;
#[cfg(target_arch = "x86_64")]
use archmage::intrinsics::x86_64 as simd_mem;
#[inline]
#[allow(dead_code)]
pub fn sse4x4_scalar(a: &[u8; 16], b: &[u8; 16]) -> u32 {
let mut sum = 0u32;
for i in 0..16 {
let diff = i32::from(a[i]) - i32::from(b[i]);
sum += (diff * diff) as u32;
}
sum
}
#[cfg(target_arch = "x86_64")]
#[rite]
#[allow(dead_code)]
pub(crate) fn sse4x4_sse2(_token: X64V3Token, a: &[u8; 16], b: &[u8; 16]) -> u32 {
let zero = _mm_setzero_si128();
let a_bytes = simd_mem::_mm_loadu_si128(a);
let b_bytes = simd_mem::_mm_loadu_si128(b);
let a_lo = _mm_unpacklo_epi8(a_bytes, zero);
let b_lo = _mm_unpacklo_epi8(b_bytes, zero);
let a_hi = _mm_unpackhi_epi8(a_bytes, zero);
let b_hi = _mm_unpackhi_epi8(b_bytes, zero);
let d_lo = _mm_sub_epi16(a_lo, b_lo);
let d_hi = _mm_sub_epi16(a_hi, b_hi);
let sq_lo = _mm_madd_epi16(d_lo, d_lo);
let sq_hi = _mm_madd_epi16(d_hi, d_hi);
let sum = _mm_add_epi32(sq_lo, sq_hi);
let sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0b10_11_00_01)); let sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0b01_00_11_10));
_mm_cvtsi128_si32(sum) as u32
}
#[inline]
#[allow(dead_code)]
pub fn sse4x4_with_residual_scalar(src: &[u8; 16], pred: &[u8; 16], residual: &[i32; 16]) -> u32 {
let mut sum = 0u32;
for i in 0..16 {
let reconstructed = (i32::from(pred[i]) + residual[i]).clamp(0, 255);
let diff = i32::from(src[i]) - reconstructed;
sum += (diff * diff) as u32;
}
sum
}
#[cfg(target_arch = "x86_64")]
#[rite]
#[allow(dead_code)]
pub(crate) fn sse4x4_with_residual_sse2(
_token: X64V3Token,
src: &[u8; 16],
pred: &[u8; 16],
residual: &[i32; 16],
) -> u32 {
let zero = _mm_setzero_si128();
let max_255 = _mm_set1_epi16(255);
let src_bytes = simd_mem::_mm_loadu_si128(src);
let pred_bytes = simd_mem::_mm_loadu_si128(pred);
let src_lo = _mm_unpacklo_epi8(src_bytes, zero);
let src_hi = _mm_unpackhi_epi8(src_bytes, zero);
let pred_lo = _mm_unpacklo_epi8(pred_bytes, zero);
let pred_hi = _mm_unpackhi_epi8(pred_bytes, zero);
let (r0, r1, r2, r3) = super::q16(residual);
let res0 = simd_mem::_mm_loadu_si128(r0);
let res1 = simd_mem::_mm_loadu_si128(r1);
let res2 = simd_mem::_mm_loadu_si128(r2);
let res3 = simd_mem::_mm_loadu_si128(r3);
let res_lo = _mm_packs_epi32(res0, res1);
let res_hi = _mm_packs_epi32(res2, res3);
let rec_lo = _mm_add_epi16(pred_lo, res_lo);
let rec_hi = _mm_add_epi16(pred_hi, res_hi);
let rec_lo = _mm_max_epi16(rec_lo, zero);
let rec_lo = _mm_min_epi16(rec_lo, max_255);
let rec_hi = _mm_max_epi16(rec_hi, zero);
let rec_hi = _mm_min_epi16(rec_hi, max_255);
let d_lo = _mm_sub_epi16(src_lo, rec_lo);
let d_hi = _mm_sub_epi16(src_hi, rec_hi);
let sq_lo = _mm_madd_epi16(d_lo, d_lo);
let sq_hi = _mm_madd_epi16(d_hi, d_hi);
let sum = _mm_add_epi32(sq_lo, sq_hi);
let sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0b10_11_00_01));
let sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0b01_00_11_10));
_mm_cvtsi128_si32(sum) as u32
}
use super::prediction::{CHROMA_BLOCK_SIZE, CHROMA_STRIDE, LUMA_BLOCK_SIZE, LUMA_STRIDE};
#[inline]
#[allow(dead_code)]
pub fn sse_16x16_luma_scalar(
src_y: &[u8],
src_width: usize,
mbx: usize,
mby: usize,
pred: &[u8; LUMA_BLOCK_SIZE],
) -> u32 {
let mut sse = 0u32;
let src_base = mby * 16 * src_width + mbx * 16;
for y in 0..16 {
let src_row = src_base + y * src_width;
let pred_row = (y + 1) * LUMA_STRIDE + 1;
for x in 0..16 {
let diff = i32::from(src_y[src_row + x]) - i32::from(pred[pred_row + x]);
sse += (diff * diff) as u32;
}
}
sse
}
#[cfg(target_arch = "x86_64")]
#[rite]
#[allow(dead_code)]
pub(crate) fn sse_16x16_luma_sse2(
_token: X64V3Token,
src_y: &[u8],
src_width: usize,
mbx: usize,
mby: usize,
pred: &[u8; LUMA_BLOCK_SIZE],
) -> u32 {
let zero = _mm_setzero_si128();
let mut total = _mm_setzero_si128();
let src_base = mby * 16 * src_width + mbx * 16;
for y in 0..16 {
let src_row = src_base + y * src_width;
let pred_row = (y + 1) * LUMA_STRIDE + 1;
let src_bytes =
simd_mem::_mm_loadu_si128(<&[u8; 16]>::try_from(&src_y[src_row..][..16]).unwrap());
let pred_bytes =
simd_mem::_mm_loadu_si128(<&[u8; 16]>::try_from(&pred[pred_row..][..16]).unwrap());
let src_lo = _mm_unpacklo_epi8(src_bytes, zero);
let src_hi = _mm_unpackhi_epi8(src_bytes, zero);
let pred_lo = _mm_unpacklo_epi8(pred_bytes, zero);
let pred_hi = _mm_unpackhi_epi8(pred_bytes, zero);
let d_lo = _mm_sub_epi16(src_lo, pred_lo);
let d_hi = _mm_sub_epi16(src_hi, pred_hi);
let sq_lo = _mm_madd_epi16(d_lo, d_lo);
let sq_hi = _mm_madd_epi16(d_hi, d_hi);
total = _mm_add_epi32(total, sq_lo);
total = _mm_add_epi32(total, sq_hi);
}
let sum = _mm_add_epi32(total, _mm_shuffle_epi32(total, 0b10_11_00_01));
let sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0b01_00_11_10));
_mm_cvtsi128_si32(sum) as u32
}
#[inline]
#[allow(dead_code)]
pub fn sse_8x8_chroma_scalar(
src_uv: &[u8],
src_width: usize,
mbx: usize,
mby: usize,
pred: &[u8; CHROMA_BLOCK_SIZE],
) -> u32 {
let mut sse = 0u32;
let src_base = mby * 8 * src_width + mbx * 8;
for y in 0..8 {
let src_row = src_base + y * src_width;
let pred_row = (y + 1) * CHROMA_STRIDE + 1;
for x in 0..8 {
let diff = i32::from(src_uv[src_row + x]) - i32::from(pred[pred_row + x]);
sse += (diff * diff) as u32;
}
}
sse
}
#[cfg(target_arch = "x86_64")]
#[rite]
#[allow(dead_code)]
pub(crate) fn sse_8x8_chroma_sse2(
_token: X64V3Token,
src_uv: &[u8],
src_width: usize,
mbx: usize,
mby: usize,
pred: &[u8; CHROMA_BLOCK_SIZE],
) -> u32 {
let zero = _mm_setzero_si128();
let mut total = _mm_setzero_si128();
let src_base = mby * 8 * src_width + mbx * 8;
for y in 0..8 {
let src_row = src_base + y * src_width;
let pred_row = (y + 1) * CHROMA_STRIDE + 1;
let src_bytes =
simd_mem::_mm_loadu_si64(<&[u8; 8]>::try_from(&src_uv[src_row..][..8]).unwrap());
let pred_bytes =
simd_mem::_mm_loadu_si64(<&[u8; 8]>::try_from(&pred[pred_row..][..8]).unwrap());
let src_16 = _mm_unpacklo_epi8(src_bytes, zero);
let pred_16 = _mm_unpacklo_epi8(pred_bytes, zero);
let diff = _mm_sub_epi16(src_16, pred_16);
let sq = _mm_madd_epi16(diff, diff);
total = _mm_add_epi32(total, sq);
}
let sum = _mm_add_epi32(total, _mm_shuffle_epi32(total, 0b10_11_00_01));
let sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0b01_00_11_10));
_mm_cvtsi128_si32(sum) as u32
}
#[inline]
#[allow(dead_code)]
pub fn t_transform_scalar(input: &[u8], stride: usize, w: &[u16; 16]) -> i32 {
let mut tmp = [0i32; 16];
for i in 0..4 {
let row = i * stride;
let a0 = i32::from(input[row]) + i32::from(input[row + 2]);
let a1 = i32::from(input[row + 1]) + i32::from(input[row + 3]);
let a2 = i32::from(input[row + 1]) - i32::from(input[row + 3]);
let a3 = i32::from(input[row]) - i32::from(input[row + 2]);
tmp[i * 4] = a0 + a1;
tmp[i * 4 + 1] = a3 + a2;
tmp[i * 4 + 2] = a3 - a2;
tmp[i * 4 + 3] = a0 - a1;
}
let mut sum = 0i32;
for i in 0..4 {
let a0 = tmp[i] + tmp[8 + i];
let a1 = tmp[4 + i] + tmp[12 + i];
let a2 = tmp[4 + i] - tmp[12 + i];
let a3 = tmp[i] - tmp[8 + i];
let b0 = a0 + a1;
let b1 = a3 + a2;
let b2 = a3 - a2;
let b3 = a0 - a1;
sum += i32::from(w[i]) * b0.abs();
sum += i32::from(w[4 + i]) * b1.abs();
sum += i32::from(w[8 + i]) * b2.abs();
sum += i32::from(w[12 + i]) * b3.abs();
}
sum
}
#[cfg(target_arch = "x86_64")]
#[rite]
#[allow(dead_code)]
fn t_transform_sse2(_token: X64V3Token, input: &[u8], stride: usize, w: &[u16; 16]) -> i32 {
let zero = _mm_setzero_si128();
let row0 = simd_mem::_mm_loadu_si32(<&[u8; 4]>::try_from(&input[0..4]).unwrap());
let row1 = simd_mem::_mm_loadu_si32(<&[u8; 4]>::try_from(&input[stride..][..4]).unwrap());
let row2 = simd_mem::_mm_loadu_si32(<&[u8; 4]>::try_from(&input[stride * 2..][..4]).unwrap());
let row3 = simd_mem::_mm_loadu_si32(<&[u8; 4]>::try_from(&input[stride * 3..][..4]).unwrap());
let in0 = _mm_unpacklo_epi8(row0, zero);
let in1 = _mm_unpacklo_epi8(row1, zero);
let in2 = _mm_unpacklo_epi8(row2, zero);
let in3 = _mm_unpacklo_epi8(row3, zero);
let in01 = _mm_unpacklo_epi64(in0, in1);
let in23 = _mm_unpacklo_epi64(in2, in3);
let shuf_02_13 = _mm_shufflelo_epi16(_mm_shufflehi_epi16(in01, 0b11_01_10_00), 0b11_01_10_00);
let shuf_02_13_23 =
_mm_shufflelo_epi16(_mm_shufflehi_epi16(in23, 0b11_01_10_00), 0b11_01_10_00);
let add_pairs_01 = _mm_madd_epi16(shuf_02_13, _mm_set1_epi16(1));
let add_pairs_23 = _mm_madd_epi16(shuf_02_13_23, _mm_set1_epi16(1));
let sub_pattern = _mm_set_epi16(1, -1, 1, -1, 1, -1, 1, -1);
let sub_pairs_01 = _mm_madd_epi16(shuf_02_13, sub_pattern);
let sub_pairs_23 = _mm_madd_epi16(shuf_02_13_23, sub_pattern);
let mut ap01 = [0i32; 4];
let mut sp01 = [0i32; 4];
let mut ap23 = [0i32; 4];
let mut sp23 = [0i32; 4];
simd_mem::_mm_storeu_si128(&mut ap01, add_pairs_01);
simd_mem::_mm_storeu_si128(&mut sp01, sub_pairs_01);
simd_mem::_mm_storeu_si128(&mut ap23, add_pairs_23);
simd_mem::_mm_storeu_si128(&mut sp23, sub_pairs_23);
let mut tmp = [[0i32; 4]; 4];
tmp[0][0] = ap01[0] + ap01[1]; tmp[0][1] = sp01[0] + sp01[1]; tmp[0][2] = sp01[0] - sp01[1]; tmp[0][3] = ap01[0] - ap01[1]; tmp[1][0] = ap01[2] + ap01[3];
tmp[1][1] = sp01[2] + sp01[3];
tmp[1][2] = sp01[2] - sp01[3];
tmp[1][3] = ap01[2] - ap01[3];
tmp[2][0] = ap23[0] + ap23[1];
tmp[2][1] = sp23[0] + sp23[1];
tmp[2][2] = sp23[0] - sp23[1];
tmp[2][3] = ap23[0] - ap23[1];
tmp[3][0] = ap23[2] + ap23[3];
tmp[3][1] = sp23[2] + sp23[3];
tmp[3][2] = sp23[2] - sp23[3];
tmp[3][3] = ap23[2] - ap23[3];
let mut sum = 0i32;
for i in 0..4 {
let a0 = tmp[0][i] + tmp[2][i];
let a1 = tmp[1][i] + tmp[3][i];
let a2 = tmp[1][i] - tmp[3][i];
let a3 = tmp[0][i] - tmp[2][i];
let b0 = a0 + a1;
let b1 = a3 + a2;
let b2 = a3 - a2;
let b3 = a0 - a1;
sum += w[i] as i32 * b0.abs();
sum += w[4 + i] as i32 * b1.abs();
sum += w[8 + i] as i32 * b2.abs();
sum += w[12 + i] as i32 * b3.abs();
}
sum
}
#[cfg(target_arch = "x86_64")]
#[rite]
pub(crate) fn tdisto_4x4_fused_sse2(
_token: X64V3Token,
a: &[u8],
b: &[u8],
stride: usize,
w: &[u16; 16],
) -> i32 {
let zero = _mm_setzero_si128();
let a0 = simd_mem::_mm_loadu_si32(<&[u8; 4]>::try_from(&a[0..4]).unwrap());
let a1 = simd_mem::_mm_loadu_si32(<&[u8; 4]>::try_from(&a[stride..][..4]).unwrap());
let a2 = simd_mem::_mm_loadu_si32(<&[u8; 4]>::try_from(&a[stride * 2..][..4]).unwrap());
let a3 = simd_mem::_mm_loadu_si32(<&[u8; 4]>::try_from(&a[stride * 3..][..4]).unwrap());
let b0 = simd_mem::_mm_loadu_si32(<&[u8; 4]>::try_from(&b[0..4]).unwrap());
let b1 = simd_mem::_mm_loadu_si32(<&[u8; 4]>::try_from(&b[stride..][..4]).unwrap());
let b2 = simd_mem::_mm_loadu_si32(<&[u8; 4]>::try_from(&b[stride * 2..][..4]).unwrap());
let b3 = simd_mem::_mm_loadu_si32(<&[u8; 4]>::try_from(&b[stride * 3..][..4]).unwrap());
let ab0 = _mm_unpacklo_epi32(a0, b0);
let ab1 = _mm_unpacklo_epi32(a1, b1);
let ab2 = _mm_unpacklo_epi32(a2, b2);
let ab3 = _mm_unpacklo_epi32(a3, b3);
let mut tmp0 = _mm_unpacklo_epi8(ab0, zero);
let mut tmp1 = _mm_unpacklo_epi8(ab1, zero);
let mut tmp2 = _mm_unpacklo_epi8(ab2, zero);
let mut tmp3 = _mm_unpacklo_epi8(ab3, zero);
{
let va0 = _mm_add_epi16(tmp0, tmp2);
let va1 = _mm_add_epi16(tmp1, tmp3);
let va2 = _mm_sub_epi16(tmp1, tmp3);
let va3 = _mm_sub_epi16(tmp0, tmp2);
let vb0 = _mm_add_epi16(va0, va1);
let vb1 = _mm_add_epi16(va3, va2);
let vb2 = _mm_sub_epi16(va3, va2);
let vb3 = _mm_sub_epi16(va0, va1);
let tr0_0 = _mm_unpacklo_epi16(vb0, vb1);
let tr0_1 = _mm_unpacklo_epi16(vb2, vb3);
let tr0_2 = _mm_unpackhi_epi16(vb0, vb1);
let tr0_3 = _mm_unpackhi_epi16(vb2, vb3);
let tr1_0 = _mm_unpacklo_epi32(tr0_0, tr0_1);
let tr1_1 = _mm_unpacklo_epi32(tr0_2, tr0_3);
let tr1_2 = _mm_unpackhi_epi32(tr0_0, tr0_1);
let tr1_3 = _mm_unpackhi_epi32(tr0_2, tr0_3);
tmp0 = _mm_unpacklo_epi64(tr1_0, tr1_1);
tmp1 = _mm_unpackhi_epi64(tr1_0, tr1_1);
tmp2 = _mm_unpacklo_epi64(tr1_2, tr1_3);
tmp3 = _mm_unpackhi_epi64(tr1_2, tr1_3);
}
let ha0 = _mm_add_epi16(tmp0, tmp2);
let ha1 = _mm_add_epi16(tmp1, tmp3);
let ha2 = _mm_sub_epi16(tmp1, tmp3);
let ha3 = _mm_sub_epi16(tmp0, tmp2);
let hb0 = _mm_add_epi16(ha0, ha1);
let hb1 = _mm_add_epi16(ha3, ha2);
let hb2 = _mm_sub_epi16(ha3, ha2);
let hb3 = _mm_sub_epi16(ha0, ha1);
let a_01 = _mm_unpacklo_epi64(hb0, hb1); let a_23 = _mm_unpacklo_epi64(hb2, hb3); let b_01 = _mm_unpackhi_epi64(hb0, hb1); let b_23 = _mm_unpackhi_epi64(hb2, hb3);
let a_abs_01 = _mm_max_epi16(a_01, _mm_sub_epi16(zero, a_01));
let a_abs_23 = _mm_max_epi16(a_23, _mm_sub_epi16(zero, a_23));
let b_abs_01 = _mm_max_epi16(b_01, _mm_sub_epi16(zero, b_01));
let b_abs_23 = _mm_max_epi16(b_23, _mm_sub_epi16(zero, b_23));
let (w_lo, w_hi) = super::h16(w);
let w_0 = simd_mem::_mm_loadu_si128(w_lo);
let w_8 = simd_mem::_mm_loadu_si128(w_hi);
let a_prod_01 = _mm_madd_epi16(a_abs_01, w_0);
let a_prod_23 = _mm_madd_epi16(a_abs_23, w_8);
let b_prod_01 = _mm_madd_epi16(b_abs_01, w_0);
let b_prod_23 = _mm_madd_epi16(b_abs_23, w_8);
let a_sum_01_23 = _mm_add_epi32(a_prod_01, a_prod_23);
let a_hi = _mm_shuffle_epi32(a_sum_01_23, 0b10_11_00_01);
let a_sum_2 = _mm_add_epi32(a_sum_01_23, a_hi);
let a_final = _mm_shuffle_epi32(a_sum_2, 0b01_00_11_10);
let a_sum_3 = _mm_add_epi32(a_sum_2, a_final);
let sum_a = _mm_cvtsi128_si32(a_sum_3);
let b_sum_01_23 = _mm_add_epi32(b_prod_01, b_prod_23);
let b_hi = _mm_shuffle_epi32(b_sum_01_23, 0b10_11_00_01);
let b_sum_2 = _mm_add_epi32(b_sum_01_23, b_hi);
let b_final = _mm_shuffle_epi32(b_sum_2, 0b01_00_11_10);
let b_sum_3 = _mm_add_epi32(b_sum_2, b_final);
let sum_b = _mm_cvtsi128_si32(b_sum_3);
(sum_b - sum_a).abs() >> 5
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[cfg(target_arch = "x86_64")]
#[arcane]
fn call_sse4x4(t: X64V3Token, a: &[u8; 16], b: &[u8; 16]) -> u32 {
sse4x4_sse2(t, a, b)
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn call_sse4x4_with_residual(
t: X64V3Token,
src: &[u8; 16],
pred: &[u8; 16],
res: &[i32; 16],
) -> u32 {
sse4x4_with_residual_sse2(t, src, pred, res)
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn call_sse_16x16_luma(
t: X64V3Token,
src: &[u8],
w: usize,
mx: usize,
my: usize,
p: &[u8; LUMA_BLOCK_SIZE],
) -> u32 {
sse_16x16_luma_sse2(t, src, w, mx, my, p)
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn call_sse_8x8_chroma(
t: X64V3Token,
src: &[u8],
w: usize,
mx: usize,
my: usize,
p: &[u8; CHROMA_BLOCK_SIZE],
) -> u32 {
sse_8x8_chroma_sse2(t, src, w, mx, my, p)
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn call_t_transform(t: X64V3Token, input: &[u8], stride: usize, w: &[u16; 16]) -> i32 {
t_transform_sse2(t, input, stride, w)
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn call_tdisto_fused(t: X64V3Token, a: &[u8], b: &[u8], stride: usize, w: &[u16; 16]) -> i32 {
tdisto_4x4_fused_sse2(t, a, b, stride, w)
}
#[test]
fn test_sse4x4_scalar() {
let a = [10u8; 16];
let b = [12u8; 16];
assert_eq!(sse4x4_scalar(&a, &b), 64);
}
#[test]
fn test_sse4x4_identical() {
let a = [100u8; 16];
assert_eq!(sse4x4_scalar(&a, &a), 0);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_sse4x4_simd_matches_scalar() {
let Some(token) = X64V3Token::summon() else {
return;
};
let a: [u8; 16] = [
10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160,
];
let b: [u8; 16] = [
12, 18, 33, 38, 55, 58, 73, 78, 93, 98, 113, 118, 133, 138, 153, 158,
];
let scalar = sse4x4_scalar(&a, &b);
let simd = call_sse4x4(token, &a, &b);
assert_eq!(scalar, simd);
}
#[test]
fn test_sse4x4_with_residual_scalar() {
let src = [100u8; 16];
let pred = [90u8; 16];
let residual = [10i32; 16]; assert_eq!(sse4x4_with_residual_scalar(&src, &pred, &residual), 0);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_sse4x4_with_residual_simd_matches_scalar() {
let Some(token) = X64V3Token::summon() else {
return;
};
let src: [u8; 16] = [
100, 110, 120, 130, 90, 80, 70, 60, 50, 40, 30, 20, 10, 5, 3, 1,
];
let pred: [u8; 16] = [
95, 105, 115, 125, 85, 75, 65, 55, 45, 35, 25, 15, 5, 2, 1, 0,
];
let residual: [i32; 16] = [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 3, 2, 1];
let scalar = sse4x4_with_residual_scalar(&src, &pred, &residual);
let simd = call_sse4x4_with_residual(token, &src, &pred, &residual);
assert_eq!(scalar, simd);
}
#[test]
fn test_sse_16x16_luma_scalar() {
let src_width = 32;
let mut src_y = vec![100u8; 32 * 32];
for y in 0..16 {
for x in 0..16 {
src_y[y * src_width + x] = 100 + (x as u8);
}
}
let mut pred = [0u8; LUMA_BLOCK_SIZE];
for y in 0..16 {
for x in 0..16 {
pred[(y + 1) * LUMA_STRIDE + 1 + x] = 102 + (x as u8); }
}
let sse = sse_16x16_luma_scalar(&src_y, src_width, 0, 0, &pred);
assert_eq!(sse, 1024);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_sse_16x16_luma_simd_matches_scalar() {
let Some(token) = X64V3Token::summon() else {
return;
};
let src_width = 32;
let mut src_y = vec![0u8; 32 * 32];
for y in 0..16 {
for x in 0..16 {
src_y[y * src_width + x] = ((y * 16 + x) % 256) as u8;
}
}
let mut pred = [0u8; LUMA_BLOCK_SIZE];
for y in 0..16 {
for x in 0..16 {
pred[(y + 1) * LUMA_STRIDE + 1 + x] = ((y * 16 + x + 5) % 256) as u8;
}
}
let scalar = sse_16x16_luma_scalar(&src_y, src_width, 0, 0, &pred);
let simd = call_sse_16x16_luma(token, &src_y, src_width, 0, 0, &pred);
assert_eq!(scalar, simd);
}
#[test]
fn test_sse_8x8_chroma_scalar() {
let src_width = 16;
let mut src_uv = vec![128u8; 16 * 16];
for y in 0..8 {
for x in 0..8 {
src_uv[y * src_width + x] = 128 + (x as u8);
}
}
let mut pred = [0u8; CHROMA_BLOCK_SIZE];
for y in 0..8 {
for x in 0..8 {
pred[(y + 1) * CHROMA_STRIDE + 1 + x] = 130 + (x as u8); }
}
let sse = sse_8x8_chroma_scalar(&src_uv, src_width, 0, 0, &pred);
assert_eq!(sse, 256);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_sse_8x8_chroma_simd_matches_scalar() {
let src_width = 16;
let mut src_uv = vec![0u8; 16 * 16];
for y in 0..8 {
for x in 0..8 {
src_uv[y * src_width + x] = ((y * 8 + x * 3) % 256) as u8;
}
}
let mut pred = [0u8; CHROMA_BLOCK_SIZE];
for y in 0..8 {
for x in 0..8 {
pred[(y + 1) * CHROMA_STRIDE + 1 + x] = ((y * 8 + x * 3 + 7) % 256) as u8;
}
}
let scalar = sse_8x8_chroma_scalar(&src_uv, src_width, 0, 0, &pred);
let Some(token) = X64V3Token::summon() else {
return;
};
let simd = call_sse_8x8_chroma(token, &src_uv, src_width, 0, 0, &pred);
assert_eq!(scalar, simd);
}
#[test]
fn test_t_transform_scalar_basic() {
let mut input = [0u8; 64]; for y in 0..4 {
for x in 0..4 {
input[y * 16 + x] = ((y * 4 + x) * 10) as u8;
}
}
let weights: [u16; 16] = [1; 16];
let result = t_transform_scalar(&input, 16, &weights);
assert!(result > 0);
}
#[test]
fn test_t_transform_scalar_uniform() {
let mut input = [128u8; 64];
for y in 0..4 {
for x in 0..4 {
input[y * 16 + x] = 100;
}
}
let weights: [u16; 16] = [1; 16];
let result = t_transform_scalar(&input, 16, &weights);
assert!(result > 0);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_t_transform_simd_matches_scalar() {
let Some(token) = X64V3Token::summon() else {
return;
};
let mut input = [0u8; 64];
for y in 0..4 {
for x in 0..4 {
input[y * 16 + x] = ((y * 37 + x * 23 + 50) % 256) as u8;
}
}
let weights: [u16; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let scalar = t_transform_scalar(&input, 16, &weights);
let simd = call_t_transform(token, &input, 16, &weights);
assert_eq!(scalar, simd);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_t_transform_simd_matches_scalar_varied() {
let Some(token) = X64V3Token::summon() else {
return;
};
for stride in [4, 8, 16, 32] {
let mut input = vec![0u8; 4 * stride];
for y in 0..4 {
for x in 0..4 {
input[y * stride + x] = ((y * 53 + x * 41 + 17) % 256) as u8;
}
}
let weights: [u16; 16] = [100, 90, 80, 70, 60, 50, 40, 30, 20, 10, 5, 4, 3, 2, 1, 1];
let scalar = t_transform_scalar(&input, stride, &weights);
let simd = call_t_transform(token, &input, stride, &weights);
assert_eq!(scalar, simd, "Mismatch at stride {stride}");
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_tdisto_4x4_fused_matches_scalar() {
let Some(token) = X64V3Token::summon() else {
return;
};
#[rustfmt::skip]
let weights: [u16; 16] = [
38, 32, 20, 9,
32, 28, 17, 7,
20, 17, 10, 4,
9, 7, 4, 2,
];
for stride in [4, 8, 16, 32] {
let mut a = vec![0u8; 4 * stride];
let mut b = vec![0u8; 4 * stride];
for y in 0..4 {
for x in 0..4 {
a[y * stride + x] = ((y * 53 + x * 41 + 17) % 256) as u8;
b[y * stride + x] = ((y * 37 + x * 29 + 11) % 256) as u8;
}
}
let sum_a = t_transform_scalar(&a, stride, &weights);
let sum_b = t_transform_scalar(&b, stride, &weights);
let scalar_result = (sum_b - sum_a).abs() >> 5;
let fused_result = call_tdisto_fused(token, &a, &b, stride, &weights);
assert_eq!(scalar_result, fused_result, "Mismatch at stride {stride}");
}
let same: [u8; 16] = [128; 16];
let result = call_tdisto_fused(token, &same, &same, 4, &weights);
assert_eq!(result, 0, "Identical blocks should have 0 distortion");
}
}