1#![allow(non_upper_case_globals)]
2#![allow(dead_code, unused_variables)]
3
4const M: usize = 4;
5const N: usize = 32;
6
7use core::cell::RefCell;
8use core::ptr::{null, null_mut};
9use core::sync::atomic::{AtomicU8, AtomicUsize, Ordering};
10
11use cache::CACHE_INFO;
12
13include!(concat!(env!("OUT_DIR"), "/asm.rs"));
14
15#[derive(Copy, Clone, Debug)]
16#[repr(C)]
17pub struct Position {
18 pub row: usize,
19 pub col: usize,
20}
21
22mod cache;
23
24const FLAGS_ACCUM: usize = 1 << 0;
25const FLAGS_CONJ_LHS: usize = 1 << 1;
26const FLAGS_CONJ_NEQ: usize = 1 << 2;
27const FLAGS_LOWER: usize = 1 << 3;
28const FLAGS_UPPER: usize = 1 << 4;
29const FLAGS_32BIT_IDX: usize = 1 << 5;
30const FLAGS_CPLX: usize = 1 << 62;
31const FLAGS_ROWMAJOR: usize = 1 << 63;
32
33#[derive(Copy, Clone, Debug)]
34#[repr(C)]
35pub struct MicrokernelInfo {
36 pub flags: usize,
37 pub depth: usize,
38 pub lhs_rs: isize,
39 pub lhs_cs: isize,
40 pub rhs_rs: isize,
41 pub rhs_cs: isize,
42 pub alpha: *const (),
43
44 pub ptr: *mut (),
46 pub rs: isize,
47 pub cs: isize,
48 pub row_idx: *const (),
49 pub col_idx: *const (),
50
51 pub diag_ptr: *const (),
53 pub diag_stride: isize,
54}
55
56#[derive(Copy, Clone, Debug)]
57#[repr(C)]
58pub struct MillikernelInfo {
59 pub lhs_rs: isize,
60 pub packed_lhs_rs: isize,
61 pub rhs_cs: isize,
62 pub packed_rhs_cs: isize,
63 pub micro: MicrokernelInfo,
64}
65
66#[inline(always)]
67unsafe fn pack_rhs_imp<T: Copy>(dst: *mut T, src: *const (), depth: usize, stride: usize, nr: usize, rs: isize, cs: isize) {
68 for i in 0..depth {
69 unsafe {
70 let dst = dst.add(i * stride);
71 let src = src.byte_offset(i as isize * rs);
72
73 for j in 0..nr {
74 let dst = dst.add(j);
75 let src = src.byte_offset(j as isize * cs) as *const T;
76
77 *dst = *src;
78 }
79 }
80 }
81}
82
83#[inline(never)]
84unsafe fn pack_rhs(dst: *mut (), src: *const (), depth: usize, nr: usize, rs: isize, cs: isize, sizeof: usize) {
85 if !src.is_null() && src != dst as *const () {
86 unsafe {
87 match sizeof {
88 4 => pack_rhs_imp(dst as *mut f32, src, depth, nr, nr, rs, cs),
89 8 => pack_rhs_imp(dst as *mut [f32; 2], src, depth, nr, nr, rs, cs),
90 16 => pack_rhs_imp(dst as *mut [f64; 2], src, depth, nr, nr, rs, cs),
91 _ => unreachable!(),
92 }
93 }
94 }
95}
96
97#[inline(always)]
98pub unsafe fn call_microkernel(
99 microkernel: unsafe extern "C" fn(),
100 lhs: *const (),
101 packed_lhs: *mut (),
102
103 rhs: *const (),
104 packed_rhs: *mut (),
105
106 mut nrows: usize,
107 mut ncols: usize,
108
109 micro: &MicrokernelInfo,
110 position: &mut Position,
111) -> (usize, usize) {
112 unsafe {
113 core::arch::asm! {
114 "call r10",
115
116 in("rax") lhs,
117 in("r15") packed_lhs,
118 in("rcx") rhs,
119 in("rdx") packed_rhs,
120 in("rdi") position,
121 in("rsi") micro,
122 inout("r8") nrows,
123 inout("r9") ncols,
124 in("r10") microkernel,
125
126 out("zmm0") _,
127 out("zmm1") _,
128 out("zmm2") _,
129 out("zmm3") _,
130 out("zmm4") _,
131 out("zmm5") _,
132 out("zmm6") _,
133 out("zmm7") _,
134 out("zmm8") _,
135 out("zmm9") _,
136 out("zmm10") _,
137 out("zmm11") _,
138 out("zmm12") _,
139 out("zmm13") _,
140 out("zmm14") _,
141 out("zmm15") _,
142 out("zmm16") _,
143 out("zmm17") _,
144 out("zmm18") _,
145 out("zmm19") _,
146 out("zmm20") _,
147 out("zmm21") _,
148 out("zmm22") _,
149 out("zmm23") _,
150 out("zmm24") _,
151 out("zmm25") _,
152 out("zmm26") _,
153 out("zmm27") _,
154 out("zmm28") _,
155 out("zmm29") _,
156 out("zmm30") _,
157 out("zmm31") _,
158 out("k1") _,
159 out("k2") _,
160 out("k3") _,
161 out("k4") _,
162 }
163 }
164 (nrows, ncols)
165}
166
167pub unsafe fn millikernel_rowmajor(
168 microkernel: unsafe extern "C" fn(),
169 pack: unsafe extern "C" fn(),
170 mr: usize,
171 nr: usize,
172 sizeof: usize,
173
174 lhs: *const (),
175 packed_lhs: *mut (),
176
177 rhs: *const (),
178 packed_rhs: *mut (),
179
180 nrows: usize,
181 ncols: usize,
182
183 milli: &MillikernelInfo,
184
185 pos: &mut Position,
186) {
187 let mut rhs = rhs;
188 let mut nrows = nrows;
189 let mut lhs = lhs;
190 let mut packed_lhs = packed_lhs;
191
192 let tril = milli.micro.flags & FLAGS_LOWER != 0;
193 let triu = milli.micro.flags & FLAGS_UPPER != 0;
194 let rectangular = !tril && !triu;
195
196 loop {
197 let rs = milli.micro.lhs_rs;
198 unsafe {
199 let mut rhs = rhs;
200 let mut packed_rhs = packed_rhs;
201 let mut ncols = ncols;
202 let mut lhs = lhs;
203 let col = pos.col;
204
205 macro_rules! iter {
206 ($($lhs: ident)?) => {{
207 $({
208 let _ = $lhs;
209 if lhs != packed_lhs && !lhs.is_null() && (!milli.micro.diag_ptr.is_null() || milli.micro.lhs_rs != sizeof as isize) {
210 pack_lhs(pack, milli, Ord::min(nrows, mr), packed_lhs, lhs, sizeof);
211 lhs = null();
212 }
213 })*
214
215 let row_chunk = Ord::min(nrows, mr);
216 let col_chunk = Ord::min(ncols, nr);
217
218 {
219 let mut rhs = rhs;
220 if rhs != packed_rhs && !rhs.is_null() {
221 pack_rhs(
222 packed_rhs,
223 rhs,
224 milli.micro.depth,
225 col_chunk,
226 milli.micro.rhs_rs,
227 milli.micro.rhs_cs,
228 sizeof,
229 );
230 rhs = null();
231 }
232
233
234 if rectangular || (tril && pos.row + mr > pos.col) || (triu && pos.col + col_chunk > pos.row) {
235 call_microkernel(
236 microkernel,
237 lhs,
238 packed_lhs,
239 rhs,
240 packed_rhs,
241 row_chunk,
242 col_chunk,
243 &milli.micro,
244 pos,
245 );
246 } else {
247 if lhs != packed_lhs && !lhs.is_null() {
248 pack_lhs(pack, milli, row_chunk, packed_lhs, lhs, sizeof);
249 }
250 }
251 }
252
253 pos.col += col_chunk;
254 ncols -= col_chunk;
255 if ncols == 0 {
256 pos.row += row_chunk;
257 nrows -= row_chunk;
258 }
259
260 if !rhs.is_null() {
261 rhs = rhs.wrapping_byte_offset(milli.rhs_cs);
262 }
263 packed_rhs = packed_rhs.wrapping_byte_offset(milli.packed_rhs_cs);
264
265 $(if lhs != packed_lhs {
266 $lhs = null();
267 })?
268 }};
269 }
270 iter!(lhs);
271 while ncols > 0 {
272 iter!();
273 }
274 pos.col = col;
275 }
276
277 if !lhs.is_null() {
278 lhs = lhs.wrapping_byte_offset(milli.lhs_rs);
279 }
280 packed_lhs = packed_lhs.wrapping_byte_offset(milli.packed_lhs_rs);
281 if rhs != packed_rhs {
282 rhs = null();
283 }
284
285 if nrows == 0 {
286 break;
287 }
288 }
289}
290
291pub unsafe fn millikernel_colmajor(
292 microkernel: unsafe extern "C" fn(),
293 pack: unsafe extern "C" fn(),
294 mr: usize,
295 nr: usize,
296 sizeof: usize,
297
298 lhs: *const (),
299 packed_lhs: *mut (),
300
301 rhs: *const (),
302 packed_rhs: *mut (),
303
304 nrows: usize,
305 ncols: usize,
306
307 milli: &MillikernelInfo,
308
309 pos: &mut Position,
310) {
311 let mut lhs = lhs;
312 let mut ncols = ncols;
313 let mut rhs = rhs;
314 let mut packed_rhs = packed_rhs;
315
316 let tril = milli.micro.flags & FLAGS_LOWER != 0;
317 let triu = milli.micro.flags & FLAGS_UPPER != 0;
318 let rectangular = !tril && !triu;
319
320 let mut j = 0;
321
322 loop {
323 let cs = milli.micro.rhs_cs;
324 unsafe {
325 let mut lhs = lhs;
326 let mut packed_lhs = packed_lhs;
327 let mut nrows = nrows;
328 let mut rhs = rhs;
329 let row = pos.row;
330
331 macro_rules! iter {
332 ($($rhs: ident)?) => {{
333 {
334 let mut lhs = lhs;
335
336 let row_chunk = Ord::min(nrows, mr);
337 let col_chunk = Ord::min(ncols, nr);
338
339 if lhs != packed_lhs && !lhs.is_null() && (!milli.micro.diag_ptr.is_null() || milli.micro.lhs_rs != sizeof as isize) {
340 pack_lhs(pack, milli, row_chunk, packed_lhs, lhs, sizeof);
341 lhs = null();
342 }
343
344 $({
345 let _ = $rhs;
346 if rhs != packed_rhs && !rhs.is_null() {
347 pack_rhs(
348 packed_rhs,
349 rhs,
350 milli.micro.depth,
351 col_chunk,
352 milli.micro.rhs_rs,
353 milli.micro.rhs_cs,
354 sizeof,
355 );
356 rhs = null();
357 }
358 })*
359 if rectangular || (tril && pos.row + mr > pos.col) || (triu && pos.col + col_chunk > pos.row) {
360 call_microkernel(
361 microkernel,
362 lhs,
363 packed_lhs,
364 rhs,
365 packed_rhs,
366 row_chunk,
367 col_chunk,
368 &milli.micro,
369 pos,
370 );
371 } else {
372 if lhs != packed_lhs && !lhs.is_null() {
373 pack_lhs(pack, milli, row_chunk, packed_lhs, lhs, sizeof);
374 }
375 }
376
377 pos.row += row_chunk;
378 nrows -= row_chunk;
379 if nrows == 0 {
380 pos.col += col_chunk;
381 ncols -= col_chunk;
382 }
383 }
384
385 if !lhs.is_null() {
386 lhs = lhs.wrapping_byte_offset(milli.lhs_rs);
387 }
388 packed_lhs = packed_lhs.wrapping_byte_offset(milli.packed_lhs_rs);
389
390 $(if rhs != packed_rhs {
391 $rhs = null();
392 })?
393 }};
394 }
395 iter!(rhs);
396 while nrows > 0 {
397 iter!();
398 }
399 pos.row = row;
400 }
401
402 if !rhs.is_null() {
403 rhs = rhs.wrapping_byte_offset(milli.rhs_cs);
404 }
405 packed_rhs = packed_rhs.wrapping_byte_offset(milli.packed_rhs_cs);
406 if lhs != packed_lhs {
407 lhs = null();
408 }
409
410 j += 1;
411 if ncols == 0 {
412 break;
413 }
414 }
415}
416
417pub unsafe fn millikernel_par(
418 thd_id: usize,
419 n_threads: usize,
420
421 microkernel_job: &[AtomicU8],
422 pack_lhs_job: &[AtomicU8],
423 pack_rhs_job: &[AtomicU8],
424 finished: &AtomicUsize,
425 hyper: usize,
426
427 mr: usize,
428 nr: usize,
429 sizeof: usize,
430
431 mf: usize,
432 nf: usize,
433
434 microkernel: unsafe extern "C" fn(),
435 pack: unsafe extern "C" fn(),
436
437 lhs: *const (),
438 packed_lhs: *mut (),
439
440 rhs: *const (),
441 packed_rhs: *mut (),
442
443 nrows: usize,
444 ncols: usize,
445
446 milli: &MillikernelInfo,
447
448 pos: Position,
449 tall: bool,
450) {
451 let n_threads0 = nrows.div_ceil(mf * mr);
452 let n_threads1 = ncols.div_ceil(nf * nr);
453
454 let thd_id0 = thd_id % (n_threads0);
455 let thd_id1 = thd_id / (n_threads0);
456
457 let tril = milli.micro.flags & FLAGS_LOWER != 0;
458 let triu = milli.micro.flags & FLAGS_UPPER != 0;
459 let rectangular = !tril && !triu;
460
461 let i = mf * thd_id0;
462 let j = nf * thd_id1;
463
464 let colmajor = !tall;
465
466 for ij in 0..mf * nf {
467 let (i, j) = if colmajor {
468 (i + ij % mf, j + ij / mf)
469 } else {
470 (i + ij / nf, j + ij % nf)
471 };
472
473 let row = Ord::min(nrows, i * mr);
474 let col = Ord::min(ncols, j * nr);
475
476 let row_chunk = Ord::min(nrows - row, mr);
477 let col_chunk = Ord::min(ncols - col, nr);
478
479 if row_chunk == 0 || col_chunk == 0 {
480 continue;
481 }
482
483 let packed_lhs = packed_lhs.wrapping_byte_offset(milli.packed_lhs_rs * i as isize);
484 let packed_rhs = packed_rhs.wrapping_byte_offset(milli.packed_rhs_cs * j as isize);
485
486 let mut lhs = lhs;
487 let mut rhs = rhs;
488
489 {
490 if !lhs.is_null() {
491 lhs = lhs.wrapping_byte_offset(milli.lhs_rs * i as isize);
492 }
493
494 if lhs != packed_lhs {
495 let val = pack_lhs_job[i].load(Ordering::Acquire);
496
497 if val == 2 {
498 lhs = null();
499 }
500 }
501 }
502
503 {
504 if !rhs.is_null() {
505 rhs = rhs.wrapping_byte_offset(milli.rhs_cs * j as isize);
506 }
507 if rhs != packed_rhs {
508 let val = pack_rhs_job[j].load(Ordering::Acquire);
509
510 if val == 2 {
511 rhs = null();
512 }
513 }
514
515 unsafe {
516 if lhs != packed_lhs && !lhs.is_null() && (!milli.micro.diag_ptr.is_null() || milli.micro.lhs_rs != sizeof as isize) {
517 pack_lhs(pack, milli, row_chunk, packed_lhs, lhs, sizeof);
518
519 lhs = null();
520 pack_lhs_job[i].store(2, Ordering::Release);
521 }
522 if rhs != packed_rhs && !rhs.is_null() {
523 pack_rhs(
524 packed_rhs,
525 rhs,
526 milli.micro.depth,
527 col_chunk,
528 milli.micro.rhs_rs,
529 milli.micro.rhs_cs,
530 sizeof,
531 );
532 rhs = null();
533 pack_rhs_job[j].store(2, Ordering::Release);
534 }
535
536 if rectangular || (tril && pos.row + mr > pos.col) || (triu && pos.col + col_chunk > pos.row) {
537 call_microkernel(
538 microkernel,
539 lhs,
540 packed_lhs,
541 rhs,
542 packed_rhs,
543 row_chunk,
544 col_chunk,
545 &milli.micro,
546 &mut Position {
547 row: row + pos.row,
548 col: col + pos.col,
549 },
550 );
551 } else {
552 if lhs != packed_lhs && !lhs.is_null() {
553 pack_lhs(pack, milli, row_chunk, packed_lhs, lhs, sizeof);
554 }
555 }
556 }
557
558 if !lhs.is_null() && lhs != packed_lhs {
559 pack_lhs_job[i].store(2, Ordering::Release);
560 }
561 if !rhs.is_null() && rhs != packed_rhs {
562 pack_rhs_job[j].store(2, Ordering::Release);
563 }
564 }
565 }
566}
567
568unsafe fn pack_lhs(pack: unsafe extern "C" fn(), milli: &MillikernelInfo, row_chunk: usize, packed_lhs: *mut (), lhs: *const (), sizeof: usize) {
569 unsafe {
570 {
571 let mut dst_cs = row_chunk;
572 core::arch::asm! {
573 "call r10",
574 in("r10") pack,
575 in("rax") lhs,
576 in("r15") packed_lhs,
577 inout("r8") dst_cs,
578 in("rsi") &milli.micro,
579
580 out("zmm0") _,
581 out("zmm1") _,
582 out("zmm2") _,
583 out("zmm3") _,
584 out("zmm4") _,
585 out("zmm5") _,
586 out("zmm6") _,
587 out("zmm7") _,
588 out("zmm8") _,
589 out("zmm9") _,
590 out("zmm10") _,
591 out("zmm11") _,
592 out("zmm12") _,
593 out("zmm13") _,
594 out("zmm14") _,
595 out("zmm15") _,
596 out("zmm16") _,
597 out("zmm17") _,
598 out("zmm18") _,
599 out("zmm19") _,
600 out("zmm20") _,
601 out("zmm21") _,
602 out("zmm22") _,
603 out("zmm23") _,
604 out("zmm24") _,
605 out("zmm25") _,
606 out("zmm26") _,
607 out("zmm27") _,
608 out("zmm28") _,
609 out("zmm29") _,
610 out("zmm30") _,
611 out("zmm31") _,
612 out("k1") _,
613 out("k2") _,
614 out("k3") _,
615 out("k4") _,
616 };
617
618 if milli.micro.lhs_rs != sizeof as isize && milli.micro.lhs_cs != sizeof as isize {
619 for j in 0..milli.micro.depth {
620 let dst = packed_lhs.byte_add(j * dst_cs);
621 let src = lhs.byte_offset(j as isize * milli.micro.lhs_cs);
622 let diag_ptr = milli.micro.diag_ptr.byte_offset(j as isize * milli.micro.diag_stride);
623
624 if sizeof == 4 {
625 let dst = dst as *mut f32;
626 let src = src as *const f32;
627 for i in 0..row_chunk {
628 let dst = dst.add(i);
629 let src = src.byte_offset(i as isize * milli.micro.lhs_rs);
630
631 if diag_ptr.is_null() {
632 *dst = *src;
633 } else {
634 *dst = *src * *(diag_ptr as *const f32);
635 }
636 }
637 } else if sizeof == 16 {
638 let dst = dst as *mut [f64; 2];
639 let src = src as *const [f64; 2];
640 for i in 0..row_chunk {
641 let dst = dst.add(i);
642 let src = src.byte_offset(i as isize * milli.micro.lhs_rs);
643
644 if diag_ptr.is_null() {
645 *dst = *src;
646 } else {
647 (*dst)[0] = (*src)[0] * *(diag_ptr as *const f64);
648 (*dst)[1] = (*src)[1] * *(diag_ptr as *const f64);
649 }
650 }
651 } else {
652 if (milli.micro.flags >> 62) & 1 == 1 {
653 let dst = dst as *mut [f32; 2];
654 let src = src as *const [f32; 2];
655 for i in 0..row_chunk {
656 let dst = dst.add(i);
657 let src = src.byte_offset(i as isize * milli.micro.lhs_rs);
658
659 if diag_ptr.is_null() {
660 *dst = *src;
661 } else {
662 (*dst)[0] = (*src)[0] * *(diag_ptr as *const f32);
663 (*dst)[1] = (*src)[1] * *(diag_ptr as *const f32);
664 }
665 }
666 } else {
667 let dst = dst as *mut f64;
668 let src = src as *const f64;
669 for i in 0..row_chunk {
670 let dst = dst.add(i);
671 let src = src.byte_offset(i as isize * milli.micro.lhs_rs);
672
673 if diag_ptr.is_null() {
674 *dst = *src;
675 } else {
676 *dst = *src * *(diag_ptr as *const f64);
677 }
678 }
679 }
680 }
681 }
682 }
683 }
684 }
685}
686
687pub unsafe trait Millikernel {
688 unsafe fn call(
689 &mut self,
690
691 microkernel: unsafe extern "C" fn(),
692 pack: unsafe extern "C" fn(),
693
694 lhs: *const (),
695 packed_lhs: *mut (),
696
697 rhs: *const (),
698 packed_rhs: *mut (),
699
700 nrows: usize,
701 ncols: usize,
702
703 milli: &MillikernelInfo,
704
705 pos: Position,
706 );
707}
708
709struct Milli {
710 mr: usize,
711 nr: usize,
712 sizeof: usize,
713}
714#[cfg(feature = "rayon")]
715struct MilliPar {
716 mr: usize,
717 nr: usize,
718 hyper: usize,
719 sizeof: usize,
720
721 microkernel_job: Box<[AtomicU8]>,
722 pack_lhs_job: Box<[AtomicU8]>,
723 pack_rhs_job: Box<[AtomicU8]>,
724 finished: AtomicUsize,
725 n_threads: usize,
726}
727
728unsafe impl Millikernel for Milli {
729 unsafe fn call(
730 &mut self,
731
732 microkernel: unsafe extern "C" fn(),
733 pack: unsafe extern "C" fn(),
734
735 lhs: *const (),
736 packed_lhs: *mut (),
737
738 rhs: *const (),
739 packed_rhs: *mut (),
740
741 nrows: usize,
742 ncols: usize,
743
744 milli: &MillikernelInfo,
745 pos: Position,
746 ) {
747 unsafe {
748 (if milli.micro.flags >> 63 == 1 {
749 millikernel_rowmajor
750 } else {
751 millikernel_colmajor
752 })(
753 microkernel,
754 pack,
755 self.mr,
756 self.nr,
757 self.sizeof,
758 lhs,
759 packed_lhs,
760 rhs,
761 packed_rhs,
762 nrows,
763 ncols,
764 milli,
765 &mut { pos },
766 )
767 }
768 }
769}
770
771#[derive(Copy, Clone)]
772pub struct ForceSync<T>(pub T);
773unsafe impl<T> Sync for ForceSync<T> {}
774unsafe impl<T> Send for ForceSync<T> {}
775
776#[cfg(feature = "rayon")]
777unsafe impl Millikernel for MilliPar {
778 unsafe fn call(
779 &mut self,
780
781 microkernel: unsafe extern "C" fn(),
782 pack: unsafe extern "C" fn(),
783
784 lhs: *const (),
785 packed_lhs: *mut (),
786
787 rhs: *const (),
788 packed_rhs: *mut (),
789
790 nrows: usize,
791 ncols: usize,
792
793 milli: &MillikernelInfo,
794 pos: Position,
795 ) {
796 let lhs = ForceSync(lhs);
797 let mut rhs = ForceSync(rhs);
798 let packed_lhs = ForceSync(packed_lhs);
799 let packed_rhs = ForceSync(packed_rhs);
800 let milli = ForceSync(milli);
801
802 self.microkernel_job.fill_with(|| AtomicU8::new(0));
803 self.pack_lhs_job.fill_with(|| AtomicU8::new(0));
804 self.pack_rhs_job.fill_with(|| AtomicU8::new(0));
805 self.finished = AtomicUsize::new(0);
806
807 let f = Ord::min(8, milli.0.micro.depth.div_ceil(64));
808 let l3 = CACHE_INFO[2].cache_bytes / f;
809
810 let tall = nrows >= l3;
811 let wide = ncols >= 2 * nrows;
812
813 let mut mf = Ord::clamp(nrows.div_ceil(self.mr).div_ceil(2 * self.n_threads), 2, 4);
814 if tall {
815 mf = 16 / f;
816 }
817 if wide {
818 mf = 2;
819 }
820 let par_rows = nrows.div_ceil(mf * self.mr);
821 let nf = Ord::clamp(ncols.div_ceil(self.nr).div_ceil(8 * self.n_threads) * par_rows, 1, 1024 / f);
822 let nf = 32 / self.nr;
823
824 let n = nrows.div_ceil(mf * self.mr) * ncols.div_ceil(nf * self.nr);
825
826 let mr = self.mr;
827 let nr = self.nr;
828
829 if !rhs.0.is_null() && rhs.0 != packed_rhs.0 {
830 let depth = { milli }.0.micro.depth;
831
832 let div = depth / self.n_threads;
833 let rem = depth % self.n_threads;
834
835 if !wide {
836 spindle::for_each_raw(self.n_threads, |j| {
837 let mut start = j * div;
838 if j <= rem {
839 start += j;
840 } else {
841 start += rem;
842 }
843 let end = start + div + if j < rem { 1 } else { 0 };
844 let milli = { milli }.0;
845
846 for i in 0..ncols.div_ceil(nr) {
847 let col = Ord::min(ncols, i * nr);
848 let ncols = Ord::min(ncols - col, nr);
849
850 let rs = ncols;
851 let rhs = { rhs }.0.wrapping_byte_offset(milli.rhs_cs * i as isize);
852 let packed_rhs = { packed_rhs }.0.wrapping_byte_offset(milli.packed_rhs_cs * i as isize);
853
854 pack_rhs(
855 packed_rhs.wrapping_byte_offset((start * rs * self.sizeof) as isize),
856 rhs.wrapping_byte_offset(start as isize * milli.micro.rhs_rs),
857 end - start,
858 ncols,
859 milli.micro.rhs_rs,
860 milli.micro.rhs_cs,
861 self.sizeof,
862 );
863 }
864 });
865 rhs.0 = null();
866 }
867 }
868
869 let gtid = AtomicUsize::new(0);
870
871 spindle::for_each_raw(self.n_threads, |_| unsafe {
872 loop {
873 let tid = gtid.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
874 if tid >= n {
875 return;
876 }
877 let milli = { milli }.0;
878
879 millikernel_par(
880 tid,
881 n,
882 &self.microkernel_job,
883 &self.pack_lhs_job,
884 &self.pack_rhs_job,
885 &self.finished,
886 self.hyper,
887 self.mr,
888 self.nr,
889 self.sizeof,
890 mf,
891 nf,
892 microkernel,
893 pack,
894 { lhs }.0,
895 { packed_lhs }.0,
896 { rhs }.0,
897 { packed_rhs }.0,
898 nrows,
899 ncols,
900 milli,
901 pos,
902 tall,
903 );
904 }
905 });
906 }
907}
908
909#[inline(never)]
910unsafe fn kernel_imp(
911 millikernel: &mut dyn Millikernel,
912
913 microkernel: &[unsafe extern "C" fn()],
914 pack: &[unsafe extern "C" fn()],
915
916 mr: usize,
917 nr: usize,
918
919 lhs: *const (),
920 packed_lhs: *mut (),
921
922 rhs: *const (),
923 packed_rhs: *mut (),
924
925 nrows: usize,
926 ncols: usize,
927
928 row_chunk: &[usize],
929 col_chunk: &[usize],
930 lhs_rs: &[isize],
931 rhs_cs: &[isize],
932 packed_lhs_rs: &[isize],
933 packed_rhs_cs: &[isize],
934
935 row: usize,
936 col: usize,
937
938 pos: Position,
939 info: &MicrokernelInfo,
940) {
941 let _ = mr;
942
943 let mut stack: [(
944 *const (),
945 *mut (),
946 *const (),
947 *mut (),
948 usize,
949 usize,
950 usize,
951 usize,
952 usize,
953 usize,
954 usize,
955 usize,
956 bool,
957 bool,
958 bool,
959 bool,
960 ); 16] = const { [(null(), null_mut(), null(), null_mut(), 0, 0, 0, 0, 0, 0, 0, 0, false, false, false, false); 16] };
961
962 stack[0] = (
963 lhs, packed_lhs, rhs, packed_rhs, row, col, nrows, ncols, 0, 0, 0, 0, false, false, false, false,
964 );
965
966 let mut pos = pos;
967 let mut depth = 0;
968 let max_depth = row_chunk.len();
969
970 let milli_rs = *lhs_rs.last().unwrap();
971 let milli_cs = *rhs_cs.last().unwrap();
972
973 let micro_rs = info.lhs_rs;
974 let micro_cs = info.rhs_cs;
975
976 let milli = MillikernelInfo {
977 lhs_rs: milli_rs,
978 packed_lhs_rs: *packed_lhs_rs.last().unwrap(),
979 rhs_cs: milli_cs,
980 packed_rhs_cs: *packed_rhs_cs.last().unwrap(),
981 micro: *info,
982 };
983 let microkernel = microkernel[nr - 1];
984 let pack = pack[0];
985
986 let q = row_chunk.len();
987 let row_chunk = &row_chunk[..q - 1];
988 let col_chunk = &col_chunk[..q - 1];
989 let lhs_rs = &lhs_rs[..q];
990 let packed_lhs_rs = &packed_lhs_rs[..q];
991 let rhs_cs = &rhs_cs[..q];
992 let packed_rhs_cs = &packed_rhs_cs[..q];
993
994 loop {
995 let (lhs, packed_lhs, rhs, packed_rhs, row, col, nrows, ncols, i, j, ii, jj, is_packed_lhs, is_packed_rhs, row_rev, col_rev) = stack[depth];
996 let row_rev = false;
997 let col_rev = false;
998
999 if depth + 1 == max_depth {
1000 let mut lhs = lhs;
1001 let mut rhs = rhs;
1002
1003 pos.row = row;
1004 pos.col = col;
1005
1006 if is_packed_lhs && lhs != packed_lhs {
1007 lhs = null();
1008 }
1009 if is_packed_rhs && rhs != packed_rhs {
1010 rhs = null();
1011 }
1012
1013 unsafe {
1014 millikernel.call(microkernel, pack, lhs, packed_lhs, rhs, packed_rhs, nrows, ncols, &milli, pos);
1015 }
1016
1017 while depth > 0 {
1018 depth -= 1;
1019
1020 let (_, _, _, _, _, _, nrows, ncols, i, j, ii, jj, _, _, _, _) = &mut stack[depth];
1021
1022 let col_chunk = col_chunk[depth];
1023 let row_chunk = row_chunk[depth];
1024
1025 let j_chunk = Ord::min(col_chunk, *ncols - *j);
1026 let i_chunk = Ord::min(row_chunk, *nrows - *i);
1027
1028 if milli.micro.flags & FLAGS_ROWMAJOR == 0 {
1029 *i += i_chunk;
1030 *ii += 1;
1031 if *i == *nrows {
1032 *i = 0;
1033 *ii = 0;
1034 *j += j_chunk;
1035 *jj += 1;
1036
1037 if *j == *ncols {
1038 if depth == 0 {
1039 return;
1040 }
1041
1042 *j = 0;
1043 *jj = 0;
1044 continue;
1045 }
1046 }
1047 } else {
1048 *j += j_chunk;
1049 *jj += 1;
1050 if *j == *ncols {
1051 *j = 0;
1052 *jj = 0;
1053 *i += i_chunk;
1054 *ii += 1;
1055
1056 if *i == *nrows {
1057 *i = 0;
1058 *ii = 0;
1059 if depth == 0 {
1060 return;
1061 }
1062 continue;
1063 }
1064 }
1065 }
1066 break;
1067 }
1068 } else {
1069 let col_chunk = col_chunk[depth];
1070 let row_chunk = row_chunk[depth];
1071 let rhs_cs = rhs_cs[depth];
1072 let lhs_rs = lhs_rs[depth];
1073 let prhs_cs = packed_rhs_cs[depth];
1074 let plhs_rs = packed_lhs_rs[depth];
1075
1076 let last_row_chunk = if nrows == 0 { 0 } else { ((nrows - 1) % row_chunk) + 1 };
1077
1078 let last_col_chunk = if ncols == 0 { 0 } else { ((ncols - 1) % col_chunk) + 1 };
1079
1080 let (i, ii) = if row_rev {
1081 (nrows - last_row_chunk - i, nrows.div_ceil(row_chunk) - 1 - ii)
1082 } else {
1083 (i, ii)
1084 };
1085
1086 let (j, jj) = if col_rev {
1087 (ncols - last_col_chunk - j, ncols.div_ceil(col_chunk) - 1 - jj)
1088 } else {
1089 (j, jj)
1090 };
1091 assert!(i as isize >= 0);
1092 assert!(j as isize >= 0);
1093
1094 let j_chunk = Ord::min(col_chunk, ncols - j);
1095 let i_chunk = Ord::min(row_chunk, nrows - i);
1096
1097 depth += 1;
1098 stack[depth] = (
1099 lhs.wrapping_byte_offset(lhs_rs * ii as isize),
1100 packed_lhs.wrapping_byte_offset(plhs_rs * ii as isize),
1101 rhs.wrapping_byte_offset(rhs_cs * jj as isize),
1102 packed_rhs.wrapping_byte_offset(prhs_cs * jj as isize),
1103 row + i,
1104 col + j,
1105 i_chunk,
1106 j_chunk,
1107 0,
1108 0,
1109 0,
1110 0,
1111 is_packed_lhs || (j > 0 && packed_lhs_rs[depth - 1] != 0),
1112 is_packed_rhs || (i > 0 && packed_rhs_cs[depth - 1] != 0),
1113 jj % 2 == 1,
1114 ii % 2 == 1,
1115 );
1116 continue;
1117 }
1118 }
1119}
1120
1121#[derive(Copy, Clone, Debug, PartialEq, Eq)]
1122pub enum InstrSet {
1123 Avx256,
1124 Avx512,
1125}
1126
1127#[derive(Copy, Clone, Debug, PartialEq, Eq)]
1128pub enum DType {
1129 F32,
1130 F64,
1131 C32,
1132 C64,
1133}
1134
1135#[derive(Copy, Clone, Debug, PartialEq, Eq)]
1136pub enum Accum {
1137 Replace,
1138 Add,
1139}
1140
1141#[derive(Copy, Clone, Debug, PartialEq, Eq)]
1142pub enum IType {
1143 U32,
1144 U64,
1145}
1146
1147#[derive(Copy, Clone, Debug, PartialEq, Eq)]
1148pub enum DstKind {
1149 Lower,
1150 Upper,
1151 Full,
1152}
1153
1154pub unsafe fn gemm(
1155 dtype: DType,
1156 itype: IType,
1157
1158 instr: InstrSet,
1159 nrows: usize,
1160 ncols: usize,
1161 depth: usize,
1162
1163 dst: *mut (),
1164 dst_rs: isize,
1165 dst_cs: isize,
1166 dst_row_idx: *const (),
1167 dst_col_idx: *const (),
1168 dst_kind: DstKind,
1169
1170 beta: Accum,
1171
1172 lhs: *const (),
1173 lhs_rs: isize,
1174 lhs_cs: isize,
1175 conj_lhs: bool,
1176
1177 real_diag: *const (),
1178 diag_stride: isize,
1179
1180 rhs: *const (),
1181 rhs_rs: isize,
1182 rhs_cs: isize,
1183 conj_rhs: bool,
1184
1185 alpha: *const (),
1186
1187 n_threads: usize,
1188) {
1189 let (sizeof, cplx) = match dtype {
1190 DType::F32 => (4, false),
1191 DType::F64 => (8, false),
1192 DType::C32 => (8, true),
1193 DType::C64 => (16, true),
1194 };
1195 let mut lhs_rs = lhs_rs * sizeof as isize;
1196 let mut lhs_cs = lhs_cs * sizeof as isize;
1197 let mut rhs_rs = rhs_rs * sizeof as isize;
1198 let mut rhs_cs = rhs_cs * sizeof as isize;
1199 let mut dst_rs = dst_rs * sizeof as isize;
1200 let mut dst_cs = dst_cs * sizeof as isize;
1201 let real_diag_stride = diag_stride * sizeof as isize;
1202
1203 if nrows == 0 || ncols == 0 || (depth == 0 && beta == Accum::Add) {
1204 return;
1205 }
1206
1207 let mut nrows = nrows;
1208 let mut ncols = ncols;
1209
1210 let mut dst = dst;
1211 let mut dst_row_idx = dst_row_idx;
1212 let mut dst_col_idx = dst_col_idx;
1213 let mut dst_kind = dst_kind;
1214
1215 let mut lhs = lhs;
1216 let mut conj_lhs = conj_lhs;
1217
1218 let mut rhs = rhs;
1219 let mut conj_rhs = conj_rhs;
1220
1221 if dst_rs.unsigned_abs() > dst_cs.unsigned_abs() {
1222 use core::mem::swap;
1223 swap(&mut dst_rs, &mut dst_cs);
1224 swap(&mut dst_row_idx, &mut dst_col_idx);
1225 dst_kind = match dst_kind {
1226 DstKind::Lower => DstKind::Upper,
1227 DstKind::Upper => DstKind::Lower,
1228 DstKind::Full => DstKind::Full,
1229 };
1230 swap(&mut lhs, &mut rhs);
1231 swap(&mut lhs_rs, &mut rhs_cs);
1232 swap(&mut lhs_cs, &mut rhs_rs);
1233 swap(&mut conj_lhs, &mut conj_rhs);
1234 swap(&mut nrows, &mut ncols);
1235 }
1236
1237 if dst_rs < 0 && dst_kind == DstKind::Full && dst_row_idx.is_null() {
1238 dst = dst.wrapping_byte_offset((nrows - 1) as isize * dst_rs);
1239 lhs = lhs.wrapping_byte_offset((nrows - 1) as isize * lhs_rs);
1240 dst_rs = -dst_rs;
1241 lhs_rs = -lhs_rs;
1242 }
1243
1244 if lhs_cs < 0 && depth > 0 {
1245 lhs = lhs.wrapping_byte_offset((depth - 1) as isize * lhs_cs);
1246 rhs = rhs.wrapping_byte_offset((depth - 1) as isize * rhs_rs);
1247
1248 lhs_cs = -lhs_cs;
1249 rhs_rs = -rhs_rs;
1250 }
1251
1252 let (microkernel, pack, mr, nr) = match (instr, dtype) {
1253 (InstrSet::Avx256, DType::F32) => (F32_SIMD256.as_slice(), F32_SIMDpack_256.as_slice(), 24, 4),
1254 (InstrSet::Avx256, DType::F64) => (F64_SIMD256.as_slice(), F64_SIMDpack_256.as_slice(), 12, 4),
1255 (InstrSet::Avx256, DType::C32) => (C32_SIMD256.as_slice(), C32_SIMDpack_256.as_slice(), 12, 4),
1256 (InstrSet::Avx256, DType::C64) => (C64_SIMD256.as_slice(), C64_SIMDpack_256.as_slice(), 6, 4),
1257 (InstrSet::Avx512, DType::F32) => (F32_SIMD512x4.as_slice(), F32_SIMDpack_512.as_slice(), 96, 4),
1258 (InstrSet::Avx512, DType::F64) => {
1259 if nrows > 48 {
1260 (F64_SIMD512x4.as_slice(), F64_SIMDpack_512.as_slice(), 48, 4)
1261 } else {
1262 (F64_SIMD512x8.as_slice(), F64_SIMDpack_512.as_slice(), 24, 8)
1263 }
1264 },
1265 (InstrSet::Avx512, DType::C32) => (C32_SIMD512x4.as_slice(), C32_SIMDpack_512.as_slice(), 48, 4),
1266 (InstrSet::Avx512, DType::C64) => (C64_SIMD512x4.as_slice(), C64_SIMDpack_512.as_slice(), 24, 4),
1267 };
1268
1269 let m = nrows;
1270 let n = ncols;
1271
1272 let kc = Ord::min(depth, 512);
1273
1274 let cache = *cache::CACHE_INFO;
1275
1276 let l1 = cache[0].cache_bytes / sizeof;
1277 let l2 = cache[1].cache_bytes / sizeof;
1278 let l3 = cache[2].cache_bytes / sizeof;
1279
1280 #[repr(align(4096))]
1281 struct Page([u8; 4096]);
1282
1283 let lhs_size = (l3.next_multiple_of(16) * sizeof).div_ceil(size_of::<Page>());
1284 let rhs_size = (l3.next_multiple_of(nr) * sizeof).div_ceil(size_of::<Page>());
1285
1286 thread_local! {
1287 static MEM: RefCell<Vec::<core::mem::MaybeUninit<Page>>> = {
1288 let cache = *cache::CACHE_INFO;
1289 let l3 = cache[2].cache_bytes;
1290
1291 let lhs_size = l3.div_ceil(size_of::<Page>());
1292 let rhs_size = l3.div_ceil(size_of::<Page>());
1293
1294 let mut mem = Vec::with_capacity(lhs_size + rhs_size);
1295 unsafe { mem.set_len(lhs_size + rhs_size) };
1296 RefCell::new(mem)
1297 };
1298 }
1299
1300 MEM.with(|mem| {
1301 let mut storage;
1302 let mut alloc;
1303
1304 let mem = match mem.try_borrow_mut() {
1305 Ok(mem) => {
1306 storage = mem;
1307 &mut *storage
1308 },
1309 Err(_) => {
1310 alloc = Vec::with_capacity(lhs_size + rhs_size);
1311
1312 &mut alloc
1313 },
1314 };
1315 if mem.len() < lhs_size + rhs_size {
1316 mem.reserve_exact(lhs_size + rhs_size);
1317 unsafe { mem.set_len(lhs_size + rhs_size) };
1318 }
1319
1320 let (packed_lhs, packed_rhs) = mem.split_at_mut(lhs_size);
1321 let (packed_rhs, _) = packed_rhs.split_at_mut(rhs_size);
1322
1323 let lhs = ForceSync(lhs);
1324 let rhs = ForceSync(rhs);
1325 let dst = ForceSync(dst);
1326 let real_diag = ForceSync(real_diag);
1327 let dst_row_idx = ForceSync(dst_row_idx);
1328 let dst_col_idx = ForceSync(dst_col_idx);
1329 let alpha = ForceSync(alpha);
1330 let mut f = || {
1331 let mut k = 0;
1332 let mut beta = beta;
1333 let mut lhs = { lhs }.0;
1334 let mut rhs = { rhs }.0;
1335 let mut real_diag = { real_diag }.0;
1336 let dst = { dst }.0;
1337 while k < depth {
1338 let kc = Ord::min(depth - k, kc);
1339
1340 let f = kc.div_ceil(64);
1341 let l1 = l1 / 64 / f;
1342 let l2 = l2 / 64 / f;
1343 let l3 = l3 / 64 / f;
1344
1345 let tall = m >= 3 * n / 2 && m >= l3;
1346 let pack_lhs = !real_diag.is_null() || (n > 6 * nr && tall) || (n > 3 * nr * n_threads) || lhs_rs != sizeof as isize;
1347 let pack_rhs = tall;
1348
1349 let rowmajor = if n_threads > 1 {
1350 false
1351 } else if tall {
1352 true
1353 } else {
1354 false
1355 };
1356
1357 let info = MicrokernelInfo {
1358 flags: match beta {
1359 Accum::Replace => 0,
1360 Accum::Add => FLAGS_ACCUM,
1361 } | if conj_lhs { FLAGS_CONJ_LHS } else { 0 }
1362 | if conj_lhs != conj_rhs { FLAGS_CONJ_NEQ } else { 0 }
1363 | match itype {
1364 IType::U32 => FLAGS_32BIT_IDX,
1365 IType::U64 => 0,
1366 } | if cplx { FLAGS_CPLX } else { 0 }
1367 | match dst_kind {
1368 DstKind::Lower => FLAGS_LOWER,
1369 DstKind::Upper => FLAGS_UPPER,
1370 DstKind::Full => 0,
1371 } | if rowmajor { FLAGS_ROWMAJOR } else { 0 },
1372 depth: kc,
1373 lhs_rs,
1374 lhs_cs,
1375 rhs_rs,
1376 rhs_cs,
1377 alpha: { alpha }.0,
1378 ptr: dst,
1379 rs: dst_rs,
1380 cs: dst_cs,
1381 row_idx: { dst_row_idx }.0,
1382 col_idx: { dst_col_idx }.0,
1383 diag_ptr: real_diag,
1384 diag_stride: real_diag_stride,
1385 };
1386
1387 if n_threads <= 1 && !rowmajor && m < l2 && n < l2 {
1388 let microkernel = microkernel[nr - 1];
1389 let pack = pack[0];
1390 millikernel_colmajor(
1391 microkernel,
1392 pack,
1393 mr,
1394 nr,
1395 sizeof,
1396 lhs,
1397 if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs as _ },
1398 rhs,
1399 if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs as _ },
1400 nrows,
1401 ncols,
1402 &MillikernelInfo {
1403 lhs_rs: lhs_rs * mr as isize,
1404 packed_lhs_rs: if pack_lhs { (sizeof * mr * kc) as isize } else { lhs_rs * mr as isize },
1405 rhs_cs: rhs_cs * nr as isize,
1406 packed_rhs_cs: if pack_rhs { (sizeof * nr * kc) as isize } else { rhs_cs * nr as isize },
1407 micro: info,
1408 },
1409 &mut Position { row: 0, col: 0 },
1410 );
1411 } else {
1412 let (row_chunk, col_chunk, rowmajor) = if n_threads > 1 {
1413 (
1414 [m, m, m, l3 / 16 * 16, mr],
1416 [n, n, n, l3 / 16 * 16, nr],
1417 false,
1418 )
1419 } else if true {
1420 (
1421 [m, l3, l2, l2 / 2, mr],
1423 [n, 2 * l3, l3 / 2, l2, nr],
1424 true,
1425 )
1426 } else {
1427 (
1428 [2 * l3, l3 / 2, l3 / 2, l2, mr],
1430 [l3, l3 / 2, l2 / 2, l1, nr],
1431 false,
1432 )
1433 };
1434
1435 let mut row_chunk = row_chunk.map(|r| if r == mr { mr } else { r.next_multiple_of(16) });
1436 let mut col_chunk = col_chunk.map(|c| c.next_multiple_of(nr));
1437
1438 let q = row_chunk.len();
1439 {
1440 for i in (1..q - 1).rev() {
1441 row_chunk[i - 1] = Ord::max(row_chunk[i - 1].next_multiple_of(row_chunk[i]), row_chunk[i]);
1442 if row_chunk[i - 1] > l3 / 2 && row_chunk[i - 1] < l3 {
1443 row_chunk[i - 1] = l3 / 2;
1444 }
1445 if row_chunk[i - 1] >= l3 {
1446 row_chunk[i - 1] = Ord::min(row_chunk[i - 1], 2 * row_chunk[i]);
1447 }
1448 }
1449 for i in (1..q - 1).rev() {
1450 col_chunk[i - 1] = Ord::max(col_chunk[i - 1].next_multiple_of(col_chunk[i]), col_chunk[i]);
1451 if col_chunk[i - 1] > l3 / 2 && col_chunk[i - 1] < l3 {
1452 col_chunk[i - 1] = l3 / 2;
1453 }
1454 if col_chunk[i - 1] >= l3 {
1455 col_chunk[i - 1] = Ord::min(col_chunk[i - 1], 2 * col_chunk[i]);
1456 }
1457 }
1458 }
1459
1460 let all_lhs_rs = row_chunk.map(|m| m as isize * lhs_rs);
1461 let all_rhs_cs = col_chunk.map(|n| n as isize * rhs_cs);
1462
1463 let mut packed_lhs_rs = row_chunk.map(|x| if x > l3 / 2 { 0 } else { (x * kc * sizeof) as isize });
1464 let mut packed_rhs_cs = col_chunk.map(|x| if x > l3 / 2 { 0 } else { (x * kc * sizeof) as isize });
1465 packed_lhs_rs[0] = 0;
1466 packed_rhs_cs[0] = 0;
1467
1468 assert!(lhs_size * size_of::<Page>() >= row_chunk[q - 2] * kc * sizeof);
1469 assert!(rhs_size * size_of::<Page>() >= col_chunk[q - 2] * kc * sizeof);
1470
1471 unsafe {
1472 kernel(
1473 n_threads,
1474 microkernel,
1475 pack,
1476 mr,
1477 nr,
1478 sizeof,
1479 lhs,
1480 if pack_lhs { packed_lhs.as_mut_ptr() as *mut () } else { lhs as *mut () },
1481 rhs,
1482 if pack_rhs { packed_rhs.as_mut_ptr() as *mut () } else { rhs as *mut () },
1483 nrows,
1484 ncols,
1485 &row_chunk,
1486 &col_chunk,
1487 &all_lhs_rs,
1488 &all_rhs_cs,
1489 if pack_lhs { &packed_lhs_rs } else { &all_lhs_rs },
1490 if pack_rhs { &packed_rhs_cs } else { &all_rhs_cs },
1491 0,
1492 0,
1493 Position { row: 0, col: 0 },
1494 &info,
1495 )
1496 };
1497 }
1498
1499 k += kc;
1500 lhs = lhs.wrapping_byte_offset(lhs_cs * kc as isize);
1501 rhs = rhs.wrapping_byte_offset(rhs_rs * kc as isize);
1502 real_diag = real_diag.wrapping_byte_offset(real_diag_stride * kc as isize);
1503
1504 beta = Accum::Add;
1505 }
1506 };
1507 if n_threads <= 1 {
1508 f();
1509 } else {
1510 #[cfg(feature = "rayon")]
1511 spindle::with_lock(n_threads, f);
1512
1513 #[cfg(not(feature = "rayon"))]
1514 f();
1515 }
1516 });
1517}
1518
1519pub unsafe fn kernel(
1520 n_threads: usize,
1521 microkernel: &[unsafe extern "C" fn()],
1522 pack: &[unsafe extern "C" fn()],
1523
1524 mr: usize,
1525 nr: usize,
1526 sizeof: usize,
1527
1528 lhs: *const (),
1529 packed_lhs: *mut (),
1530
1531 rhs: *const (),
1532 packed_rhs: *mut (),
1533
1534 nrows: usize,
1535 ncols: usize,
1536
1537 row_chunk: &[usize],
1538 col_chunk: &[usize],
1539 lhs_rs: &[isize],
1540 rhs_cs: &[isize],
1541 packed_lhs_rs: &[isize],
1542 packed_rhs_cs: &[isize],
1543
1544 row: usize,
1545 col: usize,
1546
1547 pos: Position,
1548 info: &MicrokernelInfo,
1549) {
1550 unsafe {
1551 let mut seq = Milli { mr, nr, sizeof };
1552 #[cfg(feature = "rayon")]
1553 let mut par;
1554 kernel_imp(
1555 #[cfg(feature = "rayon")]
1556 if n_threads > 1 {
1557 par = {
1558 let max_i = nrows.div_ceil(mr);
1559 let max_j = ncols.div_ceil(nr);
1560 let max_jobs = max_i * max_j;
1561 let c = max_i;
1562
1563 MilliPar {
1564 mr,
1565 nr,
1566 sizeof,
1567 hyper: 1,
1568 microkernel_job: (0..c * max_j).map(|_| AtomicU8::new(0)).collect(),
1569 pack_lhs_job: (0..max_i).map(|_| AtomicU8::new(0)).collect(),
1570 pack_rhs_job: (0..max_j).map(|_| AtomicU8::new(0)).collect(),
1571 finished: AtomicUsize::new(0),
1572 n_threads,
1573 }
1574 };
1575 &mut par
1576 } else {
1577 &mut seq
1578 },
1579 #[cfg(not(feature = "rayon"))]
1580 &mut seq,
1581 microkernel,
1582 pack,
1583 mr,
1584 nr,
1585 lhs,
1586 packed_lhs,
1587 rhs,
1588 packed_rhs,
1589 nrows,
1590 ncols,
1591 row_chunk,
1592 col_chunk,
1593 lhs_rs,
1594 rhs_cs,
1595 packed_lhs_rs,
1596 packed_rhs_cs,
1597 row,
1598 col,
1599 pos,
1600 info,
1601 )
1602 };
1603}
1604
1605#[cfg(test)]
1606mod tests_f64 {
1607 use core::ptr::null_mut;
1608
1609 use super::*;
1610
1611 use aligned_vec::*;
1612 use rand::prelude::*;
1613
1614 #[test]
1615 fn test_avx512_microkernel() {
1616 let rng = &mut StdRng::seed_from_u64(0);
1617
1618 let sizeof = size_of::<f64>() as isize;
1619 let len = 64 / size_of::<f64>();
1620
1621 for pack_lhs in [false, true] {
1622 for pack_rhs in [false] {
1623 for alpha in [1.0.into(), 0.0.into(), 2.5.into()] {
1624 let alpha: f64 = alpha;
1625 for m in 1..=48usize {
1626 for n in (1..=4usize).chain([5]) {
1627 for cs in [m.next_multiple_of(48)] {
1628 let acs = m.next_multiple_of(48);
1629 let k = 2usize;
1630
1631 let packed_lhs: &mut [f64] = &mut *avec![0.0.into(); acs * k];
1632 let packed_rhs: &mut [f64] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
1633 let lhs: &mut [f64] = &mut *avec![0.0.into(); cs * k];
1634 let rhs: &mut [f64] = &mut *avec![0.0.into(); n * k];
1635 let dst: &mut [f64] = &mut *avec![0.0.into(); cs * n];
1636 let target = &mut *avec![0.0.into(); cs * n];
1637
1638 rng.fill(lhs);
1639 rng.fill(rhs);
1640
1641 for i in 0..m {
1642 for j in 0..n {
1643 let target = &mut target[i + cs * j];
1644 let mut acc = 0.0.into();
1645 for depth in 0..k {
1646 acc = f64::mul_add(lhs[i + cs * depth], rhs[depth + k * j], acc);
1647 }
1648 *target = f64::mul_add(acc, alpha, *target);
1649 }
1650 }
1651
1652 unsafe {
1653 millikernel_colmajor(
1654 F64_SIMD512x4[3],
1655 F64_SIMDpack_512[0],
1656 48,
1657 4,
1658 8,
1659 lhs.as_ptr() as _,
1660 if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
1661 rhs.as_ptr() as _,
1662 if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
1663 m,
1664 n,
1665 &mut MillikernelInfo {
1666 lhs_rs: 48 * sizeof,
1667 packed_lhs_rs: if pack_lhs { 48 * sizeof * k as isize } else { 48 * sizeof },
1668 rhs_cs: 4 * sizeof * k as isize,
1669 packed_rhs_cs: 4 * sizeof * k as isize,
1670 micro: MicrokernelInfo {
1671 flags: 0,
1672 depth: k,
1673 lhs_rs: 1 * sizeof,
1674 lhs_cs: cs as isize * sizeof,
1675 rhs_rs: 1 * sizeof,
1676 rhs_cs: k as isize * sizeof,
1677 alpha: &raw const alpha as _,
1678 ptr: dst.as_mut_ptr() as _,
1679 rs: 1 * sizeof,
1680 cs: cs as isize * sizeof,
1681 row_idx: null_mut(),
1682 col_idx: null_mut(),
1683 diag_ptr: null(),
1684 diag_stride: 0,
1685 },
1686 },
1687 &mut Position { row: 0, col: 0 },
1688 )
1689 };
1690 assert_eq!(dst, target);
1691 }
1692 }
1693 }
1694 }
1695 }
1696 }
1697 }
1698
1699 #[test]
1700 fn test_gemm() {
1701 let rng = &mut StdRng::seed_from_u64(0);
1702
1703 let sizeof = size_of::<f64>() as isize;
1704 let len = 64 / size_of::<f64>();
1705
1706 for instr in [InstrSet::Avx256, InstrSet::Avx512] {
1707 for pack_lhs in [false, true] {
1708 for pack_rhs in [false] {
1709 for alpha in [1.0.into(), 0.0.into(), 2.5.into()] {
1710 let alpha: f64 = alpha;
1711 for m in (1..=48usize).chain([513]) {
1712 for n in (1..=4usize).chain([512]) {
1713 for cs in [m.next_multiple_of(48)] {
1714 let acs = m.next_multiple_of(48);
1715 let k = 513usize;
1716
1717 let packed_lhs: &mut [f64] = &mut *avec![0.0.into(); acs * k];
1718 let packed_rhs: &mut [f64] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
1719 let lhs: &mut [f64] = &mut *avec![0.0.into(); cs * k];
1720 let rhs: &mut [f64] = &mut *avec![0.0.into(); n * k];
1721 let dst: &mut [f64] = &mut *avec![0.0.into(); cs * n];
1722 let target = &mut *avec![0.0.into(); cs * n];
1723
1724 rng.fill(lhs);
1725 rng.fill(rhs);
1726
1727 for i in 0..m {
1728 for j in 0..n {
1729 let target = &mut target[i + cs * j];
1730 let mut acc = 0.0.into();
1731 for depth in 0..k {
1732 acc = f64::mul_add(lhs[i + cs * depth], rhs[depth + k * j], acc);
1733 }
1734 *target = f64::mul_add(acc, alpha, *target);
1735 }
1736 }
1737
1738 unsafe {
1739 gemm(
1740 DType::F64,
1741 IType::U64,
1742 instr,
1743 m,
1744 n,
1745 k,
1746 dst.as_mut_ptr() as _,
1747 1,
1748 cs as isize,
1749 null(),
1750 null(),
1751 DstKind::Full,
1752 Accum::Add,
1753 lhs.as_ptr() as _,
1754 1,
1755 cs as isize,
1756 false,
1757 null(),
1758 0,
1759 rhs.as_ptr() as _,
1760 1,
1761 k as isize,
1762 false,
1763 &raw const alpha as _,
1764 1,
1765 )
1766 };
1767 let mut i = 0;
1768 for (&target, &dst) in core::iter::zip(&*target, &*dst) {
1769 if !((target - dst).abs() < 1e-6) {
1770 dbg!(i / cs, i % cs, target, dst);
1771 panic!();
1772 }
1773 i += 1;
1774 }
1775 }
1776 }
1777 }
1778 }
1779 }
1780 }
1781 }
1782 }
1783 #[test]
1784 fn test_avx512_kernel() {
1785 let m = 1023usize;
1786 let n = 1023usize;
1787 let k = 5usize;
1788
1789 let rng = &mut StdRng::seed_from_u64(0);
1790 let sizeof = size_of::<f64>() as isize;
1791 let cs = m.next_multiple_of(8);
1792 let cs = Ord::max(4096, cs);
1793
1794 let lhs: &mut [f64] = &mut *avec![0.0; cs * k];
1795 let rhs: &mut [f64] = &mut *avec![0.0; k * n];
1796 let target: &mut [f64] = &mut *avec![0.0; cs * n];
1797
1798 rng.fill(lhs);
1799 rng.fill(rhs);
1800
1801 unsafe {
1802 gemm::gemm(
1803 m,
1804 n,
1805 k,
1806 target.as_mut_ptr(),
1807 cs as isize,
1808 1,
1809 true,
1810 lhs.as_ptr(),
1811 cs as isize,
1812 1,
1813 rhs.as_ptr(),
1814 k as isize,
1815 1,
1816 1.0,
1817 1.0,
1818 false,
1819 false,
1820 false,
1821 gemm::Parallelism::None,
1822 );
1823 }
1824
1825 for pack_lhs in [false, true] {
1826 for pack_rhs in [false] {
1827 let dst = &mut *avec![0.0; cs * n];
1828 let packed_lhs = &mut *avec![0.0f64; m.next_multiple_of(8) * k];
1829 let packed_rhs = &mut *avec![0.0; if pack_rhs { n.next_multiple_of(4) * k } else { 0 }];
1830
1831 unsafe {
1832 let row_chunk = [48 * 32, 48 * 16, 48];
1833 let col_chunk = [48 * 64, 48 * 32, 48, 4];
1834
1835 let lhs_rs = row_chunk.map(|m| m as isize * sizeof);
1836 let rhs_cs = col_chunk.map(|n| (n * k) as isize * sizeof);
1837 let packed_lhs_rs = row_chunk.map(|m| (m * k) as isize * sizeof);
1838 let packed_rhs_cs = col_chunk.map(|n| (n * k) as isize * sizeof);
1839
1840 kernel(
1841 1,
1842 &F64_SIMD512x4[..24],
1843 &F64_SIMDpack_512,
1844 48,
1845 4,
1846 size_of::<f64>(),
1847 lhs.as_ptr() as _,
1848 if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
1849 rhs.as_ptr() as _,
1850 if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
1851 m,
1852 n,
1853 &row_chunk,
1854 &col_chunk,
1855 &lhs_rs,
1856 &rhs_cs,
1857 &if pack_lhs { packed_lhs_rs } else { lhs_rs },
1858 &if pack_rhs { packed_rhs_cs } else { rhs_cs },
1859 0,
1860 0,
1861 Position { row: 0, col: 0 },
1862 &MicrokernelInfo {
1863 flags: 0,
1864 depth: k,
1865 lhs_rs: sizeof,
1866 lhs_cs: cs as isize * sizeof,
1867 rhs_rs: sizeof,
1868 rhs_cs: k as isize * sizeof,
1869 alpha: &raw const *&1.0f64 as _,
1870 ptr: dst.as_mut_ptr() as _,
1871 rs: sizeof,
1872 cs: cs as isize * sizeof,
1873 row_idx: null_mut(),
1874 col_idx: null_mut(),
1875 diag_ptr: null(),
1876 diag_stride: 0,
1877 },
1878 );
1879 }
1880 let mut i = 0;
1881 for (&target, &dst) in core::iter::zip(&*target, &*dst) {
1882 if !((target - dst).abs() < 1e-6) {
1883 dbg!(i / cs, i % cs, target, dst);
1884 panic!();
1885 }
1886 i += 1;
1887 }
1888 }
1889 }
1890 }
1891}
1892
1893#[cfg(test)]
1894mod tests_c64 {
1895 use super::*;
1896
1897 use aligned_vec::*;
1898 use bytemuck::*;
1899 use core::ptr::null_mut;
1900 use gemm::c64;
1901 use rand::prelude::*;
1902
1903 #[test]
1904 fn test_avx512_microkernel() {
1905 let rng = &mut StdRng::seed_from_u64(0);
1906
1907 let sizeof = size_of::<c64>() as isize;
1908 let len = 64 / size_of::<c64>();
1909
1910 for pack_lhs in [false, true] {
1911 for pack_rhs in [false] {
1912 for alpha in [1.0.into(), 0.0.into(), c64::new(0.0, 3.5), c64::new(2.5, 3.5)] {
1913 let alpha: c64 = alpha;
1914 for m in 1..=24usize {
1915 for n in (1..=4usize).into_iter().chain([8]) {
1916 for cs in [m.next_multiple_of(len), m] {
1917 for conj_lhs in [false, true] {
1918 for conj_rhs in [false, true] {
1919 let conj_different = conj_lhs != conj_rhs;
1920
1921 let acs = m.next_multiple_of(len);
1922 let k = 1usize;
1923
1924 let packed_lhs: &mut [c64] = &mut *avec![0.0.into(); acs * k];
1925 let packed_rhs: &mut [c64] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
1926 let lhs: &mut [c64] = &mut *avec![0.0.into(); cs * k];
1927 let rhs: &mut [c64] = &mut *avec![0.0.into(); n * k];
1928 let dst: &mut [c64] = &mut *avec![0.0.into(); cs * n];
1929 let target: &mut [c64] = &mut *avec![0.0.into(); cs * n];
1930
1931 rng.fill(cast_slice_mut::<c64, f64>(lhs));
1932 rng.fill(cast_slice_mut::<c64, f64>(rhs));
1933
1934 for i in 0..m {
1935 for j in 0..n {
1936 let target = &mut target[i + cs * j];
1937 let mut acc: c64 = 0.0.into();
1938 for depth in 0..k {
1939 let mut l = lhs[i + cs * depth];
1940 let mut r = rhs[depth + k * j];
1941 if conj_lhs {
1942 l = l.conj();
1943 }
1944 if conj_rhs {
1945 r = r.conj();
1946 }
1947
1948 acc = l * r + acc;
1949 }
1950 *target = acc * alpha + *target;
1951 }
1952 }
1953
1954 unsafe {
1955 millikernel_colmajor(
1956 C64_SIMD512x4[3],
1957 C64_SIMDpack_512[0],
1958 24,
1959 4,
1960 16,
1961 lhs.as_ptr() as _,
1962 if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
1963 rhs.as_ptr() as _,
1964 if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
1965 m,
1966 n,
1967 &mut MillikernelInfo {
1968 lhs_rs: 24 * sizeof,
1969 packed_lhs_rs: 24 * sizeof * k as isize,
1970 rhs_cs: 4 * sizeof * k as isize,
1971 packed_rhs_cs: 4 * sizeof * k as isize,
1972 micro: MicrokernelInfo {
1973 flags: ((conj_lhs as usize) << 1) | ((conj_different as usize) << 2),
1974 depth: k,
1975 lhs_rs: 1 * sizeof,
1976 lhs_cs: cs as isize * sizeof,
1977 rhs_rs: 1 * sizeof,
1978 rhs_cs: k as isize * sizeof,
1979 alpha: &raw const alpha as _,
1980 ptr: dst.as_mut_ptr() as _,
1981 rs: 1 * sizeof,
1982 cs: cs as isize * sizeof,
1983 row_idx: null_mut(),
1984 col_idx: null_mut(),
1985 diag_ptr: null(),
1986 diag_stride: 0,
1987 },
1988 },
1989 &mut Position { row: 0, col: 0 },
1990 )
1991 };
1992 let mut i = 0;
1993 for (&target, &dst) in core::iter::zip(&*target, &*dst) {
1994 if !((target - dst).norm_sqr().sqrt() < 1e-6) {
1995 dbg!(i / cs, i % cs, target, dst);
1996 panic!();
1997 }
1998 i += 1;
1999 }
2000 }
2001 }
2002 }
2003 }
2004 }
2005 }
2006 }
2007 }
2008 }
2009}
2010
2011#[cfg(test)]
2012mod tests_f32 {
2013 use core::ptr::null_mut;
2014
2015 use super::*;
2016
2017 use aligned_vec::*;
2018 use rand::prelude::*;
2019
2020 #[test]
2021 fn test_avx512_microkernel() {
2022 let rng = &mut StdRng::seed_from_u64(0);
2023
2024 let sizeof = size_of::<f32>() as isize;
2025 let len = 64 / size_of::<f32>();
2026
2027 for pack_lhs in [false, true] {
2028 for pack_rhs in [false] {
2029 for alpha in [1.0.into(), 0.0.into(), 2.5.into()] {
2030 let alpha: f32 = alpha;
2031 for m in 1..=96usize {
2032 for n in (1..=4usize).into_iter().chain([8]) {
2033 for cs in [m.next_multiple_of(len), m] {
2034 let acs = m.next_multiple_of(len);
2035 let k = 1usize;
2036
2037 let packed_lhs: &mut [f32] = &mut *avec![0.0.into(); acs * k];
2038 let packed_rhs: &mut [f32] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
2039 let lhs: &mut [f32] = &mut *avec![0.0.into(); cs * k];
2040 let rhs: &mut [f32] = &mut *avec![0.0.into(); n * k];
2041 let dst: &mut [f32] = &mut *avec![0.0.into(); cs * n];
2042 let target = &mut *avec![0.0.into(); cs * n];
2043
2044 rng.fill(lhs);
2045 rng.fill(rhs);
2046
2047 for i in 0..m {
2048 for j in 0..n {
2049 let target = &mut target[i + cs * j];
2050 let mut acc = 0.0.into();
2051 for depth in 0..k {
2052 acc = f32::mul_add(lhs[i + cs * depth], rhs[depth + k * j], acc);
2053 }
2054 *target = f32::mul_add(acc, alpha, *target);
2055 }
2056 }
2057
2058 unsafe {
2059 millikernel_rowmajor(
2060 F32_SIMD512x4[3],
2061 F32_SIMDpack_512[0],
2062 96,
2063 4,
2064 4,
2065 lhs.as_ptr() as _,
2066 if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
2067 rhs.as_ptr() as _,
2068 if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
2069 m,
2070 n,
2071 &mut MillikernelInfo {
2072 lhs_rs: 96 * sizeof,
2073 packed_lhs_rs: 96 * sizeof * k as isize,
2074 rhs_cs: 4 * sizeof * k as isize,
2075 packed_rhs_cs: 4 * sizeof * k as isize,
2076 micro: MicrokernelInfo {
2077 flags: (1 << 63),
2078 depth: k,
2079 lhs_rs: 1 * sizeof,
2080 lhs_cs: cs as isize * sizeof,
2081 rhs_rs: 1 * sizeof,
2082 rhs_cs: k as isize * sizeof,
2083 alpha: &raw const alpha as _,
2084 ptr: dst.as_mut_ptr() as _,
2085 rs: 1 * sizeof,
2086 cs: cs as isize * sizeof,
2087 row_idx: null_mut(),
2088 col_idx: null_mut(),
2089 diag_ptr: null(),
2090 diag_stride: 0,
2091 },
2092 },
2093 &mut Position { row: 0, col: 0 },
2094 )
2095 };
2096 assert_eq!(dst, target);
2097 }
2098 }
2099 }
2100 }
2101 }
2102 }
2103 }
2104
2105 #[test]
2106 fn test_avx512_kernel() {
2107 let m = 6000usize;
2108 let n = 2000usize;
2109 let k = 5usize;
2110
2111 let rng = &mut StdRng::seed_from_u64(0);
2112 let sizeof = size_of::<f32>() as isize;
2113 let cs = m.next_multiple_of(16);
2114 let cs = Ord::max(4096, cs);
2115
2116 let lhs: &mut [f32] = &mut *avec![0.0; cs * k];
2117 let rhs: &mut [f32] = &mut *avec![0.0; k * n];
2118 let target: &mut [f32] = &mut *avec![0.0; cs * n];
2119
2120 rng.fill(lhs);
2121 rng.fill(rhs);
2122
2123 unsafe {
2124 gemm::gemm(
2125 m,
2126 n,
2127 k,
2128 target.as_mut_ptr(),
2129 cs as isize,
2130 1,
2131 true,
2132 lhs.as_ptr(),
2133 cs as isize,
2134 1,
2135 rhs.as_ptr(),
2136 k as isize,
2137 1,
2138 1.0,
2139 1.0,
2140 false,
2141 false,
2142 false,
2143 gemm::Parallelism::None,
2144 );
2145 }
2146
2147 for pack_lhs in [false, true] {
2148 for pack_rhs in [false] {
2149 let dst = &mut *avec![0.0; cs * n];
2150 let packed_lhs = &mut *avec![0.0f32; m.next_multiple_of(16) * k];
2151 let packed_rhs = &mut *avec![0.0; if pack_rhs { n.next_multiple_of(4) * k } else { 0 }];
2152
2153 unsafe {
2154 let row_chunk = [96 * 32, 96 * 16, 96 * 4, 96];
2155 let col_chunk = [1024, 256, 64, 16, 4];
2156
2157 let lhs_rs = row_chunk.map(|m| m as isize * sizeof);
2158 let rhs_cs = col_chunk.map(|n| (n * k) as isize * sizeof);
2159 let packed_lhs_rs = row_chunk.map(|m| (m * k) as isize * sizeof);
2160 let mut packed_rhs_cs = col_chunk.map(|n| (n * k) as isize * sizeof);
2161 packed_rhs_cs[0] = 0;
2162
2163 kernel(
2164 1,
2165 &F32_SIMD512x4[..24],
2166 &F32_SIMDpack_512,
2167 96,
2168 4,
2169 size_of::<f32>(),
2170 lhs.as_ptr() as _,
2171 if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
2172 rhs.as_ptr() as _,
2173 if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
2174 m,
2175 n,
2176 &row_chunk,
2177 &col_chunk,
2178 &lhs_rs,
2179 &rhs_cs,
2180 &if pack_lhs { packed_lhs_rs } else { lhs_rs },
2181 &if pack_rhs { packed_rhs_cs } else { rhs_cs },
2182 0,
2183 0,
2184 Position { row: 0, col: 0 },
2185 &MicrokernelInfo {
2186 flags: 0,
2187 depth: k,
2188 lhs_rs: sizeof,
2189 lhs_cs: cs as isize * sizeof,
2190 rhs_rs: sizeof,
2191 rhs_cs: k as isize * sizeof,
2192 alpha: &raw const *&1.0f32 as _,
2193 ptr: dst.as_mut_ptr() as _,
2194 rs: sizeof,
2195 cs: cs as isize * sizeof,
2196 row_idx: null_mut(),
2197 col_idx: null_mut(),
2198 diag_ptr: null(),
2199 diag_stride: 0,
2200 },
2201 )
2202 }
2203 let mut i = 0;
2204 for (&target, &dst) in core::iter::zip(&*target, &*dst) {
2205 if !((target - dst).abs() < 1e-6) {
2206 dbg!(i / cs, i % cs, target, dst);
2207 panic!();
2208 }
2209 i += 1;
2210 }
2211 }
2212 }
2213 }
2214}
2215
2216#[cfg(test)]
2217mod tests_c32 {
2218 use super::*;
2219
2220 use aligned_vec::*;
2221 use bytemuck::*;
2222 use core::ptr::null_mut;
2223 use gemm::c32;
2224 use rand::prelude::*;
2225
2226 #[test]
2227 fn test_avx512_microkernel() {
2228 let rng = &mut StdRng::seed_from_u64(0);
2229
2230 let sizeof = size_of::<c32>() as isize;
2231 let len = 64 / size_of::<c32>();
2232
2233 for pack_lhs in [false, true] {
2234 for pack_rhs in [false] {
2235 for alpha in [1.0.into(), 0.0.into(), c32::new(0.0, 3.5), c32::new(2.5, 3.5)] {
2236 let alpha: c32 = alpha;
2237 for m in 1..=127usize {
2238 for n in (1..=4usize).into_iter().chain([8]) {
2239 for cs in [m.next_multiple_of(len), m] {
2240 for conj_lhs in [false, true] {
2241 for conj_rhs in [false, true] {
2242 for diag_scale in [false, true] {
2243 if diag_scale && !pack_lhs {
2244 continue;
2245 }
2246 let conj_different = conj_lhs != conj_rhs;
2247
2248 let acs = m.next_multiple_of(len);
2249 let k = 1usize;
2250
2251 let packed_lhs: &mut [c32] = &mut *avec![0.0.into(); acs * k];
2252 let packed_rhs: &mut [c32] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
2253 let lhs: &mut [c32] = &mut *avec![0.0.into(); cs * k];
2254 let rhs: &mut [c32] = &mut *avec![0.0.into(); n * k];
2255 let dst: &mut [c32] = &mut *avec![0.0.into(); cs * n];
2256 let target: &mut [c32] = &mut *avec![0.0.into(); cs * n];
2257
2258 let diag: &mut [f32] = &mut *avec![0.0.into(); k];
2259
2260 rng.fill(cast_slice_mut::<c32, f32>(lhs));
2261 rng.fill(cast_slice_mut::<c32, f32>(rhs));
2262 rng.fill(diag);
2263
2264 for i in 0..m {
2265 for j in 0..n {
2266 let target = &mut target[i + cs * j];
2267 let mut acc: c32 = 0.0.into();
2268 for depth in 0..k {
2269 let mut l = lhs[i + cs * depth];
2270 let mut r = rhs[depth + k * j];
2271 let d = diag[depth];
2272
2273 if conj_lhs {
2274 l = l.conj();
2275 }
2276 if conj_rhs {
2277 r = r.conj();
2278 }
2279
2280 if diag_scale {
2281 acc += d * l * r;
2282 } else {
2283 acc += l * r;
2284 }
2285 }
2286 *target = acc * alpha + *target;
2287 }
2288 }
2289
2290 unsafe {
2291 millikernel_colmajor(
2292 C32_SIMD512x4[3],
2293 C32_SIMDpack_512[0],
2294 48,
2295 4,
2296 8,
2297 lhs.as_ptr() as _,
2298 if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
2299 rhs.as_ptr() as _,
2300 if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
2301 m,
2302 n,
2303 &mut MillikernelInfo {
2304 lhs_rs: 48 * sizeof,
2305 packed_lhs_rs: 48 * sizeof * k as isize,
2306 rhs_cs: 4 * sizeof * k as isize,
2307 packed_rhs_cs: 4 * sizeof * k as isize,
2308 micro: MicrokernelInfo {
2309 flags: ((conj_lhs as usize) << 1) | ((conj_different as usize) << 2),
2310 depth: k,
2311 lhs_rs: 1 * sizeof,
2312 lhs_cs: cs as isize * sizeof,
2313 rhs_rs: 1 * sizeof,
2314 rhs_cs: k as isize * sizeof,
2315 alpha: &raw const alpha as _,
2316 ptr: dst.as_mut_ptr() as _,
2317 rs: 1 * sizeof,
2318 cs: cs as isize * sizeof,
2319 row_idx: null_mut(),
2320 col_idx: null_mut(),
2321 diag_ptr: if diag_scale { diag.as_ptr() as *const () } else { null() },
2322 diag_stride: if diag_scale { size_of::<f32>() as isize } else { 0 },
2323 },
2324 },
2325 &mut Position { row: 0, col: 0 },
2326 )
2327 };
2328 let mut i = 0;
2329 for (&target, &dst) in core::iter::zip(&*target, &*dst) {
2330 if !((target - dst).norm_sqr().sqrt() < 1e-4) {
2331 dbg!(i / cs, i % cs, target, dst);
2332 panic!();
2333 }
2334 i += 1;
2335 }
2336 }
2337 }
2338 }
2339 }
2340 }
2341 }
2342 }
2343 }
2344 }
2345 }
2346}
2347
2348#[cfg(test)]
2349mod tests_c32_lower {
2350 use super::*;
2351
2352 use aligned_vec::*;
2353 use bytemuck::*;
2354 use core::ptr::null_mut;
2355 use gemm::c32;
2356 use rand::prelude::*;
2357
2358 #[test]
2359 fn test_avx512_microkernel() {
2360 let rng = &mut StdRng::seed_from_u64(0);
2361
2362 let sizeof = size_of::<c32>() as isize;
2363 let len = 64 / size_of::<c32>();
2364
2365 for pack_lhs in [false, true] {
2366 for pack_rhs in [false] {
2367 for alpha in [1.0.into(), 0.0.into(), c32::new(0.0, 3.5), c32::new(2.5, 3.5)] {
2368 let alpha: c32 = alpha;
2369 for m in 1..=127usize {
2370 for n in (1..=4usize).chain([8, 32]) {
2371 for cs in [m, m.next_multiple_of(len)] {
2372 for conj_lhs in [false, true] {
2373 for conj_rhs in [false, true] {
2374 for diag_scale in [false, true] {
2375 if diag_scale && !pack_lhs {
2376 continue;
2377 }
2378 let conj_different = conj_lhs != conj_rhs;
2379
2380 let acs = m.next_multiple_of(len);
2381 let k = 1usize;
2382
2383 let packed_lhs: &mut [c32] = &mut *avec![0.0.into(); acs * k];
2384 let packed_rhs: &mut [c32] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
2385 let lhs: &mut [c32] = &mut *avec![0.0.into(); cs * k];
2386 let rhs: &mut [c32] = &mut *avec![0.0.into(); n * k];
2387 let dst: &mut [c32] = &mut *avec![0.0.into(); cs * n];
2388 let target: &mut [c32] = &mut *avec![0.0.into(); cs * n];
2389
2390 let diag: &mut [f32] = &mut *avec![0.0.into(); k];
2391
2392 rng.fill(cast_slice_mut::<c32, f32>(lhs));
2393 rng.fill(cast_slice_mut::<c32, f32>(rhs));
2394 rng.fill(diag);
2395
2396 for i in 0..m {
2397 for j in 0..n {
2398 if i < j {
2399 continue;
2400 }
2401 let target = &mut target[i + cs * j];
2402 let mut acc: c32 = 0.0.into();
2403 for depth in 0..k {
2404 let mut l = lhs[i + cs * depth];
2405 let mut r = rhs[depth + k * j];
2406 let d = diag[depth];
2407
2408 if conj_lhs {
2409 l = l.conj();
2410 }
2411 if conj_rhs {
2412 r = r.conj();
2413 }
2414
2415 if diag_scale {
2416 acc += d * l * r;
2417 } else {
2418 acc += l * r;
2419 }
2420 }
2421 *target = acc * alpha + *target;
2422 }
2423 }
2424
2425 unsafe {
2426 millikernel_colmajor(
2427 C32_SIMD512x4[3],
2428 C32_SIMDpack_512[0],
2429 48,
2430 4,
2431 8,
2432 lhs.as_ptr() as _,
2433 if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
2434 rhs.as_ptr() as _,
2435 if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
2436 m,
2437 n,
2438 &mut MillikernelInfo {
2439 lhs_rs: 48 * sizeof,
2440 packed_lhs_rs: 48 * sizeof * k as isize,
2441 rhs_cs: 4 * sizeof * k as isize,
2442 packed_rhs_cs: 4 * sizeof * k as isize,
2443 micro: MicrokernelInfo {
2444 flags: ((conj_lhs as usize) << 1) | ((conj_different as usize) << 2) | (1 << 3),
2445 depth: k,
2446 lhs_rs: 1 * sizeof,
2447 lhs_cs: cs as isize * sizeof,
2448 rhs_rs: 1 * sizeof,
2449 rhs_cs: k as isize * sizeof,
2450 alpha: &raw const alpha as _,
2451 ptr: dst.as_mut_ptr() as _,
2452 rs: 1 * sizeof,
2453 cs: cs as isize * sizeof,
2454 row_idx: null_mut(),
2455 col_idx: null_mut(),
2456 diag_ptr: if diag_scale { diag.as_ptr() as *const () } else { null() },
2457 diag_stride: if diag_scale { size_of::<f32>() as isize } else { 0 },
2458 },
2459 },
2460 &mut Position { row: 0, col: 0 },
2461 )
2462 };
2463 let mut i = 0;
2464 for (&target, &dst) in core::iter::zip(&*target, &*dst) {
2465 if !((target - dst).norm_sqr().sqrt() < 1e-4) {
2466 dbg!(i / cs, i % cs, target, dst);
2467 panic!();
2468 }
2469 i += 1;
2470 }
2471 }
2472 }
2473 }
2474 }
2475 }
2476 }
2477 }
2478 }
2479 }
2480 }
2481
2482 #[test]
2483 fn test_avx256microkernel() {
2484 let rng = &mut StdRng::seed_from_u64(0);
2485
2486 let sizeof = size_of::<c32>() as isize;
2487 let len = 64 / size_of::<c32>();
2488
2489 for pack_lhs in [false, true] {
2490 for pack_rhs in [false] {
2491 for alpha in [1.0.into(), 0.0.into(), c32::new(0.0, 3.5), c32::new(2.5, 3.5)] {
2492 let alpha: c32 = alpha;
2493 for m in 1..=127usize {
2494 for n in (1..=4usize).chain([8, 32]) {
2495 for cs in [m, m.next_multiple_of(len)] {
2496 for conj_lhs in [false, true] {
2497 for conj_rhs in [false, true] {
2498 for diag_scale in [false, true] {
2499 if diag_scale && !pack_lhs {
2500 continue;
2501 }
2502
2503 let conj_different = conj_lhs != conj_rhs;
2504
2505 let acs = m.next_multiple_of(len);
2506 let k = 1usize;
2507
2508 let packed_lhs: &mut [c32] = &mut *avec![0.0.into(); acs * k];
2509 let packed_rhs: &mut [c32] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
2510 let lhs: &mut [c32] = &mut *avec![0.0.into(); cs * k];
2511 let rhs: &mut [c32] = &mut *avec![0.0.into(); n * k];
2512 let dst: &mut [c32] = &mut *avec![0.0.into(); cs * n];
2513 let target: &mut [c32] = &mut *avec![0.0.into(); cs * n];
2514
2515 let diag: &mut [f32] = &mut *avec![0.0.into(); k];
2516
2517 rng.fill(cast_slice_mut::<c32, f32>(lhs));
2518 rng.fill(cast_slice_mut::<c32, f32>(rhs));
2519 rng.fill(diag);
2520
2521 for i in 0..m {
2522 for j in 0..n {
2523 if i < j {
2524 continue;
2525 }
2526 let target = &mut target[i + cs * j];
2527 let mut acc: c32 = 0.0.into();
2528 for depth in 0..k {
2529 let mut l = lhs[i + cs * depth];
2530 let mut r = rhs[depth + k * j];
2531 let d = diag[depth];
2532
2533 if conj_lhs {
2534 l = l.conj();
2535 }
2536 if conj_rhs {
2537 r = r.conj();
2538 }
2539
2540 if diag_scale {
2541 acc += d * l * r;
2542 } else {
2543 acc += l * r;
2544 }
2545 }
2546 *target = acc * alpha + *target;
2547 }
2548 }
2549
2550 unsafe {
2551 millikernel_colmajor(
2552 C32_SIMD256[3],
2553 C32_SIMDpack_256[0],
2554 12,
2555 4,
2556 8,
2557 lhs.as_ptr() as _,
2558 if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
2559 rhs.as_ptr() as _,
2560 if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
2561 m,
2562 n,
2563 &mut MillikernelInfo {
2564 lhs_rs: 12 * sizeof,
2565 packed_lhs_rs: 12 * sizeof * k as isize,
2566 rhs_cs: 4 * sizeof * k as isize,
2567 packed_rhs_cs: 4 * sizeof * k as isize,
2568 micro: MicrokernelInfo {
2569 flags: ((conj_lhs as usize) << 1) | ((conj_different as usize) << 2) | (1 << 3),
2570 depth: k,
2571 lhs_rs: 1 * sizeof,
2572 lhs_cs: cs as isize * sizeof,
2573 rhs_rs: 1 * sizeof,
2574 rhs_cs: k as isize * sizeof,
2575 alpha: &raw const alpha as _,
2576 ptr: dst.as_mut_ptr() as _,
2577 rs: 1 * sizeof,
2578 cs: cs as isize * sizeof,
2579 row_idx: null_mut(),
2580 col_idx: null_mut(),
2581 diag_ptr: if diag_scale { diag.as_ptr() as *const () } else { null() },
2582 diag_stride: if diag_scale { size_of::<f32>() as isize } else { 0 },
2583 },
2584 },
2585 &mut Position { row: 0, col: 0 },
2586 )
2587 };
2588 let mut i = 0;
2589 for (&target, &dst) in core::iter::zip(&*target, &*dst) {
2590 if !((target - dst).norm_sqr().sqrt() < 1e-4) {
2591 dbg!(i / cs, i % cs, target, dst);
2592 panic!();
2593 }
2594 i += 1;
2595 }
2596 }
2597 }
2598 }
2599 }
2600 }
2601 }
2602 }
2603 }
2604 }
2605 }
2606}
2607
2608#[cfg(test)]
2609mod tests_c32_lower_add {
2610 use super::*;
2611
2612 use aligned_vec::*;
2613 use bytemuck::*;
2614 use gemm::c64;
2615 use rand::prelude::*;
2616
2617 #[test]
2618 fn test_avx512_microkernel_rowmajor() {
2619 let rng = &mut StdRng::seed_from_u64(0);
2620
2621 let sizeof = size_of::<c64>() as isize;
2622 let len = 64 / size_of::<c64>();
2623
2624 for alpha in [1.0.into(), 0.0.into(), c64::new(0.0, 3.5), c64::new(2.5, 3.5)] {
2625 let alpha: c64 = alpha;
2626 for m in 1..=127usize {
2627 let m = 4005usize;
2628 for n in (1..=4usize).chain([8, 32, 1024]) {
2629 let n = 2usize;
2630 for cs in [m, m.next_multiple_of(len)] {
2631 for conj_lhs in [false, true] {
2632 for conj_rhs in [false, true] {
2633 for diag_scale in [true, false] {
2634 let conj_different = conj_lhs != conj_rhs;
2635
2636 let acs = m.next_multiple_of(24);
2637 let k = 4005usize;
2638 dbg!(m, n, k, diag_scale, conj_lhs, conj_rhs);
2639
2640 let packed_lhs: &mut [c64] = &mut *avec![0.0.into(); acs * k];
2641 let packed_rhs: &mut [c64] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
2642 let lhs: &mut [c64] = &mut *avec![0.0.into(); m * k];
2643 let rhs: &mut [c64] = &mut *avec![0.0.into(); n * k];
2644 let dst: &mut [c64] = &mut *avec![0.0.into(); cs * n];
2645 rng.fill(cast_slice_mut::<c64, f64>(dst));
2646
2647 let target0: &mut [c64] = &mut *dst.to_vec();
2648 let target1: &mut [c64] = &mut *dst.to_vec();
2649
2650 let diag: &mut [c64] = &mut *avec![0.0.into(); k];
2651
2652 rng.fill(cast_slice_mut::<c64, f64>(lhs));
2653 rng.fill(cast_slice_mut::<c64, f64>(rhs));
2654
2655 for x in &mut *diag {
2656 x.re = rng.random();
2657 }
2658
2659 for i in 0..m {
2660 for j in 0..n {
2661 let target = &mut target0[i + cs * j];
2662 let mut acc: c64 = 0.0.into();
2663 for depth in 0..k {
2664 let mut l = lhs[i * k + depth];
2665 let mut r = rhs[depth + k * j];
2666 let d = diag[depth];
2667
2668 if conj_lhs {
2669 l = l.conj();
2670 }
2671 if conj_rhs {
2672 r = r.conj();
2673 }
2674
2675 if diag_scale {
2676 acc += d * l * r;
2677 } else {
2678 acc += l * r;
2679 }
2680 }
2681 *target = acc * alpha;
2682 }
2683 }
2684
2685 unsafe {
2686 gemm(
2687 DType::C64,
2688 IType::U64,
2689 InstrSet::Avx512,
2690 m,
2691 n,
2692 k,
2693 dst.as_mut_ptr() as _,
2694 1,
2695 cs as isize,
2696 null(),
2697 null(),
2698 DstKind::Full,
2699 Accum::Replace,
2700 lhs.as_ptr() as _,
2701 k as isize,
2702 1,
2703 conj_lhs,
2704 if diag_scale { diag.as_ptr() as _ } else { null() },
2705 if diag_scale { 1 } else { 0 },
2706 rhs.as_ptr() as _,
2707 1,
2708 k as isize,
2709 conj_rhs,
2710 &raw const alpha as _,
2711 1,
2712 )
2713 };
2714
2715 let mut i = 0;
2716 for (&target, &dst) in core::iter::zip(&*target0, &*dst) {
2717 if !((target - dst).norm_sqr().sqrt() < 1e-4) {
2718 dbg!(i / cs, i % cs, target, dst);
2719 panic!();
2720 }
2721 i += 1;
2722 }
2723 }
2724 }
2725 }
2726 }
2727 }
2728 }
2729 }
2730 }
2731
2732 #[test]
2733 fn test_avx512_microkernel_colmajor() {
2734 let rng = &mut StdRng::seed_from_u64(0);
2735
2736 let sizeof = size_of::<c64>() as isize;
2737 let len = 64 / size_of::<c64>();
2738
2739 for alpha in [1.0.into(), 0.0.into(), c64::new(0.0, 3.5), c64::new(2.5, 3.5)] {
2740 let alpha: c64 = alpha;
2741 for m in [4005usize] {
2742 for n in [2usize] {
2743 for cs in [4008] {
2744 for conj_lhs in [false, true] {
2745 for conj_rhs in [false, true] {
2746 for diag_scale in [true, false] {
2747 let conj_different = conj_lhs != conj_rhs;
2748
2749 let acs = m.next_multiple_of(24);
2750 let k = 4005usize;
2751 dbg!(m, n, k, diag_scale, conj_lhs, conj_rhs);
2752
2753 let lhs: &mut [c64] = &mut *avec![0.0.into(); cs * k];
2754 let rhs: &mut [c64] = &mut *avec![0.0.into(); n * cs];
2755 let dst: &mut [c64] = &mut *avec![0.0.into(); cs * n];
2756 rng.fill(cast_slice_mut::<c64, f64>(dst));
2757
2758 let target0: &mut [c64] = &mut *dst.to_vec();
2759 let target1: &mut [c64] = &mut *dst.to_vec();
2760
2761 let diag: &mut [c64] = &mut *avec![0.0.into(); k];
2762
2763 rng.fill(cast_slice_mut::<c64, f64>(lhs));
2764 rng.fill(cast_slice_mut::<c64, f64>(rhs));
2765
2766 for x in &mut *diag {
2767 x.re = rng.random();
2768 }
2769
2770 for i in 0..m {
2771 for j in 0..n {
2772 let target = &mut target0[i + cs * j];
2773 let mut acc: c64 = 0.0.into();
2774 for depth in 0..k {
2775 let mut l = lhs[i + cs * depth];
2776 let mut r = rhs[depth + cs * j];
2777 let d = diag[depth];
2778
2779 if conj_lhs {
2780 l = l.conj();
2781 }
2782 if conj_rhs {
2783 r = r.conj();
2784 }
2785
2786 if diag_scale {
2787 acc += d * l * r;
2788 } else {
2789 acc += l * r;
2790 }
2791 }
2792 *target = acc * alpha;
2793 }
2794 }
2795
2796 unsafe {
2797 gemm(
2798 DType::C64,
2799 IType::U64,
2800 InstrSet::Avx512,
2801 m,
2802 n,
2803 k,
2804 dst.as_mut_ptr() as _,
2805 1,
2806 cs as isize,
2807 null(),
2808 null(),
2809 DstKind::Full,
2810 Accum::Replace,
2811 lhs.as_ptr() as _,
2812 1,
2813 cs as isize,
2814 conj_lhs,
2815 if diag_scale { diag.as_ptr() as _ } else { null() },
2816 if diag_scale { 1 } else { 0 },
2817 rhs.as_ptr() as _,
2818 1,
2819 cs as isize,
2820 conj_rhs,
2821 &raw const alpha as _,
2822 2,
2823 )
2824 };
2825
2826 let mut i = 0;
2827 for (&target, &dst) in core::iter::zip(&*target0, &*dst) {
2828 if !((target - dst).norm_sqr().sqrt() < 1e-8) {
2829 dbg!(i / cs, i % cs, target, dst);
2830 panic!();
2831 }
2832 i += 1;
2833 }
2834 }
2835 }
2836 }
2837 }
2838 }
2839 }
2840 }
2841 }
2842}
2843
2844#[cfg(test)]
2845mod tests_c32_upper {
2846 use super::*;
2847
2848 use aligned_vec::*;
2849 use bytemuck::*;
2850 use core::ptr::null_mut;
2851 use gemm::c32;
2852 use rand::prelude::*;
2853
2854 #[test]
2855 fn test_avx512_microkernel() {
2856 let rng = &mut StdRng::seed_from_u64(0);
2857
2858 let sizeof = size_of::<c32>() as isize;
2859 let len = 64 / size_of::<c32>();
2860
2861 for pack_lhs in [false, true] {
2862 for pack_rhs in [false] {
2863 for alpha in [1.0.into(), 0.0.into(), c32::new(0.0, 3.5), c32::new(2.5, 3.5)] {
2864 let alpha: c32 = alpha;
2865 for m in 1..=127usize {
2866 for n in [8].into_iter().chain(1..=4usize).chain([8]) {
2867 for cs in [m, m.next_multiple_of(len)] {
2868 for conj_lhs in [false, true] {
2869 for conj_rhs in [false, true] {
2870 for diag_scale in [false, true] {
2871 if diag_scale && !pack_lhs {
2872 continue;
2873 }
2874 let conj_different = conj_lhs != conj_rhs;
2875
2876 let acs = m.next_multiple_of(len);
2877 let k = 1usize;
2878
2879 let packed_lhs: &mut [c32] = &mut *avec![0.0.into(); acs * k];
2880 let packed_rhs: &mut [c32] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
2881 let lhs: &mut [c32] = &mut *avec![0.0.into(); cs * k];
2882 let rhs: &mut [c32] = &mut *avec![0.0.into(); n * k];
2883 let dst: &mut [c32] = &mut *avec![0.0.into(); cs * n];
2884 let target: &mut [c32] = &mut *avec![0.0.into(); cs * n];
2885
2886 let diag: &mut [f32] = &mut *avec![0.0.into(); k];
2887
2888 rng.fill(cast_slice_mut::<c32, f32>(lhs));
2889 rng.fill(cast_slice_mut::<c32, f32>(rhs));
2890 rng.fill(diag);
2891
2892 for i in 0..m {
2893 for j in 0..n {
2894 if i > j {
2895 continue;
2896 }
2897 let target = &mut target[i + cs * j];
2898 let mut acc: c32 = 0.0.into();
2899 for depth in 0..k {
2900 let mut l = lhs[i + cs * depth];
2901 let mut r = rhs[depth + k * j];
2902 let d = diag[depth];
2903
2904 if conj_lhs {
2905 l = l.conj();
2906 }
2907 if conj_rhs {
2908 r = r.conj();
2909 }
2910
2911 if diag_scale {
2912 acc += d * l * r;
2913 } else {
2914 acc += l * r;
2915 }
2916 }
2917 *target = acc * alpha + *target;
2918 }
2919 }
2920
2921 unsafe {
2922 millikernel_colmajor(
2923 C32_SIMD512x4[3],
2924 C32_SIMDpack_512[0],
2925 48,
2926 4,
2927 8,
2928 lhs.as_ptr() as _,
2929 if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
2930 rhs.as_ptr() as _,
2931 if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
2932 m,
2933 n,
2934 &mut MillikernelInfo {
2935 lhs_rs: 48 * sizeof,
2936 packed_lhs_rs: 48 * sizeof * k as isize,
2937 rhs_cs: 4 * sizeof * k as isize,
2938 packed_rhs_cs: 4 * sizeof * k as isize,
2939 micro: MicrokernelInfo {
2940 flags: ((conj_lhs as usize) << 1) | ((conj_different as usize) << 2) | (1 << 4),
2941 depth: k,
2942 lhs_rs: 1 * sizeof,
2943 lhs_cs: cs as isize * sizeof,
2944 rhs_rs: 1 * sizeof,
2945 rhs_cs: k as isize * sizeof,
2946 alpha: &raw const alpha as _,
2947 ptr: dst.as_mut_ptr() as _,
2948 rs: 1 * sizeof,
2949 cs: cs as isize * sizeof,
2950 row_idx: null_mut(),
2951 col_idx: null_mut(),
2952 diag_ptr: if diag_scale { diag.as_ptr() as *const () } else { null() },
2953 diag_stride: if diag_scale { size_of::<f32>() as isize } else { 0 },
2954 },
2955 },
2956 &mut Position { row: 0, col: 0 },
2957 )
2958 };
2959 let mut i = 0;
2960 for (&target, &dst) in core::iter::zip(&*target, &*dst) {
2961 if !((target - dst).norm_sqr().sqrt() < 1e-4) {
2962 dbg!(i / cs, i % cs, target, dst);
2963 panic!();
2964 }
2965 i += 1;
2966 }
2967 }
2968 }
2969 }
2970 }
2971 }
2972 }
2973 }
2974 }
2975 }
2976 }
2977}
2978
2979#[cfg(test)]
2980mod transpose_tests {
2981 use super::*;
2982 use aligned_vec::avec;
2983 use rand::prelude::*;
2984
2985 #[test]
2986 fn test_b128() {
2987 let rng = &mut StdRng::seed_from_u64(0);
2988
2989 for m in 1..=24 {
2990 let n = 127;
2991
2992 let src = &mut *avec![0u128; m * n];
2993 let dst = &mut *avec![0u128; m.next_multiple_of(8) * n];
2994
2995 rng.fill(src);
2996 rng.fill(dst);
2997
2998 let ptr = C64_SIMDpack_512[(24 - m) / 4];
2999 let info = MicrokernelInfo {
3000 flags: 0,
3001 depth: n,
3002 lhs_rs: (n * size_of::<u128>()) as isize,
3003 lhs_cs: size_of::<u128>() as isize,
3004 rhs_rs: 0,
3005 rhs_cs: 0,
3006 alpha: null(),
3007 ptr: null_mut(),
3008 rs: 0,
3009 cs: 0,
3010 row_idx: null(),
3011 col_idx: null(),
3012 diag_ptr: null(),
3013 diag_stride: 0,
3014 };
3015
3016 unsafe {
3017 core::arch::asm! {"
3018 call r10
3019 ",
3020 in("r10") ptr,
3021 in("rax") src.as_ptr(),
3022 in("r15") dst.as_mut_ptr(),
3023 in("r8") m,
3024 in("rsi") &info,
3025 };
3026 }
3027
3028 for j in 0..n {
3029 for i in 0..m {
3030 assert_eq!(src[i * n + j], dst[i + m.next_multiple_of(4) * j]);
3031 }
3032 }
3033 }
3034 }
3035
3036 #[test]
3037 fn test_b64() {
3038 let rng = &mut StdRng::seed_from_u64(0);
3039
3040 for m in 1..=48 {
3041 let n = 127;
3042
3043 let src = &mut *avec![0u64; m * n];
3044 let dst = &mut *avec![0u64; m.next_multiple_of(8) * n];
3045
3046 rng.fill(src);
3047 rng.fill(dst);
3048
3049 let ptr = F64_SIMDpack_512[(48 - m) / 8];
3050 let info = MicrokernelInfo {
3051 flags: 0,
3052 depth: n,
3053 lhs_rs: (n * size_of::<u64>()) as isize,
3054 lhs_cs: size_of::<u64>() as isize,
3055 rhs_rs: 0,
3056 rhs_cs: 0,
3057 alpha: null(),
3058 ptr: null_mut(),
3059 rs: 0,
3060 cs: 0,
3061 row_idx: null(),
3062 col_idx: null(),
3063 diag_ptr: null(),
3064 diag_stride: 0,
3065 };
3066
3067 unsafe {
3068 core::arch::asm! {"
3069 call r10
3070 ",
3071 in("r10") ptr,
3072 in("rax") src.as_ptr(),
3073 in("r15") dst.as_mut_ptr(),
3074 in("r8") m,
3075 in("rsi") &info,
3076 };
3077 }
3078
3079 for j in 0..n {
3080 for i in 0..m {
3081 assert_eq!(src[i * n + j], dst[i + m.next_multiple_of(8) * j]);
3082 }
3083 }
3084 }
3085 }
3086
3087 #[test]
3088 fn test_b32() {
3089 let rng = &mut StdRng::seed_from_u64(0);
3090
3091 for m in 1..=96 {
3092 let n = 127;
3093
3094 let src = &mut *avec![0u32; m * n];
3095 let dst = &mut *avec![0u32; m.next_multiple_of(16) * n];
3096
3097 rng.fill(src);
3098 rng.fill(dst);
3099
3100 let ptr = F32_SIMDpack_512[(96 - m) / 16];
3101 let info = MicrokernelInfo {
3102 flags: 0,
3103 depth: n,
3104 lhs_rs: (n * size_of::<f32>()) as isize,
3105 lhs_cs: size_of::<f32>() as isize,
3106 rhs_rs: 0,
3107 rhs_cs: 0,
3108 alpha: null(),
3109 ptr: null_mut(),
3110 rs: 0,
3111 cs: 0,
3112 row_idx: null(),
3113 col_idx: null(),
3114 diag_ptr: null(),
3115 diag_stride: 0,
3116 };
3117
3118 unsafe {
3119 core::arch::asm! {"
3120 call r10
3121 ",
3122 in("r10") ptr,
3123 in("rax") src.as_ptr(),
3124 in("r15") dst.as_mut_ptr(),
3125 in("r8") m,
3126 in("rsi") &info,
3127 };
3128 }
3129
3130 for j in 0..n {
3131 for i in 0..m {
3132 assert_eq!(src[i * n + j], dst[i + m.next_multiple_of(16) * j]);
3133 }
3134 }
3135 }
3136 }
3137}
3138
3139#[cfg(test)]
3140mod tests_c32_gather_scatter {
3141 use super::*;
3142
3143 use aligned_vec::*;
3144 use bytemuck::*;
3145 use core::ptr::null_mut;
3146 use gemm::c32;
3147 use rand::prelude::*;
3148
3149 #[test]
3150 fn test_avx512_microkernel() {
3151 let rng = &mut StdRng::seed_from_u64(0);
3152
3153 let sizeof = size_of::<c32>() as isize;
3154 let len = 64 / size_of::<c32>();
3155
3156 for pack_lhs in [false, true] {
3157 for pack_rhs in [false] {
3158 for alpha in [1.0.into(), 0.0.into(), c32::new(0.0, 3.5), c32::new(2.5, 3.5)] {
3159 let alpha: c32 = alpha;
3160 for m in 1..=127usize {
3161 for n in [8].into_iter().chain(1..=4usize).chain([8]) {
3162 for cs in [m, m.next_multiple_of(len)] {
3163 for conj_lhs in [false, true] {
3164 for conj_rhs in [false, true] {
3165 for diag_scale in [false, true] {
3166 if diag_scale && !pack_lhs {
3167 continue;
3168 }
3169
3170 let m = 2usize;
3171 let cs = m;
3172 let conj_different = conj_lhs != conj_rhs;
3173
3174 let acs = m.next_multiple_of(len);
3175 let k = 1usize;
3176
3177 let packed_lhs: &mut [c32] = &mut *avec![0.0.into(); acs * k];
3178 let packed_rhs: &mut [c32] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
3179 let lhs: &mut [c32] = &mut *avec![0.0.into(); cs * k];
3180 let rhs: &mut [c32] = &mut *avec![0.0.into(); n * k];
3181 let dst: &mut [c32] = &mut *avec![0.0.into(); 2 * cs * n];
3182 let target: &mut [c32] = &mut *avec![0.0.into(); 2 * cs * n];
3183
3184 let diag: &mut [f32] = &mut *avec![0.0.into(); k];
3185
3186 rng.fill(cast_slice_mut::<c32, f32>(lhs));
3187 rng.fill(cast_slice_mut::<c32, f32>(rhs));
3188 rng.fill(diag);
3189
3190 for i in 0..m {
3191 for j in 0..n {
3192 if i > j {
3193 continue;
3194 }
3195 let target = &mut target[2 * (i + cs * j)];
3196 let mut acc: c32 = 0.0.into();
3197 for depth in 0..k {
3198 let mut l = lhs[i + cs * depth];
3199 let mut r = rhs[depth + k * j];
3200 let d = diag[depth];
3201
3202 if conj_lhs {
3203 l = l.conj();
3204 }
3205 if conj_rhs {
3206 r = r.conj();
3207 }
3208
3209 if diag_scale {
3210 acc += d * l * r;
3211 } else {
3212 acc += l * r;
3213 }
3214 }
3215 *target = acc * alpha + *target;
3216 }
3217 }
3218
3219 unsafe {
3220 millikernel_colmajor(
3221 C32_SIMD512x4[3],
3222 C32_SIMDpack_512[0],
3223 48,
3224 4,
3225 8,
3226 lhs.as_ptr() as _,
3227 if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
3228 rhs.as_ptr() as _,
3229 if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
3230 m,
3231 n,
3232 &mut MillikernelInfo {
3233 lhs_rs: 48 * sizeof,
3234 packed_lhs_rs: 48 * sizeof * k as isize,
3235 rhs_cs: 4 * sizeof * k as isize,
3236 packed_rhs_cs: 4 * sizeof * k as isize,
3237 micro: MicrokernelInfo {
3238 flags: ((conj_lhs as usize) << 1) | ((conj_different as usize) << 2) | (1 << 4),
3239 depth: k,
3240 lhs_rs: 1 * sizeof,
3241 lhs_cs: cs as isize * sizeof,
3242 rhs_rs: 1 * sizeof,
3243 rhs_cs: k as isize * sizeof,
3244 alpha: &raw const alpha as _,
3245 ptr: dst.as_mut_ptr() as _,
3246 rs: 2 * sizeof,
3247 cs: 2 * cs as isize * sizeof,
3248 row_idx: null_mut(),
3249 col_idx: null_mut(),
3250 diag_ptr: if diag_scale { diag.as_ptr() as *const () } else { null() },
3251 diag_stride: if diag_scale { size_of::<f32>() as isize } else { 0 },
3252 },
3253 },
3254 &mut Position { row: 0, col: 0 },
3255 )
3256 };
3257 let mut i = 0;
3258 for (&target, &dst) in core::iter::zip(&*target, &*dst) {
3259 if !((target - dst).norm_sqr().sqrt() < 1e-4) {
3260 dbg!(i / cs, i % cs, target, dst);
3261 panic!();
3262 }
3263 i += 1;
3264 }
3265 }
3266 }
3267 }
3268 }
3269 }
3270 }
3271 }
3272 }
3273 }
3274 }
3275
3276 #[test]
3277 fn test_avx512_microkernel2() {
3278 let rng = &mut StdRng::seed_from_u64(0);
3279
3280 let sizeof = size_of::<c32>() as isize;
3281 let len = 64 / size_of::<c32>();
3282
3283 for pack_lhs in [false, true] {
3284 for pack_rhs in [false] {
3285 for alpha in [1.0.into(), 0.0.into(), c32::new(0.0, 3.5), c32::new(2.5, 3.5)] {
3286 let alpha: c32 = alpha;
3287 for m in 1..=127usize {
3288 for n in [8].into_iter().chain(1..=4usize).chain([8]) {
3289 for cs in [m, m.next_multiple_of(len)] {
3290 for conj_lhs in [false, true] {
3291 for conj_rhs in [false, true] {
3292 for diag_scale in [false, true] {
3293 if diag_scale && !pack_lhs {
3294 continue;
3295 }
3296 let m = 2usize;
3297 let cs = m;
3298 let conj_different = conj_lhs != conj_rhs;
3299 let idx = (0..Ord::max(m, n)).map(|i| 2 * i as u32).collect::<Vec<_>>();
3300
3301 let acs = m.next_multiple_of(len);
3302 let k = 1usize;
3303
3304 let packed_lhs: &mut [c32] = &mut *avec![0.0.into(); acs * k];
3305 let packed_rhs: &mut [c32] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
3306 let lhs: &mut [c32] = &mut *avec![0.0.into(); cs * k];
3307 let rhs: &mut [c32] = &mut *avec![0.0.into(); n * k];
3308 let dst: &mut [c32] = &mut *avec![0.0.into(); 2 * cs * n];
3309 let target: &mut [c32] = &mut *avec![0.0.into(); 2 * cs * n];
3310
3311 let diag: &mut [f32] = &mut *avec![0.0.into(); k];
3312
3313 rng.fill(cast_slice_mut::<c32, f32>(lhs));
3314 rng.fill(cast_slice_mut::<c32, f32>(rhs));
3315 rng.fill(diag);
3316
3317 for i in 0..m {
3318 for j in 0..n {
3319 if i > j {
3320 continue;
3321 }
3322 let target = &mut target[2 * (i + cs * j)];
3323 let mut acc: c32 = 0.0.into();
3324 for depth in 0..k {
3325 let mut l = lhs[i + cs * depth];
3326 let mut r = rhs[depth + k * j];
3327 let d = diag[depth];
3328
3329 if conj_lhs {
3330 l = l.conj();
3331 }
3332 if conj_rhs {
3333 r = r.conj();
3334 }
3335
3336 if diag_scale {
3337 acc += d * l * r;
3338 } else {
3339 acc += l * r;
3340 }
3341 }
3342 *target = acc * alpha + *target;
3343 }
3344 }
3345
3346 unsafe {
3347 millikernel_colmajor(
3348 C32_SIMD512x4[3],
3349 C32_SIMDpack_512[0],
3350 48,
3351 4,
3352 8,
3353 lhs.as_ptr() as _,
3354 if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
3355 rhs.as_ptr() as _,
3356 if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
3357 m,
3358 n,
3359 &mut MillikernelInfo {
3360 lhs_rs: 48 * sizeof,
3361 packed_lhs_rs: 48 * sizeof * k as isize,
3362 rhs_cs: 4 * sizeof * k as isize,
3363 packed_rhs_cs: 4 * sizeof * k as isize,
3364 micro: MicrokernelInfo {
3365 flags: ((conj_lhs as usize) << 1)
3366 | ((conj_different as usize) << 2) | (1 << 4) | (1 << 5),
3367 depth: k,
3368 lhs_rs: 1 * sizeof,
3369 lhs_cs: cs as isize * sizeof,
3370 rhs_rs: 1 * sizeof,
3371 rhs_cs: k as isize * sizeof,
3372 alpha: &raw const alpha as _,
3373 ptr: dst.as_mut_ptr() as _,
3374 rs: sizeof,
3375 cs: cs as isize * sizeof,
3376 row_idx: idx.as_ptr() as _,
3377 col_idx: idx.as_ptr() as _,
3378 diag_ptr: if diag_scale { diag.as_ptr() as *const () } else { null() },
3379 diag_stride: if diag_scale { size_of::<f32>() as isize } else { 0 },
3380 },
3381 },
3382 &mut Position { row: 0, col: 0 },
3383 )
3384 };
3385 let mut i = 0;
3386 for (&target, &dst) in core::iter::zip(&*target, &*dst) {
3387 if !((target - dst).norm_sqr().sqrt() < 1e-4) {
3388 dbg!(i / cs, i % cs, target, dst);
3389 panic!();
3390 }
3391 i += 1;
3392 }
3393 }
3394 }
3395 }
3396 }
3397 }
3398 }
3399 }
3400 }
3401 }
3402 }
3403
3404 #[test]
3405 fn test_avx512_microkernel3() {
3406 let rng = &mut StdRng::seed_from_u64(0);
3407
3408 let sizeof = size_of::<c32>() as isize;
3409 let len = 64 / size_of::<c32>();
3410
3411 for pack_lhs in [true] {
3412 for pack_rhs in [false] {
3413 for alpha in [1.0.into(), 0.0.into(), c32::new(0.0, 3.5), c32::new(2.5, 3.5)] {
3414 let alpha: c32 = alpha;
3415 for m in 1..=127usize {
3416 for n in [8].into_iter().chain(1..=4usize).chain([8]) {
3417 for cs in [m, m.next_multiple_of(len)] {
3418 for conj_lhs in [false, true] {
3419 for conj_rhs in [false, true] {
3420 for diag_scale in [false, true] {
3421 if diag_scale && !pack_lhs {
3422 continue;
3423 }
3424 let m = 2usize;
3425 let cs = m;
3426 let conj_different = conj_lhs != conj_rhs;
3427 let idx = (0..Ord::max(m, n)).map(|i| 2 * i as u32).collect::<Vec<_>>();
3428
3429 let acs = m.next_multiple_of(len);
3430 let k = 1usize;
3431
3432 let packed_lhs: &mut [c32] = &mut *avec![0.0.into(); acs * k];
3433 let packed_rhs: &mut [c32] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
3434 let lhs: &mut [c32] = &mut *avec![0.0.into(); 2 * cs * k];
3435 let rhs: &mut [c32] = &mut *avec![0.0.into(); n * k];
3436 let dst: &mut [c32] = &mut *avec![0.0.into(); 2 * cs * n];
3437 let target: &mut [c32] = &mut *avec![0.0.into(); 2 * cs * n];
3438
3439 let diag: &mut [f32] = &mut *avec![0.0.into(); k];
3440
3441 rng.fill(cast_slice_mut::<c32, f32>(lhs));
3442 rng.fill(cast_slice_mut::<c32, f32>(rhs));
3443 rng.fill(diag);
3444
3445 for i in 0..m {
3446 for j in 0..n {
3447 if i > j {
3448 continue;
3449 }
3450 let target = &mut target[2 * (i + cs * j)];
3451 let mut acc: c32 = 0.0.into();
3452 for depth in 0..k {
3453 let mut l = lhs[2 * (i + cs * depth)];
3454 let mut r = rhs[depth + k * j];
3455 let d = diag[depth];
3456
3457 if conj_lhs {
3458 l = l.conj();
3459 }
3460 if conj_rhs {
3461 r = r.conj();
3462 }
3463
3464 if diag_scale {
3465 acc += d * l * r;
3466 } else {
3467 acc += l * r;
3468 }
3469 }
3470 *target = acc * alpha + *target;
3471 }
3472 }
3473
3474 unsafe {
3475 millikernel_colmajor(
3476 C32_SIMD512x4[3],
3477 C32_SIMDpack_512[0],
3478 48,
3479 4,
3480 8,
3481 lhs.as_ptr() as _,
3482 if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
3483 rhs.as_ptr() as _,
3484 if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
3485 m,
3486 n,
3487 &mut MillikernelInfo {
3488 lhs_rs: 48 * sizeof,
3489 packed_lhs_rs: 48 * sizeof * k as isize,
3490 rhs_cs: 4 * sizeof * k as isize,
3491 packed_rhs_cs: 4 * sizeof * k as isize,
3492 micro: MicrokernelInfo {
3493 flags: ((conj_lhs as usize) * FLAGS_CONJ_LHS)
3494 | ((conj_different as usize) * FLAGS_CONJ_NEQ) | (1 * FLAGS_UPPER)
3495 | (1 * FLAGS_32BIT_IDX) | (1 * FLAGS_CPLX),
3496 depth: k,
3497 lhs_rs: 2 * sizeof,
3498 lhs_cs: 2 * cs as isize * sizeof,
3499 rhs_rs: 1 * sizeof,
3500 rhs_cs: k as isize * sizeof,
3501 alpha: &raw const alpha as _,
3502 ptr: dst.as_mut_ptr() as _,
3503 rs: sizeof,
3504 cs: cs as isize * sizeof,
3505 row_idx: idx.as_ptr() as _,
3506 col_idx: idx.as_ptr() as _,
3507 diag_ptr: if diag_scale { diag.as_ptr() as *const () } else { null() },
3508 diag_stride: if diag_scale { size_of::<f32>() as isize } else { 0 },
3509 },
3510 },
3511 &mut Position { row: 0, col: 0 },
3512 )
3513 };
3514 let mut i = 0;
3515 for (&target, &dst) in core::iter::zip(&*target, &*dst) {
3516 if !((target - dst).norm_sqr().sqrt() < 1e-4) {
3517 dbg!(i / cs, i % cs, target, dst);
3518 panic!();
3519 }
3520 i += 1;
3521 }
3522 }
3523 }
3524 }
3525 }
3526 }
3527 }
3528 }
3529 }
3530 }
3531 }
3532}