1use std::mem::MaybeUninit;
2
3use crate::functional::simd_map;
4use crate::ops::GetNumOps;
5use crate::span::SrcDest;
6use crate::{Elem, Isa, Simd};
7
8pub trait SimdOp {
11 type Output;
13
14 fn eval<I: Isa>(self, isa: I) -> Self::Output;
16
17 fn dispatch(self) -> Self::Output
19 where
20 Self: Sized,
21 {
22 dispatch(self)
23 }
24}
25
26pub fn dispatch<Op: SimdOp>(op: Op) -> Op::Output {
31 #[cfg(target_arch = "aarch64")]
32 if let Some(isa) = super::arch::aarch64::ArmNeonIsa::new() {
33 return op.eval(isa);
34 }
35
36 #[cfg(target_arch = "x86_64")]
37 {
38 #[cfg(feature = "avx512")]
39 {
40 #[target_feature(enable = "avx512f")]
42 #[target_feature(enable = "avx512vl")]
43 #[target_feature(enable = "avx512bw")]
44 #[target_feature(enable = "avx512dq")]
45 unsafe fn dispatch_avx512<Op: SimdOp>(isa: impl Isa, op: Op) -> Op::Output {
46 op.eval(isa)
47 }
48
49 if let Some(isa) = super::arch::x86_64::Avx512Isa::new() {
50 unsafe {
52 return dispatch_avx512(isa, op);
53 }
54 }
55 }
56
57 #[target_feature(enable = "avx2")]
59 #[target_feature(enable = "avx")]
60 #[target_feature(enable = "fma")]
61 unsafe fn dispatch_avx2<Op: SimdOp>(isa: impl Isa, op: Op) -> Op::Output {
62 op.eval(isa)
63 }
64
65 if let Some(isa) = super::arch::x86_64::Avx2Isa::new() {
66 unsafe {
68 return dispatch_avx2(isa, op);
69 }
70 }
71 }
72
73 #[cfg(target_arch = "wasm32")]
74 #[cfg(target_feature = "simd128")]
75 {
76 if let Some(isa) = super::arch::wasm32::Wasm32Isa::new() {
77 return op.eval(isa);
78 }
79 }
80
81 let isa = super::arch::generic::GenericIsa::new();
82 op.eval(isa)
83}
84
85pub trait SimdUnaryOp<T: Elem> {
87 fn eval<I: Isa, S: Simd<Elem = T, Isa = I>>(&self, isa: I, x: S) -> S;
108
109 #[inline(always)]
115 fn apply<I: Isa, S: Simd<Elem = T, Isa = I>>(isa: I, x: S) -> S
116 where
117 Self: Default,
118 {
119 Self::default().eval(isa, x)
120 }
121
122 #[allow(private_bounds)]
127 fn map(&self, input: &[T], output: &mut [MaybeUninit<T>])
128 where
129 Self: Sized,
130 T: GetNumOps,
131 {
132 let wrapped_op = SimdMapOp::wrap((input, output).into(), self);
133 dispatch(wrapped_op);
134 }
135
136 #[allow(private_bounds)]
141 fn map_mut(&self, input: &mut [T])
142 where
143 Self: Sized,
144 T: GetNumOps,
145 {
146 let wrapped_op = SimdMapOp::wrap(input.into(), self);
147 dispatch(wrapped_op);
148 }
149
150 #[allow(private_bounds)]
152 fn scalar_eval(&self, x: T) -> T
153 where
154 Self: Sized,
155 T: GetNumOps,
156 {
157 let mut array = [x];
158 self.map_mut(&mut array);
159 array[0]
160 }
161}
162
163struct SimdMapOp<'src, 'dst, 'op, T: Elem, Op: SimdUnaryOp<T>> {
166 src_dest: SrcDest<'src, 'dst, T>,
167 op: &'op Op,
168}
169
170impl<'src, 'dst, 'op, T: Elem, Op: SimdUnaryOp<T>> SimdMapOp<'src, 'dst, 'op, T, Op> {
171 pub fn wrap(src_dest: SrcDest<'src, 'dst, T>, op: &'op Op) -> Self {
172 SimdMapOp { src_dest, op }
173 }
174}
175
176impl<'dst, T: GetNumOps, Op: SimdUnaryOp<T>> SimdOp for SimdMapOp<'_, 'dst, '_, T, Op> {
177 type Output = &'dst mut [T];
178
179 #[inline(always)]
180 fn eval<I: Isa>(self, isa: I) -> Self::Output {
181 simd_map(
182 T::num_ops(isa),
183 self.src_dest,
184 #[inline(always)]
185 |x| self.op.eval(isa, x),
186 )
187 }
188}
189
190#[cfg(test)]
192macro_rules! test_simd_op {
193 ($isa:ident, $op:block) => {{
194 struct TestOp {}
195
196 impl SimdOp for TestOp {
197 type Output = ();
198
199 fn eval<I: Isa>(self, $isa: I) {
200 $op
201 }
202 }
203
204 TestOp {}.dispatch()
205 }};
206}
207
208#[cfg(test)]
209pub(crate) use test_simd_op;
210
211#[cfg(test)]
212mod tests {
213 use super::SimdUnaryOp;
214 use crate::ops::{FloatOps, GetNumOps, NumOps};
215 use crate::{Isa, Simd};
216
217 #[test]
218 fn test_unary_float_op() {
219 struct Reciprocal {}
220
221 impl SimdUnaryOp<f32> for Reciprocal {
222 fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
223 let ops = isa.f32();
224 let x = x.same_cast();
225 let y = ops.div(ops.one(), x);
226 y.same_cast()
227 }
228 }
229
230 let mut buf = [1., 2., 3., 4.];
231 Reciprocal {}.map_mut(&mut buf);
232
233 assert_eq!(buf, [1., 1. / 2., 1. / 3., 1. / 4.]);
234 }
235
236 #[test]
237 fn test_unary_generic_op() {
238 struct Double {}
239
240 impl<T: GetNumOps> SimdUnaryOp<T> for Double {
241 fn eval<I: Isa, S: Simd<Elem = T, Isa = I>>(&self, isa: I, x: S) -> S {
242 let ops = T::num_ops(isa);
243 let x = x.same_cast();
244 ops.add(x, x).same_cast()
245 }
246 }
247
248 let mut buf = [1i32, 2, 3, 4];
249 Double {}.map_mut(&mut buf);
250 assert_eq!(buf, [2, 4, 6, 8]);
251
252 let mut buf = [1.0f32, 2., 3., 4.];
253 Double {}.map_mut(&mut buf);
254 assert_eq!(buf, [2., 4., 6., 8.]);
255 }
256}