rten_simd/dispatch.rs
1//! Dispatch SIMD operations using the preferred SIMD instruction set for the
2//! current system, as determined at runtime.
3
4use std::mem::MaybeUninit;
5
6use crate::functional::simd_map;
7use crate::span::{MutPtrLen, PtrLen};
8use crate::SimdFloat;
9
10/// Dispatches SIMD operations using the preferred SIMD types for the current
11/// platform.
12#[derive(Default)]
13pub struct SimdDispatcher {}
14
15impl SimdDispatcher {
16 /// Evaluate `op` using the preferred SIMD instruction set for the current
17 /// system.
18 #[allow(unused_imports)]
19 #[allow(unreachable_code)] // Ignore fallback, if unused
20 pub fn dispatch<Op: SimdOp>(&self, op: Op) -> Op::Output {
21 #[cfg(feature = "avx512")]
22 #[cfg(target_arch = "x86_64")]
23 #[target_feature(enable = "avx512f")]
24 #[target_feature(enable = "avx512vl")]
25 unsafe fn simd_op_avx512<Op: SimdOp>(op: Op) -> Op::Output {
26 use std::arch::x86_64::__m512;
27 op.eval::<__m512>()
28 }
29
30 #[cfg(target_arch = "x86_64")]
31 #[target_feature(enable = "avx2")]
32 #[target_feature(enable = "fma")]
33 unsafe fn simd_op_avx<Op: SimdOp>(op: Op) -> Op::Output {
34 use std::arch::x86_64::__m256;
35 op.eval::<__m256>()
36 }
37
38 #[cfg(target_arch = "x86_64")]
39 {
40 #[cfg(feature = "avx512")]
41 if crate::is_avx512_supported() {
42 return unsafe { simd_op_avx512(op) };
43 }
44
45 if is_x86_feature_detected!("fma") && is_x86_feature_detected!("avx2") {
46 // Safety: We've checked that AVX2 + FMA are available.
47 return unsafe { simd_op_avx(op) };
48 }
49 }
50
51 #[cfg(target_arch = "wasm32")]
52 #[cfg(target_feature = "simd128")]
53 {
54 use crate::arch::wasm::v128f;
55
56 // Safety: The WASM runtime will have verified SIMD instructions
57 // are accepted when loading the binary.
58 return unsafe { op.eval::<v128f>() };
59 }
60
61 #[cfg(target_arch = "aarch64")]
62 {
63 use std::arch::aarch64::float32x4_t;
64 return unsafe { op.eval::<float32x4_t>() };
65 }
66
67 // Generic fallback.
68 unsafe { op.eval::<f32>() }
69 }
70}
71
72/// Run `op` using the default SIMD dispatch configuration.
73pub fn dispatch<Op: SimdOp>(op: Op) -> Op::Output {
74 SimdDispatcher::default().dispatch(op)
75}
76
77/// Trait for SIMD operations which can be evaluated using different SIMD
78/// vector types.
79///
80/// To dispatch the operation using the preferred instruction set for the
81/// current system, call the [`dispatch`](SimdOp::dispatch) method.
82pub trait SimdOp {
83 /// Output type returned by the operation.
84 type Output;
85
86 /// Evaluate the operator using a given SIMD vector type.
87 ///
88 /// # Safety
89 ///
90 /// The caller must ensure that the `S` is a supported SIMD vector type
91 /// on the current system.
92 unsafe fn eval<S: SimdFloat>(self) -> Self::Output;
93
94 /// Evaluate this operator using the default SIMD dispatch configuration
95 /// for the current platform.
96 ///
97 /// To customize the dispatch, use the [`SimdDispatcher`] API directly.
98 fn dispatch(self) -> Self::Output
99 where
100 Self: Sized,
101 {
102 SimdDispatcher::default().dispatch(self)
103 }
104}
105
106/// Trait for evaluating a unary function on a SIMD vector.
107pub trait SimdUnaryOp {
108 /// Evaluate the unary function on the elements in `x`.
109 ///
110 /// # Safety
111 ///
112 /// The caller must ensure that the `S` is a supported SIMD vector type
113 /// on the current system.
114 unsafe fn eval<S: SimdFloat>(&self, x: S) -> S;
115
116 /// Evaluate the unary function on elements in `x`.
117 ///
118 /// This is a shorthand for `Self::default().eval(x)`. It is mainly useful
119 /// when one vectorized operation needs to call another as part of its
120 /// implementation.
121 ///
122 /// # Safety
123 ///
124 /// See safety notes for [`eval`](SimdUnaryOp::eval).
125 #[inline(always)]
126 unsafe fn apply<S: SimdFloat>(x: S) -> S
127 where
128 Self: Default,
129 {
130 Self::default().eval(x)
131 }
132
133 /// Evaluate the unary function on `x`.
134 fn scalar_eval(&self, x: f32) -> f32 {
135 // Safety: `f32` is a supported "SIMD" type on all platforms.
136 unsafe { self.eval(x) }
137 }
138
139 /// Apply this function to a slice.
140 ///
141 /// This reads elements from `input` in SIMD vector-sized chunks, applies
142 /// `op` and writes the results to `output`.
143 fn map(&self, input: &[f32], output: &mut [MaybeUninit<f32>])
144 where
145 Self: Sized,
146 {
147 let wrapped_op = SimdMapOp::wrap(input.into(), output.into(), self);
148 dispatch(wrapped_op)
149 }
150
151 /// Apply a vectorized unary function to a mutable slice.
152 ///
153 /// This is similar to [`map`](SimdUnaryOp::map) but reads and writes to the
154 /// same slice.
155 fn map_mut(&self, input: &mut [f32])
156 where
157 Self: Sized,
158 {
159 let out: MutPtrLen<f32> = input.into();
160 let wrapped_op = SimdMapOp::wrap(input.into(), out.as_uninit(), self);
161 dispatch(wrapped_op)
162 }
163}
164
165/// SIMD operation which applies a unary operator `Op` to all elements in
166/// an input buffer using [`simd_map`].
167pub struct SimdMapOp<'a, Op: SimdUnaryOp> {
168 input: PtrLen<f32>,
169 output: MutPtrLen<MaybeUninit<f32>>,
170 op: &'a Op,
171}
172
173impl<'a, Op: SimdUnaryOp> SimdMapOp<'a, Op> {
174 pub fn wrap(
175 input: PtrLen<f32>,
176 output: MutPtrLen<MaybeUninit<f32>>,
177 op: &'a Op,
178 ) -> SimdMapOp<'a, Op> {
179 SimdMapOp { input, output, op }
180 }
181}
182
183impl<Op: SimdUnaryOp> SimdOp for SimdMapOp<'_, Op> {
184 type Output = ();
185
186 #[inline(always)]
187 unsafe fn eval<S: SimdFloat>(self) {
188 simd_map(
189 self.input,
190 self.output,
191 #[inline(always)]
192 |x: S| self.op.eval(x),
193 );
194 }
195}