use std::arch::x86_64::*;
#[derive(Debug, Clone, Copy)]
pub struct SSE256Vector(pub(crate) (__m128i, __m128i));
impl super::Vector for SSE256Vector {
fn is_available() -> bool {
raw_cpuid::CpuId::new()
.get_feature_info()
.is_some_and(|info| info.has_sse41())
}
#[inline(always)]
unsafe fn zero() -> Self {
unsafe { Self((_mm_setzero_si128(), _mm_setzero_si128())) }
}
#[inline(always)]
unsafe fn splat_u8(value: u8) -> Self {
unsafe { Self((_mm_set1_epi8(value as i8), _mm_set1_epi8(value as i8))) }
}
#[inline(always)]
unsafe fn splat_u16(value: u16) -> Self {
unsafe { Self((_mm_set1_epi16(value as i16), _mm_set1_epi16(value as i16))) }
}
#[inline(always)]
unsafe fn eq_u8(self, other: Self) -> Self {
unsafe {
Self((
_mm_cmpeq_epi8(self.0.0, other.0.0),
_mm_cmpeq_epi8(self.0.1, other.0.1),
))
}
}
#[inline(always)]
unsafe fn gt_u8(self, other: Self) -> Self {
unsafe {
let sign_bit = _mm_set1_epi8(-128i8);
let a_0_flipped = _mm_xor_si128(self.0.0, sign_bit);
let b_0_flipped = _mm_xor_si128(other.0.0, sign_bit);
let a_1_flipped = _mm_xor_si128(self.0.1, sign_bit);
let b_1_flipped = _mm_xor_si128(other.0.1, sign_bit);
Self((
_mm_cmpgt_epi8(a_0_flipped, b_0_flipped),
_mm_cmpgt_epi8(a_1_flipped, b_1_flipped),
))
}
}
#[inline(always)]
unsafe fn lt_u8(self, other: Self) -> Self {
unsafe {
let sign_bit = _mm_set1_epi8(-128i8);
let a_0_flipped = _mm_xor_si128(self.0.0, sign_bit);
let b_0_flipped = _mm_xor_si128(other.0.0, sign_bit);
let a_1_flipped = _mm_xor_si128(self.0.1, sign_bit);
let b_1_flipped = _mm_xor_si128(other.0.1, sign_bit);
Self((
_mm_cmplt_epi8(a_0_flipped, b_0_flipped),
_mm_cmplt_epi8(a_1_flipped, b_1_flipped),
))
}
}
#[inline(always)]
unsafe fn max_u16(self, other: Self) -> Self {
unsafe {
Self((
_mm_max_epu16(self.0.0, other.0.0),
_mm_max_epu16(self.0.1, other.0.1),
))
}
}
#[inline(always)]
unsafe fn smax_u16(self) -> u16 {
unsafe {
let all_ones = _mm_set1_epi16(-1); let inverted = (
_mm_xor_si128(self.0.0, all_ones),
_mm_xor_si128(self.0.1, all_ones),
);
let min_pos = (_mm_minpos_epu16(inverted.0), _mm_minpos_epu16(inverted.1));
let min_pos = _mm_min_epu16(min_pos.0, min_pos.1);
let min_val = _mm_extract_epi16(min_pos, 0) as u16;
!min_val }
}
#[inline(always)]
unsafe fn add_u16(self, other: Self) -> Self {
unsafe {
Self((
_mm_add_epi16(self.0.0, other.0.0),
_mm_add_epi16(self.0.1, other.0.1),
))
}
}
#[inline(always)]
unsafe fn subs_u16(self, other: Self) -> Self {
unsafe {
Self((
_mm_subs_epu16(self.0.0, other.0.0),
_mm_subs_epu16(self.0.1, other.0.1),
))
}
}
#[inline(always)]
unsafe fn and(self, other: Self) -> Self {
unsafe {
Self((
_mm_and_si128(self.0.0, other.0.0),
_mm_and_si128(self.0.1, other.0.1),
))
}
}
#[inline(always)]
unsafe fn or(self, other: Self) -> Self {
unsafe {
Self((
_mm_or_si128(self.0.0, other.0.0),
_mm_or_si128(self.0.1, other.0.1),
))
}
}
#[inline(always)]
unsafe fn not(self) -> Self {
unsafe {
Self((
_mm_xor_si128(self.0.0, _mm_set1_epi32(-1)),
_mm_xor_si128(self.0.1, _mm_set1_epi32(-1)),
))
}
}
#[inline(always)]
unsafe fn shift_right_padded_u16<const L: i32>(self, other: Self) -> Self {
unsafe {
const { assert!(L >= 0 && L <= 8) };
macro_rules! impl_shift {
($l:expr) => {
Self((
_mm_alignr_epi8::<{ 16 - $l * 2 }>(self.0.0, other.0.1),
_mm_alignr_epi8::<{ 16 - $l * 2 }>(self.0.1, self.0.0),
))
};
}
match L {
0 => self,
1 => impl_shift!(1),
2 => impl_shift!(2),
3 => impl_shift!(3),
4 => impl_shift!(4),
5 => impl_shift!(5),
6 => impl_shift!(6),
7 => impl_shift!(7),
8 => Self((other.0.1, self.0.0)),
_ => unreachable!(),
}
}
}
#[cfg(test)]
fn from_array(arr: [u8; 16]) -> Self {
Self((
unsafe { _mm_loadu_si128(arr.as_ptr() as *const __m128i) },
unsafe { _mm_loadu_si128(arr.as_ptr() as *const __m128i) },
))
}
#[cfg(test)]
fn to_array(self) -> [u8; 16] {
let mut arr = [0u8; 16];
unsafe { _mm_storeu_si128(arr.as_mut_ptr() as *mut __m128i, self.0.0) };
arr
}
#[cfg(test)]
fn from_array_u16(arr: [u16; 8]) -> Self {
Self((
unsafe { _mm_loadu_si128(arr.as_ptr() as *const __m128i) },
unsafe { _mm_loadu_si128(arr.as_ptr() as *const __m128i) },
))
}
#[cfg(test)]
fn to_array_u16(self) -> [u16; 8] {
let mut arr = [0u16; 8];
unsafe { _mm_storeu_si128(arr.as_mut_ptr() as *mut __m128i, self.0.0) };
arr
}
}
impl super::Vector256 for SSE256Vector {
#[cfg(test)]
fn from_array_256_u16(arr: [u16; 16]) -> Self {
Self((
unsafe { _mm_loadu_si128(arr.as_ptr() as *const __m128i) },
unsafe { _mm_loadu_si128(arr.as_ptr().add(8) as *const __m128i) },
))
}
#[cfg(test)]
fn to_array_256_u16(self) -> [u16; 16] {
let mut arr = [0u16; 16];
unsafe { _mm_storeu_si128(arr.as_mut_ptr() as *mut __m128i, self.0.0) };
unsafe { _mm_storeu_si128(arr.as_mut_ptr().add(8) as *mut __m128i, self.0.1) };
arr
}
#[inline(always)]
unsafe fn load_unaligned(data: [u8; 32]) -> Self {
Self((
unsafe { _mm_loadu_si128(data.as_ptr() as *const __m128i) },
unsafe { _mm_loadu_si128(data.as_ptr().add(16) as *const __m128i) },
))
}
#[inline(always)]
unsafe fn idx_u16(self, search: u16) -> usize {
unsafe {
let (cmp_low, cmp_high) = (
_mm_cmpeq_epi16(self.0.0, _mm_set1_epi16(search as i16)),
_mm_cmpeq_epi16(self.0.1, _mm_set1_epi16(search as i16)),
);
let (mask_low, mask_high) = (
_mm_movemask_epi8(cmp_low) as u32,
_mm_movemask_epi8(cmp_high) as u32,
);
let low_trailing = mask_low.trailing_zeros() as usize / 2;
let high_trailing = mask_high.trailing_zeros() as usize / 2 + 8;
low_trailing.min(high_trailing)
}
}
}