1use std::fmt::Formatter;
7
8#[cfg(target_arch = "aarch64")]
9use std::arch::aarch64::*;
10#[cfg(target_arch = "loongarch64")]
11use std::arch::loongarch64::*;
12#[cfg(target_arch = "x86_64")]
13use std::arch::x86_64::*;
14#[cfg(target_arch = "loongarch64")]
15use std::mem::transmute;
16use std::ops::{Add, AddAssign, Mul, Sub, SubAssign};
17
18use super::{FloatSimd, SIMD};
19
20#[allow(non_camel_case_types)]
22#[cfg(target_arch = "x86_64")]
23#[derive(Clone, Copy)]
24pub struct f32x8(std::arch::x86_64::__m256);
25
26#[allow(non_camel_case_types)]
28#[cfg(target_arch = "aarch64")]
29#[derive(Clone, Copy)]
30pub struct f32x8(float32x4x2_t);
31
32#[allow(non_camel_case_types)]
34#[cfg(target_arch = "loongarch64")]
35#[derive(Clone, Copy)]
36pub struct f32x8(v8f32);
37
38impl std::fmt::Debug for f32x8 {
39 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
40 let mut arr = [0.0_f32; 8];
41 unsafe {
42 self.store_unaligned(arr.as_mut_ptr());
43 }
44 write!(f, "f32x8({:?})", arr)
45 }
46}
47
48impl f32x8 {
49 #[inline]
50 pub fn gather(slice: &[f32], indices: &[i32; 8]) -> Self {
51 #[cfg(target_arch = "x86_64")]
52 unsafe {
53 use super::i32::i32x8;
54
55 let idx = i32x8::from(indices);
56 Self(_mm256_i32gather_ps::<4>(slice.as_ptr(), idx.0))
57 }
58
59 #[cfg(target_arch = "aarch64")]
60 unsafe {
61 let ptr = slice.as_ptr();
63
64 let values = [
65 *ptr.add(indices[0] as usize),
66 *ptr.add(indices[1] as usize),
67 *ptr.add(indices[2] as usize),
68 *ptr.add(indices[3] as usize),
69 *ptr.add(indices[4] as usize),
70 *ptr.add(indices[5] as usize),
71 *ptr.add(indices[6] as usize),
72 *ptr.add(indices[7] as usize),
73 ];
74 Self::load_unaligned(values.as_ptr())
75 }
76
77 #[cfg(target_arch = "loongarch64")]
78 unsafe {
79 let ptr = slice.as_ptr();
81
82 let values = [
83 *ptr.add(indices[0] as usize),
84 *ptr.add(indices[1] as usize),
85 *ptr.add(indices[2] as usize),
86 *ptr.add(indices[3] as usize),
87 *ptr.add(indices[4] as usize),
88 *ptr.add(indices[5] as usize),
89 *ptr.add(indices[6] as usize),
90 *ptr.add(indices[7] as usize),
91 ];
92 Self::load_unaligned(values.as_ptr())
93 }
94 }
95}
96
97impl From<&[f32]> for f32x8 {
98 fn from(value: &[f32]) -> Self {
99 unsafe { Self::load_unaligned(value.as_ptr()) }
100 }
101}
102
103impl<'a> From<&'a [f32; 8]> for f32x8 {
104 fn from(value: &'a [f32; 8]) -> Self {
105 unsafe { Self::load_unaligned(value.as_ptr()) }
106 }
107}
108
109impl SIMD<f32, 8> for f32x8 {
110 fn splat(val: f32) -> Self {
111 #[cfg(target_arch = "x86_64")]
112 unsafe {
113 Self(_mm256_set1_ps(val))
114 }
115 #[cfg(target_arch = "aarch64")]
116 unsafe {
117 Self(float32x4x2_t(vdupq_n_f32(val), vdupq_n_f32(val)))
118 }
119 #[cfg(target_arch = "loongarch64")]
120 unsafe {
121 Self(transmute(lasx_xvreplgr2vr_w(transmute(val))))
122 }
123 }
124
125 fn zeros() -> Self {
126 #[cfg(target_arch = "x86_64")]
127 unsafe {
128 Self(_mm256_setzero_ps())
129 }
130 #[cfg(target_arch = "aarch64")]
131 {
132 Self::splat(0.0)
133 }
134 #[cfg(target_arch = "loongarch64")]
135 {
136 Self::splat(0.0)
137 }
138 }
139
140 #[inline]
141 unsafe fn load(ptr: *const f32) -> Self {
142 #[cfg(target_arch = "x86_64")]
143 unsafe {
144 Self(_mm256_load_ps(ptr))
145 }
146 #[cfg(target_arch = "aarch64")]
147 {
148 Self::load_unaligned(ptr)
149 }
150 #[cfg(target_arch = "loongarch64")]
151 {
152 Self(transmute(lasx_xvld::<0>(transmute(ptr))))
153 }
154 }
155
156 #[inline]
157 unsafe fn load_unaligned(ptr: *const f32) -> Self {
158 #[cfg(target_arch = "x86_64")]
159 unsafe {
160 Self(_mm256_loadu_ps(ptr))
161 }
162 #[cfg(target_arch = "aarch64")]
163 {
164 Self(vld1q_f32_x2(ptr))
165 }
166 #[cfg(target_arch = "loongarch64")]
167 {
168 Self(transmute(lasx_xvld::<0>(transmute(ptr))))
169 }
170 }
171
172 unsafe fn store(&self, ptr: *mut f32) {
173 #[cfg(target_arch = "x86_64")]
174 unsafe {
175 _mm256_store_ps(ptr, self.0);
176 }
177 #[cfg(target_arch = "aarch64")]
178 unsafe {
179 vst1q_f32_x2(ptr, self.0);
180 }
181 #[cfg(target_arch = "loongarch64")]
182 unsafe {
183 lasx_xvst::<0>(transmute(self.0), transmute(ptr));
184 }
185 }
186
187 unsafe fn store_unaligned(&self, ptr: *mut f32) {
188 #[cfg(target_arch = "x86_64")]
189 unsafe {
190 _mm256_storeu_ps(ptr, self.0);
191 }
192 #[cfg(target_arch = "aarch64")]
193 unsafe {
194 vst1q_f32_x2(ptr, self.0);
195 }
196 #[cfg(target_arch = "loongarch64")]
197 unsafe {
198 lasx_xvst::<0>(transmute(self.0), transmute(ptr));
199 }
200 }
201
202 #[inline]
203 fn reduce_sum(&self) -> f32 {
204 #[cfg(target_arch = "x86_64")]
205 unsafe {
206 let mut sum = self.0;
207 let mut shift = _mm256_permute2f128_ps(sum, sum, 1);
210 sum = _mm256_add_ps(sum, shift);
212 shift = _mm256_permute_ps(sum, 14);
213 sum = _mm256_add_ps(sum, shift);
214 sum = _mm256_hadd_ps(sum, sum);
215 let mut results: [f32; 8] = [0f32; 8];
216 _mm256_storeu_ps(results.as_mut_ptr(), sum);
217 results[0]
218 }
219 #[cfg(target_arch = "aarch64")]
220 unsafe {
221 let sum = vaddq_f32(self.0 .0, self.0 .1);
222 vaddvq_f32(sum)
223 }
224 #[cfg(target_arch = "loongarch64")]
225 {
226 self.as_array().iter().sum()
227 }
228 }
229
230 fn reduce_min(&self) -> f32 {
231 #[cfg(target_arch = "x86_64")]
232 {
233 unsafe {
234 let mut min = self.0;
235 let mut shift = _mm256_permute2f128_ps(min, min, 1);
238 min = _mm256_min_ps(min, shift);
240 shift = _mm256_permute_ps(min, 14);
241 min = _mm256_min_ps(min, shift);
242 shift = _mm256_permute_ps(min, 1);
243 min = _mm256_min_ps(min, shift);
244 _mm256_cvtss_f32(min)
245 }
246 }
247 #[cfg(target_arch = "aarch64")]
248 unsafe {
249 let m = vminq_f32(self.0 .0, self.0 .1);
250 vminvq_f32(m)
251 }
252 #[cfg(target_arch = "loongarch64")]
253 unsafe {
254 let m1 = lasx_xvpermi_d::<14>(transmute(self.0));
255 let m2 = lasx_xvfmin_s(transmute(m1), self.0);
256 let m1 = lasx_xvpermi_w::<14>(transmute(m2), transmute(m2));
257 let m2 = lasx_xvfmin_s(transmute(m1), transmute(m2));
258 let m1 = lasx_xvpermi_w::<1>(transmute(m2), transmute(m2));
259 let m2 = lasx_xvfmin_s(transmute(m1), transmute(m2));
260 transmute(lasx_xvpickve2gr_w::<0>(transmute(m2)))
261 }
262 }
263
264 fn min(&self, rhs: &Self) -> Self {
265 #[cfg(target_arch = "x86_64")]
266 unsafe {
267 Self(_mm256_min_ps(self.0, rhs.0))
268 }
269 #[cfg(target_arch = "aarch64")]
270 unsafe {
271 Self(float32x4x2_t(
272 vminq_f32(self.0 .0, rhs.0 .0),
273 vminq_f32(self.0 .1, rhs.0 .1),
274 ))
275 }
276 #[cfg(target_arch = "loongarch64")]
277 unsafe {
278 Self(lasx_xvfmin_s(self.0, rhs.0))
279 }
280 }
281
282 fn find(&self, val: f32) -> Option<i32> {
283 #[cfg(target_arch = "x86_64")]
284 unsafe {
285 for i in 0..8 {
286 if self.as_array().get_unchecked(i) == &val {
287 return Some(i as i32);
288 }
289 }
290 }
291 #[cfg(target_arch = "aarch64")]
292 unsafe {
293 let tgt = vdupq_n_f32(val);
294 let mut arr = [0; 8];
295 let mask1 = vceqq_f32(self.0 .0, tgt);
296 let mask2 = vceqq_f32(self.0 .1, tgt);
297 vst1q_u32(arr.as_mut_ptr(), mask1);
298 vst1q_u32(arr.as_mut_ptr().add(4), mask2);
299 for i in 0..8 {
300 if arr.get_unchecked(i) != &0 {
301 return Some(i as i32);
302 }
303 }
304 }
305 #[cfg(target_arch = "loongarch64")]
306 unsafe {
307 for i in 0..8 {
308 if self.as_array().get_unchecked(i) == &val {
309 return Some(i as i32);
310 }
311 }
312 }
313 None
314 }
315}
316
317impl FloatSimd<f32, 8> for f32x8 {
318 fn multiply_add(&mut self, a: Self, b: Self) {
319 #[cfg(target_arch = "x86_64")]
320 unsafe {
321 self.0 = _mm256_fmadd_ps(a.0, b.0, self.0);
322 }
323 #[cfg(target_arch = "aarch64")]
324 unsafe {
325 self.0 .0 = vfmaq_f32(self.0 .0, a.0 .0, b.0 .0);
326 self.0 .1 = vfmaq_f32(self.0 .1, a.0 .1, b.0 .1);
327 }
328 #[cfg(target_arch = "loongarch64")]
329 unsafe {
330 self.0 = lasx_xvfmadd_s(a.0, b.0, self.0);
331 }
332 }
333}
334
335impl Add for f32x8 {
336 type Output = Self;
337
338 #[inline]
339 fn add(self, rhs: Self) -> Self::Output {
340 #[cfg(target_arch = "x86_64")]
341 unsafe {
342 Self(_mm256_add_ps(self.0, rhs.0))
343 }
344 #[cfg(target_arch = "aarch64")]
345 unsafe {
346 Self(float32x4x2_t(
347 vaddq_f32(self.0 .0, rhs.0 .0),
348 vaddq_f32(self.0 .1, rhs.0 .1),
349 ))
350 }
351 #[cfg(target_arch = "loongarch64")]
352 unsafe {
353 Self(lasx_xvfadd_s(self.0, rhs.0))
354 }
355 }
356}
357
358impl AddAssign for f32x8 {
359 #[inline]
360 fn add_assign(&mut self, rhs: Self) {
361 #[cfg(target_arch = "x86_64")]
362 unsafe {
363 self.0 = _mm256_add_ps(self.0, rhs.0)
364 }
365 #[cfg(target_arch = "aarch64")]
366 unsafe {
367 self.0 .0 = vaddq_f32(self.0 .0, rhs.0 .0);
368 self.0 .1 = vaddq_f32(self.0 .1, rhs.0 .1);
369 }
370 #[cfg(target_arch = "loongarch64")]
371 unsafe {
372 self.0 = lasx_xvfadd_s(self.0, rhs.0);
373 }
374 }
375}
376
377impl Sub for f32x8 {
378 type Output = Self;
379
380 #[inline]
381 fn sub(self, rhs: Self) -> Self::Output {
382 #[cfg(target_arch = "x86_64")]
383 unsafe {
384 Self(_mm256_sub_ps(self.0, rhs.0))
385 }
386 #[cfg(target_arch = "aarch64")]
387 unsafe {
388 Self(float32x4x2_t(
389 vsubq_f32(self.0 .0, rhs.0 .0),
390 vsubq_f32(self.0 .1, rhs.0 .1),
391 ))
392 }
393 #[cfg(target_arch = "loongarch64")]
394 unsafe {
395 Self(lasx_xvfsub_s(self.0, rhs.0))
396 }
397 }
398}
399
400impl SubAssign for f32x8 {
401 #[inline]
402 fn sub_assign(&mut self, rhs: Self) {
403 #[cfg(target_arch = "x86_64")]
404 unsafe {
405 self.0 = _mm256_sub_ps(self.0, rhs.0)
406 }
407 #[cfg(target_arch = "aarch64")]
408 unsafe {
409 self.0 .0 = vsubq_f32(self.0 .0, rhs.0 .0);
410 self.0 .1 = vsubq_f32(self.0 .1, rhs.0 .1);
411 }
412 #[cfg(target_arch = "loongarch64")]
413 unsafe {
414 self.0 = lasx_xvfsub_s(self.0, rhs.0);
415 }
416 }
417}
418
419impl Mul for f32x8 {
420 type Output = Self;
421
422 #[inline]
423 fn mul(self, rhs: Self) -> Self::Output {
424 #[cfg(target_arch = "x86_64")]
425 unsafe {
426 Self(_mm256_mul_ps(self.0, rhs.0))
427 }
428 #[cfg(target_arch = "aarch64")]
429 unsafe {
430 Self(float32x4x2_t(
431 vmulq_f32(self.0 .0, rhs.0 .0),
432 vmulq_f32(self.0 .1, rhs.0 .1),
433 ))
434 }
435 #[cfg(target_arch = "loongarch64")]
436 unsafe {
437 Self(lasx_xvfmul_s(self.0, rhs.0))
438 }
439 }
440}
441
442#[allow(non_camel_case_types)]
444#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
445#[derive(Clone, Copy)]
446pub struct f32x16(__m256, __m256);
447#[allow(non_camel_case_types)]
448#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
449#[derive(Clone, Copy)]
450pub struct f32x16(__m512);
451
452#[allow(non_camel_case_types)]
454#[cfg(target_arch = "aarch64")]
455#[derive(Clone, Copy)]
456pub struct f32x16(float32x4x4_t);
457
458#[allow(non_camel_case_types)]
460#[cfg(target_arch = "loongarch64")]
461#[derive(Clone, Copy)]
462pub struct f32x16(v8f32, v8f32);
463
464impl std::fmt::Debug for f32x16 {
465 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
466 let mut arr = [0.0_f32; 16];
467 unsafe {
468 self.store_unaligned(arr.as_mut_ptr());
469 }
470 write!(f, "f32x16({:?})", arr)
471 }
472}
473
474impl From<&[f32]> for f32x16 {
475 fn from(value: &[f32]) -> Self {
476 unsafe { Self::load_unaligned(value.as_ptr()) }
477 }
478}
479
480impl<'a> From<&'a [f32; 16]> for f32x16 {
481 fn from(value: &'a [f32; 16]) -> Self {
482 unsafe { Self::load_unaligned(value.as_ptr()) }
483 }
484}
485
486impl SIMD<f32, 16> for f32x16 {
487 #[inline]
488 fn splat(val: f32) -> Self {
489 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
490 unsafe {
491 Self(_mm512_set1_ps(val))
492 }
493 #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
494 unsafe {
495 Self(_mm256_set1_ps(val), _mm256_set1_ps(val))
496 }
497 #[cfg(target_arch = "aarch64")]
498 unsafe {
499 Self(float32x4x4_t(
500 vdupq_n_f32(val),
501 vdupq_n_f32(val),
502 vdupq_n_f32(val),
503 vdupq_n_f32(val),
504 ))
505 }
506 #[cfg(target_arch = "loongarch64")]
507 unsafe {
508 Self(
509 transmute(lasx_xvreplgr2vr_w(transmute(val))),
510 transmute(lasx_xvreplgr2vr_w(transmute(val))),
511 )
512 }
513 }
514
515 #[inline]
516 fn zeros() -> Self {
517 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
518 unsafe {
519 Self(_mm512_setzero_ps())
520 }
521 #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
522 unsafe {
523 Self(_mm256_setzero_ps(), _mm256_setzero_ps())
524 }
525 #[cfg(target_arch = "aarch64")]
526 {
527 Self::splat(0.0)
528 }
529 #[cfg(target_arch = "loongarch64")]
530 {
531 Self::splat(0.0)
532 }
533 }
534
535 #[inline]
536 unsafe fn load(ptr: *const f32) -> Self {
537 #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
538 unsafe {
539 Self(_mm256_load_ps(ptr), _mm256_load_ps(ptr.add(8)))
540 }
541 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
542 unsafe {
543 Self(_mm512_load_ps(ptr))
544 }
545 #[cfg(target_arch = "aarch64")]
546 {
547 Self::load_unaligned(ptr)
548 }
549 #[cfg(target_arch = "loongarch64")]
550 {
551 Self(
552 transmute(lasx_xvld::<0>(transmute(ptr))),
553 transmute(lasx_xvld::<32>(transmute(ptr))),
554 )
555 }
556 }
557
558 #[inline]
559 unsafe fn load_unaligned(ptr: *const f32) -> Self {
560 #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
561 unsafe {
562 Self(_mm256_loadu_ps(ptr), _mm256_loadu_ps(ptr.add(8)))
563 }
564 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
565 unsafe {
566 Self(_mm512_loadu_ps(ptr))
567 }
568 #[cfg(target_arch = "aarch64")]
569 {
570 Self(vld1q_f32_x4(ptr))
571 }
572 #[cfg(target_arch = "loongarch64")]
573 {
574 Self(
575 transmute(lasx_xvld::<0>(transmute(ptr))),
576 transmute(lasx_xvld::<32>(transmute(ptr))),
577 )
578 }
579 }
580
581 #[inline]
582 unsafe fn store(&self, ptr: *mut f32) {
583 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
584 unsafe {
585 _mm512_store_ps(ptr, self.0)
586 }
587 #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
588 unsafe {
589 _mm256_store_ps(ptr, self.0);
590 _mm256_store_ps(ptr.add(8), self.1);
591 }
592 #[cfg(target_arch = "aarch64")]
593 unsafe {
594 vst1q_f32_x4(ptr, self.0);
595 }
596 #[cfg(target_arch = "loongarch64")]
597 {
598 lasx_xvst::<0>(transmute(self.0), transmute(ptr));
599 lasx_xvst::<32>(transmute(self.1), transmute(ptr));
600 }
601 }
602
603 #[inline]
604 unsafe fn store_unaligned(&self, ptr: *mut f32) {
605 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
606 unsafe {
607 _mm512_storeu_ps(ptr, self.0)
608 }
609 #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
610 unsafe {
611 _mm256_storeu_ps(ptr, self.0);
612 _mm256_storeu_ps(ptr.add(8), self.1);
613 }
614 #[cfg(target_arch = "aarch64")]
615 unsafe {
616 vst1q_f32_x4(ptr, self.0);
617 }
618 #[cfg(target_arch = "loongarch64")]
619 {
620 lasx_xvst::<0>(transmute(self.0), transmute(ptr));
621 lasx_xvst::<32>(transmute(self.1), transmute(ptr));
622 }
623 }
624
625 fn reduce_sum(&self) -> f32 {
626 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
627 unsafe {
628 _mm512_mask_reduce_add_ps(0xFFFF, self.0)
629 }
630 #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
631 unsafe {
632 let mut sum = _mm256_add_ps(self.0, self.1);
633 let mut shift = _mm256_permute2f128_ps(sum, sum, 1);
636 sum = _mm256_add_ps(sum, shift);
638 shift = _mm256_permute_ps(sum, 14);
639 sum = _mm256_add_ps(sum, shift);
640 sum = _mm256_hadd_ps(sum, sum);
641 let mut results: [f32; 8] = [0f32; 8];
642 _mm256_storeu_ps(results.as_mut_ptr(), sum);
643 results[0]
644 }
645 #[cfg(target_arch = "aarch64")]
646 unsafe {
647 let mut sum1 = vaddq_f32(self.0 .0, self.0 .1);
648 let sum2 = vaddq_f32(self.0 .2, self.0 .3);
649 sum1 = vaddq_f32(sum1, sum2);
650 vaddvq_f32(sum1)
651 }
652 #[cfg(target_arch = "loongarch64")]
653 {
654 self.as_array().iter().sum()
655 }
656 }
657
658 #[inline]
659 fn reduce_min(&self) -> f32 {
660 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
661 unsafe {
662 _mm512_mask_reduce_min_ps(0xFFFF, self.0)
663 }
664 #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
665 unsafe {
666 let mut m1 = _mm256_min_ps(self.0, self.1);
667 let mut m2 = _mm256_permute2f128_ps(m1, m1, 1);
668 m1 = _mm256_min_ps(m1, m2);
669 m2 = _mm256_permute_ps(m1, 14);
670 m1 = _mm256_min_ps(m1, m2);
671 m2 = _mm256_permute_ps(m1, 1);
672 m1 = _mm256_min_ps(m1, m2);
673 _mm256_cvtss_f32(m1)
674 }
675
676 #[cfg(target_arch = "aarch64")]
677 unsafe {
678 let m1 = vminq_f32(self.0 .0, self.0 .1);
679 let m2 = vminq_f32(self.0 .2, self.0 .3);
680 let m = vminq_f32(m1, m2);
681 vminvq_f32(m)
682 }
683 #[cfg(target_arch = "loongarch64")]
684 unsafe {
685 let m1 = lasx_xvfmin_s(self.0, self.1);
686 let m2 = lasx_xvpermi_d::<14>(transmute(m1));
687 let m1 = lasx_xvfmin_s(transmute(m1), transmute(m2));
688 let m2 = lasx_xvpermi_w::<14>(transmute(m1), transmute(m1));
689 let m1 = lasx_xvfmin_s(transmute(m1), transmute(m2));
690 let m2 = lasx_xvpermi_w::<1>(transmute(m1), transmute(m1));
691 let m1 = lasx_xvfmin_s(transmute(m1), transmute(m2));
692 transmute(lasx_xvpickve2gr_w::<0>(transmute(m1)))
693 }
694 }
695
696 #[inline]
697 fn min(&self, rhs: &Self) -> Self {
698 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
699 unsafe {
700 Self(_mm512_min_ps(self.0, rhs.0))
701 }
702 #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
703 unsafe {
704 Self(_mm256_min_ps(self.0, rhs.0), _mm256_min_ps(self.1, rhs.1))
705 }
706 #[cfg(target_arch = "aarch64")]
707 unsafe {
708 Self(float32x4x4_t(
709 vminq_f32(self.0 .0, rhs.0 .0),
710 vminq_f32(self.0 .1, rhs.0 .1),
711 vminq_f32(self.0 .2, rhs.0 .2),
712 vminq_f32(self.0 .3, rhs.0 .3),
713 ))
714 }
715 #[cfg(target_arch = "loongarch64")]
716 unsafe {
717 Self(lasx_xvfmin_s(self.0, rhs.0), lasx_xvfmin_s(self.1, rhs.1))
718 }
719 }
720
721 fn find(&self, val: f32) -> Option<i32> {
722 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
723 unsafe {
724 todo!()
730 }
731 #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
732 unsafe {
733 for i in 0..16 {
735 if self.as_array().get_unchecked(i) == &val {
736 return Some(i as i32);
737 }
738 }
739 None
740 }
741 #[cfg(target_arch = "aarch64")]
742 unsafe {
743 let tgt = vdupq_n_f32(val);
744 let mut arr = [0; 16];
745 let mask1 = vceqq_f32(self.0 .0, tgt);
746 let mask2 = vceqq_f32(self.0 .1, tgt);
747 let mask3 = vceqq_f32(self.0 .2, tgt);
748 let mask4 = vceqq_f32(self.0 .3, tgt);
749
750 vst1q_u32(arr.as_mut_ptr(), mask1);
751 vst1q_u32(arr.as_mut_ptr().add(4), mask2);
752 vst1q_u32(arr.as_mut_ptr().add(8), mask3);
753 vst1q_u32(arr.as_mut_ptr().add(12), mask4);
754
755 for i in 0..16 {
756 if arr.get_unchecked(i) != &0 {
757 return Some(i as i32);
758 }
759 }
760 None
761 }
762 #[cfg(target_arch = "loongarch64")]
763 unsafe {
764 for i in 0..16 {
765 if self.as_array().get_unchecked(i) == &val {
766 return Some(i as i32);
767 }
768 }
769 None
770 }
771 }
772}
773
774impl FloatSimd<f32, 16> for f32x16 {
775 #[inline]
776 fn multiply_add(&mut self, a: Self, b: Self) {
777 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
778 unsafe {
779 self.0 = _mm512_fmadd_ps(a.0, b.0, self.0)
780 }
781 #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
782 unsafe {
783 self.0 = _mm256_fmadd_ps(a.0, b.0, self.0);
784 self.1 = _mm256_fmadd_ps(a.1, b.1, self.1);
785 }
786 #[cfg(target_arch = "aarch64")]
787 unsafe {
788 self.0 .0 = vfmaq_f32(self.0 .0, a.0 .0, b.0 .0);
789 self.0 .1 = vfmaq_f32(self.0 .1, a.0 .1, b.0 .1);
790 self.0 .2 = vfmaq_f32(self.0 .2, a.0 .2, b.0 .2);
791 self.0 .3 = vfmaq_f32(self.0 .3, a.0 .3, b.0 .3);
792 }
793 #[cfg(target_arch = "loongarch64")]
794 unsafe {
795 self.0 = lasx_xvfmadd_s(a.0, b.0, self.0);
796 self.1 = lasx_xvfmadd_s(a.1, b.1, self.1);
797 }
798 }
799}
800
801impl Add for f32x16 {
802 type Output = Self;
803
804 #[inline]
805 fn add(self, rhs: Self) -> Self::Output {
806 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
807 unsafe {
808 Self(_mm512_add_ps(self.0, rhs.0))
809 }
810 #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
811 unsafe {
812 Self(_mm256_add_ps(self.0, rhs.0), _mm256_add_ps(self.1, rhs.1))
813 }
814 #[cfg(target_arch = "aarch64")]
815 unsafe {
816 Self(float32x4x4_t(
817 vaddq_f32(self.0 .0, rhs.0 .0),
818 vaddq_f32(self.0 .1, rhs.0 .1),
819 vaddq_f32(self.0 .2, rhs.0 .2),
820 vaddq_f32(self.0 .3, rhs.0 .3),
821 ))
822 }
823 #[cfg(target_arch = "loongarch64")]
824 unsafe {
825 Self(lasx_xvfadd_s(self.0, rhs.0), lasx_xvfadd_s(self.1, rhs.1))
826 }
827 }
828}
829
830impl AddAssign for f32x16 {
831 #[inline]
832 fn add_assign(&mut self, rhs: Self) {
833 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
834 unsafe {
835 self.0 = _mm512_add_ps(self.0, rhs.0)
836 }
837 #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
838 unsafe {
839 self.0 = _mm256_add_ps(self.0, rhs.0);
840 self.1 = _mm256_add_ps(self.1, rhs.1);
841 }
842 #[cfg(target_arch = "aarch64")]
843 unsafe {
844 self.0 .0 = vaddq_f32(self.0 .0, rhs.0 .0);
845 self.0 .1 = vaddq_f32(self.0 .1, rhs.0 .1);
846 self.0 .2 = vaddq_f32(self.0 .2, rhs.0 .2);
847 self.0 .3 = vaddq_f32(self.0 .3, rhs.0 .3);
848 }
849 #[cfg(target_arch = "loongarch64")]
850 unsafe {
851 self.0 = lasx_xvfadd_s(self.0, rhs.0);
852 self.1 = lasx_xvfadd_s(self.1, rhs.1);
853 }
854 }
855}
856
857impl Mul for f32x16 {
858 type Output = Self;
859
860 #[inline]
861 fn mul(self, rhs: Self) -> Self::Output {
862 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
863 unsafe {
864 Self(_mm512_mul_ps(self.0, rhs.0))
865 }
866 #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
867 unsafe {
868 Self(_mm256_mul_ps(self.0, rhs.0), _mm256_mul_ps(self.1, rhs.1))
869 }
870 #[cfg(target_arch = "aarch64")]
871 unsafe {
872 Self(float32x4x4_t(
873 vmulq_f32(self.0 .0, rhs.0 .0),
874 vmulq_f32(self.0 .1, rhs.0 .1),
875 vmulq_f32(self.0 .2, rhs.0 .2),
876 vmulq_f32(self.0 .3, rhs.0 .3),
877 ))
878 }
879 #[cfg(target_arch = "loongarch64")]
880 unsafe {
881 Self(lasx_xvfmul_s(self.0, rhs.0), lasx_xvfmul_s(self.1, rhs.1))
882 }
883 }
884}
885
886impl Sub for f32x16 {
887 type Output = Self;
888
889 #[inline]
890 fn sub(self, rhs: Self) -> Self::Output {
891 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
892 unsafe {
893 Self(_mm512_sub_ps(self.0, rhs.0))
894 }
895 #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
896 unsafe {
897 Self(_mm256_sub_ps(self.0, rhs.0), _mm256_sub_ps(self.1, rhs.1))
898 }
899 #[cfg(target_arch = "aarch64")]
900 unsafe {
901 Self(float32x4x4_t(
902 vsubq_f32(self.0 .0, rhs.0 .0),
903 vsubq_f32(self.0 .1, rhs.0 .1),
904 vsubq_f32(self.0 .2, rhs.0 .2),
905 vsubq_f32(self.0 .3, rhs.0 .3),
906 ))
907 }
908 #[cfg(target_arch = "loongarch64")]
909 unsafe {
910 Self(lasx_xvfsub_s(self.0, rhs.0), lasx_xvfsub_s(self.1, rhs.1))
911 }
912 }
913}
914
915impl SubAssign for f32x16 {
916 #[inline]
917 fn sub_assign(&mut self, rhs: Self) {
918 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
919 unsafe {
920 self.0 = _mm512_sub_ps(self.0, rhs.0)
921 }
922 #[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))]
923 unsafe {
924 self.0 = _mm256_sub_ps(self.0, rhs.0);
925 self.1 = _mm256_sub_ps(self.1, rhs.1);
926 }
927 #[cfg(target_arch = "aarch64")]
928 unsafe {
929 self.0 .0 = vsubq_f32(self.0 .0, rhs.0 .0);
930 self.0 .1 = vsubq_f32(self.0 .1, rhs.0 .1);
931 self.0 .2 = vsubq_f32(self.0 .2, rhs.0 .2);
932 self.0 .3 = vsubq_f32(self.0 .3, rhs.0 .3);
933 }
934 #[cfg(target_arch = "loongarch64")]
935 unsafe {
936 self.0 = lasx_xvfsub_s(self.0, rhs.0);
937 self.1 = lasx_xvfsub_s(self.1, rhs.1);
938 }
939 }
940}
941
942#[cfg(test)]
943mod tests {
944
945 use super::*;
946
947 #[test]
948 fn test_basic_ops() {
949 let a = (0..8).map(|f| f as f32).collect::<Vec<_>>();
950 let b = (10..18).map(|f| f as f32).collect::<Vec<_>>();
951
952 let mut simd_a = unsafe { f32x8::load_unaligned(a.as_ptr()) };
953 let simd_b = unsafe { f32x8::load_unaligned(b.as_ptr()) };
954
955 let simd_add = simd_a + simd_b;
956 assert!((0..8)
957 .zip(simd_add.as_array().iter())
958 .all(|(x, &y)| (x + x + 10) as f32 == y));
959
960 let simd_mul = simd_a * simd_b;
961 assert!((0..8)
962 .zip(simd_mul.as_array().iter())
963 .all(|(x, &y)| (x * (x + 10)) as f32 == y));
964
965 let simd_sub = simd_b - simd_a;
966 assert!(simd_sub.as_array().iter().all(|&v| v == 10.0));
967
968 simd_a -= simd_b;
969 assert_eq!(simd_a.reduce_sum(), -80.0);
970
971 let mut simd_power = f32x8::splat(0.0);
972 simd_power.multiply_add(simd_a, simd_a);
973
974 assert_eq!(
975 "f32x8([100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0])",
976 format!("{:?}", simd_power)
977 );
978 }
979
980 #[test]
981 fn test_f32x8_cmp_ops() {
982 let a = [1.0_f32, 2.0, 5.0, 6.0, 7.0, 3.0, 2.0, 1.0];
983 let b = [2.0_f32, 1.0, 4.0, 5.0, 9.0, 5.0, 6.0, 2.0];
984 let c = [2.0_f32, 1.0, 4.0, 5.0, 7.0, 3.0, 2.0, 1.0];
985 let simd_a: f32x8 = (&a).into();
986 let simd_b: f32x8 = (&b).into();
987 let simd_c: f32x8 = (&c).into();
988
989 let min_simd = simd_a.min(&simd_b);
990 assert_eq!(
991 min_simd.as_array(),
992 [1.0, 1.0, 4.0, 5.0, 7.0, 3.0, 2.0, 1.0]
993 );
994 let min_val = min_simd.reduce_min();
995 assert_eq!(min_val, 1.0);
996 let min_val = simd_c.reduce_min();
997 assert_eq!(min_val, 1.0);
998
999 assert_eq!(Some(2), simd_a.find(5.0));
1000 assert_eq!(Some(1), simd_a.find(2.0));
1001 assert_eq!(None, simd_a.find(-200.0));
1002 }
1003
1004 #[test]
1005 fn test_basic_f32x16_ops() {
1006 let a = (0..16).map(|f| f as f32).collect::<Vec<_>>();
1007 let b = (10..26).map(|f| f as f32).collect::<Vec<_>>();
1008
1009 let mut simd_a = unsafe { f32x16::load_unaligned(a.as_ptr()) };
1010 let simd_b = unsafe { f32x16::load_unaligned(b.as_ptr()) };
1011
1012 let simd_add = simd_a + simd_b;
1013 assert!((0..16)
1014 .zip(simd_add.as_array().iter())
1015 .all(|(x, &y)| (x + x + 10) as f32 == y));
1016
1017 let simd_mul = simd_a * simd_b;
1018 assert!((0..16)
1019 .zip(simd_mul.as_array().iter())
1020 .all(|(x, &y)| (x * (x + 10)) as f32 == y));
1021
1022 simd_a -= simd_b;
1023 assert_eq!(simd_a.reduce_sum(), -160.0);
1024
1025 let mut simd_power = f32x16::zeros();
1026 simd_power.multiply_add(simd_a, simd_a);
1027
1028 assert_eq!(
1029 format!("f32x16({:?})", [100.0; 16]),
1030 format!("{:?}", simd_power)
1031 );
1032 }
1033
1034 #[test]
1035 fn test_f32x16_cmp_ops() {
1036 let a = [
1037 1.0_f32, 2.0, 5.0, 6.0, 7.0, 3.0, 2.0, 1.0, -0.5, 5.0, 6.0, 7.0, 8.0, 9.0, 1.0, 2.0,
1038 ];
1039 let b = [
1040 2.0_f32, 1.0, 4.0, 5.0, 9.0, 5.0, 6.0, 2.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 2.0, 1.0,
1041 ];
1042 let c = [
1043 1.0_f32, 1.0, 4.0, 5.0, 7.0, 3.0, 2.0, 1.0, -0.5, 5.0, 6.0, 7.0, 8.0, 9.0, 1.0, -1.0,
1044 ];
1045 let simd_a: f32x16 = (&a).into();
1046 let simd_b: f32x16 = (&b).into();
1047 let simd_c: f32x16 = (&c).into();
1048
1049 let min_simd = simd_a.min(&simd_b);
1050 assert_eq!(
1051 min_simd.as_array(),
1052 [1.0, 1.0, 4.0, 5.0, 7.0, 3.0, 2.0, 1.0, -0.5, 5.0, 6.0, 7.0, 8.0, 9.0, 1.0, 1.0]
1053 );
1054 let min_val = min_simd.reduce_min();
1055 assert_eq!(min_val, -0.5);
1056 let min_val = simd_c.reduce_min();
1057 assert_eq!(min_val, -1.0);
1058
1059 assert_eq!(Some(2), simd_a.find(5.0));
1060 assert_eq!(Some(1), simd_a.find(2.0));
1061 assert_eq!(Some(13), simd_a.find(9.0));
1062 assert_eq!(None, simd_a.find(-200.0));
1063 }
1064
1065 #[test]
1066 fn test_f32x8_gather() {
1067 let a = (0..256).map(|f| f as f32).collect::<Vec<_>>();
1068 let idx = [0_i32, 4, 8, 12, 16, 20, 24, 29];
1069 let v = f32x8::gather(&a, &idx);
1070 assert_eq!(v.reduce_sum(), 113.0);
1071 }
1072}