use std::mem::MaybeUninit;
use crate::Isa;
use crate::functional::simd_map;
use crate::ops::{GetNumOps, GetSimd};
use crate::span::SrcDest;
pub trait SimdOp {
type Output;
fn eval<I: Isa>(self, isa: I) -> Self::Output;
fn dispatch(self) -> Self::Output
where
Self: Sized,
{
dispatch(self)
}
}
pub fn dispatch<Op: SimdOp>(op: Op) -> Op::Output {
#[cfg(target_arch = "aarch64")]
if let Some(isa) = super::arch::aarch64::ArmNeonIsa::new() {
return op.eval(isa);
}
#[cfg(target_arch = "x86_64")]
{
{
#[target_feature(enable = "avx512f")]
#[target_feature(enable = "avx512vl")]
#[target_feature(enable = "avx512bw")]
#[target_feature(enable = "avx512dq")]
unsafe fn dispatch_avx512<Op: SimdOp>(isa: impl Isa, op: Op) -> Op::Output {
op.eval(isa)
}
if let Some(isa) = super::arch::x86_64::Avx512Isa::new() {
unsafe {
return dispatch_avx512(isa, op);
}
}
}
#[target_feature(enable = "avx2")]
#[target_feature(enable = "avx")]
#[target_feature(enable = "fma")]
unsafe fn dispatch_avx2<Op: SimdOp>(isa: impl Isa, op: Op) -> Op::Output {
op.eval(isa)
}
if let Some(isa) = super::arch::x86_64::Avx2Isa::new() {
unsafe {
return dispatch_avx2(isa, op);
}
}
}
#[cfg(target_arch = "wasm32")]
#[cfg(target_feature = "simd128")]
{
if let Some(isa) = super::arch::wasm32::Wasm32Isa::new() {
return op.eval(isa);
}
}
let isa = super::arch::generic::GenericIsa::new();
op.eval(isa)
}
pub trait SimdUnaryOp<T: GetSimd> {
fn eval<I: Isa>(&self, isa: I, x: T::Simd<I>) -> T::Simd<I>;
#[inline(always)]
fn apply<I: Isa>(isa: I, x: T::Simd<I>) -> T::Simd<I>
where
Self: Default,
{
Self::default().eval(isa, x)
}
fn map<'dst>(&self, input: &[T], output: &'dst mut [MaybeUninit<T>]) -> &'dst mut [T]
where
Self: Sized,
T: GetNumOps,
{
let wrapped_op = SimdMapOp::wrap((input, output).into(), self);
dispatch(wrapped_op)
}
#[allow(private_bounds)]
fn map_mut(&self, input: &mut [T])
where
Self: Sized,
T: GetNumOps,
{
let wrapped_op = SimdMapOp::wrap(input.into(), self);
dispatch(wrapped_op);
}
fn scalar_eval(&self, x: T) -> T
where
Self: Sized,
T: GetNumOps,
{
let mut array = [x];
self.map_mut(&mut array);
array[0]
}
}
struct SimdMapOp<'src, 'dst, 'op, T: GetSimd, Op: SimdUnaryOp<T>> {
src_dest: SrcDest<'src, 'dst, T>,
op: &'op Op,
}
impl<'src, 'dst, 'op, T: GetSimd, Op: SimdUnaryOp<T>> SimdMapOp<'src, 'dst, 'op, T, Op> {
pub fn wrap(src_dest: SrcDest<'src, 'dst, T>, op: &'op Op) -> Self {
SimdMapOp { src_dest, op }
}
}
impl<'dst, T: GetNumOps + GetSimd, Op: SimdUnaryOp<T>> SimdOp for SimdMapOp<'_, 'dst, '_, T, Op> {
type Output = &'dst mut [T];
#[inline(always)]
fn eval<I: Isa>(self, isa: I) -> Self::Output {
simd_map(
T::num_ops(isa),
self.src_dest,
#[inline(always)]
|x| self.op.eval(isa, x),
)
}
}
#[cfg(test)]
macro_rules! test_simd_op {
($isa:ident, $op:block) => {{
struct TestOp {}
impl SimdOp for TestOp {
type Output = ();
fn eval<I: Isa>(self, $isa: I) {
$op
}
}
TestOp {}.dispatch()
}};
}
#[cfg(test)]
pub(crate) use test_simd_op;
#[cfg(test)]
mod tests {
use super::SimdUnaryOp;
use crate::Isa;
use crate::ops::{FloatOps, GetNumOps, GetSimd, NumOps};
#[test]
fn test_unary_float_op() {
struct Reciprocal {}
impl SimdUnaryOp<f32> for Reciprocal {
fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
let ops = isa.f32();
ops.div(ops.one(), x)
}
}
let mut buf = [1., 2., 3., 4.];
Reciprocal {}.map_mut(&mut buf);
assert_eq!(buf, [1., 1. / 2., 1. / 3., 1. / 4.]);
}
#[test]
fn test_unary_generic_op() {
struct Double {}
impl<T> SimdUnaryOp<T> for Double
where
T: GetSimd + GetNumOps,
{
fn eval<I: Isa>(&self, isa: I, x: T::Simd<I>) -> T::Simd<I> {
let ops = T::num_ops(isa);
ops.add(x, x)
}
}
let mut buf = [1i32, 2, 3, 4];
Double {}.map_mut(&mut buf);
assert_eq!(buf, [2, 4, 6, 8]);
let mut buf = [1.0f32, 2., 3., 4.];
Double {}.map_mut(&mut buf);
assert_eq!(buf, [2., 4., 6., 8.]);
}
}