gemm_common/
gemm.rs

1use crate::{
2    cache::{kernel_params, DivCeil, KernelParams, CACHE_INFO},
3    gemv, gevv,
4    microkernel::{HMicroKernelFn, MicroKernelFn},
5    pack_operands::{pack_lhs, pack_rhs},
6    simd::MixedSimd,
7    Parallelism, Ptr,
8};
9use core::sync::atomic::{AtomicUsize, Ordering};
10use dyn_stack::{DynStack, StackReq};
11#[cfg(feature = "f16")]
12use half::f16;
13use num_traits::{One, Zero};
14
15#[allow(non_camel_case_types)]
16pub type c32 = num_complex::Complex32;
17#[allow(non_camel_case_types)]
18pub type c64 = num_complex::Complex64;
19
20// https://rust-lang.github.io/hashbrown/src/crossbeam_utils/cache_padded.rs.html#128-130
21pub const CACHELINE_ALIGN: usize = {
22    #[cfg(any(
23        target_arch = "x86_64",
24        target_arch = "aarch64",
25        target_arch = "powerpc64",
26    ))]
27    {
28        128
29    }
30    #[cfg(any(
31        target_arch = "arm",
32        target_arch = "mips",
33        target_arch = "mips64",
34        target_arch = "riscv64",
35    ))]
36    {
37        32
38    }
39    #[cfg(target_arch = "s390x")]
40    {
41        256
42    }
43    #[cfg(not(any(
44        target_arch = "x86_64",
45        target_arch = "aarch64",
46        target_arch = "powerpc64",
47        target_arch = "arm",
48        target_arch = "mips",
49        target_arch = "mips64",
50        target_arch = "riscv64",
51        target_arch = "s390x",
52    )))]
53    {
54        64
55    }
56};
57
58#[cfg(feature = "std")]
59thread_local! {
60    pub static L2_SLAB: core::cell::RefCell<dyn_stack::MemBuffer> = core::cell::RefCell::new(dyn_stack::MemBuffer::new(
61        StackReq::new_aligned::<u8>(CACHE_INFO[1].cache_bytes, CACHELINE_ALIGN)
62    ));
63}
64
65pub trait Conj: Copy {
66    fn conj(self) -> Self;
67}
68
69#[cfg(feature = "f16")]
70impl Conj for f16 {
71    #[inline(always)]
72    fn conj(self) -> Self {
73        self
74    }
75}
76
77impl Conj for f32 {
78    #[inline(always)]
79    fn conj(self) -> Self {
80        self
81    }
82}
83impl Conj for f64 {
84    #[inline(always)]
85    fn conj(self) -> Self {
86        self
87    }
88}
89
90impl Conj for c32 {
91    #[inline(always)]
92    fn conj(self) -> Self {
93        c32 {
94            re: self.re,
95            im: -self.im,
96        }
97    }
98}
99impl Conj for c64 {
100    #[inline(always)]
101    fn conj(self) -> Self {
102        c64 {
103            re: self.re,
104            im: -self.im,
105        }
106    }
107}
108
109pub const DEFAULT_THREADING_THRESHOLD: usize = 48 * 48 * 256;
110
111// we REALLY want to pack the rhs on aarch64 since we can use mul_add_lane
112#[cfg(target_arch = "aarch64")]
113pub const DEFAULT_RHS_PACKING_THRESHOLD: usize = 2;
114#[cfg(not(target_arch = "aarch64"))]
115pub const DEFAULT_RHS_PACKING_THRESHOLD: usize = 128;
116
117pub const DEFAULT_LHS_PACKING_THRESHOLD_SINGLE_THREAD: usize = 8;
118pub const DEFAULT_LHS_PACKING_THRESHOLD_MULTI_THREAD: usize = 16;
119
120static THREADING_THRESHOLD: AtomicUsize = AtomicUsize::new(DEFAULT_THREADING_THRESHOLD);
121static RHS_PACKING_THRESHOLD: AtomicUsize = AtomicUsize::new(DEFAULT_RHS_PACKING_THRESHOLD);
122static LHS_PACKING_THRESHOLD_SINGLE_THREAD: AtomicUsize =
123    AtomicUsize::new(DEFAULT_LHS_PACKING_THRESHOLD_SINGLE_THREAD);
124static LHS_PACKING_THRESHOLD_MULTI_THREAD: AtomicUsize =
125    AtomicUsize::new(DEFAULT_LHS_PACKING_THRESHOLD_MULTI_THREAD);
126
127#[inline]
128pub fn get_threading_threshold() -> usize {
129    THREADING_THRESHOLD.load(Ordering::Relaxed)
130}
131#[inline]
132pub fn set_threading_threshold(value: usize) {
133    THREADING_THRESHOLD.store(value, Ordering::Relaxed);
134}
135
136#[inline]
137pub fn get_rhs_packing_threshold() -> usize {
138    RHS_PACKING_THRESHOLD.load(Ordering::Relaxed)
139}
140#[inline]
141pub fn set_rhs_packing_threshold(value: usize) {
142    RHS_PACKING_THRESHOLD.store(value.min(256), Ordering::Relaxed);
143}
144
145#[inline]
146pub fn get_lhs_packing_threshold_single_thread() -> usize {
147    LHS_PACKING_THRESHOLD_SINGLE_THREAD.load(Ordering::Relaxed)
148}
149#[inline]
150pub fn set_lhs_packing_threshold_single_thread(value: usize) {
151    LHS_PACKING_THRESHOLD_SINGLE_THREAD.store(value.min(256), Ordering::Relaxed);
152}
153
154#[inline]
155pub fn get_lhs_packing_threshold_multi_thread() -> usize {
156    LHS_PACKING_THRESHOLD_MULTI_THREAD.load(Ordering::Relaxed)
157}
158#[inline]
159pub fn set_lhs_packing_threshold_multi_thread(value: usize) {
160    LHS_PACKING_THRESHOLD_MULTI_THREAD.store(value.min(256), Ordering::Relaxed);
161}
162
163#[cfg(feature = "rayon")]
164pub fn par_for_each(n_threads: usize, func: impl Fn(usize) + Send + Sync) {
165    fn inner(n_threads: usize, func: &(dyn Fn(usize) + Send + Sync)) {
166        use rayon::prelude::*;
167        (0..n_threads).into_par_iter().for_each(func);
168    }
169
170    inner(n_threads, &func)
171}
172
173#[inline(always)]
174pub unsafe fn gemm_basic_generic<
175    S: MixedSimd<T, T, T, T>,
176    T: Copy
177        + Zero
178        + One
179        + Conj
180        + Send
181        + Sync
182        + core::fmt::Debug
183        + core::ops::Add<Output = T>
184        + core::ops::Mul<Output = T>
185        + core::cmp::PartialEq
186        + 'static,
187    const N: usize,
188    const MR: usize,
189    const NR: usize,
190    const MR_DIV_N: usize,
191    const H_M: usize,
192    const H_N: usize,
193>(
194    simd: S,
195    m: usize,
196    n: usize,
197    k: usize,
198    dst: *mut T,
199    dst_cs: isize,
200    dst_rs: isize,
201    read_dst: bool,
202    lhs: *const T,
203    lhs_cs: isize,
204    lhs_rs: isize,
205    rhs: *const T,
206    rhs_cs: isize,
207    rhs_rs: isize,
208    mut alpha: T,
209    beta: T,
210    conj_dst: bool,
211    conj_lhs: bool,
212    conj_rhs: bool,
213    mul_add: impl Copy + Fn(T, T, T) -> T,
214    dispatcher: &[[MicroKernelFn<T>; NR]; MR_DIV_N],
215    horizontal_dispatcher: &[[HMicroKernelFn<T>; H_N]; H_M],
216    _requires_row_major_rhs: bool,
217    parallelism: Parallelism,
218) {
219    if m == 0 || n == 0 {
220        return;
221    }
222    if !read_dst {
223        alpha.set_zero();
224    }
225
226    if k == 0 {
227        // dst = alpha * conj?(dst)
228
229        if alpha.is_zero() {
230            for j in 0..n {
231                for i in 0..m {
232                    *dst.offset(i as isize * dst_rs + j as isize * dst_cs) = T::zero();
233                }
234            }
235            return;
236        }
237
238        if alpha.is_one() && !conj_dst {
239            return;
240        }
241
242        if conj_dst {
243            for j in 0..n {
244                for i in 0..m {
245                    let dst = dst.offset(i as isize * dst_rs + j as isize * dst_cs);
246                    *dst = alpha * (*dst).conj();
247                }
248            }
249        } else {
250            for j in 0..n {
251                for i in 0..m {
252                    let dst = dst.offset(i as isize * dst_rs + j as isize * dst_cs);
253                    *dst = alpha * *dst;
254                }
255            }
256        }
257        return;
258    }
259
260    if (H_M > 0 && H_N > 0) && (!conj_dst && lhs_cs == 1 && rhs_rs == 1 && (m * n) <= 16 * 16) {
261        let kc = 1024;
262        let mut depth = 0;
263        let mut conj_dst = conj_dst;
264        while depth < k {
265            let kb = Ord::min(kc, k - depth);
266            let alpha_status = if alpha.is_zero() {
267                0
268            } else if alpha.is_one() {
269                1
270            } else {
271                2
272            };
273
274            let mut col = 0;
275            while col < n {
276                let nb = Ord::min(H_N, n - col);
277
278                let mut row = 0;
279                while row < m {
280                    let mb = Ord::min(H_M, m - row);
281
282                    horizontal_dispatcher[mb - 1][nb - 1](
283                        kb,
284                        dst.wrapping_offset(dst_rs * row as isize + dst_cs * col as isize),
285                        lhs.wrapping_offset(lhs_rs * row as isize + depth as isize),
286                        rhs.wrapping_offset(rhs_cs * col as isize + depth as isize),
287                        dst_cs,
288                        dst_rs,
289                        lhs_rs,
290                        rhs_cs,
291                        alpha,
292                        beta,
293                        alpha_status,
294                        conj_dst,
295                        conj_lhs,
296                        conj_rhs,
297                    );
298
299                    row += mb;
300                }
301
302                col += nb;
303            }
304
305            alpha = T::one();
306            conj_dst = false;
307            depth += kb;
308        }
309
310        return;
311    }
312
313    if !conj_dst && !conj_lhs && !conj_rhs {
314        if k <= 2 {
315            gevv::gevv(
316                simd, m, n, k, dst, dst_cs, dst_rs, lhs, lhs_cs, lhs_rs, rhs, rhs_cs, rhs_rs,
317                alpha, beta, mul_add,
318            );
319            return;
320        }
321
322        if n <= 1 && lhs_rs == 1 && dst_rs == 1 {
323            gemv::mixed_gemv_colmajor(
324                simd, m, n, k, dst, dst_cs, dst_rs, lhs, lhs_cs, lhs_rs, rhs, rhs_cs, rhs_rs,
325                alpha, beta,
326            );
327            return;
328        }
329        if n <= 1 && lhs_cs == 1 && rhs_rs == 1 {
330            gemv::mixed_gemv_rowmajor(
331                simd, m, n, k, dst, dst_cs, dst_rs, lhs, lhs_cs, lhs_rs, rhs, rhs_cs, rhs_rs,
332                alpha, beta,
333            );
334            return;
335        }
336        if m <= 1 && rhs_cs == 1 && dst_cs == 1 {
337            gemv::mixed_gemv_colmajor(
338                simd, n, m, k, dst, dst_rs, dst_cs, rhs, rhs_rs, rhs_cs, lhs, lhs_rs, lhs_cs,
339                alpha, beta,
340            );
341            return;
342        }
343        if m <= 1 && rhs_rs == 1 && lhs_cs == 1 {
344            gemv::mixed_gemv_rowmajor(
345                simd, n, m, k, dst, dst_rs, dst_cs, rhs, rhs_rs, rhs_cs, lhs, lhs_rs, lhs_cs,
346                alpha, beta,
347            );
348            return;
349        }
350    }
351
352    let KernelParams { kc, mc, nc } = if m <= 64 && n <= 64 {
353        // skip expensive kernel_params call for small sizes
354        let kc = k.min(512);
355        let alloc = CACHE_INFO[1].cache_bytes / core::mem::size_of::<T>();
356        let mc = (alloc / kc) / MR * MR;
357
358        KernelParams {
359            kc,
360            mc,
361            nc: n.msrv_next_multiple_of(NR),
362        }
363    } else {
364        kernel_params(m, n, k, MR, NR, core::mem::size_of::<T>())
365    };
366    let nc = if nc > 0 {
367        nc
368    } else {
369        match parallelism {
370            Parallelism::None => 128 * NR,
371            #[cfg(feature = "rayon")]
372            Parallelism::Rayon(_) => n.msrv_next_multiple_of(NR),
373        }
374    };
375
376    let simd_align = CACHELINE_ALIGN;
377
378    let packed_rhs_stride = kc * NR;
379    let packed_lhs_stride = kc * MR;
380
381    let dst = Ptr(dst);
382    let lhs = Ptr(lhs as *mut T);
383    let rhs = Ptr(rhs as *mut T);
384
385    #[cfg(feature = "rayon")]
386    let max_threads = match parallelism {
387        Parallelism::None => 1,
388        Parallelism::Rayon(n_threads) => {
389            if n_threads == 0 {
390                rayon::current_num_threads()
391            } else {
392                n_threads
393            }
394        }
395    };
396
397    #[cfg(feature = "rayon")]
398    let threading_threshold = {
399        use core::any::TypeId;
400        let is_c32 = TypeId::of::<c32>() == TypeId::of::<T>();
401        let is_c64 = TypeId::of::<c64>() == TypeId::of::<T>();
402        if is_c32 {
403            get_threading_threshold() / 4
404        } else if is_c64 {
405            get_threading_threshold() / 16
406        } else {
407            get_threading_threshold()
408        }
409    };
410
411    #[cfg(target_arch = "aarch64")]
412    let do_pack_rhs = _requires_row_major_rhs || m > get_rhs_packing_threshold() * MR;
413
414    // no need to pack if the lhs is already contiguous-ish
415    #[cfg(not(target_arch = "aarch64"))]
416    let do_pack_rhs = (rhs_rs.unsigned_abs() != 1 && m > 2 * MR)
417        || (rhs_rs.unsigned_abs() == 1 && m > get_rhs_packing_threshold() * MR);
418    let do_prepack_lhs = m <= 2 * mc && ((m % N != 0) || lhs_rs != 1);
419
420    let mut mem = if do_pack_rhs || do_prepack_lhs {
421        let rhs_req = StackReq::new_aligned::<T>(
422            if do_pack_rhs {
423                packed_rhs_stride * (nc / NR)
424            } else {
425                0
426            },
427            simd_align,
428        );
429        let lhs_req = StackReq::new_aligned::<T>(
430            if do_prepack_lhs {
431                packed_lhs_stride * (m.msrv_next_multiple_of(MR) / MR)
432            } else {
433                0
434            },
435            simd_align,
436        );
437        Some(dyn_stack::MemBuffer::new(rhs_req.and(lhs_req)))
438    } else {
439        None
440    };
441
442    #[cfg(not(feature = "std"))]
443    let mut l2_slab = dyn_stack::MemBuffer::new(StackReq::new_aligned::<T>(
444        packed_lhs_stride * (mc / MR),
445        simd_align,
446    ));
447
448    let mut packed_storage = mem.as_mut().map(|mem| {
449        let stack = DynStack::new(mem);
450        let (rhs, stack) = stack.make_aligned_uninit::<T>(
451            if do_pack_rhs {
452                packed_rhs_stride * (nc / NR)
453            } else {
454                0
455            },
456            simd_align,
457        );
458
459        (
460            rhs,
461            stack
462                .make_aligned_uninit::<T>(
463                    if do_prepack_lhs {
464                        packed_lhs_stride * (m.msrv_next_multiple_of(MR) / MR)
465                    } else {
466                        0
467                    },
468                    simd_align,
469                )
470                .0,
471        )
472    });
473
474    let (packed_rhs, prepacked_lhs) = packed_storage
475        .as_mut()
476        .map(|storage| {
477            (
478                storage.0.as_mut_ptr() as *mut T,
479                storage.1.as_mut_ptr() as *mut T,
480            )
481        })
482        .unwrap_or((core::ptr::null_mut(), core::ptr::null_mut()));
483
484    let packed_rhs = Ptr(packed_rhs);
485    let prepacked_lhs = Ptr(prepacked_lhs);
486
487    let packed_rhs_rs = if do_pack_rhs { NR as isize } else { rhs_rs };
488    let packed_rhs_cs = if do_pack_rhs { 1 } else { rhs_cs };
489
490    let mut did_pack_lhs = alloc::vec![false; mc / MR];
491    let did_pack_lhs = Ptr((&mut *did_pack_lhs) as *mut _);
492
493    let mut col_outer = 0;
494    while col_outer != n {
495        let n_chunk = nc.min(n - col_outer);
496
497        let mut alpha = alpha;
498        let mut conj_dst = conj_dst;
499
500        let mut depth_outer = 0;
501        while depth_outer != k {
502            let k_chunk = kc.min(k - depth_outer);
503            let alpha_status = if alpha.is_zero() {
504                0
505            } else if alpha.is_one() {
506                1
507            } else {
508                2
509            };
510
511            let n_threads = match parallelism {
512                Parallelism::None => 1,
513                #[cfg(feature = "rayon")]
514                Parallelism::Rayon(_) => {
515                    let total_work = (m * n_chunk).saturating_mul(k_chunk);
516                    if total_work < threading_threshold {
517                        1
518                    } else {
519                        max_threads
520                    }
521                }
522            };
523
524            let packing_threshold = if n_threads == 1 {
525                get_lhs_packing_threshold_single_thread()
526            } else {
527                get_lhs_packing_threshold_multi_thread()
528            };
529
530            if do_pack_rhs {
531                if n_threads <= 1 {
532                    // on aarch64 we want the registers to be fully initialized
533                    // for use with neon/amx
534                    #[cfg(target_arch = "aarch64")]
535                    pack_rhs::<T, N, NR, _>(
536                        simd,
537                        n_chunk,
538                        k_chunk,
539                        packed_rhs,
540                        rhs.wrapping_offset(
541                            depth_outer as isize * rhs_rs + col_outer as isize * rhs_cs,
542                        ),
543                        rhs_cs,
544                        rhs_rs,
545                        packed_rhs_stride,
546                    );
547                    #[cfg(not(target_arch = "aarch64"))]
548                    pack_rhs::<T, 1, NR, _>(
549                        simd,
550                        n_chunk,
551                        k_chunk,
552                        packed_rhs,
553                        rhs.wrapping_offset(
554                            depth_outer as isize * rhs_rs + col_outer as isize * rhs_cs,
555                        ),
556                        rhs_cs,
557                        rhs_rs,
558                        packed_rhs_stride,
559                    );
560                } else {
561                    #[cfg(feature = "rayon")]
562                    {
563                        let n_tasks = n_chunk.msrv_div_ceil(NR);
564                        let base = n_tasks / n_threads;
565                        let rem = n_tasks % n_threads;
566
567                        let tid_to_col_inner = |tid: usize| {
568                            if tid == n_threads {
569                                return n_chunk;
570                            }
571
572                            let col = if tid < rem {
573                                NR * tid * (base + 1)
574                            } else {
575                                NR * (rem + tid * base)
576                            };
577                            col.min(n_chunk)
578                        };
579
580                        let func = |tid: usize| {
581                            let col_inner = tid_to_col_inner(tid);
582                            let ncols = tid_to_col_inner(tid + 1) - col_inner;
583                            let j = col_inner / NR;
584
585                            if ncols > 0 {
586                                #[cfg(target_arch = "aarch64")]
587                                pack_rhs::<T, N, NR, _>(
588                                    simd,
589                                    ncols,
590                                    k_chunk,
591                                    packed_rhs.wrapping_add(j * packed_rhs_stride),
592                                    rhs.wrapping_offset(
593                                        depth_outer as isize * rhs_rs
594                                            + (col_outer + col_inner) as isize * rhs_cs,
595                                    ),
596                                    rhs_cs,
597                                    rhs_rs,
598                                    packed_rhs_stride,
599                                );
600                                #[cfg(not(target_arch = "aarch64"))]
601                                pack_rhs::<T, 1, NR, _>(
602                                    simd,
603                                    ncols,
604                                    k_chunk,
605                                    packed_rhs.wrapping_add(j * packed_rhs_stride),
606                                    rhs.wrapping_offset(
607                                        depth_outer as isize * rhs_rs
608                                            + (col_outer + col_inner) as isize * rhs_cs,
609                                    ),
610                                    rhs_cs,
611                                    rhs_rs,
612                                    packed_rhs_stride,
613                                );
614                            }
615                        };
616                        par_for_each(n_threads, func);
617                    }
618
619                    #[cfg(not(feature = "rayon"))]
620                    {
621                        unreachable!();
622                    }
623                }
624            }
625            if do_prepack_lhs {
626                pack_lhs::<T, N, MR, _>(
627                    simd,
628                    m,
629                    k_chunk,
630                    prepacked_lhs,
631                    lhs.wrapping_offset(depth_outer as isize * lhs_cs),
632                    lhs_cs,
633                    lhs_rs,
634                    packed_lhs_stride,
635                );
636            }
637
638            let n_col_mini_chunks = (n_chunk + (NR - 1)) / NR;
639
640            let mut n_jobs = 0;
641            let mut row_outer = 0;
642            while row_outer != m {
643                let mut m_chunk = mc.min(m - row_outer);
644                if m_chunk > N && !do_prepack_lhs {
645                    m_chunk = m_chunk / N * N;
646                }
647                let n_row_mini_chunks = (m_chunk + (MR - 1)) / MR;
648                n_jobs += n_col_mini_chunks * n_row_mini_chunks;
649                row_outer += m_chunk;
650            }
651
652            let func = move |tid, packed_lhs: Ptr<T>| {
653                let mut did_pack_lhs_storage =
654                    alloc::vec![false; if tid > 0 { mc / MR } else { 0 }];
655                let did_pack_lhs = if tid > 0 {
656                    &mut *did_pack_lhs_storage
657                } else {
658                    &mut *({ did_pack_lhs }.0)
659                };
660
661                let min_jobs_per_thread = n_jobs / n_threads;
662                let rem = n_jobs - n_threads * min_jobs_per_thread;
663
664                // thread `tid` takes min_jobs_per_thread or min_jobs_per_thread + 1
665                let (job_start, job_end) = if tid < rem {
666                    let start = tid * (min_jobs_per_thread + 1);
667                    (start, start + min_jobs_per_thread + 1)
668                } else {
669                    // start = rem * (min_jobs_per_thread + 1) + (tid - rem) * min_jobs_per_thread;
670                    let start = tid * min_jobs_per_thread + rem;
671                    (start, start + min_jobs_per_thread)
672                };
673
674                let mut row_outer = 0;
675                let mut job_id = 0;
676                while row_outer != m {
677                    let mut m_chunk = mc.min(m - row_outer);
678                    if m_chunk > N && !do_prepack_lhs {
679                        m_chunk = m_chunk / N * N;
680                    }
681                    let n_row_mini_chunks = (m_chunk + (MR - 1)) / MR;
682
683                    let n_mini_jobs = n_col_mini_chunks * n_row_mini_chunks;
684
685                    if job_id >= job_end {
686                        return;
687                    }
688                    if job_id + n_mini_jobs < job_start {
689                        row_outer += m_chunk;
690                        job_id += n_mini_jobs;
691                        continue;
692                    }
693
694                    let do_pack_lhs = !do_prepack_lhs
695                        && ((m_chunk % N != 0) || lhs_rs != 1 || n_chunk > packing_threshold * NR);
696                    let packed_lhs_cs = if do_prepack_lhs || do_pack_lhs {
697                        MR as isize
698                    } else {
699                        lhs_cs
700                    };
701
702                    let mut j = 0;
703                    did_pack_lhs.fill(false);
704                    while j < n_col_mini_chunks {
705                        let mut i = 0;
706                        while i < n_row_mini_chunks {
707                            let col_inner = NR * j;
708                            let n_chunk_inner = NR.min(n_chunk - col_inner);
709
710                            let row_inner = MR * i;
711                            let m_chunk_inner = MR.min(m_chunk - row_inner);
712
713                            if job_id < job_start || job_id >= job_end {
714                                job_id += 1;
715                                i += 1;
716                                continue;
717                            }
718                            job_id += 1;
719
720                            let dst = dst.wrapping_offset(
721                                (row_outer + row_inner) as isize * dst_rs
722                                    + (col_outer + col_inner) as isize * dst_cs,
723                            );
724
725                            let func =
726                                dispatcher[(m_chunk_inner + (N - 1)) / N - 1][n_chunk_inner - 1];
727
728                            if do_pack_lhs && !did_pack_lhs[i] {
729                                pack_lhs::<T, N, MR, _>(
730                                    simd,
731                                    m_chunk_inner,
732                                    k_chunk,
733                                    packed_lhs.wrapping_add(i * packed_lhs_stride),
734                                    lhs.wrapping_offset(
735                                        (row_outer + row_inner) as isize * lhs_rs
736                                            + depth_outer as isize * lhs_cs,
737                                    ),
738                                    lhs_cs,
739                                    lhs_rs,
740                                    packed_lhs_stride,
741                                );
742                                did_pack_lhs[i] = true;
743                            }
744
745                            func(
746                                m_chunk_inner,
747                                n_chunk_inner,
748                                k_chunk,
749                                dst.0,
750                                if do_pack_lhs {
751                                    packed_lhs.wrapping_add(i * packed_lhs_stride).0
752                                } else if do_prepack_lhs {
753                                    packed_lhs
754                                        .wrapping_add((i + row_outer / MR) * packed_lhs_stride)
755                                        .0
756                                } else {
757                                    lhs.wrapping_offset(
758                                        (row_outer + row_inner) as isize * lhs_rs
759                                            + depth_outer as isize * lhs_cs,
760                                    )
761                                    .0
762                                },
763                                if do_pack_rhs {
764                                    packed_rhs.wrapping_add(j * packed_rhs_stride).0
765                                } else {
766                                    rhs.wrapping_offset(
767                                        depth_outer as isize * rhs_rs
768                                            + (col_outer + col_inner) as isize * rhs_cs,
769                                    )
770                                    .0
771                                },
772                                dst_cs,
773                                dst_rs,
774                                packed_lhs_cs,
775                                packed_rhs_rs,
776                                packed_rhs_cs,
777                                alpha,
778                                beta,
779                                alpha_status,
780                                conj_dst,
781                                conj_lhs,
782                                conj_rhs,
783                                core::ptr::null(),
784                            );
785                            i += 1;
786                        }
787                        j += 1;
788                    }
789
790                    row_outer += m_chunk;
791                }
792            };
793
794            if do_prepack_lhs {
795                match parallelism {
796                    Parallelism::None => func(0, prepacked_lhs),
797                    #[cfg(feature = "rayon")]
798                    Parallelism::Rayon(_) => {
799                        if n_threads == 1 {
800                            func(0, prepacked_lhs);
801                        } else {
802                            par_for_each(n_threads, |tid| func(tid, prepacked_lhs));
803                        }
804                    }
805                }
806            } else {
807                #[cfg(feature = "std")]
808                let func = |tid: usize| {
809                    L2_SLAB.with(|mem| {
810                        let mut mem = mem.borrow_mut();
811                        let stack = DynStack::new(&mut mem);
812                        let (packed_lhs_storage, _) = stack
813                            .make_aligned_uninit::<T>(packed_lhs_stride * (mc / MR), simd_align);
814                        let packed_lhs = Ptr(packed_lhs_storage.as_mut_ptr() as *mut T);
815                        func(tid, packed_lhs);
816                    });
817                };
818
819                #[cfg(not(feature = "std"))]
820                let mut func = |tid: usize| {
821                    let stack = DynStack::new(&mut l2_slab);
822                    let (packed_lhs_storage, _) =
823                        stack.make_aligned_uninit::<T>(packed_lhs_stride * (mc / MR), simd_align);
824                    let packed_lhs = Ptr(packed_lhs_storage.as_mut_ptr() as *mut T);
825                    func(tid, packed_lhs);
826                };
827
828                match parallelism {
829                    Parallelism::None => func(0),
830                    #[cfg(feature = "rayon")]
831                    Parallelism::Rayon(_) => {
832                        if n_threads == 1 {
833                            func(0);
834                        } else {
835                            par_for_each(n_threads, func);
836                        }
837                    }
838                }
839            }
840
841            conj_dst = false;
842            alpha.set_one();
843
844            depth_outer += k_chunk;
845        }
846        col_outer += n_chunk;
847    }
848}
849
850#[macro_export]
851macro_rules! __inject_mod {
852    ($module: ident, $ty: ident, $N: expr, $simd: ident, $requires_packed_rhs: expr) => {
853        mod $module {
854            use super::*;
855            use crate::gemm_common::simd::MixedSimd;
856            use crate::microkernel::$module::$ty::*;
857            const N: usize = $N;
858
859            #[inline(never)]
860            pub unsafe fn gemm_basic(
861                m: usize,
862                n: usize,
863                k: usize,
864                dst: *mut $ty,
865                dst_cs: isize,
866                dst_rs: isize,
867                read_dst: bool,
868                lhs: *const $ty,
869                lhs_cs: isize,
870                lhs_rs: isize,
871                rhs: *const $ty,
872                rhs_cs: isize,
873                rhs_rs: isize,
874                alpha: $ty,
875                beta: $ty,
876                conj_dst: bool,
877                conj_lhs: bool,
878                conj_rhs: bool,
879                parallelism: $crate::Parallelism,
880            ) {
881                $crate::gemm::gemm_basic_generic::<
882                    _,
883                    $ty,
884                    N,
885                    { MR_DIV_N * N },
886                    NR,
887                    MR_DIV_N,
888                    H_M,
889                    H_N,
890                >(
891                    <$crate::simd::$simd as MixedSimd<$ty, $ty, $ty, $ty>>::try_new().unwrap(),
892                    m,
893                    n,
894                    k,
895                    dst,
896                    dst_cs,
897                    dst_rs,
898                    read_dst,
899                    lhs,
900                    lhs_cs,
901                    lhs_rs,
902                    rhs,
903                    rhs_cs,
904                    rhs_rs,
905                    alpha,
906                    beta,
907                    conj_dst,
908                    conj_lhs,
909                    conj_rhs,
910                    |a, b, c| a * b + c,
911                    &UKR,
912                    &H_UKR,
913                    $requires_packed_rhs,
914                    parallelism,
915                );
916            }
917        }
918    };
919}
920
921#[macro_export]
922macro_rules! __inject_mod_cplx {
923    ($module: ident, $ty: ident, $N: expr, $simd: ident) => {
924        paste::paste! {
925            mod [<$module _cplx>] {
926                use super::*;
927                use crate::microkernel::$module::$ty::*;
928                use crate::gemm_common::simd::MixedSimd;
929                const N: usize = $N;
930
931                #[inline(never)]
932                pub unsafe fn gemm_basic_cplx(
933                    m: usize,
934                    n: usize,
935                    k: usize,
936                    dst: *mut num_complex::Complex<T>,
937                    dst_cs: isize,
938                    dst_rs: isize,
939                    read_dst: bool,
940                    lhs: *const num_complex::Complex<T>,
941                    lhs_cs: isize,
942                    lhs_rs: isize,
943                    rhs: *const num_complex::Complex<T>,
944                    rhs_cs: isize,
945                    rhs_rs: isize,
946                    alpha: num_complex::Complex<T>,
947                    beta: num_complex::Complex<T>,
948                    conj_dst: bool,
949                    conj_lhs: bool,
950                    conj_rhs: bool,
951                    parallelism: $crate::Parallelism,
952                    ) {
953                    $crate::gemm::gemm_basic_generic::<_, _, N, { CPLX_MR_DIV_N * N }, CPLX_NR, CPLX_MR_DIV_N, H_CPLX_M, H_CPLX_N>(
954                        <$crate::simd::$simd as MixedSimd<T, T, T, T>>::try_new().unwrap(),
955                        m,
956                        n,
957                        k,
958                        dst,
959                        dst_cs,
960                        dst_rs,
961                        read_dst,
962                        lhs,
963                        lhs_cs,
964                        lhs_rs,
965                        rhs,
966                        rhs_cs,
967                        rhs_rs,
968                        alpha,
969                        beta,
970                        conj_dst,
971                        conj_lhs,
972                        conj_rhs,
973                        |a, b, c| a * b + c,
974                        &CPLX_UKR,
975                        &H_CPLX_UKR,
976                        false,
977                        parallelism,
978                        );
979                }
980            }
981        }
982    };
983}
984
985#[macro_export]
986macro_rules! gemm_def {
987    ($ty: tt, $multiplier: expr) => {
988        type GemmTy = unsafe fn(
989            usize,
990            usize,
991            usize,
992            *mut T,
993            isize,
994            isize,
995            bool,
996            *const T,
997            isize,
998            isize,
999            *const T,
1000            isize,
1001            isize,
1002            T,
1003            T,
1004            bool,
1005            bool,
1006            bool,
1007            $crate::Parallelism,
1008        );
1009
1010        #[inline]
1011        fn init_gemm_fn() -> GemmTy {
1012            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1013            {
1014                #[cfg(feature = "nightly")]
1015                if $crate::feature_detected!("avx512f") {
1016                    return avx512f::gemm_basic;
1017                }
1018                if $crate::feature_detected!("fma") {
1019                    fma::gemm_basic
1020                } else {
1021                    scalar::gemm_basic
1022                }
1023            }
1024
1025            #[cfg(target_arch = "aarch64")]
1026            {
1027                if $crate::feature_detected!("neon") {
1028                    #[cfg(feature = "experimental-apple-amx")]
1029                    if $crate::cache::HasAmx::get() {
1030                        return amx::gemm_basic;
1031                    }
1032                    neon::gemm_basic
1033                } else {
1034                    scalar::gemm_basic
1035                }
1036            }
1037
1038            #[cfg(target_arch = "wasm32")]
1039            {
1040                if $crate::feature_detected!("simd128") {
1041                    simd128::gemm_basic
1042                } else {
1043                    scalar::gemm_basic
1044                }
1045            }
1046
1047            #[cfg(not(any(
1048                target_arch = "x86",
1049                target_arch = "x86_64",
1050                target_arch = "aarch64",
1051                target_arch = "wasm32",
1052            )))]
1053            {
1054                scalar::gemm_basic
1055            }
1056        }
1057
1058        static GEMM_PTR: ::core::sync::atomic::AtomicPtr<()> =
1059            ::core::sync::atomic::AtomicPtr::new(::core::ptr::null_mut());
1060
1061        #[inline(never)]
1062        fn init_gemm_ptr() -> GemmTy {
1063            let gemm_fn = init_gemm_fn();
1064            GEMM_PTR.store(gemm_fn as *mut (), ::core::sync::atomic::Ordering::Relaxed);
1065            gemm_fn
1066        }
1067
1068        #[inline(always)]
1069        pub fn get_gemm_fn() -> GemmTy {
1070            let mut gemm_fn = GEMM_PTR.load(::core::sync::atomic::Ordering::Relaxed);
1071            if gemm_fn.is_null() {
1072                gemm_fn = init_gemm_ptr() as *mut ();
1073            }
1074            unsafe { ::core::mem::transmute(gemm_fn) }
1075        }
1076
1077        $crate::__inject_mod!(scalar, $ty, 1, Scalar, false);
1078
1079        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1080        $crate::__inject_mod!(fma, $ty, 4 * $multiplier, V3, false);
1081        #[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
1082        $crate::__inject_mod!(avx512f, $ty, 8 * $multiplier, V4, false);
1083
1084        #[cfg(target_arch = "aarch64")]
1085        $crate::__inject_mod!(neon, $ty, 2 * $multiplier, Scalar, false);
1086        #[cfg(target_arch = "aarch64")]
1087        #[cfg(feature = "experimental-apple-amx")]
1088        $crate::__inject_mod!(amx, $ty, 8 * $multiplier, Scalar, true);
1089
1090        #[cfg(target_arch = "wasm32")]
1091        $crate::__inject_mod!(simd128, $ty, 2 * $multiplier, Scalar, false);
1092    };
1093}
1094
1095#[macro_export]
1096macro_rules! gemm_cplx_def {
1097    ($ty: tt, $cplx_ty: tt, $multiplier: expr) => {
1098        type GemmCplxTy = unsafe fn(
1099            usize,
1100            usize,
1101            usize,
1102            *mut num_complex::Complex<T>,
1103            isize,
1104            isize,
1105            bool,
1106            *const num_complex::Complex<T>,
1107            isize,
1108            isize,
1109            *const num_complex::Complex<T>,
1110            isize,
1111            isize,
1112            num_complex::Complex<T>,
1113            num_complex::Complex<T>,
1114            bool,
1115            bool,
1116            bool,
1117            $crate::Parallelism,
1118        );
1119
1120        fn init_gemm_cplx_fn() -> GemmCplxTy {
1121            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1122            {
1123                #[cfg(feature = "nightly")]
1124                if $crate::feature_detected!("avx512f") {
1125                    return avx512f_cplx::gemm_basic_cplx;
1126                }
1127                if $crate::feature_detected!("fma") {
1128                    return fma_cplx::gemm_basic_cplx;
1129                }
1130            }
1131
1132            #[cfg(target_arch = "aarch64")]
1133            {
1134                #[cfg(target_arch = "aarch64")]
1135                if $crate::feature_detected!("neon") && $crate::feature_detected!("fcma") {
1136                    return neonfcma::gemm_basic;
1137                }
1138            }
1139
1140            scalar_cplx::gemm_basic_cplx
1141        }
1142
1143        static GEMM_PTR: ::core::sync::atomic::AtomicPtr<()> =
1144            ::core::sync::atomic::AtomicPtr::new(::core::ptr::null_mut());
1145
1146        #[inline(never)]
1147        fn init_gemm_ptr() -> GemmCplxTy {
1148            let gemm_fn = init_gemm_cplx_fn();
1149            GEMM_PTR.store(gemm_fn as *mut (), ::core::sync::atomic::Ordering::Relaxed);
1150            gemm_fn
1151        }
1152
1153        #[inline(always)]
1154        pub fn get_gemm_fn() -> GemmCplxTy {
1155            let mut gemm_fn = GEMM_PTR.load(::core::sync::atomic::Ordering::Relaxed);
1156            if gemm_fn.is_null() {
1157                gemm_fn = init_gemm_ptr() as *mut ();
1158            }
1159            unsafe { ::core::mem::transmute(gemm_fn) }
1160        }
1161
1162        $crate::__inject_mod_cplx!(scalar, $ty, 1, Scalar);
1163
1164        #[cfg(target_arch = "aarch64")]
1165        $crate::__inject_mod!(neonfcma, $cplx_ty, 1 * $multiplier, Scalar, false);
1166
1167        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1168        $crate::__inject_mod_cplx!(fma, $ty, 2 * $multiplier, V3);
1169        #[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
1170        $crate::__inject_mod_cplx!(avx512f, $ty, 4 * $multiplier, V4);
1171    };
1172}