lance_linalg/simd/
f32.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! `f32x8`, 8 of `f32` values.s
5
6use std::fmt::Formatter;
7
8#[cfg(target_arch = "aarch64")]
9use std::arch::aarch64::*;
10#[cfg(target_arch = "loongarch64")]
11use std::arch::loongarch64::*;
12#[cfg(target_arch = "x86_64")]
13use std::arch::x86_64::*;
14#[cfg(target_arch = "loongarch64")]
15use std::mem::transmute;
16use std::ops::{Add, AddAssign, Mul, Sub, SubAssign};
17
18use super::{FloatSimd, SIMD};
19
20/// 8 of 32-bit `f32` values. Use 256-bit SIMD if possible.
21#[allow(non_camel_case_types)]
22#[cfg(target_arch = "x86_64")]
23#[derive(Clone, Copy)]
24pub struct f32x8(std::arch::x86_64::__m256);
25
26/// 8 of 32-bit `f32` values. Use 256-bit SIMD if possible.
27#[allow(non_camel_case_types)]
28#[cfg(target_arch = "aarch64")]
29#[derive(Clone, Copy)]
30pub struct f32x8(float32x4x2_t);
31
32/// 8 of 32-bit `f32` values. Use 256-bit SIMD if possible.
33#[allow(non_camel_case_types)]
34#[cfg(target_arch = "loongarch64")]
35#[derive(Clone, Copy)]
36pub struct f32x8(v8f32);
37
38impl std::fmt::Debug for f32x8 {
39    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
40        let mut arr = [0.0_f32; 8];
41        unsafe {
42            self.store_unaligned(arr.as_mut_ptr());
43        }
44        write!(f, "f32x8({:?})", arr)
45    }
46}
47
48impl f32x8 {
49    #[inline]
50    pub fn gather(slice: &[f32], indices: &[i32; 8]) -> Self {
51        #[cfg(target_arch = "x86_64")]
52        unsafe {
53            use super::i32::i32x8;
54
55            let idx = i32x8::from(indices);
56            Self(_mm256_i32gather_ps::<4>(slice.as_ptr(), idx.0))
57        }
58
59        #[cfg(target_arch = "aarch64")]
60        unsafe {
61            // aarch64 does not have relevant SIMD instructions.
62            let ptr = slice.as_ptr();
63
64            let values = [
65                *ptr.add(indices[0] as usize),
66                *ptr.add(indices[1] as usize),
67                *ptr.add(indices[2] as usize),
68                *ptr.add(indices[3] as usize),
69                *ptr.add(indices[4] as usize),
70                *ptr.add(indices[5] as usize),
71                *ptr.add(indices[6] as usize),
72                *ptr.add(indices[7] as usize),
73            ];
74            Self::load_unaligned(values.as_ptr())
75        }
76
77        #[cfg(target_arch = "loongarch64")]
78        unsafe {
79            // loongarch64 does not have relevant SIMD instructions.
80            let ptr = slice.as_ptr();
81
82            let values = [
83                *ptr.add(indices[0] as usize),
84                *ptr.add(indices[1] as usize),
85                *ptr.add(indices[2] as usize),
86                *ptr.add(indices[3] as usize),
87                *ptr.add(indices[4] as usize),
88                *ptr.add(indices[5] as usize),
89                *ptr.add(indices[6] as usize),
90                *ptr.add(indices[7] as usize),
91            ];
92            Self::load_unaligned(values.as_ptr())
93        }
94    }
95}
96
97impl From<&[f32]> for f32x8 {
98    fn from(value: &[f32]) -> Self {
99        unsafe { Self::load_unaligned(value.as_ptr()) }
100    }
101}
102
103impl<'a> From<&'a [f32; 8]> for f32x8 {
104    fn from(value: &'a [f32; 8]) -> Self {
105        unsafe { Self::load_unaligned(value.as_ptr()) }
106    }
107}
108
109impl SIMD<f32, 8> for f32x8 {
110    fn splat(val: f32) -> Self {
111        #[cfg(target_arch = "x86_64")]
112        unsafe {
113            Self(_mm256_set1_ps(val))
114        }
115        #[cfg(target_arch = "aarch64")]
116        unsafe {
117            Self(float32x4x2_t(vdupq_n_f32(val), vdupq_n_f32(val)))
118        }
119        #[cfg(target_arch = "loongarch64")]
120        unsafe {
121            Self(transmute(lasx_xvreplgr2vr_w(transmute(val))))
122        }
123    }
124
125    fn zeros() -> Self {
126        #[cfg(target_arch = "x86_64")]
127        unsafe {
128            Self(_mm256_setzero_ps())
129        }
130        #[cfg(target_arch = "aarch64")]
131        {
132            Self::splat(0.0)
133        }
134        #[cfg(target_arch = "loongarch64")]
135        {
136            Self::splat(0.0)
137        }
138    }
139
140    #[inline]
141    unsafe fn load(ptr: *const f32) -> Self {
142        #[cfg(target_arch = "x86_64")]
143        unsafe {
144            Self(_mm256_load_ps(ptr))
145        }
146        #[cfg(target_arch = "aarch64")]
147        {
148            Self::load_unaligned(ptr)
149        }
150        #[cfg(target_arch = "loongarch64")]
151        {
152            Self(transmute(lasx_xvld::<0>(transmute(ptr))))
153        }
154    }
155
156    #[inline]
157    unsafe fn load_unaligned(ptr: *const f32) -> Self {
158        #[cfg(target_arch = "x86_64")]
159        unsafe {
160            Self(_mm256_loadu_ps(ptr))
161        }
162        #[cfg(target_arch = "aarch64")]
163        {
164            Self(vld1q_f32_x2(ptr))
165        }
166        #[cfg(target_arch = "loongarch64")]
167        {
168            Self(transmute(lasx_xvld::<0>(transmute(ptr))))
169        }
170    }
171
172    unsafe fn store(&self, ptr: *mut f32) {
173        #[cfg(target_arch = "x86_64")]
174        unsafe {
175            _mm256_store_ps(ptr, self.0);
176        }
177        #[cfg(target_arch = "aarch64")]
178        unsafe {
179            vst1q_f32_x2(ptr, self.0);
180        }
181        #[cfg(target_arch = "loongarch64")]
182        unsafe {
183            lasx_xvst::<0>(transmute(self.0), transmute(ptr));
184        }
185    }
186
187    unsafe fn store_unaligned(&self, ptr: *mut f32) {
188        #[cfg(target_arch = "x86_64")]
189        unsafe {
190            _mm256_storeu_ps(ptr, self.0);
191        }
192        #[cfg(target_arch = "aarch64")]
193        unsafe {
194            vst1q_f32_x2(ptr, self.0);
195        }
196        #[cfg(target_arch = "loongarch64")]
197        unsafe {
198            lasx_xvst::<0>(transmute(self.0), transmute(ptr));
199        }
200    }
201
202    #[inline]
203    fn reduce_sum(&self) -> f32 {
204        #[cfg(target_arch = "x86_64")]
205        unsafe {
206            let mut sum = self.0;
207            // Shift and add vector, until only 1 value left.
208            // sums = [x0-x7], shift = [x4-x7]
209            let mut shift = _mm256_permute2f128_ps(sum, sum, 1);
210            // [x0+x4, x1+x5, ..]
211            sum = _mm256_add_ps(sum, shift);
212            shift = _mm256_permute_ps(sum, 14);
213            sum = _mm256_add_ps(sum, shift);
214            sum = _mm256_hadd_ps(sum, sum);
215            let mut results: [f32; 8] = [0f32; 8];
216            _mm256_storeu_ps(results.as_mut_ptr(), sum);
217            results[0]
218        }
219        #[cfg(target_arch = "aarch64")]
220        unsafe {
221            let sum = vaddq_f32(self.0 .0, self.0 .1);
222            vaddvq_f32(sum)
223        }
224        #[cfg(target_arch = "loongarch64")]
225        {
226            self.as_array().iter().sum()
227        }
228    }
229
230    fn reduce_min(&self) -> f32 {
231        #[cfg(target_arch = "x86_64")]
232        {
233            unsafe {
234                let mut min = self.0;
235                // Shift and add vector, until only 1 value left.
236                // sums = [x0-x7], shift = [x4-x7]
237                let mut shift = _mm256_permute2f128_ps(min, min, 1);
238                // [x0+x4, x1+x5, ..]
239                min = _mm256_min_ps(min, shift);
240                shift = _mm256_permute_ps(min, 14);
241                min = _mm256_min_ps(min, shift);
242                shift = _mm256_permute_ps(min, 1);
243                min = _mm256_min_ps(min, shift);
244                _mm256_cvtss_f32(min)
245            }
246        }
247        #[cfg(target_arch = "aarch64")]
248        unsafe {
249            let m = vminq_f32(self.0 .0, self.0 .1);
250            vminvq_f32(m)
251        }
252        #[cfg(target_arch = "loongarch64")]
253        unsafe {
254            let m1 = lasx_xvpermi_d::<14>(transmute(self.0));
255            let m2 = lasx_xvfmin_s(transmute(m1), self.0);
256            let m1 = lasx_xvpermi_w::<14>(transmute(m2), transmute(m2));
257            let m2 = lasx_xvfmin_s(transmute(m1), transmute(m2));
258            let m1 = lasx_xvpermi_w::<1>(transmute(m2), transmute(m2));
259            let m2 = lasx_xvfmin_s(transmute(m1), transmute(m2));
260            transmute(lasx_xvpickve2gr_w::<0>(transmute(m2)))
261        }
262    }
263
264    fn min(&self, rhs: &Self) -> Self {
265        #[cfg(target_arch = "x86_64")]
266        unsafe {
267            Self(_mm256_min_ps(self.0, rhs.0))
268        }
269        #[cfg(target_arch = "aarch64")]
270        unsafe {
271            Self(float32x4x2_t(
272                vminq_f32(self.0 .0, rhs.0 .0),
273                vminq_f32(self.0 .1, rhs.0 .1),
274            ))
275        }
276        #[cfg(target_arch = "loongarch64")]
277        unsafe {
278            Self(lasx_xvfmin_s(self.0, rhs.0))
279        }
280    }
281
282    fn find(&self, val: f32) -> Option<i32> {
283        #[cfg(target_arch = "x86_64")]
284        unsafe {
285            for i in 0..8 {
286                if self.as_array().get_unchecked(i) == &val {
287                    return Some(i as i32);
288                }
289            }
290        }
291        #[cfg(target_arch = "aarch64")]
292        unsafe {
293            let tgt = vdupq_n_f32(val);
294            let mut arr = [0; 8];
295            let mask1 = vceqq_f32(self.0 .0, tgt);
296            let mask2 = vceqq_f32(self.0 .1, tgt);
297            vst1q_u32(arr.as_mut_ptr(), mask1);
298            vst1q_u32(arr.as_mut_ptr().add(4), mask2);
299            for i in 0..8 {
300                if arr.get_unchecked(i) != &0 {
301                    return Some(i as i32);
302                }
303            }
304        }
305        #[cfg(target_arch = "loongarch64")]
306        unsafe {
307            for i in 0..8 {
308                if self.as_array().get_unchecked(i) == &val {
309                    return Some(i as i32);
310                }
311            }
312        }
313        None
314    }
315}
316
317impl FloatSimd<f32, 8> for f32x8 {
318    fn multiply_add(&mut self, a: Self, b: Self) {
319        #[cfg(target_arch = "x86_64")]
320        unsafe {
321            self.0 = _mm256_fmadd_ps(a.0, b.0, self.0);
322        }
323        #[cfg(target_arch = "aarch64")]
324        unsafe {
325            self.0 .0 = vfmaq_f32(self.0 .0, a.0 .0, b.0 .0);
326            self.0 .1 = vfmaq_f32(self.0 .1, a.0 .1, b.0 .1);
327        }
328        #[cfg(target_arch = "loongarch64")]
329        unsafe {
330            self.0 = lasx_xvfmadd_s(a.0, b.0, self.0);
331        }
332    }
333}
334
335impl Add for f32x8 {
336    type Output = Self;
337
338    #[inline]
339    fn add(self, rhs: Self) -> Self::Output {
340        #[cfg(target_arch = "x86_64")]
341        unsafe {
342            Self(_mm256_add_ps(self.0, rhs.0))
343        }
344        #[cfg(target_arch = "aarch64")]
345        unsafe {
346            Self(float32x4x2_t(
347                vaddq_f32(self.0 .0, rhs.0 .0),
348                vaddq_f32(self.0 .1, rhs.0 .1),
349            ))
350        }
351        #[cfg(target_arch = "loongarch64")]
352        unsafe {
353            Self(lasx_xvfadd_s(self.0, rhs.0))
354        }
355    }
356}
357
358impl AddAssign for f32x8 {
359    #[inline]
360    fn add_assign(&mut self, rhs: Self) {
361        #[cfg(target_arch = "x86_64")]
362        unsafe {
363            self.0 = _mm256_add_ps(self.0, rhs.0)
364        }
365        #[cfg(target_arch = "aarch64")]
366        unsafe {
367            self.0 .0 = vaddq_f32(self.0 .0, rhs.0 .0);
368            self.0 .1 = vaddq_f32(self.0 .1, rhs.0 .1);
369        }
370        #[cfg(target_arch = "loongarch64")]
371        unsafe {
372            self.0 = lasx_xvfadd_s(self.0, rhs.0);
373        }
374    }
375}
376
377impl Sub for f32x8 {
378    type Output = Self;
379
380    #[inline]
381    fn sub(self, rhs: Self) -> Self::Output {
382        #[cfg(target_arch = "x86_64")]
383        unsafe {
384            Self(_mm256_sub_ps(self.0, rhs.0))
385        }
386        #[cfg(target_arch = "aarch64")]
387        unsafe {
388            Self(float32x4x2_t(
389                vsubq_f32(self.0 .0, rhs.0 .0),
390                vsubq_f32(self.0 .1, rhs.0 .1),
391            ))
392        }
393        #[cfg(target_arch = "loongarch64")]
394        unsafe {
395            Self(lasx_xvfsub_s(self.0, rhs.0))
396        }
397    }
398}
399
400impl SubAssign for f32x8 {
401    #[inline]
402    fn sub_assign(&mut self, rhs: Self) {
403        #[cfg(target_arch = "x86_64")]
404        unsafe {
405            self.0 = _mm256_sub_ps(self.0, rhs.0)
406        }
407        #[cfg(target_arch = "aarch64")]
408        unsafe {
409            self.0 .0 = vsubq_f32(self.0 .0, rhs.0 .0);
410            self.0 .1 = vsubq_f32(self.0 .1, rhs.0 .1);
411        }
412        #[cfg(target_arch = "loongarch64")]
413        unsafe {
414            self.0 = lasx_xvfsub_s(self.0, rhs.0);
415        }
416    }
417}
418
419impl Mul for f32x8 {
420    type Output = Self;
421
422    #[inline]
423    fn mul(self, rhs: Self) -> Self::Output {
424        #[cfg(target_arch = "x86_64")]
425        unsafe {
426            Self(_mm256_mul_ps(self.0, rhs.0))
427        }
428        #[cfg(target_arch = "aarch64")]
429        unsafe {
430            Self(float32x4x2_t(
431                vmulq_f32(self.0 .0, rhs.0 .0),
432                vmulq_f32(self.0 .1, rhs.0 .1),
433            ))
434        }
435        #[cfg(target_arch = "loongarch64")]
436        unsafe {
437            Self(lasx_xvfmul_s(self.0, rhs.0))
438        }
439    }
440}
441
442/// 16 of 32-bit `f32` values. Use 512-bit SIMD if possible.
443#[allow(non_camel_case_types)]
444#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
445#[derive(Clone, Copy)]
446pub struct f32x16(__m256, __m256);
447#[allow(non_camel_case_types)]
448#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
449#[derive(Clone, Copy)]
450pub struct f32x16(__m512);
451
452/// 16 of 32-bit `f32` values. Use 512-bit SIMD if possible.
453#[allow(non_camel_case_types)]
454#[cfg(target_arch = "aarch64")]
455#[derive(Clone, Copy)]
456pub struct f32x16(float32x4x4_t);
457
458/// 16 of 32-bit `f32` values. Use 256-bit SIMD
459#[allow(non_camel_case_types)]
460#[cfg(target_arch = "loongarch64")]
461#[derive(Clone, Copy)]
462pub struct f32x16(v8f32, v8f32);
463
464impl std::fmt::Debug for f32x16 {
465    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
466        let mut arr = [0.0_f32; 16];
467        unsafe {
468            self.store_unaligned(arr.as_mut_ptr());
469        }
470        write!(f, "f32x16({:?})", arr)
471    }
472}
473
474impl From<&[f32]> for f32x16 {
475    fn from(value: &[f32]) -> Self {
476        unsafe { Self::load_unaligned(value.as_ptr()) }
477    }
478}
479
480impl<'a> From<&'a [f32; 16]> for f32x16 {
481    fn from(value: &'a [f32; 16]) -> Self {
482        unsafe { Self::load_unaligned(value.as_ptr()) }
483    }
484}
485
486impl SIMD<f32, 16> for f32x16 {
487    #[inline]
488    fn splat(val: f32) -> Self {
489        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
490        unsafe {
491            Self(_mm512_set1_ps(val))
492        }
493        #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
494        unsafe {
495            Self(_mm256_set1_ps(val), _mm256_set1_ps(val))
496        }
497        #[cfg(target_arch = "aarch64")]
498        unsafe {
499            Self(float32x4x4_t(
500                vdupq_n_f32(val),
501                vdupq_n_f32(val),
502                vdupq_n_f32(val),
503                vdupq_n_f32(val),
504            ))
505        }
506        #[cfg(target_arch = "loongarch64")]
507        unsafe {
508            Self(
509                transmute(lasx_xvreplgr2vr_w(transmute(val))),
510                transmute(lasx_xvreplgr2vr_w(transmute(val))),
511            )
512        }
513    }
514
515    #[inline]
516    fn zeros() -> Self {
517        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
518        unsafe {
519            Self(_mm512_setzero_ps())
520        }
521        #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
522        unsafe {
523            Self(_mm256_setzero_ps(), _mm256_setzero_ps())
524        }
525        #[cfg(target_arch = "aarch64")]
526        {
527            Self::splat(0.0)
528        }
529        #[cfg(target_arch = "loongarch64")]
530        {
531            Self::splat(0.0)
532        }
533    }
534
535    #[inline]
536    unsafe fn load(ptr: *const f32) -> Self {
537        #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
538        unsafe {
539            Self(_mm256_load_ps(ptr), _mm256_load_ps(ptr.add(8)))
540        }
541        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
542        unsafe {
543            Self(_mm512_load_ps(ptr))
544        }
545        #[cfg(target_arch = "aarch64")]
546        {
547            Self::load_unaligned(ptr)
548        }
549        #[cfg(target_arch = "loongarch64")]
550        {
551            Self(
552                transmute(lasx_xvld::<0>(transmute(ptr))),
553                transmute(lasx_xvld::<32>(transmute(ptr))),
554            )
555        }
556    }
557
558    #[inline]
559    unsafe fn load_unaligned(ptr: *const f32) -> Self {
560        #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
561        unsafe {
562            Self(_mm256_loadu_ps(ptr), _mm256_loadu_ps(ptr.add(8)))
563        }
564        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
565        unsafe {
566            Self(_mm512_loadu_ps(ptr))
567        }
568        #[cfg(target_arch = "aarch64")]
569        {
570            Self(vld1q_f32_x4(ptr))
571        }
572        #[cfg(target_arch = "loongarch64")]
573        {
574            Self(
575                transmute(lasx_xvld::<0>(transmute(ptr))),
576                transmute(lasx_xvld::<32>(transmute(ptr))),
577            )
578        }
579    }
580
581    #[inline]
582    unsafe fn store(&self, ptr: *mut f32) {
583        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
584        unsafe {
585            _mm512_store_ps(ptr, self.0)
586        }
587        #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
588        unsafe {
589            _mm256_store_ps(ptr, self.0);
590            _mm256_store_ps(ptr.add(8), self.1);
591        }
592        #[cfg(target_arch = "aarch64")]
593        unsafe {
594            vst1q_f32_x4(ptr, self.0);
595        }
596        #[cfg(target_arch = "loongarch64")]
597        {
598            lasx_xvst::<0>(transmute(self.0), transmute(ptr));
599            lasx_xvst::<32>(transmute(self.1), transmute(ptr));
600        }
601    }
602
603    #[inline]
604    unsafe fn store_unaligned(&self, ptr: *mut f32) {
605        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
606        unsafe {
607            _mm512_storeu_ps(ptr, self.0)
608        }
609        #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
610        unsafe {
611            _mm256_storeu_ps(ptr, self.0);
612            _mm256_storeu_ps(ptr.add(8), self.1);
613        }
614        #[cfg(target_arch = "aarch64")]
615        unsafe {
616            vst1q_f32_x4(ptr, self.0);
617        }
618        #[cfg(target_arch = "loongarch64")]
619        {
620            lasx_xvst::<0>(transmute(self.0), transmute(ptr));
621            lasx_xvst::<32>(transmute(self.1), transmute(ptr));
622        }
623    }
624
625    fn reduce_sum(&self) -> f32 {
626        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
627        unsafe {
628            _mm512_mask_reduce_add_ps(0xFFFF, self.0)
629        }
630        #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
631        unsafe {
632            let mut sum = _mm256_add_ps(self.0, self.1);
633            // Shift and add vector, until only 1 value left.
634            // sums = [x0-x7], shift = [x4-x7]
635            let mut shift = _mm256_permute2f128_ps(sum, sum, 1);
636            // [x0+x4, x1+x5, ..]
637            sum = _mm256_add_ps(sum, shift);
638            shift = _mm256_permute_ps(sum, 14);
639            sum = _mm256_add_ps(sum, shift);
640            sum = _mm256_hadd_ps(sum, sum);
641            let mut results: [f32; 8] = [0f32; 8];
642            _mm256_storeu_ps(results.as_mut_ptr(), sum);
643            results[0]
644        }
645        #[cfg(target_arch = "aarch64")]
646        unsafe {
647            let mut sum1 = vaddq_f32(self.0 .0, self.0 .1);
648            let sum2 = vaddq_f32(self.0 .2, self.0 .3);
649            sum1 = vaddq_f32(sum1, sum2);
650            vaddvq_f32(sum1)
651        }
652        #[cfg(target_arch = "loongarch64")]
653        {
654            self.as_array().iter().sum()
655        }
656    }
657
658    #[inline]
659    fn reduce_min(&self) -> f32 {
660        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
661        unsafe {
662            _mm512_mask_reduce_min_ps(0xFFFF, self.0)
663        }
664        #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
665        unsafe {
666            let mut m1 = _mm256_min_ps(self.0, self.1);
667            let mut m2 = _mm256_permute2f128_ps(m1, m1, 1);
668            m1 = _mm256_min_ps(m1, m2);
669            m2 = _mm256_permute_ps(m1, 14);
670            m1 = _mm256_min_ps(m1, m2);
671            m2 = _mm256_permute_ps(m1, 1);
672            m1 = _mm256_min_ps(m1, m2);
673            _mm256_cvtss_f32(m1)
674        }
675
676        #[cfg(target_arch = "aarch64")]
677        unsafe {
678            let m1 = vminq_f32(self.0 .0, self.0 .1);
679            let m2 = vminq_f32(self.0 .2, self.0 .3);
680            let m = vminq_f32(m1, m2);
681            vminvq_f32(m)
682        }
683        #[cfg(target_arch = "loongarch64")]
684        unsafe {
685            let m1 = lasx_xvfmin_s(self.0, self.1);
686            let m2 = lasx_xvpermi_d::<14>(transmute(m1));
687            let m1 = lasx_xvfmin_s(transmute(m1), transmute(m2));
688            let m2 = lasx_xvpermi_w::<14>(transmute(m1), transmute(m1));
689            let m1 = lasx_xvfmin_s(transmute(m1), transmute(m2));
690            let m2 = lasx_xvpermi_w::<1>(transmute(m1), transmute(m1));
691            let m1 = lasx_xvfmin_s(transmute(m1), transmute(m2));
692            transmute(lasx_xvpickve2gr_w::<0>(transmute(m1)))
693        }
694    }
695
696    #[inline]
697    fn min(&self, rhs: &Self) -> Self {
698        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
699        unsafe {
700            Self(_mm512_min_ps(self.0, rhs.0))
701        }
702        #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
703        unsafe {
704            Self(_mm256_min_ps(self.0, rhs.0), _mm256_min_ps(self.1, rhs.1))
705        }
706        #[cfg(target_arch = "aarch64")]
707        unsafe {
708            Self(float32x4x4_t(
709                vminq_f32(self.0 .0, rhs.0 .0),
710                vminq_f32(self.0 .1, rhs.0 .1),
711                vminq_f32(self.0 .2, rhs.0 .2),
712                vminq_f32(self.0 .3, rhs.0 .3),
713            ))
714        }
715        #[cfg(target_arch = "loongarch64")]
716        unsafe {
717            Self(lasx_xvfmin_s(self.0, rhs.0), lasx_xvfmin_s(self.1, rhs.1))
718        }
719    }
720
721    fn find(&self, val: f32) -> Option<i32> {
722        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
723        unsafe {
724            // let tgt = _mm512_set1_ps(val);
725            // let mask = _mm512_cmpeq_ps_mask(self.0, tgt);
726            // if mask != 0 {
727            //     return Some(mask.trailing_zeros() as i32);
728            // }
729            todo!()
730        }
731        #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
732        unsafe {
733            // _mm256_cmpeq_ps_mask requires "avx512l".
734            for i in 0..16 {
735                if self.as_array().get_unchecked(i) == &val {
736                    return Some(i as i32);
737                }
738            }
739            None
740        }
741        #[cfg(target_arch = "aarch64")]
742        unsafe {
743            let tgt = vdupq_n_f32(val);
744            let mut arr = [0; 16];
745            let mask1 = vceqq_f32(self.0 .0, tgt);
746            let mask2 = vceqq_f32(self.0 .1, tgt);
747            let mask3 = vceqq_f32(self.0 .2, tgt);
748            let mask4 = vceqq_f32(self.0 .3, tgt);
749
750            vst1q_u32(arr.as_mut_ptr(), mask1);
751            vst1q_u32(arr.as_mut_ptr().add(4), mask2);
752            vst1q_u32(arr.as_mut_ptr().add(8), mask3);
753            vst1q_u32(arr.as_mut_ptr().add(12), mask4);
754
755            for i in 0..16 {
756                if arr.get_unchecked(i) != &0 {
757                    return Some(i as i32);
758                }
759            }
760            None
761        }
762        #[cfg(target_arch = "loongarch64")]
763        unsafe {
764            for i in 0..16 {
765                if self.as_array().get_unchecked(i) == &val {
766                    return Some(i as i32);
767                }
768            }
769            None
770        }
771    }
772}
773
774impl FloatSimd<f32, 16> for f32x16 {
775    #[inline]
776    fn multiply_add(&mut self, a: Self, b: Self) {
777        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
778        unsafe {
779            self.0 = _mm512_fmadd_ps(a.0, b.0, self.0)
780        }
781        #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
782        unsafe {
783            self.0 = _mm256_fmadd_ps(a.0, b.0, self.0);
784            self.1 = _mm256_fmadd_ps(a.1, b.1, self.1);
785        }
786        #[cfg(target_arch = "aarch64")]
787        unsafe {
788            self.0 .0 = vfmaq_f32(self.0 .0, a.0 .0, b.0 .0);
789            self.0 .1 = vfmaq_f32(self.0 .1, a.0 .1, b.0 .1);
790            self.0 .2 = vfmaq_f32(self.0 .2, a.0 .2, b.0 .2);
791            self.0 .3 = vfmaq_f32(self.0 .3, a.0 .3, b.0 .3);
792        }
793        #[cfg(target_arch = "loongarch64")]
794        unsafe {
795            self.0 = lasx_xvfmadd_s(a.0, b.0, self.0);
796            self.1 = lasx_xvfmadd_s(a.1, b.1, self.1);
797        }
798    }
799}
800
801impl Add for f32x16 {
802    type Output = Self;
803
804    #[inline]
805    fn add(self, rhs: Self) -> Self::Output {
806        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
807        unsafe {
808            Self(_mm512_add_ps(self.0, rhs.0))
809        }
810        #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
811        unsafe {
812            Self(_mm256_add_ps(self.0, rhs.0), _mm256_add_ps(self.1, rhs.1))
813        }
814        #[cfg(target_arch = "aarch64")]
815        unsafe {
816            Self(float32x4x4_t(
817                vaddq_f32(self.0 .0, rhs.0 .0),
818                vaddq_f32(self.0 .1, rhs.0 .1),
819                vaddq_f32(self.0 .2, rhs.0 .2),
820                vaddq_f32(self.0 .3, rhs.0 .3),
821            ))
822        }
823        #[cfg(target_arch = "loongarch64")]
824        unsafe {
825            Self(lasx_xvfadd_s(self.0, rhs.0), lasx_xvfadd_s(self.1, rhs.1))
826        }
827    }
828}
829
830impl AddAssign for f32x16 {
831    #[inline]
832    fn add_assign(&mut self, rhs: Self) {
833        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
834        unsafe {
835            self.0 = _mm512_add_ps(self.0, rhs.0)
836        }
837        #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
838        unsafe {
839            self.0 = _mm256_add_ps(self.0, rhs.0);
840            self.1 = _mm256_add_ps(self.1, rhs.1);
841        }
842        #[cfg(target_arch = "aarch64")]
843        unsafe {
844            self.0 .0 = vaddq_f32(self.0 .0, rhs.0 .0);
845            self.0 .1 = vaddq_f32(self.0 .1, rhs.0 .1);
846            self.0 .2 = vaddq_f32(self.0 .2, rhs.0 .2);
847            self.0 .3 = vaddq_f32(self.0 .3, rhs.0 .3);
848        }
849        #[cfg(target_arch = "loongarch64")]
850        unsafe {
851            self.0 = lasx_xvfadd_s(self.0, rhs.0);
852            self.1 = lasx_xvfadd_s(self.1, rhs.1);
853        }
854    }
855}
856
857impl Mul for f32x16 {
858    type Output = Self;
859
860    #[inline]
861    fn mul(self, rhs: Self) -> Self::Output {
862        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
863        unsafe {
864            Self(_mm512_mul_ps(self.0, rhs.0))
865        }
866        #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
867        unsafe {
868            Self(_mm256_mul_ps(self.0, rhs.0), _mm256_mul_ps(self.1, rhs.1))
869        }
870        #[cfg(target_arch = "aarch64")]
871        unsafe {
872            Self(float32x4x4_t(
873                vmulq_f32(self.0 .0, rhs.0 .0),
874                vmulq_f32(self.0 .1, rhs.0 .1),
875                vmulq_f32(self.0 .2, rhs.0 .2),
876                vmulq_f32(self.0 .3, rhs.0 .3),
877            ))
878        }
879        #[cfg(target_arch = "loongarch64")]
880        unsafe {
881            Self(lasx_xvfmul_s(self.0, rhs.0), lasx_xvfmul_s(self.1, rhs.1))
882        }
883    }
884}
885
886impl Sub for f32x16 {
887    type Output = Self;
888
889    #[inline]
890    fn sub(self, rhs: Self) -> Self::Output {
891        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
892        unsafe {
893            Self(_mm512_sub_ps(self.0, rhs.0))
894        }
895        #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
896        unsafe {
897            Self(_mm256_sub_ps(self.0, rhs.0), _mm256_sub_ps(self.1, rhs.1))
898        }
899        #[cfg(target_arch = "aarch64")]
900        unsafe {
901            Self(float32x4x4_t(
902                vsubq_f32(self.0 .0, rhs.0 .0),
903                vsubq_f32(self.0 .1, rhs.0 .1),
904                vsubq_f32(self.0 .2, rhs.0 .2),
905                vsubq_f32(self.0 .3, rhs.0 .3),
906            ))
907        }
908        #[cfg(target_arch = "loongarch64")]
909        unsafe {
910            Self(lasx_xvfsub_s(self.0, rhs.0), lasx_xvfsub_s(self.1, rhs.1))
911        }
912    }
913}
914
915impl SubAssign for f32x16 {
916    #[inline]
917    fn sub_assign(&mut self, rhs: Self) {
918        #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
919        unsafe {
920            self.0 = _mm512_sub_ps(self.0, rhs.0)
921        }
922        #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
923        unsafe {
924            self.0 = _mm256_sub_ps(self.0, rhs.0);
925            self.1 = _mm256_sub_ps(self.1, rhs.1);
926        }
927        #[cfg(target_arch = "aarch64")]
928        unsafe {
929            self.0 .0 = vsubq_f32(self.0 .0, rhs.0 .0);
930            self.0 .1 = vsubq_f32(self.0 .1, rhs.0 .1);
931            self.0 .2 = vsubq_f32(self.0 .2, rhs.0 .2);
932            self.0 .3 = vsubq_f32(self.0 .3, rhs.0 .3);
933        }
934        #[cfg(target_arch = "loongarch64")]
935        unsafe {
936            self.0 = lasx_xvfsub_s(self.0, rhs.0);
937            self.1 = lasx_xvfsub_s(self.1, rhs.1);
938        }
939    }
940}
941
942#[cfg(test)]
943mod tests {
944
945    use super::*;
946
947    #[test]
948    fn test_basic_ops() {
949        let a = (0..8).map(|f| f as f32).collect::<Vec<_>>();
950        let b = (10..18).map(|f| f as f32).collect::<Vec<_>>();
951
952        let mut simd_a = unsafe { f32x8::load_unaligned(a.as_ptr()) };
953        let simd_b = unsafe { f32x8::load_unaligned(b.as_ptr()) };
954
955        let simd_add = simd_a + simd_b;
956        assert!((0..8)
957            .zip(simd_add.as_array().iter())
958            .all(|(x, &y)| (x + x + 10) as f32 == y));
959
960        let simd_mul = simd_a * simd_b;
961        assert!((0..8)
962            .zip(simd_mul.as_array().iter())
963            .all(|(x, &y)| (x * (x + 10)) as f32 == y));
964
965        let simd_sub = simd_b - simd_a;
966        assert!(simd_sub.as_array().iter().all(|&v| v == 10.0));
967
968        simd_a -= simd_b;
969        assert_eq!(simd_a.reduce_sum(), -80.0);
970
971        let mut simd_power = f32x8::splat(0.0);
972        simd_power.multiply_add(simd_a, simd_a);
973
974        assert_eq!(
975            "f32x8([100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0])",
976            format!("{:?}", simd_power)
977        );
978    }
979
980    #[test]
981    fn test_f32x8_cmp_ops() {
982        let a = [1.0_f32, 2.0, 5.0, 6.0, 7.0, 3.0, 2.0, 1.0];
983        let b = [2.0_f32, 1.0, 4.0, 5.0, 9.0, 5.0, 6.0, 2.0];
984        let c = [2.0_f32, 1.0, 4.0, 5.0, 7.0, 3.0, 2.0, 1.0];
985        let simd_a: f32x8 = (&a).into();
986        let simd_b: f32x8 = (&b).into();
987        let simd_c: f32x8 = (&c).into();
988
989        let min_simd = simd_a.min(&simd_b);
990        assert_eq!(
991            min_simd.as_array(),
992            [1.0, 1.0, 4.0, 5.0, 7.0, 3.0, 2.0, 1.0]
993        );
994        let min_val = min_simd.reduce_min();
995        assert_eq!(min_val, 1.0);
996        let min_val = simd_c.reduce_min();
997        assert_eq!(min_val, 1.0);
998
999        assert_eq!(Some(2), simd_a.find(5.0));
1000        assert_eq!(Some(1), simd_a.find(2.0));
1001        assert_eq!(None, simd_a.find(-200.0));
1002    }
1003
1004    #[test]
1005    fn test_basic_f32x16_ops() {
1006        let a = (0..16).map(|f| f as f32).collect::<Vec<_>>();
1007        let b = (10..26).map(|f| f as f32).collect::<Vec<_>>();
1008
1009        let mut simd_a = unsafe { f32x16::load_unaligned(a.as_ptr()) };
1010        let simd_b = unsafe { f32x16::load_unaligned(b.as_ptr()) };
1011
1012        let simd_add = simd_a + simd_b;
1013        assert!((0..16)
1014            .zip(simd_add.as_array().iter())
1015            .all(|(x, &y)| (x + x + 10) as f32 == y));
1016
1017        let simd_mul = simd_a * simd_b;
1018        assert!((0..16)
1019            .zip(simd_mul.as_array().iter())
1020            .all(|(x, &y)| (x * (x + 10)) as f32 == y));
1021
1022        simd_a -= simd_b;
1023        assert_eq!(simd_a.reduce_sum(), -160.0);
1024
1025        let mut simd_power = f32x16::zeros();
1026        simd_power.multiply_add(simd_a, simd_a);
1027
1028        assert_eq!(
1029            format!("f32x16({:?})", [100.0; 16]),
1030            format!("{:?}", simd_power)
1031        );
1032    }
1033
1034    #[test]
1035    fn test_f32x16_cmp_ops() {
1036        let a = [
1037            1.0_f32, 2.0, 5.0, 6.0, 7.0, 3.0, 2.0, 1.0, -0.5, 5.0, 6.0, 7.0, 8.0, 9.0, 1.0, 2.0,
1038        ];
1039        let b = [
1040            2.0_f32, 1.0, 4.0, 5.0, 9.0, 5.0, 6.0, 2.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 2.0, 1.0,
1041        ];
1042        let c = [
1043            1.0_f32, 1.0, 4.0, 5.0, 7.0, 3.0, 2.0, 1.0, -0.5, 5.0, 6.0, 7.0, 8.0, 9.0, 1.0, -1.0,
1044        ];
1045        let simd_a: f32x16 = (&a).into();
1046        let simd_b: f32x16 = (&b).into();
1047        let simd_c: f32x16 = (&c).into();
1048
1049        let min_simd = simd_a.min(&simd_b);
1050        assert_eq!(
1051            min_simd.as_array(),
1052            [1.0, 1.0, 4.0, 5.0, 7.0, 3.0, 2.0, 1.0, -0.5, 5.0, 6.0, 7.0, 8.0, 9.0, 1.0, 1.0]
1053        );
1054        let min_val = min_simd.reduce_min();
1055        assert_eq!(min_val, -0.5);
1056        let min_val = simd_c.reduce_min();
1057        assert_eq!(min_val, -1.0);
1058
1059        assert_eq!(Some(2), simd_a.find(5.0));
1060        assert_eq!(Some(1), simd_a.find(2.0));
1061        assert_eq!(Some(13), simd_a.find(9.0));
1062        assert_eq!(None, simd_a.find(-200.0));
1063    }
1064
1065    #[test]
1066    fn test_f32x8_gather() {
1067        let a = (0..256).map(|f| f as f32).collect::<Vec<_>>();
1068        let idx = [0_i32, 4, 8, 12, 16, 20, 24, 29];
1069        let v = f32x8::gather(&a, &idx);
1070        assert_eq!(v.reduce_sum(), 113.0);
1071    }
1072}