#![allow(clippy::undocumented_unsafe_blocks)]
#![allow(unsafe_op_in_unsafe_fn)]
use std::{arch::x86_64::*, num::NonZeroUsize};
#[target_feature(enable = "avx2")]
unsafe fn overlap_u8x8_product(src: *const u8, window: *const u16) -> __m256i {
let src = _mm_loadl_epi64(src.cast::<__m128i>());
let src = _mm256_cvtepu8_epi32(src);
let window = _mm_loadu_si128(window.cast::<__m128i>());
let window = _mm256_cvtepu16_epi32(window);
_mm256_srli_epi32(_mm256_mullo_epi32(src, window), 6)
}
#[target_feature(enable = "avx2")]
unsafe fn overlap_u8x16_product(src: *const u8, window: *const u16) -> __m256i {
let lo = overlap_u8x8_product(src, window);
let hi = overlap_u8x8_product(src.add(8), window.add(8));
let packed = _mm256_packus_epi32(lo, hi);
_mm256_permute4x64_epi64(packed, 0b11_01_10_00)
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn overlaps_u8<const WIDTH: usize, const HEIGHT: usize>(
dest: *mut u8,
dest_stride_bytes: NonZeroUsize,
src: *const u8,
src_stride_bytes: NonZeroUsize,
window: *const u16,
window_stride: NonZeroUsize,
) {
let dest_stride = dest_stride_bytes.get();
let src_stride = src_stride_bytes.get();
let window_stride = window_stride.get();
for y in 0..HEIGHT {
let src_row = src.add(y * src_stride);
let dest_row = dest.add(y * dest_stride).cast::<u16>();
let window_row = window.add(y * window_stride);
let mut x = 0;
while x + 16 <= WIDTH {
let product = overlap_u8x16_product(src_row.add(x), window_row.add(x));
let dest = _mm256_loadu_si256(dest_row.add(x).cast::<__m256i>());
let dest = _mm256_add_epi16(dest, product);
_mm256_storeu_si256(dest_row.add(x).cast::<__m256i>(), dest);
x += 16;
}
if x + 8 <= WIDTH {
let product = overlap_u8x8_product(src_row.add(x), window_row.add(x));
let product = _mm_packus_epi32(
_mm256_castsi256_si128(product),
_mm256_extracti128_si256(product, 1),
);
let dest = _mm_loadu_si128(dest_row.add(x).cast::<__m128i>());
let dest = _mm_add_epi16(dest, product);
_mm_storeu_si128(dest_row.add(x).cast::<__m128i>(), dest);
x += 8;
}
while x < WIDTH {
*dest_row.add(x) += ((*src_row.add(x) as u32 * *window_row.add(x) as u32) >> 6) as u16;
x += 1;
}
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn overlaps_u16<const WIDTH: usize, const HEIGHT: usize>(
dest: *mut u8,
dest_stride_bytes: NonZeroUsize,
src: *const u8,
src_stride_bytes: NonZeroUsize,
window: *const u16,
window_stride: NonZeroUsize,
) {
let dest_stride = dest_stride_bytes.get();
let src_stride = src_stride_bytes.get();
let window_stride = window_stride.get();
for y in 0..HEIGHT {
let src_row = src.add(y * src_stride).cast::<u16>();
let dest_row = dest.add(y * dest_stride).cast::<u32>();
let window_row = window.add(y * window_stride);
let mut x = 0;
while x + 8 <= WIDTH {
let s16 = _mm_loadu_si128(src_row.add(x).cast::<__m128i>());
let w16 = _mm_loadu_si128(window_row.add(x).cast::<__m128i>());
let s32 = _mm256_cvtepu16_epi32(s16);
let w32 = _mm256_cvtepu16_epi32(w16);
let prod = _mm256_srli_epi32(_mm256_mullo_epi32(s32, w32), 6);
let d32 = _mm256_loadu_si256(dest_row.add(x).cast::<__m256i>());
let acc = _mm256_add_epi32(d32, prod);
_mm256_storeu_si256(dest_row.add(x).cast::<__m256i>(), acc);
x += 8;
}
while x < WIDTH {
*dest_row.add(x) += (*src_row.add(x) as u32 * *window_row.add(x) as u32) >> 6;
x += 1;
}
}
}