1pub use bytemuck::Pod;
2#[cfg(feature = "f16")]
3use half::f16;
4pub use pulp::{cast, NullaryFnOnce};
5
6#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
7pub use x86::*;
8
9#[cfg(target_arch = "aarch64")]
10pub use aarch64::*;
11
12use crate::gemm::{c32, c64};
13
14pub trait Simd: Copy + Send + Sync + 'static {
15 unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output;
16}
17
18#[derive(Copy, Clone, Debug)]
19pub struct Scalar;
20
21impl Simd for Scalar {
22 #[inline(always)]
23 unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
24 f.call()
25 }
26}
27
28#[cfg(feature = "f16")]
29unsafe impl MixedSimd<f16, f16, f16, f32> for Scalar {
30 const SIMD_WIDTH: usize = 1;
31
32 type LhsN = f16;
33 type RhsN = f16;
34 type DstN = f16;
35 type AccN = f32;
36
37 #[inline]
38 fn try_new() -> Option<Self> {
39 Some(Self)
40 }
41
42 #[inline(always)]
43 fn add(self, lhs: f32, rhs: f32) -> f32 {
44 lhs + rhs
45 }
46
47 #[inline(always)]
48 fn mult(self, lhs: f32, rhs: f32) -> f32 {
49 lhs * rhs
50 }
51
52 #[inline(always)]
53 fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
54 lhs * rhs + acc
55 }
56
57 #[inline(always)]
58 fn from_lhs(self, lhs: f16) -> f32 {
59 lhs.into()
60 }
61
62 #[inline(always)]
63 fn from_rhs(self, rhs: f16) -> f32 {
64 rhs.into()
65 }
66
67 #[inline(always)]
68 fn from_dst(self, dst: f16) -> f32 {
69 dst.into()
70 }
71
72 #[inline(always)]
73 fn into_dst(self, acc: f32) -> f16 {
74 f16::from_f32(acc)
75 }
76
77 #[inline(always)]
78 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
79 lhs * rhs + acc
80 }
81
82 #[inline(always)]
83 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
84 lhs.into()
85 }
86
87 #[inline(always)]
88 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
89 rhs.into()
90 }
91
92 #[inline(always)]
93 fn simd_splat(self, lhs: f32) -> Self::AccN {
94 lhs
95 }
96
97 #[inline(always)]
98 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
99 dst.into()
100 }
101
102 #[inline(always)]
103 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
104 f16::from_f32(acc)
105 }
106
107 #[inline(always)]
108 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
109 f.call()
110 }
111
112 #[inline(always)]
113 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
114 lhs * rhs
115 }
116
117 #[inline(always)]
118 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
119 lhs + rhs
120 }
121}
122
123unsafe impl MixedSimd<f32, f32, f32, f32> for Scalar {
124 const SIMD_WIDTH: usize = 1;
125
126 type LhsN = f32;
127 type RhsN = f32;
128 type DstN = f32;
129 type AccN = f32;
130
131 #[inline]
132 fn try_new() -> Option<Self> {
133 Some(Self)
134 }
135
136 #[inline(always)]
137 fn mult(self, lhs: f32, rhs: f32) -> f32 {
138 lhs * rhs
139 }
140
141 #[inline(always)]
142 fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
143 lhs * rhs + acc
144 }
145
146 #[inline(always)]
147 fn from_lhs(self, lhs: f32) -> f32 {
148 lhs
149 }
150
151 #[inline(always)]
152 fn from_rhs(self, rhs: f32) -> f32 {
153 rhs
154 }
155
156 #[inline(always)]
157 fn from_dst(self, dst: f32) -> f32 {
158 dst
159 }
160
161 #[inline(always)]
162 fn into_dst(self, acc: f32) -> f32 {
163 acc
164 }
165
166 #[inline(always)]
167 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
168 lhs * rhs + acc
169 }
170
171 #[inline(always)]
172 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
173 lhs
174 }
175
176 #[inline(always)]
177 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
178 rhs
179 }
180
181 #[inline(always)]
182 fn simd_splat(self, lhs: f32) -> Self::AccN {
183 lhs
184 }
185
186 #[inline(always)]
187 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
188 dst
189 }
190
191 #[inline(always)]
192 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
193 acc
194 }
195
196 #[inline(always)]
197 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
198 f.call()
199 }
200
201 #[inline(always)]
202 fn add(self, lhs: f32, rhs: f32) -> f32 {
203 lhs + rhs
204 }
205
206 #[inline(always)]
207 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
208 lhs * rhs
209 }
210
211 #[inline(always)]
212 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
213 lhs + rhs
214 }
215}
216
217unsafe impl MixedSimd<f64, f64, f64, f64> for Scalar {
218 const SIMD_WIDTH: usize = 1;
219
220 type LhsN = f64;
221 type RhsN = f64;
222 type DstN = f64;
223 type AccN = f64;
224
225 #[inline]
226 fn try_new() -> Option<Self> {
227 Some(Self)
228 }
229
230 #[inline(always)]
231 fn mult(self, lhs: f64, rhs: f64) -> f64 {
232 lhs * rhs
233 }
234
235 #[inline(always)]
236 fn mult_add(self, lhs: f64, rhs: f64, acc: f64) -> f64 {
237 lhs * rhs + acc
238 }
239
240 #[inline(always)]
241 fn from_lhs(self, lhs: f64) -> f64 {
242 lhs
243 }
244
245 #[inline(always)]
246 fn from_rhs(self, rhs: f64) -> f64 {
247 rhs
248 }
249
250 #[inline(always)]
251 fn from_dst(self, dst: f64) -> f64 {
252 dst
253 }
254
255 #[inline(always)]
256 fn into_dst(self, acc: f64) -> f64 {
257 acc
258 }
259
260 #[inline(always)]
261 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
262 lhs * rhs + acc
263 }
264
265 #[inline(always)]
266 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
267 lhs
268 }
269
270 #[inline(always)]
271 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
272 rhs
273 }
274
275 #[inline(always)]
276 fn simd_splat(self, lhs: f64) -> Self::AccN {
277 lhs
278 }
279
280 #[inline(always)]
281 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
282 dst
283 }
284
285 #[inline(always)]
286 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
287 acc
288 }
289
290 #[inline(always)]
291 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
292 f.call()
293 }
294
295 #[inline(always)]
296 fn add(self, lhs: f64, rhs: f64) -> f64 {
297 lhs + rhs
298 }
299
300 #[inline(always)]
301 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
302 lhs * rhs
303 }
304
305 #[inline(always)]
306 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
307 lhs + rhs
308 }
309}
310
311unsafe impl MixedSimd<c32, c32, c32, c32> for Scalar {
312 const SIMD_WIDTH: usize = 1;
313
314 type LhsN = c32;
315 type RhsN = c32;
316 type DstN = c32;
317 type AccN = c32;
318
319 #[inline]
320 fn try_new() -> Option<Self> {
321 Some(Self)
322 }
323
324 #[inline(always)]
325 fn mult(self, lhs: c32, rhs: c32) -> c32 {
326 lhs * rhs
327 }
328
329 #[inline(always)]
330 fn mult_add(self, lhs: c32, rhs: c32, acc: c32) -> c32 {
331 lhs * rhs + acc
332 }
333
334 #[inline(always)]
335 fn from_lhs(self, lhs: c32) -> c32 {
336 lhs
337 }
338
339 #[inline(always)]
340 fn from_rhs(self, rhs: c32) -> c32 {
341 rhs
342 }
343
344 #[inline(always)]
345 fn from_dst(self, dst: c32) -> c32 {
346 dst
347 }
348
349 #[inline(always)]
350 fn into_dst(self, acc: c32) -> c32 {
351 acc
352 }
353
354 #[inline(always)]
355 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
356 lhs * rhs + acc
357 }
358
359 #[inline(always)]
360 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
361 lhs
362 }
363
364 #[inline(always)]
365 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
366 rhs
367 }
368
369 #[inline(always)]
370 fn simd_splat(self, lhs: c32) -> Self::AccN {
371 lhs
372 }
373
374 #[inline(always)]
375 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
376 dst
377 }
378
379 #[inline(always)]
380 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
381 acc
382 }
383
384 #[inline(always)]
385 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
386 f.call()
387 }
388
389 #[inline(always)]
390 fn add(self, lhs: c32, rhs: c32) -> c32 {
391 lhs + rhs
392 }
393
394 #[inline(always)]
395 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
396 lhs * rhs
397 }
398
399 #[inline(always)]
400 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
401 lhs + rhs
402 }
403}
404
405unsafe impl MixedSimd<c64, c64, c64, c64> for Scalar {
406 const SIMD_WIDTH: usize = 1;
407
408 type LhsN = c64;
409 type RhsN = c64;
410 type DstN = c64;
411 type AccN = c64;
412
413 #[inline]
414 fn try_new() -> Option<Self> {
415 Some(Self)
416 }
417
418 #[inline(always)]
419 fn mult(self, lhs: c64, rhs: c64) -> c64 {
420 lhs * rhs
421 }
422
423 #[inline(always)]
424 fn mult_add(self, lhs: c64, rhs: c64, acc: c64) -> c64 {
425 lhs * rhs + acc
426 }
427
428 #[inline(always)]
429 fn from_lhs(self, lhs: c64) -> c64 {
430 lhs
431 }
432
433 #[inline(always)]
434 fn from_rhs(self, rhs: c64) -> c64 {
435 rhs
436 }
437
438 #[inline(always)]
439 fn from_dst(self, dst: c64) -> c64 {
440 dst
441 }
442
443 #[inline(always)]
444 fn into_dst(self, acc: c64) -> c64 {
445 acc
446 }
447
448 #[inline(always)]
449 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
450 lhs * rhs + acc
451 }
452
453 #[inline(always)]
454 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
455 lhs
456 }
457
458 #[inline(always)]
459 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
460 rhs
461 }
462
463 #[inline(always)]
464 fn simd_splat(self, lhs: c64) -> Self::AccN {
465 lhs
466 }
467
468 #[inline(always)]
469 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
470 dst
471 }
472
473 #[inline(always)]
474 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
475 acc
476 }
477
478 #[inline(always)]
479 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
480 f.call()
481 }
482
483 #[inline(always)]
484 fn add(self, lhs: c64, rhs: c64) -> c64 {
485 lhs + rhs
486 }
487
488 #[inline(always)]
489 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
490 lhs * rhs
491 }
492
493 #[inline(always)]
494 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
495 lhs + rhs
496 }
497}
498
499#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
500mod x86 {
501 use super::*;
502 #[cfg(target_arch = "x86")]
503 use core::arch::x86::*;
504 #[cfg(target_arch = "x86_64")]
505 use core::arch::x86_64::*;
506
507 #[inline(always)]
508 pub unsafe fn v3_fmaf(a: f32, b: f32, c: f32) -> f32 {
509 #[cfg(feature = "std")]
510 {
511 f32::mul_add(a, b, c)
512 }
513 #[cfg(not(feature = "std"))]
514 {
515 libm::fmaf(a, b, c)
516 }
517 }
518
519 #[inline(always)]
520 pub unsafe fn v3_fma(a: f64, b: f64, c: f64) -> f64 {
521 #[cfg(feature = "std")]
522 {
523 f64::mul_add(a, b, c)
524 }
525 #[cfg(not(feature = "std"))]
526 {
527 libm::fma(a, b, c)
528 }
529 }
530
531 #[derive(Copy, Clone)]
532 pub struct Sse;
533 #[derive(Copy, Clone)]
534 pub struct Avx;
535 #[derive(Copy, Clone)]
536 pub struct Fma;
537
538 #[cfg(feature = "nightly")]
539 #[derive(Copy, Clone)]
540 pub struct Avx512f;
541
542 impl Simd for Sse {
543 #[inline]
544 #[target_feature(enable = "sse,sse2")]
545 unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
546 f.call()
547 }
548 }
549
550 impl Simd for Avx {
551 #[inline]
552 #[target_feature(enable = "avx")]
553 unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
554 f.call()
555 }
556 }
557
558 impl Simd for Fma {
559 #[inline]
560 #[target_feature(enable = "fma")]
561 unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
562 f.call()
563 }
564 }
565
566 #[cfg(feature = "nightly")]
567 impl Simd for Avx512f {
568 #[inline]
569 #[target_feature(enable = "avx512f")]
570 unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
571 f.call()
572 }
573 }
574
575 #[derive(Debug, Copy, Clone)]
576 pub struct V3Half {
577 __private: (),
578 }
579
580 #[cfg(feature = "f16")]
581 pulp::simd_type! {
582 pub struct V3 {
583 pub sse: "sse",
584 pub sse2: "sse2",
585 pub fxsr: "fxsr",
586 pub sse3: "sse3",
587 pub ssse3: "ssse3",
588 pub sse4_1: "sse4.1",
589 pub sse4_2: "sse4.2",
590 pub avx: "avx",
591 pub avx2: "avx2",
592 pub fma: "fma",
593 pub f16c: "f16c",
594 }
595 }
596
597 #[cfg(feature = "nightly")]
598 #[cfg(feature = "f16")]
599 pulp::simd_type! {
600 pub struct V4 {
601 pub sse: "sse",
602 pub sse2: "sse2",
603 pub fxsr: "fxsr",
604 pub sse3: "sse3",
605 pub ssse3: "ssse3",
606 pub sse4_1: "sse4.1",
607 pub sse4_2: "sse4.2",
608 pub avx: "avx",
609 pub avx2: "avx2",
610 pub fma: "fma",
611 pub f16c: "f16c",
612 pub avx512f: "avx512f",
613 }
614 }
615
616 #[cfg(not(feature = "f16"))]
617 pulp::simd_type! {
618 pub struct V3 {
619 pub sse: "sse",
620 pub sse2: "sse2",
621 pub fxsr: "fxsr",
622 pub sse3: "sse3",
623 pub ssse3: "ssse3",
624 pub sse4_1: "sse4.1",
625 pub sse4_2: "sse4.2",
626 pub avx: "avx",
627 pub avx2: "avx2",
628 pub fma: "fma",
629 }
630 }
631
632 #[cfg(feature = "nightly")]
633 #[cfg(not(feature = "f16"))]
634 pulp::simd_type! {
635 pub struct V4 {
636 pub sse: "sse",
637 pub sse2: "sse2",
638 pub fxsr: "fxsr",
639 pub sse3: "sse3",
640 pub ssse3: "ssse3",
641 pub sse4_1: "sse4.1",
642 pub sse4_2: "sse4.2",
643 pub avx: "avx",
644 pub avx2: "avx2",
645 pub fma: "fma",
646 pub avx512f: "avx512f",
647 }
648 }
649
650 impl Simd for V3Half {
651 unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
652 f.call()
653 }
654 }
655
656 #[cfg(feature = "f16")]
657 unsafe impl MixedSimd<f16, f16, f16, f32> for V3Half {
658 const SIMD_WIDTH: usize = 4;
659
660 type LhsN = [f16; 4];
661 type RhsN = [f16; 4];
662 type DstN = [f16; 4];
663 type AccN = [f32; 4];
664
665 #[inline]
666 fn try_new() -> Option<Self> {
667 Some(Self { __private: () })
668 }
669
670 #[inline(always)]
671 fn mult(self, lhs: f32, rhs: f32) -> f32 {
672 lhs * rhs
673 }
674
675 #[inline(always)]
676 fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
677 unsafe { v3_fmaf(lhs, rhs, acc) }
678 }
679
680 #[inline(always)]
681 fn from_lhs(self, _lhs: f16) -> f32 {
682 todo!()
683 }
684
685 #[inline(always)]
686 fn from_rhs(self, _rhs: f16) -> f32 {
687 todo!()
688 }
689
690 #[inline(always)]
691 fn from_dst(self, _dst: f16) -> f32 {
692 todo!()
693 }
694
695 #[inline(always)]
696 fn into_dst(self, _acc: f32) -> f16 {
697 todo!()
698 }
699
700 #[inline(always)]
701 fn simd_mult_add(self, _lhs: Self::AccN, _rhs: Self::AccN, _acc: Self::AccN) -> Self::AccN {
702 todo!()
703 }
704
705 #[inline(always)]
706 fn simd_from_lhs(self, _lhs: Self::LhsN) -> Self::AccN {
707 todo!()
708 }
709
710 #[inline(always)]
711 fn simd_from_rhs(self, _rhs: Self::RhsN) -> Self::AccN {
712 todo!()
713 }
714
715 #[inline(always)]
716 fn simd_splat(self, _lhs: f32) -> Self::AccN {
717 todo!()
718 }
719
720 #[inline(always)]
721 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
722 unsafe { cast(_mm_cvtph_ps(cast([dst, [f16::ZERO; 4]]))) }
723 }
724
725 #[inline(always)]
726 fn simd_into_dst(self, _acc: Self::AccN) -> Self::DstN {
727 todo!()
728 }
729
730 #[inline(always)]
731 fn vectorize<F: NullaryFnOnce>(self, _f: F) -> F::Output {
732 todo!()
733 }
734
735 #[inline(always)]
736 fn add(self, _lhs: f32, _rhs: f32) -> f32 {
737 todo!()
738 }
739
740 #[inline(always)]
741 fn simd_mul(self, _lhs: Self::AccN, _rhs: Self::AccN) -> Self::AccN {
742 todo!()
743 }
744
745 #[inline(always)]
746 fn simd_add(self, _lhs: Self::AccN, _rhs: Self::AccN) -> Self::AccN {
747 todo!()
748 }
749 }
750
751 #[cfg(feature = "f16")]
752 unsafe impl MixedSimd<f16, f16, f16, f32> for V3 {
753 const SIMD_WIDTH: usize = 8;
754
755 type LhsN = [f16; 8];
756 type RhsN = [f16; 8];
757 type DstN = [f16; 8];
758 type AccN = [f32; 8];
759
760 #[inline]
761 fn try_new() -> Option<Self> {
762 Self::try_new()
763 }
764
765 #[inline(always)]
766 fn mult(self, lhs: f32, rhs: f32) -> f32 {
767 lhs * rhs
768 }
769
770 #[inline(always)]
771 fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
772 unsafe { v3_fmaf(lhs, rhs, acc) }
773 }
774
775 #[inline(always)]
776 fn from_lhs(self, lhs: f16) -> f32 {
777 unsafe { pulp::cast_lossy(_mm_cvtph_ps(self.sse2._mm_set1_epi16(cast(lhs)))) }
778 }
779
780 #[inline(always)]
781 fn from_rhs(self, rhs: f16) -> f32 {
782 unsafe { pulp::cast_lossy(_mm_cvtph_ps(self.sse2._mm_set1_epi16(cast(rhs)))) }
783 }
784
785 #[inline(always)]
786 fn from_dst(self, dst: f16) -> f32 {
787 unsafe { pulp::cast_lossy(_mm_cvtph_ps(self.sse2._mm_set1_epi16(cast(dst)))) }
788 }
789
790 #[inline(always)]
791 fn into_dst(self, acc: f32) -> f16 {
792 unsafe {
793 pulp::cast_lossy(_mm_cvtps_ph::<_MM_FROUND_CUR_DIRECTION>(
794 self.sse._mm_load_ss(&acc),
795 ))
796 }
797 }
798
799 #[inline(always)]
800 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
801 cast(self.fma._mm256_fmadd_ps(cast(lhs), cast(rhs), cast(acc)))
802 }
803
804 #[inline(always)]
805 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
806 unsafe { cast(_mm256_cvtph_ps(cast(lhs))) }
807 }
808
809 #[inline(always)]
810 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
811 unsafe { cast(_mm256_cvtph_ps(cast(rhs))) }
812 }
813
814 #[inline(always)]
815 fn simd_splat(self, lhs: f32) -> Self::AccN {
816 cast(self.avx._mm256_set1_ps(lhs))
817 }
818
819 #[inline(always)]
820 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
821 unsafe { cast(_mm256_cvtph_ps(cast(dst))) }
822 }
823
824 #[inline(always)]
825 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
826 unsafe { cast(_mm256_cvtps_ph::<_MM_FROUND_CUR_DIRECTION>(cast(acc))) }
827 }
828
829 #[inline(always)]
830 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
831 self.vectorize(f)
832 }
833
834 #[inline(always)]
835 fn add(self, lhs: f32, rhs: f32) -> f32 {
836 lhs + rhs
837 }
838
839 #[inline(always)]
840 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
841 cast(self.avx._mm256_mul_ps(cast(lhs), cast(rhs)))
842 }
843
844 #[inline(always)]
845 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
846 cast(self.avx._mm256_add_ps(cast(lhs), cast(rhs)))
847 }
848 }
849
850 unsafe impl MixedSimd<c32, c32, c32, c32> for V3 {
851 const SIMD_WIDTH: usize = 4;
852
853 type LhsN = [c32; 4];
854 type RhsN = [c32; 4];
855 type DstN = [c32; 4];
856 type AccN = [c32; 4];
857
858 #[inline]
859 fn try_new() -> Option<Self> {
860 Self::try_new()
861 }
862
863 #[inline(always)]
864 fn mult(self, lhs: c32, rhs: c32) -> c32 {
865 lhs * rhs
866 }
867
868 #[inline(always)]
869 fn mult_add(self, lhs: c32, rhs: c32, acc: c32) -> c32 {
870 lhs * rhs + acc
871 }
872
873 #[inline(always)]
874 fn from_lhs(self, lhs: c32) -> c32 {
875 lhs
876 }
877
878 #[inline(always)]
879 fn from_rhs(self, rhs: c32) -> c32 {
880 rhs
881 }
882
883 #[inline(always)]
884 fn from_dst(self, dst: c32) -> c32 {
885 dst
886 }
887
888 #[inline(always)]
889 fn into_dst(self, acc: c32) -> c32 {
890 acc
891 }
892
893 #[inline(always)]
894 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
895 unsafe {
896 let ab = cast(lhs);
897 let xy = cast(rhs);
898
899 let yx = _mm256_permute_ps::<0b10_11_00_01>(xy);
900 let aa = _mm256_moveldup_ps(ab);
901 let bb = _mm256_movehdup_ps(ab);
902
903 cast(_mm256_fmaddsub_ps(
904 aa,
905 xy,
906 _mm256_fmaddsub_ps(bb, yx, cast(acc)),
907 ))
908 }
909 }
910
911 #[inline(always)]
912 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
913 lhs
914 }
915
916 #[inline(always)]
917 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
918 rhs
919 }
920
921 #[inline(always)]
922 fn simd_splat(self, lhs: c32) -> Self::AccN {
923 cast(self.avx._mm256_set1_pd(cast(lhs)))
924 }
925
926 #[inline(always)]
927 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
928 dst
929 }
930
931 #[inline(always)]
932 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
933 acc
934 }
935
936 #[inline(always)]
937 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
938 self.vectorize(f)
939 }
940
941 #[inline(always)]
942 fn add(self, lhs: c32, rhs: c32) -> c32 {
943 lhs + rhs
944 }
945
946 #[inline(always)]
947 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
948 unsafe {
949 let ab = cast(lhs);
950 let xy = cast(rhs);
951
952 let yx = _mm256_permute_ps::<0b10_11_00_01>(xy);
953 let aa = _mm256_moveldup_ps(ab);
954 let bb = _mm256_movehdup_ps(ab);
955
956 cast(_mm256_fmaddsub_ps(aa, xy, _mm256_mul_ps(bb, yx)))
957 }
958 }
959
960 #[inline(always)]
961 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
962 cast(self.avx._mm256_add_ps(cast(lhs), cast(rhs)))
963 }
964 }
965
966 unsafe impl MixedSimd<c64, c64, c64, c64> for V3 {
967 const SIMD_WIDTH: usize = 2;
968
969 type LhsN = [c64; 2];
970 type RhsN = [c64; 2];
971 type DstN = [c64; 2];
972 type AccN = [c64; 2];
973
974 #[inline]
975 fn try_new() -> Option<Self> {
976 Self::try_new()
977 }
978
979 #[inline(always)]
980 fn mult(self, lhs: c64, rhs: c64) -> c64 {
981 lhs * rhs
982 }
983
984 #[inline(always)]
985 fn mult_add(self, lhs: c64, rhs: c64, acc: c64) -> c64 {
986 lhs * rhs + acc
987 }
988
989 #[inline(always)]
990 fn from_lhs(self, lhs: c64) -> c64 {
991 lhs
992 }
993
994 #[inline(always)]
995 fn from_rhs(self, rhs: c64) -> c64 {
996 rhs
997 }
998
999 #[inline(always)]
1000 fn from_dst(self, dst: c64) -> c64 {
1001 dst
1002 }
1003
1004 #[inline(always)]
1005 fn into_dst(self, acc: c64) -> c64 {
1006 acc
1007 }
1008
1009 #[inline(always)]
1010 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
1011 unsafe {
1012 let ab = cast(lhs);
1013 let xy = cast(rhs);
1014
1015 let yx = _mm256_permute_pd::<0b0101>(xy);
1016 let aa = _mm256_unpacklo_pd(ab, ab);
1017 let bb = _mm256_unpackhi_pd(ab, ab);
1018
1019 cast(_mm256_fmaddsub_pd(
1020 aa,
1021 xy,
1022 _mm256_fmaddsub_pd(bb, yx, cast(acc)),
1023 ))
1024 }
1025 }
1026
1027 #[inline(always)]
1028 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
1029 lhs
1030 }
1031
1032 #[inline(always)]
1033 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
1034 rhs
1035 }
1036
1037 #[inline(always)]
1038 fn simd_splat(self, lhs: c64) -> Self::AccN {
1039 cast([lhs; 2])
1040 }
1041
1042 #[inline(always)]
1043 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
1044 dst
1045 }
1046
1047 #[inline(always)]
1048 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
1049 acc
1050 }
1051
1052 #[inline(always)]
1053 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
1054 self.vectorize(f)
1055 }
1056
1057 #[inline(always)]
1058 fn add(self, lhs: c64, rhs: c64) -> c64 {
1059 lhs + rhs
1060 }
1061
1062 #[inline(always)]
1063 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1064 unsafe {
1065 let ab = cast(lhs);
1066 let xy = cast(rhs);
1067
1068 let yx = _mm256_permute_pd::<0b0101>(xy);
1069 let aa = _mm256_unpacklo_pd(ab, ab);
1070 let bb = _mm256_unpackhi_pd(ab, ab);
1071
1072 cast(_mm256_fmaddsub_pd(aa, xy, _mm256_mul_pd(bb, yx)))
1073 }
1074 }
1075
1076 #[inline(always)]
1077 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1078 cast(self.avx._mm256_add_pd(cast(lhs), cast(rhs)))
1079 }
1080 }
1081
1082 #[cfg(feature = "nightly")]
1083 unsafe impl MixedSimd<c32, c32, c32, c32> for V4 {
1084 const SIMD_WIDTH: usize = 8;
1085
1086 type LhsN = [c32; 8];
1087 type RhsN = [c32; 8];
1088 type DstN = [c32; 8];
1089 type AccN = [c32; 8];
1090
1091 #[inline]
1092 fn try_new() -> Option<Self> {
1093 Self::try_new()
1094 }
1095
1096 #[inline(always)]
1097 fn mult(self, lhs: c32, rhs: c32) -> c32 {
1098 lhs * rhs
1099 }
1100
1101 #[inline(always)]
1102 fn mult_add(self, lhs: c32, rhs: c32, acc: c32) -> c32 {
1103 lhs * rhs + acc
1104 }
1105
1106 #[inline(always)]
1107 fn from_lhs(self, lhs: c32) -> c32 {
1108 lhs
1109 }
1110
1111 #[inline(always)]
1112 fn from_rhs(self, rhs: c32) -> c32 {
1113 rhs
1114 }
1115
1116 #[inline(always)]
1117 fn from_dst(self, dst: c32) -> c32 {
1118 dst
1119 }
1120
1121 #[inline(always)]
1122 fn into_dst(self, acc: c32) -> c32 {
1123 acc
1124 }
1125
1126 #[inline(always)]
1127 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
1128 unsafe {
1129 let ab = cast(lhs);
1130 let xy = cast(rhs);
1131
1132 let yx = _mm512_permute_ps::<0b10_11_00_01>(xy);
1133 let aa = _mm512_moveldup_ps(ab);
1134 let bb = _mm512_movehdup_ps(ab);
1135
1136 cast(_mm512_fmaddsub_ps(
1137 aa,
1138 xy,
1139 _mm512_fmaddsub_ps(bb, yx, cast(acc)),
1140 ))
1141 }
1142 }
1143
1144 #[inline(always)]
1145 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
1146 lhs
1147 }
1148
1149 #[inline(always)]
1150 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
1151 rhs
1152 }
1153
1154 #[inline(always)]
1155 fn simd_splat(self, lhs: c32) -> Self::AccN {
1156 cast(self.avx512f._mm512_set1_pd(cast(lhs)))
1157 }
1158
1159 #[inline(always)]
1160 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
1161 dst
1162 }
1163
1164 #[inline(always)]
1165 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
1166 acc
1167 }
1168
1169 #[inline(always)]
1170 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
1171 self.vectorize(f)
1172 }
1173
1174 #[inline(always)]
1175 fn add(self, lhs: c32, rhs: c32) -> c32 {
1176 lhs + rhs
1177 }
1178
1179 #[inline(always)]
1180 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1181 unsafe {
1182 let ab = cast(lhs);
1183 let xy = cast(rhs);
1184
1185 let yx = _mm512_permute_ps::<0b10_11_00_01>(xy);
1186 let aa = _mm512_moveldup_ps(ab);
1187 let bb = _mm512_movehdup_ps(ab);
1188
1189 cast(_mm512_fmaddsub_ps(aa, xy, _mm512_mul_ps(bb, yx)))
1190 }
1191 }
1192
1193 #[inline(always)]
1194 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1195 cast(self.avx512f._mm512_add_ps(cast(lhs), cast(rhs)))
1196 }
1197 }
1198
1199 #[cfg(feature = "nightly")]
1200 unsafe impl MixedSimd<c64, c64, c64, c64> for V4 {
1201 const SIMD_WIDTH: usize = 4;
1202
1203 type LhsN = [c64; 4];
1204 type RhsN = [c64; 4];
1205 type DstN = [c64; 4];
1206 type AccN = [c64; 4];
1207
1208 #[inline]
1209 fn try_new() -> Option<Self> {
1210 Self::try_new()
1211 }
1212
1213 #[inline(always)]
1214 fn mult(self, lhs: c64, rhs: c64) -> c64 {
1215 lhs * rhs
1216 }
1217
1218 #[inline(always)]
1219 fn mult_add(self, lhs: c64, rhs: c64, acc: c64) -> c64 {
1220 lhs * rhs + acc
1221 }
1222
1223 #[inline(always)]
1224 fn from_lhs(self, lhs: c64) -> c64 {
1225 lhs
1226 }
1227
1228 #[inline(always)]
1229 fn from_rhs(self, rhs: c64) -> c64 {
1230 rhs
1231 }
1232
1233 #[inline(always)]
1234 fn from_dst(self, dst: c64) -> c64 {
1235 dst
1236 }
1237
1238 #[inline(always)]
1239 fn into_dst(self, acc: c64) -> c64 {
1240 acc
1241 }
1242
1243 #[inline(always)]
1244 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
1245 unsafe {
1246 let ab = cast(lhs);
1247 let xy = cast(rhs);
1248
1249 let yx = _mm512_permute_pd::<0b01010101>(xy);
1250 let aa = _mm512_unpacklo_pd(ab, ab);
1251 let bb = _mm512_unpackhi_pd(ab, ab);
1252
1253 cast(_mm512_fmaddsub_pd(
1254 aa,
1255 xy,
1256 _mm512_fmaddsub_pd(bb, yx, cast(acc)),
1257 ))
1258 }
1259 }
1260
1261 #[inline(always)]
1262 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
1263 lhs
1264 }
1265
1266 #[inline(always)]
1267 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
1268 rhs
1269 }
1270
1271 #[inline(always)]
1272 fn simd_splat(self, lhs: c64) -> Self::AccN {
1273 cast([lhs; 4])
1274 }
1275
1276 #[inline(always)]
1277 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
1278 dst
1279 }
1280
1281 #[inline(always)]
1282 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
1283 acc
1284 }
1285
1286 #[inline(always)]
1287 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
1288 self.vectorize(f)
1289 }
1290
1291 #[inline(always)]
1292 fn add(self, lhs: c64, rhs: c64) -> c64 {
1293 lhs + rhs
1294 }
1295
1296 #[inline(always)]
1297 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1298 unsafe {
1299 let ab = cast(lhs);
1300 let xy = cast(rhs);
1301
1302 let yx = _mm512_permute_pd::<0b01010101>(xy);
1303 let aa = _mm512_unpacklo_pd(ab, ab);
1304 let bb = _mm512_unpackhi_pd(ab, ab);
1305
1306 cast(_mm512_fmaddsub_pd(aa, xy, _mm512_mul_pd(bb, yx)))
1307 }
1308 }
1309
1310 #[inline(always)]
1311 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1312 cast(self.avx512f._mm512_add_pd(cast(lhs), cast(rhs)))
1313 }
1314 }
1315
1316 unsafe impl MixedSimd<f32, f32, f32, f32> for V3 {
1317 const SIMD_WIDTH: usize = 8;
1318
1319 type LhsN = [f32; 8];
1320 type RhsN = [f32; 8];
1321 type DstN = [f32; 8];
1322 type AccN = [f32; 8];
1323
1324 #[inline]
1325 fn try_new() -> Option<Self> {
1326 Self::try_new()
1327 }
1328
1329 #[inline(always)]
1330 fn mult(self, lhs: f32, rhs: f32) -> f32 {
1331 lhs * rhs
1332 }
1333
1334 #[inline(always)]
1335 fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
1336 unsafe { v3_fmaf(lhs, rhs, acc) }
1337 }
1338
1339 #[inline(always)]
1340 fn from_lhs(self, lhs: f32) -> f32 {
1341 lhs
1342 }
1343
1344 #[inline(always)]
1345 fn from_rhs(self, rhs: f32) -> f32 {
1346 rhs
1347 }
1348
1349 #[inline(always)]
1350 fn from_dst(self, dst: f32) -> f32 {
1351 dst
1352 }
1353
1354 #[inline(always)]
1355 fn into_dst(self, acc: f32) -> f32 {
1356 acc
1357 }
1358
1359 #[inline(always)]
1360 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
1361 cast(self.fma._mm256_fmadd_ps(cast(lhs), cast(rhs), cast(acc)))
1362 }
1363
1364 #[inline(always)]
1365 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
1366 lhs
1367 }
1368
1369 #[inline(always)]
1370 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
1371 rhs
1372 }
1373
1374 #[inline(always)]
1375 fn simd_splat(self, lhs: f32) -> Self::AccN {
1376 cast(self.avx._mm256_set1_ps(lhs))
1377 }
1378
1379 #[inline(always)]
1380 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
1381 dst
1382 }
1383
1384 #[inline(always)]
1385 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
1386 acc
1387 }
1388
1389 #[inline(always)]
1390 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
1391 self.vectorize(f)
1392 }
1393
1394 #[inline(always)]
1395 fn add(self, lhs: f32, rhs: f32) -> f32 {
1396 lhs + rhs
1397 }
1398
1399 #[inline(always)]
1400 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1401 cast(self.avx._mm256_mul_ps(cast(lhs), cast(rhs)))
1402 }
1403
1404 #[inline(always)]
1405 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1406 cast(self.avx._mm256_add_ps(cast(lhs), cast(rhs)))
1407 }
1408 }
1409
1410 unsafe impl MixedSimd<f64, f64, f64, f64> for V3 {
1411 const SIMD_WIDTH: usize = 4;
1412
1413 type LhsN = [f64; 4];
1414 type RhsN = [f64; 4];
1415 type DstN = [f64; 4];
1416 type AccN = [f64; 4];
1417
1418 #[inline]
1419 fn try_new() -> Option<Self> {
1420 Self::try_new()
1421 }
1422
1423 #[inline(always)]
1424 fn mult(self, lhs: f64, rhs: f64) -> f64 {
1425 lhs * rhs
1426 }
1427
1428 #[inline(always)]
1429 fn mult_add(self, lhs: f64, rhs: f64, acc: f64) -> f64 {
1430 unsafe { v3_fma(lhs, rhs, acc) }
1431 }
1432
1433 #[inline(always)]
1434 fn from_lhs(self, lhs: f64) -> f64 {
1435 lhs
1436 }
1437
1438 #[inline(always)]
1439 fn from_rhs(self, rhs: f64) -> f64 {
1440 rhs
1441 }
1442
1443 #[inline(always)]
1444 fn from_dst(self, dst: f64) -> f64 {
1445 dst
1446 }
1447
1448 #[inline(always)]
1449 fn into_dst(self, acc: f64) -> f64 {
1450 acc
1451 }
1452
1453 #[inline(always)]
1454 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
1455 cast(self.fma._mm256_fmadd_pd(cast(lhs), cast(rhs), cast(acc)))
1456 }
1457
1458 #[inline(always)]
1459 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
1460 lhs
1461 }
1462
1463 #[inline(always)]
1464 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
1465 rhs
1466 }
1467
1468 #[inline(always)]
1469 fn simd_splat(self, lhs: f64) -> Self::AccN {
1470 cast(self.avx._mm256_set1_pd(lhs))
1471 }
1472
1473 #[inline(always)]
1474 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
1475 dst
1476 }
1477
1478 #[inline(always)]
1479 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
1480 acc
1481 }
1482
1483 #[inline(always)]
1484 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
1485 self.vectorize(f)
1486 }
1487
1488 #[inline(always)]
1489 fn add(self, lhs: f64, rhs: f64) -> f64 {
1490 lhs + rhs
1491 }
1492
1493 #[inline(always)]
1494 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1495 cast(self.avx._mm256_mul_pd(cast(lhs), cast(rhs)))
1496 }
1497
1498 #[inline(always)]
1499 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1500 cast(self.avx._mm256_add_pd(cast(lhs), cast(rhs)))
1501 }
1502 }
1503
1504 impl Simd for V3 {
1505 #[inline(always)]
1506 unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
1507 Self::new_unchecked().vectorize(f)
1508 }
1509 }
1510
1511 #[cfg(feature = "nightly")]
1512 impl Simd for V4 {
1513 #[inline(always)]
1514 unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
1515 Self::new_unchecked().vectorize(f)
1516 }
1517 }
1518
1519 #[cfg(feature = "nightly")]
1520 #[cfg(feature = "f16")]
1521 unsafe impl MixedSimd<f16, f16, f16, f32> for V4 {
1522 const SIMD_WIDTH: usize = 16;
1523
1524 type LhsN = [f16; 16];
1525 type RhsN = [f16; 16];
1526 type DstN = [f16; 16];
1527 type AccN = [f32; 16];
1528
1529 #[inline]
1530 fn try_new() -> Option<Self> {
1531 Self::try_new()
1532 }
1533
1534 #[inline(always)]
1535 fn mult(self, lhs: f32, rhs: f32) -> f32 {
1536 lhs * rhs
1537 }
1538
1539 #[inline(always)]
1540 fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
1541 unsafe { v3_fmaf(lhs, rhs, acc) }
1542 }
1543
1544 #[inline(always)]
1545 fn from_lhs(self, lhs: f16) -> f32 {
1546 unsafe { pulp::cast_lossy(_mm_cvtph_ps(self.sse2._mm_set1_epi16(cast(lhs)))) }
1547 }
1548
1549 #[inline(always)]
1550 fn from_rhs(self, rhs: f16) -> f32 {
1551 unsafe { pulp::cast_lossy(_mm_cvtph_ps(self.sse2._mm_set1_epi16(cast(rhs)))) }
1552 }
1553
1554 #[inline(always)]
1555 fn from_dst(self, dst: f16) -> f32 {
1556 unsafe { pulp::cast_lossy(_mm_cvtph_ps(self.sse2._mm_set1_epi16(cast(dst)))) }
1557 }
1558
1559 #[inline(always)]
1560 fn into_dst(self, acc: f32) -> f16 {
1561 unsafe {
1562 pulp::cast_lossy(_mm_cvtps_ph::<_MM_FROUND_CUR_DIRECTION>(
1563 self.sse._mm_load_ss(&acc),
1564 ))
1565 }
1566 }
1567
1568 #[inline(always)]
1569 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
1570 cast(
1571 self.avx512f
1572 ._mm512_fmadd_ps(cast(lhs), cast(rhs), cast(acc)),
1573 )
1574 }
1575
1576 #[inline(always)]
1577 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
1578 unsafe { cast(_mm512_cvtph_ps(cast(lhs))) }
1579 }
1580
1581 #[inline(always)]
1582 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
1583 unsafe { cast(_mm512_cvtph_ps(cast(rhs))) }
1584 }
1585
1586 #[inline(always)]
1587 fn simd_splat(self, lhs: f32) -> Self::AccN {
1588 cast(self.avx512f._mm512_set1_ps(lhs))
1589 }
1590
1591 #[inline(always)]
1592 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
1593 unsafe { cast(_mm512_cvtph_ps(cast(dst))) }
1594 }
1595
1596 #[inline(always)]
1597 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
1598 unsafe { cast(_mm512_cvtps_ph::<_MM_FROUND_CUR_DIRECTION>(cast(acc))) }
1599 }
1600
1601 #[inline(always)]
1602 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
1603 self.vectorize(f)
1604 }
1605
1606 #[inline(always)]
1607 fn add(self, lhs: f32, rhs: f32) -> f32 {
1608 lhs + rhs
1609 }
1610
1611 #[inline(always)]
1612 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1613 cast(self.avx512f._mm512_mul_ps(cast(lhs), cast(rhs)))
1614 }
1615
1616 #[inline(always)]
1617 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1618 cast(self.avx512f._mm512_add_ps(cast(lhs), cast(rhs)))
1619 }
1620 }
1621
1622 #[cfg(feature = "nightly")]
1623 unsafe impl MixedSimd<f32, f32, f32, f32> for V4 {
1624 const SIMD_WIDTH: usize = 16;
1625
1626 type LhsN = [f32; 16];
1627 type RhsN = [f32; 16];
1628 type DstN = [f32; 16];
1629 type AccN = [f32; 16];
1630
1631 #[inline]
1632 fn try_new() -> Option<Self> {
1633 Self::try_new()
1634 }
1635
1636 #[inline(always)]
1637 fn mult(self, lhs: f32, rhs: f32) -> f32 {
1638 lhs * rhs
1639 }
1640
1641 #[inline(always)]
1642 fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
1643 unsafe { v3_fmaf(lhs, rhs, acc) }
1644 }
1645
1646 #[inline(always)]
1647 fn from_lhs(self, lhs: f32) -> f32 {
1648 lhs
1649 }
1650
1651 #[inline(always)]
1652 fn from_rhs(self, rhs: f32) -> f32 {
1653 rhs
1654 }
1655
1656 #[inline(always)]
1657 fn from_dst(self, dst: f32) -> f32 {
1658 dst
1659 }
1660
1661 #[inline(always)]
1662 fn into_dst(self, acc: f32) -> f32 {
1663 acc
1664 }
1665
1666 #[inline(always)]
1667 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
1668 cast(
1669 self.avx512f
1670 ._mm512_fmadd_ps(cast(lhs), cast(rhs), cast(acc)),
1671 )
1672 }
1673
1674 #[inline(always)]
1675 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
1676 lhs
1677 }
1678
1679 #[inline(always)]
1680 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
1681 rhs
1682 }
1683
1684 #[inline(always)]
1685 fn simd_splat(self, lhs: f32) -> Self::AccN {
1686 cast(self.avx512f._mm512_set1_ps(lhs))
1687 }
1688
1689 #[inline(always)]
1690 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
1691 dst
1692 }
1693
1694 #[inline(always)]
1695 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
1696 acc
1697 }
1698
1699 #[inline(always)]
1700 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
1701 self.vectorize(f)
1702 }
1703
1704 #[inline(always)]
1705 fn add(self, lhs: f32, rhs: f32) -> f32 {
1706 lhs + rhs
1707 }
1708
1709 #[inline(always)]
1710 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1711 cast(self.avx512f._mm512_mul_ps(cast(lhs), cast(rhs)))
1712 }
1713
1714 #[inline(always)]
1715 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1716 cast(self.avx512f._mm512_add_ps(cast(lhs), cast(rhs)))
1717 }
1718 }
1719
1720 #[cfg(feature = "nightly")]
1721 unsafe impl MixedSimd<f64, f64, f64, f64> for V4 {
1722 const SIMD_WIDTH: usize = 8;
1723
1724 type LhsN = [f64; 8];
1725 type RhsN = [f64; 8];
1726 type DstN = [f64; 8];
1727 type AccN = [f64; 8];
1728
1729 #[inline]
1730 fn try_new() -> Option<Self> {
1731 Self::try_new()
1732 }
1733
1734 #[inline(always)]
1735 fn mult(self, lhs: f64, rhs: f64) -> f64 {
1736 lhs * rhs
1737 }
1738
1739 #[inline(always)]
1740 fn mult_add(self, lhs: f64, rhs: f64, acc: f64) -> f64 {
1741 unsafe { v3_fma(lhs, rhs, acc) }
1742 }
1743
1744 #[inline(always)]
1745 fn from_lhs(self, lhs: f64) -> f64 {
1746 lhs
1747 }
1748
1749 #[inline(always)]
1750 fn from_rhs(self, rhs: f64) -> f64 {
1751 rhs
1752 }
1753
1754 #[inline(always)]
1755 fn from_dst(self, dst: f64) -> f64 {
1756 dst
1757 }
1758
1759 #[inline(always)]
1760 fn into_dst(self, acc: f64) -> f64 {
1761 acc
1762 }
1763
1764 #[inline(always)]
1765 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
1766 cast(
1767 self.avx512f
1768 ._mm512_fmadd_pd(cast(lhs), cast(rhs), cast(acc)),
1769 )
1770 }
1771
1772 #[inline(always)]
1773 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
1774 lhs
1775 }
1776
1777 #[inline(always)]
1778 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
1779 rhs
1780 }
1781
1782 #[inline(always)]
1783 fn simd_splat(self, lhs: f64) -> Self::AccN {
1784 cast(self.avx512f._mm512_set1_pd(lhs))
1785 }
1786
1787 #[inline(always)]
1788 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
1789 dst
1790 }
1791
1792 #[inline(always)]
1793 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
1794 acc
1795 }
1796
1797 #[inline(always)]
1798 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
1799 self.vectorize(f)
1800 }
1801
1802 #[inline(always)]
1803 fn add(self, lhs: f64, rhs: f64) -> f64 {
1804 lhs + rhs
1805 }
1806
1807 #[inline(always)]
1808 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1809 cast(self.avx512f._mm512_mul_pd(cast(lhs), cast(rhs)))
1810 }
1811
1812 #[inline(always)]
1813 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1814 cast(self.avx512f._mm512_add_pd(cast(lhs), cast(rhs)))
1815 }
1816 }
1817}
1818
1819#[cfg(target_arch = "aarch64")]
1820pub mod aarch64 {
1821 use super::*;
1822 use core::arch::aarch64::*;
1823 use core::arch::asm;
1824 #[allow(unused_imports)]
1825 use core::mem::transmute;
1826 use core::mem::MaybeUninit;
1827 use core::ptr;
1828
1829 #[inline(always)]
1830 pub unsafe fn neon_fmaf(a: f32, b: f32, c: f32) -> f32 {
1831 #[cfg(feature = "std")]
1832 {
1833 f32::mul_add(a, b, c)
1834 }
1835 #[cfg(not(feature = "std"))]
1836 {
1837 a * b + c
1838 }
1839 }
1840
1841 #[inline(always)]
1842 pub unsafe fn neon_fma(a: f64, b: f64, c: f64) -> f64 {
1843 #[cfg(feature = "std")]
1844 {
1845 f64::mul_add(a, b, c)
1846 }
1847 #[cfg(not(feature = "std"))]
1848 {
1849 a * b + c
1850 }
1851 }
1852
1853 #[target_feature(enable = "fp16,neon")]
1854 #[inline]
1855 pub unsafe fn f16_to_f32_fp16(i: u16) -> f32 {
1856 let result: f32;
1857 asm!(
1858 "fcvt {0:s}, {1:h}",
1859 out(vreg) result,
1860 in(vreg) i,
1861 options(pure, nomem, nostack));
1862 result
1863 }
1864
1865 #[target_feature(enable = "fp16,neon")]
1866 #[inline]
1867 pub unsafe fn f32_to_f16_fp16(f: f32) -> u16 {
1868 let result: u16;
1869 asm!(
1870 "fcvt {0:h}, {1:s}",
1871 out(vreg) result,
1872 in(vreg) f,
1873 options(pure, nomem, nostack));
1874 result
1875 }
1876
1877 #[target_feature(enable = "fp16,neon")]
1878 #[inline]
1879 pub unsafe fn f16x4_to_f32x4_fp16(v: &[u16; 4]) -> [f32; 4] {
1880 let mut vec = MaybeUninit::<uint16x4_t>::uninit();
1881 ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
1882 let result: float32x4_t;
1883 asm!(
1884 "fcvtl {0:v}.4s, {1:v}.4h",
1885 out(vreg) result,
1886 in(vreg) vec.assume_init(),
1887 options(pure, nomem, nostack));
1888 *(&result as *const float32x4_t).cast()
1889 }
1890
1891 #[target_feature(enable = "fp16,neon")]
1892 #[inline]
1893 pub unsafe fn f32x4_to_f16x4_fp16(v: &[f32; 4]) -> [u16; 4] {
1894 let mut vec = MaybeUninit::<float32x4_t>::uninit();
1895 ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
1896 let result: uint16x4_t;
1897 asm!(
1898 "fcvtn {0:v}.4h, {1:v}.4s",
1899 out(vreg) result,
1900 in(vreg) vec.assume_init(),
1901 options(pure, nomem, nostack));
1902 *(&result as *const uint16x4_t).cast()
1903 }
1904
1905 #[target_feature(enable = "fp16")]
1906 #[inline]
1907 pub unsafe fn add_f16_fp16(a: u16, b: u16) -> u16 {
1908 let result: u16;
1909 asm!(
1910 "fadd {0:h}, {1:h}, {2:h}",
1911 out(vreg) result,
1912 in(vreg) a,
1913 in(vreg) b,
1914 options(pure, nomem, nostack));
1915 result
1916 }
1917
1918 #[target_feature(enable = "fp16")]
1919 #[inline]
1920 pub unsafe fn fmaq_f16(mut a: u16, b: u16, c: u16) -> u16 {
1921 asm!(
1922 "fmadd {0:h}, {1:h}, {2:h}, {0:h}",
1923 inout(vreg) a,
1924 in(vreg) b,
1925 in(vreg) c,
1926 options(pure, nomem, nostack));
1927 a
1928 }
1929
1930 #[target_feature(enable = "fp16")]
1931 #[inline]
1932 pub unsafe fn multiply_f16_fp16(a: u16, b: u16) -> u16 {
1933 let result: u16;
1934 asm!(
1935 "fmul {0:h}, {1:h}, {2:h}",
1936 out(vreg) result,
1937 in(vreg) a,
1938 in(vreg) b,
1939 options(pure, nomem, nostack));
1940 result
1941 }
1942
1943 #[allow(non_camel_case_types)]
1944 type float16x8_t = uint16x8_t;
1945
1946 #[inline]
1949 pub unsafe fn vmulq_f16(a: float16x8_t, b: float16x8_t) -> float16x8_t {
1950 let result: float16x8_t;
1951 asm!(
1952 "fmul {0:v}.8h, {1:v}.8h, {2:v}.8h",
1953 out(vreg) result,
1954 in(vreg) a,
1955 in(vreg) b,
1956 options(pure, nomem, nostack));
1957 result
1958 }
1959
1960 #[inline]
1963 pub unsafe fn vaddq_f16(a: float16x8_t, b: float16x8_t) -> float16x8_t {
1964 let result: float16x8_t;
1965 asm!(
1966 "fadd {0:v}.8h, {1:v}.8h, {2:v}.8h",
1967 out(vreg) result,
1968 in(vreg) a,
1969 in(vreg) b,
1970 options(pure, nomem, nostack));
1971 result
1972 }
1973
1974 #[inline]
1976 pub unsafe fn vfmaq_f16(mut a: float16x8_t, b: float16x8_t, c: float16x8_t) -> float16x8_t {
1977 asm!(
1978 "fmla {0:v}.8h, {1:v}.8h, {2:v}.8h",
1979 inout(vreg) a,
1980 in(vreg) b,
1981 in(vreg) c,
1982 options(pure, nomem, nostack));
1983 a
1984 }
1985
1986 #[inline]
1987 pub unsafe fn vfmaq_laneq_f16<const LANE: i32>(
1988 mut a: float16x8_t,
1989 b: float16x8_t,
1990 c: float16x8_t,
1991 ) -> float16x8_t {
1992 match LANE {
1993 0 => asm!(
1994 "fmla {0:v}.8h, {1:v}.8h, {2:v}.h[0]",
1995 inout(vreg) a,
1996 in(vreg) b,
1997 in(vreg_low16) c,
1998 options(pure, nomem, nostack)),
1999 1 => asm!(
2000 "fmla {0:v}.8h, {1:v}.8h, {2:v}.h[1]",
2001 inout(vreg) a,
2002 in(vreg) b,
2003 in(vreg_low16) c,
2004 options(pure, nomem, nostack)),
2005 2 => asm!(
2006 "fmla {0:v}.8h, {1:v}.8h, {2:v}.h[2]",
2007 inout(vreg) a,
2008 in(vreg) b,
2009 in(vreg_low16) c,
2010 options(pure, nomem, nostack)),
2011 3 => asm!(
2012 "fmla {0:v}.8h, {1:v}.8h, {2:v}.h[3]",
2013 inout(vreg) a,
2014 in(vreg) b,
2015 in(vreg_low16) c,
2016 options(pure, nomem, nostack)),
2017 4 => asm!(
2018 "fmla {0:v}.8h, {1:v}.8h, {2:v}.h[4]",
2019 inout(vreg) a,
2020 in(vreg) b,
2021 in(vreg_low16) c,
2022 options(pure, nomem, nostack)),
2023 5 => asm!(
2024 "fmla {0:v}.8h, {1:v}.8h, {2:v}.h[5]",
2025 inout(vreg) a,
2026 in(vreg) b,
2027 in(vreg_low16) c,
2028 options(pure, nomem, nostack)),
2029 6 => asm!(
2030 "fmla {0:v}.8h, {1:v}.8h, {2:v}.h[6]",
2031 inout(vreg) a,
2032 in(vreg) b,
2033 in(vreg_low16) c,
2034 options(pure, nomem, nostack)),
2035 7 => asm!(
2036 "fmla {0:v}.8h, {1:v}.8h, {2:v}.h[7]",
2037 inout(vreg) a,
2038 in(vreg) b,
2039 in(vreg_low16) c,
2040 options(pure, nomem, nostack)),
2041 _ => unreachable!(),
2042 }
2043 a
2044 }
2045
2046 #[derive(Copy, Clone, Debug)]
2047 pub struct Neon {
2048 __private: (),
2049 }
2050
2051 #[derive(Copy, Clone, Debug)]
2052 pub struct NeonFp16 {
2053 __private: (),
2054 }
2055
2056 #[derive(Copy, Clone, Debug)]
2057 pub struct NeonFcma {
2058 __private: (),
2059 }
2060
2061 impl Simd for Neon {
2062 #[inline]
2063 #[target_feature(enable = "neon")]
2064 unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
2065 f.call()
2066 }
2067 }
2068
2069 impl Simd for NeonFp16 {
2070 #[inline]
2071 #[target_feature(enable = "neon,fp16")]
2072 unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
2073 f.call()
2074 }
2075 }
2076
2077 #[cfg(feature = "f16")]
2078 unsafe impl MixedSimd<f16, f16, f16, f32> for NeonFp16 {
2079 const SIMD_WIDTH: usize = 4;
2080
2081 type LhsN = [f16; 4];
2082 type RhsN = [f16; 4];
2083 type DstN = [f16; 4];
2084 type AccN = [f32; 4];
2085
2086 #[inline]
2087 fn try_new() -> Option<Self> {
2088 if crate::feature_detected!("neon") && crate::feature_detected!("fp16") {
2089 Some(Self { __private: () })
2090 } else {
2091 None
2092 }
2093 }
2094
2095 #[inline(always)]
2096 fn mult(self, lhs: f32, rhs: f32) -> f32 {
2097 lhs * rhs
2098 }
2099
2100 #[inline(always)]
2101 fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
2102 unsafe { neon_fmaf(lhs, rhs, acc) }
2103 }
2104
2105 #[inline(always)]
2106 fn from_lhs(self, lhs: f16) -> f32 {
2107 unsafe { f16_to_f32_fp16(cast(lhs)) }
2108 }
2109
2110 #[inline(always)]
2111 fn from_rhs(self, rhs: f16) -> f32 {
2112 unsafe { f16_to_f32_fp16(cast(rhs)) }
2113 }
2114
2115 #[inline(always)]
2116 fn from_dst(self, dst: f16) -> f32 {
2117 unsafe { f16_to_f32_fp16(cast(dst)) }
2118 }
2119
2120 #[inline(always)]
2121 fn into_dst(self, acc: f32) -> f16 {
2122 unsafe { cast(f32_to_f16_fp16(acc)) }
2123 }
2124
2125 #[inline(always)]
2126 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
2127 unsafe { transmute(vfmaq_f32(transmute(acc), transmute(lhs), transmute(rhs))) }
2128 }
2129
2130 #[inline(always)]
2131 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
2132 unsafe { f16x4_to_f32x4_fp16(&cast(lhs)) }
2133 }
2134
2135 #[inline(always)]
2136 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
2137 unsafe { f16x4_to_f32x4_fp16(&cast(rhs)) }
2138 }
2139
2140 #[inline(always)]
2141 fn simd_splat(self, lhs: f32) -> Self::AccN {
2142 [lhs, lhs, lhs, lhs]
2143 }
2144
2145 #[inline(always)]
2146 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
2147 unsafe { f16x4_to_f32x4_fp16(&cast(dst)) }
2148 }
2149
2150 #[inline(always)]
2151 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
2152 unsafe { cast(f32x4_to_f16x4_fp16(&acc)) }
2153 }
2154
2155 #[inline(always)]
2156 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
2157 #[inline]
2158 #[target_feature(enable = "neon,fp16")]
2159 unsafe fn implementation<F: NullaryFnOnce>(f: F) -> F::Output {
2160 f.call()
2161 }
2162
2163 unsafe { implementation(f) }
2164 }
2165
2166 #[inline(always)]
2167 fn add(self, lhs: f32, rhs: f32) -> f32 {
2168 lhs + rhs
2169 }
2170
2171 #[inline(always)]
2172 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
2173 unsafe { transmute(vmulq_f32(transmute(lhs), transmute(rhs))) }
2174 }
2175
2176 #[inline(always)]
2177 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
2178 unsafe { transmute(vaddq_f32(transmute(lhs), transmute(rhs))) }
2179 }
2180 }
2181
2182 #[cfg(feature = "f16")]
2183 unsafe impl MixedSimd<f16, f16, f16, f16> for NeonFp16 {
2184 const SIMD_WIDTH: usize = 8;
2185
2186 type LhsN = [f16; 8];
2187 type RhsN = [f16; 8];
2188 type DstN = [f16; 8];
2189 type AccN = [f16; 8];
2190
2191 #[inline]
2192 fn try_new() -> Option<Self> {
2193 if crate::feature_detected!("neon") && crate::feature_detected!("fp16") {
2194 Some(Self { __private: () })
2195 } else {
2196 None
2197 }
2198 }
2199
2200 #[inline(always)]
2201 fn mult(self, lhs: f16, rhs: f16) -> f16 {
2202 unsafe { cast(multiply_f16_fp16(cast(lhs), cast(rhs))) }
2203 }
2204
2205 #[inline(always)]
2206 fn mult_add(self, lhs: f16, rhs: f16, acc: f16) -> f16 {
2207 unsafe { cast(fmaq_f16(cast(acc), cast(lhs), cast(rhs))) }
2208 }
2209
2210 #[inline(always)]
2211 fn from_lhs(self, lhs: f16) -> f16 {
2212 lhs
2213 }
2214
2215 #[inline(always)]
2216 fn from_rhs(self, rhs: f16) -> f16 {
2217 rhs
2218 }
2219
2220 #[inline(always)]
2221 fn from_dst(self, dst: f16) -> f16 {
2222 dst
2223 }
2224
2225 #[inline(always)]
2226 fn into_dst(self, acc: f16) -> f16 {
2227 acc
2228 }
2229
2230 #[inline(always)]
2231 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
2232 unsafe { transmute(vfmaq_f16(transmute(acc), transmute(lhs), transmute(rhs))) }
2233 }
2234
2235 #[inline(always)]
2236 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
2237 lhs
2238 }
2239
2240 #[inline(always)]
2241 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
2242 rhs
2243 }
2244
2245 #[inline(always)]
2246 fn simd_splat(self, lhs: f16) -> Self::AccN {
2247 [lhs, lhs, lhs, lhs, lhs, lhs, lhs, lhs]
2248 }
2249
2250 #[inline(always)]
2251 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
2252 dst
2253 }
2254
2255 #[inline(always)]
2256 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
2257 acc
2258 }
2259
2260 #[inline(always)]
2261 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
2262 #[inline]
2263 #[target_feature(enable = "neon,fp16")]
2264 unsafe fn implementation<F: NullaryFnOnce>(f: F) -> F::Output {
2265 f.call()
2266 }
2267
2268 unsafe { implementation(f) }
2269 }
2270
2271 #[inline(always)]
2272 fn add(self, lhs: f16, rhs: f16) -> f16 {
2273 unsafe { cast(add_f16_fp16(cast(lhs), cast(rhs))) }
2274 }
2275
2276 #[inline(always)]
2277 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
2278 unsafe { transmute(vmulq_f16(transmute(lhs), transmute(rhs))) }
2279 }
2280
2281 #[inline(always)]
2282 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
2283 unsafe { transmute(vaddq_f16(transmute(lhs), transmute(rhs))) }
2284 }
2285 }
2286
2287 #[cfg(feature = "f16")]
2288 unsafe impl MixedSimd<f16, f16, f16, f32> for Neon {
2289 const SIMD_WIDTH: usize = 4;
2290
2291 type LhsN = [f16; 4];
2292 type RhsN = [f16; 4];
2293 type DstN = [f16; 4];
2294 type AccN = [f32; 4];
2295
2296 #[inline]
2297 fn try_new() -> Option<Self> {
2298 if crate::feature_detected!("neon") {
2299 Some(Self { __private: () })
2300 } else {
2301 None
2302 }
2303 }
2304
2305 #[inline(always)]
2306 fn mult(self, lhs: f32, rhs: f32) -> f32 {
2307 lhs * rhs
2308 }
2309
2310 #[inline(always)]
2311 fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
2312 unsafe { neon_fmaf(lhs, rhs, acc) }
2313 }
2314
2315 #[inline(always)]
2316 fn from_lhs(self, lhs: f16) -> f32 {
2317 lhs.into()
2318 }
2319
2320 #[inline(always)]
2321 fn from_rhs(self, rhs: f16) -> f32 {
2322 rhs.into()
2323 }
2324
2325 #[inline(always)]
2326 fn from_dst(self, dst: f16) -> f32 {
2327 dst.into()
2328 }
2329
2330 #[inline(always)]
2331 fn into_dst(self, acc: f32) -> f16 {
2332 f16::from_f32(acc)
2333 }
2334
2335 #[inline(always)]
2336 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
2337 unsafe { transmute(vfmaq_f32(transmute(acc), transmute(lhs), transmute(rhs))) }
2338 }
2339
2340 #[inline(always)]
2341 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
2342 [lhs[0].into(), lhs[1].into(), lhs[2].into(), lhs[3].into()]
2343 }
2344
2345 #[inline(always)]
2346 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
2347 [rhs[0].into(), rhs[1].into(), rhs[2].into(), rhs[3].into()]
2348 }
2349
2350 #[inline(always)]
2351 fn simd_splat(self, lhs: f32) -> Self::AccN {
2352 [lhs, lhs, lhs, lhs]
2353 }
2354
2355 #[inline(always)]
2356 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
2357 [dst[0].into(), dst[1].into(), dst[2].into(), dst[3].into()]
2358 }
2359
2360 #[inline(always)]
2361 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
2362 [
2363 f16::from_f32(acc[0]),
2364 f16::from_f32(acc[1]),
2365 f16::from_f32(acc[2]),
2366 f16::from_f32(acc[3]),
2367 ]
2368 }
2369
2370 #[inline(always)]
2371 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
2372 #[inline]
2373 #[target_feature(enable = "neon")]
2374 unsafe fn implementation<F: NullaryFnOnce>(f: F) -> F::Output {
2375 f.call()
2376 }
2377
2378 unsafe { implementation(f) }
2379 }
2380
2381 #[inline(always)]
2382 fn add(self, lhs: f32, rhs: f32) -> f32 {
2383 lhs + rhs
2384 }
2385
2386 #[inline(always)]
2387 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
2388 unsafe { transmute(vmulq_f32(transmute(lhs), transmute(rhs))) }
2389 }
2390
2391 #[inline(always)]
2392 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
2393 unsafe { transmute(vaddq_f32(transmute(lhs), transmute(rhs))) }
2394 }
2395 }
2396}
2397
2398pub trait Boilerplate: Copy + Send + Sync + core::fmt::Debug + 'static + PartialEq {}
2399impl<T: Copy + Send + Sync + core::fmt::Debug + PartialEq + 'static> Boilerplate for T {}
2400
2401pub unsafe trait MixedSimd<Lhs, Rhs, Dst, Acc>: Simd {
2402 const SIMD_WIDTH: usize;
2403
2404 type LhsN: Boilerplate;
2405 type RhsN: Boilerplate;
2406 type DstN: Boilerplate;
2407 type AccN: Boilerplate;
2408
2409 fn try_new() -> Option<Self>;
2410
2411 fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output;
2412
2413 fn add(self, lhs: Acc, rhs: Acc) -> Acc;
2414 fn mult(self, lhs: Acc, rhs: Acc) -> Acc;
2415 fn mult_add(self, lhs: Acc, rhs: Acc, acc: Acc) -> Acc;
2416 fn from_lhs(self, lhs: Lhs) -> Acc;
2417 fn from_rhs(self, rhs: Rhs) -> Acc;
2418 fn from_dst(self, dst: Dst) -> Acc;
2419 fn into_dst(self, acc: Acc) -> Dst;
2420
2421 fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN;
2422 fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN;
2423 fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN;
2424 fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN;
2425 fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN;
2426 fn simd_splat(self, lhs: Acc) -> Self::AccN;
2427
2428 fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN;
2429 fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN;
2430}