1use crate::{
2 cache::{kernel_params, DivCeil, KernelParams, CACHE_INFO},
3 gemv, gevv,
4 microkernel::{HMicroKernelFn, MicroKernelFn},
5 pack_operands::{pack_lhs, pack_rhs},
6 simd::MixedSimd,
7 Parallelism, Ptr,
8};
9use core::sync::atomic::{AtomicUsize, Ordering};
10use dyn_stack::{DynStack, StackReq};
11#[cfg(feature = "f16")]
12use half::f16;
13use num_traits::{One, Zero};
14
15#[allow(non_camel_case_types)]
16pub type c32 = num_complex::Complex32;
17#[allow(non_camel_case_types)]
18pub type c64 = num_complex::Complex64;
19
20pub const CACHELINE_ALIGN: usize = {
22 #[cfg(any(
23 target_arch = "x86_64",
24 target_arch = "aarch64",
25 target_arch = "powerpc64",
26 ))]
27 {
28 128
29 }
30 #[cfg(any(
31 target_arch = "arm",
32 target_arch = "mips",
33 target_arch = "mips64",
34 target_arch = "riscv64",
35 ))]
36 {
37 32
38 }
39 #[cfg(target_arch = "s390x")]
40 {
41 256
42 }
43 #[cfg(not(any(
44 target_arch = "x86_64",
45 target_arch = "aarch64",
46 target_arch = "powerpc64",
47 target_arch = "arm",
48 target_arch = "mips",
49 target_arch = "mips64",
50 target_arch = "riscv64",
51 target_arch = "s390x",
52 )))]
53 {
54 64
55 }
56};
57
58#[cfg(feature = "std")]
59thread_local! {
60 pub static L2_SLAB: core::cell::RefCell<dyn_stack::MemBuffer> = core::cell::RefCell::new(dyn_stack::MemBuffer::new(
61 StackReq::new_aligned::<u8>(CACHE_INFO[1].cache_bytes, CACHELINE_ALIGN)
62 ));
63}
64
65pub trait Conj: Copy {
66 fn conj(self) -> Self;
67}
68
69#[cfg(feature = "f16")]
70impl Conj for f16 {
71 #[inline(always)]
72 fn conj(self) -> Self {
73 self
74 }
75}
76
77impl Conj for f32 {
78 #[inline(always)]
79 fn conj(self) -> Self {
80 self
81 }
82}
83impl Conj for f64 {
84 #[inline(always)]
85 fn conj(self) -> Self {
86 self
87 }
88}
89
90impl Conj for c32 {
91 #[inline(always)]
92 fn conj(self) -> Self {
93 c32 {
94 re: self.re,
95 im: -self.im,
96 }
97 }
98}
99impl Conj for c64 {
100 #[inline(always)]
101 fn conj(self) -> Self {
102 c64 {
103 re: self.re,
104 im: -self.im,
105 }
106 }
107}
108
109pub const DEFAULT_THREADING_THRESHOLD: usize = 48 * 48 * 256;
110
111#[cfg(target_arch = "aarch64")]
113pub const DEFAULT_RHS_PACKING_THRESHOLD: usize = 2;
114#[cfg(not(target_arch = "aarch64"))]
115pub const DEFAULT_RHS_PACKING_THRESHOLD: usize = 128;
116
117pub const DEFAULT_LHS_PACKING_THRESHOLD_SINGLE_THREAD: usize = 8;
118pub const DEFAULT_LHS_PACKING_THRESHOLD_MULTI_THREAD: usize = 16;
119
120static THREADING_THRESHOLD: AtomicUsize = AtomicUsize::new(DEFAULT_THREADING_THRESHOLD);
121static RHS_PACKING_THRESHOLD: AtomicUsize = AtomicUsize::new(DEFAULT_RHS_PACKING_THRESHOLD);
122static LHS_PACKING_THRESHOLD_SINGLE_THREAD: AtomicUsize =
123 AtomicUsize::new(DEFAULT_LHS_PACKING_THRESHOLD_SINGLE_THREAD);
124static LHS_PACKING_THRESHOLD_MULTI_THREAD: AtomicUsize =
125 AtomicUsize::new(DEFAULT_LHS_PACKING_THRESHOLD_MULTI_THREAD);
126
127#[inline]
128pub fn get_threading_threshold() -> usize {
129 THREADING_THRESHOLD.load(Ordering::Relaxed)
130}
131#[inline]
132pub fn set_threading_threshold(value: usize) {
133 THREADING_THRESHOLD.store(value, Ordering::Relaxed);
134}
135
136#[inline]
137pub fn get_rhs_packing_threshold() -> usize {
138 RHS_PACKING_THRESHOLD.load(Ordering::Relaxed)
139}
140#[inline]
141pub fn set_rhs_packing_threshold(value: usize) {
142 RHS_PACKING_THRESHOLD.store(value.min(256), Ordering::Relaxed);
143}
144
145#[inline]
146pub fn get_lhs_packing_threshold_single_thread() -> usize {
147 LHS_PACKING_THRESHOLD_SINGLE_THREAD.load(Ordering::Relaxed)
148}
149#[inline]
150pub fn set_lhs_packing_threshold_single_thread(value: usize) {
151 LHS_PACKING_THRESHOLD_SINGLE_THREAD.store(value.min(256), Ordering::Relaxed);
152}
153
154#[inline]
155pub fn get_lhs_packing_threshold_multi_thread() -> usize {
156 LHS_PACKING_THRESHOLD_MULTI_THREAD.load(Ordering::Relaxed)
157}
158#[inline]
159pub fn set_lhs_packing_threshold_multi_thread(value: usize) {
160 LHS_PACKING_THRESHOLD_MULTI_THREAD.store(value.min(256), Ordering::Relaxed);
161}
162
163#[cfg(feature = "rayon")]
164pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
165 fn inner(n_threads: usize, func: &(dyn Fn(usize) + Send + Sync)) {
166 use rayon::prelude::*;
167 (0..n_threads).into_par_iter().for_each(func);
168 }
169
170 inner(n_threads, &func)
171}
172
173#[inline(always)]
174pub unsafe fn gemm_basic_generic<
175 S: MixedSimd<T, T, T, T>,
176 T: Copy
177 + Zero
178 + One
179 + Conj
180 + Send
181 + Sync
182 + core::fmt::Debug
183 + core::ops::Add<Output = T>
184 + core::ops::Mul<Output = T>
185 + core::cmp::PartialEq
186 + 'static,
187 const N: usize,
188 const MR: usize,
189 const NR: usize,
190 const MR_DIV_N: usize,
191 const H_M: usize,
192 const H_N: usize,
193>(
194 simd: S,
195 m: usize,
196 n: usize,
197 k: usize,
198 dst: *mut T,
199 dst_cs: isize,
200 dst_rs: isize,
201 read_dst: bool,
202 lhs: *const T,
203 lhs_cs: isize,
204 lhs_rs: isize,
205 rhs: *const T,
206 rhs_cs: isize,
207 rhs_rs: isize,
208 mut alpha: T,
209 beta: T,
210 conj_dst: bool,
211 conj_lhs: bool,
212 conj_rhs: bool,
213 mul_add: impl Copy + Fn(T, T, T) -> T,
214 dispatcher: &[[MicroKernelFn<T>; NR]; MR_DIV_N],
215 horizontal_dispatcher: &[[HMicroKernelFn<T>; H_N]; H_M],
216 _requires_row_major_rhs: bool,
217 parallelism: Parallelism,
218) {
219 if m == 0 || n == 0 {
220 return;
221 }
222 if !read_dst {
223 alpha.set_zero();
224 }
225
226 if k == 0 {
227 if alpha.is_zero() {
230 for j in 0..n {
231 for i in 0..m {
232 *dst.offset(i as isize * dst_rs + j as isize * dst_cs) = T::zero();
233 }
234 }
235 return;
236 }
237
238 if alpha.is_one() && !conj_dst {
239 return;
240 }
241
242 if conj_dst {
243 for j in 0..n {
244 for i in 0..m {
245 let dst = dst.offset(i as isize * dst_rs + j as isize * dst_cs);
246 *dst = alpha * (*dst).conj();
247 }
248 }
249 } else {
250 for j in 0..n {
251 for i in 0..m {
252 let dst = dst.offset(i as isize * dst_rs + j as isize * dst_cs);
253 *dst = alpha * *dst;
254 }
255 }
256 }
257 return;
258 }
259
260 if (H_M > 0 && H_N > 0) && (!conj_dst && lhs_cs == 1 && rhs_rs == 1 && (m * n) <= 16 * 16) {
261 let kc = 1024;
262 let mut depth = 0;
263 let mut conj_dst = conj_dst;
264 while depth < k {
265 let kb = Ord::min(kc, k - depth);
266 let alpha_status = if alpha.is_zero() {
267 0
268 } else if alpha.is_one() {
269 1
270 } else {
271 2
272 };
273
274 let mut col = 0;
275 while col < n {
276 let nb = Ord::min(H_N, n - col);
277
278 let mut row = 0;
279 while row < m {
280 let mb = Ord::min(H_M, m - row);
281
282 horizontal_dispatcher[mb - 1][nb - 1](
283 kb,
284 dst.wrapping_offset(dst_rs * row as isize + dst_cs * col as isize),
285 lhs.wrapping_offset(lhs_rs * row as isize + depth as isize),
286 rhs.wrapping_offset(rhs_cs * col as isize + depth as isize),
287 dst_cs,
288 dst_rs,
289 lhs_rs,
290 rhs_cs,
291 alpha,
292 beta,
293 alpha_status,
294 conj_dst,
295 conj_lhs,
296 conj_rhs,
297 );
298
299 row += mb;
300 }
301
302 col += nb;
303 }
304
305 alpha = T::one();
306 conj_dst = false;
307 depth += kb;
308 }
309
310 return;
311 }
312
313 if !conj_dst && !conj_lhs && !conj_rhs {
314 if k <= 2 {
315 gevv::gevv(
316 simd, m, n, k, dst, dst_cs, dst_rs, lhs, lhs_cs, lhs_rs, rhs, rhs_cs, rhs_rs,
317 alpha, beta, mul_add,
318 );
319 return;
320 }
321
322 if n <= 1 && lhs_rs == 1 && dst_rs == 1 {
323 gemv::mixed_gemv_colmajor(
324 simd, m, n, k, dst, dst_cs, dst_rs, lhs, lhs_cs, lhs_rs, rhs, rhs_cs, rhs_rs,
325 alpha, beta,
326 );
327 return;
328 }
329 if n <= 1 && lhs_cs == 1 && rhs_rs == 1 {
330 gemv::mixed_gemv_rowmajor(
331 simd, m, n, k, dst, dst_cs, dst_rs, lhs, lhs_cs, lhs_rs, rhs, rhs_cs, rhs_rs,
332 alpha, beta,
333 );
334 return;
335 }
336 if m <= 1 && rhs_cs == 1 && dst_cs == 1 {
337 gemv::mixed_gemv_colmajor(
338 simd, n, m, k, dst, dst_rs, dst_cs, rhs, rhs_rs, rhs_cs, lhs, lhs_rs, lhs_cs,
339 alpha, beta,
340 );
341 return;
342 }
343 if m <= 1 && rhs_rs == 1 && lhs_cs == 1 {
344 gemv::mixed_gemv_rowmajor(
345 simd, n, m, k, dst, dst_rs, dst_cs, rhs, rhs_rs, rhs_cs, lhs, lhs_rs, lhs_cs,
346 alpha, beta,
347 );
348 return;
349 }
350 }
351
352 let KernelParams { kc, mc, nc } = if m <= 64 && n <= 64 {
353 let kc = k.min(512);
355 let alloc = CACHE_INFO[1].cache_bytes / core::mem::size_of::<T>();
356 let mc = (alloc / kc) / MR * MR;
357
358 KernelParams {
359 kc,
360 mc,
361 nc: n.msrv_next_multiple_of(NR),
362 }
363 } else {
364 kernel_params(m, n, k, MR, NR, core::mem::size_of::<T>())
365 };
366 let nc = if nc > 0 {
367 nc
368 } else {
369 match parallelism {
370 Parallelism::None => 128 * NR,
371 #[cfg(feature = "rayon")]
372 Parallelism::Rayon(_) => n.msrv_next_multiple_of(NR),
373 }
374 };
375
376 let simd_align = CACHELINE_ALIGN;
377
378 let packed_rhs_stride = kc * NR;
379 let packed_lhs_stride = kc * MR;
380
381 let dst = Ptr(dst);
382 let lhs = Ptr(lhs as *mut T);
383 let rhs = Ptr(rhs as *mut T);
384
385 #[cfg(feature = "rayon")]
386 let max_threads = match parallelism {
387 Parallelism::None => 1,
388 Parallelism::Rayon(n_threads) => {
389 if n_threads == 0 {
390 rayon::current_num_threads()
391 } else {
392 n_threads
393 }
394 }
395 };
396
397 #[cfg(feature = "rayon")]
398 let threading_threshold = {
399 use core::any::TypeId;
400 let is_c32 = TypeId::of::<c32>() == TypeId::of::<T>();
401 let is_c64 = TypeId::of::<c64>() == TypeId::of::<T>();
402 if is_c32 {
403 get_threading_threshold() / 4
404 } else if is_c64 {
405 get_threading_threshold() / 16
406 } else {
407 get_threading_threshold()
408 }
409 };
410
411 #[cfg(target_arch = "aarch64")]
412 let do_pack_rhs = _requires_row_major_rhs || m > get_rhs_packing_threshold() * MR;
413
414 #[cfg(not(target_arch = "aarch64"))]
416 let do_pack_rhs = (rhs_rs.unsigned_abs() != 1 && m > 2 * MR)
417 || (rhs_rs.unsigned_abs() == 1 && m > get_rhs_packing_threshold() * MR);
418 let do_prepack_lhs = m <= 2 * mc && ((m % N != 0) || lhs_rs != 1);
419
420 let mut mem = if do_pack_rhs || do_prepack_lhs {
421 let rhs_req = StackReq::new_aligned::<T>(
422 if do_pack_rhs {
423 packed_rhs_stride * (nc / NR)
424 } else {
425 0
426 },
427 simd_align,
428 );
429 let lhs_req = StackReq::new_aligned::<T>(
430 if do_prepack_lhs {
431 packed_lhs_stride * (m.msrv_next_multiple_of(MR) / MR)
432 } else {
433 0
434 },
435 simd_align,
436 );
437 Some(dyn_stack::MemBuffer::new(rhs_req.and(lhs_req)))
438 } else {
439 None
440 };
441
442 #[cfg(not(feature = "std"))]
443 let mut l2_slab = dyn_stack::MemBuffer::new(StackReq::new_aligned::<T>(
444 packed_lhs_stride * (mc / MR),
445 simd_align,
446 ));
447
448 let mut packed_storage = mem.as_mut().map(|mem| {
449 let stack = DynStack::new(mem);
450 let (rhs, stack) = stack.make_aligned_uninit::<T>(
451 if do_pack_rhs {
452 packed_rhs_stride * (nc / NR)
453 } else {
454 0
455 },
456 simd_align,
457 );
458
459 (
460 rhs,
461 stack
462 .make_aligned_uninit::<T>(
463 if do_prepack_lhs {
464 packed_lhs_stride * (m.msrv_next_multiple_of(MR) / MR)
465 } else {
466 0
467 },
468 simd_align,
469 )
470 .0,
471 )
472 });
473
474 let (packed_rhs, prepacked_lhs) = packed_storage
475 .as_mut()
476 .map(|storage| {
477 (
478 storage.0.as_mut_ptr() as *mut T,
479 storage.1.as_mut_ptr() as *mut T,
480 )
481 })
482 .unwrap_or((core::ptr::null_mut(), core::ptr::null_mut()));
483
484 let packed_rhs = Ptr(packed_rhs);
485 let prepacked_lhs = Ptr(prepacked_lhs);
486
487 let packed_rhs_rs = if do_pack_rhs { NR as isize } else { rhs_rs };
488 let packed_rhs_cs = if do_pack_rhs { 1 } else { rhs_cs };
489
490 let mut did_pack_lhs = alloc::vec![false; mc / MR];
491 let did_pack_lhs = Ptr((&mut *did_pack_lhs) as *mut _);
492
493 let mut col_outer = 0;
494 while col_outer != n {
495 let n_chunk = nc.min(n - col_outer);
496
497 let mut alpha = alpha;
498 let mut conj_dst = conj_dst;
499
500 let mut depth_outer = 0;
501 while depth_outer != k {
502 let k_chunk = kc.min(k - depth_outer);
503 let alpha_status = if alpha.is_zero() {
504 0
505 } else if alpha.is_one() {
506 1
507 } else {
508 2
509 };
510
511 let n_threads = match parallelism {
512 Parallelism::None => 1,
513 #[cfg(feature = "rayon")]
514 Parallelism::Rayon(_) => {
515 let total_work = (m * n_chunk).saturating_mul(k_chunk);
516 if total_work < threading_threshold {
517 1
518 } else {
519 max_threads
520 }
521 }
522 };
523
524 let packing_threshold = if n_threads == 1 {
525 get_lhs_packing_threshold_single_thread()
526 } else {
527 get_lhs_packing_threshold_multi_thread()
528 };
529
530 if do_pack_rhs {
531 if n_threads <= 1 {
532 #[cfg(target_arch = "aarch64")]
535 pack_rhs::<T, N, NR, _>(
536 simd,
537 n_chunk,
538 k_chunk,
539 packed_rhs,
540 rhs.wrapping_offset(
541 depth_outer as isize * rhs_rs + col_outer as isize * rhs_cs,
542 ),
543 rhs_cs,
544 rhs_rs,
545 packed_rhs_stride,
546 );
547 #[cfg(not(target_arch = "aarch64"))]
548 pack_rhs::<T, 1, NR, _>(
549 simd,
550 n_chunk,
551 k_chunk,
552 packed_rhs,
553 rhs.wrapping_offset(
554 depth_outer as isize * rhs_rs + col_outer as isize * rhs_cs,
555 ),
556 rhs_cs,
557 rhs_rs,
558 packed_rhs_stride,
559 );
560 } else {
561 #[cfg(feature = "rayon")]
562 {
563 let n_tasks = n_chunk.msrv_div_ceil(NR);
564 let base = n_tasks / n_threads;
565 let rem = n_tasks % n_threads;
566
567 let tid_to_col_inner = |tid: usize| {
568 if tid == n_threads {
569 return n_chunk;
570 }
571
572 let col = if tid < rem {
573 NR * tid * (base + 1)
574 } else {
575 NR * (rem + tid * base)
576 };
577 col.min(n_chunk)
578 };
579
580 let func = |tid: usize| {
581 let col_inner = tid_to_col_inner(tid);
582 let ncols = tid_to_col_inner(tid + 1) - col_inner;
583 let j = col_inner / NR;
584
585 if ncols > 0 {
586 #[cfg(target_arch = "aarch64")]
587 pack_rhs::<T, N, NR, _>(
588 simd,
589 ncols,
590 k_chunk,
591 packed_rhs.wrapping_add(j * packed_rhs_stride),
592 rhs.wrapping_offset(
593 depth_outer as isize * rhs_rs
594 + (col_outer + col_inner) as isize * rhs_cs,
595 ),
596 rhs_cs,
597 rhs_rs,
598 packed_rhs_stride,
599 );
600 #[cfg(not(target_arch = "aarch64"))]
601 pack_rhs::<T, 1, NR, _>(
602 simd,
603 ncols,
604 k_chunk,
605 packed_rhs.wrapping_add(j * packed_rhs_stride),
606 rhs.wrapping_offset(
607 depth_outer as isize * rhs_rs
608 + (col_outer + col_inner) as isize * rhs_cs,
609 ),
610 rhs_cs,
611 rhs_rs,
612 packed_rhs_stride,
613 );
614 }
615 };
616 par_for_each(n_threads, func);
617 }
618
619 #[cfg(not(feature = "rayon"))]
620 {
621 unreachable!();
622 }
623 }
624 }
625 if do_prepack_lhs {
626 pack_lhs::<T, N, MR, _>(
627 simd,
628 m,
629 k_chunk,
630 prepacked_lhs,
631 lhs.wrapping_offset(depth_outer as isize * lhs_cs),
632 lhs_cs,
633 lhs_rs,
634 packed_lhs_stride,
635 );
636 }
637
638 let n_col_mini_chunks = (n_chunk + (NR - 1)) / NR;
639
640 let mut n_jobs = 0;
641 let mut row_outer = 0;
642 while row_outer != m {
643 let mut m_chunk = mc.min(m - row_outer);
644 if m_chunk > N && !do_prepack_lhs {
645 m_chunk = m_chunk / N * N;
646 }
647 let n_row_mini_chunks = (m_chunk + (MR - 1)) / MR;
648 n_jobs += n_col_mini_chunks * n_row_mini_chunks;
649 row_outer += m_chunk;
650 }
651
652 let func = move |tid, packed_lhs: Ptr<T>| {
653 let mut did_pack_lhs_storage =
654 alloc::vec![false; if tid > 0 { mc / MR } else { 0 }];
655 let did_pack_lhs = if tid > 0 {
656 &mut *did_pack_lhs_storage
657 } else {
658 &mut *({ did_pack_lhs }.0)
659 };
660
661 let min_jobs_per_thread = n_jobs / n_threads;
662 let rem = n_jobs - n_threads * min_jobs_per_thread;
663
664 let (job_start, job_end) = if tid < rem {
666 let start = tid * (min_jobs_per_thread + 1);
667 (start, start + min_jobs_per_thread + 1)
668 } else {
669 let start = tid * min_jobs_per_thread + rem;
671 (start, start + min_jobs_per_thread)
672 };
673
674 let mut row_outer = 0;
675 let mut job_id = 0;
676 while row_outer != m {
677 let mut m_chunk = mc.min(m - row_outer);
678 if m_chunk > N && !do_prepack_lhs {
679 m_chunk = m_chunk / N * N;
680 }
681 let n_row_mini_chunks = (m_chunk + (MR - 1)) / MR;
682
683 let n_mini_jobs = n_col_mini_chunks * n_row_mini_chunks;
684
685 if job_id >= job_end {
686 return;
687 }
688 if job_id + n_mini_jobs < job_start {
689 row_outer += m_chunk;
690 job_id += n_mini_jobs;
691 continue;
692 }
693
694 let do_pack_lhs = !do_prepack_lhs
695 && ((m_chunk % N != 0) || lhs_rs != 1 || n_chunk > packing_threshold * NR);
696 let packed_lhs_cs = if do_prepack_lhs || do_pack_lhs {
697 MR as isize
698 } else {
699 lhs_cs
700 };
701
702 let mut j = 0;
703 did_pack_lhs.fill(false);
704 while j < n_col_mini_chunks {
705 let mut i = 0;
706 while i < n_row_mini_chunks {
707 let col_inner = NR * j;
708 let n_chunk_inner = NR.min(n_chunk - col_inner);
709
710 let row_inner = MR * i;
711 let m_chunk_inner = MR.min(m_chunk - row_inner);
712
713 if job_id < job_start || job_id >= job_end {
714 job_id += 1;
715 i += 1;
716 continue;
717 }
718 job_id += 1;
719
720 let dst = dst.wrapping_offset(
721 (row_outer + row_inner) as isize * dst_rs
722 + (col_outer + col_inner) as isize * dst_cs,
723 );
724
725 let func =
726 dispatcher[(m_chunk_inner + (N - 1)) / N - 1][n_chunk_inner - 1];
727
728 if do_pack_lhs && !did_pack_lhs[i] {
729 pack_lhs::<T, N, MR, _>(
730 simd,
731 m_chunk_inner,
732 k_chunk,
733 packed_lhs.wrapping_add(i * packed_lhs_stride),
734 lhs.wrapping_offset(
735 (row_outer + row_inner) as isize * lhs_rs
736 + depth_outer as isize * lhs_cs,
737 ),
738 lhs_cs,
739 lhs_rs,
740 packed_lhs_stride,
741 );
742 did_pack_lhs[i] = true;
743 }
744
745 func(
746 m_chunk_inner,
747 n_chunk_inner,
748 k_chunk,
749 dst.0,
750 if do_pack_lhs {
751 packed_lhs.wrapping_add(i * packed_lhs_stride).0
752 } else if do_prepack_lhs {
753 packed_lhs
754 .wrapping_add((i + row_outer / MR) * packed_lhs_stride)
755 .0
756 } else {
757 lhs.wrapping_offset(
758 (row_outer + row_inner) as isize * lhs_rs
759 + depth_outer as isize * lhs_cs,
760 )
761 .0
762 },
763 if do_pack_rhs {
764 packed_rhs.wrapping_add(j * packed_rhs_stride).0
765 } else {
766 rhs.wrapping_offset(
767 depth_outer as isize * rhs_rs
768 + (col_outer + col_inner) as isize * rhs_cs,
769 )
770 .0
771 },
772 dst_cs,
773 dst_rs,
774 packed_lhs_cs,
775 packed_rhs_rs,
776 packed_rhs_cs,
777 alpha,
778 beta,
779 alpha_status,
780 conj_dst,
781 conj_lhs,
782 conj_rhs,
783 core::ptr::null(),
784 );
785 i += 1;
786 }
787 j += 1;
788 }
789
790 row_outer += m_chunk;
791 }
792 };
793
794 if do_prepack_lhs {
795 match parallelism {
796 Parallelism::None => func(0, prepacked_lhs),
797 #[cfg(feature = "rayon")]
798 Parallelism::Rayon(_) => {
799 if n_threads == 1 {
800 func(0, prepacked_lhs);
801 } else {
802 par_for_each(n_threads, |tid| func(tid, prepacked_lhs));
803 }
804 }
805 }
806 } else {
807 #[cfg(feature = "std")]
808 let func = |tid: usize| {
809 L2_SLAB.with(|mem| {
810 let mut mem = mem.borrow_mut();
811 let stack = DynStack::new(&mut mem);
812 let (packed_lhs_storage, _) = stack
813 .make_aligned_uninit::<T>(packed_lhs_stride * (mc / MR), simd_align);
814 let packed_lhs = Ptr(packed_lhs_storage.as_mut_ptr() as *mut T);
815 func(tid, packed_lhs);
816 });
817 };
818
819 #[cfg(not(feature = "std"))]
820 let mut func = |tid: usize| {
821 let stack = DynStack::new(&mut l2_slab);
822 let (packed_lhs_storage, _) =
823 stack.make_aligned_uninit::<T>(packed_lhs_stride * (mc / MR), simd_align);
824 let packed_lhs = Ptr(packed_lhs_storage.as_mut_ptr() as *mut T);
825 func(tid, packed_lhs);
826 };
827
828 match parallelism {
829 Parallelism::None => func(0),
830 #[cfg(feature = "rayon")]
831 Parallelism::Rayon(_) => {
832 if n_threads == 1 {
833 func(0);
834 } else {
835 par_for_each(n_threads, func);
836 }
837 }
838 }
839 }
840
841 conj_dst = false;
842 alpha.set_one();
843
844 depth_outer += k_chunk;
845 }
846 col_outer += n_chunk;
847 }
848}
849
850#[macro_export]
851macro_rules! __inject_mod {
852 ($module: ident, $ty: ident, $N: expr, $simd: ident, $requires_packed_rhs: expr) => {
853 mod $module {
854 use super::*;
855 use crate::gemm_common::simd::MixedSimd;
856 use crate::microkernel::$module::$ty::*;
857 const N: usize = $N;
858
859 #[inline(never)]
860 pub unsafe fn gemm_basic(
861 m: usize,
862 n: usize,
863 k: usize,
864 dst: *mut $ty,
865 dst_cs: isize,
866 dst_rs: isize,
867 read_dst: bool,
868 lhs: *const $ty,
869 lhs_cs: isize,
870 lhs_rs: isize,
871 rhs: *const $ty,
872 rhs_cs: isize,
873 rhs_rs: isize,
874 alpha: $ty,
875 beta: $ty,
876 conj_dst: bool,
877 conj_lhs: bool,
878 conj_rhs: bool,
879 parallelism: $crate::Parallelism,
880 ) {
881 $crate::gemm::gemm_basic_generic::<
882 _,
883 $ty,
884 N,
885 { MR_DIV_N * N },
886 NR,
887 MR_DIV_N,
888 H_M,
889 H_N,
890 >(
891 <$crate::simd::$simd as MixedSimd<$ty, $ty, $ty, $ty>>::try_new().unwrap(),
892 m,
893 n,
894 k,
895 dst,
896 dst_cs,
897 dst_rs,
898 read_dst,
899 lhs,
900 lhs_cs,
901 lhs_rs,
902 rhs,
903 rhs_cs,
904 rhs_rs,
905 alpha,
906 beta,
907 conj_dst,
908 conj_lhs,
909 conj_rhs,
910 |a, b, c| a * b + c,
911 &UKR,
912 &H_UKR,
913 $requires_packed_rhs,
914 parallelism,
915 );
916 }
917 }
918 };
919}
920
921#[macro_export]
922macro_rules! __inject_mod_cplx {
923 ($module: ident, $ty: ident, $N: expr, $simd: ident) => {
924 paste::paste! {
925 mod [<$module _cplx>] {
926 use super::*;
927 use crate::microkernel::$module::$ty::*;
928 use crate::gemm_common::simd::MixedSimd;
929 const N: usize = $N;
930
931 #[inline(never)]
932 pub unsafe fn gemm_basic_cplx(
933 m: usize,
934 n: usize,
935 k: usize,
936 dst: *mut num_complex::Complex<T>,
937 dst_cs: isize,
938 dst_rs: isize,
939 read_dst: bool,
940 lhs: *const num_complex::Complex<T>,
941 lhs_cs: isize,
942 lhs_rs: isize,
943 rhs: *const num_complex::Complex<T>,
944 rhs_cs: isize,
945 rhs_rs: isize,
946 alpha: num_complex::Complex<T>,
947 beta: num_complex::Complex<T>,
948 conj_dst: bool,
949 conj_lhs: bool,
950 conj_rhs: bool,
951 parallelism: $crate::Parallelism,
952 ) {
953 $crate::gemm::gemm_basic_generic::<_, _, N, { CPLX_MR_DIV_N * N }, CPLX_NR, CPLX_MR_DIV_N, H_CPLX_M, H_CPLX_N>(
954 <$crate::simd::$simd as MixedSimd<T, T, T, T>>::try_new().unwrap(),
955 m,
956 n,
957 k,
958 dst,
959 dst_cs,
960 dst_rs,
961 read_dst,
962 lhs,
963 lhs_cs,
964 lhs_rs,
965 rhs,
966 rhs_cs,
967 rhs_rs,
968 alpha,
969 beta,
970 conj_dst,
971 conj_lhs,
972 conj_rhs,
973 |a, b, c| a * b + c,
974 &CPLX_UKR,
975 &H_CPLX_UKR,
976 false,
977 parallelism,
978 );
979 }
980 }
981 }
982 };
983}
984
985#[macro_export]
986macro_rules! gemm_def {
987 ($ty: tt, $multiplier: expr) => {
988 type GemmTy = unsafe fn(
989 usize,
990 usize,
991 usize,
992 *mut T,
993 isize,
994 isize,
995 bool,
996 *const T,
997 isize,
998 isize,
999 *const T,
1000 isize,
1001 isize,
1002 T,
1003 T,
1004 bool,
1005 bool,
1006 bool,
1007 $crate::Parallelism,
1008 );
1009
1010 #[inline]
1011 fn init_gemm_fn() -> GemmTy {
1012 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1013 {
1014 #[cfg(feature = "nightly")]
1015 if $crate::feature_detected!("avx512f") {
1016 return avx512f::gemm_basic;
1017 }
1018 if $crate::feature_detected!("fma") {
1019 fma::gemm_basic
1020 } else {
1021 scalar::gemm_basic
1022 }
1023 }
1024
1025 #[cfg(target_arch = "aarch64")]
1026 {
1027 if $crate::feature_detected!("neon") {
1028 #[cfg(feature = "experimental-apple-amx")]
1029 if $crate::cache::HasAmx::get() {
1030 return amx::gemm_basic;
1031 }
1032 neon::gemm_basic
1033 } else {
1034 scalar::gemm_basic
1035 }
1036 }
1037
1038 #[cfg(target_arch = "wasm32")]
1039 {
1040 if $crate::feature_detected!("simd128") {
1041 simd128::gemm_basic
1042 } else {
1043 scalar::gemm_basic
1044 }
1045 }
1046
1047 #[cfg(not(any(
1048 target_arch = "x86",
1049 target_arch = "x86_64",
1050 target_arch = "aarch64",
1051 target_arch = "wasm32",
1052 )))]
1053 {
1054 scalar::gemm_basic
1055 }
1056 }
1057
1058 static GEMM_PTR: ::core::sync::atomic::AtomicPtr<()> =
1059 ::core::sync::atomic::AtomicPtr::new(::core::ptr::null_mut());
1060
1061 #[inline(never)]
1062 fn init_gemm_ptr() -> GemmTy {
1063 let gemm_fn = init_gemm_fn();
1064 GEMM_PTR.store(gemm_fn as *mut (), ::core::sync::atomic::Ordering::Relaxed);
1065 gemm_fn
1066 }
1067
1068 #[inline(always)]
1069 pub fn get_gemm_fn() -> GemmTy {
1070 let mut gemm_fn = GEMM_PTR.load(::core::sync::atomic::Ordering::Relaxed);
1071 if gemm_fn.is_null() {
1072 gemm_fn = init_gemm_ptr() as *mut ();
1073 }
1074 unsafe { ::core::mem::transmute(gemm_fn) }
1075 }
1076
1077 $crate::__inject_mod!(scalar, $ty, 1, Scalar, false);
1078
1079 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1080 $crate::__inject_mod!(fma, $ty, 4 * $multiplier, V3, false);
1081 #[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
1082 $crate::__inject_mod!(avx512f, $ty, 8 * $multiplier, V4, false);
1083
1084 #[cfg(target_arch = "aarch64")]
1085 $crate::__inject_mod!(neon, $ty, 2 * $multiplier, Scalar, false);
1086 #[cfg(target_arch = "aarch64")]
1087 #[cfg(feature = "experimental-apple-amx")]
1088 $crate::__inject_mod!(amx, $ty, 8 * $multiplier, Scalar, true);
1089
1090 #[cfg(target_arch = "wasm32")]
1091 $crate::__inject_mod!(simd128, $ty, 2 * $multiplier, Scalar, false);
1092 };
1093}
1094
1095#[macro_export]
1096macro_rules! gemm_cplx_def {
1097 ($ty: tt, $cplx_ty: tt, $multiplier: expr) => {
1098 type GemmCplxTy = unsafe fn(
1099 usize,
1100 usize,
1101 usize,
1102 *mut num_complex::Complex<T>,
1103 isize,
1104 isize,
1105 bool,
1106 *const num_complex::Complex<T>,
1107 isize,
1108 isize,
1109 *const num_complex::Complex<T>,
1110 isize,
1111 isize,
1112 num_complex::Complex<T>,
1113 num_complex::Complex<T>,
1114 bool,
1115 bool,
1116 bool,
1117 $crate::Parallelism,
1118 );
1119
1120 fn init_gemm_cplx_fn() -> GemmCplxTy {
1121 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1122 {
1123 #[cfg(feature = "nightly")]
1124 if $crate::feature_detected!("avx512f") {
1125 return avx512f_cplx::gemm_basic_cplx;
1126 }
1127 if $crate::feature_detected!("fma") {
1128 return fma_cplx::gemm_basic_cplx;
1129 }
1130 }
1131
1132 #[cfg(target_arch = "aarch64")]
1133 {
1134 #[cfg(target_arch = "aarch64")]
1135 if $crate::feature_detected!("neon") && $crate::feature_detected!("fcma") {
1136 return neonfcma::gemm_basic;
1137 }
1138 }
1139
1140 scalar_cplx::gemm_basic_cplx
1141 }
1142
1143 static GEMM_PTR: ::core::sync::atomic::AtomicPtr<()> =
1144 ::core::sync::atomic::AtomicPtr::new(::core::ptr::null_mut());
1145
1146 #[inline(never)]
1147 fn init_gemm_ptr() -> GemmCplxTy {
1148 let gemm_fn = init_gemm_cplx_fn();
1149 GEMM_PTR.store(gemm_fn as *mut (), ::core::sync::atomic::Ordering::Relaxed);
1150 gemm_fn
1151 }
1152
1153 #[inline(always)]
1154 pub fn get_gemm_fn() -> GemmCplxTy {
1155 let mut gemm_fn = GEMM_PTR.load(::core::sync::atomic::Ordering::Relaxed);
1156 if gemm_fn.is_null() {
1157 gemm_fn = init_gemm_ptr() as *mut ();
1158 }
1159 unsafe { ::core::mem::transmute(gemm_fn) }
1160 }
1161
1162 $crate::__inject_mod_cplx!(scalar, $ty, 1, Scalar);
1163
1164 #[cfg(target_arch = "aarch64")]
1165 $crate::__inject_mod!(neonfcma, $cplx_ty, 1 * $multiplier, Scalar, false);
1166
1167 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1168 $crate::__inject_mod_cplx!(fma, $ty, 2 * $multiplier, V3);
1169 #[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
1170 $crate::__inject_mod_cplx!(avx512f, $ty, 4 * $multiplier, V4);
1171 };
1172}