use super::{
arch,
bitmask::{BitMask, FromInt},
constant::{Const, SupportedLaneCount},
};
pub trait ArrayType<T>: SupportedLaneCount {
type Type;
}
impl<T, const N: usize> ArrayType<T> for Const<N>
where
Const<N>: SupportedLaneCount,
{
type Type = [T; N];
}
pub trait BitMaskType<A: arch::Sealed>: SupportedLaneCount {
type Type; }
impl<A, const N: usize> BitMaskType<A> for Const<N>
where
Const<N>: SupportedLaneCount,
A: arch::Sealed,
{
type Type = BitMask<N, A>;
}
pub trait AsSIMD<T>: Copy
where
T: SIMDVector,
{
fn as_simd(self, arch: T::Arch) -> T;
}
pub trait SIMDMask: Copy + std::fmt::Debug {
type Arch: arch::Sealed;
type Underlying: Copy + std::fmt::Debug;
type BitMask: SIMDMask<Arch = Self::Arch> + Into<Self> + From<Self>;
const LANES: usize;
const ISBITS: bool;
fn arch(self) -> Self::Arch;
fn to_underlying(self) -> Self::Underlying;
fn from_underlying(arch: Self::Arch, value: Self::Underlying) -> Self;
fn get_unchecked(&self, i: usize) -> bool;
fn keep_first(arch: Self::Arch, i: usize) -> Self;
fn first(&self) -> Option<usize> {
self.bitmask().first()
}
fn bitmask(self) -> Self::BitMask {
<Self::BitMask as From<Self>>::from(self)
}
fn get(&self, i: usize) -> Option<bool> {
if i >= Self::LANES {
None
} else {
Some(self.get_unchecked(i))
}
}
#[inline(always)]
fn from_fn<F>(arch: Self::Arch, f: F) -> Self
where
F: FnMut(usize) -> bool,
{
Self::BitMask::from_fn(arch, f).into()
}
#[inline(always)]
fn any(self) -> bool {
<Self::BitMask as From<Self>>::from(self).any()
}
#[inline(always)]
fn all(self) -> bool {
<Self::BitMask as From<Self>>::from(self).all()
}
#[inline(always)]
fn none(self) -> bool {
!self.any()
}
#[inline(always)]
fn count(self) -> usize {
<Self::BitMask as From<Self>>::from(self).count()
}
}
pub trait SIMDVector: Copy + std::fmt::Debug {
type Arch: arch::Sealed;
type Scalar: Copy + std::fmt::Debug;
type Underlying: Copy;
const LANES: usize;
type ConstLanes: ArrayType<Self::Scalar> + BitMaskType<Self::Arch>;
type Mask: SIMDMask<Arch = Self::Arch>
+ From<<Self::ConstLanes as BitMaskType<Self::Arch>>::Type>
+ Into<<Self::ConstLanes as BitMaskType<Self::Arch>>::Type>;
const EMULATED: bool;
fn arch(self) -> Self::Arch;
fn default(arch: Self::Arch) -> Self;
fn to_underlying(self) -> Self::Underlying;
fn from_underlying(arch: Self::Arch, repr: Self::Underlying) -> Self;
fn to_array(self) -> <Self::ConstLanes as ArrayType<Self::Scalar>>::Type;
fn from_array(arch: Self::Arch, x: <Self::ConstLanes as ArrayType<Self::Scalar>>::Type)
-> Self;
fn splat(arch: Self::Arch, value: Self::Scalar) -> Self;
fn num_lanes() -> usize {
Self::LANES
}
unsafe fn load_simd(arch: Self::Arch, ptr: *const <Self as SIMDVector>::Scalar) -> Self;
unsafe fn load_simd_masked_logical(
arch: Self::Arch,
ptr: *const <Self as SIMDVector>::Scalar,
mask: <Self as SIMDVector>::Mask,
) -> Self;
#[inline(always)]
unsafe fn load_simd_masked(
arch: Self::Arch,
ptr: *const <Self as SIMDVector>::Scalar,
mask: <<Self as SIMDVector>::ConstLanes as BitMaskType<Self::Arch>>::Type,
) -> Self {
unsafe { Self::load_simd_masked_logical(arch, ptr, mask.into()) }
}
#[inline(always)]
unsafe fn load_simd_first(
arch: Self::Arch,
ptr: *const <Self as SIMDVector>::Scalar,
first: usize,
) -> Self {
unsafe {
Self::load_simd_masked_logical(
arch,
ptr,
<Self as SIMDVector>::Mask::keep_first(arch, first),
)
}
}
unsafe fn store_simd(self, ptr: *mut <Self as SIMDVector>::Scalar);
unsafe fn store_simd_masked_logical(
self,
ptr: *mut <Self as SIMDVector>::Scalar,
mask: <Self as SIMDVector>::Mask,
);
#[inline(always)]
unsafe fn store_simd_masked(
self,
ptr: *mut <Self as SIMDVector>::Scalar,
mask: <<Self as SIMDVector>::ConstLanes as BitMaskType<Self::Arch>>::Type,
) {
unsafe { self.store_simd_masked_logical(ptr, mask.into()) }
}
#[inline(always)]
unsafe fn store_simd_first(self, ptr: *mut <Self as SIMDVector>::Scalar, first: usize) {
unsafe {
self.store_simd_masked_logical(
ptr,
<Self as SIMDVector>::Mask::keep_first(self.arch(), first),
)
}
}
#[inline(always)]
fn cast<T>(self) -> <Self as SIMDCast<T>>::Cast
where
Self: SIMDCast<T>,
{
self.simd_cast()
}
}
pub trait SIMDMulAdd {
fn mul_add_simd(self, rhs: Self, accumulator: Self) -> Self;
}
pub trait SIMDMinMax: Sized {
fn min_simd(self, rhs: Self) -> Self;
#[inline(always)]
fn min_simd_standard(self, rhs: Self) -> Self {
self.min_simd(rhs)
}
fn max_simd(self, rhs: Self) -> Self;
#[inline(always)]
fn max_simd_standard(self, rhs: Self) -> Self {
self.max_simd(rhs)
}
}
pub trait SIMDAbs {
fn abs_simd(self) -> Self;
}
pub trait SIMDPartialEq: SIMDVector {
fn eq_simd(self, other: Self) -> Self::Mask;
fn ne_simd(self, other: Self) -> Self::Mask;
}
pub trait SIMDPartialOrd: SIMDVector {
fn lt_simd(self, other: Self) -> Self::Mask;
fn le_simd(self, other: Self) -> Self::Mask;
#[inline(always)]
fn gt_simd(self, other: Self) -> Self::Mask {
other.lt_simd(self)
}
#[inline(always)]
fn ge_simd(self, other: Self) -> Self::Mask {
other.le_simd(self)
}
}
pub trait SIMDSumTree: SIMDVector {
fn sum_tree(self) -> <Self as SIMDVector>::Scalar;
}
pub trait SIMDSelect<V: SIMDVector>: SIMDMask {
fn select(self, x: V, y: V) -> V;
}
pub trait SIMDDotProduct<L: SIMDVector, R: SIMDVector = L> {
fn dot_simd(self, left: L, right: R) -> Self;
}
pub trait SIMDReinterpret<To: SIMDVector>: SIMDVector {
fn reinterpret_simd(self) -> To;
}
pub trait SIMDCast<T>: SIMDVector {
type Cast: SIMDVector<Scalar = T, ConstLanes = Self::ConstLanes>;
fn simd_cast(self) -> Self::Cast;
}
pub trait SIMDFloat:
SIMDVector
+ std::ops::Add<Output = Self>
+ std::ops::Mul<Output = Self>
+ std::ops::Sub<Output = Self>
+ SIMDMulAdd
+ SIMDMinMax
+ SIMDPartialEq
+ SIMDPartialOrd
{
}
impl<T> SIMDFloat for T where
T: SIMDVector
+ std::ops::Add<Output = Self>
+ std::ops::Mul<Output = Self>
+ std::ops::Sub<Output = Self>
+ SIMDMulAdd
+ SIMDMinMax
+ SIMDPartialEq
+ SIMDPartialOrd
{
}
pub trait SIMDUnsigned:
SIMDVector
+ std::ops::Add<Output = Self>
+ std::ops::Mul<Output = Self>
+ std::ops::Sub<Output = Self>
+ std::ops::BitAnd<Output = Self>
+ std::ops::BitOr<Output = Self>
+ std::ops::BitXor<Output = Self>
+ std::ops::Shr<Output = Self>
+ std::ops::Shl<Output = Self>
+ std::ops::Shr<Self::Scalar, Output = Self>
+ std::ops::Shl<Self::Scalar, Output = Self>
+ SIMDMulAdd
+ SIMDPartialEq
+ SIMDPartialOrd
{
}
impl<T> SIMDUnsigned for T where
T: SIMDVector
+ std::ops::Add<Output = Self>
+ std::ops::Mul<Output = Self>
+ std::ops::Sub<Output = Self>
+ std::ops::BitAnd<Output = Self>
+ std::ops::BitOr<Output = Self>
+ std::ops::BitXor<Output = Self>
+ std::ops::Shr<Output = Self>
+ std::ops::Shl<Output = Self>
+ std::ops::Shr<Self::Scalar, Output = Self>
+ std::ops::Shl<Self::Scalar, Output = Self>
+ SIMDMulAdd
+ SIMDPartialEq
+ SIMDPartialOrd
{
}
pub trait SIMDSigned: SIMDUnsigned + SIMDAbs {}
impl<T> SIMDSigned for T where T: SIMDUnsigned + SIMDAbs {}
macro_rules! impl_simd_mask_for_bitmask {
($N:literal, $repr:ty, $submask:expr) => {
impl<A: arch::Sealed> SIMDMask for BitMask<$N, A> {
type Arch = A;
type Underlying = $repr;
type BitMask = Self;
const ISBITS: bool = true;
const LANES: usize = $N;
#[inline(always)]
fn arch(self) -> A {
self.get_arch()
}
#[inline(always)]
fn to_underlying(self) -> Self::Underlying {
self.0
}
#[inline(always)]
fn from_underlying(arch: A, value: Self::Underlying) -> Self {
Self::from_int(arch, value)
}
#[inline(always)]
fn keep_first(arch: A, i: usize) -> Self {
let i = i.min(Self::LANES);
if Self::LANES == 64 && i == 64 {
return Self::from_underlying(arch, Self::Underlying::MAX);
}
let one: u64 = 1;
Self::from_underlying(arch, ((one << i) - one) as Self::Underlying)
}
#[inline(always)]
fn get_unchecked(&self, i: usize) -> bool {
if i >= Self::LANES {
false
} else {
(self.0 >> i) % 2 == 1
}
}
#[inline(always)]
fn first(&self) -> Option<usize> {
let count = self.0.trailing_zeros() as usize;
if count >= Self::LANES {
None
} else {
Some(count)
}
}
fn from_fn<F>(arch: A, mut f: F) -> Self
where
F: FnMut(usize) -> bool,
{
let mut x: $repr = 0;
for i in 0..Self::LANES {
if f(i) {
x |= (1 << i);
}
}
Self::from_underlying(arch, x)
}
#[inline(always)]
fn any(self) -> bool {
self.0 != 0
}
#[inline(always)]
fn all(self) -> bool {
let v: u64 = self.0.into();
if $N == 64 {
v == u64::MAX
} else {
v == (1 << $N) - 1
}
}
#[inline(always)]
fn count(self) -> usize {
self.0.count_ones() as usize
}
}
impl From<BitMask<$N>> for $repr {
fn from(value: BitMask<$N>) -> Self {
value.to_underlying()
}
}
};
}
impl_simd_mask_for_bitmask!(1, u8, 0x1);
impl_simd_mask_for_bitmask!(2, u8, 0x3);
impl_simd_mask_for_bitmask!(4, u8, 0xf);
impl_simd_mask_for_bitmask!(8, u8, u8::MAX);
impl_simd_mask_for_bitmask!(16, u16, u16::MAX);
impl_simd_mask_for_bitmask!(32, u32, u32::MAX);
impl_simd_mask_for_bitmask!(64, u64, u64::MAX);
#[cfg(test)]
mod test_traits {
use rand::{
SeedableRng,
distr::{Distribution, StandardUniform},
rngs::StdRng,
};
use super::*;
use crate::{
ARCH, arch,
splitjoin::{LoHi, SplitJoin},
test_utils,
};
trait FromU128 {
fn from_(value: u128) -> Self;
}
impl FromU128 for u8 {
fn from_(value: u128) -> Self {
value as u8
}
}
impl FromU128 for u16 {
fn from_(value: u128) -> Self {
value as u16
}
}
impl FromU128 for u32 {
fn from_(value: u128) -> Self {
value as u32
}
}
impl FromU128 for u64 {
fn from_(value: u128) -> Self {
value as u64
}
}
fn test_bitmask_impl<const N: usize, T>()
where
Const<N>: SupportedLaneCount, T: std::fmt::Debug + std::cmp::Eq + FromU128 + From<BitMask<N, arch::Current>>,
BitMask<N, arch::Current>: SIMDMask<Arch = arch::Current, Underlying = T>,
{
const MAXLEN: usize = 64;
assert_eq!(N, BitMask::<N, arch::Current>::LANES);
let one = 1_u128;
let all: u128 = (one << N) - one;
for i in 0..=MAXLEN {
let mask = BitMask::<N, arch::Current>::keep_first(arch::current(), i);
let expected: u128 = ((one << i) - one) & all;
assert_eq!(mask.to_underlying(), T::from_(expected));
assert_eq!(T::from_(expected), mask.into());
for j in 0..=MAXLEN {
let b = mask.get_unchecked(j);
let o = mask.get(j);
let expected: bool = j < i;
if j < N {
assert_eq!(b, expected);
assert_eq!(o.unwrap(), expected);
} else {
assert!(!b);
assert!(o.is_none());
}
}
if i == 0 {
assert!(!mask.any());
assert!(!mask.all());
assert!(mask.none());
} else if i >= N {
assert!(mask.any());
assert!(mask.all());
assert!(!mask.none());
} else {
assert!(mask.any());
assert!(!mask.all());
assert!(!mask.none());
}
}
}
#[test]
fn test_bitmask() {
test_bitmask_impl::<1, u8>();
test_bitmask_impl::<2, u8>();
test_bitmask_impl::<4, u8>();
test_bitmask_impl::<8, u8>();
test_bitmask_impl::<16, u16>();
test_bitmask_impl::<32, u32>();
test_bitmask_impl::<64, u64>();
}
fn test_bitmask_splitjoin_impl<const N: usize, const NHALF: usize>(ntrials: usize, seed: u64)
where
Const<N>: SupportedLaneCount,
Const<NHALF>: SupportedLaneCount,
BitMask<N, arch::Current>:
SIMDMask<Arch = arch::Current> + SplitJoin<Halved = BitMask<NHALF, arch::Current>>,
BitMask<NHALF, arch::Current>: SIMDMask<Arch = arch::Current>,
{
let mut rng = StdRng::seed_from_u64(seed);
for _ in 0..ntrials {
let base = BitMask::<N>::from_fn(ARCH, |_| StandardUniform {}.sample(&mut rng));
let LoHi { lo, hi } = base.split();
for i in 0..NHALF {
assert_eq!(base.get(i).unwrap(), lo.get(i).unwrap());
}
for i in 0..NHALF {
assert_eq!(base.get(i + NHALF).unwrap(), hi.get(i).unwrap());
}
let joined = BitMask::<N>::join(LoHi::new(lo, hi));
bitmasks_equal(base, joined);
}
}
#[test]
fn test_bitmask_splitjoin() {
test_bitmask_splitjoin_impl::<2, 1>(100, 0xcbdbdca310caec88);
test_bitmask_splitjoin_impl::<4, 2>(100, 0x9c8b9b6c70d941c5);
test_bitmask_splitjoin_impl::<8, 4>(100, 0xc81a25918b683d39);
test_bitmask_splitjoin_impl::<16, 8>(50, 0xad045b437c3fa0cc);
test_bitmask_splitjoin_impl::<32, 16>(50, 0xe710ccdbbd329c77);
test_bitmask_splitjoin_impl::<64, 32>(25, 0xd6697e3c534fc134);
}
#[test]
fn test_zeroing() {
let b = BitMask::<2>::from_underlying(arch::current(), 0xff);
assert_eq!(b.to_underlying(), 0x3);
assert_eq!(b.count(), 2);
let b = BitMask::<4>::from_underlying(arch::current(), 0xff);
assert_eq!(b.to_underlying(), 0xf);
assert_eq!(b.count(), 4);
}
fn bitmasks_equal<const N: usize>(x: BitMask<N, arch::Current>, y: BitMask<N, arch::Current>)
where
Const<N>: SupportedLaneCount,
BitMask<N, arch::Current>: SIMDMask,
{
assert_eq!(x.0, y.0);
}
macro_rules! test_simdmask {
($N:literal) => {
paste::paste! {
#[test]
fn [<test_simd_mask_ $N>]() {
let arch = arch::current();
test_utils::mask::test_keep_first::<BitMask<$N, arch::Current>, $N, _, _>(
arch,
bitmasks_equal
);
test_utils::mask::test_from_fn::<BitMask<$N, arch::Current>, $N, _, _>(
arch,
bitmasks_equal
);
test_utils::mask::test_reductions::<BitMask<$N, arch::Current>, $N, _, _>(
arch,
bitmasks_equal
);
test_utils::mask::test_first::<BitMask<$N, arch::Current>, $N, _, _>(
arch,
bitmasks_equal
);
}
}
};
}
test_simdmask!(2);
test_simdmask!(4);
test_simdmask!(8);
test_simdmask!(16);
test_simdmask!(32);
test_simdmask!(64);
}