rten_simd/
dispatch.rs

1use std::mem::MaybeUninit;
2
3use crate::functional::simd_map;
4use crate::ops::GetNumOps;
5use crate::span::SrcDest;
6use crate::{Elem, Isa, Simd};
7
8/// A vectorized operation which can be instantiated for different instruction
9/// sets.
10pub trait SimdOp {
11    /// The type of the operation's result.
12    type Output;
13
14    /// Evaluate the operation using the given instruction set.
15    fn eval<I: Isa>(self, isa: I) -> Self::Output;
16
17    /// Dispatch this operation using the preferred ISA for the current platform.
18    fn dispatch(self) -> Self::Output
19    where
20        Self: Sized,
21    {
22        dispatch(self)
23    }
24}
25
26/// Invoke a SIMD operation using the preferred ISA for the current system.
27///
28/// This function will check the available SIMD instruction sets and then
29/// dispatch to [`SimdOp::eval`], passing the selected [`Isa`].
30pub fn dispatch<Op: SimdOp>(op: Op) -> Op::Output {
31    #[cfg(target_arch = "aarch64")]
32    if let Some(isa) = super::arch::aarch64::ArmNeonIsa::new() {
33        return op.eval(isa);
34    }
35
36    #[cfg(target_arch = "x86_64")]
37    {
38        #[cfg(feature = "avx512")]
39        {
40            // The target features enabled here must match those tested for by `Avx512Isa::new`.
41            #[target_feature(enable = "avx512f")]
42            #[target_feature(enable = "avx512vl")]
43            #[target_feature(enable = "avx512bw")]
44            #[target_feature(enable = "avx512dq")]
45            unsafe fn dispatch_avx512<Op: SimdOp>(isa: impl Isa, op: Op) -> Op::Output {
46                op.eval(isa)
47            }
48
49            if let Some(isa) = super::arch::x86_64::Avx512Isa::new() {
50                // Safety: AVX-512 is supported
51                unsafe {
52                    return dispatch_avx512(isa, op);
53                }
54            }
55        }
56
57        // The target features enabled here must match those tested for by `Avx2Isa::new`.
58        #[target_feature(enable = "avx2")]
59        #[target_feature(enable = "avx")]
60        #[target_feature(enable = "fma")]
61        unsafe fn dispatch_avx2<Op: SimdOp>(isa: impl Isa, op: Op) -> Op::Output {
62            op.eval(isa)
63        }
64
65        if let Some(isa) = super::arch::x86_64::Avx2Isa::new() {
66            // Safety: AVX2 is supported
67            unsafe {
68                return dispatch_avx2(isa, op);
69            }
70        }
71    }
72
73    #[cfg(target_arch = "wasm32")]
74    #[cfg(target_feature = "simd128")]
75    {
76        if let Some(isa) = super::arch::wasm32::Wasm32Isa::new() {
77            return op.eval(isa);
78        }
79    }
80
81    let isa = super::arch::generic::GenericIsa::new();
82    op.eval(isa)
83}
84
85/// Convenience trait for defining vectorized unary operations.
86pub trait SimdUnaryOp<T: Elem> {
87    /// Evaluate the unary function on the elements in `x`.
88    ///
89    /// In order to perform operations on `x`, it will need to be cast to
90    /// the specific type used by the ISA:
91    ///
92    /// ```
93    /// use rten_simd::{Isa, Simd, SimdUnaryOp};
94    /// use rten_simd::ops::{FloatOps, NumOps};
95    ///
96    /// struct Reciprocal {}
97    ///
98    /// impl SimdUnaryOp<f32> for Reciprocal {
99    ///     fn eval<I: Isa, S: Simd<Elem=f32, Isa=I>>(&self, isa: I, x: S) -> S {
100    ///         let ops = isa.f32();
101    ///         let x = x.same_cast();
102    ///         let reciprocal = ops.div(ops.one(), x);
103    ///         reciprocal.same_cast()
104    ///     }
105    /// }
106    /// ```
107    fn eval<I: Isa, S: Simd<Elem = T, Isa = I>>(&self, isa: I, x: S) -> S;
108
109    /// Evaluate the unary function on elements in `x`.
110    ///
111    /// This is a shorthand for `Self::default().eval(x)`. It is mainly useful
112    /// when one vectorized operation needs to call another as part of its
113    /// implementation.
114    #[inline(always)]
115    fn apply<I: Isa, S: Simd<Elem = T, Isa = I>>(isa: I, x: S) -> S
116    where
117        Self: Default,
118    {
119        Self::default().eval(isa, x)
120    }
121
122    /// Apply this function to a slice.
123    ///
124    /// This reads elements from `input` in SIMD vector-sized chunks, applies
125    /// the operation and writes the results to `output`.
126    #[allow(private_bounds)]
127    fn map(&self, input: &[T], output: &mut [MaybeUninit<T>])
128    where
129        Self: Sized,
130        T: GetNumOps,
131    {
132        let wrapped_op = SimdMapOp::wrap((input, output).into(), self);
133        dispatch(wrapped_op);
134    }
135
136    /// Apply a vectorized unary function to a mutable slice.
137    ///
138    /// This is similar to [`map`](SimdUnaryOp::map) but reads and writes
139    /// to the same slice.
140    #[allow(private_bounds)]
141    fn map_mut(&self, input: &mut [T])
142    where
143        Self: Sized,
144        T: GetNumOps,
145    {
146        let wrapped_op = SimdMapOp::wrap(input.into(), self);
147        dispatch(wrapped_op);
148    }
149
150    /// Apply this operation to a single element.
151    #[allow(private_bounds)]
152    fn scalar_eval(&self, x: T) -> T
153    where
154        Self: Sized,
155        T: GetNumOps,
156    {
157        let mut array = [x];
158        self.map_mut(&mut array);
159        array[0]
160    }
161}
162
163/// SIMD operation which applies a unary operator `Op` to all elements in
164/// an input buffer using [`simd_map`].
165struct SimdMapOp<'src, 'dst, 'op, T: Elem, Op: SimdUnaryOp<T>> {
166    src_dest: SrcDest<'src, 'dst, T>,
167    op: &'op Op,
168}
169
170impl<'src, 'dst, 'op, T: Elem, Op: SimdUnaryOp<T>> SimdMapOp<'src, 'dst, 'op, T, Op> {
171    pub fn wrap(src_dest: SrcDest<'src, 'dst, T>, op: &'op Op) -> Self {
172        SimdMapOp { src_dest, op }
173    }
174}
175
176impl<'dst, T: GetNumOps, Op: SimdUnaryOp<T>> SimdOp for SimdMapOp<'_, 'dst, '_, T, Op> {
177    type Output = &'dst mut [T];
178
179    #[inline(always)]
180    fn eval<I: Isa>(self, isa: I) -> Self::Output {
181        simd_map(
182            T::num_ops(isa),
183            self.src_dest,
184            #[inline(always)]
185            |x| self.op.eval(isa, x),
186        )
187    }
188}
189
190/// Convenience macro for defining and evaluating a SIMD operation.
191#[cfg(test)]
192macro_rules! test_simd_op {
193    ($isa:ident, $op:block) => {{
194        struct TestOp {}
195
196        impl SimdOp for TestOp {
197            type Output = ();
198
199            fn eval<I: Isa>(self, $isa: I) {
200                $op
201            }
202        }
203
204        TestOp {}.dispatch()
205    }};
206}
207
208#[cfg(test)]
209pub(crate) use test_simd_op;
210
211#[cfg(test)]
212mod tests {
213    use super::SimdUnaryOp;
214    use crate::ops::{FloatOps, GetNumOps, NumOps};
215    use crate::{Isa, Simd};
216
217    #[test]
218    fn test_unary_float_op() {
219        struct Reciprocal {}
220
221        impl SimdUnaryOp<f32> for Reciprocal {
222            fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
223                let ops = isa.f32();
224                let x = x.same_cast();
225                let y = ops.div(ops.one(), x);
226                y.same_cast()
227            }
228        }
229
230        let mut buf = [1., 2., 3., 4.];
231        Reciprocal {}.map_mut(&mut buf);
232
233        assert_eq!(buf, [1., 1. / 2., 1. / 3., 1. / 4.]);
234    }
235
236    #[test]
237    fn test_unary_generic_op() {
238        struct Double {}
239
240        impl<T: GetNumOps> SimdUnaryOp<T> for Double {
241            fn eval<I: Isa, S: Simd<Elem = T, Isa = I>>(&self, isa: I, x: S) -> S {
242                let ops = T::num_ops(isa);
243                let x = x.same_cast();
244                ops.add(x, x).same_cast()
245            }
246        }
247
248        let mut buf = [1i32, 2, 3, 4];
249        Double {}.map_mut(&mut buf);
250        assert_eq!(buf, [2, 4, 6, 8]);
251
252        let mut buf = [1.0f32, 2., 3., 4.];
253        Double {}.map_mut(&mut buf);
254        assert_eq!(buf, [2., 4., 6., 8.]);
255    }
256}