1use std::fmt::Display;
2use paste::paste;
3use crate::acceleration::simd::Simd;
4
5macro_rules! simd_elementwise_operations {
6 ($name:ident, $simd_op:ident, $operator:tt) => {
7 paste! {
8 #[cfg(neon_simd)]
9 unsafe fn [<simd_ $name _stride_0_1>](lhs: *const Self, mut rhs: *const Self, mut dst: *mut Self, mut count: usize) {
10 let a = Self::simd_from_constant(*lhs);
11
12 while count >= 4 * Self::LANES {
13 let b0 = Self::simd_load(rhs.add(0 * Self::LANES));
14 let b1 = Self::simd_load(rhs.add(1 * Self::LANES));
15 let b2 = Self::simd_load(rhs.add(2 * Self::LANES));
16 let b3 = Self::simd_load(rhs.add(3 * Self::LANES));
17
18 let ab0 = Self::$simd_op(a, b0);
19 let ab1 = Self::$simd_op(a, b1);
20 let ab2 = Self::$simd_op(a, b2);
21 let ab3 = Self::$simd_op(a, b3);
22
23 Self::simd_store(dst.add(0 * Self::LANES), ab0);
24 Self::simd_store(dst.add(1 * Self::LANES), ab1);
25 Self::simd_store(dst.add(2 * Self::LANES), ab2);
26 Self::simd_store(dst.add(3 * Self::LANES), ab3);
27
28 count -= 4 * Self::LANES;
29 rhs = rhs.add(4 * Self::LANES);
30 dst = dst.add(4 * Self::LANES);
31 }
32
33 while count != 0 {
34 *dst = *lhs $operator *rhs;
35
36 count -= 1;
37 rhs = rhs.add(1);
38 dst = dst.add(1);
39 }
40 }
41
42 #[cfg(neon_simd)]
43 unsafe fn [<simd_ $name _stride_1_0>](mut lhs: *const Self, rhs: *const Self, mut dst: *mut Self, mut count: usize) {
44 let b = Self::simd_from_constant(*rhs);
45
46 while count >= 4 * Self::LANES {
47 let a0 = Self::simd_load(lhs.add(0 * Self::LANES));
48 let a1 = Self::simd_load(lhs.add(1 * Self::LANES));
49 let a2 = Self::simd_load(lhs.add(2 * Self::LANES));
50 let a3 = Self::simd_load(lhs.add(3 * Self::LANES));
51
52 let ab0 = Self::$simd_op(a0, b);
53 let ab1 = Self::$simd_op(a1, b);
54 let ab2 = Self::$simd_op(a2, b);
55 let ab3 = Self::$simd_op(a3, b);
56
57 Self::simd_store(dst.add(0 * Self::LANES), ab0);
58 Self::simd_store(dst.add(1 * Self::LANES), ab1);
59 Self::simd_store(dst.add(2 * Self::LANES), ab2);
60 Self::simd_store(dst.add(3 * Self::LANES), ab3);
61
62 count -= 4 * Self::LANES;
63 lhs = lhs.add(4 * Self::LANES);
64 dst = dst.add(4 * Self::LANES);
65 }
66
67 while count != 0 {
68 *dst = *lhs $operator *rhs;
69
70 count -= 1;
71 lhs = lhs.add(1);
72 dst = dst.add(1);
73 }
74 }
75
76 #[cfg(neon_simd)]
77 unsafe fn [<simd_ $name _stride_1_1>](mut lhs: *const Self, mut rhs: *const Self, mut dst: *mut Self, mut count: usize) {
78 while count >= 4 * Self::LANES {
79 let a0 = Self::simd_load(lhs.add(0 * Self::LANES));
80 let b0 = Self::simd_load(rhs.add(0 * Self::LANES));
81
82 let a1 = Self::simd_load(lhs.add(1 * Self::LANES));
83 let b1 = Self::simd_load(rhs.add(1 * Self::LANES));
84
85 let a2 = Self::simd_load(lhs.add(2 * Self::LANES));
86 let b2 = Self::simd_load(rhs.add(2 * Self::LANES));
87
88 let a3 = Self::simd_load(lhs.add(3 * Self::LANES));
89 let b3 = Self::simd_load(rhs.add(3 * Self::LANES));
90
91 let ab0 = Self::$simd_op(a0, b0);
92 let ab1 = Self::$simd_op(a1, b1);
93 let ab2 = Self::$simd_op(a2, b2);
94 let ab3 = Self::$simd_op(a3, b3);
95
96 Self::simd_store(dst.add(0 * Self::LANES), ab0);
97 Self::simd_store(dst.add(1 * Self::LANES), ab1);
98 Self::simd_store(dst.add(2 * Self::LANES), ab2);
99 Self::simd_store(dst.add(3 * Self::LANES), ab3);
100
101 count -= 4 * Self::LANES;
102 lhs = lhs.add(4 * Self::LANES);
103 rhs = rhs.add(4 * Self::LANES);
104 dst = dst.add(4 * Self::LANES);
105 }
106
107 while count != 0 {
108 *dst = *lhs $operator *rhs;
109
110 count -= 1;
111 lhs = lhs.add(1);
112 rhs = rhs.add(1);
113 dst = dst.add(1);
114 }
115 }
116
117 #[cfg(neon_simd)]
118 unsafe fn [<simd_ $name _stride_n_0>](mut lhs: *const Self, lhs_stride: usize, rhs: *const Self, mut dst: *mut Self, mut count: usize) {
119 let b = Self::simd_from_constant(*rhs);
120
121 while count >= 4 * Self::LANES {
122 let a0 = Self::simd_vec_from_stride(lhs.add(0 * lhs_stride * Self::LANES), lhs_stride);
123 let a1 = Self::simd_vec_from_stride(lhs.add(1 * lhs_stride * Self::LANES), lhs_stride);
124 let a2 = Self::simd_vec_from_stride(lhs.add(2 * lhs_stride * Self::LANES), lhs_stride);
125 let a3 = Self::simd_vec_from_stride(lhs.add(3 * lhs_stride * Self::LANES), lhs_stride);
126
127 let ab0 = Self::$simd_op(a0, b);
128 let ab1 = Self::$simd_op(a1, b);
129 let ab2 = Self::$simd_op(a2, b);
130 let ab3 = Self::$simd_op(a3, b);
131
132 Self::simd_store(dst.add(0 * Self::LANES), ab0);
133 Self::simd_store(dst.add(1 * Self::LANES), ab1);
134 Self::simd_store(dst.add(2 * Self::LANES), ab2);
135 Self::simd_store(dst.add(3 * Self::LANES), ab3);
136
137 count -= 4 * Self::LANES;
138 lhs = lhs.add(4 * lhs_stride * Self::LANES);
139 dst = dst.add(4 * Self::LANES);
140 }
141
142 while count != 0 {
143 *dst = *lhs $operator *rhs;
144
145 count -= 1;
146 lhs = lhs.add(lhs_stride);
147 dst = dst.add(1);
148 }
149 }
150
151 #[cfg(neon_simd)]
152 unsafe fn [<simd_ $name _stride_0_n>](lhs: *const Self, mut rhs: *const Self, rhs_stride: usize, mut dst: *mut Self, mut count: usize) {
153 let a = Self::simd_from_constant(*lhs);
154
155 while count >= 4 * Self::LANES {
156 let b0 = Self::simd_vec_from_stride(rhs.add(0 * rhs_stride * Self::LANES), rhs_stride);
157 let b1 = Self::simd_vec_from_stride(rhs.add(1 * rhs_stride * Self::LANES), rhs_stride);
158 let b2 = Self::simd_vec_from_stride(rhs.add(2 * rhs_stride * Self::LANES), rhs_stride);
159 let b3 = Self::simd_vec_from_stride(rhs.add(3 * rhs_stride * Self::LANES), rhs_stride);
160
161 let ab0 = Self::$simd_op(a, b0);
162 let ab1 = Self::$simd_op(a, b1);
163 let ab2 = Self::$simd_op(a, b2);
164 let ab3 = Self::$simd_op(a, b3);
165
166 Self::simd_store(dst.add(0 * Self::LANES), ab0);
167 Self::simd_store(dst.add(1 * Self::LANES), ab1);
168 Self::simd_store(dst.add(2 * Self::LANES), ab2);
169 Self::simd_store(dst.add(3 * Self::LANES), ab3);
170
171 count -= 4 * Self::LANES;
172 rhs = rhs.add(4 * rhs_stride * Self::LANES);
173 dst = dst.add(4 * Self::LANES);
174 }
175
176 while count != 0 {
177 *dst = *lhs $operator *rhs;
178
179 count -= 1;
180 rhs = rhs.add(rhs_stride);
181 dst = dst.add(1);
182 }
183 }
184
185 #[cfg(neon_simd)]
186 unsafe fn [<simd_ $name _stride_n_1>](mut lhs: *const Self, lhs_stride: usize, mut rhs: *const Self, mut dst: *mut Self, mut count: usize) {
187 while count >= 4 * Self::LANES {
188 let a0 = Self::simd_vec_from_stride(lhs.add(0 * lhs_stride * Self::LANES), lhs_stride);
189 let b0 = Self::simd_load(rhs.add(0 * Self::LANES));
190
191 let a1 = Self::simd_vec_from_stride(lhs.add(1 * lhs_stride * Self::LANES), lhs_stride);
192 let b1 = Self::simd_load(rhs.add(1 * Self::LANES));
193
194 let a2 = Self::simd_vec_from_stride(lhs.add(2 * lhs_stride * Self::LANES), lhs_stride);
195 let b2 = Self::simd_load(rhs.add(2 * Self::LANES));
196
197 let a3 = Self::simd_vec_from_stride(lhs.add(3 * lhs_stride * Self::LANES), lhs_stride);
198 let b3 = Self::simd_load(rhs.add(3 * Self::LANES));
199
200 let ab0 = Self::$simd_op(a0, b0);
201 let ab1 = Self::$simd_op(a1, b1);
202 let ab2 = Self::$simd_op(a2, b2);
203 let ab3 = Self::$simd_op(a3, b3);
204
205 Self::simd_store(dst.add(0 * Self::LANES), ab0);
206 Self::simd_store(dst.add(1 * Self::LANES), ab1);
207 Self::simd_store(dst.add(2 * Self::LANES), ab2);
208 Self::simd_store(dst.add(3 * Self::LANES), ab3);
209
210 count -= 4 * Self::LANES;
211 lhs = lhs.add(4 * lhs_stride * Self::LANES);
212 rhs = rhs.add(4 * Self::LANES);
213 dst = dst.add(4 * Self::LANES);
214 }
215
216 while count != 0 {
217 *dst = *lhs $operator *rhs;
218
219 count -= 1;
220 lhs = lhs.add(lhs_stride);
221 rhs = rhs.add(1);
222 dst = dst.add(1);
223 }
224 }
225
226 #[cfg(neon_simd)]
227 unsafe fn [<simd_ $name _stride_1_n>](mut lhs: *const Self, mut rhs: *const Self, rhs_stride: usize, mut dst: *mut Self, mut count: usize) {
228 while count >= 4 * Self::LANES {
229 let a0 = Self::simd_load(lhs.add(0 * Self::LANES));
230 let b0 = Self::simd_vec_from_stride(rhs.add(0 * rhs_stride * Self::LANES), rhs_stride);
231
232 let a1 = Self::simd_load(lhs.add(1 * Self::LANES));
233 let b1 = Self::simd_vec_from_stride(rhs.add(1 * rhs_stride * Self::LANES), rhs_stride);
234
235 let a2 = Self::simd_load(lhs.add(2 * Self::LANES));
236 let b2 = Self::simd_vec_from_stride(rhs.add(2 * rhs_stride * Self::LANES), rhs_stride);
237
238 let a3 = Self::simd_load(lhs.add(3 * Self::LANES));
239 let b3 = Self::simd_vec_from_stride(rhs.add(3 * rhs_stride * Self::LANES), rhs_stride);
240
241 let ab0 = Self::$simd_op(a0, b0);
242 let ab1 = Self::$simd_op(a1, b1);
243 let ab2 = Self::$simd_op(a2, b2);
244 let ab3 = Self::$simd_op(a3, b3);
245
246 Self::simd_store(dst.add(0 * Self::LANES), ab0);
247 Self::simd_store(dst.add(1 * Self::LANES), ab1);
248 Self::simd_store(dst.add(2 * Self::LANES), ab2);
249 Self::simd_store(dst.add(3 * Self::LANES), ab3);
250
251 count -= 4 * Self::LANES;
252 lhs = lhs.add(4 * Self::LANES);
253 rhs = rhs.add(4 * rhs_stride * Self::LANES);
254 dst = dst.add(4 * Self::LANES);
255 }
256
257 while count != 0 {
258 *dst = *lhs $operator *rhs;
259
260 count -= 1;
261 lhs = lhs.add(1);
262 rhs = rhs.add(rhs_stride);
263 dst = dst.add(1);
264 }
265 }
266
267 #[cfg(neon_simd)]
268 unsafe fn [<simd_ $name _stride_n_n>](mut lhs: *const Self, lhs_stride: usize,
269 mut rhs: *const Self, rhs_stride: usize,
270 mut dst: *mut Self, mut count: usize) {
271 while count >= 4 * Self::LANES {
272 let a0 = Self::simd_vec_from_stride(lhs.add(0 * lhs_stride * Self::LANES), lhs_stride);
273 let b0 = Self::simd_vec_from_stride(rhs.add(0 * rhs_stride * Self::LANES), rhs_stride);
274
275 let a1 = Self::simd_vec_from_stride(lhs.add(1 * lhs_stride * Self::LANES), lhs_stride);
276 let b1 = Self::simd_vec_from_stride(rhs.add(1 * rhs_stride * Self::LANES), rhs_stride);
277
278 let a2 = Self::simd_vec_from_stride(lhs.add(2 * lhs_stride * Self::LANES), lhs_stride);
279 let b2 = Self::simd_vec_from_stride(rhs.add(2 * rhs_stride * Self::LANES), rhs_stride);
280
281 let a3 = Self::simd_vec_from_stride(lhs.add(3 * lhs_stride * Self::LANES), lhs_stride);
282 let b3 = Self::simd_vec_from_stride(rhs.add(3 * rhs_stride * Self::LANES), rhs_stride);
283
284 let ab0 = Self::$simd_op(a0, b0);
285 let ab1 = Self::$simd_op(a1, b1);
286 let ab2 = Self::$simd_op(a2, b2);
287 let ab3 = Self::$simd_op(a3, b3);
288
289 Self::simd_store(dst.add(0 * Self::LANES), ab0);
290 Self::simd_store(dst.add(1 * Self::LANES), ab1);
291 Self::simd_store(dst.add(2 * Self::LANES), ab2);
292 Self::simd_store(dst.add(3 * Self::LANES), ab3);
293
294 count -= 4 * Self::LANES;
295 lhs = lhs.add(4 * lhs_stride * Self::LANES);
296 rhs = rhs.add(4 * rhs_stride * Self::LANES);
297 dst = dst.add(4 * Self::LANES);
298 }
299
300 while count != 0 {
301 *dst = *lhs $operator *rhs;
302
303 count -= 1;
304 lhs = lhs.add(lhs_stride);
305 rhs = rhs.add(rhs_stride);
306 dst = dst.add(1);
307 }
308 }
309 }
310 };
311}
312
313pub(crate) trait SimdBinaryOps: Simd + Display {
314 simd_elementwise_operations!(add, simd_add, +);
315 simd_elementwise_operations!(sub, simd_sub, -);
316 simd_elementwise_operations!(mul, simd_mul, *);
317 simd_elementwise_operations!(div, simd_div, /);
318}
319
320impl<T: Simd + Display> SimdBinaryOps for T {}
321
322#[macro_export]
323macro_rules! simd_binary_op_specializations {
324 ($name: ident) => {
325 paste! {
326 #[cfg(neon_simd)]
327 unsafe fn [<$name _stride_0_1>](lhs: *const Self, rhs: *const Self,
328 dst: *mut Self, count: usize) {
329 use $crate::ops::simd_binary_ops::SimdBinaryOps;
330 Self::[<simd_ $name _stride_0_1>](lhs, rhs, dst, count);
331 }
332
333 #[cfg(neon_simd)]
334 unsafe fn [<$name _stride_1_0>](lhs: *const Self, rhs: *const Self,
335 dst: *mut Self, count: usize) {
336 use $crate::ops::simd_binary_ops::SimdBinaryOps;
337 Self::[<simd_ $name _stride_1_0>](lhs, rhs, dst, count);
338 }
339
340 #[cfg(neon_simd)]
341 unsafe fn [<$name _stride_0_n>](lhs: *const Self,
342 rhs: *const Self, rhs_stride: usize,
343 dst: *mut Self, count: usize) {
344 use $crate::ops::simd_binary_ops::SimdBinaryOps;
345 Self::[<simd_ $name _stride_0_n>](lhs, rhs, rhs_stride, dst, count);
346 }
347
348 #[cfg(neon_simd)]
349 unsafe fn [<$name _stride_n_0>](lhs: *const Self, lhs_stride: usize,
350 rhs: *const Self,
351 dst: *mut Self, count: usize) {
352 use $crate::ops::simd_binary_ops::SimdBinaryOps;
353 Self::[<simd_ $name _stride_n_0>](lhs, lhs_stride, rhs, dst, count);
354 }
355
356 #[cfg(neon_simd)]
357 unsafe fn [<$name _stride_1_1>](lhs: *const Self, rhs: *const Self,
358 dst: *mut Self, count: usize) {
359 use $crate::ops::simd_binary_ops::SimdBinaryOps;
360 Self::[<simd_ $name _stride_1_1>](lhs, rhs, dst, count);
361 }
362
363 #[cfg(neon_simd)]
364 unsafe fn [<$name _stride_1_n>](lhs: *const Self,
365 rhs: *const Self, rhs_stride: usize,
366 dst: *mut Self, count: usize) {
367 use $crate::ops::simd_binary_ops::SimdBinaryOps;
368 Self::[<simd_ $name _stride_1_n>](lhs, rhs, rhs_stride, dst, count);
369 }
370
371 #[cfg(neon_simd)]
372 unsafe fn [<$name _stride_n_1>](lhs: *const Self, lhs_stride: usize,
373 rhs: *const Self,
374 dst: *mut Self, count: usize) {
375 use $crate::ops::simd_binary_ops::SimdBinaryOps;
376 Self::[<simd_ $name _stride_n_1>](lhs, lhs_stride, rhs, dst, count);
377 }
378
379 #[cfg(neon_simd)]
380 unsafe fn [<$name _stride_n_n>](lhs: *const Self, lhs_stride: usize,
381 rhs: *const Self, rhs_stride: usize,
382 dst: *mut Self, count: usize) {
383 use $crate::ops::simd_binary_ops::SimdBinaryOps;
384 Self::[<simd_ $name _stride_n_n>](lhs, lhs_stride, rhs, rhs_stride, dst, count);
385 }
386 }
387 };
388}