use half::f16;
use crate::{
Const, SIMDCast, SIMDDotProduct, SIMDFloat, SIMDMask, SIMDSelect, SIMDSigned, SIMDSumTree,
SIMDUnsigned, SIMDVector, SplitJoin, ZipUnzip, lifetime::AddLifetime,
};
pub(crate) mod emulated;
pub use emulated::Scalar;
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub struct Level(LevelInner);
impl Level {
const fn scalar() -> Self {
Self(LevelInner::Scalar)
}
}
cfg_if::cfg_if! {
if #[cfg(target_arch = "x86_64")] {
pub mod x86_64;
use x86_64::LevelInner;
pub use x86_64::current;
pub use x86_64::Current;
pub use x86_64::dispatch;
pub use x86_64::dispatch1;
pub use x86_64::dispatch2;
pub use x86_64::dispatch3;
pub use x86_64::dispatch_no_features;
pub use x86_64::dispatch1_no_features;
pub use x86_64::dispatch2_no_features;
pub use x86_64::dispatch3_no_features;
impl Level {
const fn v3() -> Self {
Self(LevelInner::V3)
}
const fn v4() -> Self {
Self(LevelInner::V4)
}
}
} else if #[cfg(target_arch = "aarch64")] {
pub mod aarch64;
use aarch64::LevelInner;
pub use aarch64::current;
pub use aarch64::Current;
pub use aarch64::dispatch;
pub use aarch64::dispatch1;
pub use aarch64::dispatch2;
pub use aarch64::dispatch3;
pub use aarch64::dispatch_no_features;
pub use aarch64::dispatch1_no_features;
pub use aarch64::dispatch2_no_features;
pub use aarch64::dispatch3_no_features;
impl Level {
const fn neon() -> Self {
Self(LevelInner::Neon)
}
}
} else {
pub type Current = Scalar;
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
enum LevelInner {
Scalar,
}
pub const fn current() -> Current {
Scalar::new()
}
pub fn dispatch<T, R>(f: T) -> R
where T: Target<Scalar, R> {
f.run(Scalar::new())
}
pub fn dispatch1<T, T0, R>(f: T, x0: T0) -> R
where T: Target1<Scalar, R, T0> {
f.run(Scalar::new(), x0)
}
pub fn dispatch2<T, T0, T1, R>(f: T, x0: T0, x1: T1) -> R
where T: Target2<Scalar, R, T0, T1> {
f.run(Scalar::new(), x0, x1)
}
pub fn dispatch3<T, T0, T1, T2, R>(f: T, x0: T0, x1: T1, x2: T2) -> R
where T: Target3<Scalar, R, T0, T1, T2> {
f.run(Scalar::new(), x0, x1, x2)
}
pub fn dispatch_no_features<T, R>(f: T) -> R
where T: Target<Scalar, R> {
f.run(Scalar::new())
}
pub fn dispatch1_no_features<T, T0, R>(f: T, x0: T0) -> R
where T: Target1<Scalar, R, T0> {
f.run(Scalar::new(), x0)
}
pub fn dispatch2_no_features<T, T0, T1, R>(f: T, x0: T0, x1: T1) -> R
where T: Target2<Scalar, R, T0, T1> {
f.run(Scalar::new(), x0, x1)
}
pub fn dispatch3_no_features<T, T0, T1, T2, R>(f: T, x0: T0, x1: T1, x2: T2) -> R
where T: Target3<Scalar, R, T0, T1, T2> {
f.run(Scalar::new(), x0, x1, x2)
}
}
}
mod sealed {
pub trait Sealed: std::fmt::Debug + Copy + PartialEq + Send + Sync + 'static {}
}
pub(crate) use sealed::Sealed;
macro_rules! vector {
($me:ident: <$self:ident, $T:ty, $N:literal, $mask:ident> + $($rest:tt)*) => {
type $me: SIMDVector<Arch = $self, Scalar = $T, ConstLanes = Const<$N>, Mask = Self::$mask> + $($rest)*;
}
}
#[allow(non_camel_case_types)]
pub trait Architecture: sealed::Sealed {
type mask_f16x8: SIMDMask;
type mask_f16x16: SIMDMask;
type mask_f32x4: SIMDMask + SIMDSelect<Self::f32x4>;
type mask_f32x8: SIMDMask + SIMDSelect<Self::f32x8>;
type mask_f32x16: SIMDMask + SIMDSelect<Self::f32x16>;
type mask_i8x16: SIMDMask;
type mask_i8x32: SIMDMask;
type mask_i8x64: SIMDMask;
type mask_i16x8: SIMDMask;
type mask_i16x16: SIMDMask;
type mask_i16x32: SIMDMask;
type mask_i32x4: SIMDMask;
type mask_i32x8: SIMDMask + From<Self::mask_f32x8> + SIMDSelect<Self::i32x8>;
type mask_i32x16: SIMDMask + SIMDSelect<Self::i32x16>;
type mask_u8x16: SIMDMask;
type mask_u8x32: SIMDMask;
type mask_u8x64: SIMDMask;
type mask_u32x4: SIMDMask;
type mask_u32x8: SIMDMask + From<Self::mask_f32x8>;
type mask_u32x16: SIMDMask + SIMDSelect<Self::u32x16>;
type mask_u64x2: SIMDMask;
type mask_u64x4: SIMDMask;
vector!(
f16x8: <Self, f16, 8, mask_f16x8>
+ SIMDCast<f32, Cast = Self::f32x8>
);
vector!(
f16x16: <Self, f16, 16, mask_f16x16>
+ SplitJoin<Halved = Self::f16x8>
+ ZipUnzip<Halved = Self::f16x8>
+ SIMDCast<f32, Cast = Self::f32x16>
);
vector!(
f32x4: <Self, f32, 4, mask_f32x4>
+ SIMDFloat
+ SIMDSumTree
);
vector!(
f32x8: <Self, f32, 8, mask_f32x8>
+ SIMDFloat
+ SIMDSumTree
+ SIMDCast<f16, Cast = Self::f16x8>
+ SplitJoin<Halved = Self::f32x4>
+ From<Self::f16x8>
);
vector!(
f32x16: <Self, f32, 16, mask_f32x16>
+ SIMDFloat
+ SplitJoin<Halved = Self::f32x8>
+ SIMDSumTree
+ From<Self::f16x16>
);
vector!(
i8x16: <Self, i8, 16, mask_i8x16>
+ SIMDSigned
);
vector!(
i8x32: <Self, i8, 32, mask_i8x32>
+ SIMDSigned
+ SplitJoin<Halved = Self::i8x16>
+ ZipUnzip<Halved = Self::i8x16>
);
vector!(
i8x64: <Self, i8, 64, mask_i8x64>
+ SIMDSigned
);
vector!(
i16x8: <Self, i16, 8, mask_i16x8>
+ SIMDSigned
);
vector!(
i16x16: <Self, i16, 16, mask_i16x16>
+ SIMDSigned
+ SplitJoin<Halved = Self::i16x8>
+ ZipUnzip<Halved = Self::i16x8>
+ From<Self::i8x16>
+ From<Self::u8x16>
);
vector!(
i16x32: <Self, i16, 32, mask_i16x32>
+ SIMDSigned
+ SplitJoin<Halved = Self::i16x16>
+ From<Self::i8x32>
+ From<Self::u8x32>
);
vector!(
i32x4: <Self, i32, 4, mask_i32x4>
+ SIMDSigned
);
vector!(
i32x8: <Self, i32, 8, mask_i32x8>
+ SIMDSigned
+ SIMDSumTree
+ SplitJoin<Halved = Self::i32x4>
+ ZipUnzip<Halved = Self::i32x4>
+ SIMDDotProduct<Self::i16x16>
+ SIMDDotProduct<Self::u8x32, Self::i8x32>
+ SIMDDotProduct<Self::i8x32, Self::u8x32>
+ SIMDCast<f32, Cast = Self::f32x8>
);
vector!(
i32x16: <Self, i32, 16, mask_i32x16>
+ SIMDSigned
+ SIMDSumTree
+ SplitJoin<Halved = Self::i32x8>
+ SIMDDotProduct<Self::u8x64, Self::i8x64>
+ SIMDDotProduct<Self::i8x64, Self::u8x64>
);
vector!(
u8x16: <Self, u8, 16, mask_u8x16>
+ SIMDUnsigned
);
vector!(
u8x32: <Self, u8, 32, mask_u8x32>
+ SIMDUnsigned
+ SplitJoin<Halved = Self::u8x16>
+ ZipUnzip<Halved = Self::u8x16>
);
vector!(
u8x64: <Self, u8, 64, mask_u8x64>
+ SIMDUnsigned
);
vector!(
u32x4: <Self, u32, 4, mask_u32x4>
+ SIMDUnsigned
);
vector!(
u32x8: <Self, u32, 8, mask_u32x8>
+ SplitJoin<Halved = Self::u32x4>
+ ZipUnzip<Halved = Self::u32x4>
+ SIMDUnsigned
+ SIMDSumTree
);
vector!(
u32x16: <Self, u32, 16, mask_u32x16>
+ SIMDUnsigned
+ SIMDSumTree
+ SplitJoin<Halved = Self::u32x8>
);
vector!(
u64x2: <Self, u64, 2, mask_u64x2>
+ SIMDUnsigned
);
vector!(
u64x4: <Self, u64, 4, mask_u64x4>
+ SplitJoin<Halved = Self::u64x2>
+ SIMDUnsigned
);
fn level() -> Level;
fn run<F, R>(self, f: F) -> R
where
F: Target<Self, R>;
fn run_inline<F, R>(self, f: F) -> R
where
F: Target<Self, R>;
fn run1<F, T0, R>(self, f: F, x0: T0) -> R
where
F: Target1<Self, R, T0>;
fn run1_inline<F, T0, R>(self, f: F, x0: T0) -> R
where
F: Target1<Self, R, T0>;
fn run2<F, T0, T1, R>(self, f: F, x0: T0, x1: T1) -> R
where
F: Target2<Self, R, T0, T1>;
fn run2_inline<F, T0, T1, R>(self, f: F, x0: T0, x1: T1) -> R
where
F: Target2<Self, R, T0, T1>;
fn run3<F, T0, T1, T2, R>(self, f: F, x0: T0, x1: T1, x2: T2) -> R
where
F: Target3<Self, R, T0, T1, T2>;
fn run3_inline<F, T0, T1, T2, R>(self, f: F, x0: T0, x1: T1, x2: T2) -> R
where
F: Target3<Self, R, T0, T1, T2>;
fn dispatch1<F, R, T0>(self) -> Dispatched1<R, T0>
where
T0: AddLifetime,
F: for<'a> FTarget1<Self, R, T0::Of<'a>>;
fn dispatch2<F, R, T0, T1>(self) -> Dispatched2<R, T0, T1>
where
T0: AddLifetime,
T1: AddLifetime,
F: for<'a, 'b> FTarget2<Self, R, T0::Of<'a>, T1::Of<'b>>;
fn dispatch3<F, R, T0, T1, T2>(self) -> Dispatched3<R, T0, T1, T2>
where
T0: AddLifetime,
T1: AddLifetime,
T2: AddLifetime,
F: for<'a, 'b, 'c> FTarget3<Self, R, T0::Of<'a>, T1::Of<'b>, T2::Of<'c>>;
}
pub trait Target<A, R>
where
A: Architecture,
{
fn run(self, arch: A) -> R;
}
pub trait Target1<A, R, T0>
where
A: Architecture,
{
fn run(self, arch: A, x0: T0) -> R;
}
pub trait Target2<A, R, T0, T1>
where
A: Architecture,
{
fn run(self, arch: A, x0: T0, x1: T1) -> R;
}
pub trait Target3<A, R, T0, T1, T2>
where
A: Architecture,
{
fn run(self, arch: A, x0: T0, x1: T1, x2: T2) -> R;
}
pub trait FTarget1<A, R, T0>
where
A: Architecture,
{
fn run(arch: A, x0: T0) -> R;
}
pub trait FTarget2<A, R, T0, T1>
where
A: Architecture,
{
fn run(arch: A, x0: T0, x1: T1) -> R;
}
pub trait FTarget3<A, R, T0, T1, T2>
where
A: Architecture,
{
fn run(arch: A, x0: T0, x1: T1, x2: T2) -> R;
}
impl<A, R, F> Target<A, R> for F
where
A: Architecture,
F: FnOnce() -> R,
{
#[inline]
fn run(self, _: A) -> R {
(self)()
}
}
impl<A, R, T0, F> Target1<A, R, T0> for F
where
A: Architecture,
F: FnOnce(T0) -> R,
{
#[inline]
fn run(self, _: A, x0: T0) -> R {
(self)(x0)
}
}
impl<A, R, T0, T1, F> Target2<A, R, T0, T1> for F
where
A: Architecture,
F: FnOnce(T0, T1) -> R,
{
#[inline]
fn run(self, _: A, x0: T0, x1: T1) -> R {
(self)(x0, x1)
}
}
impl<A, R, T0, T1, T2, F> Target3<A, R, T0, T1, T2> for F
where
A: Architecture,
F: FnOnce(T0, T1, T2) -> R,
{
#[inline]
fn run(self, _: A, x0: T0, x1: T1, x2: T2) -> R {
(self)(x0, x1, x2)
}
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
struct Hidden;
const _ASSERT_ZST: () = assert!(
std::mem::size_of::<Hidden>() == 0,
"Hidden **must** be zero sized"
);
const _ASSERT_ALIGNED: () = assert!(
std::mem::align_of::<Hidden>() == 1,
"Hidden **must** be alignment 1"
);
macro_rules! dispatched {
($name:ident, { $($Ts:ident )* }, { $($xs:ident )* }, { $($lt:lifetime )* }) => {
#[derive(Debug)]
#[repr(transparent)]
pub struct $name<R, $($Ts,)*>
where
$($Ts: AddLifetime,)*
{
f: for<$($lt,)*> unsafe fn(Hidden, $($Ts::Of<$lt>,)*) -> R,
}
impl<R, $($Ts,)*> $name<R, $($Ts,)*>
where
$($Ts: AddLifetime,)*
{
unsafe fn new(f: unsafe fn(Hidden, $($Ts::Of<'_>,)*) -> R) -> Self {
Self { f }
}
#[inline(always)]
pub fn call(self, $($xs: $Ts::Of<'_>,)*) -> R {
unsafe { (self.f)(Hidden, $($xs,)*) }
}
}
impl<R, $($Ts,)*> Clone for $name<R, $($Ts,)*>
where
$($Ts: AddLifetime,)*
{
fn clone(&self) -> Self {
*self
}
}
impl<R, $($Ts,)*> Copy for $name<R, $($Ts,)*>
where
$($Ts: AddLifetime,)*
{
}
}
}
dispatched!(Dispatched1, { T0 }, { x0 }, { 'a0 });
dispatched!(Dispatched2, { T0 T1 }, { x0 x1 }, { 'a0 'a1 });
dispatched!(Dispatched3, { T0 T1 T2 }, { x0 x1 x2 }, { 'a0 'a1 'a2 });
macro_rules! hide {
($name:ident, $dispatched:ident, { $($Ts:ident )* }) => {
unsafe fn $name<A, R, $($Ts,)*>(
f: unsafe fn(A, $($Ts::Of<'_>,)*) -> R
) -> $dispatched<R, $($Ts,)*>
where
$($Ts: AddLifetime,)*
{
const {
assert!(
std::mem::size_of::<A>() == 0,
"A must be zero sized to be ABI compatible with `Hidden`"
)
};
const {
assert!(
std::mem::align_of::<A>() == 1,
"A must have an alignment of 1 to be ABI compatible with `Hidden`"
)
};
let f = unsafe {
std::mem::transmute::<
unsafe fn(A, $($Ts::Of<'_>,)*) -> R,
unsafe fn(Hidden, $($Ts::Of<'_>,)*) -> R
>(f)
};
unsafe { $dispatched::new(f) }
}
}
}
hide!(hide1, Dispatched1, { T0 });
hide!(hide2, Dispatched2, { T0 T1 });
hide!(hide3, Dispatched3, { T0 T1 T2 });
macro_rules! maskdef {
($mask:ident = $repr:ty) => {
type $mask = <$repr as SIMDVector>::Mask;
};
($($mask:ident = $repr:ty),+ $(,)?) => {
$($crate::arch::maskdef!($mask = $repr);)+
};
() => {
$crate::arch::maskdef!(
mask_f16x8 = f16x8,
mask_f16x16 = f16x16,
mask_f32x4 = f32x4,
mask_f32x8 = f32x8,
mask_f32x16 = f32x16,
mask_i8x16 = i8x16,
mask_i8x32 = i8x32,
mask_i8x64 = i8x64,
mask_i16x8 = i16x8,
mask_i16x16 = i16x16,
mask_i16x32 = i16x32,
mask_i32x4 = i32x4,
mask_i32x8 = i32x8,
mask_i32x16 = i32x16,
mask_u8x16 = u8x16,
mask_u8x32 = u8x32,
mask_u8x64 = u8x64,
mask_u32x4 = u32x4,
mask_u32x8 = u32x8,
mask_u32x16 = u32x16,
mask_u64x2 = u64x2,
mask_u64x4 = u64x4,
);
};
}
macro_rules! typedef {
() => {
$crate::arch::typedef!(
f16x8,
f16x16,
f32x4,
f32x8,
f32x16,
i8x16,
i8x32,
i8x64,
i16x8,
i16x16,
i16x32,
i32x4,
i32x8,
i32x16,
u8x16,
u8x32,
u8x64,
u32x4,
u32x8,
u32x16,
u64x2,
u64x4,
);
};
($repr:ident) => {
type $repr = $repr;
};
($($repr:ident),+ $(,)?) => {
$($crate::arch::typedef!($repr);)+
};
}
pub(crate) use maskdef;
pub(crate) use typedef;
#[cfg(test)]
mod tests {
use super::*;
use crate::lifetime::{Mut, Ref};
struct TestOp;
impl<A> Target<A, &'static str> for TestOp
where
A: Architecture,
{
fn run(self, _: A) -> &'static str {
"hello world"
}
}
impl<A> Target1<A, f32, &[f32]> for TestOp
where
A: Architecture,
{
fn run(self, _: A, x: &[f32]) -> f32 {
x.iter().sum()
}
}
impl<A> FTarget1<A, f32, &[f32]> for TestOp
where
A: Architecture,
{
fn run(arch: A, x: &[f32]) -> f32 {
<_ as Target1<_, _, _>>::run(Self, arch, x)
}
}
impl<A> Target2<A, f32, &mut [f32], &[f32]> for TestOp
where
A: Architecture,
{
fn run(self, _: A, x: &mut [f32], y: &[f32]) -> f32 {
x.copy_from_slice(y);
y.iter().sum()
}
}
impl<A> FTarget2<A, f32, &mut [f32], &[f32]> for TestOp
where
A: Architecture,
{
fn run(arch: A, x: &mut [f32], y: &[f32]) -> f32 {
<_ as Target2<_, _, _, _>>::run(TestOp, arch, x, y)
}
}
impl<A> Target3<A, f32, &mut [f32], &[f32], f32> for TestOp
where
A: Architecture,
{
fn run(self, _: A, x: &mut [f32], y: &[f32], z: f32) -> f32 {
assert_eq!(x.len(), y.len());
x.iter_mut().zip(y.iter()).for_each(|(d, s)| *d = *s + z);
y.iter().sum()
}
}
impl<A> FTarget3<A, f32, &mut [f32], &[f32], f32> for TestOp
where
A: Architecture,
{
fn run(arch: A, x: &mut [f32], y: &[f32], z: f32) -> f32 {
<_ as Target3<_, _, _, _, _>>::run(TestOp, arch, x, y, z)
}
}
#[test]
fn zero_arg_target() {
let expected = "hello world";
assert_eq!((Scalar).run(TestOp), expected);
assert_eq!((Scalar).run_inline(TestOp), expected);
#[cfg(target_arch = "x86_64")]
if let Some(arch) = x86_64::V3::new_checked_uncached() {
assert_eq!(arch.run(TestOp), expected);
assert_eq!(arch.run_inline(TestOp), expected);
}
#[cfg(target_arch = "x86_64")]
if let Some(arch) = x86_64::V4::new_checked_miri() {
assert_eq!(arch.run(TestOp), expected);
assert_eq!(arch.run_inline(TestOp), expected);
}
#[cfg(target_arch = "aarch64")]
if let Some(arch) = aarch64::Neon::new_checked() {
assert_eq!(arch.run(TestOp), expected);
assert_eq!(arch.run_inline(TestOp), expected);
}
}
#[test]
fn one_arg_target() {
let src = [1.0f32, 2.0f32, 3.0f32];
let sum: f32 = src.iter().sum();
assert_eq!((Scalar).run1(TestOp, &src), sum);
assert_eq!((Scalar).run1_inline(TestOp, &src), sum);
#[cfg(target_arch = "x86_64")]
if let Some(arch) = x86_64::V3::new_checked_uncached() {
assert_eq!(arch.run1(TestOp, &src), sum);
assert_eq!(arch.run1_inline(TestOp, &src), sum);
}
#[cfg(target_arch = "x86_64")]
if let Some(arch) = x86_64::V4::new_checked_miri() {
assert_eq!(arch.run1(TestOp, &src), sum);
assert_eq!(arch.run1_inline(TestOp, &src), sum);
}
#[cfg(target_arch = "aarch64")]
if let Some(arch) = aarch64::Neon::new_checked() {
assert_eq!(arch.run1(TestOp, &src), sum);
assert_eq!(arch.run1_inline(TestOp, &src), sum);
}
}
#[test]
fn two_arg_target() {
let src = [1.0f32, 2.0f32, 3.0f32];
let sum: f32 = src.iter().sum();
macro_rules! gen_test {
($arch:ident) => {{
let mut dst = [0.0f32; 3];
assert_eq!($arch.run2(TestOp, &mut dst, &src), sum);
assert_eq!(dst, src);
}
{
let mut dst = [0.0f32; 3];
assert_eq!($arch.run2_inline(TestOp, &mut dst, &src), sum);
assert_eq!(dst, src);
}};
}
gen_test!(Scalar);
#[cfg(target_arch = "x86_64")]
if let Some(arch) = x86_64::V3::new_checked_uncached() {
gen_test!(arch);
}
#[cfg(target_arch = "x86_64")]
if let Some(arch) = x86_64::V4::new_checked_miri() {
gen_test!(arch);
}
#[cfg(target_arch = "aarch64")]
if let Some(arch) = aarch64::Neon::new_checked() {
gen_test!(arch);
}
}
#[test]
fn three_arg_target() {
let src = [1.0f32, 2.0f32, 3.0f32];
let sum: f32 = src.iter().sum();
let offset = 10.0f32;
let expected = [11.0f32, 12.0f32, 13.0f32];
macro_rules! gen_test {
($arch:ident) => {{
let mut dst = [0.0f32; 3];
assert_eq!($arch.run3(TestOp, &mut dst, &src, offset), sum);
assert_eq!(dst, expected);
}
{
let mut dst = [0.0f32; 3];
assert_eq!($arch.run3_inline(TestOp, &mut dst, &src, offset), sum);
assert_eq!(dst, expected);
}};
}
gen_test!(Scalar);
#[cfg(target_arch = "x86_64")]
if let Some(arch) = x86_64::V3::new_checked_uncached() {
gen_test!(arch);
}
#[cfg(target_arch = "x86_64")]
if let Some(arch) = x86_64::V4::new_checked_miri() {
gen_test!(arch);
}
#[cfg(target_arch = "aarch64")]
if let Some(arch) = aarch64::Neon::new_checked() {
gen_test!(arch);
}
}
#[test]
fn one_arg_function_pointer() {
let src = [1.0f32, 2.0f32, 3.0f32];
let sum: f32 = src.iter().sum();
type FnPtr = Dispatched1<f32, Ref<[f32]>>;
assert_eq!(std::mem::size_of::<FnPtr>(), std::mem::size_of::<fn()>());
assert_eq!(
std::mem::size_of::<Option<FnPtr>>(),
std::mem::size_of::<fn()>()
);
{
let f: FnPtr = (Scalar).dispatch1::<TestOp, f32, Ref<[f32]>>();
assert_eq!(f.call(&src), sum);
}
#[cfg(target_arch = "x86_64")]
if let Some(arch) = x86_64::V3::new_checked_uncached() {
let f: FnPtr = arch.dispatch1::<TestOp, f32, Ref<[f32]>>();
assert_eq!(f.call(&src), sum);
}
#[cfg(target_arch = "x86_64")]
if let Some(arch) = x86_64::V4::new_checked_miri() {
let f: FnPtr = arch.dispatch1::<TestOp, f32, Ref<[f32]>>();
assert_eq!(f.call(&src), sum);
}
#[cfg(target_arch = "aarch64")]
if let Some(arch) = aarch64::Neon::new_checked() {
let f: FnPtr = arch.dispatch1::<TestOp, f32, Ref<[f32]>>();
assert_eq!(f.call(&src), sum);
}
}
#[test]
fn two_arg_function_pointer() {
let src = [1.0f32, 2.0f32, 3.0f32];
let sum: f32 = src.iter().sum();
type FnPtr = Dispatched2<f32, Mut<[f32]>, Ref<[f32]>>;
assert_eq!(std::mem::size_of::<FnPtr>(), std::mem::size_of::<fn()>());
assert_eq!(
std::mem::size_of::<Option<FnPtr>>(),
std::mem::size_of::<fn()>()
);
{
let mut dst = [0.0f32; 3];
let f: FnPtr = (Scalar).dispatch2::<TestOp, f32, Mut<[f32]>, Ref<[f32]>>();
assert_eq!(f.call(&mut dst, &src), sum);
assert_eq!(dst, src);
}
#[cfg(target_arch = "x86_64")]
if let Some(arch) = x86_64::V3::new_checked_uncached() {
let mut dst = [0.0f32; 3];
let f: FnPtr = arch.dispatch2::<TestOp, f32, Mut<[f32]>, Ref<[f32]>>();
assert_eq!(f.call(&mut dst, &src), sum);
assert_eq!(dst, src);
}
#[cfg(target_arch = "x86_64")]
if let Some(arch) = x86_64::V4::new_checked_miri() {
let mut dst = [0.0f32; 3];
let f: FnPtr = arch.dispatch2::<TestOp, f32, Mut<[f32]>, Ref<[f32]>>();
assert_eq!(f.call(&mut dst, &src), sum);
assert_eq!(dst, src);
}
#[cfg(target_arch = "aarch64")]
if let Some(arch) = aarch64::Neon::new_checked() {
let mut dst = [0.0f32; 3];
let f: FnPtr = arch.dispatch2::<TestOp, f32, Mut<[f32]>, Ref<[f32]>>();
assert_eq!(f.call(&mut dst, &src), sum);
assert_eq!(dst, src);
}
}
#[test]
fn three_arg_function_pointer() {
let src = [1.0f32, 2.0f32, 3.0f32];
let sum: f32 = src.iter().sum();
let offset = 10.0f32;
let expected = [11.0f32, 12.0f32, 13.0f32];
type FnPtr = Dispatched3<f32, Mut<[f32]>, Ref<[f32]>, f32>;
assert_eq!(std::mem::size_of::<FnPtr>(), std::mem::size_of::<fn()>());
assert_eq!(
std::mem::size_of::<Option<FnPtr>>(),
std::mem::size_of::<fn()>()
);
{
let mut dst = [0.0f32; 3];
let f: FnPtr = (Scalar).dispatch3::<TestOp, f32, Mut<[f32]>, Ref<[f32]>, f32>();
assert_eq!(f.call(&mut dst, &src, offset), sum);
assert_eq!(dst, expected);
}
#[cfg(target_arch = "x86_64")]
if let Some(arch) = x86_64::V3::new_checked_uncached() {
let mut dst = [0.0f32; 3];
let f: FnPtr = arch.dispatch3::<TestOp, f32, Mut<[f32]>, Ref<[f32]>, f32>();
assert_eq!(f.call(&mut dst, &src, offset), sum);
assert_eq!(dst, expected);
}
#[cfg(target_arch = "x86_64")]
if let Some(arch) = x86_64::V4::new_checked_miri() {
let mut dst = [0.0f32; 3];
let f: FnPtr = arch.dispatch3::<TestOp, f32, Mut<[f32]>, Ref<[f32]>, f32>();
assert_eq!(f.call(&mut dst, &src, offset), sum);
assert_eq!(dst, expected);
}
#[cfg(target_arch = "aarch64")]
if let Some(arch) = aarch64::Neon::new_checked() {
let mut dst = [0.0f32; 3];
let f: FnPtr = arch.dispatch3::<TestOp, f32, Mut<[f32]>, Ref<[f32]>, f32>();
assert_eq!(f.call(&mut dst, &src, offset), sum);
assert_eq!(dst, expected);
}
}
}