#![allow(clippy::undocumented_unsafe_blocks)]
#![allow(unsafe_op_in_unsafe_fn)]
#[cfg(test)]
use std::mem::size_of;
use std::{arch::x86_64::*, num::NonZeroUsize};
#[cfg(test)]
use crate::util::Pixel;
use semisafe::slice::get_mut as semisafe_get_mut;
#[cfg(test)]
#[must_use]
#[target_feature(enable = "avx2")]
pub(super) unsafe fn get_satd<T: Pixel>(
width: NonZeroUsize,
height: NonZeroUsize,
src: &[T],
src_pitch: NonZeroUsize,
ref_: &[T],
ref_pitch: NonZeroUsize,
) -> u64 {
match (size_of::<T>(), width.get(), height.get()) {
(1, 4, 4) => get_satd_u8::<4, 4>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 8, 4) => get_satd_u8::<8, 4>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 8, 8) => get_satd_u8::<8, 8>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 16, 8) => get_satd_u8::<16, 8>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 16, 16) => get_satd_u8::<16, 16>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 32, 16) => get_satd_u8::<32, 16>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 32, 32) => get_satd_u8::<32, 32>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 64, 32) => get_satd_u8::<64, 32>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 64, 64) => get_satd_u8::<64, 64>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 128, 64) => get_satd_u8::<128, 64>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(1, 128, 128) => get_satd_u8::<128, 128>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 4, 4) => get_satd_u16::<4, 4>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 8, 4) => get_satd_u16::<8, 4>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 8, 8) => get_satd_u16::<8, 8>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 16, 8) => get_satd_u16::<16, 8>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 16, 16) => get_satd_u16::<16, 16>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 32, 16) => get_satd_u16::<32, 16>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 32, 32) => get_satd_u16::<32, 32>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 64, 32) => get_satd_u16::<64, 32>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 64, 64) => get_satd_u16::<64, 64>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 128, 64) => get_satd_u16::<128, 64>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
(2, 128, 128) => get_satd_u16::<128, 128>(
src.as_ptr().cast(),
src_pitch,
ref_.as_ptr().cast(),
ref_pitch,
),
_ => unreachable!("unsupported block size"),
}
}
#[must_use]
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn get_satd_u8<const WIDTH: usize, const HEIGHT: usize>(
src: *const u8,
src_pitch: NonZeroUsize,
ref_: *const u8,
ref_pitch: NonZeroUsize,
) -> u64 {
let src_pitch = src_pitch.get();
let ref_pitch = ref_pitch.get();
if WIDTH == 4 && HEIGHT == 4 {
return satd_4x4_u8(src, src_pitch, ref_, ref_pitch);
}
let mut total = 0u64;
for y in (0..HEIGHT).step_by(4) {
let mut x = 0;
while x + 16 <= WIDTH {
total += satd_16x4_u8(
src.add(y * src_pitch + x),
src_pitch,
ref_.add(y * ref_pitch + x),
ref_pitch,
);
x += 16;
}
while x + 8 <= WIDTH {
total += satd_8x4_u8(
src.add(y * src_pitch + x),
src_pitch,
ref_.add(y * ref_pitch + x),
ref_pitch,
);
x += 8;
}
}
total
}
#[target_feature(enable = "avx2")]
unsafe fn satd_4x4_u8(src: *const u8, src_pitch: usize, ref_: *const u8, ref_pitch: usize) -> u64 {
let mut rows = [_mm_setzero_si128(); 4];
for i in 0..4 {
let s = _mm_cvtepu8_epi16(_mm_cvtsi32_si128(
(src.add(i * src_pitch) as *const u32).read_unaligned() as i32,
));
let r = _mm_cvtepu8_epi16(_mm_cvtsi32_si128(
(ref_.add(i * ref_pitch) as *const u32).read_unaligned() as i32,
));
*semisafe_get_mut(&mut rows, i) = _mm_sub_epi16(s, r);
}
let m0 = _mm_add_epi16(rows[0], rows[1]);
let m1 = _mm_sub_epi16(rows[0], rows[1]);
let m2 = _mm_add_epi16(rows[2], rows[3]);
let m3 = _mm_sub_epi16(rows[2], rows[3]);
rows[0] = _mm_add_epi16(m0, m2);
rows[1] = _mm_add_epi16(m1, m3);
rows[2] = _mm_sub_epi16(m0, m2);
rows[3] = _mm_sub_epi16(m1, m3);
for row in &mut rows {
*row = hadamard4_horizontal_128(*row);
}
let ones = _mm_set1_epi16(1);
let mut acc = _mm_setzero_si128();
for row in rows {
acc = _mm_add_epi32(acc, _mm_madd_epi16(_mm_abs_epi16(row), ones));
}
(hsum_4x32_128(acc) >> 1) as u64
}
#[target_feature(enable = "avx2")]
unsafe fn satd_8x4_u8(src: *const u8, src_pitch: usize, ref_: *const u8, ref_pitch: usize) -> u64 {
let mut rows = [_mm_setzero_si128(); 4];
for i in 0..4 {
let s = _mm_cvtepu8_epi16(_mm_loadl_epi64(src.add(i * src_pitch) as *const __m128i));
let r = _mm_cvtepu8_epi16(_mm_loadl_epi64(ref_.add(i * ref_pitch) as *const __m128i));
*semisafe_get_mut(&mut rows, i) = _mm_sub_epi16(s, r);
}
let m0 = _mm_add_epi16(rows[0], rows[1]);
let m1 = _mm_sub_epi16(rows[0], rows[1]);
let m2 = _mm_add_epi16(rows[2], rows[3]);
let m3 = _mm_sub_epi16(rows[2], rows[3]);
rows[0] = _mm_add_epi16(m0, m2);
rows[1] = _mm_add_epi16(m1, m3);
rows[2] = _mm_sub_epi16(m0, m2);
rows[3] = _mm_sub_epi16(m1, m3);
for row in &mut rows {
*row = hadamard4_horizontal_128(*row);
}
let ones = _mm_set1_epi16(1);
let mut acc = _mm_setzero_si128();
for row in rows {
acc = _mm_add_epi32(acc, _mm_madd_epi16(_mm_abs_epi16(row), ones));
}
(hsum_4x32_128(acc) >> 1) as u64
}
#[target_feature(enable = "avx2")]
unsafe fn satd_16x4_u8(src: *const u8, src_pitch: usize, ref_: *const u8, ref_pitch: usize) -> u64 {
let mut rows = [_mm256_setzero_si256(); 4];
for i in 0..4 {
let s = _mm256_cvtepu8_epi16(_mm_loadu_si128(src.add(i * src_pitch) as *const __m128i));
let r = _mm256_cvtepu8_epi16(_mm_loadu_si128(ref_.add(i * ref_pitch) as *const __m128i));
*semisafe_get_mut(&mut rows, i) = _mm256_sub_epi16(s, r);
}
let m0 = _mm256_add_epi16(rows[0], rows[1]);
let m1 = _mm256_sub_epi16(rows[0], rows[1]);
let m2 = _mm256_add_epi16(rows[2], rows[3]);
let m3 = _mm256_sub_epi16(rows[2], rows[3]);
rows[0] = _mm256_add_epi16(m0, m2);
rows[1] = _mm256_add_epi16(m1, m3);
rows[2] = _mm256_sub_epi16(m0, m2);
rows[3] = _mm256_sub_epi16(m1, m3);
for row in &mut rows {
*row = hadamard4_horizontal_256(*row);
}
let ones = _mm256_set1_epi16(1);
let mut acc = _mm256_setzero_si256();
for row in rows {
acc = _mm256_add_epi32(acc, _mm256_madd_epi16(_mm256_abs_epi16(row), ones));
}
let lo = _mm256_castsi256_si128(acc);
let hi = _mm256_extracti128_si256(acc, 1);
let satd_lo = hsum_4x32_128(lo) >> 1;
let satd_hi = hsum_4x32_128(hi) >> 1;
(satd_lo + satd_hi) as u64
}
#[target_feature(enable = "avx2")]
unsafe fn hadamard4_horizontal_128(row: __m128i) -> __m128i {
let swap_adj = _mm_set_epi8(13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2);
let swapped = _mm_shuffle_epi8(row, swap_adj);
let s = _mm_add_epi16(row, swapped);
let d = _mm_sub_epi16(row, swapped);
let step1 = _mm_blend_epi16(s, d, 0xaa);
let swapped2 = _mm_shuffle_epi32(step1, 0xb1);
let s2 = _mm_add_epi16(step1, swapped2);
let d2 = _mm_sub_epi16(step1, swapped2);
_mm_blend_epi16(s2, d2, 0xcc)
}
#[target_feature(enable = "avx2")]
unsafe fn hadamard4_horizontal_256(row: __m256i) -> __m256i {
let swap_adj = _mm256_set_epi8(
13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, 10, 5, 4,
7, 6, 1, 0, 3, 2,
);
let swapped = _mm256_shuffle_epi8(row, swap_adj);
let s = _mm256_add_epi16(row, swapped);
let d = _mm256_sub_epi16(row, swapped);
let step1 = _mm256_blend_epi16(s, d, 0xaa);
let swapped2 = _mm256_shuffle_epi32(step1, 0xb1);
let s2 = _mm256_add_epi16(step1, swapped2);
let d2 = _mm256_sub_epi16(step1, swapped2);
_mm256_blend_epi16(s2, d2, 0xcc)
}
#[target_feature(enable = "avx2")]
unsafe fn hsum_4x32_128(v: __m128i) -> u32 {
let hi64 = _mm_unpackhi_epi64(v, v);
let sum2 = _mm_add_epi32(v, hi64);
let hi32 = _mm_srli_si128(sum2, 4);
let sum1 = _mm_add_epi32(sum2, hi32);
_mm_cvtsi128_si32(sum1) as u32
}
#[target_feature(enable = "avx2")]
unsafe fn hadamard4_horizontal_i32_128(row: __m128i) -> __m128i {
let swapped = _mm_shuffle_epi32(row, 0xb1);
let s = _mm_add_epi32(row, swapped);
let d = _mm_sub_epi32(row, swapped);
let step1 = _mm_blend_epi32(s, d, 0xa);
let swapped2 = _mm_shuffle_epi32(step1, 0x4e);
let s2 = _mm_add_epi32(step1, swapped2);
let d2 = _mm_sub_epi32(step1, swapped2);
_mm_blend_epi32(s2, d2, 0xc)
}
#[target_feature(enable = "avx2")]
unsafe fn hadamard4_horizontal_i32_256(row: __m256i) -> __m256i {
let swapped = _mm256_shuffle_epi32(row, 0xb1);
let s = _mm256_add_epi32(row, swapped);
let d = _mm256_sub_epi32(row, swapped);
let step1 = _mm256_blend_epi32(s, d, 0xaa);
let swapped2 = _mm256_shuffle_epi32(step1, 0x4e);
let s2 = _mm256_add_epi32(step1, swapped2);
let d2 = _mm256_sub_epi32(step1, swapped2);
_mm256_blend_epi32(s2, d2, 0xcc)
}
#[must_use]
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn get_satd_u16<const WIDTH: usize, const HEIGHT: usize>(
src: *const u8,
src_pitch: NonZeroUsize,
ref_: *const u8,
ref_pitch: NonZeroUsize,
) -> u64 {
let src: *const u16 = src.cast();
let ref_: *const u16 = ref_.cast();
let src_pitch = src_pitch.get();
let ref_pitch = ref_pitch.get();
if WIDTH == 4 && HEIGHT == 4 {
return satd_4x4_u16(src, src_pitch, ref_, ref_pitch);
}
let mut total = 0u64;
for y in (0..HEIGHT).step_by(4) {
let mut x = 0;
while x + 8 <= WIDTH {
total += satd_8x4_u16(
src.add(y * src_pitch + x),
src_pitch,
ref_.add(y * ref_pitch + x),
ref_pitch,
);
x += 8;
}
}
total
}
#[target_feature(enable = "avx2")]
unsafe fn satd_4x4_u16(
src: *const u16,
src_pitch: usize,
ref_: *const u16,
ref_pitch: usize,
) -> u64 {
let mut rows = [_mm_setzero_si128(); 4];
for i in 0..4 {
let s = _mm_cvtepu16_epi32(_mm_loadl_epi64(src.add(i * src_pitch) as *const __m128i));
let r = _mm_cvtepu16_epi32(_mm_loadl_epi64(ref_.add(i * ref_pitch) as *const __m128i));
*semisafe_get_mut(&mut rows, i) = _mm_sub_epi32(s, r);
}
let m0 = _mm_add_epi32(rows[0], rows[1]);
let m1 = _mm_sub_epi32(rows[0], rows[1]);
let m2 = _mm_add_epi32(rows[2], rows[3]);
let m3 = _mm_sub_epi32(rows[2], rows[3]);
rows[0] = _mm_add_epi32(m0, m2);
rows[1] = _mm_add_epi32(m1, m3);
rows[2] = _mm_sub_epi32(m0, m2);
rows[3] = _mm_sub_epi32(m1, m3);
for row in &mut rows {
*row = hadamard4_horizontal_i32_128(*row);
}
let mut acc = _mm_setzero_si128();
for row in rows {
acc = _mm_add_epi32(acc, _mm_abs_epi32(row));
}
(hsum_4x32_128(acc) >> 1) as u64
}
#[target_feature(enable = "avx2")]
unsafe fn satd_8x4_u16(
src: *const u16,
src_pitch: usize,
ref_: *const u16,
ref_pitch: usize,
) -> u64 {
let mut rows = [_mm256_setzero_si256(); 4];
for i in 0..4 {
let s = _mm256_cvtepu16_epi32(_mm_loadu_si128(src.add(i * src_pitch) as *const __m128i));
let r = _mm256_cvtepu16_epi32(_mm_loadu_si128(ref_.add(i * ref_pitch) as *const __m128i));
*semisafe_get_mut(&mut rows, i) = _mm256_sub_epi32(s, r);
}
let m0 = _mm256_add_epi32(rows[0], rows[1]);
let m1 = _mm256_sub_epi32(rows[0], rows[1]);
let m2 = _mm256_add_epi32(rows[2], rows[3]);
let m3 = _mm256_sub_epi32(rows[2], rows[3]);
rows[0] = _mm256_add_epi32(m0, m2);
rows[1] = _mm256_add_epi32(m1, m3);
rows[2] = _mm256_sub_epi32(m0, m2);
rows[3] = _mm256_sub_epi32(m1, m3);
for row in &mut rows {
*row = hadamard4_horizontal_i32_256(*row);
}
let mut acc = _mm256_setzero_si256();
for row in rows {
acc = _mm256_add_epi32(acc, _mm256_abs_epi32(row));
}
let lo = _mm256_castsi256_si128(acc);
let hi = _mm256_extracti128_si256(acc, 1);
let combined = _mm_add_epi32(lo, hi);
(hsum_4x32_128(combined) >> 1) as u64
}