1use std::fmt::Display;
7use std::fmt::Write;
8
9mod generic {
10 use super::*;
11
12 pub struct RealKernel {
13 pub ty: &'static str,
14 pub reg_ty: &'static str,
15 pub n: usize,
17 pub mr: usize,
18 pub nr: usize,
19 pub k: Option<usize>,
20
21 pub target_features: &'static str,
22 pub load_unaligned: [&'static str; 3],
23 pub store_unaligned: [&'static str; 3],
24 pub set1: &'static str,
25 pub mul_add: &'static str,
26 }
27
28 pub struct CplxKernel {
29 pub ty: &'static str,
30 pub reg_ty: &'static str,
31 pub n: usize,
33 pub mr: usize,
34 pub nr: usize,
35 pub k: Option<usize>,
36
37 pub target_features: &'static str,
38 pub load_unaligned: [&'static str; 3],
39 pub store_unaligned: [&'static str; 3],
40 pub set1: &'static str,
41 pub mul_add: &'static str,
42 pub conj_mul_add: &'static str,
43 pub conj: &'static str,
44 }
45
46 impl Display for RealKernel {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 let Self {
49 ty,
50 reg_ty,
51 n,
52 mr,
53 nr,
54 k,
55 target_features,
56 load_unaligned,
57 store_unaligned,
58 mul_add,
59 ..
60 } = self;
61
62 write!(f, "#[target_feature(enable = \"{target_features}\")]\n")?;
63 write!(
64 f,
65 r#"pub unsafe fn matmul_{mr}_{nr}_{}(
66 &nano_gemm_core::MicroKernelData {{ alpha, beta, k, dst_cs, lhs_cs, rhs_rs, rhs_cs, .. }}: &nano_gemm_core::MicroKernelData< {ty} >,
67 dst: *mut {ty},
68 lhs: *const {ty},
69 rhs: *const {ty},
70 ) {{
71"#,
72 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
73 )?;
74
75 write!(f, "_ = k;\n")?;
76 let mut i = 0;
77 while i < *mr {
78 let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
79
80 for j in 0..*nr {
81 write!(
82 f,
83 "let mut acc_{i}_{j}: {} = core::mem::zeroed();\n",
84 reg_ty
85 )?;
86 }
87
88 i += 1 << ii;
89 }
90
91 if let Some(k) = self.k {
92 for depth in 0..k {
93 write!(f, "let depth = {depth};\n")?;
94 self.inner_kernel(f)?;
95 }
96 } else {
97 write!(f, "for depth in 0..k as isize {{")?;
98 self.inner_kernel(f)?;
99 write!(f, "}}")?;
100 }
101
102 write!(f, "if alpha == 1.0 {{")?;
103 write!(f, "let beta = {}(beta);\n", self.set1)?;
104 for j in 0..self.nr {
105 let mut i = 0;
106 while i < *mr {
107 let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
108 write!(f, "{{")?;
109 write!(f, "let dst = dst.offset({i} + {j} * dst_cs);")?;
110 write!(
111 f,
112 "{}(dst, {mul_add}(beta, acc_{i}_{j}, {}(dst)));\n",
113 store_unaligned[ii], load_unaligned[ii],
114 )?;
115 write!(f, "}}")?;
116 i += 1 << ii;
117 }
118 }
119 write!(f, "}}")?;
120
121 write!(f, "else if alpha == 0.0 {{")?;
122 write!(f, "let beta = {}(beta);\n", self.set1)?;
123 for j in 0..self.nr {
124 let mut i = 0;
125 while i < *mr {
126 let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
127 write!(f, "{{")?;
128 write!(f, "let dst = dst.offset({i} + {j} * dst_cs);")?;
129 write!(
130 f,
131 "{}(dst, {mul_add}(beta, acc_{i}_{j}, core::mem::zeroed()));\n",
132 store_unaligned[ii],
133 )?;
134 write!(f, "}}")?;
135 i += 1 << ii;
136 }
137 }
138 write!(f, "}}")?;
139
140 write!(f, "else {{")?;
141 write!(f, "let beta = {}(beta);\n", self.set1)?;
142 write!(f, "let alpha = {}(alpha);\n", self.set1)?;
143 for j in 0..self.nr {
144 let mut i = 0;
145 while i < *mr {
146 let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
147 write!(f, "{{")?;
148 write!(f, "let dst = dst.offset({i} + {j} * dst_cs);")?;
149 write!(
150 f,
151 "{}(dst, {mul_add}(beta, acc_{i}_{j}, {mul_add}(alpha, {}(dst), core::mem::zeroed())));\n",
152 store_unaligned[ii], load_unaligned[ii],
153 )?;
154 write!(f, "}}")?;
155 i += 1 << ii;
156 }
157 }
158 write!(f, "}}")?;
159
160 write!(f, "}}")
161 }
162 }
163
164 impl RealKernel {
165 fn inner_kernel(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
166 let Self {
167 mr,
168 set1,
169 mul_add,
170 load_unaligned,
171 n,
172 ..
173 } = self;
174
175 let mut i = 0;
176 while i < *mr {
177 let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
178 write!(
179 f,
180 "let tmp_lhs_{i} = {}(lhs.offset(depth * lhs_cs + {i}));",
181 load_unaligned[ii],
182 )?;
183 i += 1 << ii;
184 }
185 for j in 0..self.nr {
186 write!(
187 f,
188 "let tmp_rhs = {set1}(*rhs.offset(depth * rhs_rs + {j} * rhs_cs));\n",
189 )?;
190
191 let mut i = 0;
192 while i < *mr {
193 let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
194 write!(
195 f,
196 "acc_{i}_{j} = {mul_add}(tmp_lhs_{i}, tmp_rhs, acc_{i}_{j});\n",
197 )?;
198 i += 1 << ii;
199 }
200 }
201
202 Ok(())
203 }
204 }
205
206 impl Display for CplxKernel {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 let Self {
209 ty,
210 reg_ty,
211 n,
212 mr,
213 nr,
214 k,
215 target_features,
216 load_unaligned,
217 store_unaligned,
218 mul_add,
219 conj,
220 ..
221 } = self;
222
223 write!(f, "#[target_feature(enable = \"{target_features}\")]\n")?;
224 write!(
225 f,
226 r#"pub unsafe fn matmul_{mr}_{nr}_{}(
227 &nano_gemm_core::MicroKernelData {{ alpha, beta, k, dst_cs, lhs_cs, rhs_rs, rhs_cs, conj_lhs, conj_rhs, .. }}: &nano_gemm_core::MicroKernelData< {ty} >,
228 dst: *mut {ty},
229 lhs: *const {ty},
230 rhs: *const {ty},
231 ) {{
232"#,
233 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
234 )?;
235
236 write!(f, "_ = k;\n")?;
237 let mut i = 0;
238 while i < *mr {
239 let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
240
241 for j in 0..*nr {
242 write!(
243 f,
244 "let mut acc_{i}_{j}: {} = core::mem::zeroed();\n",
245 reg_ty
246 )?;
247 }
248
249 i += 1 << ii;
250 }
251
252 write!(f, "if conj_lhs == conj_rhs {{")?;
253 if let Some(k) = self.k {
254 for depth in 0..k {
255 write!(f, "let depth = {depth};\n")?;
256 self.inner_kernel_no_conj(f)?;
257 }
258 } else {
259 write!(f, "for depth in 0..k as isize {{")?;
260 self.inner_kernel_no_conj(f)?;
261 write!(f, "}}")?;
262 }
263 write!(f, "}} else {{")?;
264 if let Some(k) = self.k {
265 for depth in 0..k {
266 write!(f, "let depth = {depth};\n")?;
267 self.inner_kernel_conj(f)?;
268 }
269 } else {
270 write!(f, "for depth in 0..k as isize {{")?;
271 self.inner_kernel_conj(f)?;
272 write!(f, "}}")?;
273 }
274 write!(f, "}}")?;
275
276 write!(f, "if conj_rhs {{")?;
277 for j in 0..self.nr {
278 let mut i = 0;
279 while i < *mr {
280 let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
281 write!(f, "acc_{i}_{j} = {conj}(acc_{i}_{j});")?;
282 i += 1 << ii;
283 }
284 }
285 write!(f, "}}")?;
286
287 write!(f, "if alpha == ({ty} {{ re: 1.0, im: 0.0 }}) {{")?;
288 write!(f, "let beta = {}(beta);\n", self.set1)?;
289 for j in 0..self.nr {
290 let mut i = 0;
291 while i < *mr {
292 let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
293 write!(f, "{{")?;
294 write!(f, "let dst = dst.offset({i} + {j} * dst_cs);")?;
295 write!(
296 f,
297 "{}(dst, {mul_add}(beta, acc_{i}_{j}, {}(dst)));\n",
298 store_unaligned[ii], load_unaligned[ii],
299 )?;
300 write!(f, "}}")?;
301 i += 1 << ii;
302 }
303 }
304 write!(f, "}}")?;
305
306 write!(f, "else if alpha == ({ty} {{ re: 0.0, im: 0.0 }}) {{")?;
307 write!(f, "let beta = {}(beta);\n", self.set1)?;
308 for j in 0..self.nr {
309 let mut i = 0;
310 while i < *mr {
311 let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
312 write!(f, "{{")?;
313 write!(f, "let dst = dst.offset({i} + {j} * dst_cs);")?;
314 write!(
315 f,
316 "{}(dst, {mul_add}(beta, acc_{i}_{j}, core::mem::zeroed()));\n",
317 store_unaligned[ii],
318 )?;
319 write!(f, "}}")?;
320 i += 1 << ii;
321 }
322 }
323 write!(f, "}}")?;
324
325 write!(f, "else {{")?;
326 write!(f, "let beta = {}(beta);\n", self.set1)?;
327 write!(f, "let alpha = {}(alpha);\n", self.set1)?;
328 for j in 0..self.nr {
329 let mut i = 0;
330 while i < *mr {
331 let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
332 write!(f, "{{")?;
333 write!(f, "let dst = dst.offset({i} + {j} * dst_cs);")?;
334 write!(
335 f,
336 "{}(dst, {mul_add}(beta, acc_{i}_{j}, {mul_add}(alpha, {}(dst), core::mem::zeroed())));\n",
337 store_unaligned[ii], load_unaligned[ii],
338 )?;
339 write!(f, "}}")?;
340 i += 1 << ii;
341 }
342 }
343 write!(f, "}}")?;
344
345 write!(f, "}}")
346 }
347 }
348
349 impl CplxKernel {
350 fn inner_kernel_no_conj(
351 &self,
352 f: &mut std::fmt::Formatter<'_>,
353 ) -> Result<(), std::fmt::Error> {
354 let Self {
355 mr,
356 set1,
357 mul_add,
358 load_unaligned,
359 n,
360 ..
361 } = self;
362
363 let mut i = 0;
364 while i < *mr {
365 let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
366 write!(
367 f,
368 "let tmp_lhs_{i} = {}(lhs.offset(depth * lhs_cs + {i}));",
369 load_unaligned[ii],
370 )?;
371 i += 1 << ii;
372 }
373 for j in 0..self.nr {
374 write!(
375 f,
376 "let tmp_rhs = {set1}(*rhs.offset(depth * rhs_rs + {j} * rhs_cs));\n",
377 )?;
378
379 let mut i = 0;
380 while i < *mr {
381 let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
382 write!(
383 f,
384 "acc_{i}_{j} = {mul_add}(tmp_lhs_{i}, tmp_rhs, acc_{i}_{j});\n",
385 )?;
386 i += 1 << ii;
387 }
388 }
389
390 Ok(())
391 }
392
393 fn inner_kernel_conj(
394 &self,
395 f: &mut std::fmt::Formatter<'_>,
396 ) -> Result<(), std::fmt::Error> {
397 let Self {
398 mr,
399 set1,
400 conj_mul_add,
401 load_unaligned,
402 n,
403 ..
404 } = self;
405
406 let mut i = 0;
407 while i < *mr {
408 let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
409 write!(
410 f,
411 "let tmp_lhs_{i} = {}(lhs.offset(depth * lhs_cs + {i}));",
412 load_unaligned[ii],
413 )?;
414 i += 1 << ii;
415 }
416 for j in 0..self.nr {
417 write!(
418 f,
419 "let tmp_rhs = {set1}(*rhs.offset(depth * rhs_rs + {j} * rhs_cs));\n",
420 )?;
421
422 let mut i = 0;
423 while i < *mr {
424 let ii = Ord::min((mr - i).ilog2() as usize, n.ilog2() as usize);
425 write!(
426 f,
427 "acc_{i}_{j} = {conj_mul_add}(tmp_lhs_{i}, tmp_rhs, acc_{i}_{j});\n",
428 )?;
429 i += 1 << ii;
430 }
431 }
432
433 Ok(())
434 }
435 }
436}
437
438pub mod aarch64 {
439 use super::*;
440 use generic::{CplxKernel, RealKernel};
441
442 pub fn codegen_f32() -> Result<String, Box<dyn std::error::Error>> {
443 let mut code = String::new();
444
445 write!(code, "pub mod f32 {{\n")?;
446 write!(code, "pub mod neon {{\n")?;
447 write!(
448 code,
449 r###"
450 use core::arch::aarch64::*;
451 use core::mem::transmute;
452 use core::mem::transmute_copy;
453
454 #[inline(always)]
455 unsafe fn set1(v: f32) -> float32x4_t {{
456 transmute([v; 4])
457 }}
458 #[inline(always)]
459 unsafe fn mul_add(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {{
460 vmlaq_f32(c, a, b)
461 }}
462 #[inline(always)]
463 unsafe fn load_1(ptr: *const f32) -> float32x4_t {{
464 transmute([*ptr; 4])
465 }}
466 #[inline(always)]
467 unsafe fn load_2(ptr: *const f32) -> float32x4_t {{
468 transmute([*(ptr as *const [f32; 2]); 2])
469 }}
470 #[inline(always)]
471 unsafe fn load_4(ptr: *const f32) -> float32x4_t {{
472 transmute(*(ptr as *const [f32; 4]))
473 }}
474
475 #[inline(always)]
476 unsafe fn store_1(ptr: *mut f32, v: float32x4_t) {{
477 *(ptr as *mut [f32; 1]) = transmute_copy(&v);
478 }}
479 #[inline(always)]
480 unsafe fn store_2(ptr: *mut f32, v: float32x4_t) {{
481 *(ptr as *mut [f32; 2]) = transmute_copy(&v);
482 }}
483 #[inline(always)]
484 unsafe fn store_4(ptr: *mut f32, v: float32x4_t) {{
485 *(ptr as *mut [f32; 4]) = transmute_copy(&v);
486 }}
487 "###
488 )?;
489 for mr in 1..=8 {
490 for nr in 1..=4 {
491 for k in (1..=16).into_iter().map(Some).chain([None]) {
492 let kernel = RealKernel {
493 ty: "f32",
494 reg_ty: "float32x4_t",
495 n: 4,
496 mr,
497 nr,
498 k,
499 target_features: "neon",
500 load_unaligned: ["load_1", "load_2", "load_4"],
501 store_unaligned: ["store_1", "store_2", "store_4"],
502 set1: "set1",
503 mul_add: "mul_add",
504 };
505 write!(code, "{kernel}")?;
506 }
507 }
508 }
509
510 write!(
511 code,
512 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f32>; 4]; 8]; 17] = [\n"
513 )?;
514 for k in (1..=16).into_iter().map(Some).chain([None]) {
515 write!(code, "[\n")?;
516 for mr in 1..=8 {
517 write!(code, "[\n")?;
518 for nr in 1..=4 {
519 write!(
520 code,
521 "matmul_{mr}_{nr}_{},",
522 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
523 )?;
524 }
525 write!(code, "],\n")?;
526 }
527 write!(code, "],\n")?;
528 }
529 write!(code, "];\n")?;
530 write!(code, "}}")?;
531 write!(code, "}}")?;
532
533 Ok(code)
534 }
535
536 pub fn codegen_f64() -> Result<String, Box<dyn std::error::Error>> {
537 let mut code = String::new();
538
539 write!(code, "pub mod f64 {{\n")?;
540 write!(code, "pub mod neon {{\n")?;
541 write!(
542 code,
543 r###"
544 use core::arch::aarch64::*;
545 use core::mem::transmute;
546 use core::mem::transmute_copy;
547
548 #[inline(always)]
549 unsafe fn set1(v: f64) -> float64x2_t {{
550 transmute([v; 2])
551 }}
552 #[inline(always)]
553 unsafe fn mul_add(a: float64x2_t, b: float64x2_t, c: float64x2_t) -> float64x2_t {{
554 vmlaq_f64(c, a, b)
555 }}
556 #[inline(always)]
557 unsafe fn load_1(ptr: *const f64) -> float64x2_t {{
558 transmute([*ptr; 2])
559 }}
560 #[inline(always)]
561 unsafe fn load_2(ptr: *const f64) -> float64x2_t {{
562 transmute(*(ptr as *const [f64; 2]))
563 }}
564
565 #[inline(always)]
566 unsafe fn store_1(ptr: *mut f64, v: float64x2_t) {{
567 *(ptr as *mut [f64; 1]) = transmute_copy(&v);
568 }}
569 #[inline(always)]
570 unsafe fn store_2(ptr: *mut f64, v: float64x2_t) {{
571 *(ptr as *mut [f64; 2]) = transmute_copy(&v);
572 }}
573 "###
574 )?;
575 for mr in 1..=4 {
576 for nr in 1..=4 {
577 for k in (1..=16).into_iter().map(Some).chain([None]) {
578 let kernel = RealKernel {
579 ty: "f64",
580 reg_ty: "float64x2_t",
581 n: 2,
582 mr,
583 nr,
584 k,
585 target_features: "neon",
586 load_unaligned: ["load_1", "load_2", "load_4"],
587 store_unaligned: ["store_1", "store_2", "store_4"],
588 set1: "set1",
589 mul_add: "mul_add",
590 };
591 write!(code, "{kernel}")?;
592 }
593 }
594 }
595
596 write!(
597 code,
598 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f64>; 4]; 4]; 17] = [\n"
599 )?;
600 for k in (1..=16).into_iter().map(Some).chain([None]) {
601 write!(code, "[\n")?;
602 for mr in 1..=4 {
603 write!(code, "[\n")?;
604 for nr in 1..=4 {
605 write!(
606 code,
607 "matmul_{mr}_{nr}_{},",
608 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
609 )?;
610 }
611 write!(code, "],\n")?;
612 }
613 write!(code, "],\n")?;
614 }
615 write!(code, "];\n")?;
616 write!(code, "}}")?;
617 write!(code, "}}")?;
618
619 Ok(code)
620 }
621
622 pub fn codegen_c32() -> Result<String, Box<dyn std::error::Error>> {
623 let mut code = String::new();
624
625 write!(code, "pub mod c32 {{\n")?;
626 write!(code, "pub mod neon {{ use crate::c32;\n")?;
627 write!(
628 code,
629 r###"
630 use core::arch::aarch64::*;
631 use core::mem::transmute;
632 use core::mem::transmute_copy;
633
634 #[inline(always)]
635 unsafe fn set1(v: c32) -> float32x4_t {{
636 transmute([v; 2])
637 }}
638 #[inline(always)]
639 unsafe fn mul_add(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {{
640 vcmlaq_90_f32(vcmlaq_0_f32(c, a, b), a, b)
641 }}
642 #[inline(always)]
643 unsafe fn conj_mul_add(a: float32x4_t, b: float32x4_t, c: float32x4_t) -> float32x4_t {{
644 vcmlaq_270_f32(vcmlaq_0_f32(c, a, b), a, b)
645 }}
646 #[inline(always)]
647 unsafe fn conj(a: float32x4_t) -> float32x4_t {{
648 transmute(veorq_u32(transmute(a), transmute([0.0, -0.0, 0.0, -0.0f32])))
649 }}
650 #[inline(always)]
651 unsafe fn load_1(ptr: *const c32) -> float32x4_t {{
652 transmute([*ptr; 2])
653 }}
654 #[inline(always)]
655 unsafe fn load_2(ptr: *const c32) -> float32x4_t {{
656 transmute(*(ptr as *const [c32; 2]))
657 }}
658
659 #[inline(always)]
660 unsafe fn store_1(ptr: *mut c32, v: float32x4_t) {{
661 *(ptr as *mut [c32; 1]) = transmute_copy(&v);
662 }}
663 #[inline(always)]
664 unsafe fn store_2(ptr: *mut c32, v: float32x4_t) {{
665 *(ptr as *mut [c32; 2]) = transmute_copy(&v);
666 }}
667
668#[inline]
669#[target_feature(enable = "neon,fcma")]
670unsafe fn vcmlaq_0_f32(mut acc: float32x4_t, lhs: float32x4_t, rhs: float32x4_t) -> float32x4_t {{
671 core::arch::asm!(
672 "fcmla {{0:v}}.4s, {{1:v}}.4s, {{2:v}}.4s, 0",
673 inout(vreg) acc,
674 in(vreg) lhs,
675 in(vreg) rhs,
676 options(pure, nomem, nostack));
677 acc
678}}
679
680#[inline]
681#[target_feature(enable = "neon,fcma")]
682unsafe fn vcmlaq_90_f32(mut acc: float32x4_t, lhs: float32x4_t, rhs: float32x4_t) -> float32x4_t {{
683 core::arch::asm!(
684 "fcmla {{0:v}}.4s, {{1:v}}.4s, {{2:v}}.4s, 90",
685 inout(vreg) acc,
686 in(vreg) lhs,
687 in(vreg) rhs,
688 options(pure, nomem, nostack));
689 acc
690}}
691
692#[inline]
693#[target_feature(enable = "neon,fcma")]
694unsafe fn vcmlaq_270_f32(mut acc: float32x4_t, lhs: float32x4_t, rhs: float32x4_t) -> float32x4_t {{
695 core::arch::asm!(
696 "fcmla {{0:v}}.4s, {{1:v}}.4s, {{2:v}}.4s, 270",
697 inout(vreg) acc,
698 in(vreg) lhs,
699 in(vreg) rhs,
700 options(pure, nomem, nostack));
701 acc
702}}
703 "###
704 )?;
705 for mr in 1..=4 {
706 for nr in 1..=4 {
707 for k in (1..=16).into_iter().map(Some).chain([None]) {
708 let kernel = CplxKernel {
709 ty: "c32",
710 reg_ty: "float32x4_t",
711 n: 2,
712 mr,
713 nr,
714 k,
715 target_features: "neon,fcma",
716 load_unaligned: ["load_1", "load_2", "load_4"],
717 store_unaligned: ["store_1", "store_2", "store_4"],
718 set1: "set1",
719 mul_add: "mul_add",
720 conj_mul_add: "conj_mul_add",
721 conj: "conj",
722 };
723 write!(code, "{kernel}")?;
724 }
725 }
726 }
727
728 write!(
729 code,
730 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<c32>; 4]; 4]; 17] = [\n"
731 )?;
732 for k in (1..=16).into_iter().map(Some).chain([None]) {
733 write!(code, "[\n")?;
734 for mr in 1..=4 {
735 write!(code, "[\n")?;
736 for nr in 1..=4 {
737 write!(
738 code,
739 "matmul_{mr}_{nr}_{},",
740 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
741 )?;
742 }
743 write!(code, "],\n")?;
744 }
745 write!(code, "],\n")?;
746 }
747 write!(code, "];\n")?;
748 write!(code, "}}")?;
749 write!(code, "}}")?;
750
751 Ok(code)
752 }
753
754 pub fn codegen_c64() -> Result<String, Box<dyn std::error::Error>> {
755 let mut code = String::new();
756
757 write!(code, "pub mod c64 {{\n")?;
758 write!(code, "pub mod neon {{ use crate::c64;\n")?;
759 write!(
760 code,
761 r###"
762 use core::arch::aarch64::*;
763 use core::mem::transmute;
764
765 #[inline(always)]
766 unsafe fn set1(v: c64) -> float64x2_t {{
767 transmute(v)
768 }}
769 #[inline(always)]
770 unsafe fn mul_add(a: float64x2_t, b: float64x2_t, c: float64x2_t) -> float64x2_t {{
771 vcmlaq_90_f64(vcmlaq_0_f64(c, a, b), a, b)
772 }}
773 #[inline(always)]
774 unsafe fn conj_mul_add(a: float64x2_t, b: float64x2_t, c: float64x2_t) -> float64x2_t {{
775 vcmlaq_270_f64(vcmlaq_0_f64(c, a, b), a, b)
776 }}
777 #[inline(always)]
778 unsafe fn conj(a: float64x2_t) -> float64x2_t {{
779 transmute(veorq_u64(transmute(a), transmute([0.0, -0.0f64])))
780 }}
781 #[inline(always)]
782 unsafe fn load_1(ptr: *const c64) -> float64x2_t {{
783 transmute(*ptr)
784 }}
785
786 #[inline(always)]
787 unsafe fn store_1(ptr: *mut c64, v: float64x2_t) {{
788 *ptr = transmute(v);
789 }}
790
791#[inline]
792#[target_feature(enable = "neon,fcma")]
793unsafe fn vcmlaq_0_f64(mut acc: float64x2_t, lhs: float64x2_t, rhs: float64x2_t) -> float64x2_t {{
794 core::arch::asm!(
795 "fcmla {{0:v}}.2d, {{1:v}}.2d, {{2:v}}.2d, 0",
796 inout(vreg) acc,
797 in(vreg) lhs,
798 in(vreg) rhs,
799 options(pure, nomem, nostack));
800 acc
801}}
802
803#[inline]
804#[target_feature(enable = "neon,fcma")]
805unsafe fn vcmlaq_90_f64(mut acc: float64x2_t, lhs: float64x2_t, rhs: float64x2_t) -> float64x2_t {{
806 core::arch::asm!(
807 "fcmla {{0:v}}.2d, {{1:v}}.2d, {{2:v}}.2d, 90",
808 inout(vreg) acc,
809 in(vreg) lhs,
810 in(vreg) rhs,
811 options(pure, nomem, nostack));
812 acc
813}}
814
815#[inline]
816#[target_feature(enable = "neon,fcma")]
817unsafe fn vcmlaq_270_f64(mut acc: float64x2_t, lhs: float64x2_t, rhs: float64x2_t) -> float64x2_t {{
818 core::arch::asm!(
819 "fcmla {{0:v}}.2d, {{1:v}}.2d, {{2:v}}.2d, 270",
820 inout(vreg) acc,
821 in(vreg) lhs,
822 in(vreg) rhs,
823 options(pure, nomem, nostack));
824 acc
825}}
826 "###
827 )?;
828 for mr in 1..=2 {
829 for nr in 1..=4 {
830 for k in (1..=16).into_iter().map(Some).chain([None]) {
831 let kernel = CplxKernel {
832 ty: "c64",
833 reg_ty: "float64x2_t",
834 n: 1,
835 mr,
836 nr,
837 k,
838 target_features: "neon,fcma",
839 load_unaligned: ["load_1", "load_2", "load_4"],
840 store_unaligned: ["store_1", "store_2", "store_4"],
841 set1: "set1",
842 mul_add: "mul_add",
843 conj_mul_add: "conj_mul_add",
844 conj: "conj",
845 };
846 write!(code, "{kernel}")?;
847 }
848 }
849 }
850
851 write!(
852 code,
853 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<c64>; 4]; 2]; 17] = [\n"
854 )?;
855 for k in (1..=16).into_iter().map(Some).chain([None]) {
856 write!(code, "[\n")?;
857 for mr in 1..=2 {
858 write!(code, "[\n")?;
859 for nr in 1..=4 {
860 write!(
861 code,
862 "matmul_{mr}_{nr}_{},",
863 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
864 )?;
865 }
866 write!(code, "],\n")?;
867 }
868 write!(code, "],\n")?;
869 }
870 write!(code, "];\n")?;
871 write!(code, "}}")?;
872 write!(code, "}}")?;
873
874 Ok(code)
875 }
876}
877
878pub mod x86 {
879 use super::*;
880
881 struct RealKernel {
882 ty: &'static str,
883 reg_ty: &'static str,
884 mask_ty: &'static str,
885 n: usize,
887 mr_div_n: usize,
888 nr: usize,
889 k: Option<usize>,
890
891 target_features: &'static str,
892 set1: &'static str,
893 load_unaligned: &'static str,
894 store_unaligned: &'static str,
895 mask_load_unaligned: Box<dyn Fn(String, String) -> String>,
896
897 mask_store_unaligned: &'static str,
898 mul_add: &'static str,
899 mul: &'static str,
900 need_mask: bool,
901 }
902
903 struct CplxKernel {
904 ty: &'static str,
905 reg_ty: &'static str,
906 mask_ty: &'static str,
907 n: usize,
909 mr_div_n: usize,
910 nr: usize,
911 k: Option<usize>,
912
913 target_features: &'static str,
914 set1: &'static str,
915 swap_re_im: &'static str,
916
917 load_unaligned: &'static str,
918 store_unaligned: &'static str,
919 mask_load_unaligned: Box<dyn Fn(String, String) -> String>,
920
921 mask_store_unaligned: &'static str,
922 mul_addsub: &'static str,
923 mul_subadd: &'static str,
924 xor: &'static str,
925 need_mask: bool,
926 }
927
928 impl Display for RealKernel {
929 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
930 write!(
936 f,
937 "
938 #[target_feature(enable = \"{}\")]\n",
939 self.target_features
940 )?;
941 write!(
942 f,
943 r#"pub unsafe fn matmul_{0:}_{1:}_{2:}(
944 &nano_gemm_core::MicroKernelData {{ alpha, beta, k, dst_cs, lhs_cs, rhs_rs, rhs_cs, last_mask, .. }}: &nano_gemm_core::MicroKernelData< {3:} >,
945 dst: *mut {3:},
946 lhs: *const {3:},
947 rhs: *const {3:},
948 ) {{
949"#,
950 self.mr_div_n,
951 self.nr,
952 self.k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
953 self.ty,
954 )?;
955 write!(
956 f,
957 r#"
958 #[cfg(target_arch = "x86_64")]
959 use core::arch::x86_64::*;
960 #[cfg(target_arch = "x86")]
961 use core::arch::x86::*;
962 "#
963 )?;
964
965 write!(f, "_ = k;\n")?;
966 write!(f, "type Reg = {};\n", self.reg_ty)?;
967 write!(f, "const N: isize = {};\n", self.n)?;
968 write!(
969 f,
970 "let mut acc: [[Reg; {}]; {}] = core::mem::zeroed();\n",
971 self.nr, self.mr_div_n
972 )?;
973 write!(
974 f,
975 "let mut tmp_lhs: [Reg; {}] = core::mem::zeroed();\n",
976 self.mr_div_n
977 )?;
978
979 if self.need_mask {
980 write!(
981 f,
982 "let last_mask = *(last_mask as *const {});\n",
983 self.mask_ty
984 )?;
985 } else {
986 write!(f, "_ = last_mask;\n")?;
987 }
988
989 if let Some(k) = self.k {
990 for depth in 0..k {
991 write!(f, "let depth = {depth};\n")?;
992 for i in 0..self.mr_div_n {
993 self.write_load_lhs(i, f)?;
994 }
995 for j in 0..self.nr {
996 write!(
997 f,
998 "let tmp_rhs = {}(*rhs.offset(depth * rhs_rs + {j} * rhs_cs));\n",
999 self.set1
1000 )?;
1001
1002 for i in 0..self.mr_div_n {
1003 if depth > 0 {
1004 write!(
1005 f,
1006 "acc[{i}][{j}] = {}(tmp_lhs[{i}], tmp_rhs, acc[{i}][{j}]);\n",
1007 self.mul_add
1008 )?;
1009 } else {
1010 write!(
1011 f,
1012 "acc[{i}][{j}] = {}(tmp_lhs[{i}], tmp_rhs);\n",
1013 self.mul
1014 )?;
1015 }
1016 }
1017 }
1018 }
1019 } else {
1020 write!(f, "for depth in 0..k as isize {{")?;
1021 for i in 0..self.mr_div_n {
1022 self.write_load_lhs(i, f)?;
1023 }
1024 for j in 0..self.nr {
1025 write!(
1026 f,
1027 "let tmp_rhs = {}(*rhs.offset(depth * rhs_rs + {j} * rhs_cs));\n",
1028 self.set1
1029 )?;
1030
1031 for i in 0..self.mr_div_n {
1032 write!(
1033 f,
1034 "acc[{i}][{j}] = {}(tmp_lhs[{i}], tmp_rhs, acc[{i}][{j}]);\n",
1035 self.mul_add
1036 )?;
1037 }
1038 }
1039 write!(f, "}}")?;
1040 }
1041
1042 write!(f, "if alpha == 1.0 {{")?;
1043 write!(f, "let beta = {}(beta);\n", self.set1)?;
1044 for j in 0..self.nr {
1045 for i in 0..self.mr_div_n {
1046 write!(f, "{{")?;
1047 write!(f, "let dst = dst.offset({i} * N + {j} * dst_cs);")?;
1048 if i + 1 < self.mr_div_n || !self.need_mask {
1049 write!(
1050 f,
1051 "{}(dst, {}(beta, acc[{i}][{j}], {}(dst)));\n",
1052 self.store_unaligned, self.mul_add, self.load_unaligned
1053 )?;
1054 } else {
1055 write!(
1056 f,
1057 "{}(dst, last_mask, {}(beta, acc[{i}][{j}], {}));\n",
1058 self.mask_store_unaligned,
1059 self.mul_add,
1060 (self.mask_load_unaligned)(format!("dst"), "last_mask".to_string()),
1061 )?;
1062 }
1063 write!(f, "}}")?;
1064 }
1065 }
1066 write!(f, "}}")?;
1067 write!(f, "else if alpha == 0.0 {{")?;
1068 write!(f, "let beta = {}(beta);\n", self.set1)?;
1069 for j in 0..self.nr {
1070 for i in 0..self.mr_div_n {
1071 write!(f, "{{")?;
1072 write!(f, "let dst = dst.offset({i} * N + {j} * dst_cs);")?;
1073 if i + 1 < self.mr_div_n || !self.need_mask {
1074 write!(
1075 f,
1076 "{}(dst, {}(beta, acc[{i}][{j}]));\n",
1077 self.store_unaligned, self.mul
1078 )?;
1079 } else {
1080 write!(
1081 f,
1082 "{}(dst, last_mask, {}(beta, acc[{i}][{j}]));\n",
1083 self.mask_store_unaligned, self.mul,
1084 )?;
1085 }
1086 write!(f, "}}")?;
1087 }
1088 }
1089 write!(f, "}}")?;
1090 write!(f, "else {{")?;
1091 write!(f, "let beta = {}(beta);\n", self.set1)?;
1092 write!(f, "let alpha = {}(alpha);\n", self.set1)?;
1093 for j in 0..self.nr {
1094 for i in 0..self.mr_div_n {
1095 write!(f, "{{")?;
1096 write!(f, "let dst = dst.offset({i} * N + {j} * dst_cs);")?;
1097 if i + 1 < self.mr_div_n || !self.need_mask {
1098 write!(
1099 f,
1100 "{}(dst, {}(beta, acc[{i}][{j}], {}({}(dst), alpha)));\n",
1101 self.store_unaligned, self.mul_add, self.mul, self.load_unaligned
1102 )?;
1103 } else {
1104 write!(
1105 f,
1106 "{}(dst, last_mask, {}(beta, acc[{i}][{j}], {}({}, alpha)));\n",
1107 self.mask_store_unaligned,
1108 self.mul_add,
1109 self.mul,
1110 (self.mask_load_unaligned)(format!("dst"), "last_mask".to_string()),
1111 )?;
1112 }
1113 write!(f, "}}")?;
1114 }
1115 }
1116 write!(f, "}}")?;
1117
1118 write!(f, "}}\n")
1119 }
1120 }
1121
1122 impl Display for CplxKernel {
1123 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1124 let Self {
1130 swap_re_im,
1131 load_unaligned,
1132 store_unaligned,
1133 mask_store_unaligned,
1134 mul_addsub,
1135 xor,
1136 ..
1137 } = self;
1138
1139 write!(
1140 f,
1141 "
1142 #[target_feature(enable = \"{}\")]\n",
1143 self.target_features
1144 )?;
1145 write!(
1146 f,
1147 r#"pub unsafe fn matmul_{0:}_{1:}_{2:}(
1148 &nano_gemm_core::MicroKernelData {{ alpha, beta, k, dst_cs, lhs_cs, rhs_rs, rhs_cs, last_mask, conj_lhs, conj_rhs }}: &nano_gemm_core::MicroKernelData<num_complex::Complex< {3:} >>,
1149 dst: *mut num_complex::Complex< {3:} >,
1150 lhs: *const num_complex::Complex< {3:} >,
1151 rhs: *const num_complex::Complex< {3:} >,
1152 ) {{
1153"#,
1154 self.mr_div_n,
1155 self.nr,
1156 self.k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
1157 self.ty,
1158 )?;
1159 write!(
1160 f,
1161 r#"
1162 #[cfg(target_arch = "x86_64")]
1163 use core::arch::x86_64::*;
1164 #[cfg(target_arch = "x86")]
1165 use core::arch::x86::*;
1166 "#
1167 )?;
1168
1169 write!(f, "_ = k;\n")?;
1170 write!(f, "type Reg = {};\n", self.reg_ty)?;
1171 write!(f, "const N: isize = {};\n", self.n)?;
1172 write!(
1173 f,
1174 "let mut acc: [[Reg; {}]; {}] = core::mem::zeroed();\n",
1175 self.nr, self.mr_div_n
1176 )?;
1177 write!(
1178 f,
1179 "let mut tmp_lhs: [Reg; {}] = core::mem::zeroed();\n",
1180 self.mr_div_n
1181 )?;
1182
1183 if self.need_mask {
1184 write!(
1185 f,
1186 "let last_mask = *(last_mask as *const {});\n",
1187 self.mask_ty
1188 )?;
1189 } else {
1190 write!(f, "_ = last_mask;\n")?;
1191 }
1192
1193 for idx in 0..2 {
1194 if idx == 0 {
1195 write!(f, "if conj_lhs == conj_rhs {{\n")?;
1196 } else {
1197 write!(f, "else {{\n")?;
1198 }
1199
1200 let mul_add = if idx == 0 {
1201 self.mul_subadd
1202 } else {
1203 self.mul_addsub
1204 };
1205
1206 if let Some(k) = self.k {
1207 for depth in 0..k {
1208 write!(f, "let depth = {depth};\n")?;
1209 for i in 0..self.mr_div_n {
1210 self.write_load_lhs(i, f)?;
1211 }
1212
1213 for j in 0..self.nr {
1214 write!(
1215 f,
1216 "let tmp_rhs = {}((*rhs.offset(depth * rhs_rs + {j} * rhs_cs)).re);\n",
1217 self.set1
1218 )?;
1219
1220 for i in 0..self.mr_div_n {
1221 write!(
1222 f,
1223 "acc[{i}][{j}] = {mul_add}(tmp_lhs[{i}], tmp_rhs, acc[{i}][{j}]);\n",
1224 )?;
1225 }
1226 }
1227 for i in 0..self.mr_div_n {
1228 write!(f, "tmp_lhs[{i}] = {}(tmp_lhs[{i}]);", self.swap_re_im)?;
1229 }
1230 for j in 0..self.nr {
1231 write!(
1232 f,
1233 "let tmp_rhs = {}((*rhs.offset(depth * rhs_rs + {j} * rhs_cs)).im);\n",
1234 self.set1
1235 )?;
1236
1237 for i in 0..self.mr_div_n {
1238 write!(
1239 f,
1240 "acc[{i}][{j}] = {mul_add}(tmp_lhs[{i}], tmp_rhs, acc[{i}][{j}]);\n",
1241 )?;
1242 }
1243 }
1244 }
1245 } else {
1246 write!(f, "for depth in 0..k as isize {{")?;
1247 for i in 0..self.mr_div_n {
1248 self.write_load_lhs(i, f)?;
1249 }
1250
1251 for j in 0..self.nr {
1252 write!(
1253 f,
1254 "let tmp_rhs = {}((*rhs.offset(depth * rhs_rs + {j} * rhs_cs)).re);\n",
1255 self.set1
1256 )?;
1257
1258 for i in 0..self.mr_div_n {
1259 write!(
1260 f,
1261 "acc[{i}][{j}] = {mul_add}(tmp_lhs[{i}], tmp_rhs, acc[{i}][{j}]);\n",
1262 )?;
1263 }
1264 }
1265 for i in 0..self.mr_div_n {
1266 write!(f, "tmp_lhs[{i}] = {}(tmp_lhs[{i}]);", self.swap_re_im)?;
1267 }
1268 for j in 0..self.nr {
1269 write!(
1270 f,
1271 "let tmp_rhs = {}((*rhs.offset(depth * rhs_rs + {j} * rhs_cs)).im);\n",
1272 self.set1
1273 )?;
1274
1275 for i in 0..self.mr_div_n {
1276 write!(
1277 f,
1278 "acc[{i}][{j}] = {mul_add}(tmp_lhs[{i}], tmp_rhs, acc[{i}][{j}]);\n",
1279 )?;
1280 }
1281 }
1282
1283 write!(f, "}}")?;
1284 }
1285 write!(f, "}}")?;
1286 }
1287
1288 write!(
1289 f,
1290 "let mask = XOR_MASKS[conj_lhs as usize + 2 * conj_rhs as usize];"
1291 )?;
1292 for j in 0..self.nr {
1293 for i in 0..self.mr_div_n {
1294 write!(f, "acc[{i}][{j}] = core::mem::transmute({xor}(core::mem::transmute(acc[{i}][{j}]), core::mem::transmute(mask)));")?;
1295 }
1296 }
1297
1298 write!(
1299 f,
1300 "if alpha == (num_complex::Complex {{ re: 1.0, im: 0.0 }}) {{"
1301 )?;
1302 write!(f, "let beta_re = {}(beta.re);\n", self.set1)?;
1303 write!(f, "let beta_im = {}(beta.im);\n", self.set1)?;
1304 for j in 0..self.nr {
1305 for i in 0..self.mr_div_n {
1306 write!(f, "{{")?;
1307 write!(
1308 f,
1309 "let dst = dst.offset({i} * N + {j} * dst_cs) as *mut {};",
1310 self.ty,
1311 )?;
1312 if i + 1 < self.mr_div_n || !self.need_mask {
1313 write!(
1314 f,
1315 "{store_unaligned}(
1316 dst,
1317 {mul_addsub}(
1318 {swap_re_im}(acc[{i}][{j}]),
1319 beta_im,
1320 {mul_addsub}(
1321 acc[{i}][{j}],
1322 beta_re,
1323 {load_unaligned}(dst),
1324 ),
1325 ),
1326 );\n",
1327 )?;
1328 } else {
1329 write!(
1330 f,
1331 "{mask_store_unaligned}(
1332 dst,
1333 last_mask,
1334 {mul_addsub}(
1335 {swap_re_im}(acc[{i}][{j}]),
1336 beta_im,
1337 {mul_addsub}(
1338 acc[{i}][{j}],
1339 beta_re,
1340 {},
1341 ),
1342 ),
1343 );\n",
1344 (self.mask_load_unaligned)(format!("dst"), "last_mask".to_string()),
1345 )?;
1346 }
1347 write!(f, "}}")?;
1348 }
1349 }
1350 write!(f, "}}")?;
1351
1352 write!(
1353 f,
1354 "else if alpha == (num_complex::Complex {{ re: 0.0, im: 0.0 }}) {{"
1355 )?;
1356 write!(f, "let beta_re = {}(beta.re);\n", self.set1)?;
1357 write!(f, "let beta_im = {}(beta.im);\n", self.set1)?;
1358 for j in 0..self.nr {
1359 for i in 0..self.mr_div_n {
1360 write!(f, "{{")?;
1361 write!(
1362 f,
1363 "let dst = dst.offset({i} * N + {j} * dst_cs) as *mut {};",
1364 self.ty,
1365 )?;
1366 if i + 1 < self.mr_div_n || !self.need_mask {
1367 write!(
1368 f,
1369 "{store_unaligned}(
1370 dst,
1371 {mul_addsub}(
1372 {swap_re_im}(acc[{i}][{j}]),
1373 beta_im,
1374 {mul_addsub}(
1375 acc[{i}][{j}],
1376 beta_re,
1377 core::mem::zeroed(),
1378 ),
1379 ),
1380 );\n",
1381 )?;
1382 } else {
1383 write!(
1384 f,
1385 "{mask_store_unaligned}(
1386 dst,
1387 last_mask,
1388 {mul_addsub}(
1389 {swap_re_im}(acc[{i}][{j}]),
1390 beta_im,
1391 {mul_addsub}(
1392 acc[{i}][{j}],
1393 beta_re,
1394 core::mem::zeroed(),
1395 ),
1396 ),
1397 );\n"
1398 )?;
1399 }
1400 write!(f, "}}")?;
1401 }
1402 }
1403 write!(f, "}}")?;
1404 write!(f, "else {{")?;
1405 write!(f, "let beta_re = {}(beta.re);\n", self.set1)?;
1406 write!(f, "let beta_im = {}(beta.im);\n", self.set1)?;
1407 write!(f, "let alpha_re = {}(alpha.re);\n", self.set1)?;
1408 write!(f, "let alpha_im = {}(alpha.im);\n", self.set1)?;
1409 for j in 0..self.nr {
1410 for i in 0..self.mr_div_n {
1411 write!(f, "{{")?;
1412 write!(
1413 f,
1414 "let dst = dst.offset({i} * N + {j} * dst_cs) as *mut {};",
1415 self.ty,
1416 )?;
1417 if i + 1 < self.mr_div_n || !self.need_mask {
1418 write!(
1419 f,
1420 "let dst_conj = core::mem::transmute({xor}(
1421 core::mem::transmute({load_unaligned}(dst)),
1422 core::mem::transmute(XOR_MASKS[1]),
1423 ));"
1424 )?;
1425
1426 write!(
1427 f,
1428 "{store_unaligned}(
1429 dst,
1430 {mul_addsub}(
1431 {swap_re_im}(acc[{i}][{j}]),
1432 beta_im,
1433 {mul_addsub}(
1434 acc[{i}][{j}],
1435 beta_re,
1436 {mul_addsub}(
1437 {swap_re_im}(dst_conj),
1438 alpha_im,
1439 {mul_addsub}(
1440 dst_conj,
1441 alpha_re,
1442 core::mem::zeroed(),
1443 ),
1444 ),
1445 ),
1446 ),
1447 );\n",
1448 )?;
1449 } else {
1450 write!(
1451 f,
1452 "let dst_conj = core::mem::transmute({xor}(
1453 core::mem::transmute({}),
1454 core::mem::transmute(XOR_MASKS[1]),
1455 ));",
1456 (self.mask_load_unaligned)(format!("dst"), "last_mask".to_string())
1457 )?;
1458
1459 write!(
1460 f,
1461 "{mask_store_unaligned}(
1462 dst,
1463 last_mask,
1464 {mul_addsub}(
1465 {swap_re_im}(acc[{i}][{j}]),
1466 beta_im,
1467 {mul_addsub}(
1468 acc[{i}][{j}],
1469 beta_re,
1470 {mul_addsub}(
1471 {swap_re_im}(dst_conj),
1472 alpha_im,
1473 {mul_addsub}(
1474 dst_conj,
1475 alpha_re,
1476 core::mem::zeroed(),
1477 ),
1478 ),
1479 ),
1480 ),
1481 );\n"
1482 )?;
1483 }
1484 write!(f, "}}")?;
1485 }
1486 }
1487 write!(f, "}}")?;
1488
1489 write!(f, "}}\n")
1490 }
1491 }
1492
1493 impl CplxKernel {
1494 fn write_load_lhs(
1495 &self,
1496 i: usize,
1497 f: &mut std::fmt::Formatter<'_>,
1498 ) -> Result<(), std::fmt::Error> {
1499 Ok(if i + 1 < self.mr_div_n || !self.need_mask {
1500 write!(
1501 f,
1502 "tmp_lhs[{i}] = {}(lhs.offset(depth * lhs_cs + {i} * N) as *const {});\n",
1503 self.load_unaligned, self.ty,
1504 )?;
1505 } else {
1506 write!(
1507 f,
1508 "tmp_lhs[{i}] = {};\n",
1509 (self.mask_load_unaligned)(
1510 format!("lhs.offset(depth * lhs_cs + {i} * N) as *const {}", self.ty,),
1511 "last_mask".to_string()
1512 ),
1513 )?;
1514 })
1515 }
1516 }
1517 impl RealKernel {
1518 fn write_load_lhs(
1519 &self,
1520 i: usize,
1521 f: &mut std::fmt::Formatter<'_>,
1522 ) -> Result<(), std::fmt::Error> {
1523 Ok(if i + 1 < self.mr_div_n || !self.need_mask {
1524 write!(
1525 f,
1526 "tmp_lhs[{i}] = {}(lhs.offset(depth * lhs_cs + {i} * N));\n",
1527 self.load_unaligned
1528 )?;
1529 } else {
1530 write!(
1531 f,
1532 "tmp_lhs[{i}] = {};\n",
1533 (self.mask_load_unaligned)(
1534 format!("lhs.offset(depth * lhs_cs + {i} * N)"),
1535 "last_mask".to_string()
1536 ),
1537 )?;
1538 })
1539 }
1540 }
1541
1542 pub fn codegen_f32() -> Result<String, Box<dyn std::error::Error>> {
1543 let mut code = String::new();
1544
1545 write!(code, "pub mod f32 {{\n")?;
1546 write!(code, "pub mod f32x1 {{\n")?;
1547 {
1548 write!(
1549 code,
1550 "pub const MR_DIV_N: usize = 1; pub const NR: usize = 4; pub const N: usize = 1;"
1551 )?;
1552
1553 for mr_div_n in 1..=1 {
1554 for nr in 1..=4 {
1555 for k in (1..=16).into_iter().map(Some).chain([None]) {
1556 let kernel = RealKernel {
1557 ty: "f32",
1558 reg_ty: "__m128",
1559 mask_ty: "__m128i",
1560 mr_div_n,
1561 nr,
1562 k,
1563 target_features: "avx,avx2,fma",
1564 n: 1,
1565 set1: "crate::x86::splat_1s",
1566 load_unaligned: "_mm_load_ss",
1567 store_unaligned: "_mm_store_ss",
1568 mask_load_unaligned: Box::new(|_, _| String::new()),
1569 mask_store_unaligned: "",
1570 mul_add: "_mm_fmadd_ss",
1571 mul: "_mm_mul_ss",
1572 need_mask: false,
1573 };
1574
1575 write!(code, "{kernel}")?;
1576 }
1577 }
1578 }
1579
1580 write!(
1581 code,
1582 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f32>; 4]; 1]; 17] = [\n"
1583 )?;
1584 for k in (1..=16).into_iter().map(Some).chain([None]) {
1585 write!(code, "[\n")?;
1586 for mr_div_n in 1..=1 {
1587 write!(code, "[\n")?;
1588 for nr in 1..=4 {
1589 write!(
1590 code,
1591 "matmul_{mr_div_n}_{nr}_{},",
1592 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
1593 )?;
1594 }
1595 write!(code, "],\n")?;
1596 }
1597 write!(code, "],\n")?;
1598 }
1599 write!(code, "];\n")?;
1600 }
1601 write!(code, "}}\n")?;
1602 write!(code, "pub mod f32x2 {{\n")?;
1603 {
1604 write!(
1605 code,
1606 "pub const MR_DIV_N: usize = 1; pub const NR: usize = 4; pub const N: usize = 2;"
1607 )?;
1608 for mr_div_n in 1..=1 {
1609 for nr in 1..=4 {
1610 for k in (1..=16).into_iter().map(Some).chain([None]) {
1611 let kernel = RealKernel {
1612 need_mask: false,
1613 ty: "f32",
1614 reg_ty: "__m128",
1615 mask_ty: "__m128i",
1616 mr_div_n,
1617 nr,
1618 k,
1619 target_features: "avx,avx2,fma",
1620 n: 2,
1621 set1: "_mm_set1_ps",
1622 load_unaligned: "crate::x86::load_2s",
1623 store_unaligned: "crate::x86::store_2s",
1624 mask_load_unaligned: Box::new(|_, _| String::new()),
1625 mask_store_unaligned: "",
1626 mul_add: "_mm_fmadd_ps",
1627 mul: "_mm_mul_ps",
1628 };
1629
1630 write!(code, "{kernel}")?;
1631 }
1632 }
1633 }
1634
1635 write!(
1636 code,
1637 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f32>; 4]; 1]; 17] = [\n"
1638 )?;
1639 for k in (1..=16).into_iter().map(Some).chain([None]) {
1640 write!(code, "[\n")?;
1641 for mr_div_n in 1..=1 {
1642 write!(code, "[\n")?;
1643 for nr in 1..=4 {
1644 write!(
1645 code,
1646 "matmul_{mr_div_n}_{nr}_{},",
1647 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
1648 )?;
1649 }
1650 write!(code, "],\n")?;
1651 }
1652 write!(code, "],\n")?;
1653 }
1654 write!(code, "];\n")?;
1655 }
1656 write!(code, "}}\n")?;
1657
1658 write!(code, "pub mod f32x4 {{\n")?;
1659 {
1660 write!(
1661 code,
1662 "pub const MR_DIV_N: usize = 1; pub const NR: usize = 4; pub const N: usize = 4;"
1663 )?;
1664 for mr_div_n in 1..=1 {
1665 for nr in 1..=4 {
1666 for k in (1..=16).into_iter().map(Some).chain([None]) {
1667 let kernel = RealKernel {
1668 need_mask: true,
1669 ty: "f32",
1670 reg_ty: "__m128",
1671 mask_ty: "__m128i",
1672 mr_div_n,
1673 nr,
1674 k,
1675 target_features: "avx,avx2,fma",
1676 n: 4,
1677 set1: "_mm_set1_ps",
1678 load_unaligned: "_mm_loadu_ps",
1679 store_unaligned: "_mm_storeu_ps",
1680 mask_load_unaligned: Box::new(|ptr, mask| {
1681 format!("_mm_maskload_ps({ptr}, {mask})")
1682 }),
1683 mask_store_unaligned: "_mm_maskstore_ps",
1684 mul_add: "_mm_fmadd_ps",
1685 mul: "_mm_mul_ps",
1686 };
1687
1688 write!(code, "{kernel}")?;
1689 }
1690 }
1691 }
1692
1693 write!(
1694 code,
1695 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f32>; 4]; 1]; 17] = [\n"
1696 )?;
1697 for k in (1..=16).into_iter().map(Some).chain([None]) {
1698 write!(code, "[\n")?;
1699 for mr_div_n in 1..=1 {
1700 write!(code, "[\n")?;
1701 for nr in 1..=4 {
1702 write!(
1703 code,
1704 "matmul_{mr_div_n}_{nr}_{},",
1705 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
1706 )?;
1707 }
1708 write!(code, "],\n")?;
1709 }
1710 write!(code, "],\n")?;
1711 }
1712 write!(code, "];\n")?;
1713 write!(
1714 code,
1715 "
1716 pub static MASKS: [crate::x86::__m128i; 4] = unsafe {{ core::mem::transmute([
1717 [u32::MAX, u32::MAX, u32::MAX, u32::MAX],
1718
1719 [u32::MAX, 0, 0, 0],
1720 [u32::MAX, u32::MAX, 0, 0],
1721 [u32::MAX, u32::MAX, u32::MAX, 0],
1722 ]) }};
1723 "
1724 )?;
1725 }
1726 write!(code, "}}\n")?;
1727 write!(code, "pub mod avx {{\n")?;
1728 {
1729 write!(
1730 code,
1731 "pub const MR_DIV_N: usize = 2; pub const NR: usize = 4; pub const N: usize = 8;"
1732 )?;
1733
1734 for mr_div_n in 1..=2 {
1735 for nr in 1..=4 {
1736 for k in (1..=16).into_iter().map(Some).chain([None]) {
1737 let kernel = RealKernel {
1738 need_mask: true,
1739 ty: "f32",
1740 reg_ty: "__m256",
1741 mask_ty: "__m256i",
1742 mr_div_n,
1743 nr,
1744 k,
1745 target_features: "avx,avx2,fma",
1746 n: 8,
1747 set1: "_mm256_set1_ps",
1748 load_unaligned: "_mm256_loadu_ps",
1749 store_unaligned: "_mm256_storeu_ps",
1750 mask_load_unaligned: Box::new(|ptr, mask| {
1751 format!("_mm256_maskload_ps({ptr}, {mask})")
1752 }),
1753 mask_store_unaligned: "_mm256_maskstore_ps",
1754 mul_add: "_mm256_fmadd_ps",
1755 mul: "_mm256_mul_ps",
1756 };
1757
1758 write!(code, "{kernel}")?;
1759 }
1760 }
1761 }
1762
1763 write!(
1764 code,
1765 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f32>; 4]; 2]; 17] = [\n"
1766 )?;
1767 for k in (1..=16).into_iter().map(Some).chain([None]) {
1768 write!(code, "[\n")?;
1769 for mr_div_n in 1..=2 {
1770 write!(code, "[\n")?;
1771 for nr in 1..=4 {
1772 write!(
1773 code,
1774 "matmul_{mr_div_n}_{nr}_{},",
1775 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
1776 )?;
1777 }
1778 write!(code, "],\n")?;
1779 }
1780 write!(code, "],\n")?;
1781 }
1782 write!(code, "];\n")?;
1783 write!(
1784 code,
1785 "
1786 pub static MASKS: [crate::x86::__m256i; 8] = unsafe {{ core::mem::transmute([
1787 [u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX],
1788
1789 [u32::MAX, 0, 0, 0, 0, 0, 0, 0],
1790 [u32::MAX, u32::MAX, 0, 0, 0, 0, 0, 0],
1791 [u32::MAX, u32::MAX, u32::MAX, 0, 0, 0, 0, 0],
1792 [u32::MAX, u32::MAX, u32::MAX, u32::MAX, 0, 0, 0, 0],
1793 [u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, 0, 0, 0],
1794 [u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, 0, 0],
1795 [u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, 0],
1796 ]) }};
1797 "
1798 )?;
1799 }
1800 write!(code, "}}\n")?;
1801
1802 write!(code, "#[cfg(feature = \"x86-v4\")] pub mod avx512 {{\n")?;
1803 {
1804 write!(
1805 code,
1806 "pub const MR_DIV_N: usize = 2; pub const NR: usize = 4; pub const N: usize = 16;"
1807 )?;
1808
1809 for mr_div_n in 1..=2 {
1810 for nr in 1..=4 {
1811 for k in (1..=16).into_iter().map(Some).chain([None]) {
1812 let kernel = RealKernel {
1813 need_mask: true,
1814 ty: "f32",
1815 reg_ty: "__m512",
1816 mask_ty: "u16",
1817 mr_div_n,
1818 nr,
1819 k,
1820 target_features: "avx512f",
1821 n: 16,
1822 set1: "_mm512_set1_ps",
1823 load_unaligned: "_mm512_loadu_ps",
1824 store_unaligned: "_mm512_storeu_ps",
1825 mask_load_unaligned: Box::new(|ptr, mask| {
1826 format!("_mm512_maskz_loadu_ps({mask}, {ptr})")
1827 }),
1828 mask_store_unaligned: "_mm512_mask_storeu_ps",
1829 mul_add: "_mm512_fmadd_ps",
1830 mul: "_mm512_mul_ps",
1831 };
1832
1833 write!(code, "{kernel}")?;
1834 }
1835 }
1836 }
1837
1838 write!(
1839 code,
1840 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f32>; 4]; 2]; 17] = [\n"
1841 )?;
1842 for k in (1..=16).into_iter().map(Some).chain([None]) {
1843 write!(code, "[\n")?;
1844 for mr_div_n in 1..=2 {
1845 write!(code, "[\n")?;
1846 for nr in 1..=4 {
1847 write!(
1848 code,
1849 "matmul_{mr_div_n}_{nr}_{},",
1850 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
1851 )?;
1852 }
1853 write!(code, "],\n")?;
1854 }
1855 write!(code, "],\n")?;
1856 }
1857 write!(code, "];\n")?;
1858 write!(
1859 code,
1860 "
1861 pub static MASKS: [u16; 16] = [
1862 0b1111_1111_1111_1111,
1863 0b0000_0000_0000_0001,
1864 0b0000_0000_0000_0011,
1865 0b0000_0000_0000_0111,
1866 0b0000_0000_0000_1111,
1867 0b0000_0000_0001_1111,
1868 0b0000_0000_0011_1111,
1869 0b0000_0000_0111_1111,
1870 0b0000_0000_1111_1111,
1871 0b0000_0001_1111_1111,
1872 0b0000_0011_1111_1111,
1873 0b0000_0111_1111_1111,
1874 0b0000_1111_1111_1111,
1875 0b0001_1111_1111_1111,
1876 0b0011_1111_1111_1111,
1877 0b0111_1111_1111_1111,
1878 ];
1879 "
1880 )?;
1881 }
1882 write!(code, "}}\n")?;
1883 write!(code, "}}\n")?;
1884
1885 Ok(code)
1886 }
1887
1888 pub fn codegen_f64() -> Result<String, Box<dyn std::error::Error>> {
1889 let mut code = String::new();
1890
1891 write!(code, "pub mod f64 {{\n")?;
1892 write!(code, "pub mod f64x1 {{\n")?;
1893 {
1894 write!(
1895 code,
1896 "pub const MR_DIV_N: usize = 1; pub const NR: usize = 4; pub const N: usize = 1;"
1897 )?;
1898
1899 for mr_div_n in 1..=1 {
1900 for nr in 1..=4 {
1901 for k in (1..=16).into_iter().map(Some).chain([None]) {
1902 let kernel = RealKernel {
1903 need_mask: false,
1904 ty: "f64",
1905 reg_ty: "__m128d",
1906 mask_ty: "__m128i",
1907 mr_div_n,
1908 nr,
1909 k,
1910 target_features: "avx,avx2,fma",
1911 n: 1,
1912 set1: "crate::x86::splat_1d",
1913 load_unaligned: "_mm_load_sd",
1914 store_unaligned: "_mm_store_sd",
1915 mask_load_unaligned: Box::new(|_, _| String::new()),
1916 mask_store_unaligned: "",
1917 mul_add: "_mm_fmadd_sd",
1918 mul: "_mm_mul_sd",
1919 };
1920
1921 write!(code, "{kernel}")?;
1922 }
1923 }
1924 }
1925
1926 write!(
1927 code,
1928 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f64>; 4]; 1]; 17] = [\n"
1929 )?;
1930 for k in (1..=16).into_iter().map(Some).chain([None]) {
1931 write!(code, "[\n")?;
1932 for mr_div_n in 1..=1 {
1933 write!(code, "[\n")?;
1934 for nr in 1..=4 {
1935 write!(
1936 code,
1937 "matmul_{mr_div_n}_{nr}_{},",
1938 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
1939 )?;
1940 }
1941 write!(code, "],\n")?;
1942 }
1943 write!(code, "],\n")?;
1944 }
1945 write!(code, "];\n")?;
1946 }
1947 write!(code, "}}\n")?;
1948 write!(code, "pub mod f64x2 {{\n")?;
1949 {
1950 write!(
1951 code,
1952 "pub const MR_DIV_N: usize = 1; pub const NR: usize = 4; pub const N: usize = 2;"
1953 )?;
1954 for mr_div_n in 1..=1 {
1955 for nr in 1..=4 {
1956 for k in (1..=16).into_iter().map(Some).chain([None]) {
1957 let kernel = RealKernel {
1958 need_mask: false,
1959 ty: "f64",
1960 reg_ty: "__m128d",
1961 mask_ty: "__m128i",
1962 mr_div_n,
1963 nr,
1964 k,
1965 target_features: "avx,avx2,fma",
1966 n: 2,
1967 set1: "_mm_set1_pd",
1968 load_unaligned: "_mm_loadu_pd",
1969 store_unaligned: "_mm_storeu_pd",
1970 mask_load_unaligned: Box::new(|_, _| String::new()),
1971 mask_store_unaligned: "",
1972 mul_add: "_mm_fmadd_pd",
1973 mul: "_mm_mul_pd",
1974 };
1975
1976 write!(code, "{kernel}")?;
1977 }
1978 }
1979 }
1980
1981 write!(
1982 code,
1983 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f64>; 4]; 1]; 17] = [\n"
1984 )?;
1985 for k in (1..=16).into_iter().map(Some).chain([None]) {
1986 write!(code, "[\n")?;
1987 for mr_div_n in 1..=1 {
1988 write!(code, "[\n")?;
1989 for nr in 1..=4 {
1990 write!(
1991 code,
1992 "matmul_{mr_div_n}_{nr}_{},",
1993 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
1994 )?;
1995 }
1996 write!(code, "],\n")?;
1997 }
1998 write!(code, "],\n")?;
1999 }
2000 write!(code, "];\n")?;
2001 }
2002 write!(code, "}}\n")?;
2003
2004 write!(
2005 code,
2006 "
2007 pub mod avx {{\n"
2008 )?;
2009 {
2010 write!(
2011 code,
2012 "pub const MR_DIV_N: usize = 2; pub const NR: usize = 4; pub const N: usize = 4;"
2013 )?;
2014
2015 {
2016 for mr_div_n in 1..=2 {
2017 for nr in 1..=4 {
2018 for k in (1..=16).into_iter().map(Some).chain([None]) {
2019 let kernel = RealKernel {
2020 need_mask: true,
2021 ty: "f64",
2022 reg_ty: "__m256d",
2023 mask_ty: "__m256i",
2024 mr_div_n,
2025 nr,
2026 k,
2027 target_features: "avx,avx2,fma",
2028 n: 4,
2029 set1: "_mm256_set1_pd",
2030 load_unaligned: "_mm256_loadu_pd",
2031 store_unaligned: "_mm256_storeu_pd",
2032 mask_load_unaligned: Box::new(|ptr, mask| {
2033 format!("_mm256_maskload_pd({ptr}, {mask})")
2034 }),
2035 mask_store_unaligned: "_mm256_maskstore_pd",
2036 mul_add: "_mm256_fmadd_pd",
2037 mul: "_mm256_mul_pd",
2038 };
2039
2040 write!(code, "{kernel}")?;
2041 }
2042 }
2043 }
2044
2045 write!(
2046 code,
2047 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f64>; 4]; 2]; 17] = [\n"
2048 )?;
2049 for k in (1..=16).into_iter().map(Some).chain([None]) {
2050 write!(code, "[\n")?;
2051 for mr_div_n in 1..=2 {
2052 write!(code, "[\n")?;
2053 for nr in 1..=4 {
2054 write!(
2055 code,
2056 "matmul_{mr_div_n}_{nr}_{},",
2057 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2058 )?;
2059 }
2060 write!(code, "],\n")?;
2061 }
2062 write!(code, "],\n")?;
2063 }
2064 write!(code, "];\n")?;
2065 write!(
2066 code,
2067 "
2068 pub static MASKS: [crate::x86::__m256i; 4] = unsafe {{ core::mem::transmute([
2069 [u64::MAX, u64::MAX, u64::MAX, u64::MAX],
2070
2071 [u64::MAX, 0, 0, 0],
2072 [u64::MAX, u64::MAX, 0, 0],
2073 [u64::MAX, u64::MAX, u64::MAX, 0],
2074 ]) }};
2075 "
2076 )?;
2077 }
2078 }
2079 write!(code, "}}\n")?;
2080
2081 write!(
2082 code,
2083 "
2084 #[cfg(feature = \"x86-v4\")]
2085 pub mod avx512 {{\n"
2086 )?;
2087 {
2088 write!(
2089 code,
2090 "pub const MR_DIV_N: usize = 2; pub const NR: usize = 4; pub const N: usize = 8;"
2091 )?;
2092
2093 for mr_div_n in 1..=2 {
2094 for nr in 1..=4 {
2095 for k in (1..=16).into_iter().map(Some).chain([None]) {
2096 let kernel = RealKernel {
2097 need_mask: true,
2098 ty: "f64",
2099 reg_ty: "__m512d",
2100 mask_ty: "u8",
2101 mr_div_n,
2102 nr,
2103 k,
2104 target_features: "avx512f",
2105 n: 8,
2106 set1: "_mm512_set1_pd",
2107 load_unaligned: "_mm512_loadu_pd",
2108 store_unaligned: "_mm512_storeu_pd",
2109 mask_load_unaligned: Box::new(|ptr, mask| {
2110 format!("_mm512_maskz_loadu_pd({mask}, {ptr})")
2111 }),
2112 mask_store_unaligned: "_mm512_mask_storeu_pd",
2113 mul_add: "_mm512_fmadd_pd",
2114 mul: "_mm512_mul_pd",
2115 };
2116
2117 write!(code, "{kernel}")?;
2118 }
2119 }
2120 }
2121
2122 write!(
2123 code,
2124 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<f64>; 4]; 2]; 17] = [\n"
2125 )?;
2126 for k in (1..=16).into_iter().map(Some).chain([None]) {
2127 write!(code, "[\n")?;
2128 for mr_div_n in 1..=2 {
2129 write!(code, "[\n")?;
2130 for nr in 1..=4 {
2131 write!(
2132 code,
2133 "matmul_{mr_div_n}_{nr}_{},",
2134 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2135 )?;
2136 }
2137 write!(code, "],\n")?;
2138 }
2139 write!(code, "],\n")?;
2140 }
2141 write!(code, "];\n")?;
2142 write!(
2143 code,
2144 "
2145 pub static MASKS: [u8; 8] = [
2146 0b1111_1111,
2147 0b0000_0001,
2148 0b0000_0011,
2149 0b0000_0111,
2150 0b0000_1111,
2151 0b0001_1111,
2152 0b0011_1111,
2153 0b0111_1111,
2154 ];
2155 "
2156 )?;
2157 }
2158 write!(code, "}}\n")?;
2159 write!(code, "}}\n")?;
2160
2161 Ok(code)
2162 }
2163
2164 pub fn codegen_c32() -> Result<String, Box<dyn std::error::Error>> {
2165 let mut code = String::new();
2166
2167 write!(code, "pub mod c32 {{\n")?;
2168 write!(code, "pub mod c32x1 {{\n")?;
2169 {
2170 write!(
2171 code,
2172 "pub const MR_DIV_N: usize = 1; pub const NR: usize = 2; pub const N: usize = 1;"
2173 )?;
2174 write!(
2175 code,
2176 "const XOR_MASKS: [crate::x86::__m128; 4] = unsafe {{ core::mem::transmute([
2177 [-0.0, -0.0, -0.0, -0.0f32], // no conj
2178 [ 0.0, -0.0, 0.0, -0.0f32], // conj lhs
2179 [ 0.0, 0.0, 0.0, 0.0f32], // conj rhs
2180 [-0.0, 0.0, -0.0, 0.0f32], // conj both
2181 ]) }};\n"
2182 )?;
2183
2184 for mr_div_n in 1..=1 {
2185 for nr in 1..=2 {
2186 for k in (1..=16).into_iter().map(Some).chain([None]) {
2187 let kernel = CplxKernel {
2188 need_mask: false,
2189 ty: "f32",
2190 reg_ty: "__m128",
2191 mask_ty: "__m128i",
2192 mr_div_n,
2193 nr,
2194 k,
2195 target_features: "avx,avx2,fma",
2196 n: 1,
2197 set1: "_mm_set1_ps",
2198 load_unaligned: "crate::x86::load_2s",
2199 store_unaligned: "crate::x86::store_2s",
2200 mask_load_unaligned: Box::new(|ptr, mask| {
2201 format!("_mm_maskload_ps({ptr}, {mask})")
2202 }),
2203 mask_store_unaligned: "_mm_maskstore_ps",
2204 swap_re_im: "_mm_permute_ps::<0b10_11_00_01>",
2205 mul_addsub: "_mm_fmsubadd_ps",
2206 mul_subadd: "_mm_fmaddsub_ps",
2207 xor: "_mm_xor_ps",
2208 };
2209
2210 write!(code, "{kernel}")?;
2211 }
2212 }
2213 }
2214
2215 write!(
2216 code,
2217 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<num_complex::Complex<f32>>; 2]; 1]; 17] = [\n"
2218 )?;
2219 for k in (1..=16).into_iter().map(Some).chain([None]) {
2220 write!(code, "[\n")?;
2221 for mr_div_n in 1..=1 {
2222 write!(code, "[\n")?;
2223 for nr in 1..=2 {
2224 write!(
2225 code,
2226 "matmul_{mr_div_n}_{nr}_{},",
2227 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2228 )?;
2229 }
2230 write!(code, "],\n")?;
2231 }
2232 write!(code, "],\n")?;
2233 }
2234 write!(code, "];\n")?;
2235 }
2236 write!(code, "}}\n")?;
2237
2238 write!(code, "pub mod c32x2 {{\n")?;
2239 {
2240 write!(
2241 code,
2242 "pub const MR_DIV_N: usize = 1; pub const NR: usize = 2; pub const N: usize = 2;"
2243 )?;
2244 write!(
2245 code,
2246 "const XOR_MASKS: [crate::x86::__m128; 4] = unsafe {{ core::mem::transmute([
2247 [-0.0, -0.0, -0.0, -0.0f32], // no conj
2248 [ 0.0, -0.0, 0.0, -0.0f32], // conj lhs
2249 [ 0.0, 0.0, 0.0, 0.0f32], // conj rhs
2250 [-0.0, 0.0, -0.0, 0.0f32], // conj both
2251 ]) }};\n"
2252 )?;
2253
2254 for mr_div_n in 1..=1 {
2255 for nr in 1..=2 {
2256 for k in (1..=16).into_iter().map(Some).chain([None]) {
2257 let kernel = CplxKernel {
2258 need_mask: false,
2259 ty: "f32",
2260 reg_ty: "__m128",
2261 mask_ty: "__m128i",
2262 mr_div_n,
2263 nr,
2264 k,
2265 target_features: "avx,avx2,fma",
2266 n: 2,
2267 set1: "_mm_set1_ps",
2268 load_unaligned: "_mm_loadu_ps",
2269 store_unaligned: "_mm_storeu_ps",
2270 mask_load_unaligned: Box::new(|ptr, mask| {
2271 format!("_mm_maskload_ps({ptr}, {mask})")
2272 }),
2273 mask_store_unaligned: "_mm_maskstore_ps",
2274 swap_re_im: "_mm_permute_ps::<0b10_11_00_01>",
2275 mul_addsub: "_mm_fmsubadd_ps",
2276 mul_subadd: "_mm_fmaddsub_ps",
2277 xor: "_mm_xor_ps",
2278 };
2279
2280 write!(code, "{kernel}")?;
2281 }
2282 }
2283 }
2284
2285 write!(
2286 code,
2287 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<num_complex::Complex<f32>>; 2]; 1]; 17] = [\n"
2288 )?;
2289 for k in (1..=16).into_iter().map(Some).chain([None]) {
2290 write!(code, "[\n")?;
2291 for mr_div_n in 1..=1 {
2292 write!(code, "[\n")?;
2293 for nr in 1..=2 {
2294 write!(
2295 code,
2296 "matmul_{mr_div_n}_{nr}_{},",
2297 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2298 )?;
2299 }
2300 write!(code, "],\n")?;
2301 }
2302 write!(code, "],\n")?;
2303 }
2304 write!(code, "];\n")?;
2305 }
2306 write!(code, "}}\n")?;
2307 write!(code, " pub mod avx {{\n")?;
2308 {
2309 write!(
2310 code,
2311 "pub const MR_DIV_N: usize = 2; pub const NR: usize = 2; pub const N: usize = 4;"
2312 )?;
2313 write!(
2314 code,
2315 "const XOR_MASKS: [crate::x86::__m256; 4] = unsafe {{ core::mem::transmute([
2316 [-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0f32], // no conj
2317 [ 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0f32], // conj lhs
2318 [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0f32], // conj rhs
2319 [-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0f32], // conj both
2320 ]) }};\n"
2321 )?;
2322
2323 for mr_div_n in 1..=2 {
2324 for nr in 1..=2 {
2325 for k in (1..=16).into_iter().map(Some).chain([None]) {
2326 let kernel = CplxKernel {
2327 need_mask: true,
2328 ty: "f32",
2329 reg_ty: "__m256",
2330 mask_ty: "__m256i",
2331 mr_div_n,
2332 nr,
2333 k,
2334 target_features: "avx,avx2,fma",
2335 n: 4,
2336 set1: "_mm256_set1_ps",
2337 load_unaligned: "_mm256_loadu_ps",
2338 store_unaligned: "_mm256_storeu_ps",
2339 mask_load_unaligned: Box::new(|ptr, mask| {
2340 format!("_mm256_maskload_ps({ptr}, {mask})")
2341 }),
2342 mask_store_unaligned: "_mm256_maskstore_ps",
2343 swap_re_im: "_mm256_permute_ps::<0b10_11_00_01>",
2344 mul_addsub: "_mm256_fmsubadd_ps",
2345 mul_subadd: "_mm256_fmaddsub_ps",
2346 xor: "_mm256_xor_ps",
2347 };
2348
2349 write!(code, "{kernel}")?;
2350 }
2351 }
2352 }
2353
2354 write!(
2355 code,
2356 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<num_complex::Complex<f32>>; 2]; 2]; 17] = [\n"
2357 )?;
2358 for k in (1..=16).into_iter().map(Some).chain([None]) {
2359 write!(code, "[\n")?;
2360 for mr_div_n in 1..=2 {
2361 write!(code, "[\n")?;
2362 for nr in 1..=2 {
2363 write!(
2364 code,
2365 "matmul_{mr_div_n}_{nr}_{},",
2366 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2367 )?;
2368 }
2369 write!(code, "],\n")?;
2370 }
2371 write!(code, "],\n")?;
2372 }
2373 write!(code, "];\n")?;
2374 write!(
2375 code,
2376 "
2377 pub static MASKS: [crate::x86::__m256i; 4] = unsafe {{ core::mem::transmute([
2378 [u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX],
2379
2380 [u32::MAX, u32::MAX, 0, 0, 0, 0, 0, 0],
2381 [u32::MAX, u32::MAX, u32::MAX, u32::MAX, 0, 0, 0, 0],
2382 [u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX, 0, 0],
2383 ]) }};
2384 "
2385 )?;
2386 }
2387 write!(code, "}}\n")?;
2388 write!(
2389 code,
2390 "
2391 #[cfg(feature = \"x86-v4\")]
2392 pub mod avx512 {{\n"
2393 )?;
2394
2395 {
2396 write!(
2397 code,
2398 "pub const MR_DIV_N: usize = 2; pub const NR: usize = 2; pub const N: usize = 8;"
2399 )?;
2400 write!(
2401 code,
2402 "const XOR_MASKS: [crate::x86::__m512; 4] = unsafe {{ core::mem::transmute([
2403 [-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0f32], // no conj
2404 [ 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0f32], // conj lhs
2405 [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0f32], // conj rhs
2406 [-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0f32], // conj both
2407 ]) }};\n"
2408 )?;
2409
2410 for mr_div_n in 1..=2 {
2411 for nr in 1..=2 {
2412 for k in (1..=16).into_iter().map(Some).chain([None]) {
2413 let kernel = CplxKernel {
2414 need_mask: true,
2415 ty: "f32",
2416 reg_ty: "__m512",
2417 mask_ty: "u16",
2418 mr_div_n,
2419 nr,
2420 k,
2421 target_features: "avx512f",
2422 n: 8,
2423 set1: "_mm512_set1_ps",
2424 load_unaligned: "_mm512_loadu_ps",
2425 store_unaligned: "_mm512_storeu_ps",
2426 mask_load_unaligned: Box::new(|ptr, mask| {
2427 format!("_mm512_maskz_loadu_ps({mask}, {ptr})")
2428 }),
2429 mask_store_unaligned: "_mm512_mask_storeu_ps",
2430 swap_re_im: "_mm512_permute_ps::<0b10_11_00_01>",
2431 mul_addsub: "crate::x86::subadd_ps",
2432 mul_subadd: "_mm512_fmaddsub_ps",
2433 xor: "_mm512_xor_si512",
2434 };
2435
2436 write!(code, "{kernel}")?;
2437 }
2438 }
2439 }
2440
2441 write!(
2442 code,
2443 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<num_complex::Complex<f32>>; 2]; 2]; 17] = [\n"
2444 )?;
2445 for k in (1..=16).into_iter().map(Some).chain([None]) {
2446 write!(code, "[\n")?;
2447 for mr_div_n in 1..=2 {
2448 write!(code, "[\n")?;
2449 for nr in 1..=2 {
2450 write!(
2451 code,
2452 "matmul_{mr_div_n}_{nr}_{},",
2453 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2454 )?;
2455 }
2456 write!(code, "],\n")?;
2457 }
2458 write!(code, "],\n")?;
2459 }
2460 write!(code, "];\n")?;
2461 write!(
2462 code,
2463 "
2464 pub static MASKS: [u16; 8] = [
2465 0b1111_1111_1111_1111,
2466 0b0000_0000_0000_0011,
2467 0b0000_0000_0000_1111,
2468 0b0000_0000_0011_1111,
2469 0b0000_0000_1111_1111,
2470 0b0000_0011_1111_1111,
2471 0b0000_1111_1111_1111,
2472 0b0011_1111_1111_1111,
2473 ];
2474 "
2475 )?;
2476 }
2477 write!(code, "}}\n")?;
2478 write!(code, "}}\n")?;
2479
2480 Ok(code)
2481 }
2482
2483 pub fn codegen_c64() -> Result<String, Box<dyn std::error::Error>> {
2484 let mut code = String::new();
2485
2486 write!(code, "pub mod c64 {{\n")?;
2487 write!(code, "pub mod c64x1 {{\n")?;
2488 {
2489 write!(
2490 code,
2491 "pub const MR_DIV_N: usize = 1; pub const NR: usize = 2; pub const N: usize = 1;"
2492 )?;
2493 write!(
2494 code,
2495 "const XOR_MASKS: [crate::x86::__m128d; 4] = unsafe {{ core::mem::transmute([
2496 [-0.0, -0.0f64], // no conj
2497 [ 0.0, -0.0f64], // conj lhs
2498 [ 0.0, 0.0f64], // conj rhs
2499 [-0.0, 0.0f64], // conj both
2500 ]) }};\n"
2501 )?;
2502
2503 for mr_div_n in 1..=1 {
2504 for nr in 1..=2 {
2505 for k in (1..=16).into_iter().map(Some).chain([None]) {
2506 let kernel = CplxKernel {
2507 need_mask: false,
2508 ty: "f64",
2509 reg_ty: "__m128d",
2510 mask_ty: "__m128i",
2511 mr_div_n,
2512 nr,
2513 k,
2514 target_features: "avx,avx2,fma",
2515 n: 1,
2516 set1: "_mm_set1_pd",
2517 load_unaligned: "_mm_loadu_pd",
2518 store_unaligned: "_mm_storeu_pd",
2519 mask_load_unaligned: Box::new(|ptr, mask| {
2520 format!("_mm_maskload_pd({ptr}, {mask})")
2521 }),
2522 mask_store_unaligned: "_mm_maskstore_pd",
2523 swap_re_im: "_mm_permute_pd::<0b01>",
2524 mul_addsub: "_mm_fmsubadd_pd",
2525 mul_subadd: "_mm_fmaddsub_pd",
2526 xor: "_mm_xor_pd",
2527 };
2528
2529 write!(code, "{kernel}")?;
2530 }
2531 }
2532 }
2533
2534 write!(
2535 code,
2536 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<num_complex::Complex<f64>>; 2]; 1]; 17] = [\n"
2537 )?;
2538 for k in (1..=16).into_iter().map(Some).chain([None]) {
2539 write!(code, "[\n")?;
2540 for mr_div_n in 1..=1 {
2541 write!(code, "[\n")?;
2542 for nr in 1..=2 {
2543 write!(
2544 code,
2545 "matmul_{mr_div_n}_{nr}_{},",
2546 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2547 )?;
2548 }
2549 write!(code, "],\n")?;
2550 }
2551 write!(code, "],\n")?;
2552 }
2553 write!(code, "];\n")?;
2554 }
2555 write!(code, "}}\n")?;
2556
2557 write!(code, "pub mod avx {{\n")?;
2558 {
2559 write!(
2560 code,
2561 "pub const MR_DIV_N: usize = 2; pub const NR: usize = 2; pub const N: usize = 2;"
2562 )?;
2563 write!(
2564 code,
2565 "const XOR_MASKS: [crate::x86::__m256d; 4] = unsafe {{ core::mem::transmute([
2566 [-0.0, -0.0, -0.0, -0.0f64], // no conj
2567 [ 0.0, -0.0, 0.0, -0.0f64], // conj lhs
2568 [ 0.0, 0.0, 0.0, 0.0f64], // conj rhs
2569 [-0.0, 0.0, -0.0, 0.0f64], // conj both
2570 ]) }};\n"
2571 )?;
2572
2573 for mr_div_n in 1..=2 {
2574 for nr in 1..=2 {
2575 for k in (1..=16).into_iter().map(Some).chain([None]) {
2576 let kernel = CplxKernel {
2577 need_mask: true,
2578 ty: "f64",
2579 reg_ty: "__m256d",
2580 mask_ty: "__m256i",
2581 mr_div_n,
2582 nr,
2583 k,
2584 target_features: "avx,avx2,fma",
2585 n: 2,
2586 set1: "_mm256_set1_pd",
2587 load_unaligned: "_mm256_loadu_pd",
2588 store_unaligned: "_mm256_storeu_pd",
2589 mask_load_unaligned: Box::new(|ptr, mask| {
2590 format!("_mm256_maskload_pd({ptr}, {mask})")
2591 }),
2592 mask_store_unaligned: "_mm256_maskstore_pd",
2593 swap_re_im: "_mm256_permute_pd::<0b0101>",
2594 mul_addsub: "_mm256_fmsubadd_pd",
2595 mul_subadd: "_mm256_fmaddsub_pd",
2596 xor: "_mm256_xor_pd",
2597 };
2598
2599 write!(code, "{kernel}")?;
2600 }
2601 }
2602 }
2603
2604 write!(
2605 code,
2606 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<num_complex::Complex<f64>>; 2]; 2]; 17] = [\n"
2607 )?;
2608 for k in (1..=16).into_iter().map(Some).chain([None]) {
2609 write!(code, "[\n")?;
2610 for mr_div_n in 1..=2 {
2611 write!(code, "[\n")?;
2612 for nr in 1..=2 {
2613 write!(
2614 code,
2615 "matmul_{mr_div_n}_{nr}_{},",
2616 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2617 )?;
2618 }
2619 write!(code, "],\n")?;
2620 }
2621 write!(code, "],\n")?;
2622 }
2623 write!(code, "];\n")?;
2624 write!(
2625 code,
2626 "
2627 pub static MASKS: [crate::x86::__m256i; 2] = unsafe {{ core::mem::transmute([
2628 [u64::MAX, u64::MAX, u64::MAX, u64::MAX],
2629
2630 [u64::MAX, u64::MAX, 0, 0],
2631 ]) }};
2632 "
2633 )?;
2634 }
2635 write!(code, "}}\n")?;
2636 write!(
2637 code,
2638 "
2639 #[cfg(feature = \"x86-v4\")]
2640 pub mod avx512 {{\n"
2641 )?;
2642
2643 {
2644 write!(
2645 code,
2646 "pub const MR_DIV_N: usize = 2; pub const NR: usize = 2; pub const N: usize = 4;"
2647 )?;
2648 write!(
2649 code,
2650 "const XOR_MASKS: [crate::x86::__m512; 4] = unsafe {{ core::mem::transmute([
2651 [-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0f64], // no conj
2652 [ 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0f64], // conj lhs
2653 [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0f64], // conj rhs
2654 [-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0f64], // conj both
2655 ]) }};\n"
2656 )?;
2657
2658 for mr_div_n in 1..=2 {
2659 for nr in 1..=2 {
2660 for k in (1..=16).into_iter().map(Some).chain([None]) {
2661 let kernel = CplxKernel {
2662 need_mask: true,
2663 ty: "f64",
2664 reg_ty: "__m512d",
2665 mask_ty: "u8",
2666 mr_div_n,
2667 nr,
2668 k,
2669 target_features: "avx512f",
2670 n: 4,
2671 set1: "_mm512_set1_pd",
2672 load_unaligned: "_mm512_loadu_pd",
2673 store_unaligned: "_mm512_storeu_pd",
2674 mask_load_unaligned: Box::new(|ptr, mask| {
2675 format!("_mm512_maskz_loadu_pd({mask}, {ptr})")
2676 }),
2677 mask_store_unaligned: "_mm512_mask_storeu_pd",
2678 swap_re_im: "_mm512_permute_pd::<0b01010101>",
2679 mul_addsub: "crate::x86::subadd_pd",
2680 mul_subadd: "_mm512_fmaddsub_pd",
2681 xor: "_mm512_xor_si512",
2682 };
2683
2684 write!(code, "{kernel}")?;
2685 }
2686 }
2687 }
2688
2689 write!(
2690 code,
2691 "pub static MICROKERNELS: [[[nano_gemm_core::MicroKernel<num_complex::Complex<f64>>; 2]; 2]; 17] = [\n"
2692 )?;
2693 for k in (1..=16).into_iter().map(Some).chain([None]) {
2694 write!(code, "[\n")?;
2695 for mr_div_n in 1..=2 {
2696 write!(code, "[\n")?;
2697 for nr in 1..=2 {
2698 write!(
2699 code,
2700 "matmul_{mr_div_n}_{nr}_{},",
2701 k.map(|k| k.to_string()).unwrap_or("dyn".to_string()),
2702 )?;
2703 }
2704 write!(code, "],\n")?;
2705 }
2706 write!(code, "],\n")?;
2707 }
2708 write!(code, "];\n")?;
2709 write!(
2710 code,
2711 "
2712 pub static MASKS: [u8; 4] = [
2713 0b1111_1111,
2714 0b0000_0011,
2715 0b0000_1111,
2716 0b0011_1111,
2717 ];
2718 "
2719 )?;
2720 }
2721 write!(code, "}}\n")?;
2722 write!(code, "}}\n")?;
2723
2724 Ok(code)
2725 }
2726}