#[cfg(target_arch = "x86_64")]
use archmage::{Desktop64, SimdToken, arcane};
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
#[cfg(target_arch = "wasm32")]
use archmage::{SimdToken, Wasm128Token, arcane};
#[cfg(target_arch = "wasm32")]
use core::arch::wasm32::*;
const PW_1024: i16 = 1024;
#[cfg(target_arch = "x86_64")]
#[arcane]
pub fn avg_8bpc_avx2(
_token: Desktop64,
dst: &mut [u8],
dst_stride: usize,
tmp1: &[i16],
tmp2: &[i16],
w: usize,
h: usize,
) {
debug_assert!(
w.is_multiple_of(32),
"width must be multiple of 32 for AVX2"
);
debug_assert!(tmp1.len() >= w * h, "tmp1 too small");
debug_assert!(tmp2.len() >= w * h, "tmp2 too small");
debug_assert!(dst.len() >= (h - 1) * dst_stride + w, "dst too small");
let round = _mm256_set1_epi16(PW_1024);
for row in 0..h {
let tmp1_row = &tmp1[row * w..][..w];
let tmp2_row = &tmp2[row * w..][..w];
let dst_row = &mut dst[row * dst_stride..][..w];
let mut col = 0;
while col + 32 <= w {
let t1_lo_arr: &[i16; 16] = tmp1_row[col..col + 16].try_into().unwrap();
let t1_hi_arr: &[i16; 16] = tmp1_row[col + 16..col + 32].try_into().unwrap();
let t2_lo_arr: &[i16; 16] = tmp2_row[col..col + 16].try_into().unwrap();
let t2_hi_arr: &[i16; 16] = tmp2_row[col + 16..col + 32].try_into().unwrap();
let t1_lo = safe_unaligned_simd::x86_64::_mm256_loadu_si256(t1_lo_arr);
let t1_hi = safe_unaligned_simd::x86_64::_mm256_loadu_si256(t1_hi_arr);
let t2_lo = safe_unaligned_simd::x86_64::_mm256_loadu_si256(t2_lo_arr);
let t2_hi = safe_unaligned_simd::x86_64::_mm256_loadu_si256(t2_hi_arr);
let sum_lo = _mm256_add_epi16(t1_lo, t2_lo);
let sum_hi = _mm256_add_epi16(t1_hi, t2_hi);
let avg_lo = _mm256_mulhrs_epi16(sum_lo, round);
let avg_hi = _mm256_mulhrs_epi16(sum_hi, round);
let packed = _mm256_packus_epi16(avg_lo, avg_hi);
let result = _mm256_permute4x64_epi64(packed, 0b11_01_10_00);
let dst_arr: &mut [u8; 32] = (&mut dst_row[col..col + 32]).try_into().unwrap();
safe_unaligned_simd::x86_64::_mm256_storeu_si256(dst_arr, result);
col += 32;
}
while col < w {
let sum = tmp1_row[col].wrapping_add(tmp2_row[col]);
let avg = ((sum as i32 * 1024 + 16384) >> 15).clamp(0, 255) as u8;
dst_row[col] = avg;
col += 1;
}
}
}
#[cfg(target_arch = "wasm32")]
#[arcane]
pub fn avg_8bpc_wasm128(
_token: Wasm128Token,
dst: &mut [u8],
dst_stride: usize,
tmp1: &[i16],
tmp2: &[i16],
w: usize,
h: usize,
) {
debug_assert!(tmp1.len() >= w * h, "tmp1 too small");
debug_assert!(tmp2.len() >= w * h, "tmp2 too small");
debug_assert!(dst.len() >= (h - 1) * dst_stride + w, "dst too small");
let round_const = i32x4_splat(16384);
let zero = i16x8_splat(0);
let pw_1024 = i16x8_splat(PW_1024);
for row in 0..h {
let tmp1_row = &tmp1[row * w..][..w];
let tmp2_row = &tmp2[row * w..][..w];
let dst_row = &mut dst[row * dst_stride..][..w];
let mut col = 0;
while col + 8 <= w {
let t1_arr: &[i16; 8] = tmp1_row[col..col + 8].try_into().unwrap();
let t2_arr: &[i16; 8] = tmp2_row[col..col + 8].try_into().unwrap();
let t1 = safe_unaligned_simd::wasm32::v128_load(t1_arr);
let t2 = safe_unaligned_simd::wasm32::v128_load(t2_arr);
let sum = i16x8_add(t1, t2);
let prod_lo = i32x4_extmul_low_i16x8(sum, pw_1024);
let prod_hi = i32x4_extmul_high_i16x8(sum, pw_1024);
let rounded_lo = i32x4_add(prod_lo, round_const);
let rounded_hi = i32x4_add(prod_hi, round_const);
let shifted_lo = i32x4_shr(rounded_lo, 15);
let shifted_hi = i32x4_shr(rounded_hi, 15);
let narrowed = i16x8_narrow_i32x4(shifted_lo, shifted_hi);
let packed = u8x16_narrow_i16x8(narrowed, zero);
let val = i64x2_extract_lane::<0>(packed);
let bytes = val.to_ne_bytes();
dst_row[col..col + 8].copy_from_slice(&bytes);
col += 8;
}
while col < w {
let sum = tmp1_row[col].wrapping_add(tmp2_row[col]);
let avg = ((sum as i32 * 1024 + 16384) >> 15).clamp(0, 255) as u8;
dst_row[col] = avg;
col += 1;
}
}
}
pub fn avg_8bpc_scalar(
dst: &mut [u8],
dst_stride: usize,
tmp1: &[i16],
tmp2: &[i16],
w: usize,
h: usize,
) {
for row in 0..h {
let tmp1_row = &tmp1[row * w..][..w];
let tmp2_row = &tmp2[row * w..][..w];
let dst_row = &mut dst[row * dst_stride..][..w];
for col in 0..w {
let sum = tmp1_row[col].wrapping_add(tmp2_row[col]);
let avg = ((sum as i32 * 1024 + 16384) >> 15).clamp(0, 255) as u8;
dst_row[col] = avg;
}
}
}
pub fn avg_8bpc(dst: &mut [u8], dst_stride: usize, tmp1: &[i16], tmp2: &[i16], w: usize, h: usize) {
#[cfg(target_arch = "x86_64")]
if let Some(token) = Desktop64::summon() {
avg_8bpc_avx2(token, dst, dst_stride, tmp1, tmp2, w, h);
return;
}
#[cfg(target_arch = "wasm32")]
if let Some(token) = Wasm128Token::summon() {
avg_8bpc_wasm128(token, dst, dst_stride, tmp1, tmp2, w, h);
return;
}
avg_8bpc_scalar(dst, dst_stride, tmp1, tmp2, w, h);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_avg_avx2_matches_scalar() {
let test_values: Vec<i16> = vec![
0,
1,
2,
127,
128,
255,
256,
511,
512,
1023,
1024,
2047,
2048,
4095,
4096,
8191,
8192,
16383,
16384,
-1,
-128,
-256,
-512,
-1024,
-2048,
-4096,
i16::MIN,
i16::MAX,
];
let w = 64; let h = 2;
let mut tmp1 = vec![0i16; w * h];
let mut tmp2 = vec![0i16; w * h];
let mut dst_avx2 = vec![0u8; w * h];
let mut dst_scalar = vec![0u8; w * h];
for &v1 in &test_values {
for &v2 in &test_values {
tmp1.fill(v1);
tmp2.fill(v2);
dst_avx2.fill(0);
dst_scalar.fill(0);
avg_8bpc_scalar(&mut dst_scalar, w, &tmp1, &tmp2, w, h);
avg_8bpc(&mut dst_avx2, w, &tmp1, &tmp2, w, h);
assert_eq!(
dst_avx2,
dst_scalar,
"Mismatch for v1={}, v2={}: avx2={:?} scalar={:?}",
v1,
v2,
&dst_avx2[..8],
&dst_scalar[..8]
);
}
}
}
#[test]
fn test_avg_varying_data() {
let w = 128;
let h = 4;
let tmp1: Vec<i16> = (0..w * h).map(|i| ((i * 37) % 8192) as i16).collect();
let tmp2: Vec<i16> = (0..w * h)
.map(|i| ((i * 73 + 1000) % 8192) as i16)
.collect();
let mut dst_avx2 = vec![0u8; w * h];
let mut dst_scalar = vec![0u8; w * h];
avg_8bpc_scalar(&mut dst_scalar, w, &tmp1, &tmp2, w, h);
avg_8bpc(&mut dst_avx2, w, &tmp1, &tmp2, w, h);
assert_eq!(
dst_avx2, dst_scalar,
"Results differ for varying data pattern"
);
}
#[test]
fn test_avg_rounding() {
let w = 32;
let h = 1;
let tmp1 = vec![1i16; w];
let tmp2 = vec![0i16; w];
let mut dst = vec![255u8; w];
avg_8bpc(&mut dst, w, &tmp1, &tmp2, w, h);
assert_eq!(dst[0], 0, "sum=1 should round to 0");
let tmp1 = vec![16i16; w];
let tmp2 = vec![16i16; w];
let mut dst = vec![0u8; w];
avg_8bpc(&mut dst, w, &tmp1, &tmp2, w, h);
assert_eq!(dst[0], 1, "sum=32 should round to 1");
let tmp1 = vec![255i16; w];
let tmp2 = vec![255i16; w];
let mut dst = vec![0u8; w];
avg_8bpc(&mut dst, w, &tmp1, &tmp2, w, h);
assert_eq!(dst[0], 16, "sum=510 should give 16");
let tmp1 = vec![8192i16; w];
let tmp2 = vec![8192i16; w];
let mut dst = vec![0u8; w];
avg_8bpc(&mut dst, w, &tmp1, &tmp2, w, h);
assert_eq!(dst[0], 255, "sum=16384 should saturate to 255");
}
}