use archmage::prelude::*;
#[cfg(target_arch = "wasm32")]
use safe_unaligned_simd::wasm32::v128_load32_zero;
#[cfg(target_arch = "x86_64")]
use safe_unaligned_simd::x86_64::{_mm_loadu_si32, _mm_storeu_si32};
pub(crate) fn unfilter_avg(row: &mut [u8], prev: &[u8], bpp: usize) {
match bpp {
4 => incant!(
unfilter_avg_bpp4_impl(row, prev),
[v1, neon, wasm128, scalar]
),
_ => unfilter_avg_scalar_any(row, prev, bpp),
}
}
fn unfilter_avg_scalar_any(row: &mut [u8], prev: &[u8], bpp: usize) {
let len = row.len();
for i in 0..bpp.min(len) {
row[i] = row[i].wrapping_add(prev[i] >> 1);
}
for i in bpp..len {
let avg = ((row[i - bpp] as u16 + prev[i] as u16) >> 1) as u8;
row[i] = row[i].wrapping_add(avg);
}
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn unfilter_avg_bpp4_impl_v1(_token: Sse2Token, row: &mut [u8], prev: &[u8]) {
let len = row.len();
if len < 4 {
return;
}
let zero = _mm_setzero_si128();
let mut a_wide = zero;
let mut i = 0;
while i + 4 <= len {
let b_raw = _mm_loadu_si32(<&[u8; 4]>::try_from(&prev[i..i + 4]).unwrap());
let b_wide = _mm_unpacklo_epi8(b_raw, zero);
let sum = _mm_add_epi16(a_wide, b_wide);
let avg_wide = _mm_srli_epi16(sum, 1);
let avg_narrow = _mm_packus_epi16(avg_wide, zero);
let filt = _mm_loadu_si32(<&[u8; 4]>::try_from(&row[i..i + 4]).unwrap());
let result = _mm_add_epi8(filt, avg_narrow);
_mm_storeu_si32(
<&mut [u8; 4]>::try_from(&mut row[i..i + 4]).unwrap(),
result,
);
a_wide = _mm_unpacklo_epi8(result, zero);
i += 4;
}
}
#[cfg(target_arch = "aarch64")]
#[arcane]
fn unfilter_avg_bpp4_impl_neon(_token: NeonToken, row: &mut [u8], prev: &[u8]) {
let len = row.len();
if len < 4 {
return;
}
let mut a_wide = vdup_n_u16(0);
let mut i = 0;
while i + 4 <= len {
let b_bytes = u32::from_le_bytes(<[u8; 4]>::try_from(&prev[i..i + 4]).unwrap());
let b_raw = vcreate_u8(b_bytes as u64);
let b_wide = vget_low_u16(vmovl_u8(b_raw));
let sum = vadd_u16(a_wide, b_wide);
let avg_wide = vshr_n_u16::<1>(sum);
let avg_narrow = vmovn_u16(vcombine_u16(avg_wide, vdup_n_u16(0)));
let filt_bytes = u32::from_le_bytes(<[u8; 4]>::try_from(&row[i..i + 4]).unwrap());
let filt = vcreate_u8(filt_bytes as u64);
let result = vadd_u8(filt, avg_narrow);
let result_u32 = vget_lane_u32::<0>(vreinterpret_u32_u8(result));
row[i..i + 4].copy_from_slice(&result_u32.to_le_bytes());
a_wide = vget_low_u16(vmovl_u8(result));
i += 4;
}
}
#[cfg(target_arch = "wasm32")]
#[arcane]
fn unfilter_avg_bpp4_impl_wasm128(_token: Wasm128Token, row: &mut [u8], prev: &[u8]) {
let len = row.len();
if len < 4 {
return;
}
let zero = i16x8_splat(0);
let mut a_wide = zero;
let mut i = 0;
while i + 4 <= len {
let b_raw = v128_load32_zero(<&[u8; 4]>::try_from(&prev[i..i + 4]).unwrap());
let b_wide = i16x8_extend_low_u8x16(b_raw);
let sum = i16x8_add(a_wide, b_wide);
let avg_wide = u16x8_shr(sum, 1);
let avg_narrow = u8x16_narrow_i16x8(avg_wide, zero);
let filt = v128_load32_zero(<&[u8; 4]>::try_from(&row[i..i + 4]).unwrap());
let result = i8x16_add(filt, avg_narrow);
let val = (i32x4_extract_lane::<0>(result) as u32).to_le_bytes();
row[i..i + 4].copy_from_slice(&val);
a_wide = i16x8_extend_low_u8x16(result);
i += 4;
}
}
fn unfilter_avg_bpp4_impl_scalar(_token: ScalarToken, row: &mut [u8], prev: &[u8]) {
unfilter_avg_scalar_any(row, prev, 4);
}
#[cfg(test)]
mod tests {
use archmage::testing::{CompileTimePolicy, for_each_token_permutation};
fn scalar_avg(row: &mut [u8], prev: &[u8], bpp: usize) {
let len = row.len();
for i in 0..bpp.min(len) {
row[i] = row[i].wrapping_add(prev[i] >> 1);
}
for i in bpp..len {
let avg = ((row[i - bpp] as u16 + prev[i] as u16) >> 1) as u8;
row[i] = row[i].wrapping_add(avg);
}
}
#[test]
fn avg_bpp4_all_tiers() {
let report = for_each_token_permutation(CompileTimePolicy::Warn, |_perm| {
for &len in &[0, 4, 8, 12, 100, 4096, 65536] {
let prev: Vec<u8> = (0..len).map(|i| (i * 7 + 13) as u8).collect();
let filtered: Vec<u8> = (0..len).map(|i| (i * 3 + 5) as u8).collect();
let mut expected = filtered.clone();
scalar_avg(&mut expected, &prev, 4);
let mut actual = filtered.clone();
super::unfilter_avg(&mut actual, &prev, 4);
assert_eq!(actual, expected, "avg bpp=4 mismatch at len={len}");
}
});
eprintln!("avg bpp=4: {report}");
}
#[test]
fn avg_other_bpp_unchanged() {
for &bpp in &[1, 2, 3, 6, 8] {
for &len in &[0, bpp, bpp * 4, bpp * 100] {
let prev: Vec<u8> = (0..len).map(|i| (i * 11 + 3) as u8).collect();
let filtered: Vec<u8> = (0..len).map(|i| (i * 5 + 7) as u8).collect();
let mut expected = filtered.clone();
scalar_avg(&mut expected, &prev, bpp);
let mut actual = filtered.clone();
super::unfilter_avg(&mut actual, &prev, bpp);
assert_eq!(actual, expected, "avg bpp={bpp} mismatch at len={len}");
}
}
}
}