nano_gemm_codegen/
lib.rs

1// only targeting scalar, avx2, avx512
2// no arm64 (until SVE comes), because no masked load instructions
3
4// f32, f64, c32, c64
5
6use std::fmt::Display;
7use std::fmt::Write;
8
9mod generic {
10    use super::*;
11
12    pub struct RealKernel {
13        pub ty: &'static str,
14        pub reg_ty: &'static str,
15        // register size
16        pub n: usize,
17        pub mr: usize,
18        pub nr: usize,
19        pub k: Option<usize>,
20
21        pub target_features: &'static str,
22        pub load_unaligned: [&'static str; 3],
23        pub store_unaligned: [&'static str; 3],
24        pub set1: &'static str,
25        pub mul_add: &'static str,
26    }
27
28    pub struct CplxKernel {
29        pub ty: &'static str,
30        pub reg_ty: &'static str,
31        // register size
32        pub n: usize,
33        pub mr: usize,
34        pub nr: usize,
35        pub k: Option<usize>,
36
37        pub target_features: &'static str,
38        pub load_unaligned: [&'static str; 3],
39        pub store_unaligned: [&'static str; 3],
40        pub set1: &'static str,
41        pub mul_add: &'static str,
42        pub conj_mul_add: &'static str,
43        pub conj: &'static str,
44    }
45
46    impl Display for RealKernel {
47        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48            let Self {
49                ty,
50                reg_ty,
51                n,
52                mr,
53                nr,
54                k,
55                target_features,
56                load_unaligned,
57                store_unaligned,
58                mul_add,
59                ..
60            } = self;
61
62            write!(f, "#[target_feature(enable = \"{target_features}\")]\n")?;
63            write!(
64                f,
65                r#"pub unsafe fn matmul_{mr}_{nr}_{}(
66                &nano_gemm_core::MicroKernelData {{ alpha, beta, k, dst_cs, lhs_cs, rhs_rs, rhs_cs, .. }}: &nano_gemm_core::MicroKernelData< {ty} >,
67                dst: *mut {ty},
68                lhs: *const {ty},
69                rhs: *const {ty},
70            ) {{
71"#,
72                k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
73            )?;
74
75            write!(f, "_ = k;\n")?;
76            let mut i = 0;
77            while i < *mr {
78                let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
79
80                for j in 0..*nr {
81                    write!(
82                        f,
83                        "let mut acc_{i}_{j}: {} = core::mem::zeroed();\n",
84                        reg_ty
85                    )?;
86                }
87
88                i += 1 << ii;
89            }
90
91            if let Some(k) = self.k {
92                for depth in 0..k {
93                    write!(f, "let depth = {depth};\n")?;
94                    self.inner_kernel(f)?;
95                }
96            } else {
97                write!(f, "for depth in 0..k as isize {{")?;
98                self.inner_kernel(f)?;
99                write!(f, "}}")?;
100            }
101
102            write!(f, "if alpha == 1.0 {{")?;
103            write!(f, "let beta = {}(beta);\n", self.set1)?;
104            for j in 0..self.nr {
105                let mut i = 0;
106                while i < *mr {
107                    let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
108                    write!(f, "{{")?;
109                    write!(f, "let dst = dst.offset({i} + {j} * dst_cs);")?;
110                    write!(
111                        f,
112                        "{}(dst, {mul_add}(beta, acc_{i}_{j}, {}(dst)));\n",
113                        store_unaligned[ii], load_unaligned[ii],
114                    )?;
115                    write!(f, "}}")?;
116                    i += 1 << ii;
117                }
118            }
119            write!(f, "}}")?;
120
121            write!(f, "else if alpha == 0.0 {{")?;
122            write!(f, "let beta = {}(beta);\n", self.set1)?;
123            for j in 0..self.nr {
124                let mut i = 0;
125                while i < *mr {
126                    let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
127                    write!(f, "{{")?;
128                    write!(f, "let dst = dst.offset({i} + {j} * dst_cs);")?;
129                    write!(
130                        f,
131                        "{}(dst, {mul_add}(beta, acc_{i}_{j}, core::mem::zeroed()));\n",
132                        store_unaligned[ii],
133                    )?;
134                    write!(f, "}}")?;
135                    i += 1 << ii;
136                }
137            }
138            write!(f, "}}")?;
139
140            write!(f, "else {{")?;
141            write!(f, "let beta = {}(beta);\n", self.set1)?;
142            write!(f, "let alpha = {}(alpha);\n", self.set1)?;
143            for j in 0..self.nr {
144                let mut i = 0;
145                while i < *mr {
146                    let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
147                    write!(f, "{{")?;
148                    write!(f, "let dst = dst.offset({i} + {j} * dst_cs);")?;
149                    write!(
150                        f,
151                        "{}(dst, {mul_add}(beta, acc_{i}_{j}, {mul_add}(alpha, {}(dst), core::mem::zeroed())));\n",
152                        store_unaligned[ii], load_unaligned[ii],
153                    )?;
154                    write!(f, "}}")?;
155                    i += 1 << ii;
156                }
157            }
158            write!(f, "}}")?;
159
160            write!(f, "}}")
161        }
162    }
163
164    impl RealKernel {
165        fn inner_kernel(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
166            let Self {
167                mr,
168                set1,
169                mul_add,
170                load_unaligned,
171                n,
172                ..
173            } = self;
174
175            let mut i = 0;
176            while i < *mr {
177                let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
178                write!(
179                    f,
180                    "let tmp_lhs_{i} = {}(lhs.offset(depth * lhs_cs + {i}));",
181                    load_unaligned[ii],
182                )?;
183                i += 1 << ii;
184            }
185            for j in 0..self.nr {
186                write!(
187                    f,
188                    "let tmp_rhs = {set1}(*rhs.offset(depth * rhs_rs + {j} * rhs_cs));\n",
189                )?;
190
191                let mut i = 0;
192                while i < *mr {
193                    let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
194                    write!(
195                        f,
196                        "acc_{i}_{j} = {mul_add}(tmp_lhs_{i}, tmp_rhs, acc_{i}_{j});\n",
197                    )?;
198                    i += 1 << ii;
199                }
200            }
201
202            Ok(())
203        }
204    }
205
206    impl Display for CplxKernel {
207        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208            let Self {
209                ty,
210                reg_ty,
211                n,
212                mr,
213                nr,
214                k,
215                target_features,
216                load_unaligned,
217                store_unaligned,
218                mul_add,
219                conj,
220                ..
221            } = self;
222
223            write!(f, "#[target_feature(enable = \"{target_features}\")]\n")?;
224            write!(
225                f,
226                r#"pub unsafe fn matmul_{mr}_{nr}_{}(
227                &nano_gemm_core::MicroKernelData {{ alpha, beta, k, dst_cs, lhs_cs, rhs_rs, rhs_cs, conj_lhs, conj_rhs, .. }}: &nano_gemm_core::MicroKernelData< {ty} >,
228                dst: *mut {ty},
229                lhs: *const {ty},
230                rhs: *const {ty},
231            ) {{
232"#,
233                k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
234            )?;
235
236            write!(f, "_ = k;\n")?;
237            let mut i = 0;
238            while i < *mr {
239                let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
240
241                for j in 0..*nr {
242                    write!(
243                        f,
244                        "let mut acc_{i}_{j}: {} = core::mem::zeroed();\n",
245                        reg_ty
246                    )?;
247                }
248
249                i += 1 << ii;
250            }
251
252            write!(f, "if conj_lhs == conj_rhs {{")?;
253            if let Some(k) = self.k {
254                for depth in 0..k {
255                    write!(f, "let depth = {depth};\n")?;
256                    self.inner_kernel_no_conj(f)?;
257                }
258            } else {
259                write!(f, "for depth in 0..k as isize {{")?;
260                self.inner_kernel_no_conj(f)?;
261                write!(f, "}}")?;
262            }
263            write!(f, "}} else {{")?;
264            if let Some(k) = self.k {
265                for depth in 0..k {
266                    write!(f, "let depth = {depth};\n")?;
267                    self.inner_kernel_conj(f)?;
268                }
269            } else {
270                write!(f, "for depth in 0..k as isize {{")?;
271                self.inner_kernel_conj(f)?;
272                write!(f, "}}")?;
273            }
274            write!(f, "}}")?;
275
276            write!(f, "if conj_rhs {{")?;
277            for j in 0..self.nr {
278                let mut i = 0;
279                while i < *mr {
280                    let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
281                    write!(f, "acc_{i}_{j} = {conj}(acc_{i}_{j});")?;
282                    i += 1 << ii;
283                }
284            }
285            write!(f, "}}")?;
286
287            write!(f, "if alpha == ({ty} {{ re: 1.0, im: 0.0 }}) {{")?;
288            write!(f, "let beta = {}(beta);\n", self.set1)?;
289            for j in 0..self.nr {
290                let mut i = 0;
291                while i < *mr {
292                    let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
293                    write!(f, "{{")?;
294                    write!(f, "let dst = dst.offset({i} + {j} * dst_cs);")?;
295                    write!(
296                        f,
297                        "{}(dst, {mul_add}(beta, acc_{i}_{j}, {}(dst)));\n",
298                        store_unaligned[ii], load_unaligned[ii],
299                    )?;
300                    write!(f, "}}")?;
301                    i += 1 << ii;
302                }
303            }
304            write!(f, "}}")?;
305
306            write!(f, "else if alpha == ({ty} {{ re: 0.0, im: 0.0 }}) {{")?;
307            write!(f, "let beta = {}(beta);\n", self.set1)?;
308            for j in 0..self.nr {
309                let mut i = 0;
310                while i < *mr {
311                    let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
312                    write!(f, "{{")?;
313                    write!(f, "let dst = dst.offset({i} + {j} * dst_cs);")?;
314                    write!(
315                        f,
316                        "{}(dst, {mul_add}(beta, acc_{i}_{j}, core::mem::zeroed()));\n",
317                        store_unaligned[ii],
318                    )?;
319                    write!(f, "}}")?;
320                    i += 1 << ii;
321                }
322            }
323            write!(f, "}}")?;
324
325            write!(f, "else {{")?;
326            write!(f, "let beta = {}(beta);\n", self.set1)?;
327            write!(f, "let alpha = {}(alpha);\n", self.set1)?;
328            for j in 0..self.nr {
329                let mut i = 0;
330                while i < *mr {
331                    let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
332                    write!(f, "{{")?;
333                    write!(f, "let dst = dst.offset({i} + {j} * dst_cs);")?;
334                    write!(
335                        f,
336                        "{}(dst, {mul_add}(beta, acc_{i}_{j}, {mul_add}(alpha, {}(dst), core::mem::zeroed())));\n",
337                        store_unaligned[ii], load_unaligned[ii],
338                    )?;
339                    write!(f, "}}")?;
340                    i += 1 << ii;
341                }
342            }
343            write!(f, "}}")?;
344
345            write!(f, "}}")
346        }
347    }
348
349    impl CplxKernel {
350        fn inner_kernel_no_conj(
351            &self,
352            f: &mut std::fmt::Formatter<'_>,
353        ) -> Result<(), std::fmt::Error> {
354            let Self {
355                mr,
356                set1,
357                mul_add,
358                load_unaligned,
359                n,
360                ..
361            } = self;
362
363            let mut i = 0;
364            while i < *mr {
365                let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
366                write!(
367                    f,
368                    "let tmp_lhs_{i} = {}(lhs.offset(depth * lhs_cs + {i}));",
369                    load_unaligned[ii],
370                )?;
371                i += 1 << ii;
372            }
373            for j in 0..self.nr {
374                write!(
375                    f,
376                    "let tmp_rhs = {set1}(*rhs.offset(depth * rhs_rs + {j} * rhs_cs));\n",
377                )?;
378
379                let mut i = 0;
380                while i < *mr {
381                    let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
382                    write!(
383                        f,
384                        "acc_{i}_{j} = {mul_add}(tmp_lhs_{i}, tmp_rhs, acc_{i}_{j});\n",
385                    )?;
386                    i += 1 << ii;
387                }
388            }
389
390            Ok(())
391        }
392
393        fn inner_kernel_conj(
394            &self,
395            f: &mut std::fmt::Formatter<'_>,
396        ) -> Result<(), std::fmt::Error> {
397            let Self {
398                mr,
399                set1,
400                conj_mul_add,
401                load_unaligned,
402                n,
403                ..
404            } = self;
405
406            let mut i = 0;
407            while i < *mr {
408                let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
409                write!(
410                    f,
411                    "let tmp_lhs_{i} = {}(lhs.offset(depth * lhs_cs + {i}));",
412                    load_unaligned[ii],
413                )?;
414                i += 1 << ii;
415            }
416            for j in 0..self.nr {
417                write!(
418                    f,
419                    "let tmp_rhs = {set1}(*rhs.offset(depth * rhs_rs + {j} * rhs_cs));\n",
420                )?;
421
422                let mut i = 0;
423                while i < *mr {
424                    let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
425                    write!(
426                        f,
427                        "acc_{i}_{j} = {conj_mul_add}(tmp_lhs_{i}, tmp_rhs, acc_{i}_{j});\n",
428                    )?;
429                    i += 1 << ii;
430                }
431            }
432
433            Ok(())
434        }
435    }
436}
437
438pub mod aarch64 {
439    use super::*;
440    use generic::{CplxKernel, RealKernel};
441
442    pub fn codegen_f32() -> Result<String, Box<dyn std::error::Error>> {
443        let mut code = String::new();
444
445        write!(code, "pub mod f32 {{\n")?;
446        write!(code, "pub mod neon {{\n")?;
447        write!(
448            code,
449            r###"
450            use core::arch::aarch64::*;
451            use core::mem::transmute;
452            use core::mem::transmute_copy;
453
454            #[inline(always)]
455            unsafe fn set1(v: f32) -> float32x4_t {{
456                transmute([v; 4])
457            }}
458            #[inline(always)]
459            unsafe fn mul_add(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {{
460                vmlaq_f32(c, a, b)
461            }}
462            #[inline(always)]
463            unsafe fn load_1(ptr: *const f32) -> float32x4_t {{
464                transmute([*ptr; 4])
465            }}
466            #[inline(always)]
467            unsafe fn load_2(ptr: *const f32) -> float32x4_t {{
468                transmute([*(ptr as *const [f32; 2]); 2])
469            }}
470            #[inline(always)]
471            unsafe fn load_4(ptr: *const f32) -> float32x4_t {{
472                transmute(*(ptr as *const [f32; 4]))
473            }}
474
475            #[inline(always)]
476            unsafe fn store_1(ptr: *mut f32, v: float32x4_t) {{
477                *(ptr as *mut [f32; 1]) = transmute_copy(&v);
478            }}
479            #[inline(always)]
480            unsafe fn store_2(ptr: *mut f32, v: float32x4_t) {{
481                *(ptr as *mut [f32; 2]) = transmute_copy(&v);
482            }}
483            #[inline(always)]
484            unsafe fn store_4(ptr: *mut f32, v: float32x4_t) {{
485                *(ptr as *mut [f32; 4]) = transmute_copy(&v);
486            }}
487            "###
488        )?;
489        for mr in 1..=8 {
490            for nr in 1..=4 {
491                for k in (1..=16).into_iter().map(Some).chain([None]) {
492                    let kernel = RealKernel {
493                        ty: "f32",
494                        reg_ty: "float32x4_t",
495                        n: 4,
496                        mr,
497                        nr,
498                        k,
499                        target_features: "neon",
500                        load_unaligned: ["load_1", "load_2", "load_4"],
501                        store_unaligned: ["store_1", "store_2", "store_4"],
502                        set1: "set1",
503                        mul_add: "mul_add",
504                    };
505                    write!(code, "{kernel}")?;
506                }
507            }
508        }
509
510        write!(
511            code,
512            "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f32>; 4]; 8]; 17] = [\n"
513        )?;
514        for k in (1..=16).into_iter().map(Some).chain([None]) {
515            write!(code, "[\n")?;
516            for mr in 1..=8 {
517                write!(code, "[\n")?;
518                for nr in 1..=4 {
519                    write!(
520                        code,
521                        "matmul_{mr}_{nr}_{},",
522                        k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
523                    )?;
524                }
525                write!(code, "],\n")?;
526            }
527            write!(code, "],\n")?;
528        }
529        write!(code, "];\n")?;
530        write!(code, "}}")?;
531        write!(code, "}}")?;
532
533        Ok(code)
534    }
535
536    pub fn codegen_f64() -> Result<String, Box<dyn std::error::Error>> {
537        let mut code = String::new();
538
539        write!(code, "pub mod f64 {{\n")?;
540        write!(code, "pub mod neon {{\n")?;
541        write!(
542            code,
543            r###"
544            use core::arch::aarch64::*;
545            use core::mem::transmute;
546            use core::mem::transmute_copy;
547
548            #[inline(always)]
549            unsafe fn set1(v: f64) -> float64x2_t {{
550                transmute([v; 2])
551            }}
552            #[inline(always)]
553            unsafe fn mul_add(a: float64x2_t, b: float64x2_t, c: float64x2_t) -> float64x2_t {{
554                vmlaq_f64(c, a, b)
555            }}
556            #[inline(always)]
557            unsafe fn load_1(ptr: *const f64) -> float64x2_t {{
558                transmute([*ptr; 2])
559            }}
560            #[inline(always)]
561            unsafe fn load_2(ptr: *const f64) -> float64x2_t {{
562                transmute(*(ptr as *const [f64; 2]))
563            }}
564
565            #[inline(always)]
566            unsafe fn store_1(ptr: *mut f64, v: float64x2_t) {{
567                *(ptr as *mut [f64; 1]) = transmute_copy(&v);
568            }}
569            #[inline(always)]
570            unsafe fn store_2(ptr: *mut f64, v: float64x2_t) {{
571                *(ptr as *mut [f64; 2]) = transmute_copy(&v);
572            }}
573            "###
574        )?;
575        for mr in 1..=4 {
576            for nr in 1..=4 {
577                for k in (1..=16).into_iter().map(Some).chain([None]) {
578                    let kernel = RealKernel {
579                        ty: "f64",
580                        reg_ty: "float64x2_t",
581                        n: 2,
582                        mr,
583                        nr,
584                        k,
585                        target_features: "neon",
586                        load_unaligned: ["load_1", "load_2", "load_4"],
587                        store_unaligned: ["store_1", "store_2", "store_4"],
588                        set1: "set1",
589                        mul_add: "mul_add",
590                    };
591                    write!(code, "{kernel}")?;
592                }
593            }
594        }
595
596        write!(
597            code,
598            "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f64>; 4]; 4]; 17] = [\n"
599        )?;
600        for k in (1..=16).into_iter().map(Some).chain([None]) {
601            write!(code, "[\n")?;
602            for mr in 1..=4 {
603                write!(code, "[\n")?;
604                for nr in 1..=4 {
605                    write!(
606                        code,
607                        "matmul_{mr}_{nr}_{},",
608                        k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
609                    )?;
610                }
611                write!(code, "],\n")?;
612            }
613            write!(code, "],\n")?;
614        }
615        write!(code, "];\n")?;
616        write!(code, "}}")?;
617        write!(code, "}}")?;
618
619        Ok(code)
620    }
621
622    pub fn codegen_c32() -> Result<String, Box<dyn std::error::Error>> {
623        let mut code = String::new();
624
625        write!(code, "pub mod c32 {{\n")?;
626        write!(code, "pub mod neon {{ use crate::c32;\n")?;
627        write!(
628            code,
629            r###"
630            use core::arch::aarch64::*;
631            use core::mem::transmute;
632            use core::mem::transmute_copy;
633
634            #[inline(always)]
635            unsafe fn set1(v: c32) -> float32x4_t {{
636                transmute([v; 2])
637            }}
638            #[inline(always)]
639            unsafe fn mul_add(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {{
640                vcmlaq_90_f32(vcmlaq_0_f32(c, a, b), a, b)
641            }}
642            #[inline(always)]
643            unsafe fn conj_mul_add(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {{
644                vcmlaq_270_f32(vcmlaq_0_f32(c, a, b), a, b)
645            }}
646            #[inline(always)]
647            unsafe fn conj(a: float32x4_t) -> float32x4_t {{
648                transmute(veorq_u32(transmute(a), transmute([0.0, -0.0, 0.0, -0.0f32])))
649            }}
650            #[inline(always)]
651            unsafe fn load_1(ptr: *const c32) -> float32x4_t {{
652                transmute([*ptr; 2])
653            }}
654            #[inline(always)]
655            unsafe fn load_2(ptr: *const c32) -> float32x4_t {{
656                transmute(*(ptr as *const [c32; 2]))
657            }}
658
659            #[inline(always)]
660            unsafe fn store_1(ptr: *mut c32, v: float32x4_t) {{
661                *(ptr as *mut [c32; 1]) = transmute_copy(&v);
662            }}
663            #[inline(always)]
664            unsafe fn store_2(ptr: *mut c32, v: float32x4_t) {{
665                *(ptr as *mut [c32; 2]) = transmute_copy(&v);
666            }}
667
668#[inline]
669#[target_feature(enable = "neon,fcma")]
670unsafe fn vcmlaq_0_f32(mut acc: float32x4_t, lhs: float32x4_t, rhs: float32x4_t) -> float32x4_t {{
671    core::arch::asm!(
672        "fcmla {{0:v}}.4s, {{1:v}}.4s, {{2:v}}.4s, 0",
673        inout(vreg) acc,
674        in(vreg) lhs,
675        in(vreg) rhs,
676        options(pure, nomem, nostack));
677    acc
678}}
679
680#[inline]
681#[target_feature(enable = "neon,fcma")]
682unsafe fn vcmlaq_90_f32(mut acc: float32x4_t, lhs: float32x4_t, rhs: float32x4_t) -> float32x4_t {{
683    core::arch::asm!(
684        "fcmla {{0:v}}.4s, {{1:v}}.4s, {{2:v}}.4s, 90",
685        inout(vreg) acc,
686        in(vreg) lhs,
687        in(vreg) rhs,
688        options(pure, nomem, nostack));
689    acc
690}}
691
692#[inline]
693#[target_feature(enable = "neon,fcma")]
694unsafe fn vcmlaq_270_f32(mut acc: float32x4_t, lhs: float32x4_t, rhs: float32x4_t) -> float32x4_t {{
695    core::arch::asm!(
696        "fcmla {{0:v}}.4s, {{1:v}}.4s, {{2:v}}.4s, 270",
697        inout(vreg) acc,
698        in(vreg) lhs,
699        in(vreg) rhs,
700        options(pure, nomem, nostack));
701    acc
702}}
703            "###
704        )?;
705        for mr in 1..=4 {
706            for nr in 1..=4 {
707                for k in (1..=16).into_iter().map(Some).chain([None]) {
708                    let kernel = CplxKernel {
709                        ty: "c32",
710                        reg_ty: "float32x4_t",
711                        n: 2,
712                        mr,
713                        nr,
714                        k,
715                        target_features: "neon,fcma",
716                        load_unaligned: ["load_1", "load_2", "load_4"],
717                        store_unaligned: ["store_1", "store_2", "store_4"],
718                        set1: "set1",
719                        mul_add: "mul_add",
720                        conj_mul_add: "conj_mul_add",
721                        conj: "conj",
722                    };
723                    write!(code, "{kernel}")?;
724                }
725            }
726        }
727
728        write!(
729            code,
730            "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<c32>; 4]; 4]; 17] = [\n"
731        )?;
732        for k in (1..=16).into_iter().map(Some).chain([None]) {
733            write!(code, "[\n")?;
734            for mr in 1..=4 {
735                write!(code, "[\n")?;
736                for nr in 1..=4 {
737                    write!(
738                        code,
739                        "matmul_{mr}_{nr}_{},",
740                        k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
741                    )?;
742                }
743                write!(code, "],\n")?;
744            }
745            write!(code, "],\n")?;
746        }
747        write!(code, "];\n")?;
748        write!(code, "}}")?;
749        write!(code, "}}")?;
750
751        Ok(code)
752    }
753
754    pub fn codegen_c64() -> Result<String, Box<dyn std::error::Error>> {
755        let mut code = String::new();
756
757        write!(code, "pub mod c64 {{\n")?;
758        write!(code, "pub mod neon {{ use crate::c64;\n")?;
759        write!(
760            code,
761            r###"
762            use core::arch::aarch64::*;
763            use core::mem::transmute;
764
765            #[inline(always)]
766            unsafe fn set1(v: c64) -> float64x2_t {{
767                transmute(v)
768            }}
769            #[inline(always)]
770            unsafe fn mul_add(a: float64x2_t, b: float64x2_t, c: float64x2_t) -> float64x2_t {{
771                vcmlaq_90_f64(vcmlaq_0_f64(c, a, b), a, b)
772            }}
773            #[inline(always)]
774            unsafe fn conj_mul_add(a: float64x2_t, b: float64x2_t, c: float64x2_t) -> float64x2_t {{
775                vcmlaq_270_f64(vcmlaq_0_f64(c, a, b), a, b)
776            }}
777            #[inline(always)]
778            unsafe fn conj(a: float64x2_t) -> float64x2_t {{
779                transmute(veorq_u64(transmute(a), transmute([0.0, -0.0f64])))
780            }}
781            #[inline(always)]
782            unsafe fn load_1(ptr: *const c64) -> float64x2_t {{
783                transmute(*ptr)
784            }}
785
786            #[inline(always)]
787            unsafe fn store_1(ptr: *mut c64, v: float64x2_t) {{
788                *ptr = transmute(v);
789            }}
790
791#[inline]
792#[target_feature(enable = "neon,fcma")]
793unsafe fn vcmlaq_0_f64(mut acc: float64x2_t, lhs: float64x2_t, rhs: float64x2_t) -> float64x2_t {{
794    core::arch::asm!(
795        "fcmla {{0:v}}.2d, {{1:v}}.2d, {{2:v}}.2d, 0",
796        inout(vreg) acc,
797        in(vreg) lhs,
798        in(vreg) rhs,
799        options(pure, nomem, nostack));
800    acc
801}}
802
803#[inline]
804#[target_feature(enable = "neon,fcma")]
805unsafe fn vcmlaq_90_f64(mut acc: float64x2_t, lhs: float64x2_t, rhs: float64x2_t) -> float64x2_t {{
806    core::arch::asm!(
807        "fcmla {{0:v}}.2d, {{1:v}}.2d, {{2:v}}.2d, 90",
808        inout(vreg) acc,
809        in(vreg) lhs,
810        in(vreg) rhs,
811        options(pure, nomem, nostack));
812    acc
813}}
814
815#[inline]
816#[target_feature(enable = "neon,fcma")]
817unsafe fn vcmlaq_270_f64(mut acc: float64x2_t, lhs: float64x2_t, rhs: float64x2_t) -> float64x2_t {{
818    core::arch::asm!(
819        "fcmla {{0:v}}.2d, {{1:v}}.2d, {{2:v}}.2d, 270",
820        inout(vreg) acc,
821        in(vreg) lhs,
822        in(vreg) rhs,
823        options(pure, nomem, nostack));
824    acc
825}}
826            "###
827        )?;
828        for mr in 1..=2 {
829            for nr in 1..=4 {
830                for k in (1..=16).into_iter().map(Some).chain([None]) {
831                    let kernel = CplxKernel {
832                        ty: "c64",
833                        reg_ty: "float64x2_t",
834                        n: 1,
835                        mr,
836                        nr,
837                        k,
838                        target_features: "neon,fcma",
839                        load_unaligned: ["load_1", "load_2", "load_4"],
840                        store_unaligned: ["store_1", "store_2", "store_4"],
841                        set1: "set1",
842                        mul_add: "mul_add",
843                        conj_mul_add: "conj_mul_add",
844                        conj: "conj",
845                    };
846                    write!(code, "{kernel}")?;
847                }
848            }
849        }
850
851        write!(
852            code,
853            "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<c64>; 4]; 2]; 17] = [\n"
854        )?;
855        for k in (1..=16).into_iter().map(Some).chain([None]) {
856            write!(code, "[\n")?;
857            for mr in 1..=2 {
858                write!(code, "[\n")?;
859                for nr in 1..=4 {
860                    write!(
861                        code,
862                        "matmul_{mr}_{nr}_{},",
863                        k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
864                    )?;
865                }
866                write!(code, "],\n")?;
867            }
868            write!(code, "],\n")?;
869        }
870        write!(code, "];\n")?;
871        write!(code, "}}")?;
872        write!(code, "}}")?;
873
874        Ok(code)
875    }
876}
877
878pub mod x86 {
879    use super::*;
880
881    struct RealKernel {
882        ty: &'static str,
883        reg_ty: &'static str,
884        mask_ty: &'static str,
885        // register size
886        n: usize,
887        mr_div_n: usize,
888        nr: usize,
889        k: Option<usize>,
890
891        target_features: &'static str,
892        set1: &'static str,
893        load_unaligned: &'static str,
894        store_unaligned: &'static str,
895        mask_load_unaligned: Box<dyn Fn(String, String) -> String>,
896
897        mask_store_unaligned: &'static str,
898        mul_add: &'static str,
899        mul: &'static str,
900        need_mask: bool,
901    }
902
903    struct CplxKernel {
904        ty: &'static str,
905        reg_ty: &'static str,
906        mask_ty: &'static str,
907        // register size
908        n: usize,
909        mr_div_n: usize,
910        nr: usize,
911        k: Option<usize>,
912
913        target_features: &'static str,
914        set1: &'static str,
915        swap_re_im: &'static str,
916
917        load_unaligned: &'static str,
918        store_unaligned: &'static str,
919        mask_load_unaligned: Box<dyn Fn(String, String) -> String>,
920
921        mask_store_unaligned: &'static str,
922        mul_addsub: &'static str,
923        mul_subadd: &'static str,
924        xor: &'static str,
925        need_mask: bool,
926    }
927
928    impl Display for RealKernel {
929        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
930            // function that multiplies (mr_div_n * n, k) by (k, n_r)
931            // C += beta * A * B
932
933            // not exactly mr_div_n
934            // the actual number of rows is between mr_div_n * (n-1) and mr_div_n * n
935            write!(
936                f,
937                "
938            #[target_feature(enable = \"{}\")]\n",
939                self.target_features
940            )?;
941            write!(
942                f,
943                r#"pub unsafe fn matmul_{0:}_{1:}_{2:}(
944                &nano_gemm_core::MicroKernelData {{ alpha, beta, k, dst_cs, lhs_cs, rhs_rs, rhs_cs, last_mask, .. }}: &nano_gemm_core::MicroKernelData< {3:} >,
945                dst: *mut {3:},
946                lhs: *const {3:},
947                rhs: *const {3:},
948            ) {{
949"#,
950                self.mr_div_n,
951                self.nr,
952                self.k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
953                self.ty,
954            )?;
955            write!(
956                f,
957                r#"
958                #[cfg(target_arch = "x86_64")]
959                use core::arch::x86_64::*;
960                #[cfg(target_arch = "x86")]
961                use core::arch::x86::*;
962            "#
963            )?;
964
965            write!(f, "_ = k;\n")?;
966            write!(f, "type Reg = {};\n", self.reg_ty)?;
967            write!(f, "const N: isize = {};\n", self.n)?;
968            write!(
969                f,
970                "let mut acc: [[Reg; {}]; {}] = core::mem::zeroed();\n",
971                self.nr, self.mr_div_n
972            )?;
973            write!(
974                f,
975                "let mut tmp_lhs: [Reg; {}] = core::mem::zeroed();\n",
976                self.mr_div_n
977            )?;
978
979            if self.need_mask {
980                write!(
981                    f,
982                    "let last_mask = *(last_mask as *const {});\n",
983                    self.mask_ty
984                )?;
985            } else {
986                write!(f, "_ = last_mask;\n")?;
987            }
988
989            if let Some(k) = self.k {
990                for depth in 0..k {
991                    write!(f, "let depth = {depth};\n")?;
992                    for i in 0..self.mr_div_n {
993                        self.write_load_lhs(i, f)?;
994                    }
995                    for j in 0..self.nr {
996                        write!(
997                            f,
998                            "let tmp_rhs = {}(*rhs.offset(depth * rhs_rs + {j} * rhs_cs));\n",
999                            self.set1
1000                        )?;
1001
1002                        for i in 0..self.mr_div_n {
1003                            if depth > 0 {
1004                                write!(
1005                                    f,
1006                                    "acc[{i}][{j}] = {}(tmp_lhs[{i}], tmp_rhs, acc[{i}][{j}]);\n",
1007                                    self.mul_add
1008                                )?;
1009                            } else {
1010                                write!(
1011                                    f,
1012                                    "acc[{i}][{j}] = {}(tmp_lhs[{i}], tmp_rhs);\n",
1013                                    self.mul
1014                                )?;
1015                            }
1016                        }
1017                    }
1018                }
1019            } else {
1020                write!(f, "for depth in 0..k as isize {{")?;
1021                for i in 0..self.mr_div_n {
1022                    self.write_load_lhs(i, f)?;
1023                }
1024                for j in 0..self.nr {
1025                    write!(
1026                        f,
1027                        "let tmp_rhs = {}(*rhs.offset(depth * rhs_rs + {j} * rhs_cs));\n",
1028                        self.set1
1029                    )?;
1030
1031                    for i in 0..self.mr_div_n {
1032                        write!(
1033                            f,
1034                            "acc[{i}][{j}] = {}(tmp_lhs[{i}], tmp_rhs, acc[{i}][{j}]);\n",
1035                            self.mul_add
1036                        )?;
1037                    }
1038                }
1039                write!(f, "}}")?;
1040            }
1041
1042            write!(f, "if alpha == 1.0 {{")?;
1043            write!(f, "let beta = {}(beta);\n", self.set1)?;
1044            for j in 0..self.nr {
1045                for i in 0..self.mr_div_n {
1046                    write!(f, "{{")?;
1047                    write!(f, "let dst = dst.offset({i} * N + {j} * dst_cs);")?;
1048                    if i + 1 < self.mr_div_n || !self.need_mask {
1049                        write!(
1050                            f,
1051                            "{}(dst, {}(beta, acc[{i}][{j}], {}(dst)));\n",
1052                            self.store_unaligned, self.mul_add, self.load_unaligned
1053                        )?;
1054                    } else {
1055                        write!(
1056                            f,
1057                            "{}(dst, last_mask, {}(beta, acc[{i}][{j}], {}));\n",
1058                            self.mask_store_unaligned,
1059                            self.mul_add,
1060                            (self.mask_load_unaligned)(format!("dst"), "last_mask".to_string()),
1061                        )?;
1062                    }
1063                    write!(f, "}}")?;
1064                }
1065            }
1066            write!(f, "}}")?;
1067            write!(f, "else if alpha == 0.0 {{")?;
1068            write!(f, "let beta = {}(beta);\n", self.set1)?;
1069            for j in 0..self.nr {
1070                for i in 0..self.mr_div_n {
1071                    write!(f, "{{")?;
1072                    write!(f, "let dst = dst.offset({i} * N + {j} * dst_cs);")?;
1073                    if i + 1 < self.mr_div_n || !self.need_mask {
1074                        write!(
1075                            f,
1076                            "{}(dst, {}(beta, acc[{i}][{j}]));\n",
1077                            self.store_unaligned, self.mul
1078                        )?;
1079                    } else {
1080                        write!(
1081                            f,
1082                            "{}(dst, last_mask, {}(beta, acc[{i}][{j}]));\n",
1083                            self.mask_store_unaligned, self.mul,
1084                        )?;
1085                    }
1086                    write!(f, "}}")?;
1087                }
1088            }
1089            write!(f, "}}")?;
1090            write!(f, "else {{")?;
1091            write!(f, "let beta = {}(beta);\n", self.set1)?;
1092            write!(f, "let alpha = {}(alpha);\n", self.set1)?;
1093            for j in 0..self.nr {
1094                for i in 0..self.mr_div_n {
1095                    write!(f, "{{")?;
1096                    write!(f, "let dst = dst.offset({i} * N + {j} * dst_cs);")?;
1097                    if i + 1 < self.mr_div_n || !self.need_mask {
1098                        write!(
1099                            f,
1100                            "{}(dst, {}(beta, acc[{i}][{j}], {}({}(dst), alpha)));\n",
1101                            self.store_unaligned, self.mul_add, self.mul, self.load_unaligned
1102                        )?;
1103                    } else {
1104                        write!(
1105                            f,
1106                            "{}(dst, last_mask, {}(beta, acc[{i}][{j}], {}({}, alpha)));\n",
1107                            self.mask_store_unaligned,
1108                            self.mul_add,
1109                            self.mul,
1110                            (self.mask_load_unaligned)(format!("dst"), "last_mask".to_string()),
1111                        )?;
1112                    }
1113                    write!(f, "}}")?;
1114                }
1115            }
1116            write!(f, "}}")?;
1117
1118            write!(f, "}}\n")
1119        }
1120    }
1121
1122    impl Display for CplxKernel {
1123        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1124            // function that multiplies (mr_div_n * n, k) by (k, n_r)
1125            // C += beta * A * B
1126
1127            // not exactly mr_div_n
1128            // the actual number of rows is between mr_div_n * (n-1) and mr_div_n * n
1129            let Self {
1130                swap_re_im,
1131                load_unaligned,
1132                store_unaligned,
1133                mask_store_unaligned,
1134                mul_addsub,
1135                xor,
1136                ..
1137            } = self;
1138
1139            write!(
1140                f,
1141                "
1142            #[target_feature(enable = \"{}\")]\n",
1143                self.target_features
1144            )?;
1145            write!(
1146                f,
1147                r#"pub unsafe fn matmul_{0:}_{1:}_{2:}(
1148                &nano_gemm_core::MicroKernelData {{ alpha, beta, k, dst_cs, lhs_cs, rhs_rs, rhs_cs, last_mask, conj_lhs, conj_rhs }}: &nano_gemm_core::MicroKernelData<num_complex::Complex< {3:} >>,
1149                dst: *mut num_complex::Complex< {3:} >,
1150                lhs: *const num_complex::Complex< {3:} >,
1151                rhs: *const num_complex::Complex< {3:} >,
1152            ) {{
1153"#,
1154                self.mr_div_n,
1155                self.nr,
1156                self.k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
1157                self.ty,
1158            )?;
1159            write!(
1160                f,
1161                r#"
1162                #[cfg(target_arch = "x86_64")]
1163                use core::arch::x86_64::*;
1164                #[cfg(target_arch = "x86")]
1165                use core::arch::x86::*;
1166            "#
1167            )?;
1168
1169            write!(f, "_ = k;\n")?;
1170            write!(f, "type Reg = {};\n", self.reg_ty)?;
1171            write!(f, "const N: isize = {};\n", self.n)?;
1172            write!(
1173                f,
1174                "let mut acc: [[Reg; {}]; {}] = core::mem::zeroed();\n",
1175                self.nr, self.mr_div_n
1176            )?;
1177            write!(
1178                f,
1179                "let mut tmp_lhs: [Reg; {}] = core::mem::zeroed();\n",
1180                self.mr_div_n
1181            )?;
1182
1183            if self.need_mask {
1184                write!(
1185                    f,
1186                    "let last_mask = *(last_mask as *const {});\n",
1187                    self.mask_ty
1188                )?;
1189            } else {
1190                write!(f, "_ = last_mask;\n")?;
1191            }
1192
1193            for idx in 0..2 {
1194                if idx == 0 {
1195                    write!(f, "if conj_lhs == conj_rhs {{\n")?;
1196                } else {
1197                    write!(f, "else {{\n")?;
1198                }
1199
1200                let mul_add = if idx == 0 {
1201                    self.mul_subadd
1202                } else {
1203                    self.mul_addsub
1204                };
1205
1206                if let Some(k) = self.k {
1207                    for depth in 0..k {
1208                        write!(f, "let depth = {depth};\n")?;
1209                        for i in 0..self.mr_div_n {
1210                            self.write_load_lhs(i, f)?;
1211                        }
1212
1213                        for j in 0..self.nr {
1214                            write!(
1215                            f,
1216                            "let tmp_rhs = {}((*rhs.offset(depth * rhs_rs + {j} * rhs_cs)).re);\n",
1217                            self.set1
1218                        )?;
1219
1220                            for i in 0..self.mr_div_n {
1221                                write!(
1222                                f,
1223                                "acc[{i}][{j}] = {mul_add}(tmp_lhs[{i}], tmp_rhs, acc[{i}][{j}]);\n",
1224                            )?;
1225                            }
1226                        }
1227                        for i in 0..self.mr_div_n {
1228                            write!(f, "tmp_lhs[{i}] = {}(tmp_lhs[{i}]);", self.swap_re_im)?;
1229                        }
1230                        for j in 0..self.nr {
1231                            write!(
1232                            f,
1233                            "let tmp_rhs = {}((*rhs.offset(depth * rhs_rs + {j} * rhs_cs)).im);\n",
1234                            self.set1
1235                        )?;
1236
1237                            for i in 0..self.mr_div_n {
1238                                write!(
1239                                f,
1240                                "acc[{i}][{j}] = {mul_add}(tmp_lhs[{i}], tmp_rhs, acc[{i}][{j}]);\n",
1241                            )?;
1242                            }
1243                        }
1244                    }
1245                } else {
1246                    write!(f, "for depth in 0..k as isize {{")?;
1247                    for i in 0..self.mr_div_n {
1248                        self.write_load_lhs(i, f)?;
1249                    }
1250
1251                    for j in 0..self.nr {
1252                        write!(
1253                            f,
1254                            "let tmp_rhs = {}((*rhs.offset(depth * rhs_rs + {j} * rhs_cs)).re);\n",
1255                            self.set1
1256                        )?;
1257
1258                        for i in 0..self.mr_div_n {
1259                            write!(
1260                            f,
1261                            "acc[{i}][{j}] = {mul_add}(tmp_lhs[{i}], tmp_rhs, acc[{i}][{j}]);\n",
1262                        )?;
1263                        }
1264                    }
1265                    for i in 0..self.mr_div_n {
1266                        write!(f, "tmp_lhs[{i}] = {}(tmp_lhs[{i}]);", self.swap_re_im)?;
1267                    }
1268                    for j in 0..self.nr {
1269                        write!(
1270                            f,
1271                            "let tmp_rhs = {}((*rhs.offset(depth * rhs_rs + {j} * rhs_cs)).im);\n",
1272                            self.set1
1273                        )?;
1274
1275                        for i in 0..self.mr_div_n {
1276                            write!(
1277                            f,
1278                            "acc[{i}][{j}] = {mul_add}(tmp_lhs[{i}], tmp_rhs, acc[{i}][{j}]);\n",
1279                        )?;
1280                        }
1281                    }
1282
1283                    write!(f, "}}")?;
1284                }
1285                write!(f, "}}")?;
1286            }
1287
1288            write!(
1289                f,
1290                "let mask = XOR_MASKS[conj_lhs as usize + 2 * conj_rhs as usize];"
1291            )?;
1292            for j in 0..self.nr {
1293                for i in 0..self.mr_div_n {
1294                    write!(f, "acc[{i}][{j}] = core::mem::transmute({xor}(core::mem::transmute(acc[{i}][{j}]), core::mem::transmute(mask)));")?;
1295                }
1296            }
1297
1298            write!(
1299                f,
1300                "if alpha == (num_complex::Complex {{ re: 1.0, im: 0.0 }}) {{"
1301            )?;
1302            write!(f, "let beta_re = {}(beta.re);\n", self.set1)?;
1303            write!(f, "let beta_im = {}(beta.im);\n", self.set1)?;
1304            for j in 0..self.nr {
1305                for i in 0..self.mr_div_n {
1306                    write!(f, "{{")?;
1307                    write!(
1308                        f,
1309                        "let dst = dst.offset({i} * N + {j} * dst_cs) as *mut {};",
1310                        self.ty,
1311                    )?;
1312                    if i + 1 < self.mr_div_n || !self.need_mask {
1313                        write!(
1314                            f,
1315                            "{store_unaligned}(
1316                            dst,
1317                            {mul_addsub}(
1318                                {swap_re_im}(acc[{i}][{j}]),
1319                                beta_im,
1320                                {mul_addsub}(
1321                                    acc[{i}][{j}],
1322                                    beta_re,
1323                                    {load_unaligned}(dst),
1324                                ),
1325                            ),
1326                        );\n",
1327                        )?;
1328                    } else {
1329                        write!(
1330                            f,
1331                            "{mask_store_unaligned}(
1332                            dst,
1333                            last_mask,
1334                            {mul_addsub}(
1335                                {swap_re_im}(acc[{i}][{j}]),
1336                                beta_im,
1337                                {mul_addsub}(
1338                                    acc[{i}][{j}],
1339                                    beta_re,
1340                                    {},
1341                                ),
1342                            ),
1343                        );\n",
1344                            (self.mask_load_unaligned)(format!("dst"), "last_mask".to_string()),
1345                        )?;
1346                    }
1347                    write!(f, "}}")?;
1348                }
1349            }
1350            write!(f, "}}")?;
1351
1352            write!(
1353                f,
1354                "else if alpha == (num_complex::Complex {{ re: 0.0, im: 0.0 }}) {{"
1355            )?;
1356            write!(f, "let beta_re = {}(beta.re);\n", self.set1)?;
1357            write!(f, "let beta_im = {}(beta.im);\n", self.set1)?;
1358            for j in 0..self.nr {
1359                for i in 0..self.mr_div_n {
1360                    write!(f, "{{")?;
1361                    write!(
1362                        f,
1363                        "let dst = dst.offset({i} * N + {j} * dst_cs) as *mut {};",
1364                        self.ty,
1365                    )?;
1366                    if i + 1 < self.mr_div_n || !self.need_mask {
1367                        write!(
1368                            f,
1369                            "{store_unaligned}(
1370                            dst,
1371                            {mul_addsub}(
1372                                {swap_re_im}(acc[{i}][{j}]),
1373                                beta_im,
1374                                {mul_addsub}(
1375                                    acc[{i}][{j}],
1376                                    beta_re,
1377                                    core::mem::zeroed(),
1378                                ),
1379                            ),
1380                        );\n",
1381                        )?;
1382                    } else {
1383                        write!(
1384                            f,
1385                            "{mask_store_unaligned}(
1386                            dst,
1387                            last_mask,
1388                            {mul_addsub}(
1389                                {swap_re_im}(acc[{i}][{j}]),
1390                                beta_im,
1391                                {mul_addsub}(
1392                                    acc[{i}][{j}],
1393                                    beta_re,
1394                                    core::mem::zeroed(),
1395                                ),
1396                            ),
1397                        );\n"
1398                        )?;
1399                    }
1400                    write!(f, "}}")?;
1401                }
1402            }
1403            write!(f, "}}")?;
1404            write!(f, "else {{")?;
1405            write!(f, "let beta_re = {}(beta.re);\n", self.set1)?;
1406            write!(f, "let beta_im = {}(beta.im);\n", self.set1)?;
1407            write!(f, "let alpha_re = {}(alpha.re);\n", self.set1)?;
1408            write!(f, "let alpha_im = {}(alpha.im);\n", self.set1)?;
1409            for j in 0..self.nr {
1410                for i in 0..self.mr_div_n {
1411                    write!(f, "{{")?;
1412                    write!(
1413                        f,
1414                        "let dst = dst.offset({i} * N + {j} * dst_cs) as *mut {};",
1415                        self.ty,
1416                    )?;
1417                    if i + 1 < self.mr_div_n || !self.need_mask {
1418                        write!(
1419                            f,
1420                            "let dst_conj = core::mem::transmute({xor}(
1421                            core::mem::transmute({load_unaligned}(dst)),
1422                            core::mem::transmute(XOR_MASKS[1]),
1423                        ));"
1424                        )?;
1425
1426                        write!(
1427                            f,
1428                            "{store_unaligned}(
1429                            dst,
1430                            {mul_addsub}(
1431                                {swap_re_im}(acc[{i}][{j}]),
1432                                beta_im,
1433                                {mul_addsub}(
1434                                    acc[{i}][{j}],
1435                                    beta_re,
1436                                    {mul_addsub}(
1437                                        {swap_re_im}(dst_conj),
1438                                        alpha_im,
1439                                        {mul_addsub}(
1440                                            dst_conj,
1441                                            alpha_re,
1442                                            core::mem::zeroed(),
1443                                        ),
1444                                    ),
1445                                ),
1446                            ),
1447                        );\n",
1448                        )?;
1449                    } else {
1450                        write!(
1451                            f,
1452                            "let dst_conj = core::mem::transmute({xor}(
1453                            core::mem::transmute({}),
1454                            core::mem::transmute(XOR_MASKS[1]),
1455                        ));",
1456                            (self.mask_load_unaligned)(format!("dst"), "last_mask".to_string())
1457                        )?;
1458
1459                        write!(
1460                            f,
1461                            "{mask_store_unaligned}(
1462                            dst,
1463                            last_mask,
1464                            {mul_addsub}(
1465                                {swap_re_im}(acc[{i}][{j}]),
1466                                beta_im,
1467                                {mul_addsub}(
1468                                    acc[{i}][{j}],
1469                                    beta_re,
1470                                    {mul_addsub}(
1471                                        {swap_re_im}(dst_conj),
1472                                        alpha_im,
1473                                        {mul_addsub}(
1474                                            dst_conj,
1475                                            alpha_re,
1476                                            core::mem::zeroed(),
1477                                        ),
1478                                    ),
1479                                ),
1480                            ),
1481                        );\n"
1482                        )?;
1483                    }
1484                    write!(f, "}}")?;
1485                }
1486            }
1487            write!(f, "}}")?;
1488
1489            write!(f, "}}\n")
1490        }
1491    }
1492
1493    impl CplxKernel {
1494        fn write_load_lhs(
1495            &self,
1496            i: usize,
1497            f: &mut std::fmt::Formatter<'_>,
1498        ) -> Result<(), std::fmt::Error> {
1499            Ok(if i + 1 < self.mr_div_n || !self.need_mask {
1500                write!(
1501                    f,
1502                    "tmp_lhs[{i}] = {}(lhs.offset(depth * lhs_cs + {i} * N) as *const {});\n",
1503                    self.load_unaligned, self.ty,
1504                )?;
1505            } else {
1506                write!(
1507                    f,
1508                    "tmp_lhs[{i}] = {};\n",
1509                    (self.mask_load_unaligned)(
1510                        format!("lhs.offset(depth * lhs_cs + {i} * N) as *const {}", self.ty,),
1511                        "last_mask".to_string()
1512                    ),
1513                )?;
1514            })
1515        }
1516    }
1517    impl RealKernel {
1518        fn write_load_lhs(
1519            &self,
1520            i: usize,
1521            f: &mut std::fmt::Formatter<'_>,
1522        ) -> Result<(), std::fmt::Error> {
1523            Ok(if i + 1 < self.mr_div_n || !self.need_mask {
1524                write!(
1525                    f,
1526                    "tmp_lhs[{i}] = {}(lhs.offset(depth * lhs_cs + {i} * N));\n",
1527                    self.load_unaligned
1528                )?;
1529            } else {
1530                write!(
1531                    f,
1532                    "tmp_lhs[{i}] = {};\n",
1533                    (self.mask_load_unaligned)(
1534                        format!("lhs.offset(depth * lhs_cs + {i} * N)"),
1535                        "last_mask".to_string()
1536                    ),
1537                )?;
1538            })
1539        }
1540    }
1541
1542    pub fn codegen_f32() -> Result<String, Box<dyn std::error::Error>> {
1543        let mut code = String::new();
1544
1545        write!(code, "pub mod f32 {{\n")?;
1546        write!(code, "pub mod f32x1 {{\n")?;
1547        {
1548            write!(
1549                code,
1550                "pub const MR_DIV_N: usize = 1; pub const NR: usize = 4; pub const N: usize = 1;"
1551            )?;
1552
1553            for mr_div_n in 1..=1 {
1554                for nr in 1..=4 {
1555                    for k in (1..=16).into_iter().map(Some).chain([None]) {
1556                        let kernel = RealKernel {
1557                            ty: "f32",
1558                            reg_ty: "__m128",
1559                            mask_ty: "__m128i",
1560                            mr_div_n,
1561                            nr,
1562                            k,
1563                            target_features: "avx,avx2,fma",
1564                            n: 1,
1565                            set1: "crate::x86::splat_1s",
1566                            load_unaligned: "_mm_load_ss",
1567                            store_unaligned: "_mm_store_ss",
1568                            mask_load_unaligned: Box::new(|_, _| String::new()),
1569                            mask_store_unaligned: "",
1570                            mul_add: "_mm_fmadd_ss",
1571                            mul: "_mm_mul_ss",
1572                            need_mask: false,
1573                        };
1574
1575                        write!(code, "{kernel}")?;
1576                    }
1577                }
1578            }
1579
1580            write!(
1581                code,
1582                "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f32>; 4]; 1]; 17] = [\n"
1583            )?;
1584            for k in (1..=16).into_iter().map(Some).chain([None]) {
1585                write!(code, "[\n")?;
1586                for mr_div_n in 1..=1 {
1587                    write!(code, "[\n")?;
1588                    for nr in 1..=4 {
1589                        write!(
1590                            code,
1591                            "matmul_{mr_div_n}_{nr}_{},",
1592                            k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
1593                        )?;
1594                    }
1595                    write!(code, "],\n")?;
1596                }
1597                write!(code, "],\n")?;
1598            }
1599            write!(code, "];\n")?;
1600        }
1601        write!(code, "}}\n")?;
1602        write!(code, "pub mod f32x2 {{\n")?;
1603        {
1604            write!(
1605                code,
1606                "pub const MR_DIV_N: usize = 1; pub const NR: usize = 4; pub const N: usize = 2;"
1607            )?;
1608            for mr_div_n in 1..=1 {
1609                for nr in 1..=4 {
1610                    for k in (1..=16).into_iter().map(Some).chain([None]) {
1611                        let kernel = RealKernel {
1612                            need_mask: false,
1613                            ty: "f32",
1614                            reg_ty: "__m128",
1615                            mask_ty: "__m128i",
1616                            mr_div_n,
1617                            nr,
1618                            k,
1619                            target_features: "avx,avx2,fma",
1620                            n: 2,
1621                            set1: "_mm_set1_ps",
1622                            load_unaligned: "crate::x86::load_2s",
1623                            store_unaligned: "crate::x86::store_2s",
1624                            mask_load_unaligned: Box::new(|_, _| String::new()),
1625                            mask_store_unaligned: "",
1626                            mul_add: "_mm_fmadd_ps",
1627                            mul: "_mm_mul_ps",
1628                        };
1629
1630                        write!(code, "{kernel}")?;
1631                    }
1632                }
1633            }
1634
1635            write!(
1636                code,
1637                "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f32>; 4]; 1]; 17] = [\n"
1638            )?;
1639            for k in (1..=16).into_iter().map(Some).chain([None]) {
1640                write!(code, "[\n")?;
1641                for mr_div_n in 1..=1 {
1642                    write!(code, "[\n")?;
1643                    for nr in 1..=4 {
1644                        write!(
1645                            code,
1646                            "matmul_{mr_div_n}_{nr}_{},",
1647                            k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
1648                        )?;
1649                    }
1650                    write!(code, "],\n")?;
1651                }
1652                write!(code, "],\n")?;
1653            }
1654            write!(code, "];\n")?;
1655        }
1656        write!(code, "}}\n")?;
1657
1658        write!(code, "pub mod f32x4 {{\n")?;
1659        {
1660            write!(
1661                code,
1662                "pub const MR_DIV_N: usize = 1; pub const NR: usize = 4; pub const N: usize = 4;"
1663            )?;
1664            for mr_div_n in 1..=1 {
1665                for nr in 1..=4 {
1666                    for k in (1..=16).into_iter().map(Some).chain([None]) {
1667                        let kernel = RealKernel {
1668                            need_mask: true,
1669                            ty: "f32",
1670                            reg_ty: "__m128",
1671                            mask_ty: "__m128i",
1672                            mr_div_n,
1673                            nr,
1674                            k,
1675                            target_features: "avx,avx2,fma",
1676                            n: 4,
1677                            set1: "_mm_set1_ps",
1678                            load_unaligned: "_mm_loadu_ps",
1679                            store_unaligned: "_mm_storeu_ps",
1680                            mask_load_unaligned: Box::new(|ptr, mask| {
1681                                format!("_mm_maskload_ps({ptr}, {mask})")
1682                            }),
1683                            mask_store_unaligned: "_mm_maskstore_ps",
1684                            mul_add: "_mm_fmadd_ps",
1685                            mul: "_mm_mul_ps",
1686                        };
1687
1688                        write!(code, "{kernel}")?;
1689                    }
1690                }
1691            }
1692
1693            write!(
1694                code,
1695                "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f32>; 4]; 1]; 17] = [\n"
1696            )?;
1697            for k in (1..=16).into_iter().map(Some).chain([None]) {
1698                write!(code, "[\n")?;
1699                for mr_div_n in 1..=1 {
1700                    write!(code, "[\n")?;
1701                    for nr in 1..=4 {
1702                        write!(
1703                            code,
1704                            "matmul_{mr_div_n}_{nr}_{},",
1705                            k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
1706                        )?;
1707                    }
1708                    write!(code, "],\n")?;
1709                }
1710                write!(code, "],\n")?;
1711            }
1712            write!(code, "];\n")?;
1713            write!(
1714                code,
1715                "
1716            pub static MASKS: [crate::x86::__m128i; 4] = unsafe {{ core::mem::transmute([
1717                [u32::MAX, u32::MAX, u32::MAX, u32::MAX],
1718
1719                [u32::MAX, 0, 0, 0],
1720                [u32::MAX, u32::MAX, 0, 0],
1721                [u32::MAX, u32::MAX, u32::MAX, 0],
1722            ]) }};
1723        "
1724            )?;
1725        }
1726        write!(code, "}}\n")?;
1727        write!(code, "pub mod avx {{\n")?;
1728        {
1729            write!(
1730                code,
1731                "pub const MR_DIV_N: usize = 2; pub const NR: usize = 4; pub const N: usize = 8;"
1732            )?;
1733
1734            for mr_div_n in 1..=2 {
1735                for nr in 1..=4 {
1736                    for k in (1..=16).into_iter().map(Some).chain([None]) {
1737                        let kernel = RealKernel {
1738                            need_mask: true,
1739                            ty: "f32",
1740                            reg_ty: "__m256",
1741                            mask_ty: "__m256i",
1742                            mr_div_n,
1743                            nr,
1744                            k,
1745                            target_features: "avx,avx2,fma",
1746                            n: 8,
1747                            set1: "_mm256_set1_ps",
1748                            load_unaligned: "_mm256_loadu_ps",
1749                            store_unaligned: "_mm256_storeu_ps",
1750                            mask_load_unaligned: Box::new(|ptr, mask| {
1751                                format!("_mm256_maskload_ps({ptr}, {mask})")
1752                            }),
1753                            mask_store_unaligned: "_mm256_maskstore_ps",
1754                            mul_add: "_mm256_fmadd_ps",
1755                            mul: "_mm256_mul_ps",
1756                        };
1757
1758                        write!(code, "{kernel}")?;
1759                    }
1760                }
1761            }
1762
1763            write!(
1764                code,
1765                "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f32>; 4]; 2]; 17] = [\n"
1766            )?;
1767            for k in (1..=16).into_iter().map(Some).chain([None]) {
1768                write!(code, "[\n")?;
1769                for mr_div_n in 1..=2 {
1770                    write!(code, "[\n")?;
1771                    for nr in 1..=4 {
1772                        write!(
1773                            code,
1774                            "matmul_{mr_div_n}_{nr}_{},",
1775                            k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
1776                        )?;
1777                    }
1778                    write!(code, "],\n")?;
1779                }
1780                write!(code, "],\n")?;
1781            }
1782            write!(code, "];\n")?;
1783            write!(
1784                code,
1785                "
1786            pub static MASKS: [crate::x86::__m256i; 8] = unsafe {{ core::mem::transmute([
1787                [u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX],
1788
1789                [u32::MAX, 0, 0, 0, 0, 0, 0, 0],
1790                [u32::MAX, u32::MAX, 0, 0, 0, 0, 0, 0],
1791                [u32::MAX, u32::MAX, u32::MAX, 0, 0, 0, 0, 0],
1792                [u32::MAX, u32::MAX, u32::MAX, u32::MAX, 0, 0, 0, 0],
1793                [u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, 0, 0, 0],
1794                [u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, 0, 0],
1795                [u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, 0],
1796            ]) }};
1797        "
1798            )?;
1799        }
1800        write!(code, "}}\n")?;
1801
1802        write!(code, "#[cfg(feature = \"x86-v4\")] pub mod avx512 {{\n")?;
1803        {
1804            write!(
1805                code,
1806                "pub const MR_DIV_N: usize = 2; pub const NR: usize = 4; pub const N: usize = 16;"
1807            )?;
1808
1809            for mr_div_n in 1..=2 {
1810                for nr in 1..=4 {
1811                    for k in (1..=16).into_iter().map(Some).chain([None]) {
1812                        let kernel = RealKernel {
1813                            need_mask: true,
1814                            ty: "f32",
1815                            reg_ty: "__m512",
1816                            mask_ty: "u16",
1817                            mr_div_n,
1818                            nr,
1819                            k,
1820                            target_features: "avx512f",
1821                            n: 16,
1822                            set1: "_mm512_set1_ps",
1823                            load_unaligned: "_mm512_loadu_ps",
1824                            store_unaligned: "_mm512_storeu_ps",
1825                            mask_load_unaligned: Box::new(|ptr, mask| {
1826                                format!("_mm512_maskz_loadu_ps({mask}, {ptr})")
1827                            }),
1828                            mask_store_unaligned: "_mm512_mask_storeu_ps",
1829                            mul_add: "_mm512_fmadd_ps",
1830                            mul: "_mm512_mul_ps",
1831                        };
1832
1833                        write!(code, "{kernel}")?;
1834                    }
1835                }
1836            }
1837
1838            write!(
1839                code,
1840                "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f32>; 4]; 2]; 17] = [\n"
1841            )?;
1842            for k in (1..=16).into_iter().map(Some).chain([None]) {
1843                write!(code, "[\n")?;
1844                for mr_div_n in 1..=2 {
1845                    write!(code, "[\n")?;
1846                    for nr in 1..=4 {
1847                        write!(
1848                            code,
1849                            "matmul_{mr_div_n}_{nr}_{},",
1850                            k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
1851                        )?;
1852                    }
1853                    write!(code, "],\n")?;
1854                }
1855                write!(code, "],\n")?;
1856            }
1857            write!(code, "];\n")?;
1858            write!(
1859                code,
1860                "
1861            pub static MASKS: [u16; 16] = [
1862                0b1111_1111_1111_1111,
1863                0b0000_0000_0000_0001,
1864                0b0000_0000_0000_0011,
1865                0b0000_0000_0000_0111,
1866                0b0000_0000_0000_1111,
1867                0b0000_0000_0001_1111,
1868                0b0000_0000_0011_1111,
1869                0b0000_0000_0111_1111,
1870                0b0000_0000_1111_1111,
1871                0b0000_0001_1111_1111,
1872                0b0000_0011_1111_1111,
1873                0b0000_0111_1111_1111,
1874                0b0000_1111_1111_1111,
1875                0b0001_1111_1111_1111,
1876                0b0011_1111_1111_1111,
1877                0b0111_1111_1111_1111,
1878            ];
1879        "
1880            )?;
1881        }
1882        write!(code, "}}\n")?;
1883        write!(code, "}}\n")?;
1884
1885        Ok(code)
1886    }
1887
1888    pub fn codegen_f64() -> Result<String, Box<dyn std::error::Error>> {
1889        let mut code = String::new();
1890
1891        write!(code, "pub mod f64 {{\n")?;
1892        write!(code, "pub mod f64x1 {{\n")?;
1893        {
1894            write!(
1895                code,
1896                "pub const MR_DIV_N: usize = 1; pub const NR: usize = 4; pub const N: usize = 1;"
1897            )?;
1898
1899            for mr_div_n in 1..=1 {
1900                for nr in 1..=4 {
1901                    for k in (1..=16).into_iter().map(Some).chain([None]) {
1902                        let kernel = RealKernel {
1903                            need_mask: false,
1904                            ty: "f64",
1905                            reg_ty: "__m128d",
1906                            mask_ty: "__m128i",
1907                            mr_div_n,
1908                            nr,
1909                            k,
1910                            target_features: "avx,avx2,fma",
1911                            n: 1,
1912                            set1: "crate::x86::splat_1d",
1913                            load_unaligned: "_mm_load_sd",
1914                            store_unaligned: "_mm_store_sd",
1915                            mask_load_unaligned: Box::new(|_, _| String::new()),
1916                            mask_store_unaligned: "",
1917                            mul_add: "_mm_fmadd_sd",
1918                            mul: "_mm_mul_sd",
1919                        };
1920
1921                        write!(code, "{kernel}")?;
1922                    }
1923                }
1924            }
1925
1926            write!(
1927                code,
1928                "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f64>; 4]; 1]; 17] = [\n"
1929            )?;
1930            for k in (1..=16).into_iter().map(Some).chain([None]) {
1931                write!(code, "[\n")?;
1932                for mr_div_n in 1..=1 {
1933                    write!(code, "[\n")?;
1934                    for nr in 1..=4 {
1935                        write!(
1936                            code,
1937                            "matmul_{mr_div_n}_{nr}_{},",
1938                            k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
1939                        )?;
1940                    }
1941                    write!(code, "],\n")?;
1942                }
1943                write!(code, "],\n")?;
1944            }
1945            write!(code, "];\n")?;
1946        }
1947        write!(code, "}}\n")?;
1948        write!(code, "pub mod f64x2 {{\n")?;
1949        {
1950            write!(
1951                code,
1952                "pub const MR_DIV_N: usize = 1; pub const NR: usize = 4; pub const N: usize = 2;"
1953            )?;
1954            for mr_div_n in 1..=1 {
1955                for nr in 1..=4 {
1956                    for k in (1..=16).into_iter().map(Some).chain([None]) {
1957                        let kernel = RealKernel {
1958                            need_mask: false,
1959                            ty: "f64",
1960                            reg_ty: "__m128d",
1961                            mask_ty: "__m128i",
1962                            mr_div_n,
1963                            nr,
1964                            k,
1965                            target_features: "avx,avx2,fma",
1966                            n: 2,
1967                            set1: "_mm_set1_pd",
1968                            load_unaligned: "_mm_loadu_pd",
1969                            store_unaligned: "_mm_storeu_pd",
1970                            mask_load_unaligned: Box::new(|_, _| String::new()),
1971                            mask_store_unaligned: "",
1972                            mul_add: "_mm_fmadd_pd",
1973                            mul: "_mm_mul_pd",
1974                        };
1975
1976                        write!(code, "{kernel}")?;
1977                    }
1978                }
1979            }
1980
1981            write!(
1982                code,
1983                "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f64>; 4]; 1]; 17] = [\n"
1984            )?;
1985            for k in (1..=16).into_iter().map(Some).chain([None]) {
1986                write!(code, "[\n")?;
1987                for mr_div_n in 1..=1 {
1988                    write!(code, "[\n")?;
1989                    for nr in 1..=4 {
1990                        write!(
1991                            code,
1992                            "matmul_{mr_div_n}_{nr}_{},",
1993                            k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
1994                        )?;
1995                    }
1996                    write!(code, "],\n")?;
1997                }
1998                write!(code, "],\n")?;
1999            }
2000            write!(code, "];\n")?;
2001        }
2002        write!(code, "}}\n")?;
2003
2004        write!(
2005            code,
2006            "
2007        pub mod avx {{\n"
2008        )?;
2009        {
2010            write!(
2011                code,
2012                "pub const MR_DIV_N: usize = 2; pub const NR: usize = 4; pub const N: usize = 4;"
2013            )?;
2014
2015            {
2016                for mr_div_n in 1..=2 {
2017                    for nr in 1..=4 {
2018                        for k in (1..=16).into_iter().map(Some).chain([None]) {
2019                            let kernel = RealKernel {
2020                                need_mask: true,
2021                                ty: "f64",
2022                                reg_ty: "__m256d",
2023                                mask_ty: "__m256i",
2024                                mr_div_n,
2025                                nr,
2026                                k,
2027                                target_features: "avx,avx2,fma",
2028                                n: 4,
2029                                set1: "_mm256_set1_pd",
2030                                load_unaligned: "_mm256_loadu_pd",
2031                                store_unaligned: "_mm256_storeu_pd",
2032                                mask_load_unaligned: Box::new(|ptr, mask| {
2033                                    format!("_mm256_maskload_pd({ptr}, {mask})")
2034                                }),
2035                                mask_store_unaligned: "_mm256_maskstore_pd",
2036                                mul_add: "_mm256_fmadd_pd",
2037                                mul: "_mm256_mul_pd",
2038                            };
2039
2040                            write!(code, "{kernel}")?;
2041                        }
2042                    }
2043                }
2044
2045                write!(
2046                    code,
2047                    "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f64>; 4]; 2]; 17] = [\n"
2048                )?;
2049                for k in (1..=16).into_iter().map(Some).chain([None]) {
2050                    write!(code, "[\n")?;
2051                    for mr_div_n in 1..=2 {
2052                        write!(code, "[\n")?;
2053                        for nr in 1..=4 {
2054                            write!(
2055                                code,
2056                                "matmul_{mr_div_n}_{nr}_{},",
2057                                k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2058                            )?;
2059                        }
2060                        write!(code, "],\n")?;
2061                    }
2062                    write!(code, "],\n")?;
2063                }
2064                write!(code, "];\n")?;
2065                write!(
2066                    code,
2067                    "
2068            pub static MASKS: [crate::x86::__m256i; 4] = unsafe {{ core::mem::transmute([
2069                [u64::MAX, u64::MAX, u64::MAX, u64::MAX],
2070
2071                [u64::MAX, 0, 0, 0],
2072                [u64::MAX, u64::MAX, 0, 0],
2073                [u64::MAX, u64::MAX, u64::MAX, 0],
2074            ]) }};
2075        "
2076                )?;
2077            }
2078        }
2079        write!(code, "}}\n")?;
2080
2081        write!(
2082            code,
2083            "
2084        #[cfg(feature = \"x86-v4\")]
2085        pub mod avx512 {{\n"
2086        )?;
2087        {
2088            write!(
2089                code,
2090                "pub const MR_DIV_N: usize = 2; pub const NR: usize = 4; pub const N: usize = 8;"
2091            )?;
2092
2093            for mr_div_n in 1..=2 {
2094                for nr in 1..=4 {
2095                    for k in (1..=16).into_iter().map(Some).chain([None]) {
2096                        let kernel = RealKernel {
2097                            need_mask: true,
2098                            ty: "f64",
2099                            reg_ty: "__m512d",
2100                            mask_ty: "u8",
2101                            mr_div_n,
2102                            nr,
2103                            k,
2104                            target_features: "avx512f",
2105                            n: 8,
2106                            set1: "_mm512_set1_pd",
2107                            load_unaligned: "_mm512_loadu_pd",
2108                            store_unaligned: "_mm512_storeu_pd",
2109                            mask_load_unaligned: Box::new(|ptr, mask| {
2110                                format!("_mm512_maskz_loadu_pd({mask}, {ptr})")
2111                            }),
2112                            mask_store_unaligned: "_mm512_mask_storeu_pd",
2113                            mul_add: "_mm512_fmadd_pd",
2114                            mul: "_mm512_mul_pd",
2115                        };
2116
2117                        write!(code, "{kernel}")?;
2118                    }
2119                }
2120            }
2121
2122            write!(
2123                code,
2124                "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f64>; 4]; 2]; 17] = [\n"
2125            )?;
2126            for k in (1..=16).into_iter().map(Some).chain([None]) {
2127                write!(code, "[\n")?;
2128                for mr_div_n in 1..=2 {
2129                    write!(code, "[\n")?;
2130                    for nr in 1..=4 {
2131                        write!(
2132                            code,
2133                            "matmul_{mr_div_n}_{nr}_{},",
2134                            k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2135                        )?;
2136                    }
2137                    write!(code, "],\n")?;
2138                }
2139                write!(code, "],\n")?;
2140            }
2141            write!(code, "];\n")?;
2142            write!(
2143                code,
2144                "
2145            pub static MASKS: [u8; 8] = [
2146                0b1111_1111,
2147                0b0000_0001,
2148                0b0000_0011,
2149                0b0000_0111,
2150                0b0000_1111,
2151                0b0001_1111,
2152                0b0011_1111,
2153                0b0111_1111,
2154            ];
2155        "
2156            )?;
2157        }
2158        write!(code, "}}\n")?;
2159        write!(code, "}}\n")?;
2160
2161        Ok(code)
2162    }
2163
2164    pub fn codegen_c32() -> Result<String, Box<dyn std::error::Error>> {
2165        let mut code = String::new();
2166
2167        write!(code, "pub mod c32 {{\n")?;
2168        write!(code, "pub mod c32x1 {{\n")?;
2169        {
2170            write!(
2171                code,
2172                "pub const MR_DIV_N: usize = 1; pub const NR: usize = 2; pub const N: usize = 1;"
2173            )?;
2174            write!(
2175                code,
2176                "const XOR_MASKS: [crate::x86::__m128; 4] = unsafe {{ core::mem::transmute([
2177                   [-0.0, -0.0, -0.0, -0.0f32], // no conj
2178                   [ 0.0, -0.0,  0.0, -0.0f32], // conj lhs
2179                   [ 0.0,  0.0,  0.0,  0.0f32], // conj rhs
2180                   [-0.0,  0.0, -0.0,  0.0f32], // conj both
2181                ]) }};\n"
2182            )?;
2183
2184            for mr_div_n in 1..=1 {
2185                for nr in 1..=2 {
2186                    for k in (1..=16).into_iter().map(Some).chain([None]) {
2187                        let kernel = CplxKernel {
2188                            need_mask: false,
2189                            ty: "f32",
2190                            reg_ty: "__m128",
2191                            mask_ty: "__m128i",
2192                            mr_div_n,
2193                            nr,
2194                            k,
2195                            target_features: "avx,avx2,fma",
2196                            n: 1,
2197                            set1: "_mm_set1_ps",
2198                            load_unaligned: "crate::x86::load_2s",
2199                            store_unaligned: "crate::x86::store_2s",
2200                            mask_load_unaligned: Box::new(|ptr, mask| {
2201                                format!("_mm_maskload_ps({ptr}, {mask})")
2202                            }),
2203                            mask_store_unaligned: "_mm_maskstore_ps",
2204                            swap_re_im: "_mm_permute_ps::<0b10_11_00_01>",
2205                            mul_addsub: "_mm_fmsubadd_ps",
2206                            mul_subadd: "_mm_fmaddsub_ps",
2207                            xor: "_mm_xor_ps",
2208                        };
2209
2210                        write!(code, "{kernel}")?;
2211                    }
2212                }
2213            }
2214
2215            write!(
2216                code,
2217                "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<num_complex::Complex<f32>>; 2]; 1]; 17] = [\n"
2218            )?;
2219            for k in (1..=16).into_iter().map(Some).chain([None]) {
2220                write!(code, "[\n")?;
2221                for mr_div_n in 1..=1 {
2222                    write!(code, "[\n")?;
2223                    for nr in 1..=2 {
2224                        write!(
2225                            code,
2226                            "matmul_{mr_div_n}_{nr}_{},",
2227                            k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2228                        )?;
2229                    }
2230                    write!(code, "],\n")?;
2231                }
2232                write!(code, "],\n")?;
2233            }
2234            write!(code, "];\n")?;
2235        }
2236        write!(code, "}}\n")?;
2237
2238        write!(code, "pub mod c32x2 {{\n")?;
2239        {
2240            write!(
2241                code,
2242                "pub const MR_DIV_N: usize = 1; pub const NR: usize = 2; pub const N: usize = 2;"
2243            )?;
2244            write!(
2245                code,
2246                "const XOR_MASKS: [crate::x86::__m128; 4] = unsafe {{ core::mem::transmute([
2247                   [-0.0, -0.0, -0.0, -0.0f32], // no conj
2248                   [ 0.0, -0.0,  0.0, -0.0f32], // conj lhs
2249                   [ 0.0,  0.0,  0.0,  0.0f32], // conj rhs
2250                   [-0.0,  0.0, -0.0,  0.0f32], // conj both
2251                ]) }};\n"
2252            )?;
2253
2254            for mr_div_n in 1..=1 {
2255                for nr in 1..=2 {
2256                    for k in (1..=16).into_iter().map(Some).chain([None]) {
2257                        let kernel = CplxKernel {
2258                            need_mask: false,
2259                            ty: "f32",
2260                            reg_ty: "__m128",
2261                            mask_ty: "__m128i",
2262                            mr_div_n,
2263                            nr,
2264                            k,
2265                            target_features: "avx,avx2,fma",
2266                            n: 2,
2267                            set1: "_mm_set1_ps",
2268                            load_unaligned: "_mm_loadu_ps",
2269                            store_unaligned: "_mm_storeu_ps",
2270                            mask_load_unaligned: Box::new(|ptr, mask| {
2271                                format!("_mm_maskload_ps({ptr}, {mask})")
2272                            }),
2273                            mask_store_unaligned: "_mm_maskstore_ps",
2274                            swap_re_im: "_mm_permute_ps::<0b10_11_00_01>",
2275                            mul_addsub: "_mm_fmsubadd_ps",
2276                            mul_subadd: "_mm_fmaddsub_ps",
2277                            xor: "_mm_xor_ps",
2278                        };
2279
2280                        write!(code, "{kernel}")?;
2281                    }
2282                }
2283            }
2284
2285            write!(
2286                code,
2287                "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<num_complex::Complex<f32>>; 2]; 1]; 17] = [\n"
2288            )?;
2289            for k in (1..=16).into_iter().map(Some).chain([None]) {
2290                write!(code, "[\n")?;
2291                for mr_div_n in 1..=1 {
2292                    write!(code, "[\n")?;
2293                    for nr in 1..=2 {
2294                        write!(
2295                            code,
2296                            "matmul_{mr_div_n}_{nr}_{},",
2297                            k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2298                        )?;
2299                    }
2300                    write!(code, "],\n")?;
2301                }
2302                write!(code, "],\n")?;
2303            }
2304            write!(code, "];\n")?;
2305        }
2306        write!(code, "}}\n")?;
2307        write!(code, " pub mod avx {{\n")?;
2308        {
2309            write!(
2310                code,
2311                "pub const MR_DIV_N: usize = 2; pub const NR: usize = 2; pub const N: usize = 4;"
2312            )?;
2313            write!(
2314                code,
2315                "const XOR_MASKS: [crate::x86::__m256; 4] = unsafe {{ core::mem::transmute([
2316                   [-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0f32], // no conj
2317                   [ 0.0, -0.0,  0.0, -0.0,  0.0, -0.0,  0.0, -0.0f32], // conj lhs
2318                   [ 0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0f32], // conj rhs
2319                   [-0.0,  0.0, -0.0,  0.0, -0.0,  0.0, -0.0,  0.0f32], // conj both
2320                ]) }};\n"
2321            )?;
2322
2323            for mr_div_n in 1..=2 {
2324                for nr in 1..=2 {
2325                    for k in (1..=16).into_iter().map(Some).chain([None]) {
2326                        let kernel = CplxKernel {
2327                            need_mask: true,
2328                            ty: "f32",
2329                            reg_ty: "__m256",
2330                            mask_ty: "__m256i",
2331                            mr_div_n,
2332                            nr,
2333                            k,
2334                            target_features: "avx,avx2,fma",
2335                            n: 4,
2336                            set1: "_mm256_set1_ps",
2337                            load_unaligned: "_mm256_loadu_ps",
2338                            store_unaligned: "_mm256_storeu_ps",
2339                            mask_load_unaligned: Box::new(|ptr, mask| {
2340                                format!("_mm256_maskload_ps({ptr}, {mask})")
2341                            }),
2342                            mask_store_unaligned: "_mm256_maskstore_ps",
2343                            swap_re_im: "_mm256_permute_ps::<0b10_11_00_01>",
2344                            mul_addsub: "_mm256_fmsubadd_ps",
2345                            mul_subadd: "_mm256_fmaddsub_ps",
2346                            xor: "_mm256_xor_ps",
2347                        };
2348
2349                        write!(code, "{kernel}")?;
2350                    }
2351                }
2352            }
2353
2354            write!(
2355                code,
2356                "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<num_complex::Complex<f32>>; 2]; 2]; 17] = [\n"
2357            )?;
2358            for k in (1..=16).into_iter().map(Some).chain([None]) {
2359                write!(code, "[\n")?;
2360                for mr_div_n in 1..=2 {
2361                    write!(code, "[\n")?;
2362                    for nr in 1..=2 {
2363                        write!(
2364                            code,
2365                            "matmul_{mr_div_n}_{nr}_{},",
2366                            k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2367                        )?;
2368                    }
2369                    write!(code, "],\n")?;
2370                }
2371                write!(code, "],\n")?;
2372            }
2373            write!(code, "];\n")?;
2374            write!(
2375                code,
2376                "
2377            pub static MASKS: [crate::x86::__m256i; 4] = unsafe {{ core::mem::transmute([
2378                [u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX],
2379
2380                [u32::MAX, u32::MAX, 0, 0, 0, 0, 0, 0],
2381                [u32::MAX, u32::MAX, u32::MAX, u32::MAX, 0, 0, 0, 0],
2382                [u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, 0, 0],
2383            ]) }};
2384        "
2385            )?;
2386        }
2387        write!(code, "}}\n")?;
2388        write!(
2389            code,
2390            "
2391        #[cfg(feature = \"x86-v4\")]
2392        pub mod avx512 {{\n"
2393        )?;
2394
2395        {
2396            write!(
2397                code,
2398                "pub const MR_DIV_N: usize = 2; pub const NR: usize = 2; pub const N: usize = 8;"
2399            )?;
2400            write!(
2401            code,
2402            "const XOR_MASKS: [crate::x86::__m512; 4] = unsafe {{ core::mem::transmute([
2403                   [-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0f32], // no conj
2404                   [ 0.0, -0.0,  0.0, -0.0,  0.0, -0.0,  0.0, -0.0,  0.0, -0.0,  0.0, -0.0,  0.0, -0.0,  0.0, -0.0f32], // conj lhs
2405                   [ 0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0f32], // conj rhs
2406                   [-0.0,  0.0, -0.0,  0.0, -0.0,  0.0, -0.0,  0.0, -0.0,  0.0, -0.0,  0.0, -0.0,  0.0, -0.0,  0.0f32], // conj both
2407                ]) }};\n"
2408            )?;
2409
2410            for mr_div_n in 1..=2 {
2411                for nr in 1..=2 {
2412                    for k in (1..=16).into_iter().map(Some).chain([None]) {
2413                        let kernel = CplxKernel {
2414                            need_mask: true,
2415                            ty: "f32",
2416                            reg_ty: "__m512",
2417                            mask_ty: "u16",
2418                            mr_div_n,
2419                            nr,
2420                            k,
2421                            target_features: "avx512f",
2422                            n: 8,
2423                            set1: "_mm512_set1_ps",
2424                            load_unaligned: "_mm512_loadu_ps",
2425                            store_unaligned: "_mm512_storeu_ps",
2426                            mask_load_unaligned: Box::new(|ptr, mask| {
2427                                format!("_mm512_maskz_loadu_ps({mask}, {ptr})")
2428                            }),
2429                            mask_store_unaligned: "_mm512_mask_storeu_ps",
2430                            swap_re_im: "_mm512_permute_ps::<0b10_11_00_01>",
2431                            mul_addsub: "crate::x86::subadd_ps",
2432                            mul_subadd: "_mm512_fmaddsub_ps",
2433                            xor: "_mm512_xor_si512",
2434                        };
2435
2436                        write!(code, "{kernel}")?;
2437                    }
2438                }
2439            }
2440
2441            write!(
2442                code,
2443                "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<num_complex::Complex<f32>>; 2]; 2]; 17] = [\n"
2444            )?;
2445            for k in (1..=16).into_iter().map(Some).chain([None]) {
2446                write!(code, "[\n")?;
2447                for mr_div_n in 1..=2 {
2448                    write!(code, "[\n")?;
2449                    for nr in 1..=2 {
2450                        write!(
2451                            code,
2452                            "matmul_{mr_div_n}_{nr}_{},",
2453                            k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2454                        )?;
2455                    }
2456                    write!(code, "],\n")?;
2457                }
2458                write!(code, "],\n")?;
2459            }
2460            write!(code, "];\n")?;
2461            write!(
2462                code,
2463                "
2464            pub static MASKS: [u16; 8] = [
2465                0b1111_1111_1111_1111,
2466                0b0000_0000_0000_0011,
2467                0b0000_0000_0000_1111,
2468                0b0000_0000_0011_1111,
2469                0b0000_0000_1111_1111,
2470                0b0000_0011_1111_1111,
2471                0b0000_1111_1111_1111,
2472                0b0011_1111_1111_1111,
2473            ];
2474        "
2475            )?;
2476        }
2477        write!(code, "}}\n")?;
2478        write!(code, "}}\n")?;
2479
2480        Ok(code)
2481    }
2482
2483    pub fn codegen_c64() -> Result<String, Box<dyn std::error::Error>> {
2484        let mut code = String::new();
2485
2486        write!(code, "pub mod c64 {{\n")?;
2487        write!(code, "pub mod c64x1 {{\n")?;
2488        {
2489            write!(
2490                code,
2491                "pub const MR_DIV_N: usize = 1; pub const NR: usize = 2; pub const N: usize = 1;"
2492            )?;
2493            write!(
2494                code,
2495                "const XOR_MASKS: [crate::x86::__m128d; 4] = unsafe {{ core::mem::transmute([
2496                   [-0.0, -0.0f64], // no conj
2497                   [ 0.0, -0.0f64], // conj lhs
2498                   [ 0.0,  0.0f64], // conj rhs
2499                   [-0.0,  0.0f64], // conj both
2500                ]) }};\n"
2501            )?;
2502
2503            for mr_div_n in 1..=1 {
2504                for nr in 1..=2 {
2505                    for k in (1..=16).into_iter().map(Some).chain([None]) {
2506                        let kernel = CplxKernel {
2507                            need_mask: false,
2508                            ty: "f64",
2509                            reg_ty: "__m128d",
2510                            mask_ty: "__m128i",
2511                            mr_div_n,
2512                            nr,
2513                            k,
2514                            target_features: "avx,avx2,fma",
2515                            n: 1,
2516                            set1: "_mm_set1_pd",
2517                            load_unaligned: "_mm_loadu_pd",
2518                            store_unaligned: "_mm_storeu_pd",
2519                            mask_load_unaligned: Box::new(|ptr, mask| {
2520                                format!("_mm_maskload_pd({ptr}, {mask})")
2521                            }),
2522                            mask_store_unaligned: "_mm_maskstore_pd",
2523                            swap_re_im: "_mm_permute_pd::<0b01>",
2524                            mul_addsub: "_mm_fmsubadd_pd",
2525                            mul_subadd: "_mm_fmaddsub_pd",
2526                            xor: "_mm_xor_pd",
2527                        };
2528
2529                        write!(code, "{kernel}")?;
2530                    }
2531                }
2532            }
2533
2534            write!(
2535                code,
2536                "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<num_complex::Complex<f64>>; 2]; 1]; 17] = [\n"
2537            )?;
2538            for k in (1..=16).into_iter().map(Some).chain([None]) {
2539                write!(code, "[\n")?;
2540                for mr_div_n in 1..=1 {
2541                    write!(code, "[\n")?;
2542                    for nr in 1..=2 {
2543                        write!(
2544                            code,
2545                            "matmul_{mr_div_n}_{nr}_{},",
2546                            k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2547                        )?;
2548                    }
2549                    write!(code, "],\n")?;
2550                }
2551                write!(code, "],\n")?;
2552            }
2553            write!(code, "];\n")?;
2554        }
2555        write!(code, "}}\n")?;
2556
2557        write!(code, "pub mod avx {{\n")?;
2558        {
2559            write!(
2560                code,
2561                "pub const MR_DIV_N: usize = 2; pub const NR: usize = 2; pub const N: usize = 2;"
2562            )?;
2563            write!(
2564                code,
2565                "const XOR_MASKS: [crate::x86::__m256d; 4] = unsafe {{ core::mem::transmute([
2566                   [-0.0, -0.0, -0.0, -0.0f64], // no conj
2567                   [ 0.0, -0.0,  0.0, -0.0f64], // conj lhs
2568                   [ 0.0,  0.0,  0.0,  0.0f64], // conj rhs
2569                   [-0.0,  0.0, -0.0,  0.0f64], // conj both
2570                ]) }};\n"
2571            )?;
2572
2573            for mr_div_n in 1..=2 {
2574                for nr in 1..=2 {
2575                    for k in (1..=16).into_iter().map(Some).chain([None]) {
2576                        let kernel = CplxKernel {
2577                            need_mask: true,
2578                            ty: "f64",
2579                            reg_ty: "__m256d",
2580                            mask_ty: "__m256i",
2581                            mr_div_n,
2582                            nr,
2583                            k,
2584                            target_features: "avx,avx2,fma",
2585                            n: 2,
2586                            set1: "_mm256_set1_pd",
2587                            load_unaligned: "_mm256_loadu_pd",
2588                            store_unaligned: "_mm256_storeu_pd",
2589                            mask_load_unaligned: Box::new(|ptr, mask| {
2590                                format!("_mm256_maskload_pd({ptr}, {mask})")
2591                            }),
2592                            mask_store_unaligned: "_mm256_maskstore_pd",
2593                            swap_re_im: "_mm256_permute_pd::<0b0101>",
2594                            mul_addsub: "_mm256_fmsubadd_pd",
2595                            mul_subadd: "_mm256_fmaddsub_pd",
2596                            xor: "_mm256_xor_pd",
2597                        };
2598
2599                        write!(code, "{kernel}")?;
2600                    }
2601                }
2602            }
2603
2604            write!(
2605                code,
2606                "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<num_complex::Complex<f64>>; 2]; 2]; 17] = [\n"
2607            )?;
2608            for k in (1..=16).into_iter().map(Some).chain([None]) {
2609                write!(code, "[\n")?;
2610                for mr_div_n in 1..=2 {
2611                    write!(code, "[\n")?;
2612                    for nr in 1..=2 {
2613                        write!(
2614                            code,
2615                            "matmul_{mr_div_n}_{nr}_{},",
2616                            k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2617                        )?;
2618                    }
2619                    write!(code, "],\n")?;
2620                }
2621                write!(code, "],\n")?;
2622            }
2623            write!(code, "];\n")?;
2624            write!(
2625                code,
2626                "
2627            pub static MASKS: [crate::x86::__m256i; 2] = unsafe {{ core::mem::transmute([
2628                [u64::MAX, u64::MAX, u64::MAX, u64::MAX],
2629
2630                [u64::MAX, u64::MAX, 0, 0],
2631            ]) }};
2632        "
2633            )?;
2634        }
2635        write!(code, "}}\n")?;
2636        write!(
2637            code,
2638            "
2639        #[cfg(feature = \"x86-v4\")]
2640        pub mod avx512 {{\n"
2641        )?;
2642
2643        {
2644            write!(
2645                code,
2646                "pub const MR_DIV_N: usize = 2; pub const NR: usize = 2; pub const N: usize = 4;"
2647            )?;
2648            write!(
2649                code,
2650                "const XOR_MASKS: [crate::x86::__m512; 4] = unsafe {{ core::mem::transmute([
2651                   [-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0f64], // no conj
2652                   [ 0.0, -0.0,  0.0, -0.0,  0.0, -0.0,  0.0, -0.0f64], // conj lhs
2653                   [ 0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0f64], // conj rhs
2654                   [-0.0,  0.0, -0.0,  0.0, -0.0,  0.0, -0.0,  0.0f64], // conj both
2655                ]) }};\n"
2656            )?;
2657
2658            for mr_div_n in 1..=2 {
2659                for nr in 1..=2 {
2660                    for k in (1..=16).into_iter().map(Some).chain([None]) {
2661                        let kernel = CplxKernel {
2662                            need_mask: true,
2663                            ty: "f64",
2664                            reg_ty: "__m512d",
2665                            mask_ty: "u8",
2666                            mr_div_n,
2667                            nr,
2668                            k,
2669                            target_features: "avx512f",
2670                            n: 4,
2671                            set1: "_mm512_set1_pd",
2672                            load_unaligned: "_mm512_loadu_pd",
2673                            store_unaligned: "_mm512_storeu_pd",
2674                            mask_load_unaligned: Box::new(|ptr, mask| {
2675                                format!("_mm512_maskz_loadu_pd({mask}, {ptr})")
2676                            }),
2677                            mask_store_unaligned: "_mm512_mask_storeu_pd",
2678                            swap_re_im: "_mm512_permute_pd::<0b01010101>",
2679                            mul_addsub: "crate::x86::subadd_pd",
2680                            mul_subadd: "_mm512_fmaddsub_pd",
2681                            xor: "_mm512_xor_si512",
2682                        };
2683
2684                        write!(code, "{kernel}")?;
2685                    }
2686                }
2687            }
2688
2689            write!(
2690                code,
2691                "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<num_complex::Complex<f64>>; 2]; 2]; 17] = [\n"
2692            )?;
2693            for k in (1..=16).into_iter().map(Some).chain([None]) {
2694                write!(code, "[\n")?;
2695                for mr_div_n in 1..=2 {
2696                    write!(code, "[\n")?;
2697                    for nr in 1..=2 {
2698                        write!(
2699                            code,
2700                            "matmul_{mr_div_n}_{nr}_{},",
2701                            k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2702                        )?;
2703                    }
2704                    write!(code, "],\n")?;
2705                }
2706                write!(code, "],\n")?;
2707            }
2708            write!(code, "];\n")?;
2709            write!(
2710                code,
2711                "
2712            pub static MASKS: [u8; 4] = [
2713                0b1111_1111,
2714                0b0000_0011,
2715                0b0000_1111,
2716                0b0011_1111,
2717            ];
2718        "
2719            )?;
2720        }
2721        write!(code, "}}\n")?;
2722        write!(code, "}}\n")?;
2723
2724        Ok(code)
2725    }
2726}