1#![cfg_attr(not(feature = "std"), no_std)]
2
3use core::mem::MaybeUninit;
4use equator::debug_assert;
5
6#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
7pub mod x86 {
8 pub use nano_gemm_c32::x86::*;
9 pub use nano_gemm_c64::x86::*;
10 pub use nano_gemm_f32::x86::*;
11 pub use nano_gemm_f64::x86::*;
12}
13#[cfg(target_arch = "aarch64")]
14pub mod aarch64 {
15 pub use nano_gemm_c32::aarch64::*;
16 pub use nano_gemm_c64::aarch64::*;
17 pub use nano_gemm_f32::aarch64::*;
18 pub use nano_gemm_f64::aarch64::*;
19}
20
21#[allow(non_camel_case_types)]
22pub type c32 = num_complex::Complex32;
23#[allow(non_camel_case_types)]
24pub type c64 = num_complex::Complex64;
25
26pub use nano_gemm_core::*;
27
28#[derive(Copy, Clone)]
29pub struct Plan<T> {
30 microkernels: [[MaybeUninit<MicroKernel<T>>; 2]; 2],
31 millikernel: unsafe fn(
32 microkernels: &[[MaybeUninit<MicroKernel<T>>; 2]; 2],
33 mr: usize,
34 nr: usize,
35 m: usize,
36 n: usize,
37 k: usize,
38 dst: *mut T,
39 dst_rs: isize,
40 dst_cs: isize,
41 lhs: *const T,
42 lhs_rs: isize,
43 lhs_cs: isize,
44 rhs: *const T,
45 rhs_rs: isize,
46 rhs_cs: isize,
47 alpha: T,
48 beta: T,
49 conj_lhs: bool,
50 conj_rhs: bool,
51 full_mask: *const (),
52 last_mask: *const (),
53 ),
54 mr: usize,
55 nr: usize,
56 full_mask: *const (),
57 last_mask: *const (),
58 m: usize,
59 n: usize,
60 k: usize,
61 dst_cs: isize,
62 dst_rs: isize,
63 lhs_cs: isize,
64 lhs_rs: isize,
65 rhs_cs: isize,
66 rhs_rs: isize,
67}
68
69#[allow(unused_variables)]
70unsafe fn noop_millikernel<T: Copy>(
71 microkernels: &[[MaybeUninit<MicroKernel<T>>; 2]; 2],
72 mr: usize,
73 nr: usize,
74 m: usize,
75 n: usize,
76 k: usize,
77 dst: *mut T,
78 dst_rs: isize,
79 dst_cs: isize,
80 lhs: *const T,
81 lhs_rs: isize,
82 lhs_cs: isize,
83 rhs: *const T,
84 rhs_rs: isize,
85 rhs_cs: isize,
86 alpha: T,
87 beta: T,
88 conj_lhs: bool,
89 conj_rhs: bool,
90 full_mask: *const (),
91 last_mask: *const (),
92) {
93}
94
95#[allow(unused_variables)]
96unsafe fn naive_millikernel<
97 T: Copy + core::ops::Mul<Output = T> + core::ops::Add<Output = T> + PartialEq + Conj,
98>(
99 microkernels: &[[MaybeUninit<MicroKernel<T>>; 2]; 2],
100 mr: usize,
101 nr: usize,
102 m: usize,
103 n: usize,
104 k: usize,
105 dst: *mut T,
106 dst_rs: isize,
107 dst_cs: isize,
108 lhs: *const T,
109 lhs_rs: isize,
110 lhs_cs: isize,
111 rhs: *const T,
112 rhs_rs: isize,
113 rhs_cs: isize,
114 alpha: T,
115 beta: T,
116 conj_lhs: bool,
117 conj_rhs: bool,
118 full_mask: *const (),
119 last_mask: *const (),
120) {
121 let zero: T = core::mem::zeroed();
122 if alpha == zero {
123 for j in 0..n {
124 for i in 0..m {
125 let mut acc = zero;
126 for depth in 0..k {
127 let lhs = *lhs.offset(lhs_rs * i as isize + lhs_cs * depth as isize);
128 let rhs = *rhs.offset(rhs_rs * depth as isize + rhs_cs * j as isize);
129 acc = acc
130 + if conj_lhs { lhs.conj() } else { lhs }
131 * if conj_rhs { rhs.conj() } else { rhs };
132 }
133 *dst.offset(dst_rs * i as isize + dst_cs * j as isize) = beta * acc;
134 }
135 }
136 } else {
137 for j in 0..n {
138 for i in 0..m {
139 let mut acc = zero;
140 for depth in 0..k {
141 let lhs = *lhs.offset(lhs_rs * i as isize + lhs_cs * depth as isize);
142 let rhs = *rhs.offset(rhs_rs * depth as isize + rhs_cs * j as isize);
143 acc = acc
144 + if conj_lhs { lhs.conj() } else { lhs }
145 * if conj_rhs { rhs.conj() } else { rhs };
146 }
147 let dst = dst.offset(dst_rs * i as isize + dst_cs * j as isize);
148 *dst = alpha * *dst + beta * acc;
149 }
150 }
151 }
152}
153
154#[allow(unused_variables)]
155unsafe fn fill_millikernel<T: Copy + PartialEq + core::ops::Mul<Output = T>>(
156 microkernels: &[[MaybeUninit<MicroKernel<T>>; 2]; 2],
157 mr: usize,
158 nr: usize,
159 m: usize,
160 n: usize,
161 k: usize,
162 dst: *mut T,
163 dst_rs: isize,
164 dst_cs: isize,
165 lhs: *const T,
166 lhs_rs: isize,
167 lhs_cs: isize,
168 rhs: *const T,
169 rhs_rs: isize,
170 rhs_cs: isize,
171 alpha: T,
172 beta: T,
173 conj_lhs: bool,
174 conj_rhs: bool,
175 full_mask: *const (),
176 last_mask: *const (),
177) {
178 let zero: T = core::mem::zeroed();
179 if alpha == zero {
180 for j in 0..n {
181 for i in 0..m {
182 *dst.offset(dst_rs * i as isize + dst_cs * j as isize) = core::mem::zeroed();
183 }
184 }
185 } else {
186 for j in 0..n {
187 for i in 0..m {
188 let dst = dst.offset(dst_rs * i as isize + dst_cs * j as isize);
189 *dst = alpha * *dst;
190 }
191 }
192 }
193}
194
195#[inline(always)]
196unsafe fn small_direct_millikernel<
197 T: Copy,
198 const M_DIVCEIL_MR: usize,
199 const N_DIVCEIL_NR: usize,
200>(
201 microkernels: &[[MaybeUninit<MicroKernel<T>>; 2]; 2],
202 mr: usize,
203 nr: usize,
204 m: usize,
205 n: usize,
206 k: usize,
207 dst: *mut T,
208 dst_rs: isize,
209 dst_cs: isize,
210 lhs: *const T,
211 lhs_rs: isize,
212 lhs_cs: isize,
213 rhs: *const T,
214 rhs_rs: isize,
215 rhs_cs: isize,
216 alpha: T,
217 beta: T,
218 conj_lhs: bool,
219 conj_rhs: bool,
220 full_mask: *const (),
221 last_mask: *const (),
222) {
223 _ = (m, n);
224 debug_assert!(all(lhs_rs == 1, dst_rs == 1));
225
226 let mut data = MicroKernelData {
227 alpha,
228 beta,
229 conj_lhs,
230 conj_rhs,
231 k,
232 dst_cs,
233 lhs_cs,
234 rhs_rs,
235 rhs_cs,
236 last_mask,
237 };
238
239 let mut i = 0usize;
240 while i < M_DIVCEIL_MR {
241 data.last_mask = if i + 1 < M_DIVCEIL_MR {
242 full_mask
243 } else {
244 last_mask
245 };
246
247 let microkernels = microkernels.get_unchecked((i + 1 >= M_DIVCEIL_MR) as usize);
248 {
249 let i = i * mr;
250 let dst = dst.offset(i as isize);
251
252 let mut j = 0usize;
253 while j < N_DIVCEIL_NR {
254 let microkernel = microkernels
255 .get_unchecked((j + 1 >= N_DIVCEIL_NR) as usize)
256 .assume_init();
257
258 {
259 let j = j * nr;
260 microkernel(
261 &data,
262 dst.offset(j as isize * dst_cs),
263 lhs.offset(i as isize),
264 rhs.offset(j as isize * rhs_cs),
265 );
266 }
267
268 j += 1;
269 }
270 }
271 i += 1;
272 }
273}
274
275unsafe fn direct_millikernel<T: Copy>(
276 microkernels: &[[MaybeUninit<MicroKernel<T>>; 2]; 2],
277 mr: usize,
278 nr: usize,
279 m: usize,
280 n: usize,
281 k: usize,
282 dst: *mut T,
283 dst_rs: isize,
284 dst_cs: isize,
285 lhs: *const T,
286 lhs_rs: isize,
287 lhs_cs: isize,
288 rhs: *const T,
289 rhs_rs: isize,
290 rhs_cs: isize,
291 alpha: T,
292 beta: T,
293 conj_lhs: bool,
294 conj_rhs: bool,
295 full_mask: *const (),
296 last_mask: *const (),
297) {
298 debug_assert!(all(lhs_rs == 1, dst_rs == 1));
299
300 let mut data = MicroKernelData {
301 alpha,
302 beta,
303 conj_lhs,
304 conj_rhs,
305 k,
306 dst_cs,
307 lhs_cs,
308 rhs_rs,
309 rhs_cs,
310 last_mask,
311 };
312
313 let mut i = 0usize;
314 while i < m {
315 data.last_mask = if i + mr <= m { full_mask } else { last_mask };
316 let microkernels = microkernels.get_unchecked((i + mr > m) as usize);
317 let dst = dst.offset(i as isize);
318
319 let mut j = 0usize;
320 while j < n {
321 let microkernel = microkernels
322 .get_unchecked((j + nr > n) as usize)
323 .assume_init();
324
325 microkernel(
326 &data,
327 dst.offset(j as isize * dst_cs),
328 lhs.offset(i as isize),
329 rhs.offset(j as isize * rhs_cs),
330 );
331
332 j += nr;
333 }
334
335 i += mr;
336 }
337}
338
339trait One {
340 const ONE: Self;
341}
342trait Conj {
343 fn conj(self) -> Self;
344}
345
346impl One for f32 {
347 const ONE: Self = 1.0;
348}
349impl One for f64 {
350 const ONE: Self = 1.0;
351}
352impl One for c32 {
353 const ONE: Self = Self { re: 1.0, im: 0.0 };
354}
355impl One for c64 {
356 const ONE: Self = Self { re: 1.0, im: 0.0 };
357}
358
359impl Conj for f32 {
360 #[inline]
361 fn conj(self) -> Self {
362 self
363 }
364}
365impl Conj for f64 {
366 #[inline]
367 fn conj(self) -> Self {
368 self
369 }
370}
371
372impl Conj for c32 {
373 #[inline]
374 fn conj(self) -> Self {
375 Self::conj(&self)
376 }
377}
378impl Conj for c64 {
379 #[inline]
380 fn conj(self) -> Self {
381 Self::conj(&self)
382 }
383}
384
385unsafe fn copy_millikernel<
386 T: Copy + PartialEq + core::ops::Add<Output = T> + core::ops::Mul<Output = T> + Conj + One,
387>(
388 microkernels: &[[MaybeUninit<MicroKernel<T>>; 2]; 2],
389 mr: usize,
390 nr: usize,
391 m: usize,
392 n: usize,
393 k: usize,
394 dst: *mut T,
395 dst_rs: isize,
396 dst_cs: isize,
397 lhs: *const T,
398 lhs_rs: isize,
399 lhs_cs: isize,
400 rhs: *const T,
401 rhs_rs: isize,
402 rhs_cs: isize,
403 mut alpha: T,
404 beta: T,
405 conj_lhs: bool,
406 conj_rhs: bool,
407 full_mask: *const (),
408 last_mask: *const (),
409) {
410 if dst_rs == 1 && lhs_rs == 1 {
411 let gemm_dst = dst;
412 let gemm_lhs = lhs;
413 let gemm_dst_cs = dst_cs;
414 let gemm_lhs_cs = lhs_cs;
415
416 direct_millikernel(
417 microkernels,
418 mr,
419 nr,
420 m,
421 n,
422 k,
423 gemm_dst,
424 1,
425 gemm_dst_cs,
426 gemm_lhs,
427 1,
428 gemm_lhs_cs,
429 rhs,
430 rhs_rs,
431 rhs_cs,
432 alpha,
433 beta,
434 conj_lhs,
435 conj_rhs,
436 full_mask,
437 last_mask,
438 );
439 } else {
440 const M_BS: usize = 64;
442 const N_BS: usize = 64;
443 const K_BS: usize = 64;
444 let mut dst_tmp: MaybeUninit<[T; M_BS * N_BS]> = core::mem::MaybeUninit::uninit();
445 let mut lhs_tmp: MaybeUninit<[T; M_BS * K_BS]> = core::mem::MaybeUninit::uninit();
446
447 let dst_tmp = &mut *((&mut dst_tmp) as *mut _ as *mut [[MaybeUninit<T>; M_BS]; N_BS]);
448 let lhs_tmp = &mut *((&mut lhs_tmp) as *mut _ as *mut [[MaybeUninit<T>; M_BS]; K_BS]);
449
450 let gemm_dst = if dst_rs == 1 {
451 dst
452 } else {
453 dst_tmp.as_mut_ptr() as *mut T
454 };
455 let gemm_lhs = lhs_tmp.as_mut_ptr() as *mut T;
456 let gemm_dst_cs = if dst_rs == 1 { dst_cs } else { M_BS as isize };
457 let gemm_lhs_cs = M_BS as isize;
458
459 let mut depth = 0usize;
460 while depth < k {
461 let depth_bs = Ord::min(K_BS, k - depth);
462
463 let mut i = 0usize;
464 while i < m {
465 let i_bs = Ord::min(M_BS, m - i);
466
467 let lhs = lhs.offset(lhs_rs * i as isize + lhs_cs * depth as isize);
468
469 for ii in 0..i_bs {
470 for jj in 0..depth_bs {
471 let ii = ii as isize;
472 let jj = jj as isize;
473 *(gemm_lhs.offset(ii + gemm_lhs_cs * jj)) =
474 *(lhs.offset(lhs_rs * ii + lhs_cs * jj));
475 }
476 }
477
478 let mut j = 0usize;
479 while j < n {
480 let j_bs = Ord::min(N_BS, n - j);
481
482 let rhs = rhs.offset(rhs_rs * depth as isize + rhs_cs * j as isize);
483
484 let dst = dst.offset(dst_rs * i as isize + dst_cs * j as isize);
485 let gemm_dst = if dst_rs == 1 {
486 gemm_dst.offset(i as isize + gemm_dst_cs * j as isize)
487 } else {
488 gemm_dst
489 };
490
491 direct_millikernel(
492 microkernels,
493 mr,
494 nr,
495 i_bs,
496 j_bs,
497 depth_bs,
498 gemm_dst,
499 1,
500 gemm_dst_cs,
501 gemm_lhs,
502 1,
503 gemm_lhs_cs,
504 rhs,
505 rhs_rs,
506 rhs_cs,
507 if dst_rs == 1 {
508 alpha
509 } else {
510 core::mem::zeroed()
511 },
512 beta,
513 conj_lhs,
514 conj_rhs,
515 full_mask,
516 if i + i_bs == m { last_mask } else { full_mask },
517 );
518
519 if dst_rs != 1 {
520 if alpha == core::mem::zeroed() {
521 for ii in 0..i_bs {
522 for jj in 0..j_bs {
523 let ii = ii as isize;
524 let jj = jj as isize;
525 *(dst.offset(dst_rs * ii + dst_cs * jj)) =
526 *(gemm_dst.offset(ii + gemm_dst_cs * jj));
527 }
528 }
529 } else {
530 for ii in 0..i_bs {
531 for jj in 0..j_bs {
532 let ii = ii as isize;
533 let jj = jj as isize;
534 let dst = dst.offset(dst_rs * ii + dst_cs * jj);
535 *dst = alpha * *dst + *(gemm_dst.offset(ii + gemm_dst_cs * jj));
536 }
537 }
538 }
539 }
540
541 j += j_bs;
542 }
543
544 i += i_bs;
545 }
546
547 alpha = T::ONE;
548 depth += depth_bs;
549 }
550 }
551}
552
553impl<T> Plan<T> {
554 #[allow(dead_code)]
555 #[inline(always)]
556 fn from_masked_impl<const MR_DIV_N: usize, const NR: usize, const N: usize, Mask>(
557 const_microkernels: &[[[MicroKernel<T>; NR]; MR_DIV_N]; 17],
558 const_masks: Option<&[Mask; N]>,
559 m: usize,
560 n: usize,
561 k: usize,
562 is_col_major: bool,
563 ) -> Self
564 where
565 T: Copy + PartialEq + core::ops::Add<Output = T> + core::ops::Mul<Output = T> + Conj + One,
566 {
567 let mut microkernels = [[MaybeUninit::<MicroKernel<T>>::uninit(); 2]; 2];
568
569 let mr = MR_DIV_N * N;
570 let nr = NR;
571
572 {
573 let k = Ord::min(k.wrapping_sub(1), 16);
574 let m = (m.wrapping_sub(1) / N) % (mr / N);
575 let n = n.wrapping_sub(1) % nr;
576
577 microkernels[0][0].write(const_microkernels[k][MR_DIV_N - 1][NR - 1]);
578 microkernels[0][1].write(const_microkernels[k][MR_DIV_N - 1][n]);
579 microkernels[1][0].write(const_microkernels[k][m][NR - 1]);
580 microkernels[1][1].write(const_microkernels[k][m][n]);
581 }
582
583 Self {
584 microkernels,
585 millikernel: if m == 0 || n == 0 {
586 noop_millikernel
587 } else if k == 0 {
588 fill_millikernel
589 } else if is_col_major {
590 if m <= mr && n <= nr {
591 small_direct_millikernel::<_, 1, 1>
592 } else if m <= mr && n <= 2 * nr {
593 small_direct_millikernel::<_, 1, 2>
594 } else if m <= 2 * mr && n <= nr {
595 small_direct_millikernel::<_, 2, 1>
596 } else if m <= 2 * mr && n <= 2 * nr {
597 small_direct_millikernel::<_, 2, 2>
598 } else {
599 direct_millikernel
600 }
601 } else {
602 copy_millikernel
603 },
604 mr,
605 nr,
606 m,
607 n,
608 k,
609 dst_rs: if is_col_major { 1 } else { isize::MIN },
610 dst_cs: isize::MIN,
611 lhs_rs: if is_col_major { 1 } else { isize::MIN },
612 lhs_cs: isize::MIN,
613 rhs_cs: isize::MIN,
614 rhs_rs: isize::MIN,
615 full_mask: if let Some(const_masks) = const_masks {
616 (&const_masks[0]) as *const _ as *const ()
617 } else {
618 &()
619 },
620 last_mask: if let Some(const_masks) = const_masks {
621 (&const_masks[m % N]) as *const _ as *const ()
622 } else {
623 &()
624 },
625 }
626 }
627
628 #[allow(dead_code)]
629 #[inline(always)]
630 fn from_non_masked_impl<const MR: usize, const NR: usize>(
631 const_microkernels: &[[[MicroKernel<T>; NR]; MR]; 17],
632 m: usize,
633 n: usize,
634 k: usize,
635 is_col_major: bool,
636 ) -> Self
637 where
638 T: Copy + PartialEq + core::ops::Add<Output = T> + core::ops::Mul<Output = T> + Conj + One,
639 {
640 let mut microkernels = [[MaybeUninit::<MicroKernel<T>>::uninit(); 2]; 2];
641
642 let mr = MR;
643 let nr = NR;
644
645 {
646 let k = Ord::min(k.wrapping_sub(1), 16);
647 let m = m.wrapping_sub(1) % mr;
648 let n = n.wrapping_sub(1) % nr;
649
650 microkernels[0][0].write(const_microkernels[k][MR - 1][NR - 1]);
651 microkernels[0][1].write(const_microkernels[k][MR - 1][n]);
652 microkernels[1][0].write(const_microkernels[k][m][NR - 1]);
653 microkernels[1][1].write(const_microkernels[k][m][n]);
654 }
655
656 Self {
657 microkernels,
658 millikernel: if m == 0 || n == 0 {
659 noop_millikernel
660 } else if k == 0 {
661 fill_millikernel
662 } else if is_col_major {
663 if m <= mr && n <= nr {
664 small_direct_millikernel::<_, 1, 1>
665 } else if m <= mr && n <= 2 * nr {
666 small_direct_millikernel::<_, 1, 2>
667 } else if m <= 2 * mr && n <= nr {
668 small_direct_millikernel::<_, 2, 1>
669 } else if m <= 2 * mr && n <= 2 * nr {
670 small_direct_millikernel::<_, 2, 2>
671 } else {
672 direct_millikernel
673 }
674 } else {
675 copy_millikernel
676 },
677 mr,
678 nr,
679 m,
680 n,
681 k,
682 dst_rs: if is_col_major { 1 } else { isize::MIN },
683 dst_cs: isize::MIN,
684 lhs_rs: if is_col_major { 1 } else { isize::MIN },
685 lhs_cs: isize::MIN,
686 rhs_cs: isize::MIN,
687 rhs_rs: isize::MIN,
688 full_mask: &(),
689 last_mask: &(),
690 }
691 }
692}
693
694impl Plan<f32> {
695 fn new_f32_scalar(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
696 Self {
697 microkernels: [[MaybeUninit::<MicroKernel<f32>>::uninit(); 2]; 2],
698 millikernel: naive_millikernel,
699 mr: 0,
700 nr: 0,
701 full_mask: core::ptr::null(),
702 last_mask: core::ptr::null(),
703 m,
704 n,
705 k,
706 dst_rs: if is_col_major { 1 } else { isize::MIN },
707 dst_cs: isize::MIN,
708 lhs_rs: if is_col_major { 1 } else { isize::MIN },
709 lhs_cs: isize::MIN,
710 rhs_cs: isize::MIN,
711 rhs_rs: isize::MIN,
712 }
713 }
714}
715impl Plan<f64> {
716 fn new_f64_scalar(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
717 Self {
718 microkernels: [[MaybeUninit::<MicroKernel<f64>>::uninit(); 2]; 2],
719 millikernel: naive_millikernel,
720 mr: 0,
721 nr: 0,
722 full_mask: core::ptr::null(),
723 last_mask: core::ptr::null(),
724 m,
725 n,
726 k,
727 dst_rs: if is_col_major { 1 } else { isize::MIN },
728 dst_cs: isize::MIN,
729 lhs_rs: if is_col_major { 1 } else { isize::MIN },
730 lhs_cs: isize::MIN,
731 rhs_cs: isize::MIN,
732 rhs_rs: isize::MIN,
733 }
734 }
735}
736impl Plan<c32> {
737 fn new_c32_scalar(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
738 Self {
739 microkernels: [[MaybeUninit::<MicroKernel<c32>>::uninit(); 2]; 2],
740 millikernel: naive_millikernel,
741 mr: 0,
742 nr: 0,
743 full_mask: core::ptr::null(),
744 last_mask: core::ptr::null(),
745 m,
746 n,
747 k,
748 dst_rs: if is_col_major { 1 } else { isize::MIN },
749 dst_cs: isize::MIN,
750 lhs_rs: if is_col_major { 1 } else { isize::MIN },
751 lhs_cs: isize::MIN,
752 rhs_cs: isize::MIN,
753 rhs_rs: isize::MIN,
754 }
755 }
756}
757impl Plan<c64> {
758 fn new_c64_scalar(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
759 Self {
760 microkernels: [[MaybeUninit::<MicroKernel<c64>>::uninit(); 2]; 2],
761 millikernel: naive_millikernel,
762 mr: 0,
763 nr: 0,
764 full_mask: core::ptr::null(),
765 last_mask: core::ptr::null(),
766 m,
767 n,
768 k,
769 dst_rs: if is_col_major { 1 } else { isize::MIN },
770 dst_cs: isize::MIN,
771 lhs_rs: if is_col_major { 1 } else { isize::MIN },
772 lhs_cs: isize::MIN,
773 rhs_cs: isize::MIN,
774 rhs_rs: isize::MIN,
775 }
776 }
777}
778
779impl<T> Plan<T> {
780 #[inline(always)]
781 pub unsafe fn execute_unchecked(
782 &self,
783 m: usize,
784 n: usize,
785 k: usize,
786 dst: *mut T,
787 dst_rs: isize,
788 dst_cs: isize,
789 lhs: *const T,
790 lhs_rs: isize,
791 lhs_cs: isize,
792 rhs: *const T,
793 rhs_rs: isize,
794 rhs_cs: isize,
795 alpha: T,
796 beta: T,
797 conj_lhs: bool,
798 conj_rhs: bool,
799 ) {
800 debug_assert!(m == self.m);
801 debug_assert!(n == self.n);
802 debug_assert!(k == self.k);
803 if self.dst_cs != isize::MIN {
804 debug_assert!(dst_cs == self.dst_cs);
805 }
806 if self.dst_rs != isize::MIN {
807 debug_assert!(dst_rs == self.dst_rs);
808 }
809 if self.lhs_cs != isize::MIN {
810 debug_assert!(lhs_cs == self.lhs_cs);
811 }
812 if self.lhs_rs != isize::MIN {
813 debug_assert!(lhs_rs == self.lhs_rs);
814 }
815 if self.rhs_cs != isize::MIN {
816 debug_assert!(rhs_cs == self.rhs_cs);
817 }
818 if self.rhs_rs != isize::MIN {
819 debug_assert!(rhs_rs == self.rhs_rs);
820 }
821
822 (self.millikernel)(
823 &self.microkernels,
824 self.mr,
825 self.nr,
826 m,
827 n,
828 k,
829 dst,
830 dst_rs,
831 dst_cs,
832 lhs,
833 lhs_rs,
834 lhs_cs,
835 rhs,
836 rhs_rs,
837 rhs_cs,
838 alpha,
839 beta,
840 conj_lhs,
841 conj_rhs,
842 self.full_mask,
843 self.last_mask,
844 );
845 }
846}
847
848impl Plan<f32> {
849 #[track_caller]
850 pub fn new_f32_impl(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
851 #[cfg(feature = "std")]
852 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
853 {
854 #[cfg(feature = "x86-v4")]
855 if m > 8 && std::is_x86_feature_detected!("avx512f") {
856 return Self::new_f32_avx512(m, n, k, is_col_major);
857 }
858
859 if std::is_x86_feature_detected!("avx2") {
860 if m == 1 {
861 return Self::new_f32x1(m, n, k, is_col_major);
862 }
863 if m == 2 {
864 return Self::new_f32x2(m, n, k, is_col_major);
865 }
866 if m <= 4 {
867 return Self::new_f32x4(m, n, k, is_col_major);
868 }
869
870 return Self::new_f32_avx(m, n, k, is_col_major);
871 }
872 }
873 #[cfg(feature = "std")]
874 #[cfg(target_arch = "aarch64")]
875 {
876 if std::arch::is_aarch64_feature_detected!("neon") {
877 return Self::from_non_masked_impl(
878 &aarch64::f32::neon::MICROKERNELS,
879 m,
880 n,
881 k,
882 is_col_major,
883 );
884 }
885 }
886
887 Self::new_f32_scalar(m, n, k, is_col_major)
888 }
889
890 #[track_caller]
891 pub fn new_colmajor_lhs_and_dst_f32(m: usize, n: usize, k: usize) -> Self {
892 Self::new_f32_impl(m, n, k, true)
893 }
894
895 #[track_caller]
896 pub fn new_f32(m: usize, n: usize, k: usize) -> Self {
897 Self::new_f32_impl(m, n, k, false)
898 }
899}
900
901impl Plan<f64> {
902 #[track_caller]
903 pub fn new_f64_impl(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
904 #[cfg(feature = "std")]
905 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
906 {
907 #[cfg(feature = "x86-v4")]
908 if m > 4 && std::is_x86_feature_detected!("avx512f") {
909 return Self::new_f64_avx512(m, n, k, is_col_major);
910 }
911
912 if std::is_x86_feature_detected!("avx2") {
913 if m == 1 {
914 return Self::new_f64x1(m, n, k, is_col_major);
915 }
916 if m == 2 {
917 return Self::new_f64x2(m, n, k, is_col_major);
918 }
919
920 return Self::new_f64_avx(m, n, k, is_col_major);
921 }
922 }
923
924 #[cfg(feature = "std")]
925 #[cfg(target_arch = "aarch64")]
926 {
927 if std::arch::is_aarch64_feature_detected!("neon") {
928 return Self::from_non_masked_impl(
929 &aarch64::f64::neon::MICROKERNELS,
930 m,
931 n,
932 k,
933 is_col_major,
934 );
935 }
936 }
937
938 Self::new_f64_scalar(m, n, k, is_col_major)
939 }
940
941 #[track_caller]
942 pub fn new_colmajor_lhs_and_dst_f64(m: usize, n: usize, k: usize) -> Self {
943 Self::new_f64_impl(m, n, k, true)
944 }
945
946 #[track_caller]
947 pub fn new_f64(m: usize, n: usize, k: usize) -> Self {
948 Self::new_f64_impl(m, n, k, false)
949 }
950}
951
952impl Plan<c32> {
953 #[track_caller]
954 pub fn new_c32_impl(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
955 #[cfg(feature = "std")]
956 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
957 {
958 #[cfg(feature = "x86-v4")]
959 if m > 4 && std::is_x86_feature_detected!("avx512f") {
960 return Self::new_c32_avx512(m, n, k, is_col_major);
961 }
962
963 if std::is_x86_feature_detected!("avx2") {
964 if m == 1 {
965 return Self::new_c32x1(m, n, k, is_col_major);
966 }
967 if m == 2 {
968 return Self::new_c32x2(m, n, k, is_col_major);
969 }
970
971 return Self::new_c32_avx(m, n, k, is_col_major);
972 }
973 }
974
975 #[cfg(feature = "std")]
976 #[cfg(target_arch = "aarch64")]
977 {
978 if std::arch::is_aarch64_feature_detected!("neon")
979 && std::arch::is_aarch64_feature_detected!("fcma")
980 {
981 return Self::from_non_masked_impl(
982 &aarch64::c32::neon::MICROKERNELS,
983 m,
984 n,
985 k,
986 is_col_major,
987 );
988 }
989 }
990
991 Self::new_c32_scalar(m, n, k, is_col_major)
992 }
993
994 #[track_caller]
995 pub fn new_colmajor_lhs_and_dst_c32(m: usize, n: usize, k: usize) -> Self {
996 Self::new_c32_impl(m, n, k, true)
997 }
998
999 #[track_caller]
1000 pub fn new_c32(m: usize, n: usize, k: usize) -> Self {
1001 Self::new_c32_impl(m, n, k, false)
1002 }
1003}
1004
1005impl Plan<c64> {
1006 #[track_caller]
1007 pub fn new_c64_impl(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1008 #[cfg(feature = "std")]
1009 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1010 {
1011 #[cfg(feature = "x86-v4")]
1012 if m > 2 && std::is_x86_feature_detected!("avx512f") {
1013 return Self::new_c64_avx512(m, n, k, is_col_major);
1014 }
1015
1016 if std::is_x86_feature_detected!("avx2") {
1017 if m == 1 {
1018 return Self::new_c64x1(m, n, k, is_col_major);
1019 }
1020 return Self::new_c64_avx(m, n, k, is_col_major);
1021 }
1022 }
1023
1024 #[cfg(feature = "std")]
1025 #[cfg(target_arch = "aarch64")]
1026 {
1027 if std::arch::is_aarch64_feature_detected!("neon")
1028 && std::arch::is_aarch64_feature_detected!("fcma")
1029 {
1030 return Self::from_non_masked_impl(
1031 &aarch64::c64::neon::MICROKERNELS,
1032 m,
1033 n,
1034 k,
1035 is_col_major,
1036 );
1037 }
1038 }
1039 Self::new_c64_scalar(m, n, k, is_col_major)
1040 }
1041
1042 #[track_caller]
1043 pub fn new_colmajor_lhs_and_dst_c64(m: usize, n: usize, k: usize) -> Self {
1044 Self::new_c64_impl(m, n, k, true)
1045 }
1046
1047 #[track_caller]
1048 pub fn new_c64(m: usize, n: usize, k: usize) -> Self {
1049 Self::new_c64_impl(m, n, k, false)
1050 }
1051}
1052
1053#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1054mod x86_api {
1055 use super::*;
1056
1057 impl Plan<f32> {
1058 pub fn new_f32x1(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1059 use x86::f32::f32x1::*;
1060 Self::from_masked_impl::<MR_DIV_N, NR, N, ()>(
1061 &MICROKERNELS,
1062 None,
1063 m,
1064 n,
1065 k,
1066 is_col_major,
1067 )
1068 }
1069 pub fn new_f32x2(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1070 use x86::f32::f32x2::*;
1071 Self::from_masked_impl::<MR_DIV_N, NR, N, ()>(
1072 &MICROKERNELS,
1073 None,
1074 m,
1075 n,
1076 k,
1077 is_col_major,
1078 )
1079 }
1080 pub fn new_f32x4(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1081 use x86::f32::f32x4::*;
1082 Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1083 }
1084
1085 pub fn new_f32_avx(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1086 use x86::f32::avx::*;
1087 Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1088 }
1089
1090 #[cfg(feature = "x86-v4")]
1091 pub fn new_f32_avx512(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1092 use x86::f32::avx512::*;
1093 Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1094 }
1095 }
1096
1097 impl Plan<f64> {
1098 pub fn new_f64x1(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1099 use x86::f64::f64x1::*;
1100 Self::from_masked_impl::<MR_DIV_N, NR, N, ()>(
1101 &MICROKERNELS,
1102 None,
1103 m,
1104 n,
1105 k,
1106 is_col_major,
1107 )
1108 }
1109 pub fn new_f64x2(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1110 use x86::f64::f64x2::*;
1111 Self::from_masked_impl::<MR_DIV_N, NR, N, ()>(
1112 &MICROKERNELS,
1113 None,
1114 m,
1115 n,
1116 k,
1117 is_col_major,
1118 )
1119 }
1120
1121 pub fn new_f64_avx(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1122 use x86::f64::avx::*;
1123 Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1124 }
1125
1126 #[cfg(feature = "x86-v4")]
1127 pub fn new_f64_avx512(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1128 use x86::f64::avx512::*;
1129 Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1130 }
1131 }
1132 impl Plan<c32> {
1133 pub fn new_c32x1(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1134 use x86::c32::c32x1::*;
1135 Self::from_masked_impl::<MR_DIV_N, NR, N, ()>(
1136 &MICROKERNELS,
1137 None,
1138 m,
1139 n,
1140 k,
1141 is_col_major,
1142 )
1143 }
1144 pub fn new_c32x2(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1145 use x86::c32::c32x2::*;
1146 Self::from_masked_impl::<MR_DIV_N, NR, N, ()>(
1147 &MICROKERNELS,
1148 None,
1149 m,
1150 n,
1151 k,
1152 is_col_major,
1153 )
1154 }
1155
1156 pub fn new_c32_avx(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1157 use x86::c32::avx::*;
1158 Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1159 }
1160
1161 #[cfg(feature = "x86-v4")]
1162 pub fn new_c32_avx512(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1163 use x86::c32::avx512::*;
1164 Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1165 }
1166 }
1167 impl Plan<c64> {
1168 pub fn new_c64x1(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1169 use x86::c64::c64x1::*;
1170 Self::from_masked_impl::<MR_DIV_N, NR, N, ()>(
1171 &MICROKERNELS,
1172 None,
1173 m,
1174 n,
1175 k,
1176 is_col_major,
1177 )
1178 }
1179
1180 pub fn new_c64_avx(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1181 use x86::c64::avx::*;
1182 Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1183 }
1184
1185 #[cfg(feature = "x86-v4")]
1186 pub fn new_c64_avx512(m: usize, n: usize, k: usize, is_col_major: bool) -> Self {
1187 use x86::c64::avx512::*;
1188 Self::from_masked_impl(&MICROKERNELS, Some(&MASKS), m, n, k, is_col_major)
1189 }
1190 }
1191}
1192
1193pub mod planless {
1194 use super::*;
1195
1196 #[inline(always)]
1197 pub unsafe fn execute_f32(
1198 mut m: usize,
1199 mut n: usize,
1200 k: usize,
1201 mut dst: *mut f32,
1202 mut dst_rs: isize,
1203 mut dst_cs: isize,
1204 mut lhs: *const f32,
1205 mut lhs_rs: isize,
1206 mut lhs_cs: isize,
1207 mut rhs: *const f32,
1208 mut rhs_rs: isize,
1209 mut rhs_cs: isize,
1210 alpha: f32,
1211 beta: f32,
1212 mut conj_lhs: bool,
1213 mut conj_rhs: bool,
1214 ) {
1215 if dst_cs.unsigned_abs() < dst_rs.unsigned_abs() {
1216 core::mem::swap(&mut m, &mut n);
1217 core::mem::swap(&mut dst_rs, &mut dst_cs);
1218 core::mem::swap(&mut lhs, &mut rhs);
1219 core::mem::swap(&mut lhs_rs, &mut rhs_cs);
1220 core::mem::swap(&mut lhs_cs, &mut rhs_rs);
1221 core::mem::swap(&mut conj_lhs, &mut conj_rhs);
1222 }
1223 if dst_rs == -1 && m > 0 {
1224 dst = dst.wrapping_offset((m - 1) as isize * dst_rs);
1225 dst_rs = dst_rs.wrapping_neg();
1226 lhs = lhs.wrapping_offset((m - 1) as isize * lhs_rs);
1227 lhs_rs = lhs_rs.wrapping_neg();
1228 }
1229
1230 let plan = if lhs_rs == 1 && dst_rs == 1 {
1231 Plan::new_colmajor_lhs_and_dst_f32(m, n, k)
1232 } else {
1233 Plan::new_f32(m, n, k)
1234 };
1235 plan.execute_unchecked(
1236 m, n, k, dst, dst_rs, dst_cs, lhs, lhs_rs, lhs_cs, rhs, rhs_rs, rhs_cs, alpha, beta,
1237 conj_lhs, conj_rhs,
1238 )
1239 }
1240
1241 #[inline(always)]
1242 pub unsafe fn execute_c32(
1243 mut m: usize,
1244 mut n: usize,
1245 k: usize,
1246 mut dst: *mut c32,
1247 mut dst_rs: isize,
1248 mut dst_cs: isize,
1249 mut lhs: *const c32,
1250 mut lhs_rs: isize,
1251 mut lhs_cs: isize,
1252 mut rhs: *const c32,
1253 mut rhs_rs: isize,
1254 mut rhs_cs: isize,
1255 alpha: c32,
1256 beta: c32,
1257 mut conj_lhs: bool,
1258 mut conj_rhs: bool,
1259 ) {
1260 if dst_cs.unsigned_abs() < dst_rs.unsigned_abs() {
1261 core::mem::swap(&mut m, &mut n);
1262 core::mem::swap(&mut dst_rs, &mut dst_cs);
1263 core::mem::swap(&mut lhs, &mut rhs);
1264 core::mem::swap(&mut lhs_rs, &mut rhs_cs);
1265 core::mem::swap(&mut lhs_cs, &mut rhs_rs);
1266 core::mem::swap(&mut conj_lhs, &mut conj_rhs);
1267 }
1268 if dst_rs == -1 && m > 0 {
1269 dst = dst.wrapping_offset((m - 1) as isize * dst_rs);
1270 dst_rs = dst_rs.wrapping_neg();
1271 lhs = lhs.wrapping_offset((m - 1) as isize * lhs_rs);
1272 lhs_rs = lhs_rs.wrapping_neg();
1273 }
1274
1275 let plan = if lhs_rs == 1 && dst_rs == 1 {
1276 Plan::new_colmajor_lhs_and_dst_c32(m, n, k)
1277 } else {
1278 Plan::new_c32(m, n, k)
1279 };
1280 plan.execute_unchecked(
1281 m, n, k, dst, dst_rs, dst_cs, lhs, lhs_rs, lhs_cs, rhs, rhs_rs, rhs_cs, alpha, beta,
1282 conj_lhs, conj_rhs,
1283 )
1284 }
1285
1286 #[inline(always)]
1287 pub unsafe fn execute_f64(
1288 mut m: usize,
1289 mut n: usize,
1290 k: usize,
1291 mut dst: *mut f64,
1292 mut dst_rs: isize,
1293 mut dst_cs: isize,
1294 mut lhs: *const f64,
1295 mut lhs_rs: isize,
1296 mut lhs_cs: isize,
1297 mut rhs: *const f64,
1298 mut rhs_rs: isize,
1299 mut rhs_cs: isize,
1300 alpha: f64,
1301 beta: f64,
1302 mut conj_lhs: bool,
1303 mut conj_rhs: bool,
1304 ) {
1305 if dst_cs.unsigned_abs() < dst_rs.unsigned_abs() {
1306 core::mem::swap(&mut m, &mut n);
1307 core::mem::swap(&mut dst_rs, &mut dst_cs);
1308 core::mem::swap(&mut lhs, &mut rhs);
1309 core::mem::swap(&mut lhs_rs, &mut rhs_cs);
1310 core::mem::swap(&mut lhs_cs, &mut rhs_rs);
1311 core::mem::swap(&mut conj_lhs, &mut conj_rhs);
1312 }
1313 if dst_rs == -1 && m > 0 {
1314 dst = dst.wrapping_offset((m - 1) as isize * dst_rs);
1315 dst_rs = dst_rs.wrapping_neg();
1316 lhs = lhs.wrapping_offset((m - 1) as isize * lhs_rs);
1317 lhs_rs = lhs_rs.wrapping_neg();
1318 }
1319
1320 let plan = if lhs_rs == 1 && dst_rs == 1 {
1321 Plan::new_colmajor_lhs_and_dst_f64(m, n, k)
1322 } else {
1323 Plan::new_f64(m, n, k)
1324 };
1325 plan.execute_unchecked(
1326 m, n, k, dst, dst_rs, dst_cs, lhs, lhs_rs, lhs_cs, rhs, rhs_rs, rhs_cs, alpha, beta,
1327 conj_lhs, conj_rhs,
1328 )
1329 }
1330
1331 #[inline(always)]
1332 pub unsafe fn execute_c64(
1333 mut m: usize,
1334 mut n: usize,
1335 k: usize,
1336 mut dst: *mut c64,
1337 mut dst_rs: isize,
1338 mut dst_cs: isize,
1339 mut lhs: *const c64,
1340 mut lhs_rs: isize,
1341 mut lhs_cs: isize,
1342 mut rhs: *const c64,
1343 mut rhs_rs: isize,
1344 mut rhs_cs: isize,
1345 alpha: c64,
1346 beta: c64,
1347 mut conj_lhs: bool,
1348 mut conj_rhs: bool,
1349 ) {
1350 if dst_cs.unsigned_abs() < dst_rs.unsigned_abs() {
1351 core::mem::swap(&mut m, &mut n);
1352 core::mem::swap(&mut dst_rs, &mut dst_cs);
1353 core::mem::swap(&mut lhs, &mut rhs);
1354 core::mem::swap(&mut lhs_rs, &mut rhs_cs);
1355 core::mem::swap(&mut lhs_cs, &mut rhs_rs);
1356 core::mem::swap(&mut conj_lhs, &mut conj_rhs);
1357 }
1358 if dst_rs == -1 && m > 0 {
1359 dst = dst.wrapping_offset((m - 1) as isize * dst_rs);
1360 dst_rs = dst_rs.wrapping_neg();
1361 lhs = lhs.wrapping_offset((m - 1) as isize * lhs_rs);
1362 lhs_rs = lhs_rs.wrapping_neg();
1363 }
1364
1365 let plan = if lhs_rs == 1 && dst_rs == 1 {
1366 Plan::new_colmajor_lhs_and_dst_c64(m, n, k)
1367 } else {
1368 Plan::new_c64(m, n, k)
1369 };
1370 plan.execute_unchecked(
1371 m, n, k, dst, dst_rs, dst_cs, lhs, lhs_rs, lhs_cs, rhs, rhs_rs, rhs_cs, alpha, beta,
1372 conj_lhs, conj_rhs,
1373 )
1374 }
1375}
1376
1377#[cfg(test)]
1378mod tests {
1379 use super::*;
1380 use equator::assert;
1381
1382 #[cfg(target_arch = "x86_64")]
1383 #[test]
1384 fn test_kernel() {
1385 let gen = |_| rand::random::<f32>();
1386 let a: [[f32; 17]; 3] = core::array::from_fn(|_| core::array::from_fn(gen));
1387 let b: [[f32; 6]; 4] = core::array::from_fn(|_| core::array::from_fn(gen));
1388 let c: [[f32; 15]; 4] = core::array::from_fn(|_| core::array::from_fn(gen));
1389 assert!(std::is_x86_feature_detected!("avx2"));
1390 let mut dst = c;
1391
1392 let last_mask: std::arch::x86_64::__m256i = unsafe {
1393 core::mem::transmute([
1394 u32::MAX,
1395 u32::MAX,
1396 u32::MAX,
1397 u32::MAX,
1398 u32::MAX,
1399 u32::MAX,
1400 u32::MAX,
1401 0,
1402 ])
1403 };
1404
1405 let beta = 2.5;
1406 let alpha = 1.0;
1407
1408 unsafe {
1409 x86::f32::avx::matmul_2_4_dyn(
1410 &MicroKernelData {
1411 alpha,
1412 beta,
1413 conj_lhs: false,
1414 conj_rhs: false,
1415 k: 3,
1416 dst_cs: dst[0].len() as isize,
1417 lhs_cs: a[0].len() as isize,
1418 rhs_rs: 2,
1419 rhs_cs: 6,
1420 last_mask: (&last_mask) as *const _ as *const (),
1421 },
1422 dst.as_mut_ptr() as *mut f32,
1423 a.as_ptr() as *const f32,
1424 b.as_ptr() as *const f32,
1425 );
1426 };
1427
1428 let mut expected_dst = c;
1429 for i in 0..15 {
1430 for j in 0..4 {
1431 let mut acc = 0.0f32;
1432 for depth in 0..3 {
1433 acc = f32::mul_add(a[depth][i], b[j][2 * depth], acc);
1434 }
1435 expected_dst[j][i] = f32::mul_add(beta, acc, expected_dst[j][i]);
1436 }
1437 }
1438
1439 assert!(dst == expected_dst);
1440 }
1441
1442 #[cfg(target_arch = "x86_64")]
1443 #[test]
1444 fn test_kernel_cplx() {
1445 let gen = |_| rand::random::<c32>();
1446 let a: [[c32; 9]; 3] = core::array::from_fn(|_| core::array::from_fn(gen));
1447 let b: [[c32; 6]; 2] = core::array::from_fn(|_| core::array::from_fn(gen));
1448 let c: [[c32; 7]; 2] = core::array::from_fn(|_| core::array::from_fn(gen));
1449 assert!(std::is_x86_feature_detected!("avx2"));
1450
1451 let last_mask: std::arch::x86_64::__m256i = unsafe {
1452 core::mem::transmute([
1453 u32::MAX,
1454 u32::MAX,
1455 u32::MAX,
1456 u32::MAX,
1457 u32::MAX,
1458 u32::MAX,
1459 0,
1460 0,
1461 ])
1462 };
1463
1464 let beta = c32::new(2.5, 3.5);
1465 let alpha = c32::new(1.0, 0.0);
1466
1467 for (conj_lhs, conj_rhs) in [(false, false), (false, true), (true, false), (true, true)] {
1468 let mut dst = c;
1469 unsafe {
1470 x86::c32::avx::matmul_2_2_dyn(
1471 &MicroKernelData {
1472 alpha,
1473 beta,
1474 conj_lhs,
1475 conj_rhs,
1476 k: 3,
1477 dst_cs: dst[0].len() as isize,
1478 lhs_cs: a[0].len() as isize,
1479 rhs_rs: 2,
1480 rhs_cs: b[0].len() as isize,
1481 last_mask: (&last_mask) as *const _ as *const (),
1482 },
1483 dst.as_mut_ptr() as *mut c32,
1484 a.as_ptr() as *const c32,
1485 b.as_ptr() as *const c32,
1486 );
1487 };
1488
1489 let mut expected_dst = c;
1490 for i in 0..7 {
1491 for j in 0..2 {
1492 let mut acc = c32::new(0.0, 0.0);
1493 for depth in 0..3 {
1494 let mut a = a[depth][i];
1495 let mut b = b[j][2 * depth];
1496 if conj_lhs {
1497 a = a.conj();
1498 }
1499 if conj_rhs {
1500 b = b.conj();
1501 }
1502 acc += a * b;
1503 }
1504 expected_dst[j][i] += beta * acc;
1505 }
1506 }
1507
1508 for (&dst, &expected_dst) in
1509 core::iter::zip(dst.iter().flatten(), expected_dst.iter().flatten())
1510 {
1511 assert!((dst.re - expected_dst.re).abs() < 1e-5);
1512 assert!((dst.im - expected_dst.im).abs() < 1e-5);
1513 }
1514 }
1515 }
1516
1517 #[cfg(target_arch = "x86_64")]
1518 #[test]
1519 fn test_kernel_cplx64() {
1520 let gen = |_| rand::random::<c64>();
1521 let a: [[c64; 5]; 3] = core::array::from_fn(|_| core::array::from_fn(gen));
1522 let b: [[c64; 6]; 2] = core::array::from_fn(|_| core::array::from_fn(gen));
1523 let c: [[c64; 3]; 2] = core::array::from_fn(|_| core::array::from_fn(gen));
1524 assert!(std::is_x86_feature_detected!("avx2"));
1525
1526 let last_mask: std::arch::x86_64::__m256i =
1527 unsafe { core::mem::transmute([u64::MAX, u64::MAX, 0, 0]) };
1528
1529 let beta = c64::new(2.5, 3.5);
1530 let alpha = c64::new(1.0, 0.0);
1531
1532 for (conj_lhs, conj_rhs) in [(false, false), (false, true), (true, false), (true, true)] {
1533 let mut dst = c;
1534 unsafe {
1535 x86::c64::avx::matmul_2_2_dyn(
1536 &MicroKernelData {
1537 alpha,
1538 beta,
1539 conj_lhs,
1540 conj_rhs,
1541 k: 3,
1542 dst_cs: dst[0].len() as isize,
1543 lhs_cs: a[0].len() as isize,
1544 rhs_rs: 2,
1545 rhs_cs: b[0].len() as isize,
1546 last_mask: (&last_mask) as *const _ as *const (),
1547 },
1548 dst.as_mut_ptr() as *mut c64,
1549 a.as_ptr() as *const c64,
1550 b.as_ptr() as *const c64,
1551 );
1552 };
1553
1554 let mut expected_dst = c;
1555 for i in 0..3 {
1556 for j in 0..2 {
1557 let mut acc = c64::new(0.0, 0.0);
1558 for depth in 0..3 {
1559 let mut a = a[depth][i];
1560 let mut b = b[j][2 * depth];
1561 if conj_lhs {
1562 a = a.conj();
1563 }
1564 if conj_rhs {
1565 b = b.conj();
1566 }
1567 acc += a * b;
1568 }
1569 expected_dst[j][i] += beta * acc;
1570 }
1571 }
1572
1573 for (&dst, &expected_dst) in
1574 core::iter::zip(dst.iter().flatten(), expected_dst.iter().flatten())
1575 {
1576 assert!((dst.re - expected_dst.re).abs() < 1e-5);
1577 assert!((dst.im - expected_dst.im).abs() < 1e-5);
1578 }
1579 }
1580 }
1581 #[test]
1582 fn test_plan() {
1583 let gen = |_| rand::random::<f32>();
1584 for ((m, n), k) in (64..=64).zip(64..=64).zip([1, 4, 64]) {
1585 let a = (0..m * k).into_iter().map(gen).collect::<Vec<_>>();
1586 let b = (0..k * n).into_iter().map(gen).collect::<Vec<_>>();
1587 let c = (0..m * n).into_iter().map(|_| 0.0).collect::<Vec<_>>();
1588 let mut dst = c.clone();
1589
1590 let plan = Plan::new_f32(m, n, k);
1591 let beta = 2.5;
1592
1593 unsafe {
1594 plan.execute_unchecked(
1595 m,
1596 n,
1597 k,
1598 dst.as_mut_ptr(),
1599 1,
1600 m as isize,
1601 a.as_ptr(),
1602 1,
1603 m as isize,
1604 b.as_ptr(),
1605 1,
1606 k as isize,
1607 1.0,
1608 beta,
1609 false,
1610 false,
1611 );
1612 };
1613
1614 let mut expected_dst = c;
1615 for i in 0..m {
1616 for j in 0..n {
1617 let mut acc = 0.0f32;
1618 for depth in 0..k {
1619 acc = f32::mul_add(a[depth * m + i], b[j * k + depth], acc);
1620 }
1621 expected_dst[j * m + i] = f32::mul_add(beta, acc, expected_dst[j * m + i]);
1622 }
1623 }
1624
1625 for (dst, expected_dst) in dst.iter().zip(&expected_dst) {
1626 assert!((dst - expected_dst).abs() < 1e-4);
1627 }
1628 }
1629 }
1630
1631 #[test]
1632 fn test_plan_cplx() {
1633 let gen = |_| rand::random::<c64>();
1634 for ((m, n), k) in (0..128).zip(0..128).zip([1, 4, 17]) {
1635 let a = (0..m * k).into_iter().map(gen).collect::<Vec<_>>();
1636 let b = (0..k * n).into_iter().map(gen).collect::<Vec<_>>();
1637 let c = (0..m * n).into_iter().map(gen).collect::<Vec<_>>();
1638
1639 for (conj_lhs, conj_rhs) in [(false, true), (false, false), (true, true), (true, false)]
1640 {
1641 for alpha in [c64::new(0.0, 0.0), c64::new(1.0, 0.0), c64::new(2.7, 3.7)] {
1642 let mut dst = c.clone();
1643
1644 let plan = Plan::new_colmajor_lhs_and_dst_c64(m, n, k);
1645 let beta = c64::new(2.5, 0.0);
1646
1647 unsafe {
1648 plan.execute_unchecked(
1649 m,
1650 n,
1651 k,
1652 dst.as_mut_ptr(),
1653 1,
1654 m as isize,
1655 a.as_ptr(),
1656 1,
1657 m as isize,
1658 b.as_ptr(),
1659 1,
1660 k as isize,
1661 alpha,
1662 beta,
1663 conj_lhs,
1664 conj_rhs,
1665 );
1666 };
1667
1668 let mut expected_dst = c.clone();
1669 for i in 0..m {
1670 for j in 0..n {
1671 let mut acc = c64::new(0.0, 0.0);
1672 for depth in 0..k {
1673 let mut a = a[depth * m + i];
1674 let mut b = b[j * k + depth];
1675 if conj_lhs {
1676 a = a.conj();
1677 }
1678 if conj_rhs {
1679 b = b.conj();
1680 }
1681 acc += a * b;
1682 }
1683 expected_dst[j * m + i] = alpha * expected_dst[j * m + i] + beta * acc;
1684 }
1685 }
1686
1687 for (&dst, &expected_dst) in core::iter::zip(dst.iter(), expected_dst.iter()) {
1688 assert!((dst.re - expected_dst.re).abs() < 1e-5);
1689 assert!((dst.im - expected_dst.im).abs() < 1e-5);
1690 }
1691 }
1692 }
1693 }
1694 }
1695
1696 #[test]
1697 fn test_plan_strided() {
1698 let gen = |_| rand::random::<f32>();
1699 for ((m, n), k) in (0..128).zip(0..128).zip([1, 4, 17]) {
1700 let a = (0..2 * 200 * k).into_iter().map(gen).collect::<Vec<_>>();
1701 let b = (0..k * n).into_iter().map(gen).collect::<Vec<_>>();
1702 let c = (0..3 * 400 * n)
1703 .into_iter()
1704 .map(|_| 0.0)
1705 .collect::<Vec<_>>();
1706 let mut dst = c.clone();
1707
1708 let plan = Plan::new_f32(m, n, k);
1709 let beta = 2.5;
1710
1711 unsafe {
1712 plan.execute_unchecked(
1713 m,
1714 n,
1715 k,
1716 dst.as_mut_ptr(),
1717 3,
1718 400,
1719 a.as_ptr(),
1720 2,
1721 200,
1722 b.as_ptr(),
1723 1,
1724 k as isize,
1725 1.0,
1726 beta,
1727 false,
1728 false,
1729 );
1730 };
1731
1732 let mut expected_dst = c;
1733 for i in 0..m {
1734 for j in 0..n {
1735 let mut acc = 0.0f32;
1736 for depth in 0..k {
1737 acc = f32::mul_add(a[depth * 200 + i * 2], b[j * k + depth], acc);
1738 }
1739 expected_dst[j * 400 + i * 3] =
1740 f32::mul_add(beta, acc, expected_dst[j * 400 + i * 3]);
1741 }
1742 }
1743
1744 for (dst, expected_dst) in dst.iter().zip(&expected_dst) {
1745 assert!((dst - expected_dst).abs() < 1e-4);
1746 }
1747 }
1748 }
1749
1750 #[test]
1751 fn test_plan_cplx_strided() {
1752 let gen = |_| c64::new(rand::random(), rand::random());
1753 for ((m, n), k) in (0..128).zip(0..128).zip([1, 4, 17, 190]) {
1754 let a = (0..2 * 200 * k).into_iter().map(gen).collect::<Vec<_>>();
1755 let b = (0..k * n).into_iter().map(gen).collect::<Vec<_>>();
1756 let c = (0..3 * 400 * n)
1757 .into_iter()
1758 .map(|_| c64::ZERO)
1759 .collect::<Vec<_>>();
1760 let mut dst = c.clone();
1761
1762 let beta = 2.5.into();
1763
1764 unsafe {
1765 planless::execute_c64(
1766 m,
1767 n,
1768 k,
1769 dst.as_mut_ptr(),
1770 3,
1771 400,
1772 a.as_ptr(),
1773 2,
1774 200,
1775 b.as_ptr(),
1776 1,
1777 k as isize,
1778 1.0.into(),
1779 beta,
1780 false,
1781 false,
1782 );
1783 };
1784
1785 let mut expected_dst = c;
1786 for i in 0..m {
1787 for j in 0..n {
1788 let mut acc = c64::ZERO;
1789 for depth in 0..k {
1790 acc += a[depth * 200 + i * 2] * b[j * k + depth];
1791 }
1792 expected_dst[j * 400 + i * 3] = beta * acc + expected_dst[j * 400 + i * 3];
1793 }
1794 }
1795
1796 for (dst, expected_dst) in dst.iter().zip(&expected_dst) {
1797 use num_complex::ComplexFloat;
1798 assert!((dst - expected_dst).abs() < 1e-4);
1799 }
1800 }
1801 }
1802
1803 #[test]
1804 fn test_plan_cplx_strided2() {
1805 let gen = |_| c64::new(rand::random(), rand::random());
1806 let m = 102;
1807 let n = 2;
1808 let k = 190;
1809 {
1810 let a = (0..2 * 200 * k).into_iter().map(gen).collect::<Vec<_>>();
1811 let b = (0..k * n).into_iter().map(gen).collect::<Vec<_>>();
1812 let c = (0..400 * n)
1813 .into_iter()
1814 .map(|_| c64::ZERO)
1815 .collect::<Vec<_>>();
1816 let mut dst = c.clone();
1817
1818 let beta = 2.5.into();
1819
1820 unsafe {
1821 planless::execute_c64(
1822 m,
1823 n,
1824 k,
1825 dst.as_mut_ptr(),
1826 1,
1827 400,
1828 a.as_ptr(),
1829 2,
1830 200,
1831 b.as_ptr(),
1832 1,
1833 k as isize,
1834 1.0.into(),
1835 beta,
1836 false,
1837 false,
1838 );
1839 };
1840
1841 let mut expected_dst = c;
1842 for i in 0..m {
1843 for j in 0..n {
1844 let mut acc = c64::ZERO;
1845 for depth in 0..k {
1846 acc += a[depth * 200 + i * 2] * b[j * k + depth];
1847 }
1848 expected_dst[j * 400 + i] = beta * acc + expected_dst[j * 400 + i];
1849 }
1850 }
1851
1852 for (dst, expected_dst) in dst.iter().zip(&expected_dst) {
1853 use num_complex::ComplexFloat;
1854 assert!((dst - expected_dst).abs() < 1e-4);
1855 }
1856 }
1857 }
1858}