Skip to main content

simdeez/ops/
f32.rs

1use super::*;
2
3#[inline(always)]
4fn scalar_simd_min(a: f32, b: f32) -> f32 {
5    if a < b {
6        a
7    } else {
8        b
9    }
10}
11
12#[inline(always)]
13fn scalar_simd_max(a: f32, b: f32) -> f32 {
14    if a > b {
15        a
16    } else {
17        b
18    }
19}
20
21impl_op! {
22    fn add<f32> {
23        for Avx512(a: __m512, b: __m512) -> __m512 {
24            _mm512_add_ps(a, b)
25        }
26        for Avx2(a: __m256, b: __m256) -> __m256 {
27            _mm256_add_ps(a, b)
28        }
29        for Sse41(a: __m128, b: __m128) -> __m128 {
30            _mm_add_ps(a, b)
31        }
32        for Sse2(a: __m128, b: __m128) -> __m128 {
33            _mm_add_ps(a, b)
34        }
35        for Scalar(a: f32, b: f32) -> f32 {
36            a + b
37        }
38        for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
39            vaddq_f32(a, b)
40        }
41        for Wasm(a: v128, b: v128) -> v128 {
42            f32x4_add(a, b)
43        }
44    }
45}
46
47impl_op! {
48    fn sub<f32> {
49        for Avx512(a: __m512, b: __m512) -> __m512 {
50            _mm512_sub_ps(a, b)
51        }
52        for Avx2(a: __m256, b: __m256) -> __m256 {
53            _mm256_sub_ps(a, b)
54        }
55        for Sse41(a: __m128, b: __m128) -> __m128 {
56            _mm_sub_ps(a, b)
57        }
58        for Sse2(a: __m128, b: __m128) -> __m128 {
59            _mm_sub_ps(a, b)
60        }
61        for Scalar(a: f32, b: f32) -> f32 {
62            a - b
63        }
64        for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
65            vsubq_f32(a, b)
66        }
67        for Wasm(a: v128, b: v128) -> v128 {
68            f32x4_sub(a, b)
69        }
70    }
71}
72
73impl_op! {
74    fn mul<f32> {
75        for Avx512(a: __m512, b: __m512) -> __m512 {
76            _mm512_mul_ps(a, b)
77        }
78        for Avx2(a: __m256, b: __m256) -> __m256 {
79            _mm256_mul_ps(a, b)
80        }
81        for Sse41(a: __m128, b: __m128) -> __m128 {
82            _mm_mul_ps(a, b)
83        }
84        for Sse2(a: __m128, b: __m128) -> __m128 {
85            _mm_mul_ps(a, b)
86        }
87        for Scalar(a: f32, b: f32) -> f32 {
88            a * b
89        }
90        for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
91            vmulq_f32(a, b)
92        }
93        for Wasm(a: v128, b: v128) -> v128 {
94            f32x4_mul(a, b)
95        }
96    }
97}
98
99impl_op! {
100    fn div<f32> {
101        for Avx512(a: __m512, b: __m512) -> __m512 {
102            _mm512_div_ps(a, b)
103        }
104        for Avx2(a: __m256, b: __m256) -> __m256 {
105            _mm256_div_ps(a, b)
106        }
107        for Sse41(a: __m128, b: __m128) -> __m128 {
108            _mm_div_ps(a, b)
109        }
110        for Sse2(a: __m128, b: __m128) -> __m128 {
111            _mm_div_ps(a, b)
112        }
113        for Scalar(a: f32, b: f32) -> f32 {
114            a / b
115        }
116        for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
117            vdivq_f32(a, b)
118        }
119        for Wasm(a: v128, b: v128) -> v128 {
120            f32x4_div(a, b)
121        }
122    }
123}
124
125impl_op! {
126    fn mul_add<f32> {
127        for Avx512(a: __m512, b: __m512, c: __m512) -> __m512 {
128            _mm512_fmadd_ps(a, b, c)
129        }
130        for Avx2(a: __m256, b: __m256, c: __m256) -> __m256 {
131            _mm256_fmadd_ps(a, b, c)
132        }
133        for Sse41(a: __m128, b: __m128, c: __m128) -> __m128 {
134            _mm_add_ps(_mm_mul_ps(a, b), c)
135        }
136        for Sse2(a: __m128, b: __m128, c: __m128) -> __m128 {
137            _mm_add_ps(_mm_mul_ps(a, b), c)
138        }
139        for Scalar(a: f32, b: f32, c: f32) -> f32 {
140            a * b + c
141        }
142        for Neon(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
143            vfmaq_f32(c, a, b)
144        }
145        for Wasm(a: v128, b: v128, c: v128) -> v128 {
146            f32x4_add(f32x4_mul(a, b), c)
147        }
148    }
149}
150
151impl_op! {
152    fn mul_sub<f32> {
153        for Avx512(a: __m512, b: __m512, c: __m512) -> __m512 {
154            _mm512_fmsub_ps(a, b, c)
155        }
156        for Avx2(a: __m256, b: __m256, c: __m256) -> __m256 {
157            _mm256_fmsub_ps(a, b, c)
158        }
159        for Sse41(a: __m128, b: __m128, c: __m128) -> __m128 {
160            _mm_sub_ps(_mm_mul_ps(a, b), c)
161        }
162        for Sse2(a: __m128, b: __m128, c: __m128) -> __m128 {
163            _mm_sub_ps(_mm_mul_ps(a, b), c)
164        }
165        for Scalar(a: f32, b: f32, c: f32) -> f32 {
166            a * b - c
167        }
168        for Neon(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
169            vnegq_f32(vfmsq_f32(c, a, b))
170        }
171        for Wasm(a: v128, b: v128, c: v128) -> v128 {
172            f32x4_sub(f32x4_mul(a, b), c)
173        }
174    }
175}
176
177impl_op! {
178    fn neg_mul_add<f32> {
179        for Avx512(a: __m512, b: __m512, c: __m512) -> __m512 {
180            _mm512_fnmadd_ps(a, b, c)
181        }
182        for Avx2(a: __m256, b: __m256, c: __m256) -> __m256 {
183            _mm256_fnmadd_ps(a, b, c)
184        }
185        for Sse41(a: __m128, b: __m128, c: __m128) -> __m128 {
186            _mm_sub_ps(c, _mm_mul_ps(a, b))
187        }
188        for Sse2(a: __m128, b: __m128, c: __m128) -> __m128 {
189            _mm_sub_ps(c, _mm_mul_ps(a, b))
190        }
191        for Scalar(a: f32, b: f32, c: f32) -> f32 {
192            c - a * b
193        }
194        for Neon(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
195            vfmsq_f32(c, a, b)
196        }
197        for Wasm(a: v128, b: v128, c: v128) -> v128 {
198            f32x4_sub(c, f32x4_mul(a, b))
199        }
200    }
201}
202
203impl_op! {
204    fn neg_mul_sub<f32> {
205        for Avx512(a: __m512, b: __m512, c: __m512) -> __m512 {
206            _mm512_fnmsub_ps(a, b, c)
207        }
208        for Avx2(a: __m256, b: __m256, c: __m256) -> __m256 {
209            _mm256_fnmsub_ps(a, b, c)
210        }
211        for Sse41(a: __m128, b: __m128, c: __m128) -> __m128 {
212            let mul = _mm_mul_ps(a, b);
213            let neg = _mm_sub_ps(_mm_setzero_ps(), mul);
214            _mm_sub_ps(neg, c)
215        }
216        for Sse2(a: __m128, b: __m128, c: __m128) -> __m128 {
217            let mul = _mm_mul_ps(a, b);
218            let neg = _mm_sub_ps(_mm_setzero_ps(), mul);
219            _mm_sub_ps(neg, c)
220        }
221        for Scalar(a: f32, b: f32, c: f32) -> f32 {
222            -a * b - c
223        }
224        for Neon(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {
225            vnegq_f32(vfmaq_f32(c, a, b))
226        }
227        for Wasm(a: v128, b: v128, c: v128) -> v128 {
228            f32x4_sub(f32x4_neg(f32x4_mul(a, b)), c)
229        }
230    }
231}
232
233impl_op! {
234    fn sqrt<f32> {
235        for Avx512(a: __m512) -> __m512 {
236            _mm512_sqrt_ps(a)
237        }
238        for Avx2(a: __m256) -> __m256 {
239            _mm256_sqrt_ps(a)
240        }
241        for Sse41(a: __m128) -> __m128 {
242            _mm_sqrt_ps(a)
243        }
244        for Sse2(a: __m128) -> __m128 {
245            _mm_sqrt_ps(a)
246        }
247        for Scalar(a: f32) -> f32 {
248            a.m_sqrt()
249        }
250        for Neon(a: float32x4_t) -> float32x4_t {
251            vsqrtq_f32(a)
252        }
253        for Wasm(a: v128) -> v128 {
254            f32x4_sqrt(a)
255        }
256    }
257}
258
259impl_op! {
260    fn recip<f32> {
261        for Avx512(a: __m512) -> __m512 {
262            _mm512_rcp14_ps(a)
263        }
264        for Avx2(a: __m256) -> __m256 {
265            _mm256_rcp_ps(a)
266        }
267        for Sse41(a: __m128) -> __m128 {
268            _mm_rcp_ps(a)
269        }
270        for Sse2(a: __m128) -> __m128 {
271            _mm_rcp_ps(a)
272        }
273        for Scalar(a: f32) -> f32 {
274            1.0 / a
275        }
276        for Neon(a: float32x4_t) -> float32x4_t {
277            vrecpeq_f32(a)
278        }
279        for Wasm(a: v128) -> v128 {
280            f32x4_div(f32x4_splat(1.0), a)
281        }
282    }
283}
284
285impl_op! {
286    fn rsqrt<f32> {
287        for Avx512(a: __m512) -> __m512 {
288            _mm512_rsqrt14_ps(a)
289        }
290        for Avx2(a: __m256) -> __m256 {
291            _mm256_rsqrt_ps(a)
292        }
293        for Sse41(a: __m128) -> __m128 {
294            _mm_rsqrt_ps(a)
295        }
296        for Sse2(a: __m128) -> __m128 {
297            _mm_rsqrt_ps(a)
298        }
299        for Scalar(a: f32) -> f32 {
300            1.0 / a.m_sqrt()
301        }
302        for Neon(a: float32x4_t) -> float32x4_t {
303            vrsqrteq_f32(a)
304        }
305        for Wasm(a: v128) -> v128 {
306            f32x4_div(f32x4_splat(1.0), f32x4_sqrt(a))
307        }
308    }
309}
310
311impl_op! {
312    fn min<f32> {
313        for Avx512(a: __m512, b: __m512) -> __m512 {
314            _mm512_min_ps(a, b)
315        }
316        for Avx2(a: __m256, b: __m256) -> __m256 {
317            _mm256_min_ps(a, b)
318        }
319        for Sse41(a: __m128, b: __m128) -> __m128 {
320            _mm_min_ps(a, b)
321        }
322        for Sse2(a: __m128, b: __m128) -> __m128 {
323            _mm_min_ps(a, b)
324        }
325        for Scalar(a: f32, b: f32) -> f32 {
326            scalar_simd_min(a, b)
327        }
328        for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
329            vminq_f32(a, b)
330        }
331        for Wasm(a: v128, b: v128) -> v128 {
332            f32x4_min(a, b)
333        }
334    }
335}
336
337impl_op! {
338    fn max<f32> {
339        for Avx512(a: __m512, b: __m512) -> __m512 {
340            _mm512_max_ps(a, b)
341        }
342        for Avx2(a: __m256, b: __m256) -> __m256 {
343            _mm256_max_ps(a, b)
344        }
345        for Sse41(a: __m128, b: __m128) -> __m128 {
346            _mm_max_ps(a, b)
347        }
348        for Sse2(a: __m128, b: __m128) -> __m128 {
349            _mm_max_ps(a, b)
350        }
351        for Scalar(a: f32, b: f32) -> f32 {
352            scalar_simd_max(a, b)
353        }
354        for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
355            vmaxq_f32(a, b)
356        }
357        for Wasm(a: v128, b: v128) -> v128 {
358            f32x4_max(a, b)
359        }
360    }
361}
362
363impl_op! {
364    fn abs<f32> {
365        for Avx512(a: __m512) -> __m512 {
366            _mm512_andnot_ps(_mm512_set1_ps(-0.0), a)
367        }
368        for Avx2(a: __m256) -> __m256 {
369            _mm256_andnot_ps(_mm256_set1_ps(-0.0), a)
370        }
371        for Sse41(a: __m128) -> __m128 {
372            _mm_andnot_ps(_mm_set1_ps(-0.0), a)
373        }
374        for Sse2(a: __m128) -> __m128 {
375            _mm_andnot_ps(_mm_set1_ps(-0.0), a)
376        }
377        for Scalar(a: f32) -> f32 {
378            a.m_abs()
379        }
380        for Neon(a: float32x4_t) -> float32x4_t {
381            vabsq_f32(a)
382        }
383        for Wasm(a: v128) -> v128 {
384            f32x4_abs(a)
385        }
386    }
387}
388
389impl_op! {
390    fn round<f32> {
391        for Avx512(a: __m512) -> __m512 {
392            _mm512_roundscale_ps::<0x08>(a)
393        }
394        for Avx2(a: __m256) -> __m256 {
395            _mm256_round_ps(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
396        }
397        for Sse41(a: __m128) -> __m128 {
398            _mm_round_ps(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
399        }
400        for Sse2(a: __m128) -> __m128 {
401            let sign_mask = _mm_set1_ps(-0.0);
402            let magic = _mm_castsi128_ps(_mm_set1_epi32(0x4B000000));
403            let sign = _mm_and_ps(a, sign_mask);
404            let signed_magic = _mm_or_ps(magic, sign);
405            let b = _mm_add_ps(a, signed_magic);
406            _mm_sub_ps(b, signed_magic)
407        }
408        for Scalar(a: f32) -> f32 {
409            a.m_round()
410        }
411        for Neon(a: float32x4_t) -> float32x4_t {
412            vrndaq_f32(a)
413        }
414        for Wasm(a: v128) -> v128 {
415            f32x4_nearest(a)
416        }
417    }
418}
419
420impl_op! {
421    fn floor<f32> {
422        for Avx512(a: __m512) -> __m512 {
423            _mm512_roundscale_ps::<0x09>(a)
424        }
425        for Avx2(a: __m256) -> __m256 {
426            _mm256_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)
427        }
428        for Sse41(a: __m128) -> __m128 {
429            _mm_round_ps(a, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)
430        }
431        for Sse2(a: __m128) -> __m128 {
432            let nums_arr = core::mem::transmute::<__m128, [f32; 4]>(a);
433            let ceil = [
434                nums_arr[0].m_floor(),
435                nums_arr[1].m_floor(),
436                nums_arr[2].m_floor(),
437                nums_arr[3].m_floor(),
438            ];
439            core::mem::transmute::<[f32; 4], __m128>(ceil)
440        }
441        for Scalar(a: f32) -> f32 {
442            a.m_floor()
443        }
444        for Neon(a: float32x4_t) -> float32x4_t {
445            vrndmq_f32(a)
446        }
447        for Wasm(a: v128) -> v128 {
448            f32x4_floor(a)
449        }
450    }
451}
452
453impl_op! {
454    fn ceil<f32> {
455        for Avx512(a: __m512) -> __m512 {
456            _mm512_roundscale_ps::<0x0A>(a)
457        }
458        for Avx2(a: __m256) -> __m256 {
459            _mm256_round_ps(a, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)
460        }
461        for Sse41(a: __m128) -> __m128 {
462            _mm_round_ps(a, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)
463        }
464        for Sse2(a: __m128) -> __m128 {
465            let nums_arr = core::mem::transmute::<__m128, [f32; 4]>(a);
466            let ceil = [
467                nums_arr[0].m_ceil(),
468                nums_arr[1].m_ceil(),
469                nums_arr[2].m_ceil(),
470                nums_arr[3].m_ceil(),
471            ];
472            core::mem::transmute::<[f32; 4], __m128>(ceil)
473        }
474        for Scalar(a: f32) -> f32 {
475            a.m_ceil()
476        }
477        for Neon(a: float32x4_t) -> float32x4_t {
478            vrndpq_f32(a)
479        }
480        for Wasm(a: v128) -> v128 {
481            f32x4_ceil(a)
482        }
483    }
484}
485
486impl_op! {
487    fn fast_round<f32> {
488        for Avx512(a: __m512) -> __m512 {
489            Self::round(a)
490        }
491        for Avx2(a: __m256) -> __m256 {
492            Self::round(a)
493        }
494        for Sse41(a: __m128) -> __m128 {
495            Self::round(a)
496        }
497        for Sse2(a: __m128) -> __m128 {
498            Self::round(a)
499        }
500        for Scalar(a: f32) -> f32 {
501            Self::round(a)
502        }
503        for Neon(a: float32x4_t) -> float32x4_t {
504            Self::round(a)
505        }
506        for Wasm(a: v128) -> v128 {
507            Self::round(a)
508        }
509    }
510}
511
512impl_op! {
513    fn fast_floor<f32> {
514        for Avx512(a: __m512) -> __m512 {
515            Self::floor(a)
516        }
517        for Avx2(a: __m256) -> __m256 {
518            Self::floor(a)
519        }
520        for Sse41(a: __m128) -> __m128 {
521            Self::floor(a)
522        }
523        for Sse2(a: __m128) -> __m128 {
524            Self::floor(a)
525        }
526        for Scalar(a: f32) -> f32 {
527            Self::floor(a)
528        }
529        for Neon(a: float32x4_t) -> float32x4_t {
530            Self::floor(a)
531        }
532        for Wasm(a: v128) -> v128 {
533            Self::floor(a)
534        }
535    }
536}
537
538impl_op! {
539    fn fast_ceil<f32> {
540        for Avx512(a: __m512) -> __m512 {
541            Self::ceil(a)
542        }
543        for Avx2(a: __m256) -> __m256 {
544            Self::ceil(a)
545        }
546        for Sse41(a: __m128) -> __m128 {
547            Self::ceil(a)
548        }
549        for Sse2(a: __m128) -> __m128 {
550            Self::ceil(a)
551        }
552        for Scalar(a: f32) -> f32 {
553            Self::ceil(a)
554        }
555        for Neon(a: float32x4_t) -> float32x4_t {
556            Self::ceil(a)
557        }
558        for Wasm(a: v128) -> v128 {
559            Self::ceil(a)
560        }
561    }
562}
563
564impl_op! {
565    fn eq<f32> {
566        for Avx512(a: __m512, b: __m512) -> __m512 {
567            let k = _mm512_cmp_ps_mask::<_CMP_EQ_OQ>(a, b);
568            _mm512_castsi512_ps(_mm512_movm_epi32(k))
569        }
570        for Avx2(a: __m256, b: __m256) -> __m256 {
571            _mm256_cmp_ps(a, b, _CMP_EQ_OQ)
572        }
573        for Sse41(a: __m128, b: __m128) -> __m128 {
574            _mm_cmpeq_ps(a, b)
575        }
576        for Sse2(a: __m128, b: __m128) -> __m128 {
577            _mm_cmpeq_ps(a, b)
578        }
579        for Scalar(a: f32, b: f32) -> f32 {
580            if a == b {
581                f32::from_bits(u32::MAX)
582            } else {
583                0.0
584            }
585        }
586        for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
587            vreinterpretq_f32_u32(vceqq_f32(a, b))
588        }
589        for Wasm(a: v128, b: v128) -> v128 {
590            f32x4_eq(a, b)
591        }
592    }
593}
594
595impl_op! {
596    fn neq<f32> {
597        for Avx512(a: __m512, b: __m512) -> __m512 {
598            let k = _mm512_cmp_ps_mask::<_CMP_NEQ_UQ>(a, b);
599            _mm512_castsi512_ps(_mm512_movm_epi32(k))
600        }
601        for Avx2(a: __m256, b: __m256) -> __m256 {
602            _mm256_cmp_ps(a, b, _CMP_NEQ_UQ)
603        }
604        for Sse41(a: __m128, b: __m128) -> __m128 {
605            _mm_cmpneq_ps(a, b)
606        }
607        for Sse2(a: __m128, b: __m128) -> __m128 {
608            _mm_cmpneq_ps(a, b)
609        }
610        for Scalar(a: f32, b: f32) -> f32 {
611            if a != b {
612                f32::from_bits(u32::MAX)
613            } else {
614                0.0
615            }
616        }
617        for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
618            vreinterpretq_f32_u32(vmvnq_u32(vceqq_f32(a, b)))
619        }
620        for Wasm(a: v128, b: v128) -> v128 {
621            f32x4_ne(a, b)
622        }
623    }
624}
625
626impl_op! {
627    fn lt<f32> {
628        for Avx512(a: __m512, b: __m512) -> __m512 {
629            let k = _mm512_cmp_ps_mask::<_CMP_LT_OQ>(a, b);
630            _mm512_castsi512_ps(_mm512_movm_epi32(k))
631        }
632        for Avx2(a: __m256, b: __m256) -> __m256 {
633            _mm256_cmp_ps(a, b, _CMP_LT_OQ)
634        }
635        for Sse41(a: __m128, b: __m128) -> __m128 {
636            _mm_cmplt_ps(a, b)
637        }
638        for Sse2(a: __m128, b: __m128) -> __m128 {
639            _mm_cmplt_ps(a, b)
640        }
641        for Scalar(a: f32, b: f32) -> f32 {
642            if a < b {
643                f32::from_bits(u32::MAX)
644            } else {
645                0.0
646            }
647        }
648        for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
649            vreinterpretq_f32_u32(vcltq_f32(a, b))
650        }
651        for Wasm(a: v128, b: v128) -> v128 {
652            f32x4_lt(a, b)
653        }
654    }
655}
656
657impl_op! {
658    fn lte<f32> {
659        for Avx512(a: __m512, b: __m512) -> __m512 {
660            let k = _mm512_cmp_ps_mask::<_CMP_LE_OQ>(a, b);
661            _mm512_castsi512_ps(_mm512_movm_epi32(k))
662        }
663        for Avx2(a: __m256, b: __m256) -> __m256 {
664            _mm256_cmp_ps(a, b, _CMP_LE_OQ)
665        }
666        for Sse41(a: __m128, b: __m128) -> __m128 {
667            _mm_cmple_ps(a, b)
668        }
669        for Sse2(a: __m128, b: __m128) -> __m128 {
670            _mm_cmple_ps(a, b)
671        }
672        for Scalar(a: f32, b: f32) -> f32 {
673            if a <= b {
674                f32::from_bits(u32::MAX)
675            } else {
676                0.0
677            }
678        }
679        for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
680            vreinterpretq_f32_u32(vcleq_f32(a, b))
681        }
682        for Wasm(a: v128, b: v128) -> v128 {
683            f32x4_le(a, b)
684        }
685    }
686}
687
688impl_op! {
689    fn gt<f32> {
690        for Avx512(a: __m512, b: __m512) -> __m512 {
691            let k = _mm512_cmp_ps_mask::<_CMP_GT_OQ>(a, b);
692            _mm512_castsi512_ps(_mm512_movm_epi32(k))
693        }
694        for Avx2(a: __m256, b: __m256) -> __m256 {
695            _mm256_cmp_ps(a, b, _CMP_GT_OQ)
696        }
697        for Sse41(a: __m128, b: __m128) -> __m128 {
698            _mm_cmpgt_ps(a, b)
699        }
700        for Sse2(a: __m128, b: __m128) -> __m128 {
701            _mm_cmpgt_ps(a, b)
702        }
703        for Scalar(a: f32, b: f32) -> f32 {
704            if a > b {
705                f32::from_bits(u32::MAX)
706            } else {
707                0.0
708            }
709        }
710        for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
711            vreinterpretq_f32_u32(vcgtq_f32(a, b))
712        }
713        for Wasm(a: v128, b: v128) -> v128 {
714            f32x4_gt(a, b)
715        }
716    }
717}
718
719impl_op! {
720    fn gte<f32> {
721        for Avx512(a: __m512, b: __m512) -> __m512 {
722            let k = _mm512_cmp_ps_mask::<_CMP_GE_OQ>(a, b);
723            _mm512_castsi512_ps(_mm512_movm_epi32(k))
724        }
725        for Avx2(a: __m256, b: __m256) -> __m256 {
726            _mm256_cmp_ps(a, b, _CMP_GE_OQ)
727        }
728        for Sse41(a: __m128, b: __m128) -> __m128 {
729            _mm_cmpge_ps(a, b)
730        }
731        for Sse2(a: __m128, b: __m128) -> __m128 {
732            _mm_cmpge_ps(a, b)
733        }
734        for Scalar(a: f32, b: f32) -> f32 {
735            if a >= b {
736                f32::from_bits(u32::MAX)
737            } else {
738                0.0
739            }
740        }
741        for Neon(a: float32x4_t, b: float32x4_t) -> float32x4_t {
742            vreinterpretq_f32_u32(vcgeq_f32(a, b))
743        }
744        for Wasm(a: v128, b: v128) -> v128 {
745            f32x4_ge(a, b)
746        }
747    }
748}
749
750impl_op! {
751    fn blendv<f32> {
752        for Avx512(a: __m512, b: __m512, mask: __m512) -> __m512 {
753            let k = _mm512_movepi32_mask(_mm512_castps_si512(mask));
754            _mm512_mask_blend_ps(k, a, b)
755        }
756        for Avx2(a: __m256, b: __m256, mask: __m256) -> __m256 {
757            _mm256_blendv_ps(a, b, mask)
758        }
759        for Sse41(a: __m128, b: __m128, mask: __m128) -> __m128 {
760            _mm_blendv_ps(a, b, mask)
761        }
762        for Sse2(a: __m128, b: __m128, mask: __m128) -> __m128 {
763            _mm_or_ps(_mm_and_ps(mask, b), _mm_andnot_ps(mask, a))
764        }
765        for Scalar(a: f32, b: f32, mask: f32) -> f32 {
766            if mask.to_bits() == 0 {
767                a
768            } else {
769                b
770            }
771        }
772        for Neon(a: float32x4_t, b: float32x4_t, mask: float32x4_t) -> float32x4_t {
773            vbslq_f32(vreinterpretq_u32_f32(mask), b, a)
774        }
775        for Wasm(a: v128, b: v128, mask: v128) -> v128 {
776            v128_or(v128_and(mask, b), v128_andnot(a, mask))
777        }
778    }
779}
780
781impl_op! {
782    fn horizontal_add<f32> {
783        for Avx512(a: __m512) -> f32 {
784            _mm512_reduce_add_ps(a)
785        }
786        for Avx2(a: __m256) -> f32 {
787            let a = _mm256_hadd_ps(a, a);
788            let b = _mm256_hadd_ps(a, a);
789
790            let first = _mm_cvtss_f32(_mm256_extractf128_ps(b, 0));
791            let second = _mm_cvtss_f32(_mm256_extractf128_ps(b, 1));
792
793            first + second
794        }
795        for Sse41(a: __m128) -> f32 {
796            let a = _mm_hadd_ps(a, a);
797            let b = _mm_hadd_ps(a, a);
798
799            _mm_cvtss_f32(b)
800        }
801        for Sse2(a: __m128) -> f32 {
802            let t1 = _mm_movehl_ps(a, a);
803            let t2 = _mm_add_ps(a, t1);
804            let t3 = _mm_shuffle_ps(t2, t2, 1);
805            _mm_cvtss_f32(t2) + _mm_cvtss_f32(t3)
806        }
807        for Scalar(a: f32) -> f32 {
808            a
809        }
810        for Neon(a: float32x4_t) -> f32 {
811            let a = vpaddq_f32(a, a);
812            let a = vpaddq_f32(a, a);
813            vgetq_lane_f32(a, 0)
814        }
815        for Wasm(a: v128) -> f32 {
816            let l0 = f32x4_extract_lane::<0>(a);
817            let l1 = f32x4_extract_lane::<1>(a);
818            let l2 = f32x4_extract_lane::<2>(a);
819            let l3 = f32x4_extract_lane::<3>(a);
820            l0 + l1 + l2 + l3
821        }
822    }
823}
824
825impl_op! {
826    fn cast_i32<f32> {
827        for Avx512(a: __m512) -> __m512i {
828            _mm512_cvtps_epi32(a)
829        }
830        for Avx2(a: __m256) -> __m256i {
831            _mm256_cvtps_epi32(a)
832        }
833        for Sse41(a: __m128) -> __m128i {
834            _mm_cvtps_epi32(a)
835        }
836        for Sse2(a: __m128) -> __m128i {
837            _mm_cvtps_epi32(a)
838        }
839        for Scalar(a: f32) -> i32 {
840            a.m_round_ties_even() as i32
841        }
842        for Neon(a: float32x4_t) -> int32x4_t {
843            // Because other intrinsics round instead of flooring, we round here first.
844            let a = vrndnq_f32(a);
845            vcvtq_s32_f32(a)
846        }
847        for Wasm(a: v128) -> v128 {
848            let a = f32x4_nearest(a);
849            i32x4_trunc_sat_f32x4(a)
850        }
851    }
852}
853
854impl_op! {
855    fn bitcast_i32<f32> {
856        for Avx512(a: __m512) -> __m512i {
857            _mm512_castps_si512(a)
858        }
859        for Avx2(a: __m256) -> __m256i {
860            _mm256_castps_si256(a)
861        }
862        for Sse41(a: __m128) -> __m128i {
863            _mm_castps_si128(a)
864        }
865        for Sse2(a: __m128) -> __m128i {
866            _mm_castps_si128(a)
867        }
868        for Scalar(a: f32) -> i32 {
869            a.to_bits() as i32
870        }
871        for Neon(a: float32x4_t) -> int32x4_t {
872            vreinterpretq_s32_f32(a)
873        }
874        for Wasm(a: v128) -> v128 {
875            a
876        }
877    }
878}
879
880impl_op! {
881    fn zeroes<f32> {
882        for Avx512() -> __m512 {
883            _mm512_setzero_ps()
884        }
885        for Avx2() -> __m256 {
886            _mm256_setzero_ps()
887        }
888        for Sse41() -> __m128 {
889            _mm_setzero_ps()
890        }
891        for Sse2() -> __m128 {
892            _mm_setzero_ps()
893        }
894        for Scalar() -> f32 {
895            0.0
896        }
897        for Neon() -> float32x4_t {
898            vdupq_n_f32(0.0)
899        }
900        for Wasm() -> v128 {
901            f32x4_splat(0.0)
902        }
903    }
904}
905
906impl_op! {
907    fn set1<f32> {
908        for Avx512(val: f32) -> __m512 {
909            _mm512_set1_ps(val)
910        }
911        for Avx2(val: f32) -> __m256 {
912            _mm256_set1_ps(val)
913        }
914        for Sse41(val: f32) -> __m128 {
915            _mm_set1_ps(val)
916        }
917        for Sse2(val: f32) -> __m128 {
918            _mm_set1_ps(val)
919        }
920        for Scalar(val: f32) -> f32 {
921            val
922        }
923        for Neon(val: f32) -> float32x4_t {
924            vdupq_n_f32(val)
925        }
926        for Wasm(val: f32) -> v128 {
927            f32x4_splat(val)
928        }
929    }
930}
931
932impl_op! {
933    fn load_unaligned<f32> {
934        for Avx512(ptr: *const f32) -> __m512 {
935            _mm512_loadu_ps(ptr)
936        }
937        for Avx2(ptr: *const f32) -> __m256 {
938            _mm256_loadu_ps(ptr)
939        }
940        for Sse41(ptr: *const f32) -> __m128 {
941            _mm_loadu_ps(ptr)
942        }
943        for Sse2(ptr: *const f32) -> __m128 {
944            _mm_loadu_ps(ptr)
945        }
946        for Scalar(ptr: *const f32) -> f32 {
947            unsafe { *ptr }
948        }
949        for Neon(ptr: *const f32) -> float32x4_t {
950            vld1q_f32(ptr)
951        }
952        for Wasm(ptr: *const f32) -> v128 {
953            unsafe { v128_load(ptr as *const v128) }
954        }
955    }
956}
957
958impl_op! {
959    fn load_aligned<f32> {
960        for Avx512(ptr: *const f32) -> __m512 {
961            _mm512_load_ps(ptr)
962        }
963        for Avx2(ptr: *const f32) -> __m256 {
964            _mm256_load_ps(ptr)
965        }
966        for Sse41(ptr: *const f32) -> __m128 {
967            _mm_load_ps(ptr)
968        }
969        for Sse2(ptr: *const f32) -> __m128 {
970            _mm_load_ps(ptr)
971        }
972        for Scalar(ptr: *const f32) -> f32 {
973            unsafe { *ptr }
974        }
975        for Neon(ptr: *const f32) -> float32x4_t {
976            vld1q_f32(ptr)
977        }
978        for Wasm(ptr: *const f32) -> v128 {
979            *(ptr as *const v128)
980        }
981    }
982}
983
984impl_op! {
985    fn store_unaligned<f32> {
986        for Avx512(ptr: *mut f32, a: __m512) {
987            _mm512_storeu_ps(ptr, a)
988        }
989        for Avx2(ptr: *mut f32, a: __m256) {
990            _mm256_storeu_ps(ptr, a)
991        }
992        for Sse41(ptr: *mut f32, a: __m128) {
993            _mm_storeu_ps(ptr, a)
994        }
995        for Sse2(ptr: *mut f32, a: __m128) {
996            _mm_storeu_ps(ptr, a)
997        }
998        for Scalar(ptr: *mut f32, a: f32) {
999            unsafe { *ptr = a }
1000        }
1001        for Neon(ptr: *mut f32, a: float32x4_t) {
1002            vst1q_f32(ptr, a)
1003        }
1004        for Wasm(ptr: *mut f32, a: v128) {
1005            unsafe { v128_store(ptr as *mut v128, a) }
1006        }
1007    }
1008}
1009
1010impl_op! {
1011    fn store_aligned<f32> {
1012        for Avx512(ptr: *mut f32, a: __m512) {
1013            _mm512_store_ps(ptr, a)
1014        }
1015        for Avx2(ptr: *mut f32, a: __m256) {
1016            _mm256_store_ps(ptr, a)
1017        }
1018        for Sse41(ptr: *mut f32, a: __m128) {
1019            _mm_store_ps(ptr, a)
1020        }
1021        for Sse2(ptr: *mut f32, a: __m128) {
1022            _mm_store_ps(ptr, a)
1023        }
1024        for Scalar(ptr: *mut f32, a: f32) {
1025            unsafe { *ptr = a }
1026        }
1027        for Neon(ptr: *mut f32, a: float32x4_t) {
1028            vst1q_f32(ptr, a)
1029        }
1030        for Wasm(ptr: *mut f32, a: v128) {
1031            *(ptr as *mut v128) = a;
1032        }
1033    }
1034}