gemm_common/
simd.rs

1pub use bytemuck::Pod;
2#[cfg(feature = "f16")]
3use half::f16;
4pub use pulp::{cast, NullaryFnOnce};
5
6#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
7pub use x86::*;
8
9#[cfg(target_arch = "aarch64")]
10pub use aarch64::*;
11
12use crate::gemm::{c32, c64};
13
14pub trait Simd: Copy + Send + Sync + 'static {
15    unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output;
16}
17
18#[derive(Copy, Clone, Debug)]
19pub struct Scalar;
20
21impl Simd for Scalar {
22    #[inline(always)]
23    unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
24        f.call()
25    }
26}
27
28#[cfg(feature = "f16")]
29unsafe impl MixedSimd<f16, f16, f16, f32> for Scalar {
30    const SIMD_WIDTH: usize = 1;
31
32    type LhsN = f16;
33    type RhsN = f16;
34    type DstN = f16;
35    type AccN = f32;
36
37    #[inline]
38    fn try_new() -> Option<Self> {
39        Some(Self)
40    }
41
42    #[inline(always)]
43    fn add(self, lhs: f32, rhs: f32) -> f32 {
44        lhs + rhs
45    }
46
47    #[inline(always)]
48    fn mult(self, lhs: f32, rhs: f32) -> f32 {
49        lhs * rhs
50    }
51
52    #[inline(always)]
53    fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
54        lhs * rhs + acc
55    }
56
57    #[inline(always)]
58    fn from_lhs(self, lhs: f16) -> f32 {
59        lhs.into()
60    }
61
62    #[inline(always)]
63    fn from_rhs(self, rhs: f16) -> f32 {
64        rhs.into()
65    }
66
67    #[inline(always)]
68    fn from_dst(self, dst: f16) -> f32 {
69        dst.into()
70    }
71
72    #[inline(always)]
73    fn into_dst(self, acc: f32) -> f16 {
74        f16::from_f32(acc)
75    }
76
77    #[inline(always)]
78    fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
79        lhs * rhs + acc
80    }
81
82    #[inline(always)]
83    fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
84        lhs.into()
85    }
86
87    #[inline(always)]
88    fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
89        rhs.into()
90    }
91
92    #[inline(always)]
93    fn simd_splat(self, lhs: f32) -> Self::AccN {
94        lhs
95    }
96
97    #[inline(always)]
98    fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
99        dst.into()
100    }
101
102    #[inline(always)]
103    fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
104        f16::from_f32(acc)
105    }
106
107    #[inline(always)]
108    fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
109        f.call()
110    }
111
112    #[inline(always)]
113    fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
114        lhs * rhs
115    }
116
117    #[inline(always)]
118    fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
119        lhs + rhs
120    }
121}
122
123unsafe impl MixedSimd<f32, f32, f32, f32> for Scalar {
124    const SIMD_WIDTH: usize = 1;
125
126    type LhsN = f32;
127    type RhsN = f32;
128    type DstN = f32;
129    type AccN = f32;
130
131    #[inline]
132    fn try_new() -> Option<Self> {
133        Some(Self)
134    }
135
136    #[inline(always)]
137    fn mult(self, lhs: f32, rhs: f32) -> f32 {
138        lhs * rhs
139    }
140
141    #[inline(always)]
142    fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
143        lhs * rhs + acc
144    }
145
146    #[inline(always)]
147    fn from_lhs(self, lhs: f32) -> f32 {
148        lhs
149    }
150
151    #[inline(always)]
152    fn from_rhs(self, rhs: f32) -> f32 {
153        rhs
154    }
155
156    #[inline(always)]
157    fn from_dst(self, dst: f32) -> f32 {
158        dst
159    }
160
161    #[inline(always)]
162    fn into_dst(self, acc: f32) -> f32 {
163        acc
164    }
165
166    #[inline(always)]
167    fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
168        lhs * rhs + acc
169    }
170
171    #[inline(always)]
172    fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
173        lhs
174    }
175
176    #[inline(always)]
177    fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
178        rhs
179    }
180
181    #[inline(always)]
182    fn simd_splat(self, lhs: f32) -> Self::AccN {
183        lhs
184    }
185
186    #[inline(always)]
187    fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
188        dst
189    }
190
191    #[inline(always)]
192    fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
193        acc
194    }
195
196    #[inline(always)]
197    fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
198        f.call()
199    }
200
201    #[inline(always)]
202    fn add(self, lhs: f32, rhs: f32) -> f32 {
203        lhs + rhs
204    }
205
206    #[inline(always)]
207    fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
208        lhs * rhs
209    }
210
211    #[inline(always)]
212    fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
213        lhs + rhs
214    }
215}
216
217unsafe impl MixedSimd<f64, f64, f64, f64> for Scalar {
218    const SIMD_WIDTH: usize = 1;
219
220    type LhsN = f64;
221    type RhsN = f64;
222    type DstN = f64;
223    type AccN = f64;
224
225    #[inline]
226    fn try_new() -> Option<Self> {
227        Some(Self)
228    }
229
230    #[inline(always)]
231    fn mult(self, lhs: f64, rhs: f64) -> f64 {
232        lhs * rhs
233    }
234
235    #[inline(always)]
236    fn mult_add(self, lhs: f64, rhs: f64, acc: f64) -> f64 {
237        lhs * rhs + acc
238    }
239
240    #[inline(always)]
241    fn from_lhs(self, lhs: f64) -> f64 {
242        lhs
243    }
244
245    #[inline(always)]
246    fn from_rhs(self, rhs: f64) -> f64 {
247        rhs
248    }
249
250    #[inline(always)]
251    fn from_dst(self, dst: f64) -> f64 {
252        dst
253    }
254
255    #[inline(always)]
256    fn into_dst(self, acc: f64) -> f64 {
257        acc
258    }
259
260    #[inline(always)]
261    fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
262        lhs * rhs + acc
263    }
264
265    #[inline(always)]
266    fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
267        lhs
268    }
269
270    #[inline(always)]
271    fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
272        rhs
273    }
274
275    #[inline(always)]
276    fn simd_splat(self, lhs: f64) -> Self::AccN {
277        lhs
278    }
279
280    #[inline(always)]
281    fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
282        dst
283    }
284
285    #[inline(always)]
286    fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
287        acc
288    }
289
290    #[inline(always)]
291    fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
292        f.call()
293    }
294
295    #[inline(always)]
296    fn add(self, lhs: f64, rhs: f64) -> f64 {
297        lhs + rhs
298    }
299
300    #[inline(always)]
301    fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
302        lhs * rhs
303    }
304
305    #[inline(always)]
306    fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
307        lhs + rhs
308    }
309}
310
311unsafe impl MixedSimd<c32, c32, c32, c32> for Scalar {
312    const SIMD_WIDTH: usize = 1;
313
314    type LhsN = c32;
315    type RhsN = c32;
316    type DstN = c32;
317    type AccN = c32;
318
319    #[inline]
320    fn try_new() -> Option<Self> {
321        Some(Self)
322    }
323
324    #[inline(always)]
325    fn mult(self, lhs: c32, rhs: c32) -> c32 {
326        lhs * rhs
327    }
328
329    #[inline(always)]
330    fn mult_add(self, lhs: c32, rhs: c32, acc: c32) -> c32 {
331        lhs * rhs + acc
332    }
333
334    #[inline(always)]
335    fn from_lhs(self, lhs: c32) -> c32 {
336        lhs
337    }
338
339    #[inline(always)]
340    fn from_rhs(self, rhs: c32) -> c32 {
341        rhs
342    }
343
344    #[inline(always)]
345    fn from_dst(self, dst: c32) -> c32 {
346        dst
347    }
348
349    #[inline(always)]
350    fn into_dst(self, acc: c32) -> c32 {
351        acc
352    }
353
354    #[inline(always)]
355    fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
356        lhs * rhs + acc
357    }
358
359    #[inline(always)]
360    fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
361        lhs
362    }
363
364    #[inline(always)]
365    fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
366        rhs
367    }
368
369    #[inline(always)]
370    fn simd_splat(self, lhs: c32) -> Self::AccN {
371        lhs
372    }
373
374    #[inline(always)]
375    fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
376        dst
377    }
378
379    #[inline(always)]
380    fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
381        acc
382    }
383
384    #[inline(always)]
385    fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
386        f.call()
387    }
388
389    #[inline(always)]
390    fn add(self, lhs: c32, rhs: c32) -> c32 {
391        lhs + rhs
392    }
393
394    #[inline(always)]
395    fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
396        lhs * rhs
397    }
398
399    #[inline(always)]
400    fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
401        lhs + rhs
402    }
403}
404
405unsafe impl MixedSimd<c64, c64, c64, c64> for Scalar {
406    const SIMD_WIDTH: usize = 1;
407
408    type LhsN = c64;
409    type RhsN = c64;
410    type DstN = c64;
411    type AccN = c64;
412
413    #[inline]
414    fn try_new() -> Option<Self> {
415        Some(Self)
416    }
417
418    #[inline(always)]
419    fn mult(self, lhs: c64, rhs: c64) -> c64 {
420        lhs * rhs
421    }
422
423    #[inline(always)]
424    fn mult_add(self, lhs: c64, rhs: c64, acc: c64) -> c64 {
425        lhs * rhs + acc
426    }
427
428    #[inline(always)]
429    fn from_lhs(self, lhs: c64) -> c64 {
430        lhs
431    }
432
433    #[inline(always)]
434    fn from_rhs(self, rhs: c64) -> c64 {
435        rhs
436    }
437
438    #[inline(always)]
439    fn from_dst(self, dst: c64) -> c64 {
440        dst
441    }
442
443    #[inline(always)]
444    fn into_dst(self, acc: c64) -> c64 {
445        acc
446    }
447
448    #[inline(always)]
449    fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
450        lhs * rhs + acc
451    }
452
453    #[inline(always)]
454    fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
455        lhs
456    }
457
458    #[inline(always)]
459    fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
460        rhs
461    }
462
463    #[inline(always)]
464    fn simd_splat(self, lhs: c64) -> Self::AccN {
465        lhs
466    }
467
468    #[inline(always)]
469    fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
470        dst
471    }
472
473    #[inline(always)]
474    fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
475        acc
476    }
477
478    #[inline(always)]
479    fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
480        f.call()
481    }
482
483    #[inline(always)]
484    fn add(self, lhs: c64, rhs: c64) -> c64 {
485        lhs + rhs
486    }
487
488    #[inline(always)]
489    fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
490        lhs * rhs
491    }
492
493    #[inline(always)]
494    fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
495        lhs + rhs
496    }
497}
498
499#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
500mod x86 {
501    use super::*;
502    #[cfg(target_arch = "x86")]
503    use core::arch::x86::*;
504    #[cfg(target_arch = "x86_64")]
505    use core::arch::x86_64::*;
506
507    #[inline(always)]
508    pub unsafe fn v3_fmaf(a: f32, b: f32, c: f32) -> f32 {
509        #[cfg(feature = "std")]
510        {
511            f32::mul_add(a, b, c)
512        }
513        #[cfg(not(feature = "std"))]
514        {
515            libm::fmaf(a, b, c)
516        }
517    }
518
519    #[inline(always)]
520    pub unsafe fn v3_fma(a: f64, b: f64, c: f64) -> f64 {
521        #[cfg(feature = "std")]
522        {
523            f64::mul_add(a, b, c)
524        }
525        #[cfg(not(feature = "std"))]
526        {
527            libm::fma(a, b, c)
528        }
529    }
530
531    #[derive(Copy, Clone)]
532    pub struct Sse;
533    #[derive(Copy, Clone)]
534    pub struct Avx;
535    #[derive(Copy, Clone)]
536    pub struct Fma;
537
538    #[cfg(feature = "nightly")]
539    #[derive(Copy, Clone)]
540    pub struct Avx512f;
541
542    impl Simd for Sse {
543        #[inline]
544        #[target_feature(enable = "sse,sse2")]
545        unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
546            f.call()
547        }
548    }
549
550    impl Simd for Avx {
551        #[inline]
552        #[target_feature(enable = "avx")]
553        unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
554            f.call()
555        }
556    }
557
558    impl Simd for Fma {
559        #[inline]
560        #[target_feature(enable = "fma")]
561        unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
562            f.call()
563        }
564    }
565
566    #[cfg(feature = "nightly")]
567    impl Simd for Avx512f {
568        #[inline]
569        #[target_feature(enable = "avx512f")]
570        unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
571            f.call()
572        }
573    }
574
575    #[derive(Debug, Copy, Clone)]
576    pub struct V3Half {
577        __private: (),
578    }
579
580    #[cfg(feature = "f16")]
581    pulp::simd_type! {
582        pub struct V3 {
583            pub sse: "sse",
584            pub sse2: "sse2",
585            pub fxsr: "fxsr",
586            pub sse3: "sse3",
587            pub ssse3: "ssse3",
588            pub sse4_1: "sse4.1",
589            pub sse4_2: "sse4.2",
590            pub avx: "avx",
591            pub avx2: "avx2",
592            pub fma: "fma",
593            pub f16c: "f16c",
594        }
595    }
596
597    #[cfg(feature = "nightly")]
598    #[cfg(feature = "f16")]
599    pulp::simd_type! {
600        pub struct V4 {
601            pub sse: "sse",
602            pub sse2: "sse2",
603            pub fxsr: "fxsr",
604            pub sse3: "sse3",
605            pub ssse3: "ssse3",
606            pub sse4_1: "sse4.1",
607            pub sse4_2: "sse4.2",
608            pub avx: "avx",
609            pub avx2: "avx2",
610            pub fma: "fma",
611            pub f16c: "f16c",
612            pub avx512f: "avx512f",
613        }
614    }
615
616    #[cfg(not(feature = "f16"))]
617    pulp::simd_type! {
618        pub struct V3 {
619            pub sse: "sse",
620            pub sse2: "sse2",
621            pub fxsr: "fxsr",
622            pub sse3: "sse3",
623            pub ssse3: "ssse3",
624            pub sse4_1: "sse4.1",
625            pub sse4_2: "sse4.2",
626            pub avx: "avx",
627            pub avx2: "avx2",
628            pub fma: "fma",
629        }
630    }
631
632    #[cfg(feature = "nightly")]
633    #[cfg(not(feature = "f16"))]
634    pulp::simd_type! {
635        pub struct V4 {
636            pub sse: "sse",
637            pub sse2: "sse2",
638            pub fxsr: "fxsr",
639            pub sse3: "sse3",
640            pub ssse3: "ssse3",
641            pub sse4_1: "sse4.1",
642            pub sse4_2: "sse4.2",
643            pub avx: "avx",
644            pub avx2: "avx2",
645            pub fma: "fma",
646            pub avx512f: "avx512f",
647        }
648    }
649
650    impl Simd for V3Half {
651        unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
652            f.call()
653        }
654    }
655
656    #[cfg(feature = "f16")]
657    unsafe impl MixedSimd<f16, f16, f16, f32> for V3Half {
658        const SIMD_WIDTH: usize = 4;
659
660        type LhsN = [f16; 4];
661        type RhsN = [f16; 4];
662        type DstN = [f16; 4];
663        type AccN = [f32; 4];
664
665        #[inline]
666        fn try_new() -> Option<Self> {
667            Some(Self { __private: () })
668        }
669
670        #[inline(always)]
671        fn mult(self, lhs: f32, rhs: f32) -> f32 {
672            lhs * rhs
673        }
674
675        #[inline(always)]
676        fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
677            unsafe { v3_fmaf(lhs, rhs, acc) }
678        }
679
680        #[inline(always)]
681        fn from_lhs(self, _lhs: f16) -> f32 {
682            todo!()
683        }
684
685        #[inline(always)]
686        fn from_rhs(self, _rhs: f16) -> f32 {
687            todo!()
688        }
689
690        #[inline(always)]
691        fn from_dst(self, _dst: f16) -> f32 {
692            todo!()
693        }
694
695        #[inline(always)]
696        fn into_dst(self, _acc: f32) -> f16 {
697            todo!()
698        }
699
700        #[inline(always)]
701        fn simd_mult_add(self, _lhs: Self::AccN, _rhs: Self::AccN, _acc: Self::AccN) -> Self::AccN {
702            todo!()
703        }
704
705        #[inline(always)]
706        fn simd_from_lhs(self, _lhs: Self::LhsN) -> Self::AccN {
707            todo!()
708        }
709
710        #[inline(always)]
711        fn simd_from_rhs(self, _rhs: Self::RhsN) -> Self::AccN {
712            todo!()
713        }
714
715        #[inline(always)]
716        fn simd_splat(self, _lhs: f32) -> Self::AccN {
717            todo!()
718        }
719
720        #[inline(always)]
721        fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
722            unsafe { cast(_mm_cvtph_ps(cast([dst, [f16::ZERO; 4]]))) }
723        }
724
725        #[inline(always)]
726        fn simd_into_dst(self, _acc: Self::AccN) -> Self::DstN {
727            todo!()
728        }
729
730        #[inline(always)]
731        fn vectorize<F: NullaryFnOnce>(self, _f: F) -> F::Output {
732            todo!()
733        }
734
735        #[inline(always)]
736        fn add(self, _lhs: f32, _rhs: f32) -> f32 {
737            todo!()
738        }
739
740        #[inline(always)]
741        fn simd_mul(self, _lhs: Self::AccN, _rhs: Self::AccN) -> Self::AccN {
742            todo!()
743        }
744
745        #[inline(always)]
746        fn simd_add(self, _lhs: Self::AccN, _rhs: Self::AccN) -> Self::AccN {
747            todo!()
748        }
749    }
750
751    #[cfg(feature = "f16")]
752    unsafe impl MixedSimd<f16, f16, f16, f32> for V3 {
753        const SIMD_WIDTH: usize = 8;
754
755        type LhsN = [f16; 8];
756        type RhsN = [f16; 8];
757        type DstN = [f16; 8];
758        type AccN = [f32; 8];
759
760        #[inline]
761        fn try_new() -> Option<Self> {
762            Self::try_new()
763        }
764
765        #[inline(always)]
766        fn mult(self, lhs: f32, rhs: f32) -> f32 {
767            lhs * rhs
768        }
769
770        #[inline(always)]
771        fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
772            unsafe { v3_fmaf(lhs, rhs, acc) }
773        }
774
775        #[inline(always)]
776        fn from_lhs(self, lhs: f16) -> f32 {
777            unsafe { pulp::cast_lossy(_mm_cvtph_ps(self.sse2._mm_set1_epi16(cast(lhs)))) }
778        }
779
780        #[inline(always)]
781        fn from_rhs(self, rhs: f16) -> f32 {
782            unsafe { pulp::cast_lossy(_mm_cvtph_ps(self.sse2._mm_set1_epi16(cast(rhs)))) }
783        }
784
785        #[inline(always)]
786        fn from_dst(self, dst: f16) -> f32 {
787            unsafe { pulp::cast_lossy(_mm_cvtph_ps(self.sse2._mm_set1_epi16(cast(dst)))) }
788        }
789
790        #[inline(always)]
791        fn into_dst(self, acc: f32) -> f16 {
792            unsafe {
793                pulp::cast_lossy(_mm_cvtps_ph::<_MM_FROUND_CUR_DIRECTION>(
794                    self.sse._mm_load_ss(&acc),
795                ))
796            }
797        }
798
799        #[inline(always)]
800        fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
801            cast(self.fma._mm256_fmadd_ps(cast(lhs), cast(rhs), cast(acc)))
802        }
803
804        #[inline(always)]
805        fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
806            unsafe { cast(_mm256_cvtph_ps(cast(lhs))) }
807        }
808
809        #[inline(always)]
810        fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
811            unsafe { cast(_mm256_cvtph_ps(cast(rhs))) }
812        }
813
814        #[inline(always)]
815        fn simd_splat(self, lhs: f32) -> Self::AccN {
816            cast(self.avx._mm256_set1_ps(lhs))
817        }
818
819        #[inline(always)]
820        fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
821            unsafe { cast(_mm256_cvtph_ps(cast(dst))) }
822        }
823
824        #[inline(always)]
825        fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
826            unsafe { cast(_mm256_cvtps_ph::<_MM_FROUND_CUR_DIRECTION>(cast(acc))) }
827        }
828
829        #[inline(always)]
830        fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
831            self.vectorize(f)
832        }
833
834        #[inline(always)]
835        fn add(self, lhs: f32, rhs: f32) -> f32 {
836            lhs + rhs
837        }
838
839        #[inline(always)]
840        fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
841            cast(self.avx._mm256_mul_ps(cast(lhs), cast(rhs)))
842        }
843
844        #[inline(always)]
845        fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
846            cast(self.avx._mm256_add_ps(cast(lhs), cast(rhs)))
847        }
848    }
849
850    unsafe impl MixedSimd<c32, c32, c32, c32> for V3 {
851        const SIMD_WIDTH: usize = 4;
852
853        type LhsN = [c32; 4];
854        type RhsN = [c32; 4];
855        type DstN = [c32; 4];
856        type AccN = [c32; 4];
857
858        #[inline]
859        fn try_new() -> Option<Self> {
860            Self::try_new()
861        }
862
863        #[inline(always)]
864        fn mult(self, lhs: c32, rhs: c32) -> c32 {
865            lhs * rhs
866        }
867
868        #[inline(always)]
869        fn mult_add(self, lhs: c32, rhs: c32, acc: c32) -> c32 {
870            lhs * rhs + acc
871        }
872
873        #[inline(always)]
874        fn from_lhs(self, lhs: c32) -> c32 {
875            lhs
876        }
877
878        #[inline(always)]
879        fn from_rhs(self, rhs: c32) -> c32 {
880            rhs
881        }
882
883        #[inline(always)]
884        fn from_dst(self, dst: c32) -> c32 {
885            dst
886        }
887
888        #[inline(always)]
889        fn into_dst(self, acc: c32) -> c32 {
890            acc
891        }
892
893        #[inline(always)]
894        fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
895            unsafe {
896                let ab = cast(lhs);
897                let xy = cast(rhs);
898
899                let yx = _mm256_permute_ps::<0b10_11_00_01>(xy);
900                let aa = _mm256_moveldup_ps(ab);
901                let bb = _mm256_movehdup_ps(ab);
902
903                cast(_mm256_fmaddsub_ps(
904                    aa,
905                    xy,
906                    _mm256_fmaddsub_ps(bb, yx, cast(acc)),
907                ))
908            }
909        }
910
911        #[inline(always)]
912        fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
913            lhs
914        }
915
916        #[inline(always)]
917        fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
918            rhs
919        }
920
921        #[inline(always)]
922        fn simd_splat(self, lhs: c32) -> Self::AccN {
923            cast(self.avx._mm256_set1_pd(cast(lhs)))
924        }
925
926        #[inline(always)]
927        fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
928            dst
929        }
930
931        #[inline(always)]
932        fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
933            acc
934        }
935
936        #[inline(always)]
937        fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
938            self.vectorize(f)
939        }
940
941        #[inline(always)]
942        fn add(self, lhs: c32, rhs: c32) -> c32 {
943            lhs + rhs
944        }
945
946        #[inline(always)]
947        fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
948            unsafe {
949                let ab = cast(lhs);
950                let xy = cast(rhs);
951
952                let yx = _mm256_permute_ps::<0b10_11_00_01>(xy);
953                let aa = _mm256_moveldup_ps(ab);
954                let bb = _mm256_movehdup_ps(ab);
955
956                cast(_mm256_fmaddsub_ps(aa, xy, _mm256_mul_ps(bb, yx)))
957            }
958        }
959
960        #[inline(always)]
961        fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
962            cast(self.avx._mm256_add_ps(cast(lhs), cast(rhs)))
963        }
964    }
965
966    unsafe impl MixedSimd<c64, c64, c64, c64> for V3 {
967        const SIMD_WIDTH: usize = 2;
968
969        type LhsN = [c64; 2];
970        type RhsN = [c64; 2];
971        type DstN = [c64; 2];
972        type AccN = [c64; 2];
973
974        #[inline]
975        fn try_new() -> Option<Self> {
976            Self::try_new()
977        }
978
979        #[inline(always)]
980        fn mult(self, lhs: c64, rhs: c64) -> c64 {
981            lhs * rhs
982        }
983
984        #[inline(always)]
985        fn mult_add(self, lhs: c64, rhs: c64, acc: c64) -> c64 {
986            lhs * rhs + acc
987        }
988
989        #[inline(always)]
990        fn from_lhs(self, lhs: c64) -> c64 {
991            lhs
992        }
993
994        #[inline(always)]
995        fn from_rhs(self, rhs: c64) -> c64 {
996            rhs
997        }
998
999        #[inline(always)]
1000        fn from_dst(self, dst: c64) -> c64 {
1001            dst
1002        }
1003
1004        #[inline(always)]
1005        fn into_dst(self, acc: c64) -> c64 {
1006            acc
1007        }
1008
1009        #[inline(always)]
1010        fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
1011            unsafe {
1012                let ab = cast(lhs);
1013                let xy = cast(rhs);
1014
1015                let yx = _mm256_permute_pd::<0b0101>(xy);
1016                let aa = _mm256_unpacklo_pd(ab, ab);
1017                let bb = _mm256_unpackhi_pd(ab, ab);
1018
1019                cast(_mm256_fmaddsub_pd(
1020                    aa,
1021                    xy,
1022                    _mm256_fmaddsub_pd(bb, yx, cast(acc)),
1023                ))
1024            }
1025        }
1026
1027        #[inline(always)]
1028        fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
1029            lhs
1030        }
1031
1032        #[inline(always)]
1033        fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
1034            rhs
1035        }
1036
1037        #[inline(always)]
1038        fn simd_splat(self, lhs: c64) -> Self::AccN {
1039            cast([lhs; 2])
1040        }
1041
1042        #[inline(always)]
1043        fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
1044            dst
1045        }
1046
1047        #[inline(always)]
1048        fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
1049            acc
1050        }
1051
1052        #[inline(always)]
1053        fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
1054            self.vectorize(f)
1055        }
1056
1057        #[inline(always)]
1058        fn add(self, lhs: c64, rhs: c64) -> c64 {
1059            lhs + rhs
1060        }
1061
1062        #[inline(always)]
1063        fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1064            unsafe {
1065                let ab = cast(lhs);
1066                let xy = cast(rhs);
1067
1068                let yx = _mm256_permute_pd::<0b0101>(xy);
1069                let aa = _mm256_unpacklo_pd(ab, ab);
1070                let bb = _mm256_unpackhi_pd(ab, ab);
1071
1072                cast(_mm256_fmaddsub_pd(aa, xy, _mm256_mul_pd(bb, yx)))
1073            }
1074        }
1075
1076        #[inline(always)]
1077        fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1078            cast(self.avx._mm256_add_pd(cast(lhs), cast(rhs)))
1079        }
1080    }
1081
1082    #[cfg(feature = "nightly")]
1083    unsafe impl MixedSimd<c32, c32, c32, c32> for V4 {
1084        const SIMD_WIDTH: usize = 8;
1085
1086        type LhsN = [c32; 8];
1087        type RhsN = [c32; 8];
1088        type DstN = [c32; 8];
1089        type AccN = [c32; 8];
1090
1091        #[inline]
1092        fn try_new() -> Option<Self> {
1093            Self::try_new()
1094        }
1095
1096        #[inline(always)]
1097        fn mult(self, lhs: c32, rhs: c32) -> c32 {
1098            lhs * rhs
1099        }
1100
1101        #[inline(always)]
1102        fn mult_add(self, lhs: c32, rhs: c32, acc: c32) -> c32 {
1103            lhs * rhs + acc
1104        }
1105
1106        #[inline(always)]
1107        fn from_lhs(self, lhs: c32) -> c32 {
1108            lhs
1109        }
1110
1111        #[inline(always)]
1112        fn from_rhs(self, rhs: c32) -> c32 {
1113            rhs
1114        }
1115
1116        #[inline(always)]
1117        fn from_dst(self, dst: c32) -> c32 {
1118            dst
1119        }
1120
1121        #[inline(always)]
1122        fn into_dst(self, acc: c32) -> c32 {
1123            acc
1124        }
1125
1126        #[inline(always)]
1127        fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
1128            unsafe {
1129                let ab = cast(lhs);
1130                let xy = cast(rhs);
1131
1132                let yx = _mm512_permute_ps::<0b10_11_00_01>(xy);
1133                let aa = _mm512_moveldup_ps(ab);
1134                let bb = _mm512_movehdup_ps(ab);
1135
1136                cast(_mm512_fmaddsub_ps(
1137                    aa,
1138                    xy,
1139                    _mm512_fmaddsub_ps(bb, yx, cast(acc)),
1140                ))
1141            }
1142        }
1143
1144        #[inline(always)]
1145        fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
1146            lhs
1147        }
1148
1149        #[inline(always)]
1150        fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
1151            rhs
1152        }
1153
1154        #[inline(always)]
1155        fn simd_splat(self, lhs: c32) -> Self::AccN {
1156            cast(self.avx512f._mm512_set1_pd(cast(lhs)))
1157        }
1158
1159        #[inline(always)]
1160        fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
1161            dst
1162        }
1163
1164        #[inline(always)]
1165        fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
1166            acc
1167        }
1168
1169        #[inline(always)]
1170        fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
1171            self.vectorize(f)
1172        }
1173
1174        #[inline(always)]
1175        fn add(self, lhs: c32, rhs: c32) -> c32 {
1176            lhs + rhs
1177        }
1178
1179        #[inline(always)]
1180        fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1181            unsafe {
1182                let ab = cast(lhs);
1183                let xy = cast(rhs);
1184
1185                let yx = _mm512_permute_ps::<0b10_11_00_01>(xy);
1186                let aa = _mm512_moveldup_ps(ab);
1187                let bb = _mm512_movehdup_ps(ab);
1188
1189                cast(_mm512_fmaddsub_ps(aa, xy, _mm512_mul_ps(bb, yx)))
1190            }
1191        }
1192
1193        #[inline(always)]
1194        fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1195            cast(self.avx512f._mm512_add_ps(cast(lhs), cast(rhs)))
1196        }
1197    }
1198
1199    #[cfg(feature = "nightly")]
1200    unsafe impl MixedSimd<c64, c64, c64, c64> for V4 {
1201        const SIMD_WIDTH: usize = 4;
1202
1203        type LhsN = [c64; 4];
1204        type RhsN = [c64; 4];
1205        type DstN = [c64; 4];
1206        type AccN = [c64; 4];
1207
1208        #[inline]
1209        fn try_new() -> Option<Self> {
1210            Self::try_new()
1211        }
1212
1213        #[inline(always)]
1214        fn mult(self, lhs: c64, rhs: c64) -> c64 {
1215            lhs * rhs
1216        }
1217
1218        #[inline(always)]
1219        fn mult_add(self, lhs: c64, rhs: c64, acc: c64) -> c64 {
1220            lhs * rhs + acc
1221        }
1222
1223        #[inline(always)]
1224        fn from_lhs(self, lhs: c64) -> c64 {
1225            lhs
1226        }
1227
1228        #[inline(always)]
1229        fn from_rhs(self, rhs: c64) -> c64 {
1230            rhs
1231        }
1232
1233        #[inline(always)]
1234        fn from_dst(self, dst: c64) -> c64 {
1235            dst
1236        }
1237
1238        #[inline(always)]
1239        fn into_dst(self, acc: c64) -> c64 {
1240            acc
1241        }
1242
1243        #[inline(always)]
1244        fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
1245            unsafe {
1246                let ab = cast(lhs);
1247                let xy = cast(rhs);
1248
1249                let yx = _mm512_permute_pd::<0b01010101>(xy);
1250                let aa = _mm512_unpacklo_pd(ab, ab);
1251                let bb = _mm512_unpackhi_pd(ab, ab);
1252
1253                cast(_mm512_fmaddsub_pd(
1254                    aa,
1255                    xy,
1256                    _mm512_fmaddsub_pd(bb, yx, cast(acc)),
1257                ))
1258            }
1259        }
1260
1261        #[inline(always)]
1262        fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
1263            lhs
1264        }
1265
1266        #[inline(always)]
1267        fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
1268            rhs
1269        }
1270
1271        #[inline(always)]
1272        fn simd_splat(self, lhs: c64) -> Self::AccN {
1273            cast([lhs; 4])
1274        }
1275
1276        #[inline(always)]
1277        fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
1278            dst
1279        }
1280
1281        #[inline(always)]
1282        fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
1283            acc
1284        }
1285
1286        #[inline(always)]
1287        fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
1288            self.vectorize(f)
1289        }
1290
1291        #[inline(always)]
1292        fn add(self, lhs: c64, rhs: c64) -> c64 {
1293            lhs + rhs
1294        }
1295
1296        #[inline(always)]
1297        fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1298            unsafe {
1299                let ab = cast(lhs);
1300                let xy = cast(rhs);
1301
1302                let yx = _mm512_permute_pd::<0b01010101>(xy);
1303                let aa = _mm512_unpacklo_pd(ab, ab);
1304                let bb = _mm512_unpackhi_pd(ab, ab);
1305
1306                cast(_mm512_fmaddsub_pd(aa, xy, _mm512_mul_pd(bb, yx)))
1307            }
1308        }
1309
1310        #[inline(always)]
1311        fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1312            cast(self.avx512f._mm512_add_pd(cast(lhs), cast(rhs)))
1313        }
1314    }
1315
1316    unsafe impl MixedSimd<f32, f32, f32, f32> for V3 {
1317        const SIMD_WIDTH: usize = 8;
1318
1319        type LhsN = [f32; 8];
1320        type RhsN = [f32; 8];
1321        type DstN = [f32; 8];
1322        type AccN = [f32; 8];
1323
1324        #[inline]
1325        fn try_new() -> Option<Self> {
1326            Self::try_new()
1327        }
1328
1329        #[inline(always)]
1330        fn mult(self, lhs: f32, rhs: f32) -> f32 {
1331            lhs * rhs
1332        }
1333
1334        #[inline(always)]
1335        fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
1336            unsafe { v3_fmaf(lhs, rhs, acc) }
1337        }
1338
1339        #[inline(always)]
1340        fn from_lhs(self, lhs: f32) -> f32 {
1341            lhs
1342        }
1343
1344        #[inline(always)]
1345        fn from_rhs(self, rhs: f32) -> f32 {
1346            rhs
1347        }
1348
1349        #[inline(always)]
1350        fn from_dst(self, dst: f32) -> f32 {
1351            dst
1352        }
1353
1354        #[inline(always)]
1355        fn into_dst(self, acc: f32) -> f32 {
1356            acc
1357        }
1358
1359        #[inline(always)]
1360        fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
1361            cast(self.fma._mm256_fmadd_ps(cast(lhs), cast(rhs), cast(acc)))
1362        }
1363
1364        #[inline(always)]
1365        fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
1366            lhs
1367        }
1368
1369        #[inline(always)]
1370        fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
1371            rhs
1372        }
1373
1374        #[inline(always)]
1375        fn simd_splat(self, lhs: f32) -> Self::AccN {
1376            cast(self.avx._mm256_set1_ps(lhs))
1377        }
1378
1379        #[inline(always)]
1380        fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
1381            dst
1382        }
1383
1384        #[inline(always)]
1385        fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
1386            acc
1387        }
1388
1389        #[inline(always)]
1390        fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
1391            self.vectorize(f)
1392        }
1393
1394        #[inline(always)]
1395        fn add(self, lhs: f32, rhs: f32) -> f32 {
1396            lhs + rhs
1397        }
1398
1399        #[inline(always)]
1400        fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1401            cast(self.avx._mm256_mul_ps(cast(lhs), cast(rhs)))
1402        }
1403
1404        #[inline(always)]
1405        fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1406            cast(self.avx._mm256_add_ps(cast(lhs), cast(rhs)))
1407        }
1408    }
1409
1410    unsafe impl MixedSimd<f64, f64, f64, f64> for V3 {
1411        const SIMD_WIDTH: usize = 4;
1412
1413        type LhsN = [f64; 4];
1414        type RhsN = [f64; 4];
1415        type DstN = [f64; 4];
1416        type AccN = [f64; 4];
1417
1418        #[inline]
1419        fn try_new() -> Option<Self> {
1420            Self::try_new()
1421        }
1422
1423        #[inline(always)]
1424        fn mult(self, lhs: f64, rhs: f64) -> f64 {
1425            lhs * rhs
1426        }
1427
1428        #[inline(always)]
1429        fn mult_add(self, lhs: f64, rhs: f64, acc: f64) -> f64 {
1430            unsafe { v3_fma(lhs, rhs, acc) }
1431        }
1432
1433        #[inline(always)]
1434        fn from_lhs(self, lhs: f64) -> f64 {
1435            lhs
1436        }
1437
1438        #[inline(always)]
1439        fn from_rhs(self, rhs: f64) -> f64 {
1440            rhs
1441        }
1442
1443        #[inline(always)]
1444        fn from_dst(self, dst: f64) -> f64 {
1445            dst
1446        }
1447
1448        #[inline(always)]
1449        fn into_dst(self, acc: f64) -> f64 {
1450            acc
1451        }
1452
1453        #[inline(always)]
1454        fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
1455            cast(self.fma._mm256_fmadd_pd(cast(lhs), cast(rhs), cast(acc)))
1456        }
1457
1458        #[inline(always)]
1459        fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
1460            lhs
1461        }
1462
1463        #[inline(always)]
1464        fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
1465            rhs
1466        }
1467
1468        #[inline(always)]
1469        fn simd_splat(self, lhs: f64) -> Self::AccN {
1470            cast(self.avx._mm256_set1_pd(lhs))
1471        }
1472
1473        #[inline(always)]
1474        fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
1475            dst
1476        }
1477
1478        #[inline(always)]
1479        fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
1480            acc
1481        }
1482
1483        #[inline(always)]
1484        fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
1485            self.vectorize(f)
1486        }
1487
1488        #[inline(always)]
1489        fn add(self, lhs: f64, rhs: f64) -> f64 {
1490            lhs + rhs
1491        }
1492
1493        #[inline(always)]
1494        fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1495            cast(self.avx._mm256_mul_pd(cast(lhs), cast(rhs)))
1496        }
1497
1498        #[inline(always)]
1499        fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1500            cast(self.avx._mm256_add_pd(cast(lhs), cast(rhs)))
1501        }
1502    }
1503
1504    impl Simd for V3 {
1505        #[inline(always)]
1506        unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
1507            Self::new_unchecked().vectorize(f)
1508        }
1509    }
1510
1511    #[cfg(feature = "nightly")]
1512    impl Simd for V4 {
1513        #[inline(always)]
1514        unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
1515            Self::new_unchecked().vectorize(f)
1516        }
1517    }
1518
1519    #[cfg(feature = "nightly")]
1520    #[cfg(feature = "f16")]
1521    unsafe impl MixedSimd<f16, f16, f16, f32> for V4 {
1522        const SIMD_WIDTH: usize = 16;
1523
1524        type LhsN = [f16; 16];
1525        type RhsN = [f16; 16];
1526        type DstN = [f16; 16];
1527        type AccN = [f32; 16];
1528
1529        #[inline]
1530        fn try_new() -> Option<Self> {
1531            Self::try_new()
1532        }
1533
1534        #[inline(always)]
1535        fn mult(self, lhs: f32, rhs: f32) -> f32 {
1536            lhs * rhs
1537        }
1538
1539        #[inline(always)]
1540        fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
1541            unsafe { v3_fmaf(lhs, rhs, acc) }
1542        }
1543
1544        #[inline(always)]
1545        fn from_lhs(self, lhs: f16) -> f32 {
1546            unsafe { pulp::cast_lossy(_mm_cvtph_ps(self.sse2._mm_set1_epi16(cast(lhs)))) }
1547        }
1548
1549        #[inline(always)]
1550        fn from_rhs(self, rhs: f16) -> f32 {
1551            unsafe { pulp::cast_lossy(_mm_cvtph_ps(self.sse2._mm_set1_epi16(cast(rhs)))) }
1552        }
1553
1554        #[inline(always)]
1555        fn from_dst(self, dst: f16) -> f32 {
1556            unsafe { pulp::cast_lossy(_mm_cvtph_ps(self.sse2._mm_set1_epi16(cast(dst)))) }
1557        }
1558
1559        #[inline(always)]
1560        fn into_dst(self, acc: f32) -> f16 {
1561            unsafe {
1562                pulp::cast_lossy(_mm_cvtps_ph::<_MM_FROUND_CUR_DIRECTION>(
1563                    self.sse._mm_load_ss(&acc),
1564                ))
1565            }
1566        }
1567
1568        #[inline(always)]
1569        fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
1570            cast(
1571                self.avx512f
1572                    ._mm512_fmadd_ps(cast(lhs), cast(rhs), cast(acc)),
1573            )
1574        }
1575
1576        #[inline(always)]
1577        fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
1578            unsafe { cast(_mm512_cvtph_ps(cast(lhs))) }
1579        }
1580
1581        #[inline(always)]
1582        fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
1583            unsafe { cast(_mm512_cvtph_ps(cast(rhs))) }
1584        }
1585
1586        #[inline(always)]
1587        fn simd_splat(self, lhs: f32) -> Self::AccN {
1588            cast(self.avx512f._mm512_set1_ps(lhs))
1589        }
1590
1591        #[inline(always)]
1592        fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
1593            unsafe { cast(_mm512_cvtph_ps(cast(dst))) }
1594        }
1595
1596        #[inline(always)]
1597        fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
1598            unsafe { cast(_mm512_cvtps_ph::<_MM_FROUND_CUR_DIRECTION>(cast(acc))) }
1599        }
1600
1601        #[inline(always)]
1602        fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
1603            self.vectorize(f)
1604        }
1605
1606        #[inline(always)]
1607        fn add(self, lhs: f32, rhs: f32) -> f32 {
1608            lhs + rhs
1609        }
1610
1611        #[inline(always)]
1612        fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1613            cast(self.avx512f._mm512_mul_ps(cast(lhs), cast(rhs)))
1614        }
1615
1616        #[inline(always)]
1617        fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1618            cast(self.avx512f._mm512_add_ps(cast(lhs), cast(rhs)))
1619        }
1620    }
1621
1622    #[cfg(feature = "nightly")]
1623    unsafe impl MixedSimd<f32, f32, f32, f32> for V4 {
1624        const SIMD_WIDTH: usize = 16;
1625
1626        type LhsN = [f32; 16];
1627        type RhsN = [f32; 16];
1628        type DstN = [f32; 16];
1629        type AccN = [f32; 16];
1630
1631        #[inline]
1632        fn try_new() -> Option<Self> {
1633            Self::try_new()
1634        }
1635
1636        #[inline(always)]
1637        fn mult(self, lhs: f32, rhs: f32) -> f32 {
1638            lhs * rhs
1639        }
1640
1641        #[inline(always)]
1642        fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
1643            unsafe { v3_fmaf(lhs, rhs, acc) }
1644        }
1645
1646        #[inline(always)]
1647        fn from_lhs(self, lhs: f32) -> f32 {
1648            lhs
1649        }
1650
1651        #[inline(always)]
1652        fn from_rhs(self, rhs: f32) -> f32 {
1653            rhs
1654        }
1655
1656        #[inline(always)]
1657        fn from_dst(self, dst: f32) -> f32 {
1658            dst
1659        }
1660
1661        #[inline(always)]
1662        fn into_dst(self, acc: f32) -> f32 {
1663            acc
1664        }
1665
1666        #[inline(always)]
1667        fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
1668            cast(
1669                self.avx512f
1670                    ._mm512_fmadd_ps(cast(lhs), cast(rhs), cast(acc)),
1671            )
1672        }
1673
1674        #[inline(always)]
1675        fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
1676            lhs
1677        }
1678
1679        #[inline(always)]
1680        fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
1681            rhs
1682        }
1683
1684        #[inline(always)]
1685        fn simd_splat(self, lhs: f32) -> Self::AccN {
1686            cast(self.avx512f._mm512_set1_ps(lhs))
1687        }
1688
1689        #[inline(always)]
1690        fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
1691            dst
1692        }
1693
1694        #[inline(always)]
1695        fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
1696            acc
1697        }
1698
1699        #[inline(always)]
1700        fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
1701            self.vectorize(f)
1702        }
1703
1704        #[inline(always)]
1705        fn add(self, lhs: f32, rhs: f32) -> f32 {
1706            lhs + rhs
1707        }
1708
1709        #[inline(always)]
1710        fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1711            cast(self.avx512f._mm512_mul_ps(cast(lhs), cast(rhs)))
1712        }
1713
1714        #[inline(always)]
1715        fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1716            cast(self.avx512f._mm512_add_ps(cast(lhs), cast(rhs)))
1717        }
1718    }
1719
1720    #[cfg(feature = "nightly")]
1721    unsafe impl MixedSimd<f64, f64, f64, f64> for V4 {
1722        const SIMD_WIDTH: usize = 8;
1723
1724        type LhsN = [f64; 8];
1725        type RhsN = [f64; 8];
1726        type DstN = [f64; 8];
1727        type AccN = [f64; 8];
1728
1729        #[inline]
1730        fn try_new() -> Option<Self> {
1731            Self::try_new()
1732        }
1733
1734        #[inline(always)]
1735        fn mult(self, lhs: f64, rhs: f64) -> f64 {
1736            lhs * rhs
1737        }
1738
1739        #[inline(always)]
1740        fn mult_add(self, lhs: f64, rhs: f64, acc: f64) -> f64 {
1741            unsafe { v3_fma(lhs, rhs, acc) }
1742        }
1743
1744        #[inline(always)]
1745        fn from_lhs(self, lhs: f64) -> f64 {
1746            lhs
1747        }
1748
1749        #[inline(always)]
1750        fn from_rhs(self, rhs: f64) -> f64 {
1751            rhs
1752        }
1753
1754        #[inline(always)]
1755        fn from_dst(self, dst: f64) -> f64 {
1756            dst
1757        }
1758
1759        #[inline(always)]
1760        fn into_dst(self, acc: f64) -> f64 {
1761            acc
1762        }
1763
1764        #[inline(always)]
1765        fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
1766            cast(
1767                self.avx512f
1768                    ._mm512_fmadd_pd(cast(lhs), cast(rhs), cast(acc)),
1769            )
1770        }
1771
1772        #[inline(always)]
1773        fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
1774            lhs
1775        }
1776
1777        #[inline(always)]
1778        fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
1779            rhs
1780        }
1781
1782        #[inline(always)]
1783        fn simd_splat(self, lhs: f64) -> Self::AccN {
1784            cast(self.avx512f._mm512_set1_pd(lhs))
1785        }
1786
1787        #[inline(always)]
1788        fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
1789            dst
1790        }
1791
1792        #[inline(always)]
1793        fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
1794            acc
1795        }
1796
1797        #[inline(always)]
1798        fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
1799            self.vectorize(f)
1800        }
1801
1802        #[inline(always)]
1803        fn add(self, lhs: f64, rhs: f64) -> f64 {
1804            lhs + rhs
1805        }
1806
1807        #[inline(always)]
1808        fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1809            cast(self.avx512f._mm512_mul_pd(cast(lhs), cast(rhs)))
1810        }
1811
1812        #[inline(always)]
1813        fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
1814            cast(self.avx512f._mm512_add_pd(cast(lhs), cast(rhs)))
1815        }
1816    }
1817}
1818
1819#[cfg(target_arch = "aarch64")]
1820pub mod aarch64 {
1821    use super::*;
1822    use core::arch::aarch64::*;
1823    use core::arch::asm;
1824    #[allow(unused_imports)]
1825    use core::mem::transmute;
1826    use core::mem::MaybeUninit;
1827    use core::ptr;
1828
1829    #[inline(always)]
1830    pub unsafe fn neon_fmaf(a: f32, b: f32, c: f32) -> f32 {
1831        #[cfg(feature = "std")]
1832        {
1833            f32::mul_add(a, b, c)
1834        }
1835        #[cfg(not(feature = "std"))]
1836        {
1837            a * b + c
1838        }
1839    }
1840
1841    #[inline(always)]
1842    pub unsafe fn neon_fma(a: f64, b: f64, c: f64) -> f64 {
1843        #[cfg(feature = "std")]
1844        {
1845            f64::mul_add(a, b, c)
1846        }
1847        #[cfg(not(feature = "std"))]
1848        {
1849            a * b + c
1850        }
1851    }
1852
1853    #[target_feature(enable = "fp16,neon")]
1854    #[inline]
1855    pub unsafe fn f16_to_f32_fp16(i: u16) -> f32 {
1856        let result: f32;
1857        asm!(
1858        "fcvt {0:s}, {1:h}",
1859        out(vreg) result,
1860        in(vreg) i,
1861        options(pure, nomem, nostack));
1862        result
1863    }
1864
1865    #[target_feature(enable = "fp16,neon")]
1866    #[inline]
1867    pub unsafe fn f32_to_f16_fp16(f: f32) -> u16 {
1868        let result: u16;
1869        asm!(
1870        "fcvt {0:h}, {1:s}",
1871        out(vreg) result,
1872        in(vreg) f,
1873        options(pure, nomem, nostack));
1874        result
1875    }
1876
1877    #[target_feature(enable = "fp16,neon")]
1878    #[inline]
1879    pub unsafe fn f16x4_to_f32x4_fp16(v: &[u16; 4]) -> [f32; 4] {
1880        let mut vec = MaybeUninit::<uint16x4_t>::uninit();
1881        ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
1882        let result: float32x4_t;
1883        asm!(
1884        "fcvtl {0:v}.4s, {1:v}.4h",
1885        out(vreg) result,
1886        in(vreg) vec.assume_init(),
1887        options(pure, nomem, nostack));
1888        *(&result as *const float32x4_t).cast()
1889    }
1890
1891    #[target_feature(enable = "fp16,neon")]
1892    #[inline]
1893    pub unsafe fn f32x4_to_f16x4_fp16(v: &[f32; 4]) -> [u16; 4] {
1894        let mut vec = MaybeUninit::<float32x4_t>::uninit();
1895        ptr::copy_nonoverlapping(v.as_ptr(), vec.as_mut_ptr().cast(), 4);
1896        let result: uint16x4_t;
1897        asm!(
1898        "fcvtn {0:v}.4h, {1:v}.4s",
1899        out(vreg) result,
1900        in(vreg) vec.assume_init(),
1901        options(pure, nomem, nostack));
1902        *(&result as *const uint16x4_t).cast()
1903    }
1904
1905    #[target_feature(enable = "fp16")]
1906    #[inline]
1907    pub unsafe fn add_f16_fp16(a: u16, b: u16) -> u16 {
1908        let result: u16;
1909        asm!(
1910        "fadd {0:h}, {1:h}, {2:h}",
1911        out(vreg) result,
1912        in(vreg) a,
1913        in(vreg) b,
1914        options(pure, nomem, nostack));
1915        result
1916    }
1917
1918    #[target_feature(enable = "fp16")]
1919    #[inline]
1920    pub unsafe fn fmaq_f16(mut a: u16, b: u16, c: u16) -> u16 {
1921        asm!(
1922        "fmadd {0:h}, {1:h}, {2:h}, {0:h}",
1923        inout(vreg) a,
1924        in(vreg) b,
1925        in(vreg) c,
1926        options(pure, nomem, nostack));
1927        a
1928    }
1929
1930    #[target_feature(enable = "fp16")]
1931    #[inline]
1932    pub unsafe fn multiply_f16_fp16(a: u16, b: u16) -> u16 {
1933        let result: u16;
1934        asm!(
1935        "fmul {0:h}, {1:h}, {2:h}",
1936        out(vreg) result,
1937        in(vreg) a,
1938        in(vreg) b,
1939        options(pure, nomem, nostack));
1940        result
1941    }
1942
1943    #[allow(non_camel_case_types)]
1944    type float16x8_t = uint16x8_t;
1945
1946    /// Floating point multiplication
1947    /// [doc](https://developer.arm.com/documentation/dui0801/g/A64-SIMD-Vector-Instructions/FMUL--vector-)
1948    #[inline]
1949    pub unsafe fn vmulq_f16(a: float16x8_t, b: float16x8_t) -> float16x8_t {
1950        let result: float16x8_t;
1951        asm!(
1952                "fmul {0:v}.8h, {1:v}.8h, {2:v}.8h",
1953                out(vreg) result,
1954                in(vreg) a,
1955                in(vreg) b,
1956                options(pure, nomem, nostack));
1957        result
1958    }
1959
1960    /// Floating point addition
1961    /// [doc](https://developer.arm.com/documentation/dui0801/g/A64-SIMD-Vector-Instructions/FADD--vector-)
1962    #[inline]
1963    pub unsafe fn vaddq_f16(a: float16x8_t, b: float16x8_t) -> float16x8_t {
1964        let result: float16x8_t;
1965        asm!(
1966                "fadd {0:v}.8h, {1:v}.8h, {2:v}.8h",
1967                out(vreg) result,
1968                in(vreg) a,
1969                in(vreg) b,
1970                options(pure, nomem, nostack));
1971        result
1972    }
1973
1974    /// Fused multiply add [doc](https://developer.arm.com/documentation/dui0801/g/A64-SIMD-Vector-Instructions/FMLA--vector-)
1975    #[inline]
1976    pub unsafe fn vfmaq_f16(mut a: float16x8_t, b: float16x8_t, c: float16x8_t) -> float16x8_t {
1977        asm!(
1978                "fmla {0:v}.8h, {1:v}.8h, {2:v}.8h",
1979                inout(vreg) a,
1980                in(vreg) b,
1981                in(vreg) c,
1982                options(pure, nomem, nostack));
1983        a
1984    }
1985
1986    #[inline]
1987    pub unsafe fn vfmaq_laneq_f16<const LANE: i32>(
1988        mut a: float16x8_t,
1989        b: float16x8_t,
1990        c: float16x8_t,
1991    ) -> float16x8_t {
1992        match LANE {
1993            0 => asm!(
1994                "fmla {0:v}.8h, {1:v}.8h, {2:v}.h[0]",
1995                inout(vreg) a,
1996                in(vreg) b,
1997                in(vreg_low16) c,
1998                options(pure, nomem, nostack)),
1999            1 => asm!(
2000                "fmla {0:v}.8h, {1:v}.8h, {2:v}.h[1]",
2001                inout(vreg) a,
2002                in(vreg) b,
2003                in(vreg_low16) c,
2004                options(pure, nomem, nostack)),
2005            2 => asm!(
2006                "fmla {0:v}.8h, {1:v}.8h, {2:v}.h[2]",
2007                inout(vreg) a,
2008                in(vreg) b,
2009                in(vreg_low16) c,
2010                options(pure, nomem, nostack)),
2011            3 => asm!(
2012                "fmla {0:v}.8h, {1:v}.8h, {2:v}.h[3]",
2013                inout(vreg) a,
2014                in(vreg) b,
2015                in(vreg_low16) c,
2016                options(pure, nomem, nostack)),
2017            4 => asm!(
2018                "fmla {0:v}.8h, {1:v}.8h, {2:v}.h[4]",
2019                inout(vreg) a,
2020                in(vreg) b,
2021                in(vreg_low16) c,
2022                options(pure, nomem, nostack)),
2023            5 => asm!(
2024                "fmla {0:v}.8h, {1:v}.8h, {2:v}.h[5]",
2025                inout(vreg) a,
2026                in(vreg) b,
2027                in(vreg_low16) c,
2028                options(pure, nomem, nostack)),
2029            6 => asm!(
2030                "fmla {0:v}.8h, {1:v}.8h, {2:v}.h[6]",
2031                inout(vreg) a,
2032                in(vreg) b,
2033                in(vreg_low16) c,
2034                options(pure, nomem, nostack)),
2035            7 => asm!(
2036                "fmla {0:v}.8h, {1:v}.8h, {2:v}.h[7]",
2037                inout(vreg) a,
2038                in(vreg) b,
2039                in(vreg_low16) c,
2040                options(pure, nomem, nostack)),
2041            _ => unreachable!(),
2042        }
2043        a
2044    }
2045
2046    #[derive(Copy, Clone, Debug)]
2047    pub struct Neon {
2048        __private: (),
2049    }
2050
2051    #[derive(Copy, Clone, Debug)]
2052    pub struct NeonFp16 {
2053        __private: (),
2054    }
2055
2056    #[derive(Copy, Clone, Debug)]
2057    pub struct NeonFcma {
2058        __private: (),
2059    }
2060
2061    impl Simd for Neon {
2062        #[inline]
2063        #[target_feature(enable = "neon")]
2064        unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
2065            f.call()
2066        }
2067    }
2068
2069    impl Simd for NeonFp16 {
2070        #[inline]
2071        #[target_feature(enable = "neon,fp16")]
2072        unsafe fn vectorize<F: NullaryFnOnce>(f: F) -> F::Output {
2073            f.call()
2074        }
2075    }
2076
2077    #[cfg(feature = "f16")]
2078    unsafe impl MixedSimd<f16, f16, f16, f32> for NeonFp16 {
2079        const SIMD_WIDTH: usize = 4;
2080
2081        type LhsN = [f16; 4];
2082        type RhsN = [f16; 4];
2083        type DstN = [f16; 4];
2084        type AccN = [f32; 4];
2085
2086        #[inline]
2087        fn try_new() -> Option<Self> {
2088            if crate::feature_detected!("neon") && crate::feature_detected!("fp16") {
2089                Some(Self { __private: () })
2090            } else {
2091                None
2092            }
2093        }
2094
2095        #[inline(always)]
2096        fn mult(self, lhs: f32, rhs: f32) -> f32 {
2097            lhs * rhs
2098        }
2099
2100        #[inline(always)]
2101        fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
2102            unsafe { neon_fmaf(lhs, rhs, acc) }
2103        }
2104
2105        #[inline(always)]
2106        fn from_lhs(self, lhs: f16) -> f32 {
2107            unsafe { f16_to_f32_fp16(cast(lhs)) }
2108        }
2109
2110        #[inline(always)]
2111        fn from_rhs(self, rhs: f16) -> f32 {
2112            unsafe { f16_to_f32_fp16(cast(rhs)) }
2113        }
2114
2115        #[inline(always)]
2116        fn from_dst(self, dst: f16) -> f32 {
2117            unsafe { f16_to_f32_fp16(cast(dst)) }
2118        }
2119
2120        #[inline(always)]
2121        fn into_dst(self, acc: f32) -> f16 {
2122            unsafe { cast(f32_to_f16_fp16(acc)) }
2123        }
2124
2125        #[inline(always)]
2126        fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
2127            unsafe { transmute(vfmaq_f32(transmute(acc), transmute(lhs), transmute(rhs))) }
2128        }
2129
2130        #[inline(always)]
2131        fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
2132            unsafe { f16x4_to_f32x4_fp16(&cast(lhs)) }
2133        }
2134
2135        #[inline(always)]
2136        fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
2137            unsafe { f16x4_to_f32x4_fp16(&cast(rhs)) }
2138        }
2139
2140        #[inline(always)]
2141        fn simd_splat(self, lhs: f32) -> Self::AccN {
2142            [lhs, lhs, lhs, lhs]
2143        }
2144
2145        #[inline(always)]
2146        fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
2147            unsafe { f16x4_to_f32x4_fp16(&cast(dst)) }
2148        }
2149
2150        #[inline(always)]
2151        fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
2152            unsafe { cast(f32x4_to_f16x4_fp16(&acc)) }
2153        }
2154
2155        #[inline(always)]
2156        fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
2157            #[inline]
2158            #[target_feature(enable = "neon,fp16")]
2159            unsafe fn implementation<F: NullaryFnOnce>(f: F) -> F::Output {
2160                f.call()
2161            }
2162
2163            unsafe { implementation(f) }
2164        }
2165
2166        #[inline(always)]
2167        fn add(self, lhs: f32, rhs: f32) -> f32 {
2168            lhs + rhs
2169        }
2170
2171        #[inline(always)]
2172        fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
2173            unsafe { transmute(vmulq_f32(transmute(lhs), transmute(rhs))) }
2174        }
2175
2176        #[inline(always)]
2177        fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
2178            unsafe { transmute(vaddq_f32(transmute(lhs), transmute(rhs))) }
2179        }
2180    }
2181
2182    #[cfg(feature = "f16")]
2183    unsafe impl MixedSimd<f16, f16, f16, f16> for NeonFp16 {
2184        const SIMD_WIDTH: usize = 8;
2185
2186        type LhsN = [f16; 8];
2187        type RhsN = [f16; 8];
2188        type DstN = [f16; 8];
2189        type AccN = [f16; 8];
2190
2191        #[inline]
2192        fn try_new() -> Option<Self> {
2193            if crate::feature_detected!("neon") && crate::feature_detected!("fp16") {
2194                Some(Self { __private: () })
2195            } else {
2196                None
2197            }
2198        }
2199
2200        #[inline(always)]
2201        fn mult(self, lhs: f16, rhs: f16) -> f16 {
2202            unsafe { cast(multiply_f16_fp16(cast(lhs), cast(rhs))) }
2203        }
2204
2205        #[inline(always)]
2206        fn mult_add(self, lhs: f16, rhs: f16, acc: f16) -> f16 {
2207            unsafe { cast(fmaq_f16(cast(acc), cast(lhs), cast(rhs))) }
2208        }
2209
2210        #[inline(always)]
2211        fn from_lhs(self, lhs: f16) -> f16 {
2212            lhs
2213        }
2214
2215        #[inline(always)]
2216        fn from_rhs(self, rhs: f16) -> f16 {
2217            rhs
2218        }
2219
2220        #[inline(always)]
2221        fn from_dst(self, dst: f16) -> f16 {
2222            dst
2223        }
2224
2225        #[inline(always)]
2226        fn into_dst(self, acc: f16) -> f16 {
2227            acc
2228        }
2229
2230        #[inline(always)]
2231        fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
2232            unsafe { transmute(vfmaq_f16(transmute(acc), transmute(lhs), transmute(rhs))) }
2233        }
2234
2235        #[inline(always)]
2236        fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
2237            lhs
2238        }
2239
2240        #[inline(always)]
2241        fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
2242            rhs
2243        }
2244
2245        #[inline(always)]
2246        fn simd_splat(self, lhs: f16) -> Self::AccN {
2247            [lhs, lhs, lhs, lhs, lhs, lhs, lhs, lhs]
2248        }
2249
2250        #[inline(always)]
2251        fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
2252            dst
2253        }
2254
2255        #[inline(always)]
2256        fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
2257            acc
2258        }
2259
2260        #[inline(always)]
2261        fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
2262            #[inline]
2263            #[target_feature(enable = "neon,fp16")]
2264            unsafe fn implementation<F: NullaryFnOnce>(f: F) -> F::Output {
2265                f.call()
2266            }
2267
2268            unsafe { implementation(f) }
2269        }
2270
2271        #[inline(always)]
2272        fn add(self, lhs: f16, rhs: f16) -> f16 {
2273            unsafe { cast(add_f16_fp16(cast(lhs), cast(rhs))) }
2274        }
2275
2276        #[inline(always)]
2277        fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
2278            unsafe { transmute(vmulq_f16(transmute(lhs), transmute(rhs))) }
2279        }
2280
2281        #[inline(always)]
2282        fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
2283            unsafe { transmute(vaddq_f16(transmute(lhs), transmute(rhs))) }
2284        }
2285    }
2286
2287    #[cfg(feature = "f16")]
2288    unsafe impl MixedSimd<f16, f16, f16, f32> for Neon {
2289        const SIMD_WIDTH: usize = 4;
2290
2291        type LhsN = [f16; 4];
2292        type RhsN = [f16; 4];
2293        type DstN = [f16; 4];
2294        type AccN = [f32; 4];
2295
2296        #[inline]
2297        fn try_new() -> Option<Self> {
2298            if crate::feature_detected!("neon") {
2299                Some(Self { __private: () })
2300            } else {
2301                None
2302            }
2303        }
2304
2305        #[inline(always)]
2306        fn mult(self, lhs: f32, rhs: f32) -> f32 {
2307            lhs * rhs
2308        }
2309
2310        #[inline(always)]
2311        fn mult_add(self, lhs: f32, rhs: f32, acc: f32) -> f32 {
2312            unsafe { neon_fmaf(lhs, rhs, acc) }
2313        }
2314
2315        #[inline(always)]
2316        fn from_lhs(self, lhs: f16) -> f32 {
2317            lhs.into()
2318        }
2319
2320        #[inline(always)]
2321        fn from_rhs(self, rhs: f16) -> f32 {
2322            rhs.into()
2323        }
2324
2325        #[inline(always)]
2326        fn from_dst(self, dst: f16) -> f32 {
2327            dst.into()
2328        }
2329
2330        #[inline(always)]
2331        fn into_dst(self, acc: f32) -> f16 {
2332            f16::from_f32(acc)
2333        }
2334
2335        #[inline(always)]
2336        fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN {
2337            unsafe { transmute(vfmaq_f32(transmute(acc), transmute(lhs), transmute(rhs))) }
2338        }
2339
2340        #[inline(always)]
2341        fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN {
2342            [lhs[0].into(), lhs[1].into(), lhs[2].into(), lhs[3].into()]
2343        }
2344
2345        #[inline(always)]
2346        fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN {
2347            [rhs[0].into(), rhs[1].into(), rhs[2].into(), rhs[3].into()]
2348        }
2349
2350        #[inline(always)]
2351        fn simd_splat(self, lhs: f32) -> Self::AccN {
2352            [lhs, lhs, lhs, lhs]
2353        }
2354
2355        #[inline(always)]
2356        fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN {
2357            [dst[0].into(), dst[1].into(), dst[2].into(), dst[3].into()]
2358        }
2359
2360        #[inline(always)]
2361        fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN {
2362            [
2363                f16::from_f32(acc[0]),
2364                f16::from_f32(acc[1]),
2365                f16::from_f32(acc[2]),
2366                f16::from_f32(acc[3]),
2367            ]
2368        }
2369
2370        #[inline(always)]
2371        fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output {
2372            #[inline]
2373            #[target_feature(enable = "neon")]
2374            unsafe fn implementation<F: NullaryFnOnce>(f: F) -> F::Output {
2375                f.call()
2376            }
2377
2378            unsafe { implementation(f) }
2379        }
2380
2381        #[inline(always)]
2382        fn add(self, lhs: f32, rhs: f32) -> f32 {
2383            lhs + rhs
2384        }
2385
2386        #[inline(always)]
2387        fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
2388            unsafe { transmute(vmulq_f32(transmute(lhs), transmute(rhs))) }
2389        }
2390
2391        #[inline(always)]
2392        fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN {
2393            unsafe { transmute(vaddq_f32(transmute(lhs), transmute(rhs))) }
2394        }
2395    }
2396}
2397
2398pub trait Boilerplate: Copy + Send + Sync + core::fmt::Debug + 'static + PartialEq {}
2399impl<T: Copy + Send + Sync + core::fmt::Debug + PartialEq + 'static> Boilerplate for T {}
2400
2401pub unsafe trait MixedSimd<Lhs, Rhs, Dst, Acc>: Simd {
2402    const SIMD_WIDTH: usize;
2403
2404    type LhsN: Boilerplate;
2405    type RhsN: Boilerplate;
2406    type DstN: Boilerplate;
2407    type AccN: Boilerplate;
2408
2409    fn try_new() -> Option<Self>;
2410
2411    fn vectorize<F: NullaryFnOnce>(self, f: F) -> F::Output;
2412
2413    fn add(self, lhs: Acc, rhs: Acc) -> Acc;
2414    fn mult(self, lhs: Acc, rhs: Acc) -> Acc;
2415    fn mult_add(self, lhs: Acc, rhs: Acc, acc: Acc) -> Acc;
2416    fn from_lhs(self, lhs: Lhs) -> Acc;
2417    fn from_rhs(self, rhs: Rhs) -> Acc;
2418    fn from_dst(self, dst: Dst) -> Acc;
2419    fn into_dst(self, acc: Acc) -> Dst;
2420
2421    fn simd_mul(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN;
2422    fn simd_add(self, lhs: Self::AccN, rhs: Self::AccN) -> Self::AccN;
2423    fn simd_mult_add(self, lhs: Self::AccN, rhs: Self::AccN, acc: Self::AccN) -> Self::AccN;
2424    fn simd_from_lhs(self, lhs: Self::LhsN) -> Self::AccN;
2425    fn simd_from_rhs(self, rhs: Self::RhsN) -> Self::AccN;
2426    fn simd_splat(self, lhs: Acc) -> Self::AccN;
2427
2428    fn simd_from_dst(self, dst: Self::DstN) -> Self::AccN;
2429    fn simd_into_dst(self, acc: Self::AccN) -> Self::DstN;
2430}