redstone_ml/ops/
simd_binary_ops.rs

1use std::fmt::Display;
2use paste::paste;
3use crate::acceleration::simd::Simd;
4
5macro_rules! simd_elementwise_operations {
6    ($name:ident, $simd_op:ident, $operator:tt) => {
7        paste! {
8            #[cfg(neon_simd)]
9            unsafe fn [<simd_ $name _stride_0_1>](lhs: *const Self, mut rhs: *const Self, mut dst: *mut Self, mut count: usize) {
10                let a = Self::simd_from_constant(*lhs);
11
12                while count >= 4 * Self::LANES {
13                    let b0 = Self::simd_load(rhs.add(0 * Self::LANES));
14                    let b1 = Self::simd_load(rhs.add(1 * Self::LANES));
15                    let b2 = Self::simd_load(rhs.add(2 * Self::LANES));
16                    let b3 = Self::simd_load(rhs.add(3 * Self::LANES));
17
18                    let ab0 = Self::$simd_op(a, b0);
19                    let ab1 = Self::$simd_op(a, b1);
20                    let ab2 = Self::$simd_op(a, b2);
21                    let ab3 = Self::$simd_op(a, b3);
22
23                    Self::simd_store(dst.add(0 * Self::LANES), ab0);
24                    Self::simd_store(dst.add(1 * Self::LANES), ab1);
25                    Self::simd_store(dst.add(2 * Self::LANES), ab2);
26                    Self::simd_store(dst.add(3 * Self::LANES), ab3);
27
28                    count -= 4 * Self::LANES;
29                    rhs = rhs.add(4 * Self::LANES);
30                    dst = dst.add(4 * Self::LANES);
31                }
32
33                while count != 0 {
34                    *dst = *lhs $operator *rhs;
35
36                    count -= 1;
37                    rhs = rhs.add(1);
38                    dst = dst.add(1);
39                }
40            }
41
42            #[cfg(neon_simd)]
43            unsafe fn [<simd_ $name _stride_1_0>](mut lhs: *const Self, rhs: *const Self, mut dst: *mut Self, mut count: usize) {
44                let b = Self::simd_from_constant(*rhs);
45
46                while count >= 4 * Self::LANES {
47                    let a0 = Self::simd_load(lhs.add(0 * Self::LANES));
48                    let a1 = Self::simd_load(lhs.add(1 * Self::LANES));
49                    let a2 = Self::simd_load(lhs.add(2 * Self::LANES));
50                    let a3 = Self::simd_load(lhs.add(3 * Self::LANES));
51
52                    let ab0 = Self::$simd_op(a0, b);
53                    let ab1 = Self::$simd_op(a1, b);
54                    let ab2 = Self::$simd_op(a2, b);
55                    let ab3 = Self::$simd_op(a3, b);
56
57                    Self::simd_store(dst.add(0 * Self::LANES), ab0);
58                    Self::simd_store(dst.add(1 * Self::LANES), ab1);
59                    Self::simd_store(dst.add(2 * Self::LANES), ab2);
60                    Self::simd_store(dst.add(3 * Self::LANES), ab3);
61
62                    count -= 4 * Self::LANES;
63                    lhs = lhs.add(4 * Self::LANES);
64                    dst = dst.add(4 * Self::LANES);
65                }
66
67                while count != 0 {
68                    *dst = *lhs $operator *rhs;
69
70                    count -= 1;
71                    lhs = lhs.add(1);
72                    dst = dst.add(1);
73                }
74            }
75
76            #[cfg(neon_simd)]
77            unsafe fn [<simd_ $name _stride_1_1>](mut lhs: *const Self, mut rhs: *const Self, mut dst: *mut Self, mut count: usize) {
78                while count >= 4 * Self::LANES {
79                    let a0 = Self::simd_load(lhs.add(0 * Self::LANES));
80                    let b0 = Self::simd_load(rhs.add(0 * Self::LANES));
81
82                    let a1 = Self::simd_load(lhs.add(1 * Self::LANES));
83                    let b1 = Self::simd_load(rhs.add(1 * Self::LANES));
84
85                    let a2 = Self::simd_load(lhs.add(2 * Self::LANES));
86                    let b2 = Self::simd_load(rhs.add(2 * Self::LANES));
87
88                    let a3 = Self::simd_load(lhs.add(3 * Self::LANES));
89                    let b3 = Self::simd_load(rhs.add(3 * Self::LANES));
90
91                    let ab0 = Self::$simd_op(a0, b0);
92                    let ab1 = Self::$simd_op(a1, b1);
93                    let ab2 = Self::$simd_op(a2, b2);
94                    let ab3 = Self::$simd_op(a3, b3);
95
96                    Self::simd_store(dst.add(0 * Self::LANES), ab0);
97                    Self::simd_store(dst.add(1 * Self::LANES), ab1);
98                    Self::simd_store(dst.add(2 * Self::LANES), ab2);
99                    Self::simd_store(dst.add(3 * Self::LANES), ab3);
100
101                    count -= 4 * Self::LANES;
102                    lhs = lhs.add(4 * Self::LANES);
103                    rhs = rhs.add(4 * Self::LANES);
104                    dst = dst.add(4 * Self::LANES);
105                }
106
107                while count != 0 {
108                    *dst = *lhs $operator *rhs;
109
110                    count -= 1;
111                    lhs = lhs.add(1);
112                    rhs = rhs.add(1);
113                    dst = dst.add(1);
114                }
115            }
116
117            #[cfg(neon_simd)]
118            unsafe fn [<simd_ $name _stride_n_0>](mut lhs: *const Self, lhs_stride: usize, rhs: *const Self, mut dst: *mut Self, mut count: usize) {
119                 let b = Self::simd_from_constant(*rhs);
120
121                 while count >= 4 * Self::LANES {
122                    let a0 = Self::simd_vec_from_stride(lhs.add(0 * lhs_stride * Self::LANES), lhs_stride);
123                    let a1 = Self::simd_vec_from_stride(lhs.add(1 * lhs_stride * Self::LANES), lhs_stride);
124                    let a2 = Self::simd_vec_from_stride(lhs.add(2 * lhs_stride * Self::LANES), lhs_stride);
125                    let a3 = Self::simd_vec_from_stride(lhs.add(3 * lhs_stride * Self::LANES), lhs_stride);
126
127                    let ab0 = Self::$simd_op(a0, b);
128                    let ab1 = Self::$simd_op(a1, b);
129                    let ab2 = Self::$simd_op(a2, b);
130                    let ab3 = Self::$simd_op(a3, b);
131
132                    Self::simd_store(dst.add(0 * Self::LANES), ab0);
133                    Self::simd_store(dst.add(1 * Self::LANES), ab1);
134                    Self::simd_store(dst.add(2 * Self::LANES), ab2);
135                    Self::simd_store(dst.add(3 * Self::LANES), ab3);
136
137                    count -= 4 * Self::LANES;
138                    lhs = lhs.add(4 * lhs_stride * Self::LANES);
139                    dst = dst.add(4 * Self::LANES);
140                }
141
142                while count != 0 {
143                    *dst = *lhs $operator *rhs;
144
145                    count -= 1;
146                    lhs = lhs.add(lhs_stride);
147                    dst = dst.add(1);
148                }
149            }
150
151            #[cfg(neon_simd)]
152            unsafe fn [<simd_ $name _stride_0_n>](lhs: *const Self, mut rhs: *const Self, rhs_stride: usize, mut dst: *mut Self, mut count: usize) {
153                 let a = Self::simd_from_constant(*lhs);
154
155                 while count >= 4 * Self::LANES {
156                    let b0 = Self::simd_vec_from_stride(rhs.add(0 * rhs_stride * Self::LANES), rhs_stride);
157                    let b1 = Self::simd_vec_from_stride(rhs.add(1 * rhs_stride * Self::LANES), rhs_stride);
158                    let b2 = Self::simd_vec_from_stride(rhs.add(2 * rhs_stride * Self::LANES), rhs_stride);
159                    let b3 = Self::simd_vec_from_stride(rhs.add(3 * rhs_stride * Self::LANES), rhs_stride);
160
161                    let ab0 = Self::$simd_op(a, b0);
162                    let ab1 = Self::$simd_op(a, b1);
163                    let ab2 = Self::$simd_op(a, b2);
164                    let ab3 = Self::$simd_op(a, b3);
165
166                    Self::simd_store(dst.add(0 * Self::LANES), ab0);
167                    Self::simd_store(dst.add(1 * Self::LANES), ab1);
168                    Self::simd_store(dst.add(2 * Self::LANES), ab2);
169                    Self::simd_store(dst.add(3 * Self::LANES), ab3);
170
171                    count -= 4 * Self::LANES;
172                    rhs = rhs.add(4 * rhs_stride * Self::LANES);
173                    dst = dst.add(4 * Self::LANES);
174                }
175
176                while count != 0 {
177                    *dst = *lhs $operator *rhs;
178
179                    count -= 1;
180                    rhs = rhs.add(rhs_stride);
181                    dst = dst.add(1);
182                }
183            }
184
185            #[cfg(neon_simd)]
186            unsafe fn [<simd_ $name _stride_n_1>](mut lhs: *const Self, lhs_stride: usize, mut rhs: *const Self, mut dst: *mut Self, mut count: usize) {
187                while count >= 4 * Self::LANES {
188                    let a0 = Self::simd_vec_from_stride(lhs.add(0 * lhs_stride * Self::LANES), lhs_stride);
189                    let b0 = Self::simd_load(rhs.add(0 * Self::LANES));
190
191                    let a1 = Self::simd_vec_from_stride(lhs.add(1 * lhs_stride * Self::LANES), lhs_stride);
192                    let b1 = Self::simd_load(rhs.add(1 * Self::LANES));
193
194                    let a2 = Self::simd_vec_from_stride(lhs.add(2 * lhs_stride * Self::LANES), lhs_stride);
195                    let b2 = Self::simd_load(rhs.add(2 * Self::LANES));
196
197                    let a3 = Self::simd_vec_from_stride(lhs.add(3 * lhs_stride * Self::LANES), lhs_stride);
198                    let b3 = Self::simd_load(rhs.add(3 * Self::LANES));
199
200                    let ab0 = Self::$simd_op(a0, b0);
201                    let ab1 = Self::$simd_op(a1, b1);
202                    let ab2 = Self::$simd_op(a2, b2);
203                    let ab3 = Self::$simd_op(a3, b3);
204
205                    Self::simd_store(dst.add(0 * Self::LANES), ab0);
206                    Self::simd_store(dst.add(1 * Self::LANES), ab1);
207                    Self::simd_store(dst.add(2 * Self::LANES), ab2);
208                    Self::simd_store(dst.add(3 * Self::LANES), ab3);
209
210                    count -= 4 * Self::LANES;
211                    lhs = lhs.add(4 * lhs_stride * Self::LANES);
212                    rhs = rhs.add(4 * Self::LANES);
213                    dst = dst.add(4 * Self::LANES);
214                }
215
216                while count != 0 {
217                    *dst = *lhs $operator *rhs;
218
219                    count -= 1;
220                    lhs = lhs.add(lhs_stride);
221                    rhs = rhs.add(1);
222                    dst = dst.add(1);
223                }
224            }
225
226            #[cfg(neon_simd)]
227            unsafe fn [<simd_ $name _stride_1_n>](mut lhs: *const Self, mut rhs: *const Self, rhs_stride: usize, mut dst: *mut Self, mut count: usize) {
228                while count >= 4 * Self::LANES {
229                    let a0 = Self::simd_load(lhs.add(0 * Self::LANES));
230                    let b0 = Self::simd_vec_from_stride(rhs.add(0 * rhs_stride * Self::LANES), rhs_stride);
231
232                    let a1 = Self::simd_load(lhs.add(1 * Self::LANES));
233                    let b1 = Self::simd_vec_from_stride(rhs.add(1 * rhs_stride * Self::LANES), rhs_stride);
234
235                    let a2 = Self::simd_load(lhs.add(2 * Self::LANES));
236                    let b2 = Self::simd_vec_from_stride(rhs.add(2 * rhs_stride * Self::LANES), rhs_stride);
237
238                    let a3 = Self::simd_load(lhs.add(3 * Self::LANES));
239                    let b3 = Self::simd_vec_from_stride(rhs.add(3 * rhs_stride * Self::LANES), rhs_stride);
240
241                    let ab0 = Self::$simd_op(a0, b0);
242                    let ab1 = Self::$simd_op(a1, b1);
243                    let ab2 = Self::$simd_op(a2, b2);
244                    let ab3 = Self::$simd_op(a3, b3);
245
246                    Self::simd_store(dst.add(0 * Self::LANES), ab0);
247                    Self::simd_store(dst.add(1 * Self::LANES), ab1);
248                    Self::simd_store(dst.add(2 * Self::LANES), ab2);
249                    Self::simd_store(dst.add(3 * Self::LANES), ab3);
250
251                    count -= 4 * Self::LANES;
252                    lhs = lhs.add(4 * Self::LANES);
253                    rhs = rhs.add(4 * rhs_stride * Self::LANES);
254                    dst = dst.add(4 * Self::LANES);
255                }
256
257                while count != 0 {
258                    *dst = *lhs $operator *rhs;
259
260                    count -= 1;
261                    lhs = lhs.add(1);
262                    rhs = rhs.add(rhs_stride);
263                    dst = dst.add(1);
264                }
265            }
266
267            #[cfg(neon_simd)]
268            unsafe fn [<simd_ $name _stride_n_n>](mut lhs: *const Self, lhs_stride: usize,
269                                                  mut rhs: *const Self, rhs_stride: usize,
270                                                  mut dst: *mut Self, mut count: usize) {
271                while count >= 4 * Self::LANES {
272                    let a0 = Self::simd_vec_from_stride(lhs.add(0 * lhs_stride * Self::LANES), lhs_stride);
273                    let b0 = Self::simd_vec_from_stride(rhs.add(0 * rhs_stride * Self::LANES), rhs_stride);
274
275                    let a1 = Self::simd_vec_from_stride(lhs.add(1 * lhs_stride * Self::LANES), lhs_stride);
276                    let b1 = Self::simd_vec_from_stride(rhs.add(1 * rhs_stride * Self::LANES), rhs_stride);
277
278                    let a2 = Self::simd_vec_from_stride(lhs.add(2 * lhs_stride * Self::LANES), lhs_stride);
279                    let b2 = Self::simd_vec_from_stride(rhs.add(2 * rhs_stride * Self::LANES), rhs_stride);
280
281                    let a3 = Self::simd_vec_from_stride(lhs.add(3 * lhs_stride * Self::LANES), lhs_stride);
282                    let b3 = Self::simd_vec_from_stride(rhs.add(3 * rhs_stride * Self::LANES), rhs_stride);
283
284                    let ab0 = Self::$simd_op(a0, b0);
285                    let ab1 = Self::$simd_op(a1, b1);
286                    let ab2 = Self::$simd_op(a2, b2);
287                    let ab3 = Self::$simd_op(a3, b3);
288
289                    Self::simd_store(dst.add(0 * Self::LANES), ab0);
290                    Self::simd_store(dst.add(1 * Self::LANES), ab1);
291                    Self::simd_store(dst.add(2 * Self::LANES), ab2);
292                    Self::simd_store(dst.add(3 * Self::LANES), ab3);
293
294                    count -= 4 * Self::LANES;
295                    lhs = lhs.add(4 * lhs_stride * Self::LANES);
296                    rhs = rhs.add(4 * rhs_stride * Self::LANES);
297                    dst = dst.add(4 * Self::LANES);
298                }
299
300                while count != 0 {
301                    *dst = *lhs $operator *rhs;
302
303                    count -= 1;
304                    lhs = lhs.add(lhs_stride);
305                    rhs = rhs.add(rhs_stride);
306                    dst = dst.add(1);
307                }
308            }
309        }
310    };
311}
312
313pub(crate) trait SimdBinaryOps: Simd + Display {
314    simd_elementwise_operations!(add, simd_add, +);
315    simd_elementwise_operations!(sub, simd_sub, -);
316    simd_elementwise_operations!(mul, simd_mul, *);
317    simd_elementwise_operations!(div, simd_div, /);
318}
319
320impl<T: Simd + Display> SimdBinaryOps for T {}
321
322#[macro_export]
323macro_rules! simd_binary_op_specializations {
324    ($name: ident) => {
325        paste! {
326            #[cfg(neon_simd)]
327            unsafe fn [<$name _stride_0_1>](lhs: *const Self, rhs: *const Self,
328                                            dst: *mut Self, count: usize) {
329                use $crate::ops::simd_binary_ops::SimdBinaryOps;
330                Self::[<simd_ $name _stride_0_1>](lhs, rhs, dst, count);
331            }
332
333            #[cfg(neon_simd)]
334            unsafe fn [<$name _stride_1_0>](lhs: *const Self, rhs: *const Self,
335                                            dst: *mut Self, count: usize) {
336                use $crate::ops::simd_binary_ops::SimdBinaryOps;
337                Self::[<simd_ $name _stride_1_0>](lhs, rhs, dst, count);
338            }
339
340            #[cfg(neon_simd)]
341            unsafe fn [<$name _stride_0_n>](lhs: *const Self,
342                                            rhs: *const Self, rhs_stride: usize,
343                                            dst: *mut Self, count: usize) {
344                use $crate::ops::simd_binary_ops::SimdBinaryOps;
345                Self::[<simd_ $name _stride_0_n>](lhs, rhs, rhs_stride, dst, count);
346            }
347
348            #[cfg(neon_simd)]
349            unsafe fn [<$name _stride_n_0>](lhs: *const Self, lhs_stride: usize,
350                                            rhs: *const Self,
351                                            dst: *mut Self, count: usize) {
352                use $crate::ops::simd_binary_ops::SimdBinaryOps;
353                Self::[<simd_ $name _stride_n_0>](lhs, lhs_stride, rhs, dst, count);
354            }
355
356            #[cfg(neon_simd)]
357            unsafe fn [<$name _stride_1_1>](lhs: *const Self, rhs: *const Self,
358                                            dst: *mut Self, count: usize) {
359                use $crate::ops::simd_binary_ops::SimdBinaryOps;
360                Self::[<simd_ $name _stride_1_1>](lhs, rhs, dst, count);
361            }
362
363            #[cfg(neon_simd)]
364            unsafe fn [<$name _stride_1_n>](lhs: *const Self,
365                                            rhs: *const Self, rhs_stride: usize,
366                                            dst: *mut Self, count: usize) {
367                use $crate::ops::simd_binary_ops::SimdBinaryOps;
368                Self::[<simd_ $name _stride_1_n>](lhs, rhs, rhs_stride, dst, count);
369            }
370
371            #[cfg(neon_simd)]
372            unsafe fn [<$name _stride_n_1>](lhs: *const Self, lhs_stride: usize,
373                                            rhs: *const Self,
374                                            dst: *mut Self, count: usize) {
375                use $crate::ops::simd_binary_ops::SimdBinaryOps;
376                Self::[<simd_ $name _stride_n_1>](lhs, lhs_stride, rhs, dst, count);
377            }
378
379            #[cfg(neon_simd)]
380            unsafe fn [<$name _stride_n_n>](lhs: *const Self, lhs_stride: usize,
381                                            rhs: *const Self, rhs_stride: usize,
382                                            dst: *mut Self, count: usize) {
383                use $crate::ops::simd_binary_ops::SimdBinaryOps;
384                Self::[<simd_ $name _stride_n_n>](lhs, lhs_stride, rhs, rhs_stride, dst, count);
385            }
386        }
387    };
388}