#![allow(clippy::undocumented_unsafe_blocks)]
#![allow(unsafe_op_in_unsafe_fn)]
use std::{
arch::x86_64::*,
mem::size_of,
num::{NonZeroU8, NonZeroUsize},
};
use cpudetect::target_family;
use num_traits::clamp;
use crate::util::{Pixel, round_ties_to_even};
use semisafe::option::unwrap as semisafe_opt_unwrap;
use semisafe::slice::get as semisafe_get;
use semisafe::slice::get_mut as semisafe_get_mut;
#[target_family("x86_64_v3")]
pub(super) unsafe fn float_src_to_pixels<T: Pixel>(
dst: &mut [T],
dst_pitch: NonZeroUsize,
src_dct: &[f32],
size_x: NonZeroUsize,
size_y: NonZeroUsize,
bits_per_sample: NonZeroU8,
dct_shift: usize,
dct_shift0: usize,
) {
match size_of::<T>() {
1 => float_src_to_pixels_u8(
dst.as_mut_ptr().cast(),
dst_pitch,
src_dct,
size_x,
size_y,
bits_per_sample,
dct_shift,
),
2 => float_src_to_pixels_u16(
dst.as_mut_ptr().cast(),
dst_pitch,
src_dct,
size_x,
size_y,
bits_per_sample,
dct_shift,
),
_ => unreachable!(),
}
let pixel_max = (1 << bits_per_sample.get() as usize) - 1;
let pixel_half = 1 << (bits_per_sample.get() as usize - 1);
let f = *semisafe_get(src_dct, 0) * 0.5;
let integ = round_ties_to_even(f) as i32;
*semisafe_get_mut(dst, 0) = semisafe_opt_unwrap(T::from(clamp(
(integ >> dct_shift0) + pixel_half,
0,
pixel_max,
)));
}
#[target_family("x86_64_v3")]
unsafe fn float_src_to_pixels_u8(
dst: *mut u8,
dst_pitch: NonZeroUsize,
src_dct: &[f32],
size_x: NonZeroUsize,
size_y: NonZeroUsize,
_bits_per_sample: NonZeroU8,
dct_shift: usize,
) {
let bits_per_sample = 8usize;
let size_x = size_x.get();
let size_y = size_y.get();
let dst_pitch = dst_pitch.get();
let sqrt_2_div_2 = _mm256_set1_ps((2f32).sqrt() / 2.0);
let pixel_half = 1i32 << (bits_per_sample - 1);
let pixel_max = (1i32 << bits_per_sample) - 1;
let v_pixel_half = _mm256_set1_epi32(pixel_half);
let v_pixel_max = _mm256_set1_epi32(pixel_max);
let v_zero = _mm256_setzero_si256();
let shift = _mm_cvtsi32_si128(dct_shift as i32);
let sqrt_2_div_2_scalar = (2f32).sqrt() / 2.0;
for y in 0..size_y {
let src_row = semisafe_get(src_dct, y * size_x..y * size_x + size_x);
let dst_row = dst.add(y * dst_pitch);
let mut x = 0;
while x + 8 <= size_x {
let f = _mm256_loadu_ps(src_row.as_ptr().add(x));
let scaled = _mm256_mul_ps(f, sqrt_2_div_2);
let integ = _mm256_cvtps_epi32(scaled);
let shifted = _mm256_sra_epi32(integ, shift);
let biased = _mm256_add_epi32(shifted, v_pixel_half);
let clamped_hi = _mm256_min_epi32(biased, v_pixel_max);
let clamped = _mm256_max_epi32(clamped_hi, v_zero);
let packed_u16 = _mm256_packus_epi32(clamped, v_zero);
let packed_u16 = _mm256_permute4x64_epi64(packed_u16, 0b11_01_10_00);
let packed_u8 = _mm256_packus_epi16(packed_u16, v_zero);
let lo = _mm256_castsi256_si128(packed_u8);
(dst_row.add(x) as *mut u64).write_unaligned(_mm_cvtsi128_si64(lo) as u64);
x += 8;
}
while x < size_x {
let f = semisafe_get(src_row, x) * sqrt_2_div_2_scalar;
let integ = round_ties_to_even(f) as i32;
let clamped = ((integ >> dct_shift) + pixel_half).clamp(0, pixel_max);
*dst_row.add(x) = clamped as u8;
x += 1;
}
}
}
#[target_family("x86_64_v3")]
unsafe fn float_src_to_pixels_u16(
dst: *mut u16,
dst_pitch: NonZeroUsize,
src_dct: &[f32],
size_x: NonZeroUsize,
size_y: NonZeroUsize,
bits_per_sample: NonZeroU8,
dct_shift: usize,
) {
let size_x = size_x.get();
let size_y = size_y.get();
let dst_pitch = dst_pitch.get();
let sqrt_2_div_2 = _mm256_set1_ps((2f32).sqrt() / 2.0);
let pixel_half = 1i32 << (bits_per_sample.get() as usize - 1);
let pixel_max = (1i32 << bits_per_sample.get() as usize) - 1;
let v_pixel_half = _mm256_set1_epi32(pixel_half);
let v_pixel_max = _mm256_set1_epi32(pixel_max);
let v_zero = _mm256_setzero_si256();
let shift = _mm_cvtsi32_si128(dct_shift as i32);
let sqrt_2_div_2_scalar = (2f32).sqrt() / 2.0;
for y in 0..size_y {
let src_row = semisafe_get(src_dct, y * size_x..y * size_x + size_x);
let dst_row = dst.add(y * dst_pitch);
let mut x = 0;
while x + 8 <= size_x {
let f = _mm256_loadu_ps(src_row.as_ptr().add(x));
let scaled = _mm256_mul_ps(f, sqrt_2_div_2);
let integ = _mm256_cvtps_epi32(scaled);
let shifted = _mm256_sra_epi32(integ, shift);
let biased = _mm256_add_epi32(shifted, v_pixel_half);
let clamped_hi = _mm256_min_epi32(biased, v_pixel_max);
let clamped = _mm256_max_epi32(clamped_hi, v_zero);
let packed = _mm256_packus_epi32(clamped, v_zero);
let packed = _mm256_permute4x64_epi64(packed, 0b11_01_10_00);
_mm_storeu_si128(
dst_row.add(x) as *mut __m128i,
_mm256_castsi256_si128(packed),
);
x += 8;
}
while x < size_x {
let f = semisafe_get(src_row, x) * sqrt_2_div_2_scalar;
let integ = round_ties_to_even(f) as i32;
let clamped = ((integ >> dct_shift) + pixel_half).clamp(0, pixel_max);
*dst_row.add(x) = clamped as u16;
x += 1;
}
}
}