#![allow(clippy::undocumented_unsafe_blocks)]
#![allow(unsafe_op_in_unsafe_fn)]
use std::{
arch::x86_64::*,
mem::size_of,
num::{NonZeroU8, NonZeroUsize},
};
use crate::util::Pixel;
#[target_feature(enable = "avx2,avx512f,avx512cd,avx512vl,avx512dq,avx512bw")]
pub(super) fn refine_horizontal_bilinear<T: Pixel>(
dest: &mut [T],
src: &[T],
pitch: NonZeroUsize,
width: NonZeroUsize,
height: NonZeroUsize,
_bits_per_sample: NonZeroU8,
) {
match size_of::<T>() {
1 => unsafe {
refine_horizontal_bilinear_u8(
src.as_ptr().cast(),
dest.as_mut_ptr().cast(),
pitch,
width,
height,
);
},
2 => unsafe {
refine_horizontal_bilinear_u16(
src.as_ptr().cast(),
dest.as_mut_ptr().cast(),
pitch,
width,
height,
);
},
_ => unreachable!(),
}
}
#[target_feature(enable = "avx2,avx512f,avx512cd,avx512vl,avx512dq,avx512bw")]
#[cfg(any(test, feature = "experimental"))]
pub(super) fn refine_vertical_bilinear<T: Pixel>(
dest: &mut [T],
src: &[T],
pitch: NonZeroUsize,
width: NonZeroUsize,
height: NonZeroUsize,
_bits_per_sample: NonZeroU8,
) {
match size_of::<T>() {
1 => unsafe {
refine_vertical_bilinear_u8(
src.as_ptr().cast(),
dest.as_mut_ptr().cast(),
pitch,
width,
height,
);
},
2 => unsafe {
refine_vertical_bilinear_u16(
src.as_ptr().cast(),
dest.as_mut_ptr().cast(),
pitch,
width,
height,
);
},
_ => unreachable!(),
}
}
#[target_feature(enable = "avx2,avx512f,avx512cd,avx512vl,avx512dq,avx512bw")]
pub(super) fn refine_diagonal_bilinear<T: Pixel>(
dest: &mut [T],
src: &[T],
pitch: NonZeroUsize,
width: NonZeroUsize,
height: NonZeroUsize,
_bits_per_sample: NonZeroU8,
) {
match size_of::<T>() {
1 => unsafe {
refine_diagonal_bilinear_u8(
src.as_ptr().cast(),
dest.as_mut_ptr().cast(),
pitch,
width,
height,
);
},
2 => unsafe {
refine_diagonal_bilinear_u16(
src.as_ptr().cast(),
dest.as_mut_ptr().cast(),
pitch,
width,
height,
);
},
_ => unreachable!(),
}
}
#[target_feature(enable = "avx2,avx512f,avx512cd,avx512vl,avx512dq,avx512bw")]
unsafe fn refine_horizontal_bilinear_u8(
src: *const u8,
dest: *mut u8,
pitch: NonZeroUsize,
width: NonZeroUsize,
height: NonZeroUsize,
) {
let pitch = pitch.get();
let width = width.get();
let height = height.get();
for j in 0..height {
let row_offset = j * pitch;
let mut i = 0;
while i + 64 < width {
let current = _mm512_loadu_si512(src.add(row_offset + i).cast::<__m512i>());
let next = _mm512_loadu_si512(src.add(row_offset + i + 1).cast::<__m512i>());
let result = _mm512_avg_epu8(current, next);
_mm512_storeu_si512(dest.add(row_offset + i).cast::<__m512i>(), result);
i += 64;
}
while i + 32 < width {
let current = _mm256_loadu_si256(src.add(row_offset + i).cast::<__m256i>());
let next = _mm256_loadu_si256(src.add(row_offset + i + 1).cast::<__m256i>());
let result = _mm256_avg_epu8(current, next);
_mm256_storeu_si256(dest.add(row_offset + i).cast::<__m256i>(), result);
i += 32;
}
while i < width - 1 {
let a = *src.add(row_offset + i) as u16;
let b = *src.add(row_offset + i + 1) as u16;
*dest.add(row_offset + i) = ((a + b + 1) / 2) as u8;
i += 1;
}
if width > 0 {
*dest.add(row_offset + width - 1) = *src.add(row_offset + width - 1);
}
}
}
#[target_feature(enable = "avx2,avx512f,avx512cd,avx512vl,avx512dq,avx512bw")]
unsafe fn refine_horizontal_bilinear_u16(
src: *const u16,
dest: *mut u16,
pitch: NonZeroUsize,
width: NonZeroUsize,
height: NonZeroUsize,
) {
let pitch = pitch.get();
let width = width.get();
let height = height.get();
for j in 0..height {
let row_offset = j * pitch;
let mut i = 0;
while i + 32 < width {
let current = _mm512_loadu_si512(src.add(row_offset + i).cast::<__m512i>());
let next = _mm512_loadu_si512(src.add(row_offset + i + 1).cast::<__m512i>());
let result = _mm512_avg_epu16(current, next);
_mm512_storeu_si512(dest.add(row_offset + i).cast::<__m512i>(), result);
i += 32;
}
while i + 16 < width {
let current = _mm256_loadu_si256(src.add(row_offset + i).cast::<__m256i>());
let next = _mm256_loadu_si256(src.add(row_offset + i + 1).cast::<__m256i>());
let result = _mm256_avg_epu16(current, next);
_mm256_storeu_si256(dest.add(row_offset + i).cast::<__m256i>(), result);
i += 16;
}
while i < width - 1 {
let a = *src.add(row_offset + i) as u32;
let b = *src.add(row_offset + i + 1) as u32;
*dest.add(row_offset + i) = ((a + b + 1) / 2) as u16;
i += 1;
}
if width > 0 {
*dest.add(row_offset + width - 1) = *src.add(row_offset + width - 1);
}
}
}
#[target_feature(enable = "avx2,avx512f,avx512cd,avx512vl,avx512dq,avx512bw")]
#[cfg(any(test, feature = "experimental"))]
unsafe fn refine_vertical_bilinear_u8(
mut src: *const u8,
mut dest: *mut u8,
pitch: NonZeroUsize,
width: NonZeroUsize,
height: NonZeroUsize,
) {
let pitch = pitch.get();
let width = width.get();
let height = height.get();
for _ in 0..(height - 1) {
let mut i = 0;
while i + 64 <= width {
let current = _mm512_loadu_si512(src.add(i).cast::<__m512i>());
let next = _mm512_loadu_si512(src.add(pitch + i).cast::<__m512i>());
let result = _mm512_avg_epu8(current, next);
_mm512_storeu_si512(dest.add(i).cast::<__m512i>(), result);
i += 64;
}
while i + 32 <= width {
let current = _mm256_loadu_si256(src.add(i).cast::<__m256i>());
let next = _mm256_loadu_si256(src.add(pitch + i).cast::<__m256i>());
let result = _mm256_avg_epu8(current, next);
_mm256_storeu_si256(dest.add(i).cast::<__m256i>(), result);
i += 32;
}
while i < width {
let a = *src.add(i) as u16;
let b = *src.add(pitch + i) as u16;
*dest.add(i) = ((a + b + 1) / 2) as u8;
i += 1;
}
src = src.add(pitch);
dest = dest.add(pitch);
}
if height > 0 {
std::ptr::copy_nonoverlapping(src, dest, width);
}
}
#[target_feature(enable = "avx2,avx512f,avx512cd,avx512vl,avx512dq,avx512bw")]
#[cfg(any(test, feature = "experimental"))]
unsafe fn refine_vertical_bilinear_u16(
src: *const u16,
dest: *mut u16,
pitch: NonZeroUsize,
width: NonZeroUsize,
height: NonZeroUsize,
) {
let pitch = pitch.get();
let width = width.get();
let height = height.get();
for j in 0..height - 1 {
let row_offset = j * pitch;
let mut i = 0;
while i + 32 <= width {
let current = _mm512_loadu_si512(src.add(row_offset + i).cast::<__m512i>());
let next = _mm512_loadu_si512(src.add(row_offset + pitch + i).cast::<__m512i>());
let result = _mm512_avg_epu16(current, next);
_mm512_storeu_si512(dest.add(row_offset + i).cast::<__m512i>(), result);
i += 32;
}
while i + 16 <= width {
let current = _mm256_loadu_si256(src.add(row_offset + i).cast::<__m256i>());
let next = _mm256_loadu_si256(src.add(row_offset + pitch + i).cast::<__m256i>());
let result = _mm256_avg_epu16(current, next);
_mm256_storeu_si256(dest.add(row_offset + i).cast::<__m256i>(), result);
i += 16;
}
while i < width {
let a = *src.add(row_offset + i) as u32;
let b = *src.add(row_offset + pitch + i) as u32;
*dest.add(row_offset + i) = ((a + b + 1) / 2) as u16;
i += 1;
}
}
if height > 0 {
let last_row_offset = (height - 1) * pitch;
std::ptr::copy_nonoverlapping(src.add(last_row_offset), dest.add(last_row_offset), width);
}
}
#[target_feature(enable = "avx2,avx512f,avx512cd,avx512vl,avx512dq,avx512bw")]
unsafe fn refine_diagonal_bilinear_u8(
src: *const u8,
dest: *mut u8,
pitch: NonZeroUsize,
width: NonZeroUsize,
height: NonZeroUsize,
) {
let pitch = pitch.get();
let width = width.get();
let height = height.get();
let two = _mm512_set1_epi16(2);
let avx2_two = _mm256_set1_epi16(2);
let mut offset = 0;
for _j in 0..height {
let mut i = 0;
while i + 32 < width {
let a = _mm256_loadu_si256(src.add(offset + i).cast::<__m256i>());
let b = _mm256_loadu_si256(src.add(offset + i + 1).cast::<__m256i>());
let c = _mm256_loadu_si256(src.add(offset + pitch + i).cast::<__m256i>());
let d = _mm256_loadu_si256(src.add(offset + pitch + i + 1).cast::<__m256i>());
let sum_ab = _mm512_add_epi16(_mm512_cvtepu8_epi16(a), _mm512_cvtepu8_epi16(b));
let sum_cd = _mm512_add_epi16(_mm512_cvtepu8_epi16(c), _mm512_cvtepu8_epi16(d));
let sum = _mm512_add_epi16(_mm512_add_epi16(sum_ab, sum_cd), two);
let result = _mm512_srli_epi16(sum, 2);
let packed = _mm512_cvtepi16_epi8(result);
_mm256_storeu_si256(dest.add(offset + i).cast::<__m256i>(), packed);
i += 32;
}
while i + 16 < width {
let a = _mm_loadu_si128(src.add(offset + i).cast::<__m128i>());
let b = _mm_loadu_si128(src.add(offset + i + 1).cast::<__m128i>());
let c = _mm_loadu_si128(src.add(offset + pitch + i).cast::<__m128i>());
let d = _mm_loadu_si128(src.add(offset + pitch + i + 1).cast::<__m128i>());
let result = super::avx2::apply_diagonal_bilinear_u8_avx2(
_mm256_cvtepu8_epi16(a),
_mm256_cvtepu8_epi16(b),
_mm256_cvtepu8_epi16(c),
_mm256_cvtepu8_epi16(d),
avx2_two,
);
_mm_storeu_si128(dest.add(offset + i).cast::<__m128i>(), result);
i += 16;
}
while i + 1 < width {
let a = *src.add(offset + i) as u16;
let b = *src.add(offset + i + 1) as u16;
let c = *src.add(offset + pitch + i) as u16;
let d = *src.add(offset + pitch + i + 1) as u16;
*dest.add(offset + i) = ((a + b + c + d + 2) >> 2) as u8;
i += 1;
}
if width > 0 {
let a = *src.add(offset + width - 1) as u16;
let b = *src.add(offset + width - 1 + pitch) as u16;
*dest.add(offset + width - 1) = ((a + b + 1) >> 1) as u8;
}
offset += pitch;
}
for i in 0..width.saturating_sub(1) {
let a = *src.add(offset + i) as u16;
let b = *src.add(offset + i + 1) as u16;
*dest.add(offset + i) = ((a + b + 1) >> 1) as u8;
}
if width > 0 {
*dest.add(offset + width - 1) = *src.add(offset + width - 1);
}
}
#[target_feature(enable = "avx2,avx512f,avx512cd,avx512vl,avx512dq,avx512bw")]
unsafe fn refine_diagonal_bilinear_u16(
src: *const u16,
dest: *mut u16,
pitch: NonZeroUsize,
width: NonZeroUsize,
height: NonZeroUsize,
) {
let pitch = pitch.get();
let width = width.get();
let height = height.get();
let two = _mm512_set1_epi32(2);
let avx2_two = _mm256_set1_epi32(2);
let mut offset = 0;
for _j in 0..height {
let mut i = 0;
while i + 16 < width {
let a = _mm256_loadu_si256(src.add(offset + i).cast::<__m256i>());
let b = _mm256_loadu_si256(src.add(offset + i + 1).cast::<__m256i>());
let c = _mm256_loadu_si256(src.add(offset + pitch + i).cast::<__m256i>());
let d = _mm256_loadu_si256(src.add(offset + pitch + i + 1).cast::<__m256i>());
let sum_ab = _mm512_add_epi32(_mm512_cvtepu16_epi32(a), _mm512_cvtepu16_epi32(b));
let sum_cd = _mm512_add_epi32(_mm512_cvtepu16_epi32(c), _mm512_cvtepu16_epi32(d));
let sum = _mm512_add_epi32(_mm512_add_epi32(sum_ab, sum_cd), two);
let result = _mm512_srli_epi32(sum, 2);
let packed = _mm512_cvtepi32_epi16(result);
_mm256_storeu_si256(dest.add(offset + i).cast::<__m256i>(), packed);
i += 16;
}
while i + 8 < width {
let a = _mm_loadu_si128(src.add(offset + i).cast::<__m128i>());
let b = _mm_loadu_si128(src.add(offset + i + 1).cast::<__m128i>());
let c = _mm_loadu_si128(src.add(offset + pitch + i).cast::<__m128i>());
let d = _mm_loadu_si128(src.add(offset + pitch + i + 1).cast::<__m128i>());
let result = super::avx2::apply_diagonal_bilinear_u16_avx2(
_mm256_cvtepu16_epi32(a),
_mm256_cvtepu16_epi32(b),
_mm256_cvtepu16_epi32(c),
_mm256_cvtepu16_epi32(d),
avx2_two,
);
_mm_storeu_si128(dest.add(offset + i).cast::<__m128i>(), result);
i += 8;
}
while i + 1 < width {
let a = *src.add(offset + i) as u32;
let b = *src.add(offset + i + 1) as u32;
let c = *src.add(offset + pitch + i) as u32;
let d = *src.add(offset + pitch + i + 1) as u32;
*dest.add(offset + i) = ((a + b + c + d + 2) >> 2) as u16;
i += 1;
}
if width > 0 {
let a = *src.add(offset + width - 1) as u32;
let b = *src.add(offset + width - 1 + pitch) as u32;
*dest.add(offset + width - 1) = ((a + b + 1) >> 1) as u16;
}
offset += pitch;
}
for i in 0..width.saturating_sub(1) {
let a = *src.add(offset + i) as u32;
let b = *src.add(offset + i + 1) as u32;
*dest.add(offset + i) = ((a + b + 1) >> 1) as u16;
}
if width > 0 {
*dest.add(offset + width - 1) = *src.add(offset + width - 1);
}
}