neo_frizbee 0.9.1

Fast typo-resistant fuzzy matching via SIMD smith waterman, similar algorithm to FZF/FZY
Documentation
use std::arch::x86_64::*;

/// 256-bit vector using SSE instructions via 2 internal 128-bit vectors
#[derive(Debug, Clone, Copy)]
pub struct SSE256Vector(pub(crate) (__m128i, __m128i));

impl super::Vector for SSE256Vector {
    fn is_available() -> bool {
        raw_cpuid::CpuId::new()
            .get_feature_info()
            .is_some_and(|info| info.has_sse41())
    }

    #[inline(always)]
    unsafe fn zero() -> Self {
        unsafe { Self((_mm_setzero_si128(), _mm_setzero_si128())) }
    }

    #[inline(always)]
    unsafe fn splat_u8(value: u8) -> Self {
        unsafe { Self((_mm_set1_epi8(value as i8), _mm_set1_epi8(value as i8))) }
    }

    #[inline(always)]
    unsafe fn splat_u16(value: u16) -> Self {
        unsafe { Self((_mm_set1_epi16(value as i16), _mm_set1_epi16(value as i16))) }
    }

    #[inline(always)]
    unsafe fn eq_u8(self, other: Self) -> Self {
        unsafe {
            Self((
                _mm_cmpeq_epi8(self.0.0, other.0.0),
                _mm_cmpeq_epi8(self.0.1, other.0.1),
            ))
        }
    }

    #[inline(always)]
    unsafe fn gt_u8(self, other: Self) -> Self {
        unsafe {
            let sign_bit = _mm_set1_epi8(-128i8);
            let a_0_flipped = _mm_xor_si128(self.0.0, sign_bit);
            let b_0_flipped = _mm_xor_si128(other.0.0, sign_bit);
            let a_1_flipped = _mm_xor_si128(self.0.1, sign_bit);
            let b_1_flipped = _mm_xor_si128(other.0.1, sign_bit);

            Self((
                _mm_cmpgt_epi8(a_0_flipped, b_0_flipped),
                _mm_cmpgt_epi8(a_1_flipped, b_1_flipped),
            ))
        }
    }

    #[inline(always)]
    unsafe fn lt_u8(self, other: Self) -> Self {
        unsafe {
            let sign_bit = _mm_set1_epi8(-128i8);
            let a_0_flipped = _mm_xor_si128(self.0.0, sign_bit);
            let b_0_flipped = _mm_xor_si128(other.0.0, sign_bit);
            let a_1_flipped = _mm_xor_si128(self.0.1, sign_bit);
            let b_1_flipped = _mm_xor_si128(other.0.1, sign_bit);

            Self((
                _mm_cmplt_epi8(a_0_flipped, b_0_flipped),
                _mm_cmplt_epi8(a_1_flipped, b_1_flipped),
            ))
        }
    }

    #[inline(always)]
    unsafe fn max_u16(self, other: Self) -> Self {
        unsafe {
            Self((
                _mm_max_epu16(self.0.0, other.0.0),
                _mm_max_epu16(self.0.1, other.0.1),
            ))
        }
    }

    #[inline(always)]
    unsafe fn smax_u16(self) -> u16 {
        unsafe {
            // PHMINPOSUW finds minimum, so we invert to find maximum
            let all_ones = _mm_set1_epi16(-1); // 0xFFFF
            let inverted = (
                _mm_xor_si128(self.0.0, all_ones),
                _mm_xor_si128(self.0.1, all_ones),
            ); // ~v

            // Find minimum of inverted values (= maximum of original)
            let min_pos = (_mm_minpos_epu16(inverted.0), _mm_minpos_epu16(inverted.1));
            let min_pos = _mm_min_epu16(min_pos.0, min_pos.1);

            // Extract and invert back
            let min_val = _mm_extract_epi16(min_pos, 0) as u16;
            !min_val // Invert to get original max
        }
    }

    #[inline(always)]
    unsafe fn add_u16(self, other: Self) -> Self {
        unsafe {
            Self((
                _mm_add_epi16(self.0.0, other.0.0),
                _mm_add_epi16(self.0.1, other.0.1),
            ))
        }
    }

    #[inline(always)]
    unsafe fn subs_u16(self, other: Self) -> Self {
        unsafe {
            Self((
                _mm_subs_epu16(self.0.0, other.0.0),
                _mm_subs_epu16(self.0.1, other.0.1),
            ))
        }
    }

    #[inline(always)]
    unsafe fn and(self, other: Self) -> Self {
        unsafe {
            Self((
                _mm_and_si128(self.0.0, other.0.0),
                _mm_and_si128(self.0.1, other.0.1),
            ))
        }
    }

    #[inline(always)]
    unsafe fn or(self, other: Self) -> Self {
        unsafe {
            Self((
                _mm_or_si128(self.0.0, other.0.0),
                _mm_or_si128(self.0.1, other.0.1),
            ))
        }
    }

    #[inline(always)]
    unsafe fn not(self) -> Self {
        unsafe {
            Self((
                _mm_xor_si128(self.0.0, _mm_set1_epi32(-1)),
                _mm_xor_si128(self.0.1, _mm_set1_epi32(-1)),
            ))
        }
    }

    #[inline(always)]
    unsafe fn shift_right_padded_u16<const L: i32>(self, other: Self) -> Self {
        unsafe {
            const { assert!(L >= 0 && L <= 8) };

            macro_rules! impl_shift {
                ($l:expr) => {
                    Self((
                        _mm_alignr_epi8::<{ 16 - $l * 2 }>(self.0.0, other.0.1),
                        _mm_alignr_epi8::<{ 16 - $l * 2 }>(self.0.1, self.0.0),
                    ))
                };
            }

            match L {
                0 => self,
                1 => impl_shift!(1),
                2 => impl_shift!(2),
                3 => impl_shift!(3),
                4 => impl_shift!(4),
                5 => impl_shift!(5),
                6 => impl_shift!(6),
                7 => impl_shift!(7),
                8 => Self((other.0.1, self.0.0)),
                _ => unreachable!(),
            }
        }
    }

    #[cfg(test)]
    fn from_array(arr: [u8; 16]) -> Self {
        Self((
            unsafe { _mm_loadu_si128(arr.as_ptr() as *const __m128i) },
            unsafe { _mm_loadu_si128(arr.as_ptr() as *const __m128i) },
        ))
    }
    #[cfg(test)]
    fn to_array(self) -> [u8; 16] {
        let mut arr = [0u8; 16];
        unsafe { _mm_storeu_si128(arr.as_mut_ptr() as *mut __m128i, self.0.0) };
        arr
    }
    #[cfg(test)]
    fn from_array_u16(arr: [u16; 8]) -> Self {
        Self((
            unsafe { _mm_loadu_si128(arr.as_ptr() as *const __m128i) },
            unsafe { _mm_loadu_si128(arr.as_ptr() as *const __m128i) },
        ))
    }
    #[cfg(test)]
    fn to_array_u16(self) -> [u16; 8] {
        let mut arr = [0u16; 8];
        unsafe { _mm_storeu_si128(arr.as_mut_ptr() as *mut __m128i, self.0.0) };
        arr
    }
}

impl super::Vector256 for SSE256Vector {
    #[cfg(test)]
    fn from_array_256_u16(arr: [u16; 16]) -> Self {
        Self((
            unsafe { _mm_loadu_si128(arr.as_ptr() as *const __m128i) },
            unsafe { _mm_loadu_si128(arr.as_ptr().add(8) as *const __m128i) },
        ))
    }
    #[cfg(test)]
    fn to_array_256_u16(self) -> [u16; 16] {
        let mut arr = [0u16; 16];
        unsafe { _mm_storeu_si128(arr.as_mut_ptr() as *mut __m128i, self.0.0) };
        unsafe { _mm_storeu_si128(arr.as_mut_ptr().add(8) as *mut __m128i, self.0.1) };
        arr
    }

    #[inline(always)]
    unsafe fn load_unaligned(data: [u8; 32]) -> Self {
        Self((
            unsafe { _mm_loadu_si128(data.as_ptr() as *const __m128i) },
            unsafe { _mm_loadu_si128(data.as_ptr().add(16) as *const __m128i) },
        ))
    }

    #[inline(always)]
    unsafe fn idx_u16(self, search: u16) -> usize {
        unsafe {
            // compare all elements with max, get mask
            let (cmp_low, cmp_high) = (
                _mm_cmpeq_epi16(self.0.0, _mm_set1_epi16(search as i16)),
                _mm_cmpeq_epi16(self.0.1, _mm_set1_epi16(search as i16)),
            );
            let (mask_low, mask_high) = (
                _mm_movemask_epi8(cmp_low) as u32,
                _mm_movemask_epi8(cmp_high) as u32,
            );

            // find first set bit
            // divide by 2 to get element index (since u16 = 2 bytes)
            let low_trailing = mask_low.trailing_zeros() as usize / 2;
            let high_trailing = mask_high.trailing_zeros() as usize / 2 + 8;
            low_trailing.min(high_trailing)
        }
    }
}