1use std::mem::MaybeUninit;
2
3use crate::Isa;
4use crate::functional::simd_map;
5use crate::ops::{GetNumOps, GetSimd};
6use crate::span::SrcDest;
7
8pub trait SimdOp {
11 type Output;
13
14 fn eval<I: Isa>(self, isa: I) -> Self::Output;
16
17 fn dispatch(self) -> Self::Output
19 where
20 Self: Sized,
21 {
22 dispatch(self)
23 }
24}
25
26pub fn dispatch<Op: SimdOp>(op: Op) -> Op::Output {
31 #[cfg(target_arch = "aarch64")]
32 if let Some(isa) = super::arch::aarch64::ArmNeonIsa::new() {
33 return op.eval(isa);
34 }
35
36 #[cfg(target_arch = "x86_64")]
37 {
38 {
39 #[target_feature(enable = "avx512f")]
41 #[target_feature(enable = "avx512vl")]
42 #[target_feature(enable = "avx512bw")]
43 #[target_feature(enable = "avx512dq")]
44 unsafe fn dispatch_avx512<Op: SimdOp>(isa: impl Isa, op: Op) -> Op::Output {
45 op.eval(isa)
46 }
47
48 if let Some(isa) = super::arch::x86_64::Avx512Isa::new() {
49 unsafe {
51 return dispatch_avx512(isa, op);
52 }
53 }
54 }
55
56 #[target_feature(enable = "avx2")]
58 #[target_feature(enable = "avx")]
59 #[target_feature(enable = "fma")]
60 unsafe fn dispatch_avx2<Op: SimdOp>(isa: impl Isa, op: Op) -> Op::Output {
61 op.eval(isa)
62 }
63
64 if let Some(isa) = super::arch::x86_64::Avx2Isa::new() {
65 unsafe {
67 return dispatch_avx2(isa, op);
68 }
69 }
70 }
71
72 #[cfg(target_arch = "wasm32")]
73 #[cfg(target_feature = "simd128")]
74 {
75 if let Some(isa) = super::arch::wasm32::Wasm32Isa::new() {
76 return op.eval(isa);
77 }
78 }
79
80 let isa = super::arch::generic::GenericIsa::new();
81 op.eval(isa)
82}
83
84pub trait SimdUnaryOp<T: GetSimd> {
86 fn eval<I: Isa>(&self, isa: I, x: T::Simd<I>) -> T::Simd<I>;
102
103 #[inline(always)]
109 fn apply<I: Isa>(isa: I, x: T::Simd<I>) -> T::Simd<I>
110 where
111 Self: Default,
112 {
113 Self::default().eval(isa, x)
114 }
115
116 fn map<'dst>(&self, input: &[T], output: &'dst mut [MaybeUninit<T>]) -> &'dst mut [T]
121 where
122 Self: Sized,
123 T: GetNumOps,
124 {
125 let wrapped_op = SimdMapOp::wrap((input, output).into(), self);
126 dispatch(wrapped_op)
127 }
128
129 #[allow(private_bounds)]
134 fn map_mut(&self, input: &mut [T])
135 where
136 Self: Sized,
137 T: GetNumOps,
138 {
139 let wrapped_op = SimdMapOp::wrap(input.into(), self);
140 dispatch(wrapped_op);
141 }
142
143 fn scalar_eval(&self, x: T) -> T
145 where
146 Self: Sized,
147 T: GetNumOps,
148 {
149 let mut array = [x];
150 self.map_mut(&mut array);
151 array[0]
152 }
153}
154
155struct SimdMapOp<'src, 'dst, 'op, T: GetSimd, Op: SimdUnaryOp<T>> {
158 src_dest: SrcDest<'src, 'dst, T>,
159 op: &'op Op,
160}
161
162impl<'src, 'dst, 'op, T: GetSimd, Op: SimdUnaryOp<T>> SimdMapOp<'src, 'dst, 'op, T, Op> {
163 pub fn wrap(src_dest: SrcDest<'src, 'dst, T>, op: &'op Op) -> Self {
164 SimdMapOp { src_dest, op }
165 }
166}
167
168impl<'dst, T: GetNumOps + GetSimd, Op: SimdUnaryOp<T>> SimdOp for SimdMapOp<'_, 'dst, '_, T, Op> {
169 type Output = &'dst mut [T];
170
171 #[inline(always)]
172 fn eval<I: Isa>(self, isa: I) -> Self::Output {
173 simd_map(
174 T::num_ops(isa),
175 self.src_dest,
176 #[inline(always)]
177 |x| self.op.eval(isa, x),
178 )
179 }
180}
181
182#[cfg(test)]
184macro_rules! test_simd_op {
185 ($isa:ident, $op:block) => {{
186 struct TestOp {}
187
188 impl SimdOp for TestOp {
189 type Output = ();
190
191 fn eval<I: Isa>(self, $isa: I) {
192 $op
193 }
194 }
195
196 TestOp {}.dispatch()
197 }};
198}
199
200#[cfg(test)]
201pub(crate) use test_simd_op;
202
203#[cfg(test)]
204mod tests {
205 use super::SimdUnaryOp;
206 use crate::Isa;
207 use crate::ops::{FloatOps, GetNumOps, GetSimd, NumOps};
208
209 #[test]
210 fn test_unary_float_op() {
211 struct Reciprocal {}
212
213 impl SimdUnaryOp<f32> for Reciprocal {
214 fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
215 let ops = isa.f32();
216 ops.div(ops.one(), x)
217 }
218 }
219
220 let mut buf = [1., 2., 3., 4.];
221 Reciprocal {}.map_mut(&mut buf);
222
223 assert_eq!(buf, [1., 1. / 2., 1. / 3., 1. / 4.]);
224 }
225
226 #[test]
227 fn test_unary_generic_op() {
228 struct Double {}
229
230 impl<T> SimdUnaryOp<T> for Double
231 where
232 T: GetSimd + GetNumOps,
233 {
234 fn eval<I: Isa>(&self, isa: I, x: T::Simd<I>) -> T::Simd<I> {
235 let ops = T::num_ops(isa);
236 ops.add(x, x)
237 }
238 }
239
240 let mut buf = [1i32, 2, 3, 4];
241 Double {}.map_mut(&mut buf);
242 assert_eq!(buf, [2, 4, 6, 8]);
243
244 let mut buf = [1.0f32, 2., 3., 4.];
245 Double {}.map_mut(&mut buf);
246 assert_eq!(buf, [2., 4., 6., 8.]);
247 }
248}