use crate::{
backend::{Simd, Vector},
seal::Sealed,
Mask, MaskOps,
};
use bytemuck::{NoUninit, Pod};
use half::{bf16, f16};
use paste::paste;
pub trait Scalar: Sized + Copy + Pod + NoUninit + Default {
type Mask<S: Simd>: MaskOps;
fn lanes<S: Simd>() -> usize;
fn align_to<S: Simd>(data: &[Self]) -> (&[Self], &[Vector<S, Self>], &[Self]) {
unsafe { data.align_to() }
}
unsafe fn vload<S: Simd>(ptr: *const Self) -> Vector<S, Self>;
unsafe fn vload_unaligned<S: Simd>(ptr: *const Self) -> Vector<S, Self>;
unsafe fn vload_low<S: Simd>(ptr: *const Self) -> Vector<S, Self>;
unsafe fn vload_high<S: Simd>(ptr: *const Self) -> Vector<S, Self>;
unsafe fn vstore<S: Simd>(ptr: *mut Self, value: Vector<S, Self>);
unsafe fn vstore_unaligned<S: Simd>(ptr: *mut Self, value: Vector<S, Self>);
unsafe fn vstore_low<S: Simd>(ptr: *mut Self, value: Vector<S, Self>);
unsafe fn vstore_high<S: Simd>(ptr: *mut Self, value: Vector<S, Self>);
unsafe fn mask_store_as_bool<S: Simd>(out: *mut bool, mask: Mask<S, Self>);
fn mask_from_bools<S: Simd>(bools: &[bool]) -> Mask<S, Self>;
fn splat<S: Simd>(self) -> Vector<S, Self>;
}
macro_rules! impl_vectorizable {
($ty: ty, $bits: literal) => {
paste! {
impl Sealed for $ty {}
impl Scalar for $ty {
type Mask<S: Simd> = S::[<Mask $bits>];
fn lanes<S: Simd>() -> usize {
S::[<lanes $bits>]()
}
#[inline(always)]
unsafe fn vload<S: Simd>(ptr: *const Self) -> Vector<S, Self> {
unsafe { S::load(ptr) }
}
#[inline(always)]
unsafe fn vload_unaligned<S: Simd>(ptr: *const Self) -> Vector<S, Self> {
unsafe { S::load_unaligned(ptr) }
}
#[inline(always)]
unsafe fn vload_low<S: Simd>(ptr: *const Self) -> Vector<S, Self> {
unsafe { S::load_low(ptr) }
}
#[inline(always)]
unsafe fn vload_high<S: Simd>(ptr: *const Self) -> Vector<S, Self> {
unsafe { S::load_high(ptr) }
}
#[inline(always)]
unsafe fn vstore<S: Simd>(ptr: *mut Self, value: Vector<S, Self>) {
unsafe { S::store(ptr, value) }
}
#[inline(always)]
unsafe fn vstore_unaligned<S: Simd>(ptr: *mut Self, value: Vector<S, Self>) {
unsafe { S::store_unaligned(ptr, value) }
}
#[inline(always)]
unsafe fn vstore_low<S: Simd>(ptr: *mut Self, value: Vector<S, Self>) {
unsafe { S::store_low(ptr, value) }
}
#[inline(always)]
unsafe fn vstore_high<S: Simd>(ptr: *mut Self, value: Vector<S, Self>) {
unsafe { S::store_high(ptr, value) }
}
#[inline(always)]
unsafe fn mask_store_as_bool<S: Simd>(out: *mut bool, mask: Mask<S, Self>) {
S::[<mask_store_as_bool_ $bits>](out, *mask);
}
#[inline(always)]
fn mask_from_bools<S: Simd>(bools: &[bool]) -> Mask<S, Self> {
Mask(S::[<mask_from_bools_ $bits>](bools))
}
#[inline(always)]
fn splat<S: Simd>(self) -> Vector<S, Self> {
S::typed(S::[<splat_ $ty>](self))
}
}
}
};
}
impl_vectorizable!(u8, 8);
impl_vectorizable!(i8, 8);
impl_vectorizable!(u16, 16);
impl_vectorizable!(i16, 16);
impl_vectorizable!(u32, 32);
impl_vectorizable!(i32, 32);
impl_vectorizable!(f16, 16);
impl_vectorizable!(bf16, 16);
impl_vectorizable!(f32, 32);
impl_vectorizable!(u64, 64);
impl_vectorizable!(i64, 64);
impl_vectorizable!(f64, 64);
pub unsafe fn vload<S: Simd, T: Scalar>(ptr: *const T) -> Vector<S, T> {
unsafe { T::vload(ptr) }
}
pub unsafe fn vload_unaligned<S: Simd, T: Scalar>(ptr: *const T) -> Vector<S, T> {
unsafe { T::vload_unaligned(ptr) }
}
pub unsafe fn vload_low<S: Simd, T: Scalar>(ptr: *const T) -> Vector<S, T> {
unsafe { T::vload_low(ptr) }
}
pub unsafe fn vload_high<S: Simd, T: Scalar>(ptr: *const T) -> Vector<S, T> {
unsafe { T::vload_high(ptr) }
}
pub unsafe fn vstore<S: Simd, T: Scalar>(ptr: *mut T, value: Vector<S, T>) {
unsafe { T::vstore(ptr, value) };
}
pub unsafe fn vstore_unaligned<S: Simd, T: Scalar>(ptr: *mut T, value: Vector<S, T>) {
unsafe { T::vstore_unaligned(ptr, value) };
}
pub unsafe fn vstore_low<S: Simd, T: Scalar>(ptr: *mut T, value: Vector<S, T>) {
unsafe { T::vstore_low(ptr, value) };
}
pub unsafe fn vstore_high<S: Simd, T: Scalar>(ptr: *mut T, value: Vector<S, T>) {
unsafe { T::vstore_high(ptr, value) };
}