#![allow(clippy::undocumented_unsafe_blocks)]
#![allow(unsafe_op_in_unsafe_fn)]
#[cfg(test)]
use std::mem::size_of;
use std::{arch::x86_64::*, num::NonZeroUsize};
use crate::degrain::MAX_REFS_SIZE;
#[cfg(test)]
use crate::util::Pixel;
use cpudetect::target_family;
use semisafe::slice::get as semisafe_get;
use semisafe::slice::get_mut as semisafe_get_mut;
#[target_family("x86_64_v3")]
pub(crate) unsafe fn degrain_u8<const RADIUS: usize, const WIDTH: usize, const HEIGHT: usize>(
dest: *mut u8,
dest_stride_bytes: NonZeroUsize,
src: *const u8,
src_stride_bytes: NonZeroUsize,
refs: &[*const u8],
refs_strides_bytes: &[NonZeroUsize],
w_src: i32,
w_refs: &[i32],
) {
let mut refs_rows = [std::ptr::null(); MAX_REFS_SIZE];
let mut refs_strides = [0_usize; MAX_REFS_SIZE];
for r in 0..(RADIUS * 2) {
*semisafe_get_mut(&mut refs_rows, r) = *semisafe_get(refs, r);
*semisafe_get_mut(&mut refs_strides, r) = semisafe_get(refs_strides_bytes, r).get();
}
let mut src_row = src;
let mut dest_row = dest;
let bias = _mm256_set1_epi32(128);
let w_src_vec = _mm256_set1_epi32(w_src);
for _y in 0..HEIGHT {
if WIDTH >= 8 {
for x in (0..WIDTH).step_by(8) {
let src_8 = _mm_loadl_epi64(src_row.add(x).cast());
let src_i32 = _mm256_cvtepu8_epi32(src_8);
let mut sum = _mm256_add_epi32(bias, _mm256_mullo_epi32(src_i32, w_src_vec));
for r in 0..(RADIUS * 2) {
let ref_row = *semisafe_get(&refs_rows, r);
let ref_8 = _mm_loadl_epi64(ref_row.add(x).cast());
let ref_i32 = _mm256_cvtepu8_epi32(ref_8);
let w_ref_vec = _mm256_set1_epi32(*semisafe_get(w_refs, r));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(ref_i32, w_ref_vec));
}
let shifted = _mm256_srai_epi32(sum, 8);
let packed16 = _mm256_packus_epi32(shifted, _mm256_setzero_si256());
let packed16 = _mm256_permute4x64_epi64(packed16, 0xd8);
let packed8 = _mm256_packus_epi16(packed16, _mm256_setzero_si256());
_mm_storel_epi64(
dest_row.add(x) as *mut __m128i,
_mm256_castsi256_si128(packed8),
);
}
} else {
let src_4 = match WIDTH {
4 => _mm_cvtsi32_si128((src_row as *const u32).read_unaligned() as i32),
2 => _mm_cvtsi32_si128((src_row as *const u16).read_unaligned() as i32),
_ => unreachable!(),
};
let src_i32 = _mm_cvtepu8_epi32(src_4);
let mut sum = _mm_add_epi32(
_mm_set1_epi32(128),
_mm_mullo_epi32(src_i32, _mm_set1_epi32(w_src)),
);
for r in 0..(RADIUS * 2) {
let ref_row = *semisafe_get(&refs_rows, r);
let ref_4 = match WIDTH {
4 => _mm_cvtsi32_si128((ref_row as *const u32).read_unaligned() as i32),
2 => _mm_cvtsi32_si128((ref_row as *const u16).read_unaligned() as i32),
_ => unreachable!(),
};
let ref_i32 = _mm_cvtepu8_epi32(ref_4);
let w_ref_vec = _mm_set1_epi32(*semisafe_get(w_refs, r));
sum = _mm_add_epi32(sum, _mm_mullo_epi32(ref_i32, w_ref_vec));
}
let shifted = _mm_srai_epi32(sum, 8);
let packed16 = _mm_packus_epi32(shifted, _mm_setzero_si128());
let packed8 = _mm_packus_epi16(packed16, _mm_setzero_si128());
let out = _mm_cvtsi128_si32(packed8) as u32;
match WIDTH {
4 => (dest_row as *mut u32).write_unaligned(out),
2 => (dest_row as *mut u16).write_unaligned(out as u16),
_ => unreachable!(),
}
}
dest_row = dest_row.add(dest_stride_bytes.get());
src_row = src_row.add(src_stride_bytes.get());
for r in 0..(RADIUS * 2) {
let ref_row = semisafe_get_mut(&mut refs_rows, r);
*ref_row = ref_row.add(*semisafe_get(&refs_strides, r));
}
}
}
#[target_family("x86_64_v3")]
pub(crate) unsafe fn degrain_u16<const RADIUS: usize, const WIDTH: usize, const HEIGHT: usize>(
dest: *mut u8,
dest_stride_bytes: NonZeroUsize,
src: *const u8,
src_stride_bytes: NonZeroUsize,
refs: &[*const u8],
refs_strides_bytes: &[NonZeroUsize],
w_src: i32,
w_refs: &[i32],
) {
let mut refs_rows = [std::ptr::null(); MAX_REFS_SIZE];
let mut refs_strides = [0_usize; MAX_REFS_SIZE];
for r in 0..(RADIUS * 2) {
*semisafe_get_mut(&mut refs_rows, r) = *semisafe_get(refs, r);
*semisafe_get_mut(&mut refs_strides, r) = semisafe_get(refs_strides_bytes, r).get();
}
let mut src_row = src;
let mut dest_row = dest;
let bias = _mm256_set1_epi32(128);
let w_src_vec = _mm256_set1_epi32(w_src);
for _y in 0..HEIGHT {
if WIDTH >= 8 {
for x in (0..WIDTH).step_by(8) {
let offset = x * size_of::<u16>();
let src_8 = _mm_loadu_si128(src_row.add(offset).cast());
let src_i32 = _mm256_cvtepu16_epi32(src_8);
let mut sum = _mm256_add_epi32(bias, _mm256_mullo_epi32(src_i32, w_src_vec));
for r in 0..(RADIUS * 2) {
let ref_row = *semisafe_get(&refs_rows, r);
let ref_8 = _mm_loadu_si128(ref_row.add(offset).cast());
let ref_i32 = _mm256_cvtepu16_epi32(ref_8);
let w_ref_vec = _mm256_set1_epi32(*semisafe_get(w_refs, r));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(ref_i32, w_ref_vec));
}
let shifted = _mm256_srai_epi32(sum, 8);
let packed = _mm256_packus_epi32(shifted, _mm256_setzero_si256());
let packed = _mm256_permute4x64_epi64(packed, 0xd8);
_mm_storeu_si128(
dest_row.add(offset) as *mut __m128i,
_mm256_castsi256_si128(packed),
);
}
} else {
let src_4 = match WIDTH {
4 => _mm_loadl_epi64(src_row.cast()),
2 => _mm_cvtsi32_si128((src_row as *const u32).read_unaligned() as i32),
_ => unreachable!(),
};
let src_i32 = _mm_cvtepu16_epi32(src_4);
let mut sum = _mm_add_epi32(
_mm_set1_epi32(128),
_mm_mullo_epi32(src_i32, _mm_set1_epi32(w_src)),
);
for r in 0..(RADIUS * 2) {
let ref_row = *semisafe_get(&refs_rows, r);
let ref_4 = match WIDTH {
4 => _mm_loadl_epi64(ref_row.cast()),
2 => _mm_cvtsi32_si128((ref_row as *const u32).read_unaligned() as i32),
_ => unreachable!(),
};
let ref_i32 = _mm_cvtepu16_epi32(ref_4);
let w_ref_vec = _mm_set1_epi32(*semisafe_get(w_refs, r));
sum = _mm_add_epi32(sum, _mm_mullo_epi32(ref_i32, w_ref_vec));
}
let shifted = _mm_srai_epi32(sum, 8);
let packed = _mm_packus_epi32(shifted, _mm_setzero_si128());
match WIDTH {
4 => (dest_row as *mut u64).write_unaligned(_mm_cvtsi128_si64(packed) as u64),
2 => (dest_row as *mut u32).write_unaligned(_mm_cvtsi128_si32(packed) as u32),
_ => unreachable!(),
}
}
dest_row = dest_row.add(dest_stride_bytes.get());
src_row = src_row.add(src_stride_bytes.get());
for r in 0..(RADIUS * 2) {
let ref_row = semisafe_get_mut(&mut refs_rows, r);
*ref_row = ref_row.add(*semisafe_get(&refs_strides, r));
}
}
}
#[cfg(test)]
#[target_family("x86_64_v3")]
pub(super) unsafe fn degrain_test<T: Pixel>(
dest: &mut Vec<T>,
width: NonZeroUsize,
height: NonZeroUsize,
src: &[T],
src_stride_pixels: NonZeroUsize,
refs: &[&[T]],
refs_strides_pixels: &[NonZeroUsize],
w_src: i32,
w_refs: &[i32],
) {
let radius = refs.len() / 2;
let stride_bytes =
unsafe { NonZeroUsize::new_unchecked(src_stride_pixels.get().saturating_mul(size_of::<T>())) };
let refs_ptrs = refs
.iter()
.map(|ref_| ref_.as_ptr().cast())
.collect::<Box<[_]>>();
let refs_strides_bytes = refs_strides_pixels
.iter()
.map(|stride| {
unsafe { NonZeroUsize::new_unchecked(stride.get().saturating_mul(size_of::<T>())) }
})
.collect::<Box<[_]>>();
let func = match radius {
1 => super::select_degrain_avx2::<T, 1>(width, height),
2 => super::select_degrain_avx2::<T, 2>(width, height),
3 => super::select_degrain_avx2::<T, 3>(width, height),
4 => super::select_degrain_avx2::<T, 4>(width, height),
5 => super::select_degrain_avx2::<T, 5>(width, height),
6 => super::select_degrain_avx2::<T, 6>(width, height),
_ => unreachable!("unsupported degrain radius"),
};
unsafe {
func(
dest.as_mut_ptr().cast(),
stride_bytes,
src.as_ptr().cast(),
stride_bytes,
&refs_ptrs,
&refs_strides_bytes,
w_src,
w_refs,
);
}
}