1use super::*;
2
3#[inline(always)]
4fn scalar_simd_min(a: f32, b: f32) -> f32 {
5 if a < b {
6 a
7 } else {
8 b
9 }
10}
11
12#[inline(always)]
13fn scalar_simd_max(a: f32, b: f32) -> f32 {
14 if a > b {
15 a
16 } else {
17 b
18 }
19}
20
21impl_op! {
22 fn add<f32> {
23 for Avx512(a: __m512, b: __m512) -> __m512 {
24 _mm512_add_ps(a, b)
25 }
26 for Avx2(a: __m256, b: __m256) -> __m256 {
27 _mm256_add_ps(a, b)
28 }
29 for Sse41(a: __m128, b: __m128) -> __m128 {
30 _mm_add_ps(a, b)
31 }
32 for Sse2(a: __m128, b: __m128) -> __m128 {
33 _mm_add_ps(a, b)
34 }
35 for Scalar(a: f32, b: f32) -> f32 {
36 a + b
37 }
38 for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
39 vaddq_f32(a, b)
40 }
41 for Wasm(a: v128, b: v128) -> v128 {
42 f32x4_add(a, b)
43 }
44 }
45}
46
47impl_op! {
48 fn sub<f32> {
49 for Avx512(a: __m512, b: __m512) -> __m512 {
50 _mm512_sub_ps(a, b)
51 }
52 for Avx2(a: __m256, b: __m256) -> __m256 {
53 _mm256_sub_ps(a, b)
54 }
55 for Sse41(a: __m128, b: __m128) -> __m128 {
56 _mm_sub_ps(a, b)
57 }
58 for Sse2(a: __m128, b: __m128) -> __m128 {
59 _mm_sub_ps(a, b)
60 }
61 for Scalar(a: f32, b: f32) -> f32 {
62 a - b
63 }
64 for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
65 vsubq_f32(a, b)
66 }
67 for Wasm(a: v128, b: v128) -> v128 {
68 f32x4_sub(a, b)
69 }
70 }
71}
72
73impl_op! {
74 fn mul<f32> {
75 for Avx512(a: __m512, b: __m512) -> __m512 {
76 _mm512_mul_ps(a, b)
77 }
78 for Avx2(a: __m256, b: __m256) -> __m256 {
79 _mm256_mul_ps(a, b)
80 }
81 for Sse41(a: __m128, b: __m128) -> __m128 {
82 _mm_mul_ps(a, b)
83 }
84 for Sse2(a: __m128, b: __m128) -> __m128 {
85 _mm_mul_ps(a, b)
86 }
87 for Scalar(a: f32, b: f32) -> f32 {
88 a * b
89 }
90 for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
91 vmulq_f32(a, b)
92 }
93 for Wasm(a: v128, b: v128) -> v128 {
94 f32x4_mul(a, b)
95 }
96 }
97}
98
99impl_op! {
100 fn div<f32> {
101 for Avx512(a: __m512, b: __m512) -> __m512 {
102 _mm512_div_ps(a, b)
103 }
104 for Avx2(a: __m256, b: __m256) -> __m256 {
105 _mm256_div_ps(a, b)
106 }
107 for Sse41(a: __m128, b: __m128) -> __m128 {
108 _mm_div_ps(a, b)
109 }
110 for Sse2(a: __m128, b: __m128) -> __m128 {
111 _mm_div_ps(a, b)
112 }
113 for Scalar(a: f32, b: f32) -> f32 {
114 a / b
115 }
116 for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
117 vdivq_f32(a, b)
118 }
119 for Wasm(a: v128, b: v128) -> v128 {
120 f32x4_div(a, b)
121 }
122 }
123}
124
125impl_op! {
126 fn mul_add<f32> {
127 for Avx512(a: __m512, b: __m512, c: __m512) -> __m512 {
128 _mm512_fmadd_ps(a, b, c)
129 }
130 for Avx2(a: __m256, b: __m256, c: __m256) -> __m256 {
131 _mm256_fmadd_ps(a, b, c)
132 }
133 for Sse41(a: __m128, b: __m128, c: __m128) -> __m128 {
134 _mm_add_ps(_mm_mul_ps(a, b), c)
135 }
136 for Sse2(a: __m128, b: __m128, c: __m128) -> __m128 {
137 _mm_add_ps(_mm_mul_ps(a, b), c)
138 }
139 for Scalar(a: f32, b: f32, c: f32) -> f32 {
140 a * b + c
141 }
142 for Neon(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
143 vfmaq_f32(c, a, b)
144 }
145 for Wasm(a: v128, b: v128, c: v128) -> v128 {
146 f32x4_add(f32x4_mul(a, b), c)
147 }
148 }
149}
150
151impl_op! {
152 fn mul_sub<f32> {
153 for Avx512(a: __m512, b: __m512, c: __m512) -> __m512 {
154 _mm512_fmsub_ps(a, b, c)
155 }
156 for Avx2(a: __m256, b: __m256, c: __m256) -> __m256 {
157 _mm256_fmsub_ps(a, b, c)
158 }
159 for Sse41(a: __m128, b: __m128, c: __m128) -> __m128 {
160 _mm_sub_ps(_mm_mul_ps(a, b), c)
161 }
162 for Sse2(a: __m128, b: __m128, c: __m128) -> __m128 {
163 _mm_sub_ps(_mm_mul_ps(a, b), c)
164 }
165 for Scalar(a: f32, b: f32, c: f32) -> f32 {
166 a * b - c
167 }
168 for Neon(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
169 vnegq_f32(vfmsq_f32(c, a, b))
170 }
171 for Wasm(a: v128, b: v128, c: v128) -> v128 {
172 f32x4_sub(f32x4_mul(a, b), c)
173 }
174 }
175}
176
177impl_op! {
178 fn neg_mul_add<f32> {
179 for Avx512(a: __m512, b: __m512, c: __m512) -> __m512 {
180 _mm512_fnmadd_ps(a, b, c)
181 }
182 for Avx2(a: __m256, b: __m256, c: __m256) -> __m256 {
183 _mm256_fnmadd_ps(a, b, c)
184 }
185 for Sse41(a: __m128, b: __m128, c: __m128) -> __m128 {
186 _mm_sub_ps(c, _mm_mul_ps(a, b))
187 }
188 for Sse2(a: __m128, b: __m128, c: __m128) -> __m128 {
189 _mm_sub_ps(c, _mm_mul_ps(a, b))
190 }
191 for Scalar(a: f32, b: f32, c: f32) -> f32 {
192 c - a * b
193 }
194 for Neon(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
195 vfmsq_f32(c, a, b)
196 }
197 for Wasm(a: v128, b: v128, c: v128) -> v128 {
198 f32x4_sub(c, f32x4_mul(a, b))
199 }
200 }
201}
202
203impl_op! {
204 fn neg_mul_sub<f32> {
205 for Avx512(a: __m512, b: __m512, c: __m512) -> __m512 {
206 _mm512_fnmsub_ps(a, b, c)
207 }
208 for Avx2(a: __m256, b: __m256, c: __m256) -> __m256 {
209 _mm256_fnmsub_ps(a, b, c)
210 }
211 for Sse41(a: __m128, b: __m128, c: __m128) -> __m128 {
212 let mul = _mm_mul_ps(a, b);
213 let neg = _mm_sub_ps(_mm_setzero_ps(), mul);
214 _mm_sub_ps(neg, c)
215 }
216 for Sse2(a: __m128, b: __m128, c: __m128) -> __m128 {
217 let mul = _mm_mul_ps(a, b);
218 let neg = _mm_sub_ps(_mm_setzero_ps(), mul);
219 _mm_sub_ps(neg, c)
220 }
221 for Scalar(a: f32, b: f32, c: f32) -> f32 {
222 -a * b - c
223 }
224 for Neon(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
225 vnegq_f32(vfmaq_f32(c, a, b))
226 }
227 for Wasm(a: v128, b: v128, c: v128) -> v128 {
228 f32x4_sub(f32x4_neg(f32x4_mul(a, b)), c)
229 }
230 }
231}
232
233impl_op! {
234 fn sqrt<f32> {
235 for Avx512(a: __m512) -> __m512 {
236 _mm512_sqrt_ps(a)
237 }
238 for Avx2(a: __m256) -> __m256 {
239 _mm256_sqrt_ps(a)
240 }
241 for Sse41(a: __m128) -> __m128 {
242 _mm_sqrt_ps(a)
243 }
244 for Sse2(a: __m128) -> __m128 {
245 _mm_sqrt_ps(a)
246 }
247 for Scalar(a: f32) -> f32 {
248 a.m_sqrt()
249 }
250 for Neon(a: float32x4_t) -> float32x4_t {
251 vsqrtq_f32(a)
252 }
253 for Wasm(a: v128) -> v128 {
254 f32x4_sqrt(a)
255 }
256 }
257}
258
259impl_op! {
260 fn recip<f32> {
261 for Avx512(a: __m512) -> __m512 {
262 _mm512_rcp14_ps(a)
263 }
264 for Avx2(a: __m256) -> __m256 {
265 _mm256_rcp_ps(a)
266 }
267 for Sse41(a: __m128) -> __m128 {
268 _mm_rcp_ps(a)
269 }
270 for Sse2(a: __m128) -> __m128 {
271 _mm_rcp_ps(a)
272 }
273 for Scalar(a: f32) -> f32 {
274 1.0 / a
275 }
276 for Neon(a: float32x4_t) -> float32x4_t {
277 vrecpeq_f32(a)
278 }
279 for Wasm(a: v128) -> v128 {
280 f32x4_div(f32x4_splat(1.0), a)
281 }
282 }
283}
284
285impl_op! {
286 fn rsqrt<f32> {
287 for Avx512(a: __m512) -> __m512 {
288 _mm512_rsqrt14_ps(a)
289 }
290 for Avx2(a: __m256) -> __m256 {
291 _mm256_rsqrt_ps(a)
292 }
293 for Sse41(a: __m128) -> __m128 {
294 _mm_rsqrt_ps(a)
295 }
296 for Sse2(a: __m128) -> __m128 {
297 _mm_rsqrt_ps(a)
298 }
299 for Scalar(a: f32) -> f32 {
300 1.0 / a.m_sqrt()
301 }
302 for Neon(a: float32x4_t) -> float32x4_t {
303 vrsqrteq_f32(a)
304 }
305 for Wasm(a: v128) -> v128 {
306 f32x4_div(f32x4_splat(1.0), f32x4_sqrt(a))
307 }
308 }
309}
310
311impl_op! {
312 fn min<f32> {
313 for Avx512(a: __m512, b: __m512) -> __m512 {
314 _mm512_min_ps(a, b)
315 }
316 for Avx2(a: __m256, b: __m256) -> __m256 {
317 _mm256_min_ps(a, b)
318 }
319 for Sse41(a: __m128, b: __m128) -> __m128 {
320 _mm_min_ps(a, b)
321 }
322 for Sse2(a: __m128, b: __m128) -> __m128 {
323 _mm_min_ps(a, b)
324 }
325 for Scalar(a: f32, b: f32) -> f32 {
326 scalar_simd_min(a, b)
327 }
328 for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
329 vminq_f32(a, b)
330 }
331 for Wasm(a: v128, b: v128) -> v128 {
332 f32x4_min(a, b)
333 }
334 }
335}
336
337impl_op! {
338 fn max<f32> {
339 for Avx512(a: __m512, b: __m512) -> __m512 {
340 _mm512_max_ps(a, b)
341 }
342 for Avx2(a: __m256, b: __m256) -> __m256 {
343 _mm256_max_ps(a, b)
344 }
345 for Sse41(a: __m128, b: __m128) -> __m128 {
346 _mm_max_ps(a, b)
347 }
348 for Sse2(a: __m128, b: __m128) -> __m128 {
349 _mm_max_ps(a, b)
350 }
351 for Scalar(a: f32, b: f32) -> f32 {
352 scalar_simd_max(a, b)
353 }
354 for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
355 vmaxq_f32(a, b)
356 }
357 for Wasm(a: v128, b: v128) -> v128 {
358 f32x4_max(a, b)
359 }
360 }
361}
362
363impl_op! {
364 fn abs<f32> {
365 for Avx512(a: __m512) -> __m512 {
366 _mm512_andnot_ps(_mm512_set1_ps(-0.0), a)
367 }
368 for Avx2(a: __m256) -> __m256 {
369 _mm256_andnot_ps(_mm256_set1_ps(-0.0), a)
370 }
371 for Sse41(a: __m128) -> __m128 {
372 _mm_andnot_ps(_mm_set1_ps(-0.0), a)
373 }
374 for Sse2(a: __m128) -> __m128 {
375 _mm_andnot_ps(_mm_set1_ps(-0.0), a)
376 }
377 for Scalar(a: f32) -> f32 {
378 a.m_abs()
379 }
380 for Neon(a: float32x4_t) -> float32x4_t {
381 vabsq_f32(a)
382 }
383 for Wasm(a: v128) -> v128 {
384 f32x4_abs(a)
385 }
386 }
387}
388
389impl_op! {
390 fn round<f32> {
391 for Avx512(a: __m512) -> __m512 {
392 _mm512_roundscale_ps::<0x08>(a)
393 }
394 for Avx2(a: __m256) -> __m256 {
395 _mm256_round_ps(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
396 }
397 for Sse41(a: __m128) -> __m128 {
398 _mm_round_ps(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
399 }
400 for Sse2(a: __m128) -> __m128 {
401 let sign_mask = _mm_set1_ps(-0.0);
402 let magic = _mm_castsi128_ps(_mm_set1_epi32(0x4B000000));
403 let sign = _mm_and_ps(a, sign_mask);
404 let signed_magic = _mm_or_ps(magic, sign);
405 let b = _mm_add_ps(a, signed_magic);
406 _mm_sub_ps(b, signed_magic)
407 }
408 for Scalar(a: f32) -> f32 {
409 a.m_round()
410 }
411 for Neon(a: float32x4_t) -> float32x4_t {
412 vrndaq_f32(a)
413 }
414 for Wasm(a: v128) -> v128 {
415 f32x4_nearest(a)
416 }
417 }
418}
419
420impl_op! {
421 fn floor<f32> {
422 for Avx512(a: __m512) -> __m512 {
423 _mm512_roundscale_ps::<0x09>(a)
424 }
425 for Avx2(a: __m256) -> __m256 {
426 _mm256_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)
427 }
428 for Sse41(a: __m128) -> __m128 {
429 _mm_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)
430 }
431 for Sse2(a: __m128) -> __m128 {
432 let nums_arr = core::mem::transmute::<__m128, [f32; 4]>(a);
433 let ceil = [
434 nums_arr[0].m_floor(),
435 nums_arr[1].m_floor(),
436 nums_arr[2].m_floor(),
437 nums_arr[3].m_floor(),
438 ];
439 core::mem::transmute::<[f32; 4], __m128>(ceil)
440 }
441 for Scalar(a: f32) -> f32 {
442 a.m_floor()
443 }
444 for Neon(a: float32x4_t) -> float32x4_t {
445 vrndmq_f32(a)
446 }
447 for Wasm(a: v128) -> v128 {
448 f32x4_floor(a)
449 }
450 }
451}
452
453impl_op! {
454 fn ceil<f32> {
455 for Avx512(a: __m512) -> __m512 {
456 _mm512_roundscale_ps::<0x0A>(a)
457 }
458 for Avx2(a: __m256) -> __m256 {
459 _mm256_round_ps(a, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)
460 }
461 for Sse41(a: __m128) -> __m128 {
462 _mm_round_ps(a, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)
463 }
464 for Sse2(a: __m128) -> __m128 {
465 let nums_arr = core::mem::transmute::<__m128, [f32; 4]>(a);
466 let ceil = [
467 nums_arr[0].m_ceil(),
468 nums_arr[1].m_ceil(),
469 nums_arr[2].m_ceil(),
470 nums_arr[3].m_ceil(),
471 ];
472 core::mem::transmute::<[f32; 4], __m128>(ceil)
473 }
474 for Scalar(a: f32) -> f32 {
475 a.m_ceil()
476 }
477 for Neon(a: float32x4_t) -> float32x4_t {
478 vrndpq_f32(a)
479 }
480 for Wasm(a: v128) -> v128 {
481 f32x4_ceil(a)
482 }
483 }
484}
485
486impl_op! {
487 fn fast_round<f32> {
488 for Avx512(a: __m512) -> __m512 {
489 Self::round(a)
490 }
491 for Avx2(a: __m256) -> __m256 {
492 Self::round(a)
493 }
494 for Sse41(a: __m128) -> __m128 {
495 Self::round(a)
496 }
497 for Sse2(a: __m128) -> __m128 {
498 Self::round(a)
499 }
500 for Scalar(a: f32) -> f32 {
501 Self::round(a)
502 }
503 for Neon(a: float32x4_t) -> float32x4_t {
504 Self::round(a)
505 }
506 for Wasm(a: v128) -> v128 {
507 Self::round(a)
508 }
509 }
510}
511
512impl_op! {
513 fn fast_floor<f32> {
514 for Avx512(a: __m512) -> __m512 {
515 Self::floor(a)
516 }
517 for Avx2(a: __m256) -> __m256 {
518 Self::floor(a)
519 }
520 for Sse41(a: __m128) -> __m128 {
521 Self::floor(a)
522 }
523 for Sse2(a: __m128) -> __m128 {
524 Self::floor(a)
525 }
526 for Scalar(a: f32) -> f32 {
527 Self::floor(a)
528 }
529 for Neon(a: float32x4_t) -> float32x4_t {
530 Self::floor(a)
531 }
532 for Wasm(a: v128) -> v128 {
533 Self::floor(a)
534 }
535 }
536}
537
538impl_op! {
539 fn fast_ceil<f32> {
540 for Avx512(a: __m512) -> __m512 {
541 Self::ceil(a)
542 }
543 for Avx2(a: __m256) -> __m256 {
544 Self::ceil(a)
545 }
546 for Sse41(a: __m128) -> __m128 {
547 Self::ceil(a)
548 }
549 for Sse2(a: __m128) -> __m128 {
550 Self::ceil(a)
551 }
552 for Scalar(a: f32) -> f32 {
553 Self::ceil(a)
554 }
555 for Neon(a: float32x4_t) -> float32x4_t {
556 Self::ceil(a)
557 }
558 for Wasm(a: v128) -> v128 {
559 Self::ceil(a)
560 }
561 }
562}
563
564impl_op! {
565 fn eq<f32> {
566 for Avx512(a: __m512, b: __m512) -> __m512 {
567 let k = _mm512_cmp_ps_mask::<_CMP_EQ_OQ>(a, b);
568 _mm512_castsi512_ps(_mm512_movm_epi32(k))
569 }
570 for Avx2(a: __m256, b: __m256) -> __m256 {
571 _mm256_cmp_ps(a, b, _CMP_EQ_OQ)
572 }
573 for Sse41(a: __m128, b: __m128) -> __m128 {
574 _mm_cmpeq_ps(a, b)
575 }
576 for Sse2(a: __m128, b: __m128) -> __m128 {
577 _mm_cmpeq_ps(a, b)
578 }
579 for Scalar(a: f32, b: f32) -> f32 {
580 if a == b {
581 f32::from_bits(u32::MAX)
582 } else {
583 0.0
584 }
585 }
586 for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
587 vreinterpretq_f32_u32(vceqq_f32(a, b))
588 }
589 for Wasm(a: v128, b: v128) -> v128 {
590 f32x4_eq(a, b)
591 }
592 }
593}
594
595impl_op! {
596 fn neq<f32> {
597 for Avx512(a: __m512, b: __m512) -> __m512 {
598 let k = _mm512_cmp_ps_mask::<_CMP_NEQ_UQ>(a, b);
599 _mm512_castsi512_ps(_mm512_movm_epi32(k))
600 }
601 for Avx2(a: __m256, b: __m256) -> __m256 {
602 _mm256_cmp_ps(a, b, _CMP_NEQ_UQ)
603 }
604 for Sse41(a: __m128, b: __m128) -> __m128 {
605 _mm_cmpneq_ps(a, b)
606 }
607 for Sse2(a: __m128, b: __m128) -> __m128 {
608 _mm_cmpneq_ps(a, b)
609 }
610 for Scalar(a: f32, b: f32) -> f32 {
611 if a != b {
612 f32::from_bits(u32::MAX)
613 } else {
614 0.0
615 }
616 }
617 for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
618 vreinterpretq_f32_u32(vmvnq_u32(vceqq_f32(a, b)))
619 }
620 for Wasm(a: v128, b: v128) -> v128 {
621 f32x4_ne(a, b)
622 }
623 }
624}
625
626impl_op! {
627 fn lt<f32> {
628 for Avx512(a: __m512, b: __m512) -> __m512 {
629 let k = _mm512_cmp_ps_mask::<_CMP_LT_OQ>(a, b);
630 _mm512_castsi512_ps(_mm512_movm_epi32(k))
631 }
632 for Avx2(a: __m256, b: __m256) -> __m256 {
633 _mm256_cmp_ps(a, b, _CMP_LT_OQ)
634 }
635 for Sse41(a: __m128, b: __m128) -> __m128 {
636 _mm_cmplt_ps(a, b)
637 }
638 for Sse2(a: __m128, b: __m128) -> __m128 {
639 _mm_cmplt_ps(a, b)
640 }
641 for Scalar(a: f32, b: f32) -> f32 {
642 if a < b {
643 f32::from_bits(u32::MAX)
644 } else {
645 0.0
646 }
647 }
648 for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
649 vreinterpretq_f32_u32(vcltq_f32(a, b))
650 }
651 for Wasm(a: v128, b: v128) -> v128 {
652 f32x4_lt(a, b)
653 }
654 }
655}
656
657impl_op! {
658 fn lte<f32> {
659 for Avx512(a: __m512, b: __m512) -> __m512 {
660 let k = _mm512_cmp_ps_mask::<_CMP_LE_OQ>(a, b);
661 _mm512_castsi512_ps(_mm512_movm_epi32(k))
662 }
663 for Avx2(a: __m256, b: __m256) -> __m256 {
664 _mm256_cmp_ps(a, b, _CMP_LE_OQ)
665 }
666 for Sse41(a: __m128, b: __m128) -> __m128 {
667 _mm_cmple_ps(a, b)
668 }
669 for Sse2(a: __m128, b: __m128) -> __m128 {
670 _mm_cmple_ps(a, b)
671 }
672 for Scalar(a: f32, b: f32) -> f32 {
673 if a <= b {
674 f32::from_bits(u32::MAX)
675 } else {
676 0.0
677 }
678 }
679 for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
680 vreinterpretq_f32_u32(vcleq_f32(a, b))
681 }
682 for Wasm(a: v128, b: v128) -> v128 {
683 f32x4_le(a, b)
684 }
685 }
686}
687
688impl_op! {
689 fn gt<f32> {
690 for Avx512(a: __m512, b: __m512) -> __m512 {
691 let k = _mm512_cmp_ps_mask::<_CMP_GT_OQ>(a, b);
692 _mm512_castsi512_ps(_mm512_movm_epi32(k))
693 }
694 for Avx2(a: __m256, b: __m256) -> __m256 {
695 _mm256_cmp_ps(a, b, _CMP_GT_OQ)
696 }
697 for Sse41(a: __m128, b: __m128) -> __m128 {
698 _mm_cmpgt_ps(a, b)
699 }
700 for Sse2(a: __m128, b: __m128) -> __m128 {
701 _mm_cmpgt_ps(a, b)
702 }
703 for Scalar(a: f32, b: f32) -> f32 {
704 if a > b {
705 f32::from_bits(u32::MAX)
706 } else {
707 0.0
708 }
709 }
710 for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
711 vreinterpretq_f32_u32(vcgtq_f32(a, b))
712 }
713 for Wasm(a: v128, b: v128) -> v128 {
714 f32x4_gt(a, b)
715 }
716 }
717}
718
719impl_op! {
720 fn gte<f32> {
721 for Avx512(a: __m512, b: __m512) -> __m512 {
722 let k = _mm512_cmp_ps_mask::<_CMP_GE_OQ>(a, b);
723 _mm512_castsi512_ps(_mm512_movm_epi32(k))
724 }
725 for Avx2(a: __m256, b: __m256) -> __m256 {
726 _mm256_cmp_ps(a, b, _CMP_GE_OQ)
727 }
728 for Sse41(a: __m128, b: __m128) -> __m128 {
729 _mm_cmpge_ps(a, b)
730 }
731 for Sse2(a: __m128, b: __m128) -> __m128 {
732 _mm_cmpge_ps(a, b)
733 }
734 for Scalar(a: f32, b: f32) -> f32 {
735 if a >= b {
736 f32::from_bits(u32::MAX)
737 } else {
738 0.0
739 }
740 }
741 for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
742 vreinterpretq_f32_u32(vcgeq_f32(a, b))
743 }
744 for Wasm(a: v128, b: v128) -> v128 {
745 f32x4_ge(a, b)
746 }
747 }
748}
749
750impl_op! {
751 fn blendv<f32> {
752 for Avx512(a: __m512, b: __m512, mask: __m512) -> __m512 {
753 let k = _mm512_movepi32_mask(_mm512_castps_si512(mask));
754 _mm512_mask_blend_ps(k, a, b)
755 }
756 for Avx2(a: __m256, b: __m256, mask: __m256) -> __m256 {
757 _mm256_blendv_ps(a, b, mask)
758 }
759 for Sse41(a: __m128, b: __m128, mask: __m128) -> __m128 {
760 _mm_blendv_ps(a, b, mask)
761 }
762 for Sse2(a: __m128, b: __m128, mask: __m128) -> __m128 {
763 _mm_or_ps(_mm_and_ps(mask, b), _mm_andnot_ps(mask, a))
764 }
765 for Scalar(a: f32, b: f32, mask: f32) -> f32 {
766 if mask.to_bits() == 0 {
767 a
768 } else {
769 b
770 }
771 }
772 for Neon(a: float32x4_t, b: float32x4_t, mask: float32x4_t) -> float32x4_t {
773 vbslq_f32(vreinterpretq_u32_f32(mask), b, a)
774 }
775 for Wasm(a: v128, b: v128, mask: v128) -> v128 {
776 v128_or(v128_and(mask, b), v128_andnot(a, mask))
777 }
778 }
779}
780
781impl_op! {
782 fn horizontal_add<f32> {
783 for Avx512(a: __m512) -> f32 {
784 _mm512_reduce_add_ps(a)
785 }
786 for Avx2(a: __m256) -> f32 {
787 let a = _mm256_hadd_ps(a, a);
788 let b = _mm256_hadd_ps(a, a);
789
790 let first = _mm_cvtss_f32(_mm256_extractf128_ps(b, 0));
791 let second = _mm_cvtss_f32(_mm256_extractf128_ps(b, 1));
792
793 first + second
794 }
795 for Sse41(a: __m128) -> f32 {
796 let a = _mm_hadd_ps(a, a);
797 let b = _mm_hadd_ps(a, a);
798
799 _mm_cvtss_f32(b)
800 }
801 for Sse2(a: __m128) -> f32 {
802 let t1 = _mm_movehl_ps(a, a);
803 let t2 = _mm_add_ps(a, t1);
804 let t3 = _mm_shuffle_ps(t2, t2, 1);
805 _mm_cvtss_f32(t2) + _mm_cvtss_f32(t3)
806 }
807 for Scalar(a: f32) -> f32 {
808 a
809 }
810 for Neon(a: float32x4_t) -> f32 {
811 let a = vpaddq_f32(a, a);
812 let a = vpaddq_f32(a, a);
813 vgetq_lane_f32(a, 0)
814 }
815 for Wasm(a: v128) -> f32 {
816 let l0 = f32x4_extract_lane::<0>(a);
817 let l1 = f32x4_extract_lane::<1>(a);
818 let l2 = f32x4_extract_lane::<2>(a);
819 let l3 = f32x4_extract_lane::<3>(a);
820 l0 + l1 + l2 + l3
821 }
822 }
823}
824
825impl_op! {
826 fn cast_i32<f32> {
827 for Avx512(a: __m512) -> __m512i {
828 _mm512_cvtps_epi32(a)
829 }
830 for Avx2(a: __m256) -> __m256i {
831 _mm256_cvtps_epi32(a)
832 }
833 for Sse41(a: __m128) -> __m128i {
834 _mm_cvtps_epi32(a)
835 }
836 for Sse2(a: __m128) -> __m128i {
837 _mm_cvtps_epi32(a)
838 }
839 for Scalar(a: f32) -> i32 {
840 a.m_round_ties_even() as i32
841 }
842 for Neon(a: float32x4_t) -> int32x4_t {
843 let a = vrndnq_f32(a);
845 vcvtq_s32_f32(a)
846 }
847 for Wasm(a: v128) -> v128 {
848 let a = f32x4_nearest(a);
849 i32x4_trunc_sat_f32x4(a)
850 }
851 }
852}
853
854impl_op! {
855 fn bitcast_i32<f32> {
856 for Avx512(a: __m512) -> __m512i {
857 _mm512_castps_si512(a)
858 }
859 for Avx2(a: __m256) -> __m256i {
860 _mm256_castps_si256(a)
861 }
862 for Sse41(a: __m128) -> __m128i {
863 _mm_castps_si128(a)
864 }
865 for Sse2(a: __m128) -> __m128i {
866 _mm_castps_si128(a)
867 }
868 for Scalar(a: f32) -> i32 {
869 a.to_bits() as i32
870 }
871 for Neon(a: float32x4_t) -> int32x4_t {
872 vreinterpretq_s32_f32(a)
873 }
874 for Wasm(a: v128) -> v128 {
875 a
876 }
877 }
878}
879
880impl_op! {
881 fn zeroes<f32> {
882 for Avx512() -> __m512 {
883 _mm512_setzero_ps()
884 }
885 for Avx2() -> __m256 {
886 _mm256_setzero_ps()
887 }
888 for Sse41() -> __m128 {
889 _mm_setzero_ps()
890 }
891 for Sse2() -> __m128 {
892 _mm_setzero_ps()
893 }
894 for Scalar() -> f32 {
895 0.0
896 }
897 for Neon() -> float32x4_t {
898 vdupq_n_f32(0.0)
899 }
900 for Wasm() -> v128 {
901 f32x4_splat(0.0)
902 }
903 }
904}
905
906impl_op! {
907 fn set1<f32> {
908 for Avx512(val: f32) -> __m512 {
909 _mm512_set1_ps(val)
910 }
911 for Avx2(val: f32) -> __m256 {
912 _mm256_set1_ps(val)
913 }
914 for Sse41(val: f32) -> __m128 {
915 _mm_set1_ps(val)
916 }
917 for Sse2(val: f32) -> __m128 {
918 _mm_set1_ps(val)
919 }
920 for Scalar(val: f32) -> f32 {
921 val
922 }
923 for Neon(val: f32) -> float32x4_t {
924 vdupq_n_f32(val)
925 }
926 for Wasm(val: f32) -> v128 {
927 f32x4_splat(val)
928 }
929 }
930}
931
932impl_op! {
933 fn load_unaligned<f32> {
934 for Avx512(ptr: *const f32) -> __m512 {
935 _mm512_loadu_ps(ptr)
936 }
937 for Avx2(ptr: *const f32) -> __m256 {
938 _mm256_loadu_ps(ptr)
939 }
940 for Sse41(ptr: *const f32) -> __m128 {
941 _mm_loadu_ps(ptr)
942 }
943 for Sse2(ptr: *const f32) -> __m128 {
944 _mm_loadu_ps(ptr)
945 }
946 for Scalar(ptr: *const f32) -> f32 {
947 unsafe { *ptr }
948 }
949 for Neon(ptr: *const f32) -> float32x4_t {
950 vld1q_f32(ptr)
951 }
952 for Wasm(ptr: *const f32) -> v128 {
953 unsafe { v128_load(ptr as *const v128) }
954 }
955 }
956}
957
958impl_op! {
959 fn load_aligned<f32> {
960 for Avx512(ptr: *const f32) -> __m512 {
961 _mm512_load_ps(ptr)
962 }
963 for Avx2(ptr: *const f32) -> __m256 {
964 _mm256_load_ps(ptr)
965 }
966 for Sse41(ptr: *const f32) -> __m128 {
967 _mm_load_ps(ptr)
968 }
969 for Sse2(ptr: *const f32) -> __m128 {
970 _mm_load_ps(ptr)
971 }
972 for Scalar(ptr: *const f32) -> f32 {
973 unsafe { *ptr }
974 }
975 for Neon(ptr: *const f32) -> float32x4_t {
976 vld1q_f32(ptr)
977 }
978 for Wasm(ptr: *const f32) -> v128 {
979 *(ptr as *const v128)
980 }
981 }
982}
983
984impl_op! {
985 fn store_unaligned<f32> {
986 for Avx512(ptr: *mut f32, a: __m512) {
987 _mm512_storeu_ps(ptr, a)
988 }
989 for Avx2(ptr: *mut f32, a: __m256) {
990 _mm256_storeu_ps(ptr, a)
991 }
992 for Sse41(ptr: *mut f32, a: __m128) {
993 _mm_storeu_ps(ptr, a)
994 }
995 for Sse2(ptr: *mut f32, a: __m128) {
996 _mm_storeu_ps(ptr, a)
997 }
998 for Scalar(ptr: *mut f32, a: f32) {
999 unsafe { *ptr = a }
1000 }
1001 for Neon(ptr: *mut f32, a: float32x4_t) {
1002 vst1q_f32(ptr, a)
1003 }
1004 for Wasm(ptr: *mut f32, a: v128) {
1005 unsafe { v128_store(ptr as *mut v128, a) }
1006 }
1007 }
1008}
1009
1010impl_op! {
1011 fn store_aligned<f32> {
1012 for Avx512(ptr: *mut f32, a: __m512) {
1013 _mm512_store_ps(ptr, a)
1014 }
1015 for Avx2(ptr: *mut f32, a: __m256) {
1016 _mm256_store_ps(ptr, a)
1017 }
1018 for Sse41(ptr: *mut f32, a: __m128) {
1019 _mm_store_ps(ptr, a)
1020 }
1021 for Sse2(ptr: *mut f32, a: __m128) {
1022 _mm_store_ps(ptr, a)
1023 }
1024 for Scalar(ptr: *mut f32, a: f32) {
1025 unsafe { *ptr = a }
1026 }
1027 for Neon(ptr: *mut f32, a: float32x4_t) {
1028 vst1q_f32(ptr, a)
1029 }
1030 for Wasm(ptr: *mut f32, a: v128) {
1031 *(ptr as *mut v128) = a;
1032 }
1033 }
1034}