#![allow(clippy::undocumented_unsafe_blocks)]
#![allow(unsafe_op_in_unsafe_fn)]
#[cfg(test)]
use std::mem::size_of;
use std::{arch::x86_64::*, num::NonZeroUsize};
#[cfg(test)]
use crate::util::Pixel;
#[cfg(test)]
#[must_use]
#[target_feature(enable = "avx2,avx512f,avx512cd,avx512vl,avx512dq,avx512bw")]
pub(super) unsafe fn get_sad<T: Pixel>(
width: NonZeroUsize,
height: NonZeroUsize,
src: &[T],
src_pitch: NonZeroUsize,
ref_: &[T],
ref_pitch: NonZeroUsize,
) -> u64 {
match (size_of::<T>(), width.get(), height.get()) {
(1, 2, 2) => get_sad_u8::<2, 2>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 2, 4) => get_sad_u8::<2, 4>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 4, 2) => get_sad_u8::<4, 2>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 4, 4) => get_sad_u8::<4, 4>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 4, 8) => get_sad_u8::<4, 8>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 8, 1) => get_sad_u8::<8, 1>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 8, 2) => get_sad_u8::<8, 2>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 8, 4) => get_sad_u8::<8, 4>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 8, 8) => get_sad_u8::<8, 8>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 8, 16) => get_sad_u8::<8, 16>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 16, 1) => get_sad_u8::<16, 1>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 16, 2) => get_sad_u8::<16, 2>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 16, 4) => get_sad_u8::<16, 4>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 16, 8) => get_sad_u8::<16, 8>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 16, 16) => get_sad_u8::<16, 16>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 16, 32) => get_sad_u8::<16, 32>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 32, 8) => get_sad_u8::<32, 8>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 32, 16) => get_sad_u8::<32, 16>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 32, 32) => get_sad_u8::<32, 32>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 32, 64) => get_sad_u8::<32, 64>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 64, 16) => get_sad_u8::<64, 16>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 64, 32) => get_sad_u8::<64, 32>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 64, 64) => get_sad_u8::<64, 64>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 64, 128) => get_sad_u8::<64, 128>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 128, 32) => get_sad_u8::<128, 32>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 128, 64) => get_sad_u8::<128, 64>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 128, 128) => get_sad_u8::<128, 128>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 2, 2) => get_sad_u16::<2, 2>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 2, 4) => get_sad_u16::<2, 4>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 4, 2) => get_sad_u16::<4, 2>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 4, 4) => get_sad_u16::<4, 4>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 4, 8) => get_sad_u16::<4, 8>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 8, 1) => get_sad_u16::<8, 1>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 8, 2) => get_sad_u16::<8, 2>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 8, 4) => get_sad_u16::<8, 4>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 8, 8) => get_sad_u16::<8, 8>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 8, 16) => get_sad_u16::<8, 16>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 16, 1) => get_sad_u16::<16, 1>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 16, 2) => get_sad_u16::<16, 2>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 16, 4) => get_sad_u16::<16, 4>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 16, 8) => get_sad_u16::<16, 8>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 16, 16) => get_sad_u16::<16, 16>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 16, 32) => get_sad_u16::<16, 32>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 32, 8) => get_sad_u16::<32, 8>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 32, 16) => get_sad_u16::<32, 16>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 32, 32) => get_sad_u16::<32, 32>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 32, 64) => get_sad_u16::<32, 64>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 64, 16) => get_sad_u16::<64, 16>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 64, 32) => get_sad_u16::<64, 32>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 64, 64) => get_sad_u16::<64, 64>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 64, 128) => get_sad_u16::<64, 128>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 128, 32) => get_sad_u16::<128, 32>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 128, 64) => get_sad_u16::<128, 64>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 128, 128) => get_sad_u16::<128, 128>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
_ => unreachable!("unsupported block size"),
}
}
#[target_feature(enable = "avx2,avx512f,avx512cd,avx512vl,avx512dq,avx512bw")]
unsafe fn horizontal_sum_u64_zmm(sum: __m512i) -> u64 {
let sum256 = _mm256_add_epi64(
_mm512_castsi512_si256(sum),
_mm512_extracti64x4_epi64(sum, 1),
);
let sum128 = _mm_add_epi64(
_mm256_castsi256_si128(sum256),
_mm256_extracti128_si256(sum256, 1),
);
let high = _mm_unpackhi_epi64(sum128, sum128);
_mm_cvtsi128_si64(_mm_add_epi64(sum128, high)) as u64
}
#[target_feature(enable = "avx2,avx512f,avx512cd,avx512vl,avx512dq,avx512bw")]
unsafe fn widen_u32_pairwise_to_u64(sum: __m512i) -> (__m512i, __m512i) {
let zero = _mm512_setzero_si512();
(
_mm512_unpacklo_epi32(sum, zero),
_mm512_unpackhi_epi32(sum, zero),
)
}
#[target_feature(enable = "avx2,avx512f,avx512cd,avx512vl,avx512dq,avx512bw")]
unsafe fn sum_u32_to_u64(sum: __m512i) -> u64 {
let (lo, hi) = widen_u32_pairwise_to_u64(sum);
horizontal_sum_u64_zmm(_mm512_add_epi64(lo, hi))
}
#[must_use]
#[target_feature(enable = "avx2,avx512f,avx512cd,avx512vl,avx512dq,avx512bw")]
pub(crate) unsafe fn get_sad_u8<const WIDTH: usize, const HEIGHT: usize>(
src: *const u8,
src_pitch: NonZeroUsize,
ref_: *const u8,
ref_pitch: NonZeroUsize,
) -> u64 {
if WIDTH < 32 {
return crate::sad::avx2::get_sad_u8::<WIDTH, HEIGHT>(src, src_pitch, ref_, ref_pitch);
}
let src_pitch = src_pitch.get();
let ref_pitch = ref_pitch.get();
if WIDTH == 32 {
debug_assert_eq!(HEIGHT % 2, 0);
let mut acc = _mm512_setzero_si512();
for j in (0..HEIGHT).step_by(2) {
let src0 = _mm256_loadu_si256(src.add(j * src_pitch).cast::<__m256i>());
let src1 = _mm256_loadu_si256(src.add((j + 1) * src_pitch).cast::<__m256i>());
let ref0 = _mm256_loadu_si256(ref_.add(j * ref_pitch).cast::<__m256i>());
let ref1 = _mm256_loadu_si256(ref_.add((j + 1) * ref_pitch).cast::<__m256i>());
let src_pair = _mm512_inserti64x4(_mm512_castsi256_si512(src0), src1, 1);
let ref_pair = _mm512_inserti64x4(_mm512_castsi256_si512(ref0), ref1, 1);
acc = _mm512_add_epi64(acc, _mm512_sad_epu8(src_pair, ref_pair));
}
return horizontal_sum_u64_zmm(acc);
}
if WIDTH == 64 {
let mut acc0 = _mm512_setzero_si512();
let mut acc1 = _mm512_setzero_si512();
for j in 0..HEIGHT {
let sad = _mm512_sad_epu8(
_mm512_loadu_si512(src.add(j * src_pitch).cast::<__m512i>()),
_mm512_loadu_si512(ref_.add(j * ref_pitch).cast::<__m512i>()),
);
if j % 2 == 0 {
acc0 = _mm512_add_epi64(acc0, sad);
} else {
acc1 = _mm512_add_epi64(acc1, sad);
}
}
return horizontal_sum_u64_zmm(_mm512_add_epi64(acc0, acc1));
}
debug_assert_eq!(WIDTH, 128);
let mut acc0 = _mm512_setzero_si512();
let mut acc1 = _mm512_setzero_si512();
let mut acc2 = _mm512_setzero_si512();
let mut acc3 = _mm512_setzero_si512();
for j in 0..HEIGHT {
let src_row = src.add(j * src_pitch);
let ref_row = ref_.add(j * ref_pitch);
let sad0 = _mm512_sad_epu8(
_mm512_loadu_si512(src_row.cast::<__m512i>()),
_mm512_loadu_si512(ref_row.cast::<__m512i>()),
);
let sad1 = _mm512_sad_epu8(
_mm512_loadu_si512(src_row.add(64).cast::<__m512i>()),
_mm512_loadu_si512(ref_row.add(64).cast::<__m512i>()),
);
if j % 2 == 0 {
acc0 = _mm512_add_epi64(acc0, sad0);
acc1 = _mm512_add_epi64(acc1, sad1);
} else {
acc2 = _mm512_add_epi64(acc2, sad0);
acc3 = _mm512_add_epi64(acc3, sad1);
}
}
horizontal_sum_u64_zmm(_mm512_add_epi64(
_mm512_add_epi64(acc0, acc1),
_mm512_add_epi64(acc2, acc3),
))
}
#[must_use]
#[target_feature(enable = "avx2,avx512f,avx512cd,avx512vl,avx512dq,avx512bw")]
pub(crate) unsafe fn get_sad_u16<const WIDTH: usize, const HEIGHT: usize>(
src: *const u8,
src_pitch: NonZeroUsize,
ref_: *const u8,
ref_pitch: NonZeroUsize,
) -> u64 {
if WIDTH < 16 {
return crate::sad::avx2::get_sad_u16::<WIDTH, HEIGHT>(src, src_pitch, ref_, ref_pitch);
}
let src: *const u16 = src.cast();
let ref_: *const u16 = ref_.cast();
let src_pitch = src_pitch.get();
let ref_pitch = ref_pitch.get();
let zero = _mm512_setzero_si512();
if WIDTH == 16 {
let mut acc_lo = _mm512_setzero_si512();
let mut acc_hi = _mm512_setzero_si512();
let full_rows = HEIGHT / 2 * 2;
for j in (0..full_rows).step_by(2) {
let src0 = _mm256_loadu_si256(src.add(j * src_pitch) as *const __m256i);
let src1 = _mm256_loadu_si256(src.add((j + 1) * src_pitch) as *const __m256i);
let ref0 = _mm256_loadu_si256(ref_.add(j * ref_pitch) as *const __m256i);
let ref1 = _mm256_loadu_si256(ref_.add((j + 1) * ref_pitch) as *const __m256i);
let src_pair = _mm512_inserti64x4(_mm512_castsi256_si512(src0), src1, 1);
let ref_pair = _mm512_inserti64x4(_mm512_castsi256_si512(ref0), ref1, 1);
let abs = _mm512_sub_epi16(
_mm512_max_epu16(src_pair, ref_pair),
_mm512_min_epu16(src_pair, ref_pair),
);
acc_lo = _mm512_add_epi32(acc_lo, _mm512_unpacklo_epi16(abs, zero));
acc_hi = _mm512_add_epi32(acc_hi, _mm512_unpackhi_epi16(abs, zero));
}
if full_rows != HEIGHT {
let src_row = _mm256_loadu_si256(src.add(full_rows * src_pitch) as *const __m256i);
let ref_row = _mm256_loadu_si256(ref_.add(full_rows * ref_pitch) as *const __m256i);
let abs = _mm512_zextsi256_si512(_mm256_sub_epi16(
_mm256_max_epu16(src_row, ref_row),
_mm256_min_epu16(src_row, ref_row),
));
acc_lo = _mm512_add_epi32(acc_lo, _mm512_unpacklo_epi16(abs, zero));
acc_hi = _mm512_add_epi32(acc_hi, _mm512_unpackhi_epi16(abs, zero));
}
return sum_u32_to_u64(_mm512_add_epi32(acc_lo, acc_hi));
}
let mut acc_lo = _mm512_setzero_si512();
let mut acc_hi = _mm512_setzero_si512();
for j in 0..HEIGHT {
let src_row = src.add(j * src_pitch);
let ref_row = ref_.add(j * ref_pitch);
for i in (0..WIDTH).step_by(32) {
let s = _mm512_loadu_si512(src_row.add(i).cast::<__m512i>());
let r = _mm512_loadu_si512(ref_row.add(i).cast::<__m512i>());
let abs = _mm512_sub_epi16(_mm512_max_epu16(s, r), _mm512_min_epu16(s, r));
acc_lo = _mm512_add_epi32(acc_lo, _mm512_unpacklo_epi16(abs, zero));
acc_hi = _mm512_add_epi32(acc_hi, _mm512_unpackhi_epi16(abs, zero));
}
}
sum_u32_to_u64(_mm512_add_epi32(acc_lo, acc_hi))
}