private_gemm_x86/
lib.rs

1#![allow(non_upper_case_globals)]
2#![allow(dead_code, unused_variables)]
3
4const M: usize = 4;
5const N: usize = 32;
6
7use core::cell::RefCell;
8use core::ptr::{null, null_mut};
9use core::sync::atomic::{AtomicU8, AtomicUsize, Ordering};
10
11use cache::CACHE_INFO;
12
13include!(concat!(env!("OUT_DIR"), "/asm.rs"));
14
15#[derive(Copy, Clone, Debug)]
16#[repr(C)]
17pub struct Position {
18	pub row: usize,
19	pub col: usize,
20}
21
22mod cache;
23
24const FLAGS_ACCUM: usize = 1 << 0;
25const FLAGS_CONJ_LHS: usize = 1 << 1;
26const FLAGS_CONJ_NEQ: usize = 1 << 2;
27const FLAGS_LOWER: usize = 1 << 3;
28const FLAGS_UPPER: usize = 1 << 4;
29const FLAGS_32BIT_IDX: usize = 1 << 5;
30const FLAGS_CPLX: usize = 1 << 62;
31const FLAGS_ROWMAJOR: usize = 1 << 63;
32
33#[derive(Copy, Clone, Debug)]
34#[repr(C)]
35pub struct MicrokernelInfo {
36	pub flags: usize,
37	pub depth: usize,
38	pub lhs_rs: isize,
39	pub lhs_cs: isize,
40	pub rhs_rs: isize,
41	pub rhs_cs: isize,
42	pub alpha: *const (),
43
44	// dst
45	pub ptr: *mut (),
46	pub rs: isize,
47	pub cs: isize,
48	pub row_idx: *const (),
49	pub col_idx: *const (),
50
51	// diag
52	pub diag_ptr: *const (),
53	pub diag_stride: isize,
54}
55
56#[derive(Copy, Clone, Debug)]
57#[repr(C)]
58pub struct MillikernelInfo {
59	pub lhs_rs: isize,
60	pub packed_lhs_rs: isize,
61	pub rhs_cs: isize,
62	pub packed_rhs_cs: isize,
63	pub micro: MicrokernelInfo,
64}
65
66#[inline(always)]
67unsafe fn pack_rhs_imp<T: Copy>(dst: *mut T, src: *const (), depth: usize, stride: usize, nr: usize, rs: isize, cs: isize) {
68	for i in 0..depth {
69		unsafe {
70			let dst = dst.add(i * stride);
71			let src = src.byte_offset(i as isize * rs);
72
73			for j in 0..nr {
74				let dst = dst.add(j);
75				let src = src.byte_offset(j as isize * cs) as *const T;
76
77				*dst = *src;
78			}
79		}
80	}
81}
82
83#[inline(never)]
84unsafe fn pack_rhs(dst: *mut (), src: *const (), depth: usize, nr: usize, rs: isize, cs: isize, sizeof: usize) {
85	if !src.is_null() && src != dst as *const () {
86		unsafe {
87			match sizeof {
88				4 => pack_rhs_imp(dst as *mut f32, src, depth, nr, nr, rs, cs),
89				8 => pack_rhs_imp(dst as *mut [f32; 2], src, depth, nr, nr, rs, cs),
90				16 => pack_rhs_imp(dst as *mut [f64; 2], src, depth, nr, nr, rs, cs),
91				_ => unreachable!(),
92			}
93		}
94	}
95}
96
97#[inline(always)]
98pub unsafe fn call_microkernel(
99	microkernel: unsafe extern "C" fn(),
100	lhs: *const (),
101	packed_lhs: *mut (),
102
103	rhs: *const (),
104	packed_rhs: *mut (),
105
106	mut nrows: usize,
107	mut ncols: usize,
108
109	micro: &MicrokernelInfo,
110	position: &mut Position,
111) -> (usize, usize) {
112	unsafe {
113		core::arch::asm! {
114			"call r10",
115
116			in("rax") lhs,
117			in("r15") packed_lhs,
118			in("rcx") rhs,
119			in("rdx") packed_rhs,
120			in("rdi") position,
121			in("rsi") micro,
122			inout("r8") nrows,
123			inout("r9") ncols,
124			in("r10") microkernel,
125
126			out("zmm0") _,
127			out("zmm1") _,
128			out("zmm2") _,
129			out("zmm3") _,
130			out("zmm4") _,
131			out("zmm5") _,
132			out("zmm6") _,
133			out("zmm7") _,
134			out("zmm8") _,
135			out("zmm9") _,
136			out("zmm10") _,
137			out("zmm11") _,
138			out("zmm12") _,
139			out("zmm13") _,
140			out("zmm14") _,
141			out("zmm15") _,
142			out("zmm16") _,
143			out("zmm17") _,
144			out("zmm18") _,
145			out("zmm19") _,
146			out("zmm20") _,
147			out("zmm21") _,
148			out("zmm22") _,
149			out("zmm23") _,
150			out("zmm24") _,
151			out("zmm25") _,
152			out("zmm26") _,
153			out("zmm27") _,
154			out("zmm28") _,
155			out("zmm29") _,
156			out("zmm30") _,
157			out("zmm31") _,
158			out("k1") _,
159			out("k2") _,
160			out("k3") _,
161			out("k4") _,
162		}
163	}
164	(nrows, ncols)
165}
166
167pub unsafe fn millikernel_rowmajor(
168	microkernel: unsafe extern "C" fn(),
169	pack: unsafe extern "C" fn(),
170	mr: usize,
171	nr: usize,
172	sizeof: usize,
173
174	lhs: *const (),
175	packed_lhs: *mut (),
176
177	rhs: *const (),
178	packed_rhs: *mut (),
179
180	nrows: usize,
181	ncols: usize,
182
183	milli: &MillikernelInfo,
184
185	pos: &mut Position,
186) {
187	let mut rhs = rhs;
188	let mut nrows = nrows;
189	let mut lhs = lhs;
190	let mut packed_lhs = packed_lhs;
191
192	let tril = milli.micro.flags & FLAGS_LOWER != 0;
193	let triu = milli.micro.flags & FLAGS_UPPER != 0;
194	let rectangular = !tril && !triu;
195
196	loop {
197		let rs = milli.micro.lhs_rs;
198		unsafe {
199			let mut rhs = rhs;
200			let mut packed_rhs = packed_rhs;
201			let mut ncols = ncols;
202			let mut lhs = lhs;
203			let col = pos.col;
204
205			macro_rules! iter {
206                ($($lhs: ident)?) => {{
207                    $({
208                        let _ = $lhs;
209                        if lhs != packed_lhs && !lhs.is_null() && (!milli.micro.diag_ptr.is_null() || milli.micro.lhs_rs != sizeof as isize) {
210                            pack_lhs(pack, milli, Ord::min(nrows, mr), packed_lhs, lhs, sizeof);
211                            lhs = null();
212                        }
213                    })*
214
215                    let row_chunk = Ord::min(nrows, mr);
216                    let col_chunk = Ord::min(ncols, nr);
217
218                    {
219                        let mut rhs = rhs;
220                        if rhs != packed_rhs && !rhs.is_null() {
221                            pack_rhs(
222                                packed_rhs,
223                                rhs,
224                                milli.micro.depth,
225                                col_chunk,
226                                milli.micro.rhs_rs,
227                                milli.micro.rhs_cs,
228                                sizeof,
229                            );
230                            rhs = null();
231                        }
232
233
234                        if rectangular || (tril && pos.row + mr > pos.col) || (triu && pos.col + col_chunk > pos.row) {
235                            call_microkernel(
236                                microkernel,
237                                lhs,
238                                packed_lhs,
239                                rhs,
240                                packed_rhs,
241                                row_chunk,
242                                col_chunk,
243                                &milli.micro,
244                                pos,
245                            );
246                        } else {
247                            if lhs != packed_lhs && !lhs.is_null() {
248                                pack_lhs(pack, milli, row_chunk, packed_lhs, lhs, sizeof);
249                            }
250                        }
251                    }
252
253                    pos.col += col_chunk;
254                    ncols -= col_chunk;
255                    if ncols == 0 {
256                        pos.row += row_chunk;
257                        nrows -= row_chunk;
258                    }
259
260                    if !rhs.is_null() {
261                        rhs = rhs.wrapping_byte_offset(milli.rhs_cs);
262                    }
263                    packed_rhs = packed_rhs.wrapping_byte_offset(milli.packed_rhs_cs);
264
265                    $(if lhs != packed_lhs {
266                        $lhs = null();
267                    })?
268                }};
269            }
270			iter!(lhs);
271			while ncols > 0 {
272				iter!();
273			}
274			pos.col = col;
275		}
276
277		if !lhs.is_null() {
278			lhs = lhs.wrapping_byte_offset(milli.lhs_rs);
279		}
280		packed_lhs = packed_lhs.wrapping_byte_offset(milli.packed_lhs_rs);
281		if rhs != packed_rhs {
282			rhs = null();
283		}
284
285		if nrows == 0 {
286			break;
287		}
288	}
289}
290
291pub unsafe fn millikernel_colmajor(
292	microkernel: unsafe extern "C" fn(),
293	pack: unsafe extern "C" fn(),
294	mr: usize,
295	nr: usize,
296	sizeof: usize,
297
298	lhs: *const (),
299	packed_lhs: *mut (),
300
301	rhs: *const (),
302	packed_rhs: *mut (),
303
304	nrows: usize,
305	ncols: usize,
306
307	milli: &MillikernelInfo,
308
309	pos: &mut Position,
310) {
311	let mut lhs = lhs;
312	let mut ncols = ncols;
313	let mut rhs = rhs;
314	let mut packed_rhs = packed_rhs;
315
316	let tril = milli.micro.flags & FLAGS_LOWER != 0;
317	let triu = milli.micro.flags & FLAGS_UPPER != 0;
318	let rectangular = !tril && !triu;
319
320	let mut j = 0;
321
322	loop {
323		let cs = milli.micro.rhs_cs;
324		unsafe {
325			let mut lhs = lhs;
326			let mut packed_lhs = packed_lhs;
327			let mut nrows = nrows;
328			let mut rhs = rhs;
329			let row = pos.row;
330
331			macro_rules! iter {
332                ($($rhs: ident)?) => {{
333                    {
334                        let mut lhs = lhs;
335
336                        let row_chunk = Ord::min(nrows, mr);
337                        let col_chunk = Ord::min(ncols, nr);
338
339                        if lhs != packed_lhs && !lhs.is_null() && (!milli.micro.diag_ptr.is_null() || milli.micro.lhs_rs != sizeof as isize) {
340                            pack_lhs(pack, milli, row_chunk, packed_lhs, lhs, sizeof);
341                            lhs = null();
342                        }
343
344                        $({
345                            let _ = $rhs;
346                            if rhs != packed_rhs && !rhs.is_null() {
347                                pack_rhs(
348                                    packed_rhs,
349                                    rhs,
350                                    milli.micro.depth,
351                                    col_chunk,
352                                    milli.micro.rhs_rs,
353                                    milli.micro.rhs_cs,
354                                    sizeof,
355                                );
356                                rhs = null();
357                            }
358                        })*
359                        if rectangular || (tril && pos.row + mr > pos.col) || (triu && pos.col + col_chunk > pos.row) {
360                            call_microkernel(
361                                microkernel,
362                                lhs,
363                                packed_lhs,
364                                rhs,
365                                packed_rhs,
366                                row_chunk,
367                                col_chunk,
368                                &milli.micro,
369                                pos,
370                            );
371                        } else {
372                            if lhs != packed_lhs && !lhs.is_null() {
373                                pack_lhs(pack, milli, row_chunk, packed_lhs, lhs, sizeof);
374                            }
375                        }
376
377                        pos.row += row_chunk;
378                        nrows -= row_chunk;
379                        if nrows == 0 {
380                            pos.col += col_chunk;
381                            ncols -= col_chunk;
382                        }
383                    }
384
385                    if !lhs.is_null() {
386                        lhs = lhs.wrapping_byte_offset(milli.lhs_rs);
387                    }
388                    packed_lhs = packed_lhs.wrapping_byte_offset(milli.packed_lhs_rs);
389
390                    $(if rhs != packed_rhs {
391                        $rhs = null();
392                    })?
393                }};
394            }
395			iter!(rhs);
396			while nrows > 0 {
397				iter!();
398			}
399			pos.row = row;
400		}
401
402		if !rhs.is_null() {
403			rhs = rhs.wrapping_byte_offset(milli.rhs_cs);
404		}
405		packed_rhs = packed_rhs.wrapping_byte_offset(milli.packed_rhs_cs);
406		if lhs != packed_lhs {
407			lhs = null();
408		}
409
410		j += 1;
411		if ncols == 0 {
412			break;
413		}
414	}
415}
416
417pub unsafe fn millikernel_par(
418	thd_id: usize,
419	n_threads: usize,
420
421	microkernel_job: &[AtomicU8],
422	pack_lhs_job: &[AtomicU8],
423	pack_rhs_job: &[AtomicU8],
424	finished: &AtomicUsize,
425	hyper: usize,
426
427	mr: usize,
428	nr: usize,
429	sizeof: usize,
430
431	mf: usize,
432	nf: usize,
433
434	microkernel: unsafe extern "C" fn(),
435	pack: unsafe extern "C" fn(),
436
437	lhs: *const (),
438	packed_lhs: *mut (),
439
440	rhs: *const (),
441	packed_rhs: *mut (),
442
443	nrows: usize,
444	ncols: usize,
445
446	milli: &MillikernelInfo,
447
448	pos: Position,
449	tall: bool,
450) {
451	let n_threads0 = nrows.div_ceil(mf * mr);
452	let n_threads1 = ncols.div_ceil(nf * nr);
453
454	let thd_id0 = thd_id % (n_threads0);
455	let thd_id1 = thd_id / (n_threads0);
456
457	let tril = milli.micro.flags & FLAGS_LOWER != 0;
458	let triu = milli.micro.flags & FLAGS_UPPER != 0;
459	let rectangular = !tril && !triu;
460
461	let i = mf * thd_id0;
462	let j = nf * thd_id1;
463
464	let colmajor = !tall;
465
466	for ij in 0..mf * nf {
467		let (i, j) = if colmajor {
468			(i + ij % mf, j + ij / mf)
469		} else {
470			(i + ij / nf, j + ij % nf)
471		};
472
473		let row = Ord::min(nrows, i * mr);
474		let col = Ord::min(ncols, j * nr);
475
476		let row_chunk = Ord::min(nrows - row, mr);
477		let col_chunk = Ord::min(ncols - col, nr);
478
479		if row_chunk == 0 || col_chunk == 0 {
480			continue;
481		}
482
483		let packed_lhs = packed_lhs.wrapping_byte_offset(milli.packed_lhs_rs * i as isize);
484		let packed_rhs = packed_rhs.wrapping_byte_offset(milli.packed_rhs_cs * j as isize);
485
486		let mut lhs = lhs;
487		let mut rhs = rhs;
488
489		{
490			if !lhs.is_null() {
491				lhs = lhs.wrapping_byte_offset(milli.lhs_rs * i as isize);
492			}
493
494			if lhs != packed_lhs {
495				let val = pack_lhs_job[i].load(Ordering::Acquire);
496
497				if val == 2 {
498					lhs = null();
499				}
500			}
501		}
502
503		{
504			if !rhs.is_null() {
505				rhs = rhs.wrapping_byte_offset(milli.rhs_cs * j as isize);
506			}
507			if rhs != packed_rhs {
508				let val = pack_rhs_job[j].load(Ordering::Acquire);
509
510				if val == 2 {
511					rhs = null();
512				}
513			}
514
515			unsafe {
516				if lhs != packed_lhs && !lhs.is_null() && (!milli.micro.diag_ptr.is_null() || milli.micro.lhs_rs != sizeof as isize) {
517					pack_lhs(pack, milli, row_chunk, packed_lhs, lhs, sizeof);
518
519					lhs = null();
520					pack_lhs_job[i].store(2, Ordering::Release);
521				}
522				if rhs != packed_rhs && !rhs.is_null() {
523					pack_rhs(
524						packed_rhs,
525						rhs,
526						milli.micro.depth,
527						col_chunk,
528						milli.micro.rhs_rs,
529						milli.micro.rhs_cs,
530						sizeof,
531					);
532					rhs = null();
533					pack_rhs_job[j].store(2, Ordering::Release);
534				}
535
536				if rectangular || (tril && pos.row + mr > pos.col) || (triu && pos.col + col_chunk > pos.row) {
537					call_microkernel(
538						microkernel,
539						lhs,
540						packed_lhs,
541						rhs,
542						packed_rhs,
543						row_chunk,
544						col_chunk,
545						&milli.micro,
546						&mut Position {
547							row: row + pos.row,
548							col: col + pos.col,
549						},
550					);
551				} else {
552					if lhs != packed_lhs && !lhs.is_null() {
553						pack_lhs(pack, milli, row_chunk, packed_lhs, lhs, sizeof);
554					}
555				}
556			}
557
558			if !lhs.is_null() && lhs != packed_lhs {
559				pack_lhs_job[i].store(2, Ordering::Release);
560			}
561			if !rhs.is_null() && rhs != packed_rhs {
562				pack_rhs_job[j].store(2, Ordering::Release);
563			}
564		}
565	}
566}
567
568unsafe fn pack_lhs(pack: unsafe extern "C" fn(), milli: &MillikernelInfo, row_chunk: usize, packed_lhs: *mut (), lhs: *const (), sizeof: usize) {
569	unsafe {
570		{
571			let mut dst_cs = row_chunk;
572			core::arch::asm! {
573				"call r10",
574				in("r10") pack,
575				in("rax") lhs,
576				in("r15") packed_lhs,
577				inout("r8") dst_cs,
578				in("rsi") &milli.micro,
579
580				out("zmm0") _,
581				out("zmm1") _,
582				out("zmm2") _,
583				out("zmm3") _,
584				out("zmm4") _,
585				out("zmm5") _,
586				out("zmm6") _,
587				out("zmm7") _,
588				out("zmm8") _,
589				out("zmm9") _,
590				out("zmm10") _,
591				out("zmm11") _,
592				out("zmm12") _,
593				out("zmm13") _,
594				out("zmm14") _,
595				out("zmm15") _,
596				out("zmm16") _,
597				out("zmm17") _,
598				out("zmm18") _,
599				out("zmm19") _,
600				out("zmm20") _,
601				out("zmm21") _,
602				out("zmm22") _,
603				out("zmm23") _,
604				out("zmm24") _,
605				out("zmm25") _,
606				out("zmm26") _,
607				out("zmm27") _,
608				out("zmm28") _,
609				out("zmm29") _,
610				out("zmm30") _,
611				out("zmm31") _,
612				out("k1") _,
613				out("k2") _,
614				out("k3") _,
615				out("k4") _,
616			};
617
618			if milli.micro.lhs_rs != sizeof as isize && milli.micro.lhs_cs != sizeof as isize {
619				for j in 0..milli.micro.depth {
620					let dst = packed_lhs.byte_add(j * dst_cs);
621					let src = lhs.byte_offset(j as isize * milli.micro.lhs_cs);
622					let diag_ptr = milli.micro.diag_ptr.byte_offset(j as isize * milli.micro.diag_stride);
623
624					if sizeof == 4 {
625						let dst = dst as *mut f32;
626						let src = src as *const f32;
627						for i in 0..row_chunk {
628							let dst = dst.add(i);
629							let src = src.byte_offset(i as isize * milli.micro.lhs_rs);
630
631							if diag_ptr.is_null() {
632								*dst = *src;
633							} else {
634								*dst = *src * *(diag_ptr as *const f32);
635							}
636						}
637					} else if sizeof == 16 {
638						let dst = dst as *mut [f64; 2];
639						let src = src as *const [f64; 2];
640						for i in 0..row_chunk {
641							let dst = dst.add(i);
642							let src = src.byte_offset(i as isize * milli.micro.lhs_rs);
643
644							if diag_ptr.is_null() {
645								*dst = *src;
646							} else {
647								(*dst)[0] = (*src)[0] * *(diag_ptr as *const f64);
648								(*dst)[1] = (*src)[1] * *(diag_ptr as *const f64);
649							}
650						}
651					} else {
652						if (milli.micro.flags >> 62) & 1 == 1 {
653							let dst = dst as *mut [f32; 2];
654							let src = src as *const [f32; 2];
655							for i in 0..row_chunk {
656								let dst = dst.add(i);
657								let src = src.byte_offset(i as isize * milli.micro.lhs_rs);
658
659								if diag_ptr.is_null() {
660									*dst = *src;
661								} else {
662									(*dst)[0] = (*src)[0] * *(diag_ptr as *const f32);
663									(*dst)[1] = (*src)[1] * *(diag_ptr as *const f32);
664								}
665							}
666						} else {
667							let dst = dst as *mut f64;
668							let src = src as *const f64;
669							for i in 0..row_chunk {
670								let dst = dst.add(i);
671								let src = src.byte_offset(i as isize * milli.micro.lhs_rs);
672
673								if diag_ptr.is_null() {
674									*dst = *src;
675								} else {
676									*dst = *src * *(diag_ptr as *const f64);
677								}
678							}
679						}
680					}
681				}
682			}
683		}
684	}
685}
686
687pub unsafe trait Millikernel {
688	unsafe fn call(
689		&mut self,
690
691		microkernel: unsafe extern "C" fn(),
692		pack: unsafe extern "C" fn(),
693
694		lhs: *const (),
695		packed_lhs: *mut (),
696
697		rhs: *const (),
698		packed_rhs: *mut (),
699
700		nrows: usize,
701		ncols: usize,
702
703		milli: &MillikernelInfo,
704
705		pos: Position,
706	);
707}
708
709struct Milli {
710	mr: usize,
711	nr: usize,
712	sizeof: usize,
713}
714#[cfg(feature = "rayon")]
715struct MilliPar {
716	mr: usize,
717	nr: usize,
718	hyper: usize,
719	sizeof: usize,
720
721	microkernel_job: Box<[AtomicU8]>,
722	pack_lhs_job: Box<[AtomicU8]>,
723	pack_rhs_job: Box<[AtomicU8]>,
724	finished: AtomicUsize,
725	n_threads: usize,
726}
727
728unsafe impl Millikernel for Milli {
729	unsafe fn call(
730		&mut self,
731
732		microkernel: unsafe extern "C" fn(),
733		pack: unsafe extern "C" fn(),
734
735		lhs: *const (),
736		packed_lhs: *mut (),
737
738		rhs: *const (),
739		packed_rhs: *mut (),
740
741		nrows: usize,
742		ncols: usize,
743
744		milli: &MillikernelInfo,
745		pos: Position,
746	) {
747		unsafe {
748			(if milli.micro.flags >> 63 == 1 {
749				millikernel_rowmajor
750			} else {
751				millikernel_colmajor
752			})(
753				microkernel,
754				pack,
755				self.mr,
756				self.nr,
757				self.sizeof,
758				lhs,
759				packed_lhs,
760				rhs,
761				packed_rhs,
762				nrows,
763				ncols,
764				milli,
765				&mut { pos },
766			)
767		}
768	}
769}
770
771#[derive(Copy, Clone)]
772pub struct ForceSync<T>(pub T);
773unsafe impl<T> Sync for ForceSync<T> {}
774unsafe impl<T> Send for ForceSync<T> {}
775
776#[cfg(feature = "rayon")]
777unsafe impl Millikernel for MilliPar {
778	unsafe fn call(
779		&mut self,
780
781		microkernel: unsafe extern "C" fn(),
782		pack: unsafe extern "C" fn(),
783
784		lhs: *const (),
785		packed_lhs: *mut (),
786
787		rhs: *const (),
788		packed_rhs: *mut (),
789
790		nrows: usize,
791		ncols: usize,
792
793		milli: &MillikernelInfo,
794		pos: Position,
795	) {
796		let lhs = ForceSync(lhs);
797		let mut rhs = ForceSync(rhs);
798		let packed_lhs = ForceSync(packed_lhs);
799		let packed_rhs = ForceSync(packed_rhs);
800		let milli = ForceSync(milli);
801
802		self.microkernel_job.fill_with(|| AtomicU8::new(0));
803		self.pack_lhs_job.fill_with(|| AtomicU8::new(0));
804		self.pack_rhs_job.fill_with(|| AtomicU8::new(0));
805		self.finished = AtomicUsize::new(0);
806
807		let f = Ord::min(8, milli.0.micro.depth.div_ceil(64));
808		let l3 = CACHE_INFO[2].cache_bytes / f;
809
810		let tall = nrows >= l3;
811		let wide = ncols >= 2 * nrows;
812
813		let mut mf = Ord::clamp(nrows.div_ceil(self.mr).div_ceil(2 * self.n_threads), 2, 4);
814		if tall {
815			mf = 16 / f;
816		}
817		if wide {
818			mf = 2;
819		}
820		let par_rows = nrows.div_ceil(mf * self.mr);
821		let nf = Ord::clamp(ncols.div_ceil(self.nr).div_ceil(8 * self.n_threads) * par_rows, 1, 1024 / f);
822		let nf = 32 / self.nr;
823
824		let n = nrows.div_ceil(mf * self.mr) * ncols.div_ceil(nf * self.nr);
825
826		let mr = self.mr;
827		let nr = self.nr;
828
829		if !rhs.0.is_null() && rhs.0 != packed_rhs.0 {
830			let depth = { milli }.0.micro.depth;
831
832			let div = depth / self.n_threads;
833			let rem = depth % self.n_threads;
834
835			if !wide {
836				spindle::for_each_raw(self.n_threads, |j| {
837					let mut start = j * div;
838					if j <= rem {
839						start += j;
840					} else {
841						start += rem;
842					}
843					let end = start + div + if j < rem { 1 } else { 0 };
844					let milli = { milli }.0;
845
846					for i in 0..ncols.div_ceil(nr) {
847						let col = Ord::min(ncols, i * nr);
848						let ncols = Ord::min(ncols - col, nr);
849
850						let rs = ncols;
851						let rhs = { rhs }.0.wrapping_byte_offset(milli.rhs_cs * i as isize);
852						let packed_rhs = { packed_rhs }.0.wrapping_byte_offset(milli.packed_rhs_cs * i as isize);
853
854						pack_rhs(
855							packed_rhs.wrapping_byte_offset((start * rs * self.sizeof) as isize),
856							rhs.wrapping_byte_offset(start as isize * milli.micro.rhs_rs),
857							end - start,
858							ncols,
859							milli.micro.rhs_rs,
860							milli.micro.rhs_cs,
861							self.sizeof,
862						);
863					}
864				});
865				rhs.0 = null();
866			}
867		}
868
869		let gtid = AtomicUsize::new(0);
870
871		spindle::for_each_raw(self.n_threads, |_| unsafe {
872			loop {
873				let tid = gtid.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
874				if tid >= n {
875					return;
876				}
877				let milli = { milli }.0;
878
879				millikernel_par(
880					tid,
881					n,
882					&self.microkernel_job,
883					&self.pack_lhs_job,
884					&self.pack_rhs_job,
885					&self.finished,
886					self.hyper,
887					self.mr,
888					self.nr,
889					self.sizeof,
890					mf,
891					nf,
892					microkernel,
893					pack,
894					{ lhs }.0,
895					{ packed_lhs }.0,
896					{ rhs }.0,
897					{ packed_rhs }.0,
898					nrows,
899					ncols,
900					milli,
901					pos,
902					tall,
903				);
904			}
905		});
906	}
907}
908
909#[inline(never)]
910unsafe fn kernel_imp(
911	millikernel: &mut dyn Millikernel,
912
913	microkernel: &[unsafe extern "C" fn()],
914	pack: &[unsafe extern "C" fn()],
915
916	mr: usize,
917	nr: usize,
918
919	lhs: *const (),
920	packed_lhs: *mut (),
921
922	rhs: *const (),
923	packed_rhs: *mut (),
924
925	nrows: usize,
926	ncols: usize,
927
928	row_chunk: &[usize],
929	col_chunk: &[usize],
930	lhs_rs: &[isize],
931	rhs_cs: &[isize],
932	packed_lhs_rs: &[isize],
933	packed_rhs_cs: &[isize],
934
935	row: usize,
936	col: usize,
937
938	pos: Position,
939	info: &MicrokernelInfo,
940) {
941	let _ = mr;
942
943	let mut stack: [(
944		*const (),
945		*mut (),
946		*const (),
947		*mut (),
948		usize,
949		usize,
950		usize,
951		usize,
952		usize,
953		usize,
954		usize,
955		usize,
956		bool,
957		bool,
958		bool,
959		bool,
960	); 16] = const { [(null(), null_mut(), null(), null_mut(), 0, 0, 0, 0, 0, 0, 0, 0, false, false, false, false); 16] };
961
962	stack[0] = (
963		lhs, packed_lhs, rhs, packed_rhs, row, col, nrows, ncols, 0, 0, 0, 0, false, false, false, false,
964	);
965
966	let mut pos = pos;
967	let mut depth = 0;
968	let max_depth = row_chunk.len();
969
970	let milli_rs = *lhs_rs.last().unwrap();
971	let milli_cs = *rhs_cs.last().unwrap();
972
973	let micro_rs = info.lhs_rs;
974	let micro_cs = info.rhs_cs;
975
976	let milli = MillikernelInfo {
977		lhs_rs: milli_rs,
978		packed_lhs_rs: *packed_lhs_rs.last().unwrap(),
979		rhs_cs: milli_cs,
980		packed_rhs_cs: *packed_rhs_cs.last().unwrap(),
981		micro: *info,
982	};
983	let microkernel = microkernel[nr - 1];
984	let pack = pack[0];
985
986	let q = row_chunk.len();
987	let row_chunk = &row_chunk[..q - 1];
988	let col_chunk = &col_chunk[..q - 1];
989	let lhs_rs = &lhs_rs[..q];
990	let packed_lhs_rs = &packed_lhs_rs[..q];
991	let rhs_cs = &rhs_cs[..q];
992	let packed_rhs_cs = &packed_rhs_cs[..q];
993
994	loop {
995		let (lhs, packed_lhs, rhs, packed_rhs, row, col, nrows, ncols, i, j, ii, jj, is_packed_lhs, is_packed_rhs, row_rev, col_rev) = stack[depth];
996		let row_rev = false;
997		let col_rev = false;
998
999		if depth + 1 == max_depth {
1000			let mut lhs = lhs;
1001			let mut rhs = rhs;
1002
1003			pos.row = row;
1004			pos.col = col;
1005
1006			if is_packed_lhs && lhs != packed_lhs {
1007				lhs = null();
1008			}
1009			if is_packed_rhs && rhs != packed_rhs {
1010				rhs = null();
1011			}
1012
1013			unsafe {
1014				millikernel.call(microkernel, pack, lhs, packed_lhs, rhs, packed_rhs, nrows, ncols, &milli, pos);
1015			}
1016
1017			while depth > 0 {
1018				depth -= 1;
1019
1020				let (_, _, _, _, _, _, nrows, ncols, i, j, ii, jj, _, _, _, _) = &mut stack[depth];
1021
1022				let col_chunk = col_chunk[depth];
1023				let row_chunk = row_chunk[depth];
1024
1025				let j_chunk = Ord::min(col_chunk, *ncols - *j);
1026				let i_chunk = Ord::min(row_chunk, *nrows - *i);
1027
1028				if milli.micro.flags & FLAGS_ROWMAJOR == 0 {
1029					*i += i_chunk;
1030					*ii += 1;
1031					if *i == *nrows {
1032						*i = 0;
1033						*ii = 0;
1034						*j += j_chunk;
1035						*jj += 1;
1036
1037						if *j == *ncols {
1038							if depth == 0 {
1039								return;
1040							}
1041
1042							*j = 0;
1043							*jj = 0;
1044							continue;
1045						}
1046					}
1047				} else {
1048					*j += j_chunk;
1049					*jj += 1;
1050					if *j == *ncols {
1051						*j = 0;
1052						*jj = 0;
1053						*i += i_chunk;
1054						*ii += 1;
1055
1056						if *i == *nrows {
1057							*i = 0;
1058							*ii = 0;
1059							if depth == 0 {
1060								return;
1061							}
1062							continue;
1063						}
1064					}
1065				}
1066				break;
1067			}
1068		} else {
1069			let col_chunk = col_chunk[depth];
1070			let row_chunk = row_chunk[depth];
1071			let rhs_cs = rhs_cs[depth];
1072			let lhs_rs = lhs_rs[depth];
1073			let prhs_cs = packed_rhs_cs[depth];
1074			let plhs_rs = packed_lhs_rs[depth];
1075
1076			let last_row_chunk = if nrows == 0 { 0 } else { ((nrows - 1) % row_chunk) + 1 };
1077
1078			let last_col_chunk = if ncols == 0 { 0 } else { ((ncols - 1) % col_chunk) + 1 };
1079
1080			let (i, ii) = if row_rev {
1081				(nrows - last_row_chunk - i, nrows.div_ceil(row_chunk) - 1 - ii)
1082			} else {
1083				(i, ii)
1084			};
1085
1086			let (j, jj) = if col_rev {
1087				(ncols - last_col_chunk - j, ncols.div_ceil(col_chunk) - 1 - jj)
1088			} else {
1089				(j, jj)
1090			};
1091			assert!(i as isize >= 0);
1092			assert!(j as isize >= 0);
1093
1094			let j_chunk = Ord::min(col_chunk, ncols - j);
1095			let i_chunk = Ord::min(row_chunk, nrows - i);
1096
1097			depth += 1;
1098			stack[depth] = (
1099				lhs.wrapping_byte_offset(lhs_rs * ii as isize),
1100				packed_lhs.wrapping_byte_offset(plhs_rs * ii as isize),
1101				rhs.wrapping_byte_offset(rhs_cs * jj as isize),
1102				packed_rhs.wrapping_byte_offset(prhs_cs * jj as isize),
1103				row + i,
1104				col + j,
1105				i_chunk,
1106				j_chunk,
1107				0,
1108				0,
1109				0,
1110				0,
1111				is_packed_lhs || (j > 0 && packed_lhs_rs[depth - 1] != 0),
1112				is_packed_rhs || (i > 0 && packed_rhs_cs[depth - 1] != 0),
1113				jj % 2 == 1,
1114				ii % 2 == 1,
1115			);
1116			continue;
1117		}
1118	}
1119}
1120
1121#[derive(Copy, Clone, Debug, PartialEq, Eq)]
1122pub enum InstrSet {
1123	Avx256,
1124	Avx512,
1125}
1126
1127#[derive(Copy, Clone, Debug, PartialEq, Eq)]
1128pub enum DType {
1129	F32,
1130	F64,
1131	C32,
1132	C64,
1133}
1134
1135#[derive(Copy, Clone, Debug, PartialEq, Eq)]
1136pub enum Accum {
1137	Replace,
1138	Add,
1139}
1140
1141#[derive(Copy, Clone, Debug, PartialEq, Eq)]
1142pub enum IType {
1143	U32,
1144	U64,
1145}
1146
1147#[derive(Copy, Clone, Debug, PartialEq, Eq)]
1148pub enum DstKind {
1149	Lower,
1150	Upper,
1151	Full,
1152}
1153
1154pub unsafe fn gemm(
1155	dtype: DType,
1156	itype: IType,
1157
1158	instr: InstrSet,
1159	nrows: usize,
1160	ncols: usize,
1161	depth: usize,
1162
1163	dst: *mut (),
1164	dst_rs: isize,
1165	dst_cs: isize,
1166	dst_row_idx: *const (),
1167	dst_col_idx: *const (),
1168	dst_kind: DstKind,
1169
1170	beta: Accum,
1171
1172	lhs: *const (),
1173	lhs_rs: isize,
1174	lhs_cs: isize,
1175	conj_lhs: bool,
1176
1177	real_diag: *const (),
1178	diag_stride: isize,
1179
1180	rhs: *const (),
1181	rhs_rs: isize,
1182	rhs_cs: isize,
1183	conj_rhs: bool,
1184
1185	alpha: *const (),
1186
1187	n_threads: usize,
1188) {
1189	let (sizeof, cplx) = match dtype {
1190		DType::F32 => (4, false),
1191		DType::F64 => (8, false),
1192		DType::C32 => (8, true),
1193		DType::C64 => (16, true),
1194	};
1195	let mut lhs_rs = lhs_rs * sizeof as isize;
1196	let mut lhs_cs = lhs_cs * sizeof as isize;
1197	let mut rhs_rs = rhs_rs * sizeof as isize;
1198	let mut rhs_cs = rhs_cs * sizeof as isize;
1199	let mut dst_rs = dst_rs * sizeof as isize;
1200	let mut dst_cs = dst_cs * sizeof as isize;
1201	let real_diag_stride = diag_stride * sizeof as isize;
1202
1203	if nrows == 0 || ncols == 0 || (depth == 0 && beta == Accum::Add) {
1204		return;
1205	}
1206
1207	let mut nrows = nrows;
1208	let mut ncols = ncols;
1209
1210	let mut dst = dst;
1211	let mut dst_row_idx = dst_row_idx;
1212	let mut dst_col_idx = dst_col_idx;
1213	let mut dst_kind = dst_kind;
1214
1215	let mut lhs = lhs;
1216	let mut conj_lhs = conj_lhs;
1217
1218	let mut rhs = rhs;
1219	let mut conj_rhs = conj_rhs;
1220
1221	if dst_rs.unsigned_abs() > dst_cs.unsigned_abs() {
1222		use core::mem::swap;
1223		swap(&mut dst_rs, &mut dst_cs);
1224		swap(&mut dst_row_idx, &mut dst_col_idx);
1225		dst_kind = match dst_kind {
1226			DstKind::Lower => DstKind::Upper,
1227			DstKind::Upper => DstKind::Lower,
1228			DstKind::Full => DstKind::Full,
1229		};
1230		swap(&mut lhs, &mut rhs);
1231		swap(&mut lhs_rs, &mut rhs_cs);
1232		swap(&mut lhs_cs, &mut rhs_rs);
1233		swap(&mut conj_lhs, &mut conj_rhs);
1234		swap(&mut nrows, &mut ncols);
1235	}
1236
1237	if dst_rs < 0 && dst_kind == DstKind::Full && dst_row_idx.is_null() {
1238		dst = dst.wrapping_byte_offset((nrows - 1) as isize * dst_rs);
1239		lhs = lhs.wrapping_byte_offset((nrows - 1) as isize * lhs_rs);
1240		dst_rs = -dst_rs;
1241		lhs_rs = -lhs_rs;
1242	}
1243
1244	if lhs_cs < 0 && depth > 0 {
1245		lhs = lhs.wrapping_byte_offset((depth - 1) as isize * lhs_cs);
1246		rhs = rhs.wrapping_byte_offset((depth - 1) as isize * rhs_rs);
1247
1248		lhs_cs = -lhs_cs;
1249		rhs_rs = -rhs_rs;
1250	}
1251
1252	let (microkernel, pack, mr, nr) = match (instr, dtype) {
1253		(InstrSet::Avx256, DType::F32) => (F32_SIMD256.as_slice(), F32_SIMDpack_256.as_slice(), 24, 4),
1254		(InstrSet::Avx256, DType::F64) => (F64_SIMD256.as_slice(), F64_SIMDpack_256.as_slice(), 12, 4),
1255		(InstrSet::Avx256, DType::C32) => (C32_SIMD256.as_slice(), C32_SIMDpack_256.as_slice(), 12, 4),
1256		(InstrSet::Avx256, DType::C64) => (C64_SIMD256.as_slice(), C64_SIMDpack_256.as_slice(), 6, 4),
1257		(InstrSet::Avx512, DType::F32) => (F32_SIMD512x4.as_slice(), F32_SIMDpack_512.as_slice(), 96, 4),
1258		(InstrSet::Avx512, DType::F64) => {
1259			if nrows > 48 {
1260				(F64_SIMD512x4.as_slice(), F64_SIMDpack_512.as_slice(), 48, 4)
1261			} else {
1262				(F64_SIMD512x8.as_slice(), F64_SIMDpack_512.as_slice(), 24, 8)
1263			}
1264		},
1265		(InstrSet::Avx512, DType::C32) => (C32_SIMD512x4.as_slice(), C32_SIMDpack_512.as_slice(), 48, 4),
1266		(InstrSet::Avx512, DType::C64) => (C64_SIMD512x4.as_slice(), C64_SIMDpack_512.as_slice(), 24, 4),
1267	};
1268
1269	let m = nrows;
1270	let n = ncols;
1271
1272	let kc = Ord::min(depth, 512);
1273
1274	let cache = *cache::CACHE_INFO;
1275
1276	let l1 = cache[0].cache_bytes / sizeof;
1277	let l2 = cache[1].cache_bytes / sizeof;
1278	let l3 = cache[2].cache_bytes / sizeof;
1279
1280	#[repr(align(4096))]
1281	struct Page([u8; 4096]);
1282
1283	let lhs_size = (l3.next_multiple_of(16) * sizeof).div_ceil(size_of::<Page>());
1284	let rhs_size = (l3.next_multiple_of(nr) * sizeof).div_ceil(size_of::<Page>());
1285
1286	thread_local! {
1287		static MEM: RefCell<Vec::<core::mem::MaybeUninit<Page>>> = {
1288			let cache = *cache::CACHE_INFO;
1289			let l3 = cache[2].cache_bytes;
1290
1291			let lhs_size = l3.div_ceil(size_of::<Page>());
1292			let rhs_size = l3.div_ceil(size_of::<Page>());
1293
1294			let mut mem = Vec::with_capacity(lhs_size + rhs_size);
1295			unsafe { mem.set_len(lhs_size + rhs_size) };
1296			RefCell::new(mem)
1297		};
1298	}
1299
1300	MEM.with(|mem| {
1301		let mut storage;
1302		let mut alloc;
1303
1304		let mem = match mem.try_borrow_mut() {
1305			Ok(mem) => {
1306				storage = mem;
1307				&mut *storage
1308			},
1309			Err(_) => {
1310				alloc = Vec::with_capacity(lhs_size + rhs_size);
1311
1312				&mut alloc
1313			},
1314		};
1315		if mem.len() < lhs_size + rhs_size {
1316			mem.reserve_exact(lhs_size + rhs_size);
1317			unsafe { mem.set_len(lhs_size + rhs_size) };
1318		}
1319
1320		let (packed_lhs, packed_rhs) = mem.split_at_mut(lhs_size);
1321		let (packed_rhs, _) = packed_rhs.split_at_mut(rhs_size);
1322
1323		let lhs = ForceSync(lhs);
1324		let rhs = ForceSync(rhs);
1325		let dst = ForceSync(dst);
1326		let real_diag = ForceSync(real_diag);
1327		let dst_row_idx = ForceSync(dst_row_idx);
1328		let dst_col_idx = ForceSync(dst_col_idx);
1329		let alpha = ForceSync(alpha);
1330		let mut f = || {
1331			let mut k = 0;
1332			let mut beta = beta;
1333			let mut lhs = { lhs }.0;
1334			let mut rhs = { rhs }.0;
1335			let mut real_diag = { real_diag }.0;
1336			let dst = { dst }.0;
1337			while k < depth {
1338				let kc = Ord::min(depth - k, kc);
1339
1340				let f = kc.div_ceil(64);
1341				let l1 = l1 / 64 / f;
1342				let l2 = l2 / 64 / f;
1343				let l3 = l3 / 64 / f;
1344
1345				let tall = m >= 3 * n / 2 && m >= l3;
1346				let pack_lhs = !real_diag.is_null() || (n > 6 * nr && tall) || (n > 3 * nr * n_threads) || lhs_rs != sizeof as isize;
1347				let pack_rhs = tall;
1348
1349				let rowmajor = if n_threads > 1 {
1350					false
1351				} else if tall {
1352					true
1353				} else {
1354					false
1355				};
1356
1357				let info = MicrokernelInfo {
1358					flags: match beta {
1359						Accum::Replace => 0,
1360						Accum::Add => FLAGS_ACCUM,
1361					} | if conj_lhs { FLAGS_CONJ_LHS } else { 0 }
1362						| if conj_lhs != conj_rhs { FLAGS_CONJ_NEQ } else { 0 }
1363						| match itype {
1364							IType::U32 => FLAGS_32BIT_IDX,
1365							IType::U64 => 0,
1366						} | if cplx { FLAGS_CPLX } else { 0 }
1367						| match dst_kind {
1368							DstKind::Lower => FLAGS_LOWER,
1369							DstKind::Upper => FLAGS_UPPER,
1370							DstKind::Full => 0,
1371						} | if rowmajor { FLAGS_ROWMAJOR } else { 0 },
1372					depth: kc,
1373					lhs_rs,
1374					lhs_cs,
1375					rhs_rs,
1376					rhs_cs,
1377					alpha: { alpha }.0,
1378					ptr: dst,
1379					rs: dst_rs,
1380					cs: dst_cs,
1381					row_idx: { dst_row_idx }.0,
1382					col_idx: { dst_col_idx }.0,
1383					diag_ptr: real_diag,
1384					diag_stride: real_diag_stride,
1385				};
1386
1387				if n_threads <= 1 && !rowmajor && m < l2 && n < l2 {
1388					let microkernel = microkernel[nr - 1];
1389					let pack = pack[0];
1390					millikernel_colmajor(
1391						microkernel,
1392						pack,
1393						mr,
1394						nr,
1395						sizeof,
1396						lhs,
1397						if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs as _ },
1398						rhs,
1399						if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs as _ },
1400						nrows,
1401						ncols,
1402						&MillikernelInfo {
1403							lhs_rs: lhs_rs * mr as isize,
1404							packed_lhs_rs: if pack_lhs { (sizeof * mr * kc) as isize } else { lhs_rs * mr as isize },
1405							rhs_cs: rhs_cs * nr as isize,
1406							packed_rhs_cs: if pack_rhs { (sizeof * nr * kc) as isize } else { rhs_cs * nr as isize },
1407							micro: info,
1408						},
1409						&mut Position { row: 0, col: 0 },
1410					);
1411				} else {
1412					let (row_chunk, col_chunk, rowmajor) = if n_threads > 1 {
1413						(
1414							//
1415							[m, m, m, l3 / 16 * 16, mr],
1416							[n, n, n, l3 / 16 * 16, nr],
1417							false,
1418						)
1419					} else if true {
1420						(
1421							//
1422							[m, l3, l2, l2 / 2, mr],
1423							[n, 2 * l3, l3 / 2, l2, nr],
1424							true,
1425						)
1426					} else {
1427						(
1428							//
1429							[2 * l3, l3 / 2, l3 / 2, l2, mr],
1430							[l3, l3 / 2, l2 / 2, l1, nr],
1431							false,
1432						)
1433					};
1434
1435					let mut row_chunk = row_chunk.map(|r| if r == mr { mr } else { r.next_multiple_of(16) });
1436					let mut col_chunk = col_chunk.map(|c| c.next_multiple_of(nr));
1437
1438					let q = row_chunk.len();
1439					{
1440						for i in (1..q - 1).rev() {
1441							row_chunk[i - 1] = Ord::max(row_chunk[i - 1].next_multiple_of(row_chunk[i]), row_chunk[i]);
1442							if row_chunk[i - 1] > l3 / 2 && row_chunk[i - 1] < l3 {
1443								row_chunk[i - 1] = l3 / 2;
1444							}
1445							if row_chunk[i - 1] >= l3 {
1446								row_chunk[i - 1] = Ord::min(row_chunk[i - 1], 2 * row_chunk[i]);
1447							}
1448						}
1449						for i in (1..q - 1).rev() {
1450							col_chunk[i - 1] = Ord::max(col_chunk[i - 1].next_multiple_of(col_chunk[i]), col_chunk[i]);
1451							if col_chunk[i - 1] > l3 / 2 && col_chunk[i - 1] < l3 {
1452								col_chunk[i - 1] = l3 / 2;
1453							}
1454							if col_chunk[i - 1] >= l3 {
1455								col_chunk[i - 1] = Ord::min(col_chunk[i - 1], 2 * col_chunk[i]);
1456							}
1457						}
1458					}
1459
1460					let all_lhs_rs = row_chunk.map(|m| m as isize * lhs_rs);
1461					let all_rhs_cs = col_chunk.map(|n| n as isize * rhs_cs);
1462
1463					let mut packed_lhs_rs = row_chunk.map(|x| if x > l3 / 2 { 0 } else { (x * kc * sizeof) as isize });
1464					let mut packed_rhs_cs = col_chunk.map(|x| if x > l3 / 2 { 0 } else { (x * kc * sizeof) as isize });
1465					packed_lhs_rs[0] = 0;
1466					packed_rhs_cs[0] = 0;
1467
1468					assert!(lhs_size * size_of::<Page>() >= row_chunk[q - 2] * kc * sizeof);
1469					assert!(rhs_size * size_of::<Page>() >= col_chunk[q - 2] * kc * sizeof);
1470
1471					unsafe {
1472						kernel(
1473							n_threads,
1474							microkernel,
1475							pack,
1476							mr,
1477							nr,
1478							sizeof,
1479							lhs,
1480							if pack_lhs { packed_lhs.as_mut_ptr() as *mut () } else { lhs as *mut () },
1481							rhs,
1482							if pack_rhs { packed_rhs.as_mut_ptr() as *mut () } else { rhs as *mut () },
1483							nrows,
1484							ncols,
1485							&row_chunk,
1486							&col_chunk,
1487							&all_lhs_rs,
1488							&all_rhs_cs,
1489							if pack_lhs { &packed_lhs_rs } else { &all_lhs_rs },
1490							if pack_rhs { &packed_rhs_cs } else { &all_rhs_cs },
1491							0,
1492							0,
1493							Position { row: 0, col: 0 },
1494							&info,
1495						)
1496					};
1497				}
1498
1499				k += kc;
1500				lhs = lhs.wrapping_byte_offset(lhs_cs * kc as isize);
1501				rhs = rhs.wrapping_byte_offset(rhs_rs * kc as isize);
1502				real_diag = real_diag.wrapping_byte_offset(real_diag_stride * kc as isize);
1503
1504				beta = Accum::Add;
1505			}
1506		};
1507		if n_threads <= 1 {
1508			f();
1509		} else {
1510			#[cfg(feature = "rayon")]
1511			spindle::with_lock(n_threads, f);
1512
1513			#[cfg(not(feature = "rayon"))]
1514			f();
1515		}
1516	});
1517}
1518
1519pub unsafe fn kernel(
1520	n_threads: usize,
1521	microkernel: &[unsafe extern "C" fn()],
1522	pack: &[unsafe extern "C" fn()],
1523
1524	mr: usize,
1525	nr: usize,
1526	sizeof: usize,
1527
1528	lhs: *const (),
1529	packed_lhs: *mut (),
1530
1531	rhs: *const (),
1532	packed_rhs: *mut (),
1533
1534	nrows: usize,
1535	ncols: usize,
1536
1537	row_chunk: &[usize],
1538	col_chunk: &[usize],
1539	lhs_rs: &[isize],
1540	rhs_cs: &[isize],
1541	packed_lhs_rs: &[isize],
1542	packed_rhs_cs: &[isize],
1543
1544	row: usize,
1545	col: usize,
1546
1547	pos: Position,
1548	info: &MicrokernelInfo,
1549) {
1550	unsafe {
1551		let mut seq = Milli { mr, nr, sizeof };
1552		#[cfg(feature = "rayon")]
1553		let mut par;
1554		kernel_imp(
1555			#[cfg(feature = "rayon")]
1556			if n_threads > 1 {
1557				par = {
1558					let max_i = nrows.div_ceil(mr);
1559					let max_j = ncols.div_ceil(nr);
1560					let max_jobs = max_i * max_j;
1561					let c = max_i;
1562
1563					MilliPar {
1564						mr,
1565						nr,
1566						sizeof,
1567						hyper: 1,
1568						microkernel_job: (0..c * max_j).map(|_| AtomicU8::new(0)).collect(),
1569						pack_lhs_job: (0..max_i).map(|_| AtomicU8::new(0)).collect(),
1570						pack_rhs_job: (0..max_j).map(|_| AtomicU8::new(0)).collect(),
1571						finished: AtomicUsize::new(0),
1572						n_threads,
1573					}
1574				};
1575				&mut par
1576			} else {
1577				&mut seq
1578			},
1579			#[cfg(not(feature = "rayon"))]
1580			&mut seq,
1581			microkernel,
1582			pack,
1583			mr,
1584			nr,
1585			lhs,
1586			packed_lhs,
1587			rhs,
1588			packed_rhs,
1589			nrows,
1590			ncols,
1591			row_chunk,
1592			col_chunk,
1593			lhs_rs,
1594			rhs_cs,
1595			packed_lhs_rs,
1596			packed_rhs_cs,
1597			row,
1598			col,
1599			pos,
1600			info,
1601		)
1602	};
1603}
1604
1605#[cfg(test)]
1606mod tests_f64 {
1607	use core::ptr::null_mut;
1608
1609	use super::*;
1610
1611	use aligned_vec::*;
1612	use rand::prelude::*;
1613
1614	#[test]
1615	fn test_avx512_microkernel() {
1616		let rng = &mut StdRng::seed_from_u64(0);
1617
1618		let sizeof = size_of::<f64>() as isize;
1619		let len = 64 / size_of::<f64>();
1620
1621		for pack_lhs in [false, true] {
1622			for pack_rhs in [false] {
1623				for alpha in [1.0.into(), 0.0.into(), 2.5.into()] {
1624					let alpha: f64 = alpha;
1625					for m in 1..=48usize {
1626						for n in (1..=4usize).chain([5]) {
1627							for cs in [m.next_multiple_of(48)] {
1628								let acs = m.next_multiple_of(48);
1629								let k = 2usize;
1630
1631								let packed_lhs: &mut [f64] = &mut *avec![0.0.into(); acs * k];
1632								let packed_rhs: &mut [f64] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
1633								let lhs: &mut [f64] = &mut *avec![0.0.into(); cs * k];
1634								let rhs: &mut [f64] = &mut *avec![0.0.into(); n * k];
1635								let dst: &mut [f64] = &mut *avec![0.0.into(); cs * n];
1636								let target = &mut *avec![0.0.into(); cs * n];
1637
1638								rng.fill(lhs);
1639								rng.fill(rhs);
1640
1641								for i in 0..m {
1642									for j in 0..n {
1643										let target = &mut target[i + cs * j];
1644										let mut acc = 0.0.into();
1645										for depth in 0..k {
1646											acc = f64::mul_add(lhs[i + cs * depth], rhs[depth + k * j], acc);
1647										}
1648										*target = f64::mul_add(acc, alpha, *target);
1649									}
1650								}
1651
1652								unsafe {
1653									millikernel_colmajor(
1654										F64_SIMD512x4[3],
1655										F64_SIMDpack_512[0],
1656										48,
1657										4,
1658										8,
1659										lhs.as_ptr() as _,
1660										if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
1661										rhs.as_ptr() as _,
1662										if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
1663										m,
1664										n,
1665										&mut MillikernelInfo {
1666											lhs_rs: 48 * sizeof,
1667											packed_lhs_rs: if pack_lhs { 48 * sizeof * k as isize } else { 48 * sizeof },
1668											rhs_cs: 4 * sizeof * k as isize,
1669											packed_rhs_cs: 4 * sizeof * k as isize,
1670											micro: MicrokernelInfo {
1671												flags: 0,
1672												depth: k,
1673												lhs_rs: 1 * sizeof,
1674												lhs_cs: cs as isize * sizeof,
1675												rhs_rs: 1 * sizeof,
1676												rhs_cs: k as isize * sizeof,
1677												alpha: &raw const alpha as _,
1678												ptr: dst.as_mut_ptr() as _,
1679												rs: 1 * sizeof,
1680												cs: cs as isize * sizeof,
1681												row_idx: null_mut(),
1682												col_idx: null_mut(),
1683												diag_ptr: null(),
1684												diag_stride: 0,
1685											},
1686										},
1687										&mut Position { row: 0, col: 0 },
1688									)
1689								};
1690								assert_eq!(dst, target);
1691							}
1692						}
1693					}
1694				}
1695			}
1696		}
1697	}
1698
1699	#[test]
1700	fn test_gemm() {
1701		let rng = &mut StdRng::seed_from_u64(0);
1702
1703		let sizeof = size_of::<f64>() as isize;
1704		let len = 64 / size_of::<f64>();
1705
1706		for instr in [InstrSet::Avx256, InstrSet::Avx512] {
1707			for pack_lhs in [false, true] {
1708				for pack_rhs in [false] {
1709					for alpha in [1.0.into(), 0.0.into(), 2.5.into()] {
1710						let alpha: f64 = alpha;
1711						for m in (1..=48usize).chain([513]) {
1712							for n in (1..=4usize).chain([512]) {
1713								for cs in [m.next_multiple_of(48)] {
1714									let acs = m.next_multiple_of(48);
1715									let k = 513usize;
1716
1717									let packed_lhs: &mut [f64] = &mut *avec![0.0.into(); acs * k];
1718									let packed_rhs: &mut [f64] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
1719									let lhs: &mut [f64] = &mut *avec![0.0.into(); cs * k];
1720									let rhs: &mut [f64] = &mut *avec![0.0.into(); n * k];
1721									let dst: &mut [f64] = &mut *avec![0.0.into(); cs * n];
1722									let target = &mut *avec![0.0.into(); cs * n];
1723
1724									rng.fill(lhs);
1725									rng.fill(rhs);
1726
1727									for i in 0..m {
1728										for j in 0..n {
1729											let target = &mut target[i + cs * j];
1730											let mut acc = 0.0.into();
1731											for depth in 0..k {
1732												acc = f64::mul_add(lhs[i + cs * depth], rhs[depth + k * j], acc);
1733											}
1734											*target = f64::mul_add(acc, alpha, *target);
1735										}
1736									}
1737
1738									unsafe {
1739										gemm(
1740											DType::F64,
1741											IType::U64,
1742											instr,
1743											m,
1744											n,
1745											k,
1746											dst.as_mut_ptr() as _,
1747											1,
1748											cs as isize,
1749											null(),
1750											null(),
1751											DstKind::Full,
1752											Accum::Add,
1753											lhs.as_ptr() as _,
1754											1,
1755											cs as isize,
1756											false,
1757											null(),
1758											0,
1759											rhs.as_ptr() as _,
1760											1,
1761											k as isize,
1762											false,
1763											&raw const alpha as _,
1764											1,
1765										)
1766									};
1767									let mut i = 0;
1768									for (&target, &dst) in core::iter::zip(&*target, &*dst) {
1769										if !((target - dst).abs() < 1e-6) {
1770											dbg!(i / cs, i % cs, target, dst);
1771											panic!();
1772										}
1773										i += 1;
1774									}
1775								}
1776							}
1777						}
1778					}
1779				}
1780			}
1781		}
1782	}
1783	#[test]
1784	fn test_avx512_kernel() {
1785		let m = 1023usize;
1786		let n = 1023usize;
1787		let k = 5usize;
1788
1789		let rng = &mut StdRng::seed_from_u64(0);
1790		let sizeof = size_of::<f64>() as isize;
1791		let cs = m.next_multiple_of(8);
1792		let cs = Ord::max(4096, cs);
1793
1794		let lhs: &mut [f64] = &mut *avec![0.0; cs * k];
1795		let rhs: &mut [f64] = &mut *avec![0.0; k * n];
1796		let target: &mut [f64] = &mut *avec![0.0; cs * n];
1797
1798		rng.fill(lhs);
1799		rng.fill(rhs);
1800
1801		unsafe {
1802			gemm::gemm(
1803				m,
1804				n,
1805				k,
1806				target.as_mut_ptr(),
1807				cs as isize,
1808				1,
1809				true,
1810				lhs.as_ptr(),
1811				cs as isize,
1812				1,
1813				rhs.as_ptr(),
1814				k as isize,
1815				1,
1816				1.0,
1817				1.0,
1818				false,
1819				false,
1820				false,
1821				gemm::Parallelism::None,
1822			);
1823		}
1824
1825		for pack_lhs in [false, true] {
1826			for pack_rhs in [false] {
1827				let dst = &mut *avec![0.0; cs * n];
1828				let packed_lhs = &mut *avec![0.0f64; m.next_multiple_of(8) * k];
1829				let packed_rhs = &mut *avec![0.0; if pack_rhs { n.next_multiple_of(4) * k } else { 0 }];
1830
1831				unsafe {
1832					let row_chunk = [48 * 32, 48 * 16, 48];
1833					let col_chunk = [48 * 64, 48 * 32, 48, 4];
1834
1835					let lhs_rs = row_chunk.map(|m| m as isize * sizeof);
1836					let rhs_cs = col_chunk.map(|n| (n * k) as isize * sizeof);
1837					let packed_lhs_rs = row_chunk.map(|m| (m * k) as isize * sizeof);
1838					let packed_rhs_cs = col_chunk.map(|n| (n * k) as isize * sizeof);
1839
1840					kernel(
1841						1,
1842						&F64_SIMD512x4[..24],
1843						&F64_SIMDpack_512,
1844						48,
1845						4,
1846						size_of::<f64>(),
1847						lhs.as_ptr() as _,
1848						if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
1849						rhs.as_ptr() as _,
1850						if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
1851						m,
1852						n,
1853						&row_chunk,
1854						&col_chunk,
1855						&lhs_rs,
1856						&rhs_cs,
1857						&if pack_lhs { packed_lhs_rs } else { lhs_rs },
1858						&if pack_rhs { packed_rhs_cs } else { rhs_cs },
1859						0,
1860						0,
1861						Position { row: 0, col: 0 },
1862						&MicrokernelInfo {
1863							flags: 0,
1864							depth: k,
1865							lhs_rs: sizeof,
1866							lhs_cs: cs as isize * sizeof,
1867							rhs_rs: sizeof,
1868							rhs_cs: k as isize * sizeof,
1869							alpha: &raw const *&1.0f64 as _,
1870							ptr: dst.as_mut_ptr() as _,
1871							rs: sizeof,
1872							cs: cs as isize * sizeof,
1873							row_idx: null_mut(),
1874							col_idx: null_mut(),
1875							diag_ptr: null(),
1876							diag_stride: 0,
1877						},
1878					);
1879				}
1880				let mut i = 0;
1881				for (&target, &dst) in core::iter::zip(&*target, &*dst) {
1882					if !((target - dst).abs() < 1e-6) {
1883						dbg!(i / cs, i % cs, target, dst);
1884						panic!();
1885					}
1886					i += 1;
1887				}
1888			}
1889		}
1890	}
1891}
1892
1893#[cfg(test)]
1894mod tests_c64 {
1895	use super::*;
1896
1897	use aligned_vec::*;
1898	use bytemuck::*;
1899	use core::ptr::null_mut;
1900	use gemm::c64;
1901	use rand::prelude::*;
1902
1903	#[test]
1904	fn test_avx512_microkernel() {
1905		let rng = &mut StdRng::seed_from_u64(0);
1906
1907		let sizeof = size_of::<c64>() as isize;
1908		let len = 64 / size_of::<c64>();
1909
1910		for pack_lhs in [false, true] {
1911			for pack_rhs in [false] {
1912				for alpha in [1.0.into(), 0.0.into(), c64::new(0.0, 3.5), c64::new(2.5, 3.5)] {
1913					let alpha: c64 = alpha;
1914					for m in 1..=24usize {
1915						for n in (1..=4usize).into_iter().chain([8]) {
1916							for cs in [m.next_multiple_of(len), m] {
1917								for conj_lhs in [false, true] {
1918									for conj_rhs in [false, true] {
1919										let conj_different = conj_lhs != conj_rhs;
1920
1921										let acs = m.next_multiple_of(len);
1922										let k = 1usize;
1923
1924										let packed_lhs: &mut [c64] = &mut *avec![0.0.into(); acs * k];
1925										let packed_rhs: &mut [c64] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
1926										let lhs: &mut [c64] = &mut *avec![0.0.into(); cs * k];
1927										let rhs: &mut [c64] = &mut *avec![0.0.into(); n * k];
1928										let dst: &mut [c64] = &mut *avec![0.0.into(); cs * n];
1929										let target: &mut [c64] = &mut *avec![0.0.into(); cs * n];
1930
1931										rng.fill(cast_slice_mut::<c64, f64>(lhs));
1932										rng.fill(cast_slice_mut::<c64, f64>(rhs));
1933
1934										for i in 0..m {
1935											for j in 0..n {
1936												let target = &mut target[i + cs * j];
1937												let mut acc: c64 = 0.0.into();
1938												for depth in 0..k {
1939													let mut l = lhs[i + cs * depth];
1940													let mut r = rhs[depth + k * j];
1941													if conj_lhs {
1942														l = l.conj();
1943													}
1944													if conj_rhs {
1945														r = r.conj();
1946													}
1947
1948													acc = l * r + acc;
1949												}
1950												*target = acc * alpha + *target;
1951											}
1952										}
1953
1954										unsafe {
1955											millikernel_colmajor(
1956												C64_SIMD512x4[3],
1957												C64_SIMDpack_512[0],
1958												24,
1959												4,
1960												16,
1961												lhs.as_ptr() as _,
1962												if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
1963												rhs.as_ptr() as _,
1964												if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
1965												m,
1966												n,
1967												&mut MillikernelInfo {
1968													lhs_rs: 24 * sizeof,
1969													packed_lhs_rs: 24 * sizeof * k as isize,
1970													rhs_cs: 4 * sizeof * k as isize,
1971													packed_rhs_cs: 4 * sizeof * k as isize,
1972													micro: MicrokernelInfo {
1973														flags: ((conj_lhs as usize) << 1) | ((conj_different as usize) << 2),
1974														depth: k,
1975														lhs_rs: 1 * sizeof,
1976														lhs_cs: cs as isize * sizeof,
1977														rhs_rs: 1 * sizeof,
1978														rhs_cs: k as isize * sizeof,
1979														alpha: &raw const alpha as _,
1980														ptr: dst.as_mut_ptr() as _,
1981														rs: 1 * sizeof,
1982														cs: cs as isize * sizeof,
1983														row_idx: null_mut(),
1984														col_idx: null_mut(),
1985														diag_ptr: null(),
1986														diag_stride: 0,
1987													},
1988												},
1989												&mut Position { row: 0, col: 0 },
1990											)
1991										};
1992										let mut i = 0;
1993										for (&target, &dst) in core::iter::zip(&*target, &*dst) {
1994											if !((target - dst).norm_sqr().sqrt() < 1e-6) {
1995												dbg!(i / cs, i % cs, target, dst);
1996												panic!();
1997											}
1998											i += 1;
1999										}
2000									}
2001								}
2002							}
2003						}
2004					}
2005				}
2006			}
2007		}
2008	}
2009}
2010
2011#[cfg(test)]
2012mod tests_f32 {
2013	use core::ptr::null_mut;
2014
2015	use super::*;
2016
2017	use aligned_vec::*;
2018	use rand::prelude::*;
2019
2020	#[test]
2021	fn test_avx512_microkernel() {
2022		let rng = &mut StdRng::seed_from_u64(0);
2023
2024		let sizeof = size_of::<f32>() as isize;
2025		let len = 64 / size_of::<f32>();
2026
2027		for pack_lhs in [false, true] {
2028			for pack_rhs in [false] {
2029				for alpha in [1.0.into(), 0.0.into(), 2.5.into()] {
2030					let alpha: f32 = alpha;
2031					for m in 1..=96usize {
2032						for n in (1..=4usize).into_iter().chain([8]) {
2033							for cs in [m.next_multiple_of(len), m] {
2034								let acs = m.next_multiple_of(len);
2035								let k = 1usize;
2036
2037								let packed_lhs: &mut [f32] = &mut *avec![0.0.into(); acs * k];
2038								let packed_rhs: &mut [f32] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
2039								let lhs: &mut [f32] = &mut *avec![0.0.into(); cs * k];
2040								let rhs: &mut [f32] = &mut *avec![0.0.into(); n * k];
2041								let dst: &mut [f32] = &mut *avec![0.0.into(); cs * n];
2042								let target = &mut *avec![0.0.into(); cs * n];
2043
2044								rng.fill(lhs);
2045								rng.fill(rhs);
2046
2047								for i in 0..m {
2048									for j in 0..n {
2049										let target = &mut target[i + cs * j];
2050										let mut acc = 0.0.into();
2051										for depth in 0..k {
2052											acc = f32::mul_add(lhs[i + cs * depth], rhs[depth + k * j], acc);
2053										}
2054										*target = f32::mul_add(acc, alpha, *target);
2055									}
2056								}
2057
2058								unsafe {
2059									millikernel_rowmajor(
2060										F32_SIMD512x4[3],
2061										F32_SIMDpack_512[0],
2062										96,
2063										4,
2064										4,
2065										lhs.as_ptr() as _,
2066										if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
2067										rhs.as_ptr() as _,
2068										if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
2069										m,
2070										n,
2071										&mut MillikernelInfo {
2072											lhs_rs: 96 * sizeof,
2073											packed_lhs_rs: 96 * sizeof * k as isize,
2074											rhs_cs: 4 * sizeof * k as isize,
2075											packed_rhs_cs: 4 * sizeof * k as isize,
2076											micro: MicrokernelInfo {
2077												flags: (1 << 63),
2078												depth: k,
2079												lhs_rs: 1 * sizeof,
2080												lhs_cs: cs as isize * sizeof,
2081												rhs_rs: 1 * sizeof,
2082												rhs_cs: k as isize * sizeof,
2083												alpha: &raw const alpha as _,
2084												ptr: dst.as_mut_ptr() as _,
2085												rs: 1 * sizeof,
2086												cs: cs as isize * sizeof,
2087												row_idx: null_mut(),
2088												col_idx: null_mut(),
2089												diag_ptr: null(),
2090												diag_stride: 0,
2091											},
2092										},
2093										&mut Position { row: 0, col: 0 },
2094									)
2095								};
2096								assert_eq!(dst, target);
2097							}
2098						}
2099					}
2100				}
2101			}
2102		}
2103	}
2104
2105	#[test]
2106	fn test_avx512_kernel() {
2107		let m = 6000usize;
2108		let n = 2000usize;
2109		let k = 5usize;
2110
2111		let rng = &mut StdRng::seed_from_u64(0);
2112		let sizeof = size_of::<f32>() as isize;
2113		let cs = m.next_multiple_of(16);
2114		let cs = Ord::max(4096, cs);
2115
2116		let lhs: &mut [f32] = &mut *avec![0.0; cs * k];
2117		let rhs: &mut [f32] = &mut *avec![0.0; k * n];
2118		let target: &mut [f32] = &mut *avec![0.0; cs * n];
2119
2120		rng.fill(lhs);
2121		rng.fill(rhs);
2122
2123		unsafe {
2124			gemm::gemm(
2125				m,
2126				n,
2127				k,
2128				target.as_mut_ptr(),
2129				cs as isize,
2130				1,
2131				true,
2132				lhs.as_ptr(),
2133				cs as isize,
2134				1,
2135				rhs.as_ptr(),
2136				k as isize,
2137				1,
2138				1.0,
2139				1.0,
2140				false,
2141				false,
2142				false,
2143				gemm::Parallelism::None,
2144			);
2145		}
2146
2147		for pack_lhs in [false, true] {
2148			for pack_rhs in [false] {
2149				let dst = &mut *avec![0.0; cs * n];
2150				let packed_lhs = &mut *avec![0.0f32; m.next_multiple_of(16) * k];
2151				let packed_rhs = &mut *avec![0.0; if pack_rhs { n.next_multiple_of(4) * k } else { 0 }];
2152
2153				unsafe {
2154					let row_chunk = [96 * 32, 96 * 16, 96 * 4, 96];
2155					let col_chunk = [1024, 256, 64, 16, 4];
2156
2157					let lhs_rs = row_chunk.map(|m| m as isize * sizeof);
2158					let rhs_cs = col_chunk.map(|n| (n * k) as isize * sizeof);
2159					let packed_lhs_rs = row_chunk.map(|m| (m * k) as isize * sizeof);
2160					let mut packed_rhs_cs = col_chunk.map(|n| (n * k) as isize * sizeof);
2161					packed_rhs_cs[0] = 0;
2162
2163					kernel(
2164						1,
2165						&F32_SIMD512x4[..24],
2166						&F32_SIMDpack_512,
2167						96,
2168						4,
2169						size_of::<f32>(),
2170						lhs.as_ptr() as _,
2171						if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
2172						rhs.as_ptr() as _,
2173						if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
2174						m,
2175						n,
2176						&row_chunk,
2177						&col_chunk,
2178						&lhs_rs,
2179						&rhs_cs,
2180						&if pack_lhs { packed_lhs_rs } else { lhs_rs },
2181						&if pack_rhs { packed_rhs_cs } else { rhs_cs },
2182						0,
2183						0,
2184						Position { row: 0, col: 0 },
2185						&MicrokernelInfo {
2186							flags: 0,
2187							depth: k,
2188							lhs_rs: sizeof,
2189							lhs_cs: cs as isize * sizeof,
2190							rhs_rs: sizeof,
2191							rhs_cs: k as isize * sizeof,
2192							alpha: &raw const *&1.0f32 as _,
2193							ptr: dst.as_mut_ptr() as _,
2194							rs: sizeof,
2195							cs: cs as isize * sizeof,
2196							row_idx: null_mut(),
2197							col_idx: null_mut(),
2198							diag_ptr: null(),
2199							diag_stride: 0,
2200						},
2201					)
2202				}
2203				let mut i = 0;
2204				for (&target, &dst) in core::iter::zip(&*target, &*dst) {
2205					if !((target - dst).abs() < 1e-6) {
2206						dbg!(i / cs, i % cs, target, dst);
2207						panic!();
2208					}
2209					i += 1;
2210				}
2211			}
2212		}
2213	}
2214}
2215
2216#[cfg(test)]
2217mod tests_c32 {
2218	use super::*;
2219
2220	use aligned_vec::*;
2221	use bytemuck::*;
2222	use core::ptr::null_mut;
2223	use gemm::c32;
2224	use rand::prelude::*;
2225
2226	#[test]
2227	fn test_avx512_microkernel() {
2228		let rng = &mut StdRng::seed_from_u64(0);
2229
2230		let sizeof = size_of::<c32>() as isize;
2231		let len = 64 / size_of::<c32>();
2232
2233		for pack_lhs in [false, true] {
2234			for pack_rhs in [false] {
2235				for alpha in [1.0.into(), 0.0.into(), c32::new(0.0, 3.5), c32::new(2.5, 3.5)] {
2236					let alpha: c32 = alpha;
2237					for m in 1..=127usize {
2238						for n in (1..=4usize).into_iter().chain([8]) {
2239							for cs in [m.next_multiple_of(len), m] {
2240								for conj_lhs in [false, true] {
2241									for conj_rhs in [false, true] {
2242										for diag_scale in [false, true] {
2243											if diag_scale && !pack_lhs {
2244												continue;
2245											}
2246											let conj_different = conj_lhs != conj_rhs;
2247
2248											let acs = m.next_multiple_of(len);
2249											let k = 1usize;
2250
2251											let packed_lhs: &mut [c32] = &mut *avec![0.0.into(); acs * k];
2252											let packed_rhs: &mut [c32] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
2253											let lhs: &mut [c32] = &mut *avec![0.0.into(); cs * k];
2254											let rhs: &mut [c32] = &mut *avec![0.0.into(); n * k];
2255											let dst: &mut [c32] = &mut *avec![0.0.into(); cs * n];
2256											let target: &mut [c32] = &mut *avec![0.0.into(); cs * n];
2257
2258											let diag: &mut [f32] = &mut *avec![0.0.into(); k];
2259
2260											rng.fill(cast_slice_mut::<c32, f32>(lhs));
2261											rng.fill(cast_slice_mut::<c32, f32>(rhs));
2262											rng.fill(diag);
2263
2264											for i in 0..m {
2265												for j in 0..n {
2266													let target = &mut target[i + cs * j];
2267													let mut acc: c32 = 0.0.into();
2268													for depth in 0..k {
2269														let mut l = lhs[i + cs * depth];
2270														let mut r = rhs[depth + k * j];
2271														let d = diag[depth];
2272
2273														if conj_lhs {
2274															l = l.conj();
2275														}
2276														if conj_rhs {
2277															r = r.conj();
2278														}
2279
2280														if diag_scale {
2281															acc += d * l * r;
2282														} else {
2283															acc += l * r;
2284														}
2285													}
2286													*target = acc * alpha + *target;
2287												}
2288											}
2289
2290											unsafe {
2291												millikernel_colmajor(
2292													C32_SIMD512x4[3],
2293													C32_SIMDpack_512[0],
2294													48,
2295													4,
2296													8,
2297													lhs.as_ptr() as _,
2298													if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
2299													rhs.as_ptr() as _,
2300													if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
2301													m,
2302													n,
2303													&mut MillikernelInfo {
2304														lhs_rs: 48 * sizeof,
2305														packed_lhs_rs: 48 * sizeof * k as isize,
2306														rhs_cs: 4 * sizeof * k as isize,
2307														packed_rhs_cs: 4 * sizeof * k as isize,
2308														micro: MicrokernelInfo {
2309															flags: ((conj_lhs as usize) << 1) | ((conj_different as usize) << 2),
2310															depth: k,
2311															lhs_rs: 1 * sizeof,
2312															lhs_cs: cs as isize * sizeof,
2313															rhs_rs: 1 * sizeof,
2314															rhs_cs: k as isize * sizeof,
2315															alpha: &raw const alpha as _,
2316															ptr: dst.as_mut_ptr() as _,
2317															rs: 1 * sizeof,
2318															cs: cs as isize * sizeof,
2319															row_idx: null_mut(),
2320															col_idx: null_mut(),
2321															diag_ptr: if diag_scale { diag.as_ptr() as *const () } else { null() },
2322															diag_stride: if diag_scale { size_of::<f32>() as isize } else { 0 },
2323														},
2324													},
2325													&mut Position { row: 0, col: 0 },
2326												)
2327											};
2328											let mut i = 0;
2329											for (&target, &dst) in core::iter::zip(&*target, &*dst) {
2330												if !((target - dst).norm_sqr().sqrt() < 1e-4) {
2331													dbg!(i / cs, i % cs, target, dst);
2332													panic!();
2333												}
2334												i += 1;
2335											}
2336										}
2337									}
2338								}
2339							}
2340						}
2341					}
2342				}
2343			}
2344		}
2345	}
2346}
2347
2348#[cfg(test)]
2349mod tests_c32_lower {
2350	use super::*;
2351
2352	use aligned_vec::*;
2353	use bytemuck::*;
2354	use core::ptr::null_mut;
2355	use gemm::c32;
2356	use rand::prelude::*;
2357
2358	#[test]
2359	fn test_avx512_microkernel() {
2360		let rng = &mut StdRng::seed_from_u64(0);
2361
2362		let sizeof = size_of::<c32>() as isize;
2363		let len = 64 / size_of::<c32>();
2364
2365		for pack_lhs in [false, true] {
2366			for pack_rhs in [false] {
2367				for alpha in [1.0.into(), 0.0.into(), c32::new(0.0, 3.5), c32::new(2.5, 3.5)] {
2368					let alpha: c32 = alpha;
2369					for m in 1..=127usize {
2370						for n in (1..=4usize).chain([8, 32]) {
2371							for cs in [m, m.next_multiple_of(len)] {
2372								for conj_lhs in [false, true] {
2373									for conj_rhs in [false, true] {
2374										for diag_scale in [false, true] {
2375											if diag_scale && !pack_lhs {
2376												continue;
2377											}
2378											let conj_different = conj_lhs != conj_rhs;
2379
2380											let acs = m.next_multiple_of(len);
2381											let k = 1usize;
2382
2383											let packed_lhs: &mut [c32] = &mut *avec![0.0.into(); acs * k];
2384											let packed_rhs: &mut [c32] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
2385											let lhs: &mut [c32] = &mut *avec![0.0.into(); cs * k];
2386											let rhs: &mut [c32] = &mut *avec![0.0.into(); n * k];
2387											let dst: &mut [c32] = &mut *avec![0.0.into(); cs * n];
2388											let target: &mut [c32] = &mut *avec![0.0.into(); cs * n];
2389
2390											let diag: &mut [f32] = &mut *avec![0.0.into(); k];
2391
2392											rng.fill(cast_slice_mut::<c32, f32>(lhs));
2393											rng.fill(cast_slice_mut::<c32, f32>(rhs));
2394											rng.fill(diag);
2395
2396											for i in 0..m {
2397												for j in 0..n {
2398													if i < j {
2399														continue;
2400													}
2401													let target = &mut target[i + cs * j];
2402													let mut acc: c32 = 0.0.into();
2403													for depth in 0..k {
2404														let mut l = lhs[i + cs * depth];
2405														let mut r = rhs[depth + k * j];
2406														let d = diag[depth];
2407
2408														if conj_lhs {
2409															l = l.conj();
2410														}
2411														if conj_rhs {
2412															r = r.conj();
2413														}
2414
2415														if diag_scale {
2416															acc += d * l * r;
2417														} else {
2418															acc += l * r;
2419														}
2420													}
2421													*target = acc * alpha + *target;
2422												}
2423											}
2424
2425											unsafe {
2426												millikernel_colmajor(
2427													C32_SIMD512x4[3],
2428													C32_SIMDpack_512[0],
2429													48,
2430													4,
2431													8,
2432													lhs.as_ptr() as _,
2433													if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
2434													rhs.as_ptr() as _,
2435													if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
2436													m,
2437													n,
2438													&mut MillikernelInfo {
2439														lhs_rs: 48 * sizeof,
2440														packed_lhs_rs: 48 * sizeof * k as isize,
2441														rhs_cs: 4 * sizeof * k as isize,
2442														packed_rhs_cs: 4 * sizeof * k as isize,
2443														micro: MicrokernelInfo {
2444															flags: ((conj_lhs as usize) << 1) | ((conj_different as usize) << 2) | (1 << 3),
2445															depth: k,
2446															lhs_rs: 1 * sizeof,
2447															lhs_cs: cs as isize * sizeof,
2448															rhs_rs: 1 * sizeof,
2449															rhs_cs: k as isize * sizeof,
2450															alpha: &raw const alpha as _,
2451															ptr: dst.as_mut_ptr() as _,
2452															rs: 1 * sizeof,
2453															cs: cs as isize * sizeof,
2454															row_idx: null_mut(),
2455															col_idx: null_mut(),
2456															diag_ptr: if diag_scale { diag.as_ptr() as *const () } else { null() },
2457															diag_stride: if diag_scale { size_of::<f32>() as isize } else { 0 },
2458														},
2459													},
2460													&mut Position { row: 0, col: 0 },
2461												)
2462											};
2463											let mut i = 0;
2464											for (&target, &dst) in core::iter::zip(&*target, &*dst) {
2465												if !((target - dst).norm_sqr().sqrt() < 1e-4) {
2466													dbg!(i / cs, i % cs, target, dst);
2467													panic!();
2468												}
2469												i += 1;
2470											}
2471										}
2472									}
2473								}
2474							}
2475						}
2476					}
2477				}
2478			}
2479		}
2480	}
2481
2482	#[test]
2483	fn test_avx256microkernel() {
2484		let rng = &mut StdRng::seed_from_u64(0);
2485
2486		let sizeof = size_of::<c32>() as isize;
2487		let len = 64 / size_of::<c32>();
2488
2489		for pack_lhs in [false, true] {
2490			for pack_rhs in [false] {
2491				for alpha in [1.0.into(), 0.0.into(), c32::new(0.0, 3.5), c32::new(2.5, 3.5)] {
2492					let alpha: c32 = alpha;
2493					for m in 1..=127usize {
2494						for n in (1..=4usize).chain([8, 32]) {
2495							for cs in [m, m.next_multiple_of(len)] {
2496								for conj_lhs in [false, true] {
2497									for conj_rhs in [false, true] {
2498										for diag_scale in [false, true] {
2499											if diag_scale && !pack_lhs {
2500												continue;
2501											}
2502
2503											let conj_different = conj_lhs != conj_rhs;
2504
2505											let acs = m.next_multiple_of(len);
2506											let k = 1usize;
2507
2508											let packed_lhs: &mut [c32] = &mut *avec![0.0.into(); acs * k];
2509											let packed_rhs: &mut [c32] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
2510											let lhs: &mut [c32] = &mut *avec![0.0.into(); cs * k];
2511											let rhs: &mut [c32] = &mut *avec![0.0.into(); n * k];
2512											let dst: &mut [c32] = &mut *avec![0.0.into(); cs * n];
2513											let target: &mut [c32] = &mut *avec![0.0.into(); cs * n];
2514
2515											let diag: &mut [f32] = &mut *avec![0.0.into(); k];
2516
2517											rng.fill(cast_slice_mut::<c32, f32>(lhs));
2518											rng.fill(cast_slice_mut::<c32, f32>(rhs));
2519											rng.fill(diag);
2520
2521											for i in 0..m {
2522												for j in 0..n {
2523													if i < j {
2524														continue;
2525													}
2526													let target = &mut target[i + cs * j];
2527													let mut acc: c32 = 0.0.into();
2528													for depth in 0..k {
2529														let mut l = lhs[i + cs * depth];
2530														let mut r = rhs[depth + k * j];
2531														let d = diag[depth];
2532
2533														if conj_lhs {
2534															l = l.conj();
2535														}
2536														if conj_rhs {
2537															r = r.conj();
2538														}
2539
2540														if diag_scale {
2541															acc += d * l * r;
2542														} else {
2543															acc += l * r;
2544														}
2545													}
2546													*target = acc * alpha + *target;
2547												}
2548											}
2549
2550											unsafe {
2551												millikernel_colmajor(
2552													C32_SIMD256[3],
2553													C32_SIMDpack_256[0],
2554													12,
2555													4,
2556													8,
2557													lhs.as_ptr() as _,
2558													if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
2559													rhs.as_ptr() as _,
2560													if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
2561													m,
2562													n,
2563													&mut MillikernelInfo {
2564														lhs_rs: 12 * sizeof,
2565														packed_lhs_rs: 12 * sizeof * k as isize,
2566														rhs_cs: 4 * sizeof * k as isize,
2567														packed_rhs_cs: 4 * sizeof * k as isize,
2568														micro: MicrokernelInfo {
2569															flags: ((conj_lhs as usize) << 1) | ((conj_different as usize) << 2) | (1 << 3),
2570															depth: k,
2571															lhs_rs: 1 * sizeof,
2572															lhs_cs: cs as isize * sizeof,
2573															rhs_rs: 1 * sizeof,
2574															rhs_cs: k as isize * sizeof,
2575															alpha: &raw const alpha as _,
2576															ptr: dst.as_mut_ptr() as _,
2577															rs: 1 * sizeof,
2578															cs: cs as isize * sizeof,
2579															row_idx: null_mut(),
2580															col_idx: null_mut(),
2581															diag_ptr: if diag_scale { diag.as_ptr() as *const () } else { null() },
2582															diag_stride: if diag_scale { size_of::<f32>() as isize } else { 0 },
2583														},
2584													},
2585													&mut Position { row: 0, col: 0 },
2586												)
2587											};
2588											let mut i = 0;
2589											for (&target, &dst) in core::iter::zip(&*target, &*dst) {
2590												if !((target - dst).norm_sqr().sqrt() < 1e-4) {
2591													dbg!(i / cs, i % cs, target, dst);
2592													panic!();
2593												}
2594												i += 1;
2595											}
2596										}
2597									}
2598								}
2599							}
2600						}
2601					}
2602				}
2603			}
2604		}
2605	}
2606}
2607
2608#[cfg(test)]
2609mod tests_c32_lower_add {
2610	use super::*;
2611
2612	use aligned_vec::*;
2613	use bytemuck::*;
2614	use gemm::c64;
2615	use rand::prelude::*;
2616
2617	#[test]
2618	fn test_avx512_microkernel_rowmajor() {
2619		let rng = &mut StdRng::seed_from_u64(0);
2620
2621		let sizeof = size_of::<c64>() as isize;
2622		let len = 64 / size_of::<c64>();
2623
2624		for alpha in [1.0.into(), 0.0.into(), c64::new(0.0, 3.5), c64::new(2.5, 3.5)] {
2625			let alpha: c64 = alpha;
2626			for m in 1..=127usize {
2627				let m = 4005usize;
2628				for n in (1..=4usize).chain([8, 32, 1024]) {
2629					let n = 2usize;
2630					for cs in [m, m.next_multiple_of(len)] {
2631						for conj_lhs in [false, true] {
2632							for conj_rhs in [false, true] {
2633								for diag_scale in [true, false] {
2634									let conj_different = conj_lhs != conj_rhs;
2635
2636									let acs = m.next_multiple_of(24);
2637									let k = 4005usize;
2638									dbg!(m, n, k, diag_scale, conj_lhs, conj_rhs);
2639
2640									let packed_lhs: &mut [c64] = &mut *avec![0.0.into(); acs * k];
2641									let packed_rhs: &mut [c64] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
2642									let lhs: &mut [c64] = &mut *avec![0.0.into(); m * k];
2643									let rhs: &mut [c64] = &mut *avec![0.0.into(); n * k];
2644									let dst: &mut [c64] = &mut *avec![0.0.into(); cs * n];
2645									rng.fill(cast_slice_mut::<c64, f64>(dst));
2646
2647									let target0: &mut [c64] = &mut *dst.to_vec();
2648									let target1: &mut [c64] = &mut *dst.to_vec();
2649
2650									let diag: &mut [c64] = &mut *avec![0.0.into(); k];
2651
2652									rng.fill(cast_slice_mut::<c64, f64>(lhs));
2653									rng.fill(cast_slice_mut::<c64, f64>(rhs));
2654
2655									for x in &mut *diag {
2656										x.re = rng.random();
2657									}
2658
2659									for i in 0..m {
2660										for j in 0..n {
2661											let target = &mut target0[i + cs * j];
2662											let mut acc: c64 = 0.0.into();
2663											for depth in 0..k {
2664												let mut l = lhs[i * k + depth];
2665												let mut r = rhs[depth + k * j];
2666												let d = diag[depth];
2667
2668												if conj_lhs {
2669													l = l.conj();
2670												}
2671												if conj_rhs {
2672													r = r.conj();
2673												}
2674
2675												if diag_scale {
2676													acc += d * l * r;
2677												} else {
2678													acc += l * r;
2679												}
2680											}
2681											*target = acc * alpha;
2682										}
2683									}
2684
2685									unsafe {
2686										gemm(
2687											DType::C64,
2688											IType::U64,
2689											InstrSet::Avx512,
2690											m,
2691											n,
2692											k,
2693											dst.as_mut_ptr() as _,
2694											1,
2695											cs as isize,
2696											null(),
2697											null(),
2698											DstKind::Full,
2699											Accum::Replace,
2700											lhs.as_ptr() as _,
2701											k as isize,
2702											1,
2703											conj_lhs,
2704											if diag_scale { diag.as_ptr() as _ } else { null() },
2705											if diag_scale { 1 } else { 0 },
2706											rhs.as_ptr() as _,
2707											1,
2708											k as isize,
2709											conj_rhs,
2710											&raw const alpha as _,
2711											1,
2712										)
2713									};
2714
2715									let mut i = 0;
2716									for (&target, &dst) in core::iter::zip(&*target0, &*dst) {
2717										if !((target - dst).norm_sqr().sqrt() < 1e-4) {
2718											dbg!(i / cs, i % cs, target, dst);
2719											panic!();
2720										}
2721										i += 1;
2722									}
2723								}
2724							}
2725						}
2726					}
2727				}
2728			}
2729		}
2730	}
2731
2732	#[test]
2733	fn test_avx512_microkernel_colmajor() {
2734		let rng = &mut StdRng::seed_from_u64(0);
2735
2736		let sizeof = size_of::<c64>() as isize;
2737		let len = 64 / size_of::<c64>();
2738
2739		for alpha in [1.0.into(), 0.0.into(), c64::new(0.0, 3.5), c64::new(2.5, 3.5)] {
2740			let alpha: c64 = alpha;
2741			for m in [4005usize] {
2742				for n in [2usize] {
2743					for cs in [4008] {
2744						for conj_lhs in [false, true] {
2745							for conj_rhs in [false, true] {
2746								for diag_scale in [true, false] {
2747									let conj_different = conj_lhs != conj_rhs;
2748
2749									let acs = m.next_multiple_of(24);
2750									let k = 4005usize;
2751									dbg!(m, n, k, diag_scale, conj_lhs, conj_rhs);
2752
2753									let lhs: &mut [c64] = &mut *avec![0.0.into(); cs * k];
2754									let rhs: &mut [c64] = &mut *avec![0.0.into(); n * cs];
2755									let dst: &mut [c64] = &mut *avec![0.0.into(); cs * n];
2756									rng.fill(cast_slice_mut::<c64, f64>(dst));
2757
2758									let target0: &mut [c64] = &mut *dst.to_vec();
2759									let target1: &mut [c64] = &mut *dst.to_vec();
2760
2761									let diag: &mut [c64] = &mut *avec![0.0.into(); k];
2762
2763									rng.fill(cast_slice_mut::<c64, f64>(lhs));
2764									rng.fill(cast_slice_mut::<c64, f64>(rhs));
2765
2766									for x in &mut *diag {
2767										x.re = rng.random();
2768									}
2769
2770									for i in 0..m {
2771										for j in 0..n {
2772											let target = &mut target0[i + cs * j];
2773											let mut acc: c64 = 0.0.into();
2774											for depth in 0..k {
2775												let mut l = lhs[i + cs * depth];
2776												let mut r = rhs[depth + cs * j];
2777												let d = diag[depth];
2778
2779												if conj_lhs {
2780													l = l.conj();
2781												}
2782												if conj_rhs {
2783													r = r.conj();
2784												}
2785
2786												if diag_scale {
2787													acc += d * l * r;
2788												} else {
2789													acc += l * r;
2790												}
2791											}
2792											*target = acc * alpha;
2793										}
2794									}
2795
2796									unsafe {
2797										gemm(
2798											DType::C64,
2799											IType::U64,
2800											InstrSet::Avx512,
2801											m,
2802											n,
2803											k,
2804											dst.as_mut_ptr() as _,
2805											1,
2806											cs as isize,
2807											null(),
2808											null(),
2809											DstKind::Full,
2810											Accum::Replace,
2811											lhs.as_ptr() as _,
2812											1,
2813											cs as isize,
2814											conj_lhs,
2815											if diag_scale { diag.as_ptr() as _ } else { null() },
2816											if diag_scale { 1 } else { 0 },
2817											rhs.as_ptr() as _,
2818											1,
2819											cs as isize,
2820											conj_rhs,
2821											&raw const alpha as _,
2822											2,
2823										)
2824									};
2825
2826									let mut i = 0;
2827									for (&target, &dst) in core::iter::zip(&*target0, &*dst) {
2828										if !((target - dst).norm_sqr().sqrt() < 1e-8) {
2829											dbg!(i / cs, i % cs, target, dst);
2830											panic!();
2831										}
2832										i += 1;
2833									}
2834								}
2835							}
2836						}
2837					}
2838				}
2839			}
2840		}
2841	}
2842}
2843
2844#[cfg(test)]
2845mod tests_c32_upper {
2846	use super::*;
2847
2848	use aligned_vec::*;
2849	use bytemuck::*;
2850	use core::ptr::null_mut;
2851	use gemm::c32;
2852	use rand::prelude::*;
2853
2854	#[test]
2855	fn test_avx512_microkernel() {
2856		let rng = &mut StdRng::seed_from_u64(0);
2857
2858		let sizeof = size_of::<c32>() as isize;
2859		let len = 64 / size_of::<c32>();
2860
2861		for pack_lhs in [false, true] {
2862			for pack_rhs in [false] {
2863				for alpha in [1.0.into(), 0.0.into(), c32::new(0.0, 3.5), c32::new(2.5, 3.5)] {
2864					let alpha: c32 = alpha;
2865					for m in 1..=127usize {
2866						for n in [8].into_iter().chain(1..=4usize).chain([8]) {
2867							for cs in [m, m.next_multiple_of(len)] {
2868								for conj_lhs in [false, true] {
2869									for conj_rhs in [false, true] {
2870										for diag_scale in [false, true] {
2871											if diag_scale && !pack_lhs {
2872												continue;
2873											}
2874											let conj_different = conj_lhs != conj_rhs;
2875
2876											let acs = m.next_multiple_of(len);
2877											let k = 1usize;
2878
2879											let packed_lhs: &mut [c32] = &mut *avec![0.0.into(); acs * k];
2880											let packed_rhs: &mut [c32] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
2881											let lhs: &mut [c32] = &mut *avec![0.0.into(); cs * k];
2882											let rhs: &mut [c32] = &mut *avec![0.0.into(); n * k];
2883											let dst: &mut [c32] = &mut *avec![0.0.into(); cs * n];
2884											let target: &mut [c32] = &mut *avec![0.0.into(); cs * n];
2885
2886											let diag: &mut [f32] = &mut *avec![0.0.into(); k];
2887
2888											rng.fill(cast_slice_mut::<c32, f32>(lhs));
2889											rng.fill(cast_slice_mut::<c32, f32>(rhs));
2890											rng.fill(diag);
2891
2892											for i in 0..m {
2893												for j in 0..n {
2894													if i > j {
2895														continue;
2896													}
2897													let target = &mut target[i + cs * j];
2898													let mut acc: c32 = 0.0.into();
2899													for depth in 0..k {
2900														let mut l = lhs[i + cs * depth];
2901														let mut r = rhs[depth + k * j];
2902														let d = diag[depth];
2903
2904														if conj_lhs {
2905															l = l.conj();
2906														}
2907														if conj_rhs {
2908															r = r.conj();
2909														}
2910
2911														if diag_scale {
2912															acc += d * l * r;
2913														} else {
2914															acc += l * r;
2915														}
2916													}
2917													*target = acc * alpha + *target;
2918												}
2919											}
2920
2921											unsafe {
2922												millikernel_colmajor(
2923													C32_SIMD512x4[3],
2924													C32_SIMDpack_512[0],
2925													48,
2926													4,
2927													8,
2928													lhs.as_ptr() as _,
2929													if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
2930													rhs.as_ptr() as _,
2931													if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
2932													m,
2933													n,
2934													&mut MillikernelInfo {
2935														lhs_rs: 48 * sizeof,
2936														packed_lhs_rs: 48 * sizeof * k as isize,
2937														rhs_cs: 4 * sizeof * k as isize,
2938														packed_rhs_cs: 4 * sizeof * k as isize,
2939														micro: MicrokernelInfo {
2940															flags: ((conj_lhs as usize) << 1) | ((conj_different as usize) << 2) | (1 << 4),
2941															depth: k,
2942															lhs_rs: 1 * sizeof,
2943															lhs_cs: cs as isize * sizeof,
2944															rhs_rs: 1 * sizeof,
2945															rhs_cs: k as isize * sizeof,
2946															alpha: &raw const alpha as _,
2947															ptr: dst.as_mut_ptr() as _,
2948															rs: 1 * sizeof,
2949															cs: cs as isize * sizeof,
2950															row_idx: null_mut(),
2951															col_idx: null_mut(),
2952															diag_ptr: if diag_scale { diag.as_ptr() as *const () } else { null() },
2953															diag_stride: if diag_scale { size_of::<f32>() as isize } else { 0 },
2954														},
2955													},
2956													&mut Position { row: 0, col: 0 },
2957												)
2958											};
2959											let mut i = 0;
2960											for (&target, &dst) in core::iter::zip(&*target, &*dst) {
2961												if !((target - dst).norm_sqr().sqrt() < 1e-4) {
2962													dbg!(i / cs, i % cs, target, dst);
2963													panic!();
2964												}
2965												i += 1;
2966											}
2967										}
2968									}
2969								}
2970							}
2971						}
2972					}
2973				}
2974			}
2975		}
2976	}
2977}
2978
2979#[cfg(test)]
2980mod transpose_tests {
2981	use super::*;
2982	use aligned_vec::avec;
2983	use rand::prelude::*;
2984
2985	#[test]
2986	fn test_b128() {
2987		let rng = &mut StdRng::seed_from_u64(0);
2988
2989		for m in 1..=24 {
2990			let n = 127;
2991
2992			let src = &mut *avec![0u128; m * n];
2993			let dst = &mut *avec![0u128; m.next_multiple_of(8) * n];
2994
2995			rng.fill(src);
2996			rng.fill(dst);
2997
2998			let ptr = C64_SIMDpack_512[(24 - m) / 4];
2999			let info = MicrokernelInfo {
3000				flags: 0,
3001				depth: n,
3002				lhs_rs: (n * size_of::<u128>()) as isize,
3003				lhs_cs: size_of::<u128>() as isize,
3004				rhs_rs: 0,
3005				rhs_cs: 0,
3006				alpha: null(),
3007				ptr: null_mut(),
3008				rs: 0,
3009				cs: 0,
3010				row_idx: null(),
3011				col_idx: null(),
3012				diag_ptr: null(),
3013				diag_stride: 0,
3014			};
3015
3016			unsafe {
3017				core::arch::asm! {"
3018                call r10
3019                ",
3020					in("r10") ptr,
3021					in("rax") src.as_ptr(),
3022					in("r15") dst.as_mut_ptr(),
3023					in("r8") m,
3024					in("rsi") &info,
3025				};
3026			}
3027
3028			for j in 0..n {
3029				for i in 0..m {
3030					assert_eq!(src[i * n + j], dst[i + m.next_multiple_of(4) * j]);
3031				}
3032			}
3033		}
3034	}
3035
3036	#[test]
3037	fn test_b64() {
3038		let rng = &mut StdRng::seed_from_u64(0);
3039
3040		for m in 1..=48 {
3041			let n = 127;
3042
3043			let src = &mut *avec![0u64; m * n];
3044			let dst = &mut *avec![0u64; m.next_multiple_of(8) * n];
3045
3046			rng.fill(src);
3047			rng.fill(dst);
3048
3049			let ptr = F64_SIMDpack_512[(48 - m) / 8];
3050			let info = MicrokernelInfo {
3051				flags: 0,
3052				depth: n,
3053				lhs_rs: (n * size_of::<u64>()) as isize,
3054				lhs_cs: size_of::<u64>() as isize,
3055				rhs_rs: 0,
3056				rhs_cs: 0,
3057				alpha: null(),
3058				ptr: null_mut(),
3059				rs: 0,
3060				cs: 0,
3061				row_idx: null(),
3062				col_idx: null(),
3063				diag_ptr: null(),
3064				diag_stride: 0,
3065			};
3066
3067			unsafe {
3068				core::arch::asm! {"
3069                call r10
3070                ",
3071					in("r10") ptr,
3072					in("rax") src.as_ptr(),
3073					in("r15") dst.as_mut_ptr(),
3074					in("r8") m,
3075					in("rsi") &info,
3076				};
3077			}
3078
3079			for j in 0..n {
3080				for i in 0..m {
3081					assert_eq!(src[i * n + j], dst[i + m.next_multiple_of(8) * j]);
3082				}
3083			}
3084		}
3085	}
3086
3087	#[test]
3088	fn test_b32() {
3089		let rng = &mut StdRng::seed_from_u64(0);
3090
3091		for m in 1..=96 {
3092			let n = 127;
3093
3094			let src = &mut *avec![0u32; m * n];
3095			let dst = &mut *avec![0u32; m.next_multiple_of(16) * n];
3096
3097			rng.fill(src);
3098			rng.fill(dst);
3099
3100			let ptr = F32_SIMDpack_512[(96 - m) / 16];
3101			let info = MicrokernelInfo {
3102				flags: 0,
3103				depth: n,
3104				lhs_rs: (n * size_of::<f32>()) as isize,
3105				lhs_cs: size_of::<f32>() as isize,
3106				rhs_rs: 0,
3107				rhs_cs: 0,
3108				alpha: null(),
3109				ptr: null_mut(),
3110				rs: 0,
3111				cs: 0,
3112				row_idx: null(),
3113				col_idx: null(),
3114				diag_ptr: null(),
3115				diag_stride: 0,
3116			};
3117
3118			unsafe {
3119				core::arch::asm! {"
3120                call r10
3121                ",
3122					in("r10") ptr,
3123					in("rax") src.as_ptr(),
3124					in("r15") dst.as_mut_ptr(),
3125					in("r8") m,
3126					in("rsi") &info,
3127				};
3128			}
3129
3130			for j in 0..n {
3131				for i in 0..m {
3132					assert_eq!(src[i * n + j], dst[i + m.next_multiple_of(16) * j]);
3133				}
3134			}
3135		}
3136	}
3137}
3138
3139#[cfg(test)]
3140mod tests_c32_gather_scatter {
3141	use super::*;
3142
3143	use aligned_vec::*;
3144	use bytemuck::*;
3145	use core::ptr::null_mut;
3146	use gemm::c32;
3147	use rand::prelude::*;
3148
3149	#[test]
3150	fn test_avx512_microkernel() {
3151		let rng = &mut StdRng::seed_from_u64(0);
3152
3153		let sizeof = size_of::<c32>() as isize;
3154		let len = 64 / size_of::<c32>();
3155
3156		for pack_lhs in [false, true] {
3157			for pack_rhs in [false] {
3158				for alpha in [1.0.into(), 0.0.into(), c32::new(0.0, 3.5), c32::new(2.5, 3.5)] {
3159					let alpha: c32 = alpha;
3160					for m in 1..=127usize {
3161						for n in [8].into_iter().chain(1..=4usize).chain([8]) {
3162							for cs in [m, m.next_multiple_of(len)] {
3163								for conj_lhs in [false, true] {
3164									for conj_rhs in [false, true] {
3165										for diag_scale in [false, true] {
3166											if diag_scale && !pack_lhs {
3167												continue;
3168											}
3169
3170											let m = 2usize;
3171											let cs = m;
3172											let conj_different = conj_lhs != conj_rhs;
3173
3174											let acs = m.next_multiple_of(len);
3175											let k = 1usize;
3176
3177											let packed_lhs: &mut [c32] = &mut *avec![0.0.into(); acs * k];
3178											let packed_rhs: &mut [c32] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
3179											let lhs: &mut [c32] = &mut *avec![0.0.into(); cs * k];
3180											let rhs: &mut [c32] = &mut *avec![0.0.into(); n * k];
3181											let dst: &mut [c32] = &mut *avec![0.0.into(); 2 * cs * n];
3182											let target: &mut [c32] = &mut *avec![0.0.into(); 2 * cs * n];
3183
3184											let diag: &mut [f32] = &mut *avec![0.0.into(); k];
3185
3186											rng.fill(cast_slice_mut::<c32, f32>(lhs));
3187											rng.fill(cast_slice_mut::<c32, f32>(rhs));
3188											rng.fill(diag);
3189
3190											for i in 0..m {
3191												for j in 0..n {
3192													if i > j {
3193														continue;
3194													}
3195													let target = &mut target[2 * (i + cs * j)];
3196													let mut acc: c32 = 0.0.into();
3197													for depth in 0..k {
3198														let mut l = lhs[i + cs * depth];
3199														let mut r = rhs[depth + k * j];
3200														let d = diag[depth];
3201
3202														if conj_lhs {
3203															l = l.conj();
3204														}
3205														if conj_rhs {
3206															r = r.conj();
3207														}
3208
3209														if diag_scale {
3210															acc += d * l * r;
3211														} else {
3212															acc += l * r;
3213														}
3214													}
3215													*target = acc * alpha + *target;
3216												}
3217											}
3218
3219											unsafe {
3220												millikernel_colmajor(
3221													C32_SIMD512x4[3],
3222													C32_SIMDpack_512[0],
3223													48,
3224													4,
3225													8,
3226													lhs.as_ptr() as _,
3227													if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
3228													rhs.as_ptr() as _,
3229													if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
3230													m,
3231													n,
3232													&mut MillikernelInfo {
3233														lhs_rs: 48 * sizeof,
3234														packed_lhs_rs: 48 * sizeof * k as isize,
3235														rhs_cs: 4 * sizeof * k as isize,
3236														packed_rhs_cs: 4 * sizeof * k as isize,
3237														micro: MicrokernelInfo {
3238															flags: ((conj_lhs as usize) << 1) | ((conj_different as usize) << 2) | (1 << 4),
3239															depth: k,
3240															lhs_rs: 1 * sizeof,
3241															lhs_cs: cs as isize * sizeof,
3242															rhs_rs: 1 * sizeof,
3243															rhs_cs: k as isize * sizeof,
3244															alpha: &raw const alpha as _,
3245															ptr: dst.as_mut_ptr() as _,
3246															rs: 2 * sizeof,
3247															cs: 2 * cs as isize * sizeof,
3248															row_idx: null_mut(),
3249															col_idx: null_mut(),
3250															diag_ptr: if diag_scale { diag.as_ptr() as *const () } else { null() },
3251															diag_stride: if diag_scale { size_of::<f32>() as isize } else { 0 },
3252														},
3253													},
3254													&mut Position { row: 0, col: 0 },
3255												)
3256											};
3257											let mut i = 0;
3258											for (&target, &dst) in core::iter::zip(&*target, &*dst) {
3259												if !((target - dst).norm_sqr().sqrt() < 1e-4) {
3260													dbg!(i / cs, i % cs, target, dst);
3261													panic!();
3262												}
3263												i += 1;
3264											}
3265										}
3266									}
3267								}
3268							}
3269						}
3270					}
3271				}
3272			}
3273		}
3274	}
3275
3276	#[test]
3277	fn test_avx512_microkernel2() {
3278		let rng = &mut StdRng::seed_from_u64(0);
3279
3280		let sizeof = size_of::<c32>() as isize;
3281		let len = 64 / size_of::<c32>();
3282
3283		for pack_lhs in [false, true] {
3284			for pack_rhs in [false] {
3285				for alpha in [1.0.into(), 0.0.into(), c32::new(0.0, 3.5), c32::new(2.5, 3.5)] {
3286					let alpha: c32 = alpha;
3287					for m in 1..=127usize {
3288						for n in [8].into_iter().chain(1..=4usize).chain([8]) {
3289							for cs in [m, m.next_multiple_of(len)] {
3290								for conj_lhs in [false, true] {
3291									for conj_rhs in [false, true] {
3292										for diag_scale in [false, true] {
3293											if diag_scale && !pack_lhs {
3294												continue;
3295											}
3296											let m = 2usize;
3297											let cs = m;
3298											let conj_different = conj_lhs != conj_rhs;
3299											let idx = (0..Ord::max(m, n)).map(|i| 2 * i as u32).collect::<Vec<_>>();
3300
3301											let acs = m.next_multiple_of(len);
3302											let k = 1usize;
3303
3304											let packed_lhs: &mut [c32] = &mut *avec![0.0.into(); acs * k];
3305											let packed_rhs: &mut [c32] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
3306											let lhs: &mut [c32] = &mut *avec![0.0.into(); cs * k];
3307											let rhs: &mut [c32] = &mut *avec![0.0.into(); n * k];
3308											let dst: &mut [c32] = &mut *avec![0.0.into(); 2 * cs * n];
3309											let target: &mut [c32] = &mut *avec![0.0.into(); 2 * cs * n];
3310
3311											let diag: &mut [f32] = &mut *avec![0.0.into(); k];
3312
3313											rng.fill(cast_slice_mut::<c32, f32>(lhs));
3314											rng.fill(cast_slice_mut::<c32, f32>(rhs));
3315											rng.fill(diag);
3316
3317											for i in 0..m {
3318												for j in 0..n {
3319													if i > j {
3320														continue;
3321													}
3322													let target = &mut target[2 * (i + cs * j)];
3323													let mut acc: c32 = 0.0.into();
3324													for depth in 0..k {
3325														let mut l = lhs[i + cs * depth];
3326														let mut r = rhs[depth + k * j];
3327														let d = diag[depth];
3328
3329														if conj_lhs {
3330															l = l.conj();
3331														}
3332														if conj_rhs {
3333															r = r.conj();
3334														}
3335
3336														if diag_scale {
3337															acc += d * l * r;
3338														} else {
3339															acc += l * r;
3340														}
3341													}
3342													*target = acc * alpha + *target;
3343												}
3344											}
3345
3346											unsafe {
3347												millikernel_colmajor(
3348													C32_SIMD512x4[3],
3349													C32_SIMDpack_512[0],
3350													48,
3351													4,
3352													8,
3353													lhs.as_ptr() as _,
3354													if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
3355													rhs.as_ptr() as _,
3356													if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
3357													m,
3358													n,
3359													&mut MillikernelInfo {
3360														lhs_rs: 48 * sizeof,
3361														packed_lhs_rs: 48 * sizeof * k as isize,
3362														rhs_cs: 4 * sizeof * k as isize,
3363														packed_rhs_cs: 4 * sizeof * k as isize,
3364														micro: MicrokernelInfo {
3365															flags: ((conj_lhs as usize) << 1)
3366																| ((conj_different as usize) << 2) | (1 << 4) | (1 << 5),
3367															depth: k,
3368															lhs_rs: 1 * sizeof,
3369															lhs_cs: cs as isize * sizeof,
3370															rhs_rs: 1 * sizeof,
3371															rhs_cs: k as isize * sizeof,
3372															alpha: &raw const alpha as _,
3373															ptr: dst.as_mut_ptr() as _,
3374															rs: sizeof,
3375															cs: cs as isize * sizeof,
3376															row_idx: idx.as_ptr() as _,
3377															col_idx: idx.as_ptr() as _,
3378															diag_ptr: if diag_scale { diag.as_ptr() as *const () } else { null() },
3379															diag_stride: if diag_scale { size_of::<f32>() as isize } else { 0 },
3380														},
3381													},
3382													&mut Position { row: 0, col: 0 },
3383												)
3384											};
3385											let mut i = 0;
3386											for (&target, &dst) in core::iter::zip(&*target, &*dst) {
3387												if !((target - dst).norm_sqr().sqrt() < 1e-4) {
3388													dbg!(i / cs, i % cs, target, dst);
3389													panic!();
3390												}
3391												i += 1;
3392											}
3393										}
3394									}
3395								}
3396							}
3397						}
3398					}
3399				}
3400			}
3401		}
3402	}
3403
3404	#[test]
3405	fn test_avx512_microkernel3() {
3406		let rng = &mut StdRng::seed_from_u64(0);
3407
3408		let sizeof = size_of::<c32>() as isize;
3409		let len = 64 / size_of::<c32>();
3410
3411		for pack_lhs in [true] {
3412			for pack_rhs in [false] {
3413				for alpha in [1.0.into(), 0.0.into(), c32::new(0.0, 3.5), c32::new(2.5, 3.5)] {
3414					let alpha: c32 = alpha;
3415					for m in 1..=127usize {
3416						for n in [8].into_iter().chain(1..=4usize).chain([8]) {
3417							for cs in [m, m.next_multiple_of(len)] {
3418								for conj_lhs in [false, true] {
3419									for conj_rhs in [false, true] {
3420										for diag_scale in [false, true] {
3421											if diag_scale && !pack_lhs {
3422												continue;
3423											}
3424											let m = 2usize;
3425											let cs = m;
3426											let conj_different = conj_lhs != conj_rhs;
3427											let idx = (0..Ord::max(m, n)).map(|i| 2 * i as u32).collect::<Vec<_>>();
3428
3429											let acs = m.next_multiple_of(len);
3430											let k = 1usize;
3431
3432											let packed_lhs: &mut [c32] = &mut *avec![0.0.into(); acs * k];
3433											let packed_rhs: &mut [c32] = &mut *avec![0.0.into(); n.next_multiple_of(4) * k];
3434											let lhs: &mut [c32] = &mut *avec![0.0.into(); 2 * cs * k];
3435											let rhs: &mut [c32] = &mut *avec![0.0.into(); n * k];
3436											let dst: &mut [c32] = &mut *avec![0.0.into(); 2 * cs * n];
3437											let target: &mut [c32] = &mut *avec![0.0.into(); 2 * cs * n];
3438
3439											let diag: &mut [f32] = &mut *avec![0.0.into(); k];
3440
3441											rng.fill(cast_slice_mut::<c32, f32>(lhs));
3442											rng.fill(cast_slice_mut::<c32, f32>(rhs));
3443											rng.fill(diag);
3444
3445											for i in 0..m {
3446												for j in 0..n {
3447													if i > j {
3448														continue;
3449													}
3450													let target = &mut target[2 * (i + cs * j)];
3451													let mut acc: c32 = 0.0.into();
3452													for depth in 0..k {
3453														let mut l = lhs[2 * (i + cs * depth)];
3454														let mut r = rhs[depth + k * j];
3455														let d = diag[depth];
3456
3457														if conj_lhs {
3458															l = l.conj();
3459														}
3460														if conj_rhs {
3461															r = r.conj();
3462														}
3463
3464														if diag_scale {
3465															acc += d * l * r;
3466														} else {
3467															acc += l * r;
3468														}
3469													}
3470													*target = acc * alpha + *target;
3471												}
3472											}
3473
3474											unsafe {
3475												millikernel_colmajor(
3476													C32_SIMD512x4[3],
3477													C32_SIMDpack_512[0],
3478													48,
3479													4,
3480													8,
3481													lhs.as_ptr() as _,
3482													if pack_lhs { packed_lhs.as_mut_ptr() as _ } else { lhs.as_ptr() as _ },
3483													rhs.as_ptr() as _,
3484													if pack_rhs { packed_rhs.as_mut_ptr() as _ } else { rhs.as_ptr() as _ },
3485													m,
3486													n,
3487													&mut MillikernelInfo {
3488														lhs_rs: 48 * sizeof,
3489														packed_lhs_rs: 48 * sizeof * k as isize,
3490														rhs_cs: 4 * sizeof * k as isize,
3491														packed_rhs_cs: 4 * sizeof * k as isize,
3492														micro: MicrokernelInfo {
3493															flags: ((conj_lhs as usize) * FLAGS_CONJ_LHS)
3494																| ((conj_different as usize) * FLAGS_CONJ_NEQ) | (1 * FLAGS_UPPER)
3495																| (1 * FLAGS_32BIT_IDX) | (1 * FLAGS_CPLX),
3496															depth: k,
3497															lhs_rs: 2 * sizeof,
3498															lhs_cs: 2 * cs as isize * sizeof,
3499															rhs_rs: 1 * sizeof,
3500															rhs_cs: k as isize * sizeof,
3501															alpha: &raw const alpha as _,
3502															ptr: dst.as_mut_ptr() as _,
3503															rs: sizeof,
3504															cs: cs as isize * sizeof,
3505															row_idx: idx.as_ptr() as _,
3506															col_idx: idx.as_ptr() as _,
3507															diag_ptr: if diag_scale { diag.as_ptr() as *const () } else { null() },
3508															diag_stride: if diag_scale { size_of::<f32>() as isize } else { 0 },
3509														},
3510													},
3511													&mut Position { row: 0, col: 0 },
3512												)
3513											};
3514											let mut i = 0;
3515											for (&target, &dst) in core::iter::zip(&*target, &*dst) {
3516												if !((target - dst).norm_sqr().sqrt() < 1e-4) {
3517													dbg!(i / cs, i % cs, target, dst);
3518													panic!();
3519												}
3520												i += 1;
3521											}
3522										}
3523									}
3524								}
3525							}
3526						}
3527					}
3528				}
3529			}
3530		}
3531	}
3532}