nano_gemm/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2
3use core::mem::MaybeUninit;
4use equator::debug_assert;
5
6#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
7pub mod x86 {
8    pub use nano_gemm_c32::x86::*;
9    pub use nano_gemm_c64::x86::*;
10    pub use nano_gemm_f32::x86::*;
11    pub use nano_gemm_f64::x86::*;
12}
13#[cfg(target_arch = "aarch64")]
14pub mod aarch64 {
15    pub use nano_gemm_c32::aarch64::*;
16    pub use nano_gemm_c64::aarch64::*;
17    pub use nano_gemm_f32::aarch64::*;
18    pub use nano_gemm_f64::aarch64::*;
19}
20
21#[allow(non_camel_case_types)]
22pub type c32 = num_complex::Complex32;
23#[allow(non_camel_case_types)]
24pub type c64 = num_complex::Complex64;
25
26pub use nano_gemm_core::*;
27
28#[derive(Copy, Clone)]
29pub struct Plan<T> {
30    microkernels: [[MaybeUninit<MicroKernel<T>>; 2]; 2],
31    millikernel: unsafe fn(
32        microkernels: &[[MaybeUninit<MicroKernel<T>>; 2]; 2],
33        mr: usize,
34        nr: usize,
35        m: usize,
36        n: usize,
37        k: usize,
38        dst: *mut T,
39        dst_rs: isize,
40        dst_cs: isize,
41        lhs: *const T,
42        lhs_rs: isize,
43        lhs_cs: isize,
44        rhs: *const T,
45        rhs_rs: isize,
46        rhs_cs: isize,
47        alpha: T,
48        beta: T,
49        conj_lhs: bool,
50        conj_rhs: bool,
51        full_mask: *const (),
52        last_mask: *const (),
53    ),
54    mr: usize,
55    nr: usize,
56    full_mask: *const (),
57    last_mask: *const (),
58    m: usize,
59    n: usize,
60    k: usize,
61    dst_cs: isize,
62    dst_rs: isize,
63    lhs_cs: isize,
64    lhs_rs: isize,
65    rhs_cs: isize,
66    rhs_rs: isize,
67}
68
69#[allow(unused_variables)]
70unsafe fn noop_millikernel<T: Copy>(
71    microkernels: &[[MaybeUninit<MicroKernel<T>>; 2]; 2],
72    mr: usize,
73    nr: usize,
74    m: usize,
75    n: usize,
76    k: usize,
77    dst: *mut T,
78    dst_rs: isize,
79    dst_cs: isize,
80    lhs: *const T,
81    lhs_rs: isize,
82    lhs_cs: isize,
83    rhs: *const T,
84    rhs_rs: isize,
85    rhs_cs: isize,
86    alpha: T,
87    beta: T,
88    conj_lhs: bool,
89    conj_rhs: bool,
90    full_mask: *const (),
91    last_mask: *const (),
92) {
93}
94
95#[allow(unused_variables)]
96unsafe fn naive_millikernel<
97    T: Copy + core::ops::Mul<Output = T> + core::ops::Add<Output = T> + PartialEq + Conj,
98>(
99    microkernels: &[[MaybeUninit<MicroKernel<T>>; 2]; 2],
100    mr: usize,
101    nr: usize,
102    m: usize,
103    n: usize,
104    k: usize,
105    dst: *mut T,
106    dst_rs: isize,
107    dst_cs: isize,
108    lhs: *const T,
109    lhs_rs: isize,
110    lhs_cs: isize,
111    rhs: *const T,
112    rhs_rs: isize,
113    rhs_cs: isize,
114    alpha: T,
115    beta: T,
116    conj_lhs: bool,
117    conj_rhs: bool,
118    full_mask: *const (),
119    last_mask: *const (),
120) {
121    let zero: T = core::mem::zeroed();
122    if alpha == zero {
123        for j in 0..n {
124            for i in 0..m {
125                let mut acc = zero;
126                for depth in 0..k {
127                    let lhs = *lhs.offset(lhs_rs * i as isize + lhs_cs * depth as isize);
128                    let rhs = *rhs.offset(rhs_rs * depth as isize + rhs_cs * j as isize);
129                    acc = acc
130                        + if conj_lhs { lhs.conj() } else { lhs }
131                            * if conj_rhs { rhs.conj() } else { rhs };
132                }
133                *dst.offset(dst_rs * i as isize + dst_cs * j as isize) = beta * acc;
134            }
135        }
136    } else {
137        for j in 0..n {
138            for i in 0..m {
139                let mut acc = zero;
140                for depth in 0..k {
141                    let lhs = *lhs.offset(lhs_rs * i as isize + lhs_cs * depth as isize);
142                    let rhs = *rhs.offset(rhs_rs * depth as isize + rhs_cs * j as isize);
143                    acc = acc
144                        + if conj_lhs { lhs.conj() } else { lhs }
145                            * if conj_rhs { rhs.conj() } else { rhs };
146                }
147                let dst = dst.offset(dst_rs * i as isize + dst_cs * j as isize);
148                *dst = alpha * *dst + beta * acc;
149            }
150        }
151    }
152}
153
154#[allow(unused_variables)]
155unsafe fn fill_millikernel<T: Copy + PartialEq + core::ops::Mul<Output = T>>(
156    microkernels: &[[MaybeUninit<MicroKernel<T>>; 2]; 2],
157    mr: usize,
158    nr: usize,
159    m: usize,
160    n: usize,
161    k: usize,
162    dst: *mut T,
163    dst_rs: isize,
164    dst_cs: isize,
165    lhs: *const T,
166    lhs_rs: isize,
167    lhs_cs: isize,
168    rhs: *const T,
169    rhs_rs: isize,
170    rhs_cs: isize,
171    alpha: T,
172    beta: T,
173    conj_lhs: bool,
174    conj_rhs: bool,
175    full_mask: *const (),
176    last_mask: *const (),
177) {
178    let zero: T = core::mem::zeroed();
179    if alpha == zero {
180        for j in 0..n {
181            for i in 0..m {
182                *dst.offset(dst_rs * i as isize + dst_cs * j as isize) = core::mem::zeroed();
183            }
184        }
185    } else {
186        for j in 0..n {
187            for i in 0..m {
188                let dst = dst.offset(dst_rs * i as isize + dst_cs * j as isize);
189                *dst = alpha * *dst;
190            }
191        }
192    }
193}
194
195#[inline(always)]
196unsafe fn small_direct_millikernel<
197    T: Copy,
198    const M_DIVCEIL_MR: usize,
199    const N_DIVCEIL_NR: usize,
200>(
201    microkernels: &[[MaybeUninit<MicroKernel<T>>; 2]; 2],
202    mr: usize,
203    nr: usize,
204    m: usize,
205    n: usize,
206    k: usize,
207    dst: *mut T,
208    dst_rs: isize,
209    dst_cs: isize,
210    lhs: *const T,
211    lhs_rs: isize,
212    lhs_cs: isize,
213    rhs: *const T,
214    rhs_rs: isize,
215    rhs_cs: isize,
216    alpha: T,
217    beta: T,
218    conj_lhs: bool,
219    conj_rhs: bool,
220    full_mask: *const (),
221    last_mask: *const (),
222) {
223    _ = (m, n);
224    debug_assert!(all(lhs_rs == 1, dst_rs == 1));
225
226    let mut data = MicroKernelData {
227        alpha,
228        beta,
229        conj_lhs,
230        conj_rhs,
231        k,
232        dst_cs,
233        lhs_cs,
234        rhs_rs,
235        rhs_cs,
236        last_mask,
237    };
238
239    let mut i = 0usize;
240    while i < M_DIVCEIL_MR {
241        data.last_mask = if i + 1 < M_DIVCEIL_MR {
242            full_mask
243        } else {
244            last_mask
245        };
246
247        let microkernels = microkernels.get_unchecked((i + 1 >= M_DIVCEIL_MR) as usize);
248        {
249            let i = i * mr;
250            let dst = dst.offset(i as isize);
251
252            let mut j = 0usize;
253            while j < N_DIVCEIL_NR {
254                let microkernel = microkernels
255                    .get_unchecked((j + 1 >= N_DIVCEIL_NR) as usize)
256                    .assume_init();
257
258                {
259                    let j = j * nr;
260                    microkernel(
261                        &data,
262                        dst.offset(j as isize * dst_cs),
263                        lhs.offset(i as isize),
264                        rhs.offset(j as isize * rhs_cs),
265                    );
266                }
267
268                j += 1;
269            }
270        }
271        i += 1;
272    }
273}
274
275unsafe fn direct_millikernel<T: Copy>(
276    microkernels: &[[MaybeUninit<MicroKernel<T>>; 2]; 2],
277    mr: usize,
278    nr: usize,
279    m: usize,
280    n: usize,
281    k: usize,
282    dst: *mut T,
283    dst_rs: isize,
284    dst_cs: isize,
285    lhs: *const T,
286    lhs_rs: isize,
287    lhs_cs: isize,
288    rhs: *const T,
289    rhs_rs: isize,
290    rhs_cs: isize,
291    alpha: T,
292    beta: T,
293    conj_lhs: bool,
294    conj_rhs: bool,
295    full_mask: *const (),
296    last_mask: *const (),
297) {
298    debug_assert!(all(lhs_rs == 1, dst_rs == 1));
299
300    let mut data = MicroKernelData {
301        alpha,
302        beta,
303        conj_lhs,
304        conj_rhs,
305        k,
306        dst_cs,
307        lhs_cs,
308        rhs_rs,
309        rhs_cs,
310        last_mask,
311    };
312
313    let mut i = 0usize;
314    while i < m {
315        data.last_mask = if i + mr <= m { full_mask } else { last_mask };
316        let microkernels = microkernels.get_unchecked((i + mr > m) as usize);
317        let dst = dst.offset(i as isize);
318
319        let mut j = 0usize;
320        while j < n {
321            let microkernel = microkernels
322                .get_unchecked((j + nr > n) as usize)
323                .assume_init();
324
325            microkernel(
326                &data,
327                dst.offset(j as isize * dst_cs),
328                lhs.offset(i as isize),
329                rhs.offset(j as isize * rhs_cs),
330            );
331
332            j += nr;
333        }
334
335        i += mr;
336    }
337}
338
339trait One {
340    const ONE: Self;
341}
342trait Conj {
343    fn conj(self) -> Self;
344}
345
346impl One for f32 {
347    const ONE: Self = 1.0;
348}
349impl One for f64 {
350    const ONE: Self = 1.0;
351}
352impl One for c32 {
353    const ONE: Self = Self { re: 1.0, im: 0.0 };
354}
355impl One for c64 {
356    const ONE: Self = Self { re: 1.0, im: 0.0 };
357}
358
359impl Conj for f32 {
360    #[inline]
361    fn conj(self) -> Self {
362        self
363    }
364}
365impl Conj for f64 {
366    #[inline]
367    fn conj(self) -> Self {
368        self
369    }
370}
371
372impl Conj for c32 {
373    #[inline]
374    fn conj(self) -> Self {
375        Self::conj(&self)
376    }
377}
378impl Conj for c64 {
379    #[inline]
380    fn conj(self) -> Self {
381        Self::conj(&self)
382    }
383}
384
385unsafe fn copy_millikernel<
386    T: Copy + PartialEq + core::ops::Add<Output = T> + core::ops::Mul<Output = T> + Conj + One,
387>(
388    microkernels: &[[MaybeUninit<MicroKernel<T>>; 2]; 2],
389    mr: usize,
390    nr: usize,
391    m: usize,
392    n: usize,
393    k: usize,
394    dst: *mut T,
395    dst_rs: isize,
396    dst_cs: isize,
397    lhs: *const T,
398    lhs_rs: isize,
399    lhs_cs: isize,
400    rhs: *const T,
401    rhs_rs: isize,
402    rhs_cs: isize,
403    mut alpha: T,
404    beta: T,
405    conj_lhs: bool,
406    conj_rhs: bool,
407    full_mask: *const (),
408    last_mask: *const (),
409) {
410    if dst_rs == 1 && lhs_rs == 1 {
411        let gemm_dst = dst;
412        let gemm_lhs = lhs;
413        let gemm_dst_cs = dst_cs;
414        let gemm_lhs_cs = lhs_cs;
415
416        direct_millikernel(
417            microkernels,
418            mr,
419            nr,
420            m,
421            n,
422            k,
423            gemm_dst,
424            1,
425            gemm_dst_cs,
426            gemm_lhs,
427            1,
428            gemm_lhs_cs,
429            rhs,
430            rhs_rs,
431            rhs_cs,
432            alpha,
433            beta,
434            conj_lhs,
435            conj_rhs,
436            full_mask,
437            last_mask,
438        );
439    } else {
440        // 32 is always a multiple of both MR and NR
441        const M_BS: usize = 64;
442        const N_BS: usize = 64;
443        const K_BS: usize = 64;
444        let mut dst_tmp: MaybeUninit<[T; M_BS * N_BS]> = core::mem::MaybeUninit::uninit();
445        let mut lhs_tmp: MaybeUninit<[T; M_BS * K_BS]> = core::mem::MaybeUninit::uninit();
446
447        let dst_tmp = &mut *((&mut dst_tmp) as *mut _ as *mut [[MaybeUninit<T>; M_BS]; N_BS]);
448        let lhs_tmp = &mut *((&mut lhs_tmp) as *mut _ as *mut [[MaybeUninit<T>; M_BS]; K_BS]);
449
450        let gemm_dst = if dst_rs == 1 {
451            dst
452        } else {
453            dst_tmp.as_mut_ptr() as *mut T
454        };
455        let gemm_lhs = lhs_tmp.as_mut_ptr() as *mut T;
456        let gemm_dst_cs = if dst_rs == 1 { dst_cs } else { M_BS as isize };
457        let gemm_lhs_cs = M_BS as isize;
458
459        let mut depth = 0usize;
460        while depth < k {
461            let depth_bs = Ord::min(K_BS, k - depth);
462
463            let mut i = 0usize;
464            while i < m {
465                let i_bs = Ord::min(M_BS, m - i);
466
467                let lhs = lhs.offset(lhs_rs * i as isize + lhs_cs * depth as isize);
468
469                for ii in 0..i_bs {
470                    for jj in 0..depth_bs {
471                        let ii = ii as isize;
472                        let jj = jj as isize;
473                        *(gemm_lhs.offset(ii + gemm_lhs_cs * jj)) =
474                            *(lhs.offset(lhs_rs * ii + lhs_cs * jj));
475                    }
476                }
477
478                let mut j = 0usize;
479                while j < n {
480                    let j_bs = Ord::min(N_BS, n - j);
481
482                    let rhs = rhs.offset(rhs_rs * depth as isize + rhs_cs * j as isize);
483
484                    let dst = dst.offset(dst_rs * i as isize + dst_cs * j as isize);
485                    let gemm_dst = if dst_rs == 1 {
486                        gemm_dst.offset(i as isize + gemm_dst_cs * j as isize)
487                    } else {
488                        gemm_dst
489                    };
490
491                    direct_millikernel(
492                        microkernels,
493                        mr,
494                        nr,
495                        i_bs,
496                        j_bs,
497                        depth_bs,
498                        gemm_dst,
499                        1,
500                        gemm_dst_cs,
501                        gemm_lhs,
502                        1,
503                        gemm_lhs_cs,
504                        rhs,
505                        rhs_rs,
506                        rhs_cs,
507                        if dst_rs == 1 {
508                            alpha
509                        } else {
510                            core::mem::zeroed()
511                        },
512                        beta,
513                        conj_lhs,
514                        conj_rhs,
515                        full_mask,
516                        if i + i_bs == m { last_mask } else { full_mask },
517                    );
518
519                    if dst_rs != 1 {
520                        if alpha == core::mem::zeroed() {
521                            for ii in 0..i_bs {
522                                for jj in 0..j_bs {
523                                    let ii = ii as isize;
524                                    let jj = jj as isize;
525                                    *(dst.offset(dst_rs * ii + dst_cs * jj)) =
526                                        *(gemm_dst.offset(ii + gemm_dst_cs * jj));
527                                }
528                            }
529                        } else {
530                            for ii in 0..i_bs {
531                                for jj in 0..j_bs {
532                                    let ii = ii as isize;
533                                    let jj = jj as isize;
534                                    let dst = dst.offset(dst_rs * ii + dst_cs * jj);
535                                    *dst = alpha * *dst + *(gemm_dst.offset(ii + gemm_dst_cs * jj));
536                                }
537                            }
538                        }
539                    }
540
541                    j += j_bs;
542                }
543
544                i += i_bs;
545            }
546
547            alpha = T::ONE;
548            depth += depth_bs;
549        }
550    }
551}
552
553impl<T> Plan<T> {
554    #[allow(dead_code)]
555    #[inline(always)]
556    fn from_masked_impl<const MR_DIV_N: usize, const NR: usize, const N: usize, Mask>(
557        const_microkernels: &[[[MicroKernel<T>; NR]; MR_DIV_N]; 17],
558        const_masks: Option<&[Mask; N]>,
559        m: usize,
560        n: usize,
561        k: usize,
562        is_col_major: bool,
563    ) -> Self
564    where
565        T: Copy + PartialEq + core::ops::Add<Output = T> + core::ops::Mul<Output = T> + Conj + One,
566    {
567        let mut microkernels = [[MaybeUninit::<MicroKernel<T>>::uninit(); 2]; 2];
568
569        let mr = MR_DIV_N * N;
570        let nr = NR;
571
572        {
573            let k = Ord::min(k.wrapping_sub(1), 16);
574            let m = (m.wrapping_sub(1) / N) % (mr / N);
575            let n = n.wrapping_sub(1) % nr;
576
577            microkernels[0][0].write(const_microkernels[k][MR_DIV_N - 1][NR - 1]);
578            microkernels[0][1].write(const_microkernels[k][MR_DIV_N - 1][n]);
579            microkernels[1][0].write(const_microkernels[k][m][NR - 1]);
580            microkernels[1][1].write(const_microkernels[k][m][n]);
581        }
582
583        Self {
584            microkernels,
585            millikernel: if m == 0 || n == 0 {
586                noop_millikernel
587            } else if k == 0 {
588                fill_millikernel
589            } else if is_col_major {
590                if m <= mr && n <= nr {
591                    small_direct_millikernel::<_, 1, 1>
592                } else if m <= mr && n <= 2 * nr {
593                    small_direct_millikernel::<_, 1, 2>
594                } else if m <= 2 * mr && n <= nr {
595                    small_direct_millikernel::<_, 2, 1>
596                } else if m <= 2 * mr && n <= 2 * nr {
597                    small_direct_millikernel::<_, 2, 2>
598                } else {
599                    direct_millikernel
600                }
601            } else {
602                copy_millikernel
603            },
604            mr,
605            nr,
606            m,
607            n,
608            k,
609            dst_rs: if is_col_major { 1 } else { isize::MIN },
610            dst_cs: isize::MIN,
611            lhs_rs: if is_col_major { 1 } else { isize::MIN },
612            lhs_cs: isize::MIN,
613            rhs_cs: isize::MIN,
614            rhs_rs: isize::MIN,
615            full_mask: if let Some(const_masks) = const_masks {
616                (&const_masks[0]) as *const _ as *const ()
617            } else {
618                &()
619            },
620            last_mask: if let Some(const_masks) = const_masks {
621                (&const_masks[m % N]) as *const _ as *const ()
622            } else {
623                &()
624            },
625        }
626    }
627
628    #[allow(dead_code)]
629    #[inline(always)]
630    fn from_non_masked_impl<const MR: usize, const NR: usize>(
631        const_microkernels: &[[[MicroKernel<T>; NR]; MR]; 17],
632        m: usize,
633        n: usize,
634        k: usize,
635        is_col_major: bool,
636    ) -> Self
637    where
638        T: Copy + PartialEq + core::ops::Add<Output = T> + core::ops::Mul<Output = T> + Conj + One,
639    {
640        let mut microkernels = [[MaybeUninit::<MicroKernel<T>>::uninit(); 2]; 2];
641
642        let mr = MR;
643        let nr = NR;
644
645        {
646            let k = Ord::min(k.wrapping_sub(1), 16);
647            let m = m.wrapping_sub(1) % mr;
648            let n = n.wrapping_sub(1) % nr;
649
650            microkernels[0][0].write(const_microkernels[k][MR - 1][NR - 1]);
651            microkernels[0][1].write(const_microkernels[k][MR - 1][n]);
652            microkernels[1][0].write(const_microkernels[k][m][NR - 1]);
653            microkernels[1][1].write(const_microkernels[k][m][n]);
654        }
655
656        Self {
657            microkernels,
658            millikernel: if m == 0 || n == 0 {
659                noop_millikernel
660            } else if k == 0 {
661                fill_millikernel
662            } else if is_col_major {
663                if m <= mr && n <= nr {
664                    small_direct_millikernel::<_, 1, 1>
665                } else if m <= mr && n <= 2 * nr {
666                    small_direct_millikernel::<_, 1, 2>
667                } else if m <= 2 * mr && n <= nr {
668                    small_direct_millikernel::<_, 2, 1>
669                } else if m <= 2 * mr && n <= 2 * nr {
670                    small_direct_millikernel::<_, 2, 2>
671                } else {
672                    direct_millikernel
673                }
674            } else {
675                copy_millikernel
676            },
677            mr,
678            nr,
679            m,
680            n,
681            k,
682            dst_rs: if is_col_major { 1 } else { isize::MIN },
683            dst_cs: isize::MIN,
684            lhs_rs: if is_col_major { 1 } else { isize::MIN },
685            lhs_cs: isize::MIN,
686            rhs_cs: isize::MIN,
687            rhs_rs: isize::MIN,
688            full_mask: &(),
689            last_mask: &(),
690        }
691    }
692}
693
694impl Plan<f32> {
695    fn new_f32_scalar(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
696        Self {
697            microkernels: [[MaybeUninit::<MicroKernel<f32>>::uninit(); 2]; 2],
698            millikernel: naive_millikernel,
699            mr: 0,
700            nr: 0,
701            full_mask: core::ptr::null(),
702            last_mask: core::ptr::null(),
703            m,
704            n,
705            k,
706            dst_rs: if is_col_major { 1 } else { isize::MIN },
707            dst_cs: isize::MIN,
708            lhs_rs: if is_col_major { 1 } else { isize::MIN },
709            lhs_cs: isize::MIN,
710            rhs_cs: isize::MIN,
711            rhs_rs: isize::MIN,
712        }
713    }
714}
715impl Plan<f64> {
716    fn new_f64_scalar(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
717        Self {
718            microkernels: [[MaybeUninit::<MicroKernel<f64>>::uninit(); 2]; 2],
719            millikernel: naive_millikernel,
720            mr: 0,
721            nr: 0,
722            full_mask: core::ptr::null(),
723            last_mask: core::ptr::null(),
724            m,
725            n,
726            k,
727            dst_rs: if is_col_major { 1 } else { isize::MIN },
728            dst_cs: isize::MIN,
729            lhs_rs: if is_col_major { 1 } else { isize::MIN },
730            lhs_cs: isize::MIN,
731            rhs_cs: isize::MIN,
732            rhs_rs: isize::MIN,
733        }
734    }
735}
736impl Plan<c32> {
737    fn new_c32_scalar(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
738        Self {
739            microkernels: [[MaybeUninit::<MicroKernel<c32>>::uninit(); 2]; 2],
740            millikernel: naive_millikernel,
741            mr: 0,
742            nr: 0,
743            full_mask: core::ptr::null(),
744            last_mask: core::ptr::null(),
745            m,
746            n,
747            k,
748            dst_rs: if is_col_major { 1 } else { isize::MIN },
749            dst_cs: isize::MIN,
750            lhs_rs: if is_col_major { 1 } else { isize::MIN },
751            lhs_cs: isize::MIN,
752            rhs_cs: isize::MIN,
753            rhs_rs: isize::MIN,
754        }
755    }
756}
757impl Plan<c64> {
758    fn new_c64_scalar(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
759        Self {
760            microkernels: [[MaybeUninit::<MicroKernel<c64>>::uninit(); 2]; 2],
761            millikernel: naive_millikernel,
762            mr: 0,
763            nr: 0,
764            full_mask: core::ptr::null(),
765            last_mask: core::ptr::null(),
766            m,
767            n,
768            k,
769            dst_rs: if is_col_major { 1 } else { isize::MIN },
770            dst_cs: isize::MIN,
771            lhs_rs: if is_col_major { 1 } else { isize::MIN },
772            lhs_cs: isize::MIN,
773            rhs_cs: isize::MIN,
774            rhs_rs: isize::MIN,
775        }
776    }
777}
778
779impl<T> Plan<T> {
780    #[inline(always)]
781    pub unsafe fn execute_unchecked(
782        &self,
783        m: usize,
784        n: usize,
785        k: usize,
786        dst: *mut T,
787        dst_rs: isize,
788        dst_cs: isize,
789        lhs: *const T,
790        lhs_rs: isize,
791        lhs_cs: isize,
792        rhs: *const T,
793        rhs_rs: isize,
794        rhs_cs: isize,
795        alpha: T,
796        beta: T,
797        conj_lhs: bool,
798        conj_rhs: bool,
799    ) {
800        debug_assert!(m == self.m);
801        debug_assert!(n == self.n);
802        debug_assert!(k == self.k);
803        if self.dst_cs != isize::MIN {
804            debug_assert!(dst_cs == self.dst_cs);
805        }
806        if self.dst_rs != isize::MIN {
807            debug_assert!(dst_rs == self.dst_rs);
808        }
809        if self.lhs_cs != isize::MIN {
810            debug_assert!(lhs_cs == self.lhs_cs);
811        }
812        if self.lhs_rs != isize::MIN {
813            debug_assert!(lhs_rs == self.lhs_rs);
814        }
815        if self.rhs_cs != isize::MIN {
816            debug_assert!(rhs_cs == self.rhs_cs);
817        }
818        if self.rhs_rs != isize::MIN {
819            debug_assert!(rhs_rs == self.rhs_rs);
820        }
821
822        (self.millikernel)(
823            &self.microkernels,
824            self.mr,
825            self.nr,
826            m,
827            n,
828            k,
829            dst,
830            dst_rs,
831            dst_cs,
832            lhs,
833            lhs_rs,
834            lhs_cs,
835            rhs,
836            rhs_rs,
837            rhs_cs,
838            alpha,
839            beta,
840            conj_lhs,
841            conj_rhs,
842            self.full_mask,
843            self.last_mask,
844        );
845    }
846}
847
848impl Plan<f32> {
849    #[track_caller]
850    pub fn new_f32_impl(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
851        #[cfg(feature = "std")]
852        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
853        {
854            #[cfg(feature = "x86-v4")]
855            if m > 8 && std::is_x86_feature_detected!("avx512f") {
856                return Self::new_f32_avx512(m, n, k, is_col_major);
857            }
858
859            if std::is_x86_feature_detected!("avx2") {
860                if m == 1 {
861                    return Self::new_f32x1(m, n, k, is_col_major);
862                }
863                if m == 2 {
864                    return Self::new_f32x2(m, n, k, is_col_major);
865                }
866                if m <= 4 {
867                    return Self::new_f32x4(m, n, k, is_col_major);
868                }
869
870                return Self::new_f32_avx(m, n, k, is_col_major);
871            }
872        }
873        #[cfg(feature = "std")]
874        #[cfg(target_arch = "aarch64")]
875        {
876            if std::arch::is_aarch64_feature_detected!("neon") {
877                return Self::from_non_masked_impl(
878                    &aarch64::f32::neon::MICROKERNELS,
879                    m,
880                    n,
881                    k,
882                    is_col_major,
883                );
884            }
885        }
886
887        Self::new_f32_scalar(m, n, k, is_col_major)
888    }
889
890    #[track_caller]
891    pub fn new_colmajor_lhs_and_dst_f32(m: usize, n: usize, k: usize) -> Self {
892        Self::new_f32_impl(m, n, k, true)
893    }
894
895    #[track_caller]
896    pub fn new_f32(m: usize, n: usize, k: usize) -> Self {
897        Self::new_f32_impl(m, n, k, false)
898    }
899}
900
901impl Plan<f64> {
902    #[track_caller]
903    pub fn new_f64_impl(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
904        #[cfg(feature = "std")]
905        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
906        {
907            #[cfg(feature = "x86-v4")]
908            if m > 4 && std::is_x86_feature_detected!("avx512f") {
909                return Self::new_f64_avx512(m, n, k, is_col_major);
910            }
911
912            if std::is_x86_feature_detected!("avx2") {
913                if m == 1 {
914                    return Self::new_f64x1(m, n, k, is_col_major);
915                }
916                if m == 2 {
917                    return Self::new_f64x2(m, n, k, is_col_major);
918                }
919
920                return Self::new_f64_avx(m, n, k, is_col_major);
921            }
922        }
923
924        #[cfg(feature = "std")]
925        #[cfg(target_arch = "aarch64")]
926        {
927            if std::arch::is_aarch64_feature_detected!("neon") {
928                return Self::from_non_masked_impl(
929                    &aarch64::f64::neon::MICROKERNELS,
930                    m,
931                    n,
932                    k,
933                    is_col_major,
934                );
935            }
936        }
937
938        Self::new_f64_scalar(m, n, k, is_col_major)
939    }
940
941    #[track_caller]
942    pub fn new_colmajor_lhs_and_dst_f64(m: usize, n: usize, k: usize) -> Self {
943        Self::new_f64_impl(m, n, k, true)
944    }
945
946    #[track_caller]
947    pub fn new_f64(m: usize, n: usize, k: usize) -> Self {
948        Self::new_f64_impl(m, n, k, false)
949    }
950}
951
952impl Plan<c32> {
953    #[track_caller]
954    pub fn new_c32_impl(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
955        #[cfg(feature = "std")]
956        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
957        {
958            #[cfg(feature = "x86-v4")]
959            if m > 4 && std::is_x86_feature_detected!("avx512f") {
960                return Self::new_c32_avx512(m, n, k, is_col_major);
961            }
962
963            if std::is_x86_feature_detected!("avx2") {
964                if m == 1 {
965                    return Self::new_c32x1(m, n, k, is_col_major);
966                }
967                if m == 2 {
968                    return Self::new_c32x2(m, n, k, is_col_major);
969                }
970
971                return Self::new_c32_avx(m, n, k, is_col_major);
972            }
973        }
974
975        #[cfg(feature = "std")]
976        #[cfg(target_arch = "aarch64")]
977        {
978            if std::arch::is_aarch64_feature_detected!("neon")
979                && std::arch::is_aarch64_feature_detected!("fcma")
980            {
981                return Self::from_non_masked_impl(
982                    &aarch64::c32::neon::MICROKERNELS,
983                    m,
984                    n,
985                    k,
986                    is_col_major,
987                );
988            }
989        }
990
991        Self::new_c32_scalar(m, n, k, is_col_major)
992    }
993
994    #[track_caller]
995    pub fn new_colmajor_lhs_and_dst_c32(m: usize, n: usize, k: usize) -> Self {
996        Self::new_c32_impl(m, n, k, true)
997    }
998
999    #[track_caller]
1000    pub fn new_c32(m: usize, n: usize, k: usize) -> Self {
1001        Self::new_c32_impl(m, n, k, false)
1002    }
1003}
1004
1005impl Plan<c64> {
1006    #[track_caller]
1007    pub fn new_c64_impl(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1008        #[cfg(feature = "std")]
1009        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1010        {
1011            #[cfg(feature = "x86-v4")]
1012            if m > 2 && std::is_x86_feature_detected!("avx512f") {
1013                return Self::new_c64_avx512(m, n, k, is_col_major);
1014            }
1015
1016            if std::is_x86_feature_detected!("avx2") {
1017                if m == 1 {
1018                    return Self::new_c64x1(m, n, k, is_col_major);
1019                }
1020                return Self::new_c64_avx(m, n, k, is_col_major);
1021            }
1022        }
1023
1024        #[cfg(feature = "std")]
1025        #[cfg(target_arch = "aarch64")]
1026        {
1027            if std::arch::is_aarch64_feature_detected!("neon")
1028                && std::arch::is_aarch64_feature_detected!("fcma")
1029            {
1030                return Self::from_non_masked_impl(
1031                    &aarch64::c64::neon::MICROKERNELS,
1032                    m,
1033                    n,
1034                    k,
1035                    is_col_major,
1036                );
1037            }
1038        }
1039        Self::new_c64_scalar(m, n, k, is_col_major)
1040    }
1041
1042    #[track_caller]
1043    pub fn new_colmajor_lhs_and_dst_c64(m: usize, n: usize, k: usize) -> Self {
1044        Self::new_c64_impl(m, n, k, true)
1045    }
1046
1047    #[track_caller]
1048    pub fn new_c64(m: usize, n: usize, k: usize) -> Self {
1049        Self::new_c64_impl(m, n, k, false)
1050    }
1051}
1052
1053#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1054mod x86_api {
1055    use super::*;
1056
1057    impl Plan<f32> {
1058        pub fn new_f32x1(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1059            use x86::f32::f32x1::*;
1060            Self::from_masked_impl::<MR_DIV_N, NR, N, ()>(
1061                &MICROKERNELS,
1062                None,
1063                m,
1064                n,
1065                k,
1066                is_col_major,
1067            )
1068        }
1069        pub fn new_f32x2(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1070            use x86::f32::f32x2::*;
1071            Self::from_masked_impl::<MR_DIV_N, NR, N, ()>(
1072                &MICROKERNELS,
1073                None,
1074                m,
1075                n,
1076                k,
1077                is_col_major,
1078            )
1079        }
1080        pub fn new_f32x4(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1081            use x86::f32::f32x4::*;
1082            Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1083        }
1084
1085        pub fn new_f32_avx(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1086            use x86::f32::avx::*;
1087            Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1088        }
1089
1090        #[cfg(feature = "x86-v4")]
1091        pub fn new_f32_avx512(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1092            use x86::f32::avx512::*;
1093            Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1094        }
1095    }
1096
1097    impl Plan<f64> {
1098        pub fn new_f64x1(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1099            use x86::f64::f64x1::*;
1100            Self::from_masked_impl::<MR_DIV_N, NR, N, ()>(
1101                &MICROKERNELS,
1102                None,
1103                m,
1104                n,
1105                k,
1106                is_col_major,
1107            )
1108        }
1109        pub fn new_f64x2(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1110            use x86::f64::f64x2::*;
1111            Self::from_masked_impl::<MR_DIV_N, NR, N, ()>(
1112                &MICROKERNELS,
1113                None,
1114                m,
1115                n,
1116                k,
1117                is_col_major,
1118            )
1119        }
1120
1121        pub fn new_f64_avx(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1122            use x86::f64::avx::*;
1123            Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1124        }
1125
1126        #[cfg(feature = "x86-v4")]
1127        pub fn new_f64_avx512(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1128            use x86::f64::avx512::*;
1129            Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1130        }
1131    }
1132    impl Plan<c32> {
1133        pub fn new_c32x1(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1134            use x86::c32::c32x1::*;
1135            Self::from_masked_impl::<MR_DIV_N, NR, N, ()>(
1136                &MICROKERNELS,
1137                None,
1138                m,
1139                n,
1140                k,
1141                is_col_major,
1142            )
1143        }
1144        pub fn new_c32x2(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1145            use x86::c32::c32x2::*;
1146            Self::from_masked_impl::<MR_DIV_N, NR, N, ()>(
1147                &MICROKERNELS,
1148                None,
1149                m,
1150                n,
1151                k,
1152                is_col_major,
1153            )
1154        }
1155
1156        pub fn new_c32_avx(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1157            use x86::c32::avx::*;
1158            Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1159        }
1160
1161        #[cfg(feature = "x86-v4")]
1162        pub fn new_c32_avx512(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1163            use x86::c32::avx512::*;
1164            Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1165        }
1166    }
1167    impl Plan<c64> {
1168        pub fn new_c64x1(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1169            use x86::c64::c64x1::*;
1170            Self::from_masked_impl::<MR_DIV_N, NR, N, ()>(
1171                &MICROKERNELS,
1172                None,
1173                m,
1174                n,
1175                k,
1176                is_col_major,
1177            )
1178        }
1179
1180        pub fn new_c64_avx(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1181            use x86::c64::avx::*;
1182            Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1183        }
1184
1185        #[cfg(feature = "x86-v4")]
1186        pub fn new_c64_avx512(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1187            use x86::c64::avx512::*;
1188            Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1189        }
1190    }
1191}
1192
1193pub mod planless {
1194    use super::*;
1195
1196    #[inline(always)]
1197    pub unsafe fn execute_f32(
1198        mut m: usize,
1199        mut n: usize,
1200        k: usize,
1201        mut dst: *mut f32,
1202        mut dst_rs: isize,
1203        mut dst_cs: isize,
1204        mut lhs: *const f32,
1205        mut lhs_rs: isize,
1206        mut lhs_cs: isize,
1207        mut rhs: *const f32,
1208        mut rhs_rs: isize,
1209        mut rhs_cs: isize,
1210        alpha: f32,
1211        beta: f32,
1212        mut conj_lhs: bool,
1213        mut conj_rhs: bool,
1214    ) {
1215        if dst_cs.unsigned_abs() < dst_rs.unsigned_abs() {
1216            core::mem::swap(&mut m, &mut n);
1217            core::mem::swap(&mut dst_rs, &mut dst_cs);
1218            core::mem::swap(&mut lhs, &mut rhs);
1219            core::mem::swap(&mut lhs_rs, &mut rhs_cs);
1220            core::mem::swap(&mut lhs_cs, &mut rhs_rs);
1221            core::mem::swap(&mut conj_lhs, &mut conj_rhs);
1222        }
1223        if dst_rs == -1 && m > 0 {
1224            dst = dst.wrapping_offset((m - 1) as isize * dst_rs);
1225            dst_rs = dst_rs.wrapping_neg();
1226            lhs = lhs.wrapping_offset((m - 1) as isize * lhs_rs);
1227            lhs_rs = lhs_rs.wrapping_neg();
1228        }
1229
1230        let plan = if lhs_rs == 1 && dst_rs == 1 {
1231            Plan::new_colmajor_lhs_and_dst_f32(m, n, k)
1232        } else {
1233            Plan::new_f32(m, n, k)
1234        };
1235        plan.execute_unchecked(
1236            m, n, k, dst, dst_rs, dst_cs, lhs, lhs_rs, lhs_cs, rhs, rhs_rs, rhs_cs, alpha, beta,
1237            conj_lhs, conj_rhs,
1238        )
1239    }
1240
1241    #[inline(always)]
1242    pub unsafe fn execute_c32(
1243        mut m: usize,
1244        mut n: usize,
1245        k: usize,
1246        mut dst: *mut c32,
1247        mut dst_rs: isize,
1248        mut dst_cs: isize,
1249        mut lhs: *const c32,
1250        mut lhs_rs: isize,
1251        mut lhs_cs: isize,
1252        mut rhs: *const c32,
1253        mut rhs_rs: isize,
1254        mut rhs_cs: isize,
1255        alpha: c32,
1256        beta: c32,
1257        mut conj_lhs: bool,
1258        mut conj_rhs: bool,
1259    ) {
1260        if dst_cs.unsigned_abs() < dst_rs.unsigned_abs() {
1261            core::mem::swap(&mut m, &mut n);
1262            core::mem::swap(&mut dst_rs, &mut dst_cs);
1263            core::mem::swap(&mut lhs, &mut rhs);
1264            core::mem::swap(&mut lhs_rs, &mut rhs_cs);
1265            core::mem::swap(&mut lhs_cs, &mut rhs_rs);
1266            core::mem::swap(&mut conj_lhs, &mut conj_rhs);
1267        }
1268        if dst_rs == -1 && m > 0 {
1269            dst = dst.wrapping_offset((m - 1) as isize * dst_rs);
1270            dst_rs = dst_rs.wrapping_neg();
1271            lhs = lhs.wrapping_offset((m - 1) as isize * lhs_rs);
1272            lhs_rs = lhs_rs.wrapping_neg();
1273        }
1274
1275        let plan = if lhs_rs == 1 && dst_rs == 1 {
1276            Plan::new_colmajor_lhs_and_dst_c32(m, n, k)
1277        } else {
1278            Plan::new_c32(m, n, k)
1279        };
1280        plan.execute_unchecked(
1281            m, n, k, dst, dst_rs, dst_cs, lhs, lhs_rs, lhs_cs, rhs, rhs_rs, rhs_cs, alpha, beta,
1282            conj_lhs, conj_rhs,
1283        )
1284    }
1285
1286    #[inline(always)]
1287    pub unsafe fn execute_f64(
1288        mut m: usize,
1289        mut n: usize,
1290        k: usize,
1291        mut dst: *mut f64,
1292        mut dst_rs: isize,
1293        mut dst_cs: isize,
1294        mut lhs: *const f64,
1295        mut lhs_rs: isize,
1296        mut lhs_cs: isize,
1297        mut rhs: *const f64,
1298        mut rhs_rs: isize,
1299        mut rhs_cs: isize,
1300        alpha: f64,
1301        beta: f64,
1302        mut conj_lhs: bool,
1303        mut conj_rhs: bool,
1304    ) {
1305        if dst_cs.unsigned_abs() < dst_rs.unsigned_abs() {
1306            core::mem::swap(&mut m, &mut n);
1307            core::mem::swap(&mut dst_rs, &mut dst_cs);
1308            core::mem::swap(&mut lhs, &mut rhs);
1309            core::mem::swap(&mut lhs_rs, &mut rhs_cs);
1310            core::mem::swap(&mut lhs_cs, &mut rhs_rs);
1311            core::mem::swap(&mut conj_lhs, &mut conj_rhs);
1312        }
1313        if dst_rs == -1 && m > 0 {
1314            dst = dst.wrapping_offset((m - 1) as isize * dst_rs);
1315            dst_rs = dst_rs.wrapping_neg();
1316            lhs = lhs.wrapping_offset((m - 1) as isize * lhs_rs);
1317            lhs_rs = lhs_rs.wrapping_neg();
1318        }
1319
1320        let plan = if lhs_rs == 1 && dst_rs == 1 {
1321            Plan::new_colmajor_lhs_and_dst_f64(m, n, k)
1322        } else {
1323            Plan::new_f64(m, n, k)
1324        };
1325        plan.execute_unchecked(
1326            m, n, k, dst, dst_rs, dst_cs, lhs, lhs_rs, lhs_cs, rhs, rhs_rs, rhs_cs, alpha, beta,
1327            conj_lhs, conj_rhs,
1328        )
1329    }
1330
1331    #[inline(always)]
1332    pub unsafe fn execute_c64(
1333        mut m: usize,
1334        mut n: usize,
1335        k: usize,
1336        mut dst: *mut c64,
1337        mut dst_rs: isize,
1338        mut dst_cs: isize,
1339        mut lhs: *const c64,
1340        mut lhs_rs: isize,
1341        mut lhs_cs: isize,
1342        mut rhs: *const c64,
1343        mut rhs_rs: isize,
1344        mut rhs_cs: isize,
1345        alpha: c64,
1346        beta: c64,
1347        mut conj_lhs: bool,
1348        mut conj_rhs: bool,
1349    ) {
1350        if dst_cs.unsigned_abs() < dst_rs.unsigned_abs() {
1351            core::mem::swap(&mut m, &mut n);
1352            core::mem::swap(&mut dst_rs, &mut dst_cs);
1353            core::mem::swap(&mut lhs, &mut rhs);
1354            core::mem::swap(&mut lhs_rs, &mut rhs_cs);
1355            core::mem::swap(&mut lhs_cs, &mut rhs_rs);
1356            core::mem::swap(&mut conj_lhs, &mut conj_rhs);
1357        }
1358        if dst_rs == -1 && m > 0 {
1359            dst = dst.wrapping_offset((m - 1) as isize * dst_rs);
1360            dst_rs = dst_rs.wrapping_neg();
1361            lhs = lhs.wrapping_offset((m - 1) as isize * lhs_rs);
1362            lhs_rs = lhs_rs.wrapping_neg();
1363        }
1364
1365        let plan = if lhs_rs == 1 && dst_rs == 1 {
1366            Plan::new_colmajor_lhs_and_dst_c64(m, n, k)
1367        } else {
1368            Plan::new_c64(m, n, k)
1369        };
1370        plan.execute_unchecked(
1371            m, n, k, dst, dst_rs, dst_cs, lhs, lhs_rs, lhs_cs, rhs, rhs_rs, rhs_cs, alpha, beta,
1372            conj_lhs, conj_rhs,
1373        )
1374    }
1375}
1376
1377#[cfg(test)]
1378mod tests {
1379    use super::*;
1380    use equator::assert;
1381
1382    #[cfg(target_arch = "x86_64")]
1383    #[test]
1384    fn test_kernel() {
1385        let gen = |_| rand::random::<f32>();
1386        let a: [[f32; 17]; 3] = core::array::from_fn(|_| core::array::from_fn(gen));
1387        let b: [[f32; 6]; 4] = core::array::from_fn(|_| core::array::from_fn(gen));
1388        let c: [[f32; 15]; 4] = core::array::from_fn(|_| core::array::from_fn(gen));
1389        assert!(std::is_x86_feature_detected!("avx2"));
1390        let mut dst = c;
1391
1392        let last_mask: std::arch::x86_64::__m256i = unsafe {
1393            core::mem::transmute([
1394                u32::MAX,
1395                u32::MAX,
1396                u32::MAX,
1397                u32::MAX,
1398                u32::MAX,
1399                u32::MAX,
1400                u32::MAX,
1401                0,
1402            ])
1403        };
1404
1405        let beta = 2.5;
1406        let alpha = 1.0;
1407
1408        unsafe {
1409            x86::f32::avx::matmul_2_4_dyn(
1410                &MicroKernelData {
1411                    alpha,
1412                    beta,
1413                    conj_lhs: false,
1414                    conj_rhs: false,
1415                    k: 3,
1416                    dst_cs: dst[0].len() as isize,
1417                    lhs_cs: a[0].len() as isize,
1418                    rhs_rs: 2,
1419                    rhs_cs: 6,
1420                    last_mask: (&last_mask) as *const _ as *const (),
1421                },
1422                dst.as_mut_ptr() as *mut f32,
1423                a.as_ptr() as *const f32,
1424                b.as_ptr() as *const f32,
1425            );
1426        };
1427
1428        let mut expected_dst = c;
1429        for i in 0..15 {
1430            for j in 0..4 {
1431                let mut acc = 0.0f32;
1432                for depth in 0..3 {
1433                    acc = f32::mul_add(a[depth][i], b[j][2 * depth], acc);
1434                }
1435                expected_dst[j][i] = f32::mul_add(beta, acc, expected_dst[j][i]);
1436            }
1437        }
1438
1439        assert!(dst == expected_dst);
1440    }
1441
1442    #[cfg(target_arch = "x86_64")]
1443    #[test]
1444    fn test_kernel_cplx() {
1445        let gen = |_| rand::random::<c32>();
1446        let a: [[c32; 9]; 3] = core::array::from_fn(|_| core::array::from_fn(gen));
1447        let b: [[c32; 6]; 2] = core::array::from_fn(|_| core::array::from_fn(gen));
1448        let c: [[c32; 7]; 2] = core::array::from_fn(|_| core::array::from_fn(gen));
1449        assert!(std::is_x86_feature_detected!("avx2"));
1450
1451        let last_mask: std::arch::x86_64::__m256i = unsafe {
1452            core::mem::transmute([
1453                u32::MAX,
1454                u32::MAX,
1455                u32::MAX,
1456                u32::MAX,
1457                u32::MAX,
1458                u32::MAX,
1459                0,
1460                0,
1461            ])
1462        };
1463
1464        let beta = c32::new(2.5, 3.5);
1465        let alpha = c32::new(1.0, 0.0);
1466
1467        for (conj_lhs, conj_rhs) in [(false, false), (false, true), (true, false), (true, true)] {
1468            let mut dst = c;
1469            unsafe {
1470                x86::c32::avx::matmul_2_2_dyn(
1471                    &MicroKernelData {
1472                        alpha,
1473                        beta,
1474                        conj_lhs,
1475                        conj_rhs,
1476                        k: 3,
1477                        dst_cs: dst[0].len() as isize,
1478                        lhs_cs: a[0].len() as isize,
1479                        rhs_rs: 2,
1480                        rhs_cs: b[0].len() as isize,
1481                        last_mask: (&last_mask) as *const _ as *const (),
1482                    },
1483                    dst.as_mut_ptr() as *mut c32,
1484                    a.as_ptr() as *const c32,
1485                    b.as_ptr() as *const c32,
1486                );
1487            };
1488
1489            let mut expected_dst = c;
1490            for i in 0..7 {
1491                for j in 0..2 {
1492                    let mut acc = c32::new(0.0, 0.0);
1493                    for depth in 0..3 {
1494                        let mut a = a[depth][i];
1495                        let mut b = b[j][2 * depth];
1496                        if conj_lhs {
1497                            a = a.conj();
1498                        }
1499                        if conj_rhs {
1500                            b = b.conj();
1501                        }
1502                        acc += a * b;
1503                    }
1504                    expected_dst[j][i] += beta * acc;
1505                }
1506            }
1507
1508            for (&dst, &expected_dst) in
1509                core::iter::zip(dst.iter().flatten(), expected_dst.iter().flatten())
1510            {
1511                assert!((dst.re - expected_dst.re).abs() < 1e-5);
1512                assert!((dst.im - expected_dst.im).abs() < 1e-5);
1513            }
1514        }
1515    }
1516
1517    #[cfg(target_arch = "x86_64")]
1518    #[test]
1519    fn test_kernel_cplx64() {
1520        let gen = |_| rand::random::<c64>();
1521        let a: [[c64; 5]; 3] = core::array::from_fn(|_| core::array::from_fn(gen));
1522        let b: [[c64; 6]; 2] = core::array::from_fn(|_| core::array::from_fn(gen));
1523        let c: [[c64; 3]; 2] = core::array::from_fn(|_| core::array::from_fn(gen));
1524        assert!(std::is_x86_feature_detected!("avx2"));
1525
1526        let last_mask: std::arch::x86_64::__m256i =
1527            unsafe { core::mem::transmute([u64::MAX, u64::MAX, 0, 0]) };
1528
1529        let beta = c64::new(2.5, 3.5);
1530        let alpha = c64::new(1.0, 0.0);
1531
1532        for (conj_lhs, conj_rhs) in [(false, false), (false, true), (true, false), (true, true)] {
1533            let mut dst = c;
1534            unsafe {
1535                x86::c64::avx::matmul_2_2_dyn(
1536                    &MicroKernelData {
1537                        alpha,
1538                        beta,
1539                        conj_lhs,
1540                        conj_rhs,
1541                        k: 3,
1542                        dst_cs: dst[0].len() as isize,
1543                        lhs_cs: a[0].len() as isize,
1544                        rhs_rs: 2,
1545                        rhs_cs: b[0].len() as isize,
1546                        last_mask: (&last_mask) as *const _ as *const (),
1547                    },
1548                    dst.as_mut_ptr() as *mut c64,
1549                    a.as_ptr() as *const c64,
1550                    b.as_ptr() as *const c64,
1551                );
1552            };
1553
1554            let mut expected_dst = c;
1555            for i in 0..3 {
1556                for j in 0..2 {
1557                    let mut acc = c64::new(0.0, 0.0);
1558                    for depth in 0..3 {
1559                        let mut a = a[depth][i];
1560                        let mut b = b[j][2 * depth];
1561                        if conj_lhs {
1562                            a = a.conj();
1563                        }
1564                        if conj_rhs {
1565                            b = b.conj();
1566                        }
1567                        acc += a * b;
1568                    }
1569                    expected_dst[j][i] += beta * acc;
1570                }
1571            }
1572
1573            for (&dst, &expected_dst) in
1574                core::iter::zip(dst.iter().flatten(), expected_dst.iter().flatten())
1575            {
1576                assert!((dst.re - expected_dst.re).abs() < 1e-5);
1577                assert!((dst.im - expected_dst.im).abs() < 1e-5);
1578            }
1579        }
1580    }
1581    #[test]
1582    fn test_plan() {
1583        let gen = |_| rand::random::<f32>();
1584        for ((m, n), k) in (64..=64).zip(64..=64).zip([1, 4, 64]) {
1585            let a = (0..m * k).into_iter().map(gen).collect::<Vec<_>>();
1586            let b = (0..k * n).into_iter().map(gen).collect::<Vec<_>>();
1587            let c = (0..m * n).into_iter().map(|_| 0.0).collect::<Vec<_>>();
1588            let mut dst = c.clone();
1589
1590            let plan = Plan::new_f32(m, n, k);
1591            let beta = 2.5;
1592
1593            unsafe {
1594                plan.execute_unchecked(
1595                    m,
1596                    n,
1597                    k,
1598                    dst.as_mut_ptr(),
1599                    1,
1600                    m as isize,
1601                    a.as_ptr(),
1602                    1,
1603                    m as isize,
1604                    b.as_ptr(),
1605                    1,
1606                    k as isize,
1607                    1.0,
1608                    beta,
1609                    false,
1610                    false,
1611                );
1612            };
1613
1614            let mut expected_dst = c;
1615            for i in 0..m {
1616                for j in 0..n {
1617                    let mut acc = 0.0f32;
1618                    for depth in 0..k {
1619                        acc = f32::mul_add(a[depth * m + i], b[j * k + depth], acc);
1620                    }
1621                    expected_dst[j * m + i] = f32::mul_add(beta, acc, expected_dst[j * m + i]);
1622                }
1623            }
1624
1625            for (dst, expected_dst) in dst.iter().zip(&expected_dst) {
1626                assert!((dst - expected_dst).abs() < 1e-4);
1627            }
1628        }
1629    }
1630
1631    #[test]
1632    fn test_plan_cplx() {
1633        let gen = |_| rand::random::<c64>();
1634        for ((m, n), k) in (0..128).zip(0..128).zip([1, 4, 17]) {
1635            let a = (0..m * k).into_iter().map(gen).collect::<Vec<_>>();
1636            let b = (0..k * n).into_iter().map(gen).collect::<Vec<_>>();
1637            let c = (0..m * n).into_iter().map(gen).collect::<Vec<_>>();
1638
1639            for (conj_lhs, conj_rhs) in [(false, true), (false, false), (true, true), (true, false)]
1640            {
1641                for alpha in [c64::new(0.0, 0.0), c64::new(1.0, 0.0), c64::new(2.7, 3.7)] {
1642                    let mut dst = c.clone();
1643
1644                    let plan = Plan::new_colmajor_lhs_and_dst_c64(m, n, k);
1645                    let beta = c64::new(2.5, 0.0);
1646
1647                    unsafe {
1648                        plan.execute_unchecked(
1649                            m,
1650                            n,
1651                            k,
1652                            dst.as_mut_ptr(),
1653                            1,
1654                            m as isize,
1655                            a.as_ptr(),
1656                            1,
1657                            m as isize,
1658                            b.as_ptr(),
1659                            1,
1660                            k as isize,
1661                            alpha,
1662                            beta,
1663                            conj_lhs,
1664                            conj_rhs,
1665                        );
1666                    };
1667
1668                    let mut expected_dst = c.clone();
1669                    for i in 0..m {
1670                        for j in 0..n {
1671                            let mut acc = c64::new(0.0, 0.0);
1672                            for depth in 0..k {
1673                                let mut a = a[depth * m + i];
1674                                let mut b = b[j * k + depth];
1675                                if conj_lhs {
1676                                    a = a.conj();
1677                                }
1678                                if conj_rhs {
1679                                    b = b.conj();
1680                                }
1681                                acc += a * b;
1682                            }
1683                            expected_dst[j * m + i] = alpha * expected_dst[j * m + i] + beta * acc;
1684                        }
1685                    }
1686
1687                    for (&dst, &expected_dst) in core::iter::zip(dst.iter(), expected_dst.iter()) {
1688                        assert!((dst.re - expected_dst.re).abs() < 1e-5);
1689                        assert!((dst.im - expected_dst.im).abs() < 1e-5);
1690                    }
1691                }
1692            }
1693        }
1694    }
1695
1696    #[test]
1697    fn test_plan_strided() {
1698        let gen = |_| rand::random::<f32>();
1699        for ((m, n), k) in (0..128).zip(0..128).zip([1, 4, 17]) {
1700            let a = (0..2 * 200 * k).into_iter().map(gen).collect::<Vec<_>>();
1701            let b = (0..k * n).into_iter().map(gen).collect::<Vec<_>>();
1702            let c = (0..3 * 400 * n)
1703                .into_iter()
1704                .map(|_| 0.0)
1705                .collect::<Vec<_>>();
1706            let mut dst = c.clone();
1707
1708            let plan = Plan::new_f32(m, n, k);
1709            let beta = 2.5;
1710
1711            unsafe {
1712                plan.execute_unchecked(
1713                    m,
1714                    n,
1715                    k,
1716                    dst.as_mut_ptr(),
1717                    3,
1718                    400,
1719                    a.as_ptr(),
1720                    2,
1721                    200,
1722                    b.as_ptr(),
1723                    1,
1724                    k as isize,
1725                    1.0,
1726                    beta,
1727                    false,
1728                    false,
1729                );
1730            };
1731
1732            let mut expected_dst = c;
1733            for i in 0..m {
1734                for j in 0..n {
1735                    let mut acc = 0.0f32;
1736                    for depth in 0..k {
1737                        acc = f32::mul_add(a[depth * 200 + i * 2], b[j * k + depth], acc);
1738                    }
1739                    expected_dst[j * 400 + i * 3] =
1740                        f32::mul_add(beta, acc, expected_dst[j * 400 + i * 3]);
1741                }
1742            }
1743
1744            for (dst, expected_dst) in dst.iter().zip(&expected_dst) {
1745                assert!((dst - expected_dst).abs() < 1e-4);
1746            }
1747        }
1748    }
1749
1750    #[test]
1751    fn test_plan_cplx_strided() {
1752        let gen = |_| c64::new(rand::random(), rand::random());
1753        for ((m, n), k) in (0..128).zip(0..128).zip([1, 4, 17, 190]) {
1754            let a = (0..2 * 200 * k).into_iter().map(gen).collect::<Vec<_>>();
1755            let b = (0..k * n).into_iter().map(gen).collect::<Vec<_>>();
1756            let c = (0..3 * 400 * n)
1757                .into_iter()
1758                .map(|_| c64::ZERO)
1759                .collect::<Vec<_>>();
1760            let mut dst = c.clone();
1761
1762            let beta = 2.5.into();
1763
1764            unsafe {
1765                planless::execute_c64(
1766                    m,
1767                    n,
1768                    k,
1769                    dst.as_mut_ptr(),
1770                    3,
1771                    400,
1772                    a.as_ptr(),
1773                    2,
1774                    200,
1775                    b.as_ptr(),
1776                    1,
1777                    k as isize,
1778                    1.0.into(),
1779                    beta,
1780                    false,
1781                    false,
1782                );
1783            };
1784
1785            let mut expected_dst = c;
1786            for i in 0..m {
1787                for j in 0..n {
1788                    let mut acc = c64::ZERO;
1789                    for depth in 0..k {
1790                        acc += a[depth * 200 + i * 2] * b[j * k + depth];
1791                    }
1792                    expected_dst[j * 400 + i * 3] = beta * acc + expected_dst[j * 400 + i * 3];
1793                }
1794            }
1795
1796            for (dst, expected_dst) in dst.iter().zip(&expected_dst) {
1797                use num_complex::ComplexFloat;
1798                assert!((dst - expected_dst).abs() < 1e-4);
1799            }
1800        }
1801    }
1802
1803    #[test]
1804    fn test_plan_cplx_strided2() {
1805        let gen = |_| c64::new(rand::random(), rand::random());
1806        let m = 102;
1807        let n = 2;
1808        let k = 190;
1809        {
1810            let a = (0..2 * 200 * k).into_iter().map(gen).collect::<Vec<_>>();
1811            let b = (0..k * n).into_iter().map(gen).collect::<Vec<_>>();
1812            let c = (0..400 * n)
1813                .into_iter()
1814                .map(|_| c64::ZERO)
1815                .collect::<Vec<_>>();
1816            let mut dst = c.clone();
1817
1818            let beta = 2.5.into();
1819
1820            unsafe {
1821                planless::execute_c64(
1822                    m,
1823                    n,
1824                    k,
1825                    dst.as_mut_ptr(),
1826                    1,
1827                    400,
1828                    a.as_ptr(),
1829                    2,
1830                    200,
1831                    b.as_ptr(),
1832                    1,
1833                    k as isize,
1834                    1.0.into(),
1835                    beta,
1836                    false,
1837                    false,
1838                );
1839            };
1840
1841            let mut expected_dst = c;
1842            for i in 0..m {
1843                for j in 0..n {
1844                    let mut acc = c64::ZERO;
1845                    for depth in 0..k {
1846                        acc += a[depth * 200 + i * 2] * b[j * k + depth];
1847                    }
1848                    expected_dst[j * 400 + i] = beta * acc + expected_dst[j * 400 + i];
1849                }
1850            }
1851
1852            for (dst, expected_dst) in dst.iter().zip(&expected_dst) {
1853                use num_complex::ComplexFloat;
1854                assert!((dst - expected_dst).abs() < 1e-4);
1855            }
1856        }
1857    }
1858}