use std::simd::cmp::{SimdPartialEq, SimdPartialOrd};
use std::simd::{LaneCount, Mask, MaskElement, Simd, SimdElement, SupportedLaneCount};
#[cfg(target_feature = "avx512f")]
pub const SIMD_WIDTH_BITS: usize = 512;
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512f")))]
pub const SIMD_WIDTH_BITS: usize = 256;
#[cfg(not(any(target_feature = "avx2", target_feature = "avx512f")))]
pub const SIMD_WIDTH_BITS: usize = 128;
mod sealed {
pub trait Sealed {}
}
pub trait SimdMaskOps: Copy + std::ops::BitOr<Output = Self> {
fn all(self) -> bool;
fn any(self) -> bool;
fn test(self, lane: usize) -> bool;
fn to_bitmask(self) -> u64;
}
impl<T: MaskElement, const N: usize> SimdMaskOps for Mask<T, N>
where
LaneCount<N>: SupportedLaneCount,
{
#[inline(always)]
fn all(self) -> bool {
Mask::all(self)
}
#[inline(always)]
fn any(self) -> bool {
Mask::any(self)
}
#[inline(always)]
fn test(self, lane: usize) -> bool {
Mask::test(&self, lane)
}
#[inline(always)]
fn to_bitmask(self) -> u64 {
let mut mask: u64 = 0;
for i in 0..N {
if self.test(i) {
mask |= 1 << i;
}
}
mask
}
}
pub trait SortedSimdElement:
sealed::Sealed + SimdElement + Copy + PartialOrd + PartialEq + Default + Sized
{
type MaskElement: MaskElement;
const LANES: usize;
type SimdVec: Copy
+ SimdPartialEq<Mask = Self::SimdMask>
+ SimdPartialOrd<Mask = Self::SimdMask>;
type SimdMask: SimdMaskOps;
fn simd_splat(value: Self) -> Self::SimdVec;
fn simd_from_slice(slice: &[Self]) -> Self::SimdVec;
}
macro_rules! impl_sorted_simd_element {
($ty:ty, $mask_ty:ty, $lanes:expr) => {
impl sealed::Sealed for $ty {}
impl SortedSimdElement for $ty
where
LaneCount<$lanes>: SupportedLaneCount,
{
type MaskElement = $mask_ty;
const LANES: usize = $lanes;
type SimdVec = Simd<$ty, $lanes>;
type SimdMask = Mask<$mask_ty, $lanes>;
#[inline(always)]
fn simd_splat(value: Self) -> Self::SimdVec {
Simd::splat(value)
}
#[inline(always)]
fn simd_from_slice(slice: &[Self]) -> Self::SimdVec {
Simd::from_slice(slice)
}
}
};
}
#[cfg(target_feature = "avx512f")]
mod impls {
use super::*;
impl_sorted_simd_element!(u8, i8, 64);
impl_sorted_simd_element!(u16, i16, 32);
impl_sorted_simd_element!(u32, i32, 16);
impl_sorted_simd_element!(u64, i64, 8);
impl_sorted_simd_element!(i8, i8, 64);
impl_sorted_simd_element!(i16, i16, 32);
impl_sorted_simd_element!(i32, i32, 16);
impl_sorted_simd_element!(i64, i64, 8);
}
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512f")))]
mod impls {
use super::*;
impl_sorted_simd_element!(u8, i8, 32);
impl_sorted_simd_element!(u16, i16, 16);
impl_sorted_simd_element!(u32, i32, 8);
impl_sorted_simd_element!(u64, i64, 4);
impl_sorted_simd_element!(i8, i8, 32);
impl_sorted_simd_element!(i16, i16, 16);
impl_sorted_simd_element!(i32, i32, 8);
impl_sorted_simd_element!(i64, i64, 4);
}
#[cfg(not(any(target_feature = "avx2", target_feature = "avx512f")))]
mod impls {
use super::*;
impl_sorted_simd_element!(u8, i8, 16);
impl_sorted_simd_element!(u16, i16, 8);
impl_sorted_simd_element!(u32, i32, 4);
impl_sorted_simd_element!(u64, i64, 2);
impl_sorted_simd_element!(i8, i8, 16);
impl_sorted_simd_element!(i16, i16, 8);
impl_sorted_simd_element!(i32, i32, 4);
impl_sorted_simd_element!(i64, i64, 2);
}