rten_simd/
dispatch.rs

1use std::mem::MaybeUninit;
2
3use crate::Isa;
4use crate::functional::simd_map;
5use crate::ops::{GetNumOps, GetSimd};
6use crate::span::SrcDest;
7
8/// A vectorized operation which can be instantiated for different instruction
9/// sets.
10pub trait SimdOp {
11    /// The type of the operation's result.
12    type Output;
13
14    /// Evaluate the operation using the given instruction set.
15    fn eval<I: Isa>(self, isa: I) -> Self::Output;
16
17    /// Dispatch this operation using the preferred ISA for the current platform.
18    fn dispatch(self) -> Self::Output
19    where
20        Self: Sized,
21    {
22        dispatch(self)
23    }
24}
25
26/// Invoke a SIMD operation using the preferred ISA for the current system.
27///
28/// This function will check the available SIMD instruction sets and then
29/// dispatch to [`SimdOp::eval`], passing the selected [`Isa`].
30pub fn dispatch<Op: SimdOp>(op: Op) -> Op::Output {
31    #[cfg(target_arch = "aarch64")]
32    if let Some(isa) = super::arch::aarch64::ArmNeonIsa::new() {
33        return op.eval(isa);
34    }
35
36    #[cfg(target_arch = "x86_64")]
37    {
38        {
39            // The target features enabled here must match those tested for by `Avx512Isa::new`.
40            #[target_feature(enable = "avx512f")]
41            #[target_feature(enable = "avx512vl")]
42            #[target_feature(enable = "avx512bw")]
43            #[target_feature(enable = "avx512dq")]
44            unsafe fn dispatch_avx512<Op: SimdOp>(isa: impl Isa, op: Op) -> Op::Output {
45                op.eval(isa)
46            }
47
48            if let Some(isa) = super::arch::x86_64::Avx512Isa::new() {
49                // Safety: AVX-512 is supported
50                unsafe {
51                    return dispatch_avx512(isa, op);
52                }
53            }
54        }
55
56        // The target features enabled here must match those tested for by `Avx2Isa::new`.
57        #[target_feature(enable = "avx2")]
58        #[target_feature(enable = "avx")]
59        #[target_feature(enable = "fma")]
60        unsafe fn dispatch_avx2<Op: SimdOp>(isa: impl Isa, op: Op) -> Op::Output {
61            op.eval(isa)
62        }
63
64        if let Some(isa) = super::arch::x86_64::Avx2Isa::new() {
65            // Safety: AVX2 is supported
66            unsafe {
67                return dispatch_avx2(isa, op);
68            }
69        }
70    }
71
72    #[cfg(target_arch = "wasm32")]
73    #[cfg(target_feature = "simd128")]
74    {
75        if let Some(isa) = super::arch::wasm32::Wasm32Isa::new() {
76            return op.eval(isa);
77        }
78    }
79
80    let isa = super::arch::generic::GenericIsa::new();
81    op.eval(isa)
82}
83
84/// Convenience trait for defining vectorized unary operations.
85pub trait SimdUnaryOp<T: GetSimd> {
86    /// Evaluate the unary function on the elements in `x`.
87    ///
88    /// ```
89    /// use rten_simd::{Isa, Simd, SimdUnaryOp};
90    /// use rten_simd::ops::{FloatOps, NumOps};
91    ///
92    /// struct Reciprocal {}
93    ///
94    /// impl SimdUnaryOp<f32> for Reciprocal {
95    ///     fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
96    ///         let ops = isa.f32();
97    ///         ops.div(ops.one(), x)
98    ///     }
99    /// }
100    /// ```
101    fn eval<I: Isa>(&self, isa: I, x: T::Simd<I>) -> T::Simd<I>;
102
103    /// Evaluate the unary function on elements in `x`.
104    ///
105    /// This is a shorthand for `Self::default().eval(x)`. It is mainly useful
106    /// when one vectorized operation needs to call another as part of its
107    /// implementation.
108    #[inline(always)]
109    fn apply<I: Isa>(isa: I, x: T::Simd<I>) -> T::Simd<I>
110    where
111        Self: Default,
112    {
113        Self::default().eval(isa, x)
114    }
115
116    /// Apply this function to a slice.
117    ///
118    /// This reads elements from `input` in SIMD vector-sized chunks, applies
119    /// the operation and writes the results to `output`.
120    fn map<'dst>(&self, input: &[T], output: &'dst mut [MaybeUninit<T>]) -> &'dst mut [T]
121    where
122        Self: Sized,
123        T: GetNumOps,
124    {
125        let wrapped_op = SimdMapOp::wrap((input, output).into(), self);
126        dispatch(wrapped_op)
127    }
128
129    /// Apply a vectorized unary function to a mutable slice.
130    ///
131    /// This is similar to [`map`](SimdUnaryOp::map) but reads and writes
132    /// to the same slice.
133    #[allow(private_bounds)]
134    fn map_mut(&self, input: &mut [T])
135    where
136        Self: Sized,
137        T: GetNumOps,
138    {
139        let wrapped_op = SimdMapOp::wrap(input.into(), self);
140        dispatch(wrapped_op);
141    }
142
143    /// Apply this operation to a single element.
144    fn scalar_eval(&self, x: T) -> T
145    where
146        Self: Sized,
147        T: GetNumOps,
148    {
149        let mut array = [x];
150        self.map_mut(&mut array);
151        array[0]
152    }
153}
154
155/// SIMD operation which applies a unary operator `Op` to all elements in
156/// an input buffer using [`simd_map`].
157struct SimdMapOp<'src, 'dst, 'op, T: GetSimd, Op: SimdUnaryOp<T>> {
158    src_dest: SrcDest<'src, 'dst, T>,
159    op: &'op Op,
160}
161
162impl<'src, 'dst, 'op, T: GetSimd, Op: SimdUnaryOp<T>> SimdMapOp<'src, 'dst, 'op, T, Op> {
163    pub fn wrap(src_dest: SrcDest<'src, 'dst, T>, op: &'op Op) -> Self {
164        SimdMapOp { src_dest, op }
165    }
166}
167
168impl<'dst, T: GetNumOps + GetSimd, Op: SimdUnaryOp<T>> SimdOp for SimdMapOp<'_, 'dst, '_, T, Op> {
169    type Output = &'dst mut [T];
170
171    #[inline(always)]
172    fn eval<I: Isa>(self, isa: I) -> Self::Output {
173        simd_map(
174            T::num_ops(isa),
175            self.src_dest,
176            #[inline(always)]
177            |x| self.op.eval(isa, x),
178        )
179    }
180}
181
182/// Convenience macro for defining and evaluating a SIMD operation.
183#[cfg(test)]
184macro_rules! test_simd_op {
185    ($isa:ident, $op:block) => {{
186        struct TestOp {}
187
188        impl SimdOp for TestOp {
189            type Output = ();
190
191            fn eval<I: Isa>(self, $isa: I) {
192                $op
193            }
194        }
195
196        TestOp {}.dispatch()
197    }};
198}
199
200#[cfg(test)]
201pub(crate) use test_simd_op;
202
203#[cfg(test)]
204mod tests {
205    use super::SimdUnaryOp;
206    use crate::Isa;
207    use crate::ops::{FloatOps, GetNumOps, GetSimd, NumOps};
208
209    #[test]
210    fn test_unary_float_op() {
211        struct Reciprocal {}
212
213        impl SimdUnaryOp<f32> for Reciprocal {
214            fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
215                let ops = isa.f32();
216                ops.div(ops.one(), x)
217            }
218        }
219
220        let mut buf = [1., 2., 3., 4.];
221        Reciprocal {}.map_mut(&mut buf);
222
223        assert_eq!(buf, [1., 1. / 2., 1. / 3., 1. / 4.]);
224    }
225
226    #[test]
227    fn test_unary_generic_op() {
228        struct Double {}
229
230        impl<T> SimdUnaryOp<T> for Double
231        where
232            T: GetSimd + GetNumOps,
233        {
234            fn eval<I: Isa>(&self, isa: I, x: T::Simd<I>) -> T::Simd<I> {
235                let ops = T::num_ops(isa);
236                ops.add(x, x)
237            }
238        }
239
240        let mut buf = [1i32, 2, 3, 4];
241        Double {}.map_mut(&mut buf);
242        assert_eq!(buf, [2, 4, 6, 8]);
243
244        let mut buf = [1.0f32, 2., 3., 4.];
245        Double {}.map_mut(&mut buf);
246        assert_eq!(buf, [2., 4., 6., 8.]);
247    }
248}