rten_simd/
dispatch.rs

1//! Dispatch SIMD operations using the preferred SIMD instruction set for the
2//! current system, as determined at runtime.
3
4use std::mem::MaybeUninit;
5
6use crate::functional::simd_map;
7use crate::span::{MutPtrLen, PtrLen};
8use crate::SimdFloat;
9
10/// Dispatches SIMD operations using the preferred SIMD types for the current
11/// platform.
12#[derive(Default)]
13pub struct SimdDispatcher {}
14
15impl SimdDispatcher {
16    /// Evaluate `op` using the preferred SIMD instruction set for the current
17    /// system.
18    #[allow(unused_imports)]
19    #[allow(unreachable_code)] // Ignore fallback, if unused
20    pub fn dispatch<Op: SimdOp>(&self, op: Op) -> Op::Output {
21        #[cfg(feature = "avx512")]
22        #[cfg(target_arch = "x86_64")]
23        #[target_feature(enable = "avx512f")]
24        #[target_feature(enable = "avx512vl")]
25        unsafe fn simd_op_avx512<Op: SimdOp>(op: Op) -> Op::Output {
26            use std::arch::x86_64::__m512;
27            op.eval::<__m512>()
28        }
29
30        #[cfg(target_arch = "x86_64")]
31        #[target_feature(enable = "avx2")]
32        #[target_feature(enable = "fma")]
33        unsafe fn simd_op_avx<Op: SimdOp>(op: Op) -> Op::Output {
34            use std::arch::x86_64::__m256;
35            op.eval::<__m256>()
36        }
37
38        #[cfg(target_arch = "x86_64")]
39        {
40            #[cfg(feature = "avx512")]
41            if crate::is_avx512_supported() {
42                return unsafe { simd_op_avx512(op) };
43            }
44
45            if is_x86_feature_detected!("fma") && is_x86_feature_detected!("avx2") {
46                // Safety: We've checked that AVX2 + FMA are available.
47                return unsafe { simd_op_avx(op) };
48            }
49        }
50
51        #[cfg(target_arch = "wasm32")]
52        #[cfg(target_feature = "simd128")]
53        {
54            use crate::arch::wasm::v128f;
55
56            // Safety: The WASM runtime will have verified SIMD instructions
57            // are accepted when loading the binary.
58            return unsafe { op.eval::<v128f>() };
59        }
60
61        #[cfg(target_arch = "aarch64")]
62        {
63            use std::arch::aarch64::float32x4_t;
64            return unsafe { op.eval::<float32x4_t>() };
65        }
66
67        // Generic fallback.
68        unsafe { op.eval::<f32>() }
69    }
70}
71
72/// Run `op` using the default SIMD dispatch configuration.
73pub fn dispatch<Op: SimdOp>(op: Op) -> Op::Output {
74    SimdDispatcher::default().dispatch(op)
75}
76
77/// Trait for SIMD operations which can be evaluated using different SIMD
78/// vector types.
79///
80/// To dispatch the operation using the preferred instruction set for the
81/// current system, call the [`dispatch`](SimdOp::dispatch) method.
82pub trait SimdOp {
83    /// Output type returned by the operation.
84    type Output;
85
86    /// Evaluate the operator using a given SIMD vector type.
87    ///
88    /// # Safety
89    ///
90    /// The caller must ensure that the `S` is a supported SIMD vector type
91    /// on the current system.
92    unsafe fn eval<S: SimdFloat>(self) -> Self::Output;
93
94    /// Evaluate this operator using the default SIMD dispatch configuration
95    /// for the current platform.
96    ///
97    /// To customize the dispatch, use the [`SimdDispatcher`] API directly.
98    fn dispatch(self) -> Self::Output
99    where
100        Self: Sized,
101    {
102        SimdDispatcher::default().dispatch(self)
103    }
104}
105
106/// Trait for evaluating a unary function on a SIMD vector.
107pub trait SimdUnaryOp {
108    /// Evaluate the unary function on the elements in `x`.
109    ///
110    /// # Safety
111    ///
112    /// The caller must ensure that the `S` is a supported SIMD vector type
113    /// on the current system.
114    unsafe fn eval<S: SimdFloat>(&self, x: S) -> S;
115
116    /// Evaluate the unary function on elements in `x`.
117    ///
118    /// This is a shorthand for `Self::default().eval(x)`. It is mainly useful
119    /// when one vectorized operation needs to call another as part of its
120    /// implementation.
121    ///
122    /// # Safety
123    ///
124    /// See safety notes for [`eval`](SimdUnaryOp::eval).
125    #[inline(always)]
126    unsafe fn apply<S: SimdFloat>(x: S) -> S
127    where
128        Self: Default,
129    {
130        Self::default().eval(x)
131    }
132
133    /// Evaluate the unary function on `x`.
134    fn scalar_eval(&self, x: f32) -> f32 {
135        // Safety: `f32` is a supported "SIMD" type on all platforms.
136        unsafe { self.eval(x) }
137    }
138
139    /// Apply this function to a slice.
140    ///
141    /// This reads elements from `input` in SIMD vector-sized chunks, applies
142    /// `op` and writes the results to `output`.
143    fn map(&self, input: &[f32], output: &mut [MaybeUninit<f32>])
144    where
145        Self: Sized,
146    {
147        let wrapped_op = SimdMapOp::wrap(input.into(), output.into(), self);
148        dispatch(wrapped_op)
149    }
150
151    /// Apply a vectorized unary function to a mutable slice.
152    ///
153    /// This is similar to [`map`](SimdUnaryOp::map) but reads and writes to the
154    /// same slice.
155    fn map_mut(&self, input: &mut [f32])
156    where
157        Self: Sized,
158    {
159        let out: MutPtrLen<f32> = input.into();
160        let wrapped_op = SimdMapOp::wrap(input.into(), out.as_uninit(), self);
161        dispatch(wrapped_op)
162    }
163}
164
165/// SIMD operation which applies a unary operator `Op` to all elements in
166/// an input buffer using [`simd_map`].
167pub struct SimdMapOp<'a, Op: SimdUnaryOp> {
168    input: PtrLen<f32>,
169    output: MutPtrLen<MaybeUninit<f32>>,
170    op: &'a Op,
171}
172
173impl<'a, Op: SimdUnaryOp> SimdMapOp<'a, Op> {
174    pub fn wrap(
175        input: PtrLen<f32>,
176        output: MutPtrLen<MaybeUninit<f32>>,
177        op: &'a Op,
178    ) -> SimdMapOp<'a, Op> {
179        SimdMapOp { input, output, op }
180    }
181}
182
183impl<Op: SimdUnaryOp> SimdOp for SimdMapOp<'_, Op> {
184    type Output = ();
185
186    #[inline(always)]
187    unsafe fn eval<S: SimdFloat>(self) {
188        simd_map(
189            self.input,
190            self.output,
191            #[inline(always)]
192            |x: S| self.op.eval(x),
193        );
194    }
195}