glar_base/
lib.rs

1use std::sync::{Barrier, Mutex, MutexGuard, RwLock, RwLockReadGuard};
2// Consider Once Cell
3use once_cell::sync::Lazy;
4
5pub mod range_rwlock;
6
7pub fn matrix_size(rs: usize, cs: usize, m: usize, n: usize) -> usize {
8    m * rs + n * cs - (rs + cs) + 1
9}
10
11use range_rwlock::{RangeLock, RangeLockReadGuard, RangeLockWriteGuard};
12
13#[macro_export]
14macro_rules! env_or {
15    ($name:expr, $default:expr) => {
16        if let Some(value) = std::option_env!($name) {
17            const_str::parse!(value, usize)
18        } else {
19            $default
20        }
21    };
22}
23
24#[derive(Copy, Clone)]
25pub struct CpuFeatures {
26    pub avx: bool,
27    pub avx2: bool,
28    pub avx512f: bool,
29    pub avx512f16: bool,
30    // pub avx512bf16: bool,
31    pub avx512bw: bool,
32    pub avx512_vnni: bool,
33    pub fma: bool,
34    pub fma4: bool,
35    pub f16c: bool,
36}
37
38// padding in bytes
39const CACHELINE_PAD: usize = 1024;
40
41#[cfg(target_arch = "x86_64")]
42pub struct HWConfig {
43    pub cpu_ft: CpuFeatures,
44    pub hw_model: HWModel,
45    is_l1_shared: bool,
46    is_l2_shared: bool,
47    is_l3_shared: bool,
48}
49
50impl HWConfig {
51    pub fn get_cache_info(&self) -> (bool, bool, bool) {
52        (self.is_l1_shared, self.is_l2_shared, self.is_l3_shared)
53    }
54    pub fn hw_model(&self) -> HWModel {
55        self.hw_model
56    }
57
58    pub fn cpu_ft(&self) -> CpuFeatures {
59        self.cpu_ft
60    }
61}
62
63#[cfg(target_arch = "aarch64")]
64pub struct HWConfig {
65    neon: bool,
66}
67
68#[derive(Copy, Clone)]
69pub enum HWModel {
70    Reference,
71    Haswell,
72    Skylake,
73}
74
75const SKYLAKE: [u8; 13] = [78, 85, 94, 126, 140, 141, 167, 151, 154, 183, 186, 143, 207];
76
77const HASWELL: [u8; 10] = [69, 70, 63, 42, 58, 165, 79, 86, 61, 71];
78
79impl HWModel {
80    pub fn from_hw(family_id: u8, model_id: u8) -> Self {
81        if family_id == 6 {
82            if SKYLAKE.contains(&model_id) {
83                return HWModel::Skylake;
84            }
85            if HASWELL.contains(&model_id) {
86                return HWModel::Haswell;
87            }
88        }
89
90        // default to reeference
91        return HWModel::Reference;
92    }
93    pub fn get_cache_info(&self) -> (bool, bool, bool) {
94        match self {
95            HWModel::Reference => (false, false, true),
96            HWModel::Haswell => (false, false, true),
97            HWModel::Skylake => (false, false, true),
98        }
99    }
100}
101
102// Use family and model id instead of cache size parameters
103// since the relation between optimal parameters (based on performance) and cache size parameters  can be non-trivial
104// e.g. it might be cpu model dependent
105
106#[inline]
107fn detect_hw_config() -> HWConfig {
108    #[cfg(target_arch = "x86_64")]
109    {
110        let cpuid = raw_cpuid::CpuId::new();
111        let feature_info = cpuid.get_feature_info().unwrap();
112        let extended_feature_info = cpuid.get_extended_feature_info().unwrap();
113        let avx = feature_info.has_avx();
114        let fma = feature_info.has_fma();
115        let avx2 = extended_feature_info.has_avx2();
116        let avx512f16 = extended_feature_info.has_avx512_fp16();
117        // let avx512bf16 = extended_feature_info.has_avx512_bf16();
118        let avx512f = extended_feature_info.has_avx512f();
119        let avx512bw = extended_feature_info.has_avx512bw();
120        let avx512_vnni = extended_feature_info.has_avx512vnni();
121        let f16c = feature_info.has_f16c();
122        let extended_prcoessor_info = cpuid.get_extended_processor_and_feature_identifiers().unwrap();
123        let fma4 = extended_prcoessor_info.has_fma4();
124        let cpu_ft = CpuFeatures { avx, avx2, avx512f, avx512f16, avx512bw, avx512_vnni, fma, fma4, f16c };
125        let family_id = feature_info.family_id();
126        let model_id = feature_info.model_id();
127        let hw_model = HWModel::from_hw(family_id, model_id);
128        let (is_l1_shared, is_l2_shared, is_l3_shared) = hw_model.get_cache_info();
129        return HWConfig { cpu_ft, hw_model, is_l1_shared, is_l2_shared, is_l3_shared };
130    }
131    #[cfg(target_arch = "aarch64")]
132    {
133        return HWConfig { neon: true };
134    }
135}
136
137pub static RUNTIME_HW_CONFIG: Lazy<HWConfig> = Lazy::new(|| detect_hw_config());
138
139pub static GLAR_NUM_THREADS: Lazy<usize> = Lazy::new(|| {
140    let n_core = std::thread::available_parallelism().unwrap().get();
141    // GLAR_NUM_THREADS or the number of logical cores
142    let x = std::env::var("GLAR_NUM_THREADS").unwrap_or(n_core.to_string());
143    x.parse::<usize>().unwrap()
144});
145#[cfg(target_arch = "x86_64")]
146pub(crate) mod cpu_features {
147    use super::HWModel;
148    use super::RUNTIME_HW_CONFIG;
149
150    pub fn hw_model() -> HWModel {
151        RUNTIME_HW_CONFIG.hw_model
152    }
153
154    pub fn has_f32_compute() -> bool {
155        // RUNTIME_HW_CONFIG.cpu_ft.avx512f || RUNTIME_HW_CONFIG.cpu_ft.avx
156        // dont use above since some avx512f also rely on avx instructions
157        // (even though avx512f should imply), we are being super conservative here
158        RUNTIME_HW_CONFIG.cpu_ft.avx
159    }
160
161    pub fn has_f16f32_compute() -> bool {
162        // RUNTIME_HW_CONFIG.cpu_ft.avx512f || RUNTIME_HW_CONFIG.cpu_ft.avx
163        // dont use above since some avx512f also rely on avx instructions
164        // (even though avx512f should imply), we are being super conservative here
165        RUNTIME_HW_CONFIG.cpu_ft.avx && RUNTIME_HW_CONFIG.cpu_ft.f16c
166    }
167    pub fn has_f64_compute() -> bool {
168        RUNTIME_HW_CONFIG.cpu_ft.avx
169    }
170    pub fn has_f16_compute() -> bool {
171        // since avx512_f16 is not stabilized, we use avx+fma+f16c as f16f32 compute
172        // this should not be a problem since all avx512_f16 has also these features
173        // otherwise it is very obscure cpu / vm for which the support is worth the effort
174        RUNTIME_HW_CONFIG.cpu_ft.avx512f16
175            && RUNTIME_HW_CONFIG.cpu_ft.avx
176            && RUNTIME_HW_CONFIG.cpu_ft.f16c
177            && RUNTIME_HW_CONFIG.cpu_ft.fma
178    }
179    pub fn has_i16i32_compute() -> bool {
180        RUNTIME_HW_CONFIG.cpu_ft.avx2 && RUNTIME_HW_CONFIG.cpu_ft.avx
181    }
182    pub fn has_i8i32_compute() -> bool {
183        RUNTIME_HW_CONFIG.cpu_ft.avx2 && RUNTIME_HW_CONFIG.cpu_ft.avx
184    }
185    // TODO: Use actual info from hardware
186    pub fn get_cache_params() -> (usize, usize, usize) {
187        (4800, 256, 128)
188    }
189}
190#[cfg(target_arch = "aarch64")]
191pub(crate) mod cpu_features {
192    use super::RUNTIME_HW_CONFIG;
193    pub fn hw_neon() -> bool {
194        RUNTIME_HW_CONFIG.neon
195    }
196}
197pub use cpu_features::*;
198
199pub struct PackPool {
200    pub buffer: RwLock<Vec<Mutex<Vec<u8>>>>,
201}
202
203pub static PACK_POOL: PackPool = PackPool { buffer: RwLock::new(vec![]) };
204
205pub fn acquire<'a>(
206    pool_guard: &'a RwLockReadGuard<'a, Vec<Mutex<Vec<u8>>>>,
207    pack_size: usize,
208) -> Option<MutexGuard<'a, Vec<u8>>> {
209    // find the first free buffer with enough size
210    // let x = PACK_POOL.buffer.read().unwrap();
211    for i in pool_guard.iter() {
212        // TODO: this might be the most optimal algo in terms of fragmentation/meory reuse
213        // It is very optimal for all cases (except for a few exceptional cases)
214        // Exceptional cases: You have  mulththreading along mc loop that is changing in run time (so it requires varying number of packa pool for 1 gemm run)
215        // This is exceptional since this can happen only if the threadConfig is created by user and threadconfig is changing its parallelsis along mc during run.
216        // I cannot think of a rason why someone would do that (maybe unusual hardware, or just experimentation).
217        // Also, the current algo is very simple and easy  to understand.
218        let lock = i.try_lock();
219        if let Ok(mutex) = lock {
220            if mutex.len() >= pack_size {
221                return Some(mutex);
222            }
223        }
224    }
225
226    None
227}
228
229pub fn extend<'a>(pool_vec: Vec<u8>) {
230    let mut pool_guard = PACK_POOL.buffer.write().unwrap();
231    pool_guard.push(Mutex::new(pool_vec));
232}
233
234pub struct GlarThreadConfig<'a> {
235    pub ic_id: usize,
236    // pc_id: usize,
237    pub jc_id: usize,
238    pub ir_id: usize,
239    pub jr_id: usize,
240    pub i_load_p_idx: usize,
241    pub j_load_p_idx: usize,
242    pub mc_eff: usize,
243    pub nc_eff: usize,
244    pub kc_eff: usize,
245    pub par: GlarPar,
246    pub packa_barrier: &'a [Barrier],
247    pub packb_barrier: &'a [Barrier],
248}
249
250pub fn get_apbp_barrier(par: &GlarPar) -> (Vec<Barrier>, Vec<Barrier>) {
251    let mut packa_barrier = vec![];
252    for _ in 0..par.ic_par {
253        let barrier = Barrier::new(par.jc_par * par.pc_par * par.ir_par * par.jr_par);
254        packa_barrier.push(barrier);
255    }
256
257    let mut packb_barrier = vec![];
258    for _ in 0..par.jc_par {
259        let barrier = Barrier::new(par.ic_par * par.pc_par * par.ir_par * par.jr_par);
260        packb_barrier.push(barrier);
261    }
262
263    (packa_barrier, packb_barrier)
264}
265
266impl<'a> GlarThreadConfig<'a> {
267    pub fn new(
268        par: GlarPar,
269        packa_barrier: &'a [Barrier],
270        packb_barrier: &'a [Barrier],
271        t_id: usize,
272        mc_eff: usize,
273        nc_eff: usize,
274        kc_eff: usize,
275    ) -> Self {
276        let ic_id = par.get_ic_id(t_id);
277        // let pc_id = par.get_pc_id(t_id);
278        let jc_id = par.get_jc_id(t_id);
279        let ir_id = par.get_ir_id(t_id);
280        let jr_id = par.get_jr_id(t_id);
281        let i_load_p_idx = jc_id * par.ir_par * par.jr_par + ir_id * par.jr_par + jr_id;
282        let j_load_p_idx = ic_id * par.ir_par * par.jr_par + ir_id * par.jr_par + jr_id;
283
284        Self {
285            ic_id,
286            // pc_id,
287            jc_id,
288            ir_id,
289            jr_id,
290            i_load_p_idx,
291            j_load_p_idx,
292            mc_eff,
293            nc_eff,
294            kc_eff,
295            par,
296            packa_barrier,
297            packb_barrier,
298        }
299    }
300    #[inline]
301    pub fn wait_packa(&self) {
302        if self.par.jc_par * self.par.pc_par * self.par.ir_par * self.par.jr_par > 1 {
303            self.packa_barrier[self.ic_id].wait();
304        }
305    }
306
307    #[inline]
308    pub fn wait_packb(&self) {
309        if self.par.ic_par * self.par.pc_par * self.par.ir_par * self.par.jr_par > 1 {
310            self.packb_barrier[self.jc_id].wait();
311        }
312    }
313}
314
315pub fn check_mem_size(mem_size: usize, rs: usize, cs: usize, m: usize, n: usize) {
316    assert!(mem_size >= rs * cs * m * n);
317    assert!(rs >= 1 && cs >= 1 && m >= 0 && n >= 0);
318}
319
320// once this is read, this cannot be changed for the time being.
321#[inline(always)]
322pub fn glar_num_threads() -> usize {
323    return *GLAR_NUM_THREADS;
324}
325
326#[derive(Copy, Clone)]
327pub struct GlarPar {
328    pub num_threads: usize,
329    pub ic_par: usize,
330    pub pc_par: usize,
331    pub jc_par: usize,
332    pub ir_par: usize,
333    pub jr_par: usize,
334}
335
336// greedy algo to distribute the number of threads evenly
337// simple works for the time being
338#[inline(always)]
339fn inc_par(ic_par: usize, jc_par: usize, ic_max: usize, jc_max: usize, factor: usize) -> (usize, usize, usize, usize) {
340    if (ic_par < jc_par && ic_par < ic_max) || (jc_par >= jc_max && ic_par < ic_max) {
341        (ic_par * factor, jc_par, ic_max / factor, jc_max)
342    } else if (ic_par >= jc_par && jc_par < jc_max) || (ic_par >= ic_max && jc_par < jc_max) {
343        (ic_par, jc_par * factor, ic_max, jc_max / factor)
344    } else {
345        (ic_par, jc_par, ic_max, jc_max)
346    }
347}
348impl GlarPar {
349    pub fn new(num_threads: usize, ic_par: usize, pc_par: usize, jc_par: usize, ir_par: usize, jr_par: usize) -> Self {
350        assert_eq!(num_threads, jc_par * pc_par * ic_par * jr_par * ir_par);
351        Self { num_threads, ic_par, pc_par, jc_par, ir_par, jr_par }
352    }
353    pub fn from_num_threads(num_threads: usize, m: usize, n: usize) -> Self {
354        let mut num_threads = num_threads;
355        let mut ic_par_max = if m < 96 {
356            1
357        } else if m < 400 {
358            2
359        } else {
360            m / 200
361        };
362        let mut jc_par_max = if n < 48 {
363            1
364        } else if n < 200 {
365            2
366        } else {
367            n / 100
368        };
369
370        if num_threads <= 12 {
371            let jc_par_max = jc_par_max.min(num_threads);
372            let n_thread = (num_threads / jc_par_max) * jc_par_max;
373            return Self::new(n_thread, num_threads / jc_par_max, 1, jc_par_max, 1, 1);
374        }
375        // let mut jr_par_max = if k < 96 { 1 } else if jc_par_max => 4 { 4.min(k / 4) };
376        num_threads = num_threads.min(ic_par_max * jc_par_max);
377        let mut ic_par = 1;
378        let pc_par = 1;
379        let mut jc_par = 1;
380        let mut ir_par = 1;
381        let jr_par = 1;
382
383        while num_threads > 1 {
384            if num_threads % 2 == 0 {
385                num_threads = num_threads / 2;
386                (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 2);
387            } else if num_threads % 3 == 0 {
388                num_threads = num_threads / 3;
389                (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 3);
390            } else if num_threads % 5 == 0 {
391                num_threads = num_threads / 5;
392                (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 5);
393                continue;
394            } else if num_threads % 7 == 0 {
395                num_threads = num_threads / 7;
396                (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 7);
397                continue;
398            } else {
399                // if it is non trivial prime factor (i.e. not divisible by 2,3,5,7)
400                // round it so it is a "nice" number
401                num_threads = num_threads / 2 * 2;
402            }
403            // if num_threads % 11 == 0 {
404            //     num_threads = num_threads / 11;
405            //     (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 11);
406            //     continue;
407            // }
408            // if num_threads % 13 == 0 {
409            //     num_threads = num_threads / 13;
410            //     (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 13);
411            //     continue;
412            // }
413            // if num_threads % 17 == 0 {
414            //     num_threads = num_threads / 17;
415            //     (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 17);
416            //     continue;
417            // }
418        }
419        if ic_par >= 8 {
420            ic_par = ic_par / 2;
421            ir_par = 2;
422        }
423        let num_threads = ic_par * pc_par * jc_par * ir_par * jr_par;
424        Self { num_threads, ic_par, pc_par, jc_par, ir_par, jr_par }
425    }
426    #[inline(always)]
427    pub fn default(m: usize, n: usize) -> Self {
428        let num_threads = glar_num_threads();
429        Self::from_num_threads(num_threads, m, n)
430    }
431    #[inline]
432    fn get_ic_id(&self, t_id: usize) -> usize {
433        (t_id / (self.pc_par * self.jc_par * self.ir_par * self.jr_par)) % self.ic_par
434    }
435
436    //    #[inline]
437    //    fn get_pc_id(&self, t_id: usize) -> usize {
438    //        (t_id / (self.jr_par*self.ir_par*self.ic_par)) % self.pc_par
439    //    }
440    #[inline]
441    fn get_jc_id(&self, t_id: usize) -> usize {
442        (t_id / (self.jr_par * self.ir_par)) % self.jc_par
443    }
444    #[inline]
445    fn get_jr_id(&self, t_id: usize) -> usize {
446        (t_id / self.ir_par) % self.jr_par
447    }
448    #[inline]
449    fn get_ir_id(&self, t_id: usize) -> usize {
450        t_id % self.ir_par
451    }
452
453    pub fn get_load_par(
454        &self,
455        gemm_mode: &GemmPool,
456        m: usize,
457        n: usize,
458        mc_eff: usize,
459        nc_eff: usize,
460    ) -> (usize, usize) {
461        let m = (m / self.ic_par).min(mc_eff);
462        let n = (n / self.jc_par).min(nc_eff);
463        let i_load_par = ((m + 127) / 128).min(self.num_threads / self.ic_par);
464        let j_load_par = ((n + 127) / 128).min(self.num_threads / self.jc_par);
465        let i_load_par = match gemm_mode {
466            GemmPool::Goto => i_load_par,
467            GemmPool::SmallM => i_load_par,
468            GemmPool::SmallN => 1,
469        };
470        (i_load_par.max(1), j_load_par.max(1))
471    }
472}
473
474#[inline]
475pub fn split_c_range(m: usize, mc: usize, mr: usize, ic_id: usize, ic_par: usize) -> (usize, usize, bool) {
476    let chunk_len = (m / (mr * ic_par)) * mr;
477    let rem = m % (mr * ic_par);
478    if ic_id == 0 {
479        let x = chunk_len + rem % mr;
480        let mc_left = ((((x + mc - 1) / mc) * mc) * ic_par) < m;
481        return (m - chunk_len - (rem % mr), m, mc_left);
482    }
483    let ic_id = ic_id - 1;
484    let m0 = (m / mr) * mr;
485    let rem = m0 % (mr * ic_par);
486    let start_delta = rem.min(ic_id * mr);
487    let end_delta = rem.min((ic_id + 1) * mr);
488    //    let is_m_boundary = (chunk_len + end_delta - start_delta ) % mc == 0;
489    let mc_coeff = (chunk_len + end_delta - start_delta + mc - 1) / mc;
490    let mc_left = ((mc_coeff * mc) * ic_par) < m;
491    //    let mc_left = is_m_boundary && rem != 0 && end_delta == start_delta;
492    (chunk_len * ic_id + start_delta, chunk_len * (ic_id + 1) + end_delta, mc_left)
493}
494
495#[inline]
496pub fn split_range(range_len: usize, unit_len: usize, r_id: usize, r_par: usize) -> (usize, usize) {
497    let chunk_start = (range_len / (unit_len * r_par)) * unit_len * r_id;
498    let chunk_end = (range_len / (unit_len * r_par)) * unit_len * (r_id + 1);
499    let rem = range_len % (unit_len * r_par);
500    let rem = rem - rem % unit_len;
501    let rem_start = rem.min(r_id * unit_len);
502    let rem_end = rem.min((r_id + 1) * unit_len);
503    if r_id == r_par - 1 {
504        return (chunk_start + rem_start, range_len);
505    }
506    (chunk_start + rem_start, chunk_end + rem_end)
507}
508
509pub trait BaseNum: Copy + 'static + Send {}
510
511impl<T> BaseNum for T where T: Copy + 'static + Send {}
512
513#[derive(Copy, Clone)]
514pub struct PoolSize {
515    pub m: usize,
516    pub n: usize,
517    pub k: usize,
518    pub ap_pool_size: usize,
519    pub ap_pool_multiplicity: usize,
520    pub bp_pool_size: usize,
521    pub bp_pool_multiplicity: usize,
522}
523
524impl PoolSize {
525    // add alignment padding for ab only for total memory pool sizes
526    pub fn mem_pool_size_b<TA, TB>(&self) -> usize {
527        // be conservative and add 2 * AB_ALIGN padding always
528        self.ap_pool_size * std::mem::size_of::<TA>() * self.ap_pool_multiplicity
529            + self.bp_pool_size * std::mem::size_of::<TB>() * self.bp_pool_multiplicity
530            + 2 * AB_ALIGN
531    }
532
533    pub fn ap_size_b<TA>(&self) -> usize {
534        self.ap_pool_size * std::mem::size_of::<TA>()
535    }
536
537    pub fn bp_size_b<TB>(&self) -> usize {
538        self.bp_pool_size * std::mem::size_of::<TB>()
539    }
540
541    pub fn ap_size_t_b<TA>(&self) -> usize {
542        self.ap_pool_size * std::mem::size_of::<TA>() * self.ap_pool_multiplicity
543    }
544
545    pub fn bp_size_t_b<TB>(&self) -> usize {
546        self.bp_pool_size * std::mem::size_of::<TB>() * self.bp_pool_multiplicity
547    }
548
549    pub fn slice_mut_from_pool<TA, TB>(
550        &self,
551        mem_pool: &mut [u8],
552        i_load_par: usize,
553        j_load_par: usize,
554        pool_size: PoolSize,
555        mr: usize,
556        nr: usize,
557        // mc: usize, nc: usize, kc: usize, mr: usize, nr: usize,
558    ) -> (Vec<RangeLock<'_, TA>>, Vec<RangeLock<'_, TB>>) {
559        let m_size = pool_size.m;
560        let n_size = pool_size.n;
561        let k_size = pool_size.k;
562        let ap_pool_size = self.ap_pool_size;
563        let ap_pool_size_b = ap_pool_size * std::mem::size_of::<TA>();
564        let a_alignment = std::mem::align_of::<TA>();
565        assert_eq!(ap_pool_size_b % a_alignment, 0);
566        let bp_pool_size = self.bp_pool_size;
567        let bp_pool_size_b = bp_pool_size * std::mem::size_of::<TB>();
568        let b_alignment = std::mem::align_of::<TB>();
569        assert_eq!(bp_pool_size_b % b_alignment, 0);
570        let mut ap = vec![];
571        let mut bp = vec![];
572        // safety for pointer to slice casting: assert len of mem_pool is enough
573        // ap_pool_size
574        assert!(mem_pool.len() >= self.mem_pool_size_b::<TA, TB>());
575        // align mem_pool
576        let align_offset = mem_pool.as_ptr().align_offset(AB_ALIGN);
577        let mut mem_pool = &mut mem_pool[align_offset..];
578        // safety for pointer to slice casting: ap has right alignment
579        assert_eq!(mem_pool.as_ptr().align_offset(a_alignment), 0);
580        for _ in 0..self.ap_pool_multiplicity {
581            let (a, rest) = mem_pool.split_at_mut(ap_pool_size_b);
582            let ap_pool = unsafe { std::slice::from_raw_parts_mut::<TA>(a.as_mut_ptr() as *mut TA, ap_pool_size) };
583            if ap_pool_size == 0 {
584                ap.push(RangeLock::from(ap_pool, i_load_par, 0, k_size, mr));
585            } else {
586                ap.push(RangeLock::from(ap_pool, i_load_par, m_size, k_size, mr));
587            }
588            mem_pool = rest;
589        }
590        let align_offset = mem_pool.as_ptr().align_offset(AB_ALIGN);
591        let mut mem_pool = &mut mem_pool[align_offset..];
592        // safety for pointer to slice casting: bp has right alignment
593        assert_eq!(mem_pool.as_ptr().align_offset(b_alignment), 0);
594        for _ in 0..self.bp_pool_multiplicity {
595            let (b, rest) = mem_pool.split_at_mut(bp_pool_size_b);
596            let bp_pool = unsafe { std::slice::from_raw_parts_mut::<TB>(b.as_mut_ptr() as *mut TB, bp_pool_size) };
597            if bp_pool_size == 0 {
598                bp.push(RangeLock::from(bp_pool, j_load_par, 0, k_size, nr));
599            } else {
600                bp.push(RangeLock::from(bp_pool, j_load_par, n_size, k_size, nr));
601            }
602            mem_pool = rest;
603        }
604        (ap, bp)
605    }
606}
607
608pub fn get_mem_pool_size_goto<AP: BaseNum, BP: BaseNum, HWConfig: GemmCache>(
609    hw_config: &HWConfig,
610    par: &GlarPar,
611    a_need_pool: bool,
612    b_need_pool: bool,
613) -> PoolSize {
614    let m = hw_config.get_mc_eff(par.ic_par);
615    let n = hw_config.get_nc_eff(par.jc_par);
616    let k = hw_config.get_kc_eff();
617    let (ap_pool_size, ap_pool_multiplicity) = if a_need_pool {
618        let ap_pool_multiplicity = par.ic_par;
619        let ap_pool_size = hw_config.get_ap_pool_size(par.ic_par) + CACHELINE_PAD / std::mem::size_of::<AP>();
620        (ap_pool_size, ap_pool_multiplicity)
621    } else {
622        (0, 1)
623    };
624    let (bp_pool_size, bp_pool_multiplicity) = if b_need_pool {
625        let bp_pool_multiplicity = par.jc_par;
626        let bp_pool_size = hw_config.get_bp_pool_size(par.jc_par) + CACHELINE_PAD / std::mem::size_of::<BP>();
627        (bp_pool_size, bp_pool_multiplicity)
628    } else {
629        (0, 1)
630    };
631    PoolSize { m, n, k, ap_pool_size, ap_pool_multiplicity, bp_pool_size, bp_pool_multiplicity }
632}
633
634pub fn get_mem_pool_size_small_m<AP: BaseNum, BP: BaseNum, HWConfig: GemmCache>(
635    hw_config: &HWConfig,
636    par: &GlarPar,
637    a_need_pool: bool,
638) -> PoolSize {
639    let m = hw_config.get_mc_eff(par.ic_par);
640    let n = hw_config.get_nc_eff(par.jc_par);
641    let k = hw_config.get_kc_eff();
642    if a_need_pool {
643        let ap_pool_multiplicity = par.ic_par;
644        let ap_pool_size = hw_config.get_ap_pool_size(par.ic_par) + CACHELINE_PAD / std::mem::size_of::<AP>();
645        PoolSize { m, n, k, ap_pool_size, ap_pool_multiplicity, bp_pool_size: 0, bp_pool_multiplicity: 1 }
646    } else {
647        PoolSize { m, n, k, ap_pool_size: 0, ap_pool_multiplicity: 1, bp_pool_size: 0, bp_pool_multiplicity: 1 }
648    }
649}
650
651pub fn get_mem_pool_size_small_n<AP: BaseNum, BP: BaseNum, HWConfig: GemmCache>(
652    hw_config: &HWConfig,
653    par: &GlarPar,
654    b_need_pool: bool,
655) -> PoolSize {
656    let ap_pool_size = hw_config.get_ap_pool_size2() + CACHELINE_PAD / std::mem::size_of::<AP>();
657    let ap_pool_multiplicity = par.num_threads;
658    let m = hw_config.mr();
659    let n = hw_config.get_nc_eff(par.jc_par);
660    let k = hw_config.get_kc_eff();
661    if b_need_pool {
662        let bp_pool_multiplicity = par.jc_par;
663        let bp_pool_size = hw_config.get_bp_pool_size(par.jc_par) + CACHELINE_PAD / std::mem::size_of::<BP>();
664        PoolSize { m, n, k, ap_pool_size, ap_pool_multiplicity, bp_pool_size, bp_pool_multiplicity }
665    } else {
666        PoolSize { m, n, k, ap_pool_size, ap_pool_multiplicity, bp_pool_size: 0, bp_pool_multiplicity: 1 }
667    }
668}
669
670// Choose ap_size, bp_size as arguments since they are specific to Gemm implementation,
671// It is determined by hardware, gemm implementation (e.g. f64, f32, f16),
672// Otherwise, this base crate would include code coupled with other gemm crates,
673// this would require either cyclic dep (Not allowed of course) or separate code for each specii hardware and gemm
674// imple inside this crate, which is not desirable. We want this crate to be as decoupled as possbile from
675// specific gemm implementation and hardware.
676
677pub fn run_small_m(m: usize) -> bool {
678    m < 144
679}
680
681pub fn run_small_n(n: usize) -> bool {
682    n < 144
683}
684
685pub enum GemmPool {
686    Goto,
687    SmallM,
688    SmallN,
689}
690
691pub fn ap_size<T>(m: usize, k: usize) -> usize {
692    let vs = 64 / std::mem::size_of::<T>();
693    let m_max = (m + vs - 1) / vs * vs;
694    m_max * k + AB_ALIGN / std::mem::size_of::<T>()
695}
696
697pub fn bp_size<T>(n: usize, k: usize) -> usize {
698    n * k + AB_ALIGN / std::mem::size_of::<T>()
699}
700
701pub fn ap_size_int<T, P>(m: usize, k: usize) -> usize {
702    let vs = 64 / std::mem::size_of::<T>();
703    let c_r = std::mem::size_of::<P>() / std::mem::size_of::<T>();
704    let k_r = (k + c_r - 1) / c_r * c_r;
705    let m_max = (m + vs - 1) / vs * vs;
706    m_max * k_r + AB_ALIGN / std::mem::size_of::<T>()
707}
708
709pub fn bp_size_int<T, P>(n: usize, k: usize) -> usize {
710    let c_r = std::mem::size_of::<P>() / std::mem::size_of::<T>();
711    let k_r = (k + c_r - 1) / c_r * c_r;
712    n * k_r + AB_ALIGN / std::mem::size_of::<T>()
713}
714
715#[derive(Clone, Copy)]
716pub struct StridedMatrix<T> {
717    pub(crate) src: *const T,
718    pub(crate) rs: usize,
719    pub(crate) cs: usize,
720}
721
722impl<T> StridedMatrix<T> {
723    pub fn new(src: *const T, rs: usize, cs: usize) -> Self {
724        Self { src, rs, cs }
725    }
726}
727
728unsafe impl<T> Send for StridedMatrix<T> {}
729
730#[derive(Clone, Copy)]
731pub struct StridedMatrixMut<T> {
732    pub(crate) src: *mut T,
733    pub(crate) rs: usize,
734    pub(crate) cs: usize,
735}
736
737unsafe impl<T> Send for StridedMatrixMut<T> {}
738
739impl<T> StridedMatrixMut<T> {
740    pub fn new(src: *mut T, rs: usize, cs: usize) -> Self {
741        Self { src, rs, cs }
742    }
743}
744
745#[derive(Clone)]
746pub struct StridedMatrixP<'a, T, U> {
747    pub(crate) src: *const T,
748    pub(crate) rs: usize,
749    pub(crate) cs: usize,
750    pub(crate) dst: &'a RangeLock<'a, U>,
751}
752
753unsafe impl<'a, T, U> Send for StridedMatrixP<'a, T, U> {}
754
755impl<'a, T, U> StridedMatrixP<'a, T, U> {
756    pub fn src(&self) -> *const T {
757        self.src
758    }
759    pub fn dst_w(&self, idx: usize, kc: usize) -> RangeLockWriteGuard<'a, 'a, U> {
760        self.dst.write(idx, kc).unwrap()
761    }
762    pub fn dst_r(&self) -> RangeLockReadGuard<'a, 'a, U> {
763        self.dst.read().unwrap()
764    }
765    pub fn get_mc(&self) -> usize {
766        self.dst.get_mc()
767    }
768    pub fn rs(&self) -> usize {
769        self.rs
770    }
771    pub fn cs(&self) -> usize {
772        self.cs
773    }
774}
775
776#[derive(Clone, Copy)]
777pub struct PackedMatrix<T> {
778    pub(crate) src: *const T,
779    pub(crate) k: usize,
780    pub(crate) m: usize,
781}
782
783unsafe impl<T> Send for PackedMatrix<T> {}
784
785impl<T> PackedMatrix<T> {
786    pub fn src(&self) -> *const T {
787        self.src
788    }
789    pub fn k(&self) -> usize {
790        self.k
791    }
792    pub fn m(&self) -> usize {
793        self.m
794    }
795}
796
797#[derive(Clone)]
798pub struct PackedMatrixMixed<'a, X, Y> {
799    pub(crate) src: *const X,
800    pub(crate) dst: &'a RangeLock<'a, Y>,
801    pub(crate) k: usize,
802    pub(crate) m: usize,
803}
804
805impl<'a, X, Y> PackedMatrixMixed<'a, X, Y> {
806    pub fn src(&self) -> *const X {
807        self.src
808    }
809    pub fn k(&self) -> usize {
810        self.k
811    }
812    pub fn m(&self) -> usize {
813        self.m
814    }
815
816    pub fn dst_w(&self, idx: usize, kc: usize) -> RangeLockWriteGuard<'a, 'a, Y> {
817        self.dst.write(idx, kc).unwrap()
818    }
819
820    pub fn get_mc(&self) -> usize {
821        self.dst.get_mc()
822    }
823
824    pub fn dst_r(&self) -> RangeLockReadGuard<'a, 'a, Y> {
825        self.dst.read().unwrap()
826    }
827}
828
829unsafe impl<X, Y> Send for PackedMatrixMixed<'_, X, Y> {}
830
831// must be multiple largest vector size that we support
832// Now, it avx512 -> 64 bytes
833pub const AB_ALIGN: usize = 1024;
834
835pub trait GemmCache {
836    fn mr(&self) -> usize;
837    fn nr(&self) -> usize;
838    fn get_mc_eff(&self, par: usize) -> usize;
839    fn get_kc_eff(&self) -> usize;
840    fn get_nc_eff(&self, par: usize) -> usize;
841    fn get_ap_pool_size(&self, ic_par: usize) -> usize {
842        let mc_eff = self.get_mc_eff(ic_par);
843        let kc_eff = self.get_kc_eff();
844        mc_eff * kc_eff
845    }
846    fn get_ap_pool_size2(&self) -> usize {
847        let kc_eff = self.get_kc_eff();
848        self.mr() * kc_eff
849    }
850    fn get_bp_pool_size(&self, jc_par: usize) -> usize {
851        let nc_eff = self.get_nc_eff(jc_par);
852        let kc_eff = self.get_kc_eff();
853        nc_eff * kc_eff
854    }
855}
856
857#[derive(Copy, Clone)]
858pub enum Array<X> {
859    StridedMatrix(StridedMatrix<X>),
860    PackedMatrix(PackedMatrix<X>),
861}
862
863impl<X> Array<X> {
864    pub fn strided_matrix(src: *const X, rs: usize, cs: usize) -> Self {
865        Array::StridedMatrix(StridedMatrix::new(src, rs, cs))
866    }
867    pub fn packed_matrix(src: *const X, m: usize, k: usize) -> Self {
868        Array::PackedMatrix(PackedMatrix { src, k, m })
869    }
870    pub fn into_pack_array<'a>(&self, a: &'a [RangeLock<'a, X>], p_id: usize) -> PArray<'a, X> {
871        match self {
872            Array::StridedMatrix(x) => {
873                let x = StridedMatrixP { src: x.src, rs: x.rs, cs: x.cs, dst: &a[p_id] };
874                PArray::<X>::StridedMatrix(x)
875            }
876            Array::PackedMatrix(x) => {
877                let x = PackedMatrix { src: x.src, k: x.k, m: x.m };
878                PArray::PackedMatrix(x)
879            }
880        }
881    }
882    pub fn into_pack_array2<'a, Y>(&self, a: &'a [RangeLock<'a, Y>], p_id: usize) -> PArrayMixed<'a, X, Y> {
883        match self {
884            Array::StridedMatrix(x) => {
885                let x = StridedMatrixP { src: x.src, rs: x.rs, cs: x.cs, dst: &a[p_id] };
886                PArrayMixed::<X, Y>::StridedMatrix(x)
887            }
888            Array::PackedMatrix(x) => {
889                let x = PackedMatrixMixed { src: x.src, dst: &a[p_id], k: x.k, m: x.m };
890                PArrayMixed::PackedMatrix(x)
891            }
892        }
893    }
894
895    pub fn src(&self) -> *const X {
896        match self {
897            Array::StridedMatrix(x) => x.src,
898            Array::PackedMatrix(x) => x.src,
899        }
900    }
901
902    pub fn transpose(&mut self) {
903        match self {
904            Array::StridedMatrix(x) => {
905                let temp = x.rs;
906                x.rs = x.cs;
907                x.cs = temp;
908            }
909            _ => {
910                panic!("Only StridedMatrix has transpose");
911            }
912        }
913    }
914
915    pub fn rs(&self) -> usize {
916        match self {
917            Array::StridedMatrix(x) => x.rs,
918            _ => {
919                panic!("Only StridedMatrix has rs");
920            }
921        }
922    }
923
924    pub fn cs(&self) -> usize {
925        match self {
926            Array::StridedMatrix(x) => x.cs,
927            _ => {
928                panic!("Only StridedMatrix has cs");
929            }
930        }
931    }
932
933    pub fn is_strided(&self) -> bool {
934        match self {
935            Array::StridedMatrix(_) => true,
936            _ => false,
937        }
938    }
939}
940
941#[derive(Copy, Clone)]
942pub enum ArrayMut<X> {
943    StridedMatrix(StridedMatrixMut<X>),
944}
945
946impl<X> ArrayMut<X> {
947    pub fn strided_matrix(src: *mut X, rs: usize, cs: usize) -> Self {
948        ArrayMut::StridedMatrix(StridedMatrixMut::new(src, rs, cs))
949    }
950
951    pub fn src(&self) -> *mut X {
952        match self {
953            ArrayMut::StridedMatrix(x) => x.src,
954        }
955    }
956
957    pub fn transpose(&mut self) {
958        match self {
959            ArrayMut::StridedMatrix(x) => {
960                let temp = x.rs;
961                x.rs = x.cs;
962                x.cs = temp;
963            }
964        }
965    }
966
967    pub fn rs(&self) -> usize {
968        match self {
969            ArrayMut::StridedMatrix(x) => x.rs,
970        }
971    }
972
973    pub fn cs(&self) -> usize {
974        match self {
975            ArrayMut::StridedMatrix(x) => x.cs,
976        }
977    }
978}
979
980#[derive(Clone)]
981pub enum PArray<'a, X> {
982    StridedMatrix(StridedMatrixP<'a, X, X>),
983    PackedMatrix(PackedMatrix<X>),
984}
985
986impl<'a, X> PArray<'a, X> {
987    pub fn src(&self) -> *const X {
988        match self {
989            Self::StridedMatrix(x) => x.src,
990            Self::PackedMatrix(x) => x.src,
991        }
992    }
993
994    pub fn rs(&self) -> usize {
995        match self {
996            Self::StridedMatrix(x) => x.rs,
997            _ => {
998                panic!("Only StridedMatrix has rs");
999            }
1000        }
1001    }
1002
1003    pub fn cs(&self) -> usize {
1004        match self {
1005            Self::StridedMatrix(x) => x.cs,
1006            _ => {
1007                panic!("Only StridedMatrix has cs");
1008            }
1009        }
1010    }
1011
1012    pub fn dst_w(&self, idx: usize, kc: usize) -> RangeLockWriteGuard<'a, 'a, X> {
1013        match self {
1014            Self::StridedMatrix(x) => x.dst.write(idx, kc).unwrap(),
1015            _ => {
1016                panic!("Only StridedMatrix has write guard");
1017            }
1018        }
1019    }
1020
1021    pub fn dst_r(&self) -> RangeLockReadGuard<'a, 'a, X> {
1022        match self {
1023            Self::StridedMatrix(x) => x.dst.read().unwrap(),
1024            _ => {
1025                panic!("Only StridedMatrix has read guard");
1026            }
1027        }
1028    }
1029
1030    pub fn is_strided(&self) -> bool {
1031        match self {
1032            Self::StridedMatrix(_) => true,
1033            _ => false,
1034        }
1035    }
1036}
1037
1038#[derive(Clone)]
1039pub enum PArrayMixed<'a, X, Y> {
1040    StridedMatrix(StridedMatrixP<'a, X, Y>),
1041    PackedMatrix(PackedMatrixMixed<'a, X, Y>),
1042}
1043
1044impl<'a, X, Y> PArrayMixed<'a, X, Y> {
1045    pub fn src(&self) -> *const X {
1046        match self {
1047            Self::StridedMatrix(x) => x.src,
1048            Self::PackedMatrix(x) => x.src,
1049        }
1050    }
1051
1052    pub fn rs(&self) -> usize {
1053        match self {
1054            Self::StridedMatrix(x) => x.rs,
1055            _ => {
1056                panic!("Only StridedMatrix has rs");
1057            }
1058        }
1059    }
1060
1061    pub fn cs(&self) -> usize {
1062        match self {
1063            Self::StridedMatrix(x) => x.cs,
1064            _ => {
1065                panic!("Only StridedMatrix has cs");
1066            }
1067        }
1068    }
1069
1070    pub fn dst_w(&self, idx: usize, kc: usize) -> RangeLockWriteGuard<'a, 'a, Y> {
1071        match self {
1072            Self::StridedMatrix(x) => x.dst.write(idx, kc).unwrap(),
1073            Self::PackedMatrix(x) => x.dst.write(idx, kc).unwrap(),
1074        }
1075    }
1076    pub fn dst_r(&self) -> RangeLockReadGuard<'a, 'a, Y> {
1077        match self {
1078            Self::StridedMatrix(x) => x.dst.read().unwrap(),
1079            Self::PackedMatrix(x) => x.dst.read().unwrap(),
1080        }
1081    }
1082    pub fn is_strided(&self) -> bool {
1083        match self {
1084            Self::StridedMatrix(_) => true,
1085            _ => false,
1086        }
1087    }
1088}
1089
1090pub enum PtrData<'a, X> {
1091    RefData(RangeLockReadGuard<'a, 'a, X>),
1092    PtrData(*const X),
1093}
1094
1095impl<'a, X> PtrData<'a, X> {
1096    pub fn src(&self) -> *const X {
1097        match self {
1098            PtrData::RefData(x) => x.get().as_ptr(),
1099            PtrData::PtrData(x) => x.clone(),
1100        }
1101    }
1102}
1103
1104#[macro_export]
1105macro_rules! is_mixed {
1106    (T, $st1:expr, $st2:expr) => {
1107        $st1
1108    };
1109    (F, $src:expr, $st2:expr) => {
1110        $st2
1111    };
1112}
1113
1114#[macro_export]
1115macro_rules! def_pa {
1116    ($packa_ty:tt,F,$ta:tt,$tap:tt) => {
1117        type $packa_ty<'a> = PArray<'a, $tap>;
1118    };
1119    ($packa_ty:tt,T,$ta:tt,$tap:tt) => {
1120        type $packa_ty<'a> = PArrayMixed<'a, $ta, $tap>;
1121    };
1122}
1123
1124#[macro_export]
1125macro_rules! def_glar_gemm {
1126    (
1127        $t_dispatcher:tt,
1128        $ta:tt,$tap:ty,$tb:ty,$tbp:ty,$tc:ty,$t_as:ty,$t_bs:ty,
1129        $packa_ty:tt,$packb_ty:tt,
1130        $one:expr,
1131        $name:ident, $name_mt:ident,
1132        $goto_name:ident, $goto_kernel:ident,
1133        $small_m_name:ident, $small_m_kernel:ident,
1134        $small_n_name:ident, $small_n_kernel:ident,
1135        $gemv_name:ident, $gemv_name2:ident,
1136        $packa_name:ident, $packb_name:ident,
1137        $run_small_m:expr, $run_small_n:expr,
1138        $pack_fn:tt, $include_flag:tt,
1139    ) => {
1140        def_pa!($packa_ty,$include_flag,$ta,$tap);
1141        def_pa!($packb_ty,$include_flag,$tb,$tbp);
1142        pub unsafe fn $name <F:MyFn>(
1143            hw_config: &$t_dispatcher <F>,
1144            m: usize, n: usize, k: usize,
1145            alpha: $t_as,
1146            a: Array<$ta>,
1147            b: Array<$tb>,
1148            beta: $t_bs,
1149            c: ArrayMut<$tc>,
1150            par: &GlarPar,
1151        )
1152        {
1153            let a_need_pool = a.is_strided() || !hw_config.is_compute_native();
1154            let b_need_pool = b.is_strided() || !hw_config.is_compute_native();
1155            if n == 1 && a.is_strided() {
1156                let alpha = &alpha as *const $t_as;
1157                let beta = &beta as *const $t_bs;
1158                $gemv_name(hw_config, m, k, alpha, a, b, beta, c);
1159                return;
1160            }
1161            if m == 1 && b.is_strided() {
1162                let alpha = &alpha as *const $t_as;
1163                let beta = &beta as *const $t_bs;
1164                let mut a = a;
1165                a.transpose();
1166                let mut b = b;
1167                b.transpose();
1168                let mut c = c;
1169                c.transpose();
1170                $gemv_name2(hw_config, n, k, alpha.into(), b, a, beta, c);
1171                return;
1172            }
1173            let (gemm_mode, gemm_fun, pool_info)
1174            : (
1175                GemmPool, unsafe fn(
1176                    &$t_dispatcher <F>, usize, usize, usize, *const $t_as, $packa_ty, $packb_ty, *const $t_bs, ArrayMut<$tc>, &GlarThreadConfig,
1177                ),
1178                PoolSize
1179            )
1180             = if run_small_m(m) && $run_small_m && b.is_strided() {
1181                (GemmPool::SmallM, $small_m_name, get_mem_pool_size_small_m::<$tap,$tbp,$t_dispatcher::<F>>(hw_config, par, a_need_pool))
1182            } else if run_small_n(n) && $run_small_n && a.is_strided() {
1183                (GemmPool::SmallN, $small_n_name, get_mem_pool_size_small_n::<$tap,$tbp,$t_dispatcher::<F>>(hw_config, par, b_need_pool))
1184            } else {
1185                (GemmPool::Goto, $goto_name, get_mem_pool_size_goto::<$tap,$tbp,$t_dispatcher::<F>>(hw_config, par, a_need_pool, b_need_pool))
1186            };
1187            let mem_pool_size = pool_info.mem_pool_size_b::<$tap,$tbp>();
1188            // TODO: zero pool size case is very special (aonly packed and b) to optimize, optimization will not be worth it
1189            // if mem_pool_size == 0 {
1190            //     let mut pool_vec = [0_u8; 1];
1191            //     let pool_buf = &mut pool_vec;
1192            //     $name_mt(
1193            //         hw_config, m, n, k, alpha, a, b, beta, c, par, pool_buf, gemm_mode, pool_info, gemm_fun
1194            //     );
1195            //     return;
1196            // }
1197            // run goto algo
1198            {
1199                let pool_guard = PACK_POOL.buffer.read().unwrap();
1200                let y = acquire(&pool_guard, mem_pool_size);
1201                if let Some(mut pool_vec) = y {
1202                    let pool_buf = &mut pool_vec;
1203                    $name_mt(
1204                        hw_config, m, n, k, alpha, a, b, beta, c, par, pool_buf, gemm_mode, pool_info, gemm_fun
1205                    );
1206                    return;
1207                }
1208            }
1209            let mut pool_vec = vec![0_u8; mem_pool_size];
1210            let pool_buf = &mut pool_vec;
1211            $name_mt(
1212                hw_config, m, n, k, alpha, a, b, beta, c, par, pool_buf, gemm_mode, pool_info, gemm_fun
1213            );
1214            extend(pool_vec);
1215        }
1216
1217        pub unsafe fn $name_mt<F:MyFn>(
1218            hw_config: &$t_dispatcher <F>,
1219            m: usize, n: usize, k: usize,
1220            alpha: $t_as,
1221            a: Array<$ta>,
1222            b: Array<$tb>,
1223            beta: $t_bs,
1224            c: ArrayMut<$tc>,
1225            par: &GlarPar,
1226            pool_buf: &mut [u8],
1227            gemm_mode: GemmPool,
1228            pool_info: PoolSize,
1229            gemm_fn: unsafe fn(
1230                &$t_dispatcher <F>, usize, usize, usize, *const $t_as, $packa_ty, $packb_ty, *const $t_bs, ArrayMut<$tc>, &GlarThreadConfig
1231            )
1232        )
1233        where $t_dispatcher <F>: GemmCache
1234        {
1235
1236            let mc_eff = <$t_dispatcher::<F> as GemmCache>::get_mc_eff(hw_config, par.ic_par);
1237            let nc_eff = <$t_dispatcher::<F> as GemmCache>::get_nc_eff(hw_config, par.jc_par);
1238            let kc_eff = <$t_dispatcher::<F> as GemmCache>::get_kc_eff(hw_config);
1239            let (pa_br_vec_ref, pb_br_vec_ref) = get_apbp_barrier(par);
1240
1241            let (i_load_par, j_load_par) = par.get_load_par(&gemm_mode, m, n, mc_eff, nc_eff);
1242            let (ap_pool_vec, bp_pool_vec) = pool_info.slice_mut_from_pool::<$tap,$tbp>(
1243                pool_buf, i_load_par, j_load_par, pool_info, hw_config.mr, hw_config.nr
1244            );
1245            let (ap_pool, bp_pool) = (&ap_pool_vec, &bp_pool_vec);
1246
1247            // remove par.clone
1248            std::thread::scope(|s| {
1249                for t_id in 1..par.num_threads {
1250                    let t_cfg = GlarThreadConfig::new(
1251                        par.clone(), &pa_br_vec_ref, &pb_br_vec_ref, t_id, mc_eff, nc_eff, kc_eff
1252                    );
1253                    let ic_id = t_cfg.ic_id;
1254                    let jc_id = t_cfg.jc_id;
1255                    let ap_id = match gemm_mode {
1256                        GemmPool::Goto => ic_id,
1257                        GemmPool::SmallM => ic_id,
1258                        GemmPool::SmallN => t_id,
1259                    };
1260                    let bp_id = match gemm_mode {
1261                        GemmPool::Goto => jc_id,
1262                        GemmPool::SmallM => 0,
1263                        GemmPool::SmallN => jc_id,
1264                    };
1265                    let ap_cur = a.$pack_fn(ap_pool, ap_id);
1266                    let bp_cur = b.$pack_fn(bp_pool, bp_id);
1267                    let g = hw_config;
1268                    s.spawn(move || {
1269                            let alpha = &alpha as *const $t_as;
1270                            let beta = &beta as *const $t_bs;
1271                            gemm_fn(g, m, n, k, alpha, ap_cur, bp_cur, beta, c, &t_cfg);
1272                        }
1273                    );
1274                }
1275                {
1276                    let ap = a.$pack_fn(ap_pool, 0);
1277                    let bp = b.$pack_fn(bp_pool, 0);
1278                    let t_id: usize = 0;
1279                    let t_cfg = GlarThreadConfig::new(par.clone(), &pa_br_vec_ref, &pb_br_vec_ref, t_id, mc_eff, nc_eff, kc_eff);
1280                    let alpha = &alpha as *const $t_as;
1281                    let beta = &beta as *const $t_bs;
1282                    gemm_fn(hw_config, m, n, k, alpha, ap, bp, beta, c, &t_cfg);
1283                }
1284            });
1285        }
1286
1287        unsafe fn $goto_name<F:MyFn>(
1288            hw_cfg: &$t_dispatcher <F>,
1289            m: usize, n: usize, k: usize,
1290            alpha: *const $t_as,
1291            a: $packa_ty,
1292            b: $packb_ty,
1293            beta: *const $t_bs,
1294            c: ArrayMut<$tc>,
1295            t_cfg: &GlarThreadConfig
1296        ) {
1297            let ic_id = t_cfg.ic_id;
1298            let jc_id = t_cfg.jc_id;
1299            let ir_id = t_cfg.ir_id;
1300            let jr_id = t_cfg.jr_id;
1301            let ir_par = t_cfg.par.ir_par;
1302            let jr_par = t_cfg.par.jr_par;
1303            let ic_par = t_cfg.par.ic_par;
1304            let jc_par = t_cfg.par.jc_par;
1305            let mc = t_cfg.mc_eff;
1306            let nc = t_cfg.nc_eff;
1307            let kc = t_cfg.kc_eff;
1308            let mr = hw_cfg.mr;
1309            let nr = hw_cfg.nr;
1310            let (mc_start, mc_end, mc_left) = split_c_range(m, mc, mr, ic_id, ic_par);
1311            let (nc_start, nc_end, nc_left) = split_c_range(n, nc, nr, jc_id, jc_par);
1312            let (kc_start, d1_end) = (0, k);
1313            let one = $one;
1314            let c_rs = c.rs();
1315            let c_cs = c.cs();
1316            let c_ptr = c.src();
1317            let mut mc_i = mc_start;
1318            while mc_i < mc_end {
1319                let mc_len = mc.min(mc_end - mc_i);
1320                let mut kc_i = kc_start;
1321                let (mr_start, mr_end) = split_range(mc_len, mr, ir_id, ir_par);
1322                let mr_len = mr_end - mr_start;
1323                let c_i = c_ptr.add((mc_i+mr_start) * c_rs);
1324                while kc_i < d1_end {
1325                    let kc_len = kc.min(d1_end - kc_i);
1326                    let kc_len_eff = hw_cfg.round_up(kc_len);
1327                    let mut nc_i = nc_start;
1328                    let kc_last = kc_i + kc_len == d1_end;
1329                    let kc_first = kc_i == kc_start;
1330                    let beta_t = if kc_i == kc_start { beta } else { &one as *const $t_bs};
1331                    let ap_data = $packa_name(hw_cfg, &a, mc_i, kc_i, mc_len, kc_len, t_cfg);
1332                    let ap = ap_data.src();
1333                    let ap = ap.add(mr_start*kc_len_eff);
1334                    while nc_i < nc_end {
1335                        let nc_len = nc.min(nc_end - nc_i);
1336                        let (nr_start, nr_end) = split_range(nc_len, nr, jr_id, jr_par);
1337                        let nr_len = nr_end - nr_start;
1338                        let c_ij = c_i.add((nc_i+nr_start) * c_cs);
1339                        let bp_data = $packb_name(hw_cfg, &b, nc_i, kc_i, nc_len, kc_len, t_cfg);
1340                        let bp = bp_data.src();
1341                        let bp = bp.add(nr_start*kc_len_eff);
1342                        $goto_kernel(
1343                            hw_cfg, mr_len, nr_len, kc_len, alpha, beta_t, c_ij, c_rs, c_cs,
1344                            ap, bp,
1345                            kc_last, kc_first
1346                        );
1347
1348                        nc_i += nc;
1349                    }
1350                    if nc_left {
1351                        t_cfg.wait_packb();
1352                        t_cfg.wait_packb();
1353                    }
1354                    kc_i += kc;
1355                }
1356                mc_i += mc;
1357            }
1358            if mc_left {
1359                let mut kc_i = kc_start;
1360                while kc_i < d1_end {
1361                    let kc_len = kc.min(d1_end -kc_i);
1362                    t_cfg.wait_packa();
1363                    t_cfg.wait_packa();
1364                    let mut nc_i = nc_start;
1365                    while nc_i < nc_end {
1366                        let nc_len = nc.min(nc_end - nc_i);
1367                        let _ = $packb_name(hw_cfg, &b, nc_i, kc_i, nc_len, kc_len, t_cfg);
1368                        nc_i += nc;
1369                    }
1370                    if nc_left{
1371                        t_cfg.wait_packb();
1372                        t_cfg.wait_packb();
1373                    }
1374                    kc_i += kc;
1375                }
1376            }
1377        }
1378        unsafe fn $small_m_name<F:MyFn>(
1379            hw_cfg: &$t_dispatcher <F>,
1380            m: usize, n: usize, k: usize,
1381            alpha: *const $t_as,
1382            a: $packa_ty,
1383            b: $packb_ty,
1384            beta: *const $t_bs,
1385            c: ArrayMut<$tc>,
1386            t_cfg: &GlarThreadConfig
1387        ) {
1388            let par = &t_cfg.par;
1389            let ic_id = t_cfg.ic_id;
1390            let jc_id = t_cfg.jc_id;
1391            let ir_id = t_cfg.ir_id;
1392            let ir_par = par.ir_par;
1393            let jr_id = t_cfg.jr_id;
1394            let jr_par = par.jr_par;
1395            let mc = t_cfg.mc_eff;
1396            let nc = t_cfg.nc_eff;
1397            let kc = t_cfg.kc_eff;
1398            let mr = hw_cfg.mr;
1399            let nr = hw_cfg.nr;
1400            let (mc_start, mc_end, mc_left) = split_c_range(m, mc, mr, ic_id, par.ic_par);
1401            let (nc_start, nc_end, _) = split_c_range(n, nc, nr, jc_id, par.jc_par);
1402            let (kc_start, kc_end) = (0, k);
1403            let one = $one;
1404
1405            let b_ptr = b.src();
1406            let b_rs = b.rs();
1407            let b_cs = b.cs();
1408            let c_rs = c.rs();
1409            let c_cs = c.cs();
1410            let c_ptr = c.src();
1411            let mut mc_i = mc_start;
1412            while mc_i < mc_end {
1413                let mc_len = mc.min(mc_end - mc_i);
1414                let (mr_start, mr_end) = split_range(mc_len, mr, ir_id, ir_par);
1415                let mr_len = mr_end - mr_start;
1416                let c_i = c_ptr.add((mc_i+mr_start) * c_rs);
1417                let mut kc_i = kc_start;
1418                while kc_i < kc_end {
1419                    let kc_len = kc.min(kc_end - kc_i);
1420                    let kc_len_eff = hw_cfg.round_up(kc_len);
1421                    let beta_t = if kc_i == kc_start { beta } else { &one as *const $t_bs};
1422                    let kc_last = kc_i + kc_len == kc_end;
1423                    let kc_first = kc_i == kc_start;
1424                    let mut nc_i = nc_start;
1425                    let ap_data = $packa_name(hw_cfg, &a, mc_i, kc_i, mc_len, kc_len, t_cfg);
1426                    let ap = ap_data.src();
1427                    let ap = ap.add(mr_start*kc_len_eff);
1428                    let b_j = b_ptr.add(kc_i * b_rs);
1429                    while nc_i < nc_end {
1430                        let nc_len = nc.min(nc_end - nc_i);
1431                        let (nr_start, nr_end) = split_range(nc_len, nr, jr_id, jr_par);
1432                        let nr_len = nr_end - nr_start;
1433                        let c_ij = c_i.add((nc_i + nr_start) * c_cs);
1434                        let b_cur = b_j.add((nc_i + nr_start) * b_cs);
1435                        $small_m_kernel(
1436                            hw_cfg, mr_len, nr_len, kc_len, alpha, beta_t,
1437                            b_cur, b_rs, b_cs,
1438                            c_ij, c_rs, c_cs,
1439                            ap,
1440                            kc_last, kc_first
1441                        );
1442                        nc_i += nc;
1443                    }
1444                    kc_i += kc;
1445                }
1446                mc_i += mc;
1447            }
1448
1449            if mc_left {
1450                let mut kc_i = kc_start;
1451                while kc_i < kc_end {
1452                    t_cfg.wait_packa();
1453                    t_cfg.wait_packa();
1454                    kc_i += kc;
1455                }
1456            }
1457        }
1458        unsafe fn $small_n_name<F:MyFn>(
1459            hw_cfg: &$t_dispatcher <F>,
1460            m: usize, n: usize, k: usize,
1461            alpha: *const $t_as,
1462            a: $packa_ty,
1463            b: $packb_ty,
1464            beta: *const $t_bs,
1465            c: ArrayMut<$tc>,
1466            t_cfg: &GlarThreadConfig
1467        ) {
1468            let par = &t_cfg.par;
1469            let ic_id = t_cfg.ic_id;
1470            let jc_id = t_cfg.jc_id;
1471            let ir_id = t_cfg.ir_id;
1472            let ir_par = par.ir_par;
1473            let jr_id = t_cfg.jr_id;
1474            let jr_par = par.jr_par;
1475            let mc = t_cfg.mc_eff;
1476            let nc = t_cfg.nc_eff;
1477            let kc = t_cfg.kc_eff;
1478            let mr = hw_cfg.mr;
1479            let nr = hw_cfg.nr;
1480            let (mc_start, mc_end, mc_left) = split_c_range(m, mc, mr, ic_id, par.ic_par);
1481            let (nc_start, nc_end, nc_left) = split_c_range(n, nc, nr, jc_id, par.jc_par);
1482            let (kc_start, kc_end) = (0, k);
1483            let one = $one;
1484
1485            let c_rs = c.rs();
1486            let c_cs = c.cs();
1487            let c_ptr = c.src();
1488            let a_ptr = a.src();
1489            let a_rs = a.rs();
1490            let a_cs = a.cs();
1491            // make sure this ap is hwole slice
1492            let mut a_dst = a.dst_w(0, kc);
1493            let a_dst_ref = a_dst.get();
1494            let a_dst_ptr = a_dst_ref.as_mut_ptr();
1495            let mut mc_i = mc_start;
1496            while mc_i < mc_end {
1497                let mc_len = mc.min(mc_end - mc_i);
1498                let (mr_start, mr_end) = split_range(mc_len, mr, ir_id, ir_par);
1499                let mr_len = mr_end - mr_start;
1500                let c_i = c_ptr.add((mc_i+mr_start) * c_rs);
1501                let a_i = a_ptr.add((mc_i+mr_start) * a_rs);
1502                let mut kc_i = kc_start;
1503                while kc_i < kc_end {
1504                    let kc_len = kc.min(kc_end - kc_i);
1505                    let kc_last = kc_i + kc_len == kc_end;
1506                    let kc_first = kc_i == kc_start;
1507                    let beta_t = if kc_i == kc_start { beta } else { &one as *const $t_bs};
1508                    let a_cur = a_i.add(kc_i*a_cs);
1509                    let mut nc_i = nc_start;
1510                    while nc_i < nc_end {
1511                        let nc_len = nc.min(nc_end - nc_i);
1512                        let (nr_start, nr_end) = split_range(nc_len, nr, jr_id, jr_par);
1513                        let nr_len = nr_end - nr_start;
1514                        let bp_data = $packb_name(hw_cfg, &b, nc_i, kc_i, nc_len, kc_len, t_cfg);
1515                        let bp = bp_data.src();
1516                        let c_ij = c_i.add((nc_i + nr_start) * c_cs);
1517                        $small_n_kernel(
1518                            hw_cfg, mr_len, nr_len, kc_len, alpha, beta_t,
1519                            a_cur, a_rs, a_cs,
1520                            a_dst_ptr, bp,
1521                            c_ij, c_rs, c_cs,
1522                            kc_last, kc_first
1523                        );
1524                        nc_i += nc;
1525                    }
1526                    if nc_left {
1527                        t_cfg.wait_packb();
1528                        t_cfg.wait_packb();
1529                    }
1530                    kc_i += kc;
1531                }
1532                mc_i += mc;
1533            }
1534            if mc_left {
1535                let mut kc_i = kc_start;
1536                while kc_i < kc_end {
1537                    let kc_len = kc.min(kc_end - kc_i);
1538                    let mut nc_i = nc_start;
1539                    while nc_i < nc_end {
1540                        let nc_len = nc.min(nc_end - nc_i);
1541                        let _ = $packb_name(hw_cfg, &b, nc_i, kc_i, nc_len, kc_len, t_cfg);
1542                        nc_i += nc;
1543                    }
1544                    if nc_left{
1545                        t_cfg.wait_packb();
1546                        t_cfg.wait_packb();
1547                    }
1548                    kc_i += kc;
1549                }
1550            }
1551        }
1552        // for packed api mc_i(nc_i) should be multiple of mr (nr, which we ensure by the split_c_range
1553        // for packed api kc_i should be multiple of kc_eff, which is always true since we dont parallelize over kc
1554        // this is subject to change if we parallelize over kc, but this is not in the plan
1555        // sync right before write and right before read
1556        // NOTE: dont return before the second packa as it ensures sync between threads
1557        pub(crate) unsafe fn $packa_name<'a,'b,F:MyFn>(hw_cfg: &$t_dispatcher <F>, x: &'b $packa_ty<'a>, mc_i: usize, kc_i: usize, mc_len: usize, kc_len: usize, t_cfg: &GlarThreadConfig) -> PtrData<'a,$tap> {
1558            t_cfg.wait_packa();
1559            let xp_ptr = match x {
1560                $packa_ty::StridedMatrix(x_i) => {
1561                    let mc_par = x_i.get_mc();
1562                    let mc_offset = mc_par * t_cfg.i_load_p_idx;
1563                    if mc_len > mc_offset {
1564                        let kc_len_ro = hw_cfg.round_up(kc_len);
1565                        let mc_len_x = (mc_len - mc_offset).min(mc_par);
1566                        let mc_i = mc_i + mc_offset;
1567                        let rs = x_i.rs();
1568                        let cs = x_i.cs();
1569                        let src_ptr = x_i.src().add(mc_i*rs + kc_i*cs);
1570                        let mut dst = x_i.dst_w(t_cfg.i_load_p_idx, kc_len_ro);
1571                        let dst_ref = dst.get();
1572                        let dst_ptr = dst_ref.as_mut_ptr();
1573                        hw_cfg.packa_fn(src_ptr, dst_ptr, mc_len_x, kc_len, rs, cs);
1574                    }
1575                    t_cfg.wait_packa();
1576                    PtrData::RefData(x_i.dst_r())
1577                }
1578                $packa_ty::PackedMatrix(x_i) => {
1579                    let vs = hw_cfg.vs;
1580                    let m_ro = (x_i.m() + vs - 1) / vs * vs;
1581                    let kc_len_ro = hw_cfg.round_up(kc_len);
1582                    let res = is_mixed!(
1583                        $include_flag,
1584                        {
1585                            let mc_par = x_i.get_mc();
1586                            let mc_offset = mc_par * t_cfg.i_load_p_idx;
1587                            let mc_len_ro = (mc_len + vs - 1) / vs * vs;
1588                            if mc_len_ro > mc_offset {
1589                                let mc_len_ro_x = (mc_len_ro - mc_offset).min(mc_par);
1590                                let mc_i = mc_i + mc_offset;
1591                                let src_ptr = x_i.src().add(mc_i*kc_len_ro + kc_i*m_ro);
1592                                let mut dst = x_i.dst_w(t_cfg.i_load_p_idx, kc_len_ro);
1593                                let dst_ref = dst.get();
1594                                let dst_ptr = dst_ref.as_mut_ptr();
1595                                hw_cfg.cvt_mixed(src_ptr, dst_ptr, mc_len_ro_x*kc_len_ro);
1596                            }
1597                            t_cfg.wait_packa();
1598                            PtrData::RefData(x_i.dst_r())
1599                        },
1600                        {
1601                            let src_ptr = x_i.src().add(mc_i*kc_len_ro + kc_i*m_ro);
1602                            t_cfg.wait_packa();
1603                            PtrData::PtrData(src_ptr)
1604                        }
1605
1606                    );
1607                    res
1608
1609                }
1610            };
1611            xp_ptr
1612        }
1613        // NOTE: dont return before the second packa as it ensures sync between threads
1614        pub(crate) unsafe fn $packb_name<'a,'b,F:MyFn>(hw_cfg: & $t_dispatcher <F>, x: &'b$packb_ty<'a>, nc_i: usize, kc_i: usize, nc_len: usize, kc_len: usize, t_cfg: &GlarThreadConfig) -> PtrData<'a,$tbp> {
1615            t_cfg.wait_packb();
1616            let xp_ptr = match x {
1617                $packb_ty::StridedMatrix(x_i) => {
1618                    let nc_par = x_i.get_mc();
1619                    let nc_offset = nc_par * t_cfg.j_load_p_idx;
1620                    if nc_len > nc_offset {
1621                        let kc_len_ro = hw_cfg.round_up(kc_len);
1622                        let nc_len_x = (nc_len - nc_offset).min(nc_par);
1623                        let nc_i = nc_i + nc_offset;
1624                        let rs = x_i.rs();
1625                        let cs = x_i.cs();
1626                        let src_ptr = x_i.src().add(kc_i*rs + nc_i*cs);
1627                        let mut dst = x_i.dst_w(t_cfg.j_load_p_idx, kc_len_ro);
1628                        let dst_ref = dst.get();
1629                        let dst_ptr = dst_ref.as_mut_ptr();
1630                        hw_cfg.packb_fn(src_ptr, dst_ptr, nc_len_x, kc_len, rs, cs);
1631                    }
1632                    t_cfg.wait_packb();
1633                    PtrData::RefData(x_i.dst_r())
1634                }
1635                $packb_ty::PackedMatrix(x_i) => {
1636                    let kc_len_ro = hw_cfg.round_up(kc_len);
1637                    let n_ro = x_i.m();
1638                    let res = is_mixed!(
1639                        $include_flag,
1640                        {
1641                            let nc_par = x_i.get_mc();
1642                            let nc_offset = nc_par * t_cfg.j_load_p_idx;
1643                            if nc_len > nc_offset {
1644                                let nc_len_x = (nc_len - nc_offset).min(nc_par);
1645                                let nc_i = nc_i + nc_offset;
1646                                let src_ptr = x_i.src().add(nc_i*kc_len_ro + kc_i*n_ro);
1647                                let mut dst = x_i.dst_w(t_cfg.j_load_p_idx, kc_len_ro);
1648                                let dst_ref = dst.get();
1649                                let dst_ptr = dst_ref.as_mut_ptr();
1650                                hw_cfg.cvt_mixed(src_ptr, dst_ptr, nc_len_x*kc_len_ro);
1651                            }
1652                            t_cfg.wait_packb();
1653                            PtrData::RefData(x_i.dst_r())
1654                        },
1655                        {
1656                            let src_ptr = x_i.src().add(nc_i*kc_len_ro + kc_i*n_ro);
1657                            t_cfg.wait_packb();
1658                            PtrData::PtrData(src_ptr)
1659                        }
1660
1661                    );
1662                    res
1663                }
1664            };
1665            xp_ptr
1666        }
1667    }
1668}
1669
1670#[macro_export]
1671macro_rules! def_kernel_bb_pf1_no_beta {
1672    (
1673        $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
1674        $MR:tt, $NR:tt, $pf1_0:tt, $pf_step:tt, $($mr_left:tt),*
1675    ) => {
1676        paste! {
1677            #[target_feature(enable = "avx")]
1678            pub unsafe fn kernel_bb<F: MyFn, const STRIDED: bool>(
1679                m: usize, n: usize, k: usize,
1680                alpha: *const $t_as,
1681                c: *mut $tc, c_rs: usize, c_cs: usize,
1682                ap: *const $ta, bp: *const $tb,
1683                f: F,
1684            ) {
1685                const MR: usize = $MR;
1686                const NR: usize = $NR;
1687                let m_rounded = m / MR * MR;
1688                let n_rounded = n / NR * NR;
1689                let m_left = m % MR;
1690                let n_left = n % NR;
1691
1692                let d_arr = [0, 0, c_rs, c_cs];
1693
1694                let mut m_i = 0;
1695                while m_i < m_rounded {
1696                    let c_cur0 = c.add(m_i * c_rs);
1697                    let ap_cur = ap.add(m_i * k);
1698                    let mut a_pft1_offset = $pf1_0 * k;
1699                    let mut n_i = 0;
1700                    while n_i < n_rounded {
1701                        let bp_cur = bp.add(n_i * k);
1702                        let c_cur1 = c_cur0.add(n_i * c_cs);
1703                        [<ukernel_$MR x $NR _bb>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, k, d_arr, a_pft1_offset, f);
1704                        n_i += NR;
1705                        a_pft1_offset += $pf_step * k;
1706                    }
1707                    // let a_pft1_offset = ($MR+(n_iter0-n_iter)*2)*4*k;
1708                    if n_left != 0 {
1709                        let bp_cur = bp.add(n_i * k);
1710                        let c_cur1 = c_cur0.add(n_i * c_cs);
1711                        [<ukernel_$MR x n _bb>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, k, d_arr, MR, n_left, f);
1712                    }
1713                    m_i += MR;
1714                }
1715
1716                $(
1717                    if (m_left+VS-1) / VS * VS == $mr_left {
1718                        let c_cur0 = c.add(m_i * c_rs);
1719                        let ap_cur = ap.add(m_i * k);
1720                        let mut n_i = 0;
1721                        while n_i < n_rounded {
1722                            let bp_cur = bp.add(n_i * k);
1723                            let c_cur1 = c_cur0.add(n_i * c_cs);
1724                            [<ukernel_$mr_left x $NR _bb_partial>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, k, d_arr, m_left, f);
1725                            n_i += NR;
1726                        }
1727                        if n_left !=0 {
1728                            let bp_cur = bp.add(n_i * k);
1729                            let c_cur1 = c_cur0.add(n_i * c_cs);
1730                            [<ukernel_$mr_left x n_bb_partial>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, k, d_arr, m_left, n_left, f);
1731                        }
1732                    }
1733                )*
1734
1735                asm!("vzeroupper");
1736            }
1737        }
1738    };
1739}
1740
1741#[macro_export]
1742macro_rules! def_kernel_bb_pf1 {
1743    (
1744        $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
1745        $MR:tt, $NR:tt, $pf1_0:tt, $pf_step:tt, $($mr_left:tt),*
1746    ) => {
1747        paste! {
1748            #[target_feature(enable = "avx")]
1749            pub unsafe fn kernel_bb<F: MyFn, const STRIDED: bool>(
1750                m: usize, n: usize, k: usize,
1751                alpha: *const $t_as,
1752                beta: *const $t_bs,
1753                c: *mut $tc, c_rs: usize, c_cs: usize,
1754                ap: *const $ta, bp: *const $tb,
1755                f: F,
1756            ) {
1757                const MR: usize = $MR;
1758                const NR: usize = $NR;
1759                let m_rounded = m / MR * MR;
1760                let n_rounded = n / NR * NR;
1761                let m_left = m % MR;
1762                let n_left = n % NR;
1763
1764                let d_arr = [0, 0, c_rs, c_cs];
1765
1766                let mut m_i = 0;
1767                while m_i < m_rounded {
1768                    let c_cur0 = c.add(m_i * c_rs);
1769                    let ap_cur = ap.add(m_i * k);
1770                    let mut a_pft1_offset = $pf1_0 * k;
1771                    let mut n_i = 0;
1772                    while n_i < n_rounded {
1773                        let bp_cur = bp.add(n_i * k);
1774                        let c_cur1 = c_cur0.add(n_i * c_cs);
1775                        [<ukernel_$MR x $NR _bb>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, a_pft1_offset, f);
1776                        n_i += NR;
1777                        a_pft1_offset += $pf_step * k;
1778                    }
1779                    // let a_pft1_offset = ($MR+(n_iter0-n_iter)*2)*4*k;
1780                    if n_left != 0 {
1781                        let bp_cur = bp.add(n_i * k);
1782                        let c_cur1 = c_cur0.add(n_i * c_cs);
1783                        [<ukernel_$MR x n _bb>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, MR, n_left, f);
1784                    }
1785                    m_i += MR;
1786                }
1787
1788
1789                $(
1790                    if (m_left+VS-1) / VS * VS == $mr_left {
1791                        let c_cur0 = c.add(m_i * c_rs);
1792                        let ap_cur = ap.add(m_i * k);
1793                        let mut n_i = 0;
1794                        while n_i < n_rounded {
1795                            let bp_cur = bp.add(n_i * k);
1796                            let c_cur1 = c_cur0.add(n_i * c_cs);
1797                            [<ukernel_$mr_left x $NR _bb_partial>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, m_left, f);
1798                            n_i += NR;
1799                        }
1800                        if n_left !=0 {
1801                            let bp_cur = bp.add(n_i * k);
1802                            let c_cur1 = c_cur0.add(n_i * c_cs);
1803                            [<ukernel_$mr_left x n_bb_partial>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, m_left, n_left, f);
1804                        }
1805                    }
1806                )*
1807
1808                asm!("vzeroupper");
1809            }
1810        }
1811    };
1812}
1813
1814#[macro_export]
1815macro_rules! def_kernel_bb_v0 {
1816    (
1817        $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
1818        $MR:tt, $NR:tt, $($mr_left:tt),*
1819    ) => {
1820        paste! {
1821            #[target_feature(enable = "avx")]
1822            pub unsafe fn [<kernel_$MR x $NR _bb>]<F: MyFn, const STRIDED: bool>(
1823                m: usize, n: usize, k: usize,
1824                alpha: *const $t_as,
1825                beta: *const $t_bs,
1826                c: *mut $tc, c_rs: usize, c_cs: usize,
1827                ap: *const $ta, bp: *const $tb,
1828                f: F,
1829            ) {
1830                const MR: usize = $MR;
1831                const NR: usize = $NR;
1832                let m_rounded = m / MR * MR;
1833                let n_rounded = n / NR * NR;
1834                let m_left = m % MR;
1835                let n_left = n % NR;
1836
1837                let d_arr = [0, 0, c_rs, c_cs];
1838
1839                let mut m_i = 0;
1840                while m_i < m_rounded {
1841                    let c_cur0 = c.add(m_i * c_rs);
1842                    let ap_cur = ap.add(m_i * k);
1843                    let mut n_i = 0;
1844                    while n_i < n_rounded {
1845                        let bp_cur = bp.add(n_i * k);
1846                        let c_cur1 = c_cur0.add(n_i * c_cs);
1847                        [<ukernel_$MR x $NR _bb>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, MR, f);
1848                        n_i += NR;
1849                    }
1850                    if n_left != 0 {
1851                        let bp_cur = bp.add(n_i * k);
1852                        let c_cur1 = c_cur0.add(n_i * c_cs);
1853                        [<ukernel_$MR x n _bb>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, MR, n_left, f);
1854                    }
1855                    m_i += MR;
1856                }
1857
1858                $(
1859                    if (m_left+VS-1) / VS * VS == $mr_left {
1860                        let c_cur0 = c.add(m_i * c_rs);
1861                        let ap_cur = ap.add(m_i * k);
1862                        let mut n_i = 0;
1863                        while n_i < n_rounded {
1864                            let bp_cur = bp.add(n_i * k);
1865                            let c_cur1 = c_cur0.add(n_i * c_cs);
1866                            [<ukernel_$mr_left x $NR _bb_partial>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, m_left, f);
1867                            n_i += NR;
1868                        }
1869                        if n_left !=0 {
1870                            let bp_cur = bp.add(n_i * k);
1871                            let c_cur1 = c_cur0.add(n_i * c_cs);
1872                            [<ukernel_$mr_left x n_bb_partial>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, m_left, n_left, f);
1873                        }
1874                    }
1875                )*
1876
1877                asm!("vzeroupper");
1878            }
1879        }
1880    };
1881}
1882
1883#[macro_export]
1884macro_rules! def_kernel_bb_v0_no_beta {
1885    (
1886        $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
1887        $MR:tt, $NR:tt, $($mr_left:tt),*
1888    ) => {
1889        paste! {
1890            #[target_feature(enable = "avx")]
1891            pub unsafe fn [<kernel_$MR x $NR _bb>]<F: MyFn, const STRIDED: bool>(
1892                m: usize, n: usize, k: usize,
1893                alpha: *const $t_as,
1894                c: *mut $tc, c_rs: usize, c_cs: usize,
1895                ap: *const $ta, bp: *const $tb,
1896                f: F,
1897            ) {
1898                const MR: usize = $MR;
1899                const NR: usize = $NR;
1900                let m_rounded = m / MR * MR;
1901                let n_rounded = n / NR * NR;
1902                let m_left = m % MR;
1903                let n_left = n % NR;
1904
1905                let d_arr = [0, 0, c_rs, c_cs];
1906
1907                let mut m_i = 0;
1908                while m_i < m_rounded {
1909                    let c_cur0 = c.add(m_i * c_rs);
1910                    let ap_cur = ap.add(m_i * k);
1911                    let mut n_i = 0;
1912                    while n_i < n_rounded {
1913                        let bp_cur = bp.add(n_i * k);
1914                        let c_cur1 = c_cur0.add(n_i * c_cs);
1915                        [<ukernel_$MR x $NR _bb>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, k, d_arr, MR, f);
1916                        n_i += NR;
1917                    }
1918                    if n_left != 0 {
1919                        let bp_cur = bp.add(n_i * k);
1920                        let c_cur1 = c_cur0.add(n_i * c_cs);
1921                        [<ukernel_$MR x n _bb>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, k, d_arr, MR, n_left, f);
1922                    }
1923                    m_i += MR;
1924                }
1925
1926
1927                $(
1928                    if (m_left+VS-1) / VS * VS == $mr_left {
1929                        let c_cur0 = c.add(m_i * c_rs);
1930                        let ap_cur = ap.add(m_i * k);
1931                        let mut n_i = 0;
1932                        while n_i < n_rounded {
1933                            let bp_cur = bp.add(n_i * k);
1934                            let c_cur1 = c_cur0.add(n_i * c_cs);
1935                            [<ukernel_$mr_left x $NR _bb_partial>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, k, d_arr, m_left, f);
1936                            n_i += NR;
1937                        }
1938                        if n_left !=0 {
1939                            let bp_cur = bp.add(n_i * k);
1940                            let c_cur1 = c_cur0.add(n_i * c_cs);
1941                            [<ukernel_$mr_left x n_bb_partial>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, k, d_arr, m_left, n_left, f);
1942                        }
1943                    }
1944                )*
1945
1946                asm!("vzeroupper");
1947            }
1948        }
1949    };
1950}
1951
1952#[macro_export]
1953macro_rules! def_kernel_sb_pf1 {
1954    (
1955        $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
1956        $RS:tt,
1957        $MR:tt, $NR:tt, $pf1_0:tt, $pf_step:tt, $($mr_left:tt),*
1958    ) => {
1959        paste! {
1960            #[target_feature(enable = "avx")]
1961            pub unsafe fn [<kernel_$MR x $NR _sb_v0>]<F: MyFn, const STRIDED: bool>(
1962                m: usize, n: usize, k: usize,
1963                alpha: *const $t_as, beta: *const $t_bs,
1964                a: *const $ta, a_rs: usize, a_cs: usize,
1965                bp: *const $tb,
1966                c: *mut $tc, c_rs: usize, c_cs: usize,
1967                ap: *mut $ta,
1968                f: F,
1969            ) {
1970                let k_eff = (k+$RS-1) / $RS * $RS;
1971                const MR: usize = $MR;
1972                const NR: usize = $NR;
1973                let m_rounded = m / MR * MR;
1974                let n_rounded = n / NR * NR;
1975                let m_left = m % MR;
1976                let n_left = n % NR;
1977
1978                let d_arr = [0, 0, c_rs, c_cs];
1979
1980                let mut m_i = 0;
1981                while m_i < m_rounded {
1982                    let c_cur0 = c.add(m_i * c_rs);
1983                    let a_cur = a.add(m_i * a_rs);
1984                    let a_pft1_offset = $pf1_0 * k;
1985                    [<packa_panel_$MR>](MR, k, a_cur, a_rs, a_cs, ap, VS);
1986                    let mut n_i = 0;
1987                    while n_i < n_rounded {
1988                        let bp_cur = bp.add(n_i * k_eff);
1989                        let c_cur1 = c_cur0.add(n_i * c_cs);
1990                        [<ukernel_$MR x $NR _bb>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, a_pft1_offset, f);
1991                        n_i += NR;
1992                    }
1993                    if n_left != 0 {
1994                        let bp_cur = bp.add(n_i * k_eff);
1995                        let c_cur1 = c_cur0.add(n_i * c_cs);
1996                        [<ukernel_$MR x n _bb>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, MR, n_left, f);
1997                    }
1998                    m_i += MR;
1999                }
2000
2001                $(
2002                    if (m_left+VS-1) / VS *VS == $mr_left {
2003                        let c_cur0 = c.add(m_i * c_rs);
2004                        let a_cur = a.add(m_i * a_rs);
2005                        [<packa_panel_ $MR>](m_left, k, a_cur, a_rs, a_cs, ap, VS);
2006                        let mut n_i = 0;
2007                        while n_i < n_rounded {
2008                            let bp_cur = bp.add(n_i * k_eff);
2009                            let c_cur1 = c_cur0.add(n_i * c_cs);
2010                            [<ukernel_$mr_left x $NR _bb_partial>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, m_left, f);
2011                            n_i += NR;
2012                        }
2013                        if n_left != 0 {
2014                            let bp_cur = bp.add(n_i * k_eff);
2015                            let c_cur1 = c_cur0.add(n_i * c_cs);
2016                            [<ukernel_$mr_left xn_bb_partial>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, m_left, n_left, f);
2017                        }
2018                        return;
2019                    }
2020                )*
2021
2022                asm!("vzeroupper");
2023            }
2024        }
2025    };
2026}
2027
2028#[macro_export]
2029macro_rules! def_kernel_sb_v0 {
2030    (
2031        $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
2032        $RS:tt,
2033        $MR:tt, $NR:tt, $($mr_left:tt),*
2034    ) => {
2035        paste! {
2036            #[target_feature(enable = "avx")]
2037            pub unsafe fn [<kernel_$MR x $NR _sb_v0>]<F: MyFn, const STRIDED: bool>(
2038                m: usize, n: usize, k: usize,
2039                alpha: *const $t_as, beta: *const $t_bs,
2040                a: *const $ta, a_rs: usize, a_cs: usize,
2041                bp: *const $tb,
2042                c: *mut $tc, c_rs: usize, c_cs: usize,
2043                ap: *mut $ta,
2044                f: F,
2045            ) {
2046                let k_eff = (k+$RS-1) / $RS * $RS;
2047                const MR: usize = $MR;
2048                const NR: usize = $NR;
2049                let m_rounded = m / MR * MR;
2050                let n_rounded = n / NR * NR;
2051                let m_left = m % MR;
2052                let n_left = n % NR;
2053
2054                let d_arr = [0, 0, c_rs, c_cs];
2055
2056                let mut m_i = 0;
2057                while m_i < m_rounded {
2058                    let c_cur0 = c.add(m_i * c_rs);
2059                    let a_cur = a.add(m_i * a_rs);
2060                    [<packa_panel_$MR>](MR, k, a_cur, a_rs, a_cs, ap, VS);
2061                    let mut n_i = 0;
2062                    while n_i < n_rounded {
2063                        let bp_cur = bp.add(n_i * k_eff);
2064                        let c_cur1 = c_cur0.add(n_i * c_cs);
2065                        [<ukernel_$MR x $NR _bb>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, MR, f);
2066                        n_i += NR;
2067                    }
2068                    if n_left != 0 {
2069                        let bp_cur = bp.add(n_i * k_eff);
2070                        let c_cur1 = c_cur0.add(n_i * c_cs);
2071                        [<ukernel_$MR x n _bb>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, MR, n_left, f);
2072                    }
2073                    m_i += MR;
2074                }
2075
2076                $(
2077                    if (m_left+VS-1) / VS *VS == $mr_left {
2078                        let c_cur0 = c.add(m_i * c_rs);
2079                        let a_cur = a.add(m_i * a_rs);
2080                        [<packa_panel_ $MR>](m_left, k, a_cur, a_rs, a_cs, ap, VS);
2081                        let mut n_i = 0;
2082                        while n_i < n_rounded {
2083                            let bp_cur = bp.add(n_i * k_eff);
2084                            let c_cur1 = c_cur0.add(n_i * c_cs);
2085                            [<ukernel_$mr_left x $NR _bb_partial>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, m_left, f);
2086                            n_i += NR;
2087                        }
2088                        if n_left != 0 {
2089                            let bp_cur = bp.add(n_i * k_eff);
2090                            let c_cur1 = c_cur0.add(n_i * c_cs);
2091                            [<ukernel_$mr_left xn_bb_partial>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, m_left, n_left, f);
2092                        }
2093                        return;
2094                    }
2095                )*
2096
2097                asm!("vzeroupper");
2098            }
2099        }
2100    };
2101}
2102
2103#[macro_export]
2104macro_rules! def_kernel_sb_v0_no_beta {
2105    (
2106        $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
2107        $MR:tt, $NR:tt, $($mr_left:tt),*
2108    ) => {
2109        paste! {
2110            #[target_feature(enable = "avx")]
2111            pub unsafe fn [<kernel_$MR x $NR _sb_v0>]<F: MyFn, const STRIDED: bool>(
2112                m: usize, n: usize, k: usize,
2113                alpha: *const $t_as,
2114                a: *const $ta, a_rs: usize, a_cs: usize,
2115                bp: *const $tb,
2116                c: *mut $tc, c_rs: usize, c_cs: usize,
2117                ap: *mut $ta,
2118                f: F,
2119            ) {
2120                const MR: usize = $MR;
2121                const NR: usize = $NR;
2122                let m_rounded = m / MR * MR;
2123                let n_rounded = n / NR * NR;
2124                let m_left = m % MR;
2125                let n_left = n % NR;
2126
2127                let d_arr = [0, 0, c_rs, c_cs];
2128
2129                let mut m_i = 0;
2130                while m_i < m_rounded {
2131                    let c_cur0 = c.add(m_i * c_rs);
2132                    let a_cur = a.add(m_i * a_rs);
2133                    [<packa_panel_$MR>](MR, k, a_cur, a_rs, a_cs, ap, VS);
2134                    let mut n_i = 0;
2135                    while n_i < n_rounded {
2136                        let bp_cur = bp.add(n_i * k);
2137                        let c_cur1 = c_cur0.add(n_i * c_cs);
2138                        [<ukernel_$MR x $NR _bb>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, k, d_arr, MR, f);
2139                        n_i += NR;
2140                    }
2141                    if n_left != 0 {
2142                        let bp_cur = bp.add(n_i * k);
2143                        let c_cur1 = c_cur0.add(n_i * c_cs);
2144                        [<ukernel_$MR x n _bb>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, k, d_arr, MR, n_left, f);
2145                    }
2146                    m_i += MR;
2147                }
2148
2149                $(
2150                    if (m_left+VS-1) / VS *VS == $mr_left {
2151                        let c_cur0 = c.add(m_i * c_rs);
2152                        let a_cur = a.add(m_i * a_rs);
2153                        [<packa_panel_ $MR>](m_left, k, a_cur, a_rs, a_cs, ap, VS);
2154                        let mut n_i = 0;
2155                        while n_i < n_rounded {
2156                            let bp_cur = bp.add(n_i * k);
2157                            let c_cur1 = c_cur0.add(n_i * c_cs);
2158                            [<ukernel_$mr_left x $NR _bb_partial>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, k, d_arr, m_left, f);
2159                            n_i += NR;
2160                        }
2161                        if n_left != 0 {
2162                            let bp_cur = bp.add(n_i * k);
2163                            let c_cur1 = c_cur0.add(n_i * c_cs);
2164                            [<ukernel_$mr_left xn_bb_partial>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, k, d_arr, m_left, n_left, f);
2165                        }
2166                        return;
2167                    }
2168                )*
2169
2170                asm!("vzeroupper");
2171            }
2172        }
2173    };
2174}
2175
2176#[macro_export]
2177macro_rules! def_kernel_sb_pf1_no_beta {
2178    (
2179        $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
2180        $MR:tt, $NR:tt, $pf1_0:tt, $pf_step:tt, $($mr_left:tt),*
2181    ) => {
2182        paste! {
2183            #[target_feature(enable = "avx")]
2184            pub unsafe fn [<kernel_$MR x $NR _sb_v0>]<F: MyFn, const STRIDED: bool>(
2185                m: usize, n: usize, k: usize,
2186                alpha: *const $t_as,
2187                a: *const $ta, a_rs: usize, a_cs: usize,
2188                bp: *const $tb,
2189                c: *mut $tc, c_rs: usize, c_cs: usize,
2190                ap: *mut $ta,
2191                f: F,
2192            ) {
2193                const MR: usize = $MR;
2194                const NR: usize = $NR;
2195                let m_rounded = m / MR * MR;
2196                let n_rounded = n / NR * NR;
2197                let m_left = m % MR;
2198                let n_left = n % NR;
2199
2200                let d_arr = [0, 0, c_rs, c_cs];
2201
2202                let mut m_i = 0;
2203                while m_i < m_rounded {
2204                    let c_cur0 = c.add(m_i * c_rs);
2205                    let a_cur = a.add(m_i * a_rs);
2206                    let a_pft1_offset = $pf1_0 * k;
2207                    [<packa_panel_$MR>](MR, k, a_cur, a_rs, a_cs, ap, VS);
2208                    let mut n_i = 0;
2209                    while n_i < n_rounded {
2210                        let bp_cur = bp.add(n_i * k);
2211                        let c_cur1 = c_cur0.add(n_i * c_cs);
2212                        [<ukernel_$MR x $NR _bb>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, k, d_arr, a_pft1_offset, f);
2213                        n_i += NR;
2214                    }
2215                    if n_left != 0 {
2216                        let bp_cur = bp.add(n_i * k);
2217                        let c_cur1 = c_cur0.add(n_i * c_cs);
2218                        [<ukernel_$MR x n _bb>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, k, d_arr, MR, n_left, f);
2219                    }
2220                    m_i += MR;
2221                }
2222
2223                $(
2224                    if (m_left+VS-1) / VS *VS == $mr_left {
2225                        let c_cur0 = c.add(m_i * c_rs);
2226                        let a_cur = a.add(m_i * a_rs);
2227                        [<packa_panel_ $MR>](m_left, k, a_cur, a_rs, a_cs, ap, VS);
2228                        let mut n_i = 0;
2229                        while n_i < n_rounded {
2230                            let bp_cur = bp.add(n_i * k);
2231                            let c_cur1 = c_cur0.add(n_i * c_cs);
2232                            [<ukernel_$mr_left x $NR _bb_partial>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, k, d_arr, m_left, f);
2233                            n_i += NR;
2234                        }
2235                        if n_left != 0 {
2236                            let bp_cur = bp.add(n_i * k);
2237                            let c_cur1 = c_cur0.add(n_i * c_cs);
2238                            [<ukernel_$mr_left xn_bb_partial>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, k, d_arr, m_left, n_left, f);
2239                        }
2240                        return;
2241                    }
2242                )*
2243
2244                asm!("vzeroupper");
2245            }
2246        }
2247    };
2248}
2249
2250#[macro_export]
2251macro_rules! def_kernel_bs_no_beta {
2252    (
2253        $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
2254        $MR:tt, $NR:tt, $($mr_left:tt),*
2255    ) => {
2256        paste! {
2257            #[target_feature(enable = "avx")]
2258            pub unsafe fn [<kernel_$MR x $NR _bs_v0>]<F: MyFn, const STRIDED: bool>(
2259                m: usize, n: usize, k: usize,
2260                alpha: *const $t_as,
2261                b: *const $tb, b_rs: usize, b_cs: usize,
2262                c: *mut $tc, c_rs: usize, c_cs: usize,
2263                ap: *const $ta,
2264                f: F,
2265            ) {
2266                const MR: usize = $MR;
2267                const NR: usize = $NR;
2268                let m_rounded = m / MR * MR;
2269                let n_rounded = n / NR * NR;
2270                let m_left = m % MR;
2271                let n_left = n % NR;
2272
2273                let d_arr = [b_rs, b_cs, c_rs, c_cs];
2274
2275                let mut m_i = 0;
2276                while m_i < m_rounded {
2277                    let c_cur0 = c.add(m_i * c_rs);
2278                    let ap_cur = ap.add(m_i * k);
2279                    let mut n_i = 0;
2280                    while n_i < n_rounded {
2281                        let b_cur = b.add(n_i * b_cs);
2282                        let c_cur1 = c_cur0.add(n_i * c_cs);
2283                        [<ukernel_$MR x $NR _bs>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, k, d_arr, MR, f);
2284                        n_i += NR;
2285                    }
2286                    if n_left != 0 {
2287                        let b_cur = b.add(n_i * k);
2288                        let c_cur1 = c_cur0.add(n_i * c_cs);
2289                        [<ukernel_$MR xn_bs>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, k, d_arr, MR, n_left, f);
2290                    }
2291                    m_i += MR;
2292                }
2293
2294                $(
2295                    if (m_left+VS-1) / VS * VS == $mr_left {
2296                        let c_cur0 = c.add(m_i * c_rs);
2297                        let ap_cur = ap.add(m_i * k);
2298                        let mut n_i = 0;
2299                        while n_i < n_rounded {
2300                            let b_cur = b.add(n_i * b_cs);
2301                            let c_cur1 = c_cur0.add(n_i * c_cs);
2302                            [<ukernel_$mr_left x $NR _bs_partial>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, k, d_arr, m_left, f);
2303                            n_i += NR;
2304                        }
2305                        if n_left != 0 {
2306                            let b_cur = b.add(n_i * b_cs);
2307                            let c_cur1 = c_cur0.add(n_i * c_cs);
2308                            [<ukernel_$mr_left xn_bs_partial>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, k, d_arr, m_left, n_left, f);
2309                        }
2310                        return;
2311                    }
2312                )*
2313
2314                asm!("vzeroupper");
2315            }
2316        }
2317    };
2318}
2319
2320#[macro_export]
2321macro_rules! def_kernel_bs {
2322    (
2323        $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
2324        $MR:tt, $NR:tt, $($mr_left:tt),*
2325    ) => {
2326        paste! {
2327            #[target_feature(enable = "avx")]
2328            pub unsafe fn [<kernel_$MR x $NR _bs_v0>]<F: MyFn, const STRIDED: bool>(
2329                m: usize, n: usize, k: usize,
2330                alpha: *const $t_as, beta: *const $t_bs,
2331                b: *const $tb, b_rs: usize, b_cs: usize,
2332                c: *mut $tc, c_rs: usize, c_cs: usize,
2333                ap: *const $ta,
2334                f: F,
2335            ) {
2336                const MR: usize = $MR;
2337                const NR: usize = $NR;
2338                let m_rounded = m / MR * MR;
2339                let n_rounded = n / NR * NR;
2340                let m_left = m % MR;
2341                let n_left = n % NR;
2342
2343                let d_arr = [b_rs, b_cs, c_rs, c_cs];
2344
2345                let mut m_i = 0;
2346                while m_i < m_rounded {
2347                    let c_cur0 = c.add(m_i * c_rs);
2348                    let ap_cur = ap.add(m_i * k);
2349                    let mut n_i = 0;
2350                    while n_i < n_rounded {
2351                        let b_cur = b.add(n_i * b_cs);
2352                        let c_cur1 = c_cur0.add(n_i * c_cs);
2353                        [<ukernel_$MR x $NR _bs>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, beta, k, d_arr, MR, f);
2354                        n_i += NR;
2355                    }
2356                    if n_left != 0 {
2357                        let b_cur = b.add(n_i * b_cs);
2358                        let c_cur1 = c_cur0.add(n_i * c_cs);
2359                        [<ukernel_$MR xn_bs>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, beta, k, d_arr, MR, n_left, f);
2360                    }
2361                    m_i += MR;
2362                }
2363
2364                $(
2365                    if (m_left+VS-1) / VS * VS == $mr_left {
2366                        let c_cur0 = c.add(m_i * c_rs);
2367                        let ap_cur = ap.add(m_i * k);
2368                        let mut n_i = 0;
2369                        while n_i < n_rounded {
2370                            let b_cur = b.add(n_i * b_cs);
2371                            let c_cur1 = c_cur0.add(n_i * c_cs);
2372                            [<ukernel_$mr_left x $NR _bs_partial>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, beta, k, d_arr, m_left, f);
2373                            n_i += NR;
2374                        }
2375                        if n_left != 0 {
2376                            let b_cur = b.add(n_i * b_cs);
2377                            let c_cur1 = c_cur0.add(n_i * c_cs);
2378                            [<ukernel_$mr_left xn_bs_partial>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, beta, k, d_arr, m_left, n_left, f);
2379                        }
2380                        return;
2381                    }
2382                )*
2383
2384                asm!("vzeroupper");
2385            }
2386        }
2387    };
2388}
2389#[macro_export]
2390macro_rules! c_mem {
2391    (0) => {
2392        "0({cx})"
2393    };
2394    (1) => {
2395        "0({cx}, {x0})"
2396    };
2397    (2) => {
2398        "0({cx}, {x0}, 2)"
2399    };
2400    (3) => {
2401        "0({x1})"
2402    };
2403    (4) => {
2404        "0({x1}, {x0})"
2405    };
2406    (5) => {
2407        "0({x1}, {x0}, 2)"
2408    };
2409    (6) => {
2410        "0({x2})"
2411    };
2412    (7) => {
2413        "0({x2}, {x0})"
2414    };
2415    (8) => {
2416        "0({x2}, {x0}, 2)"
2417    };
2418    (9) => {
2419        "0({x3})"
2420    };
2421    (10) => {
2422        "0({x3}, {x0})"
2423    };
2424    (11) => {
2425        "0({x3}, {x0}, 2)"
2426    };
2427    (12) => {
2428        "0({x4})"
2429    };
2430    (13) => {
2431        "0({x4}, {x0})"
2432    };
2433    (14) => {
2434        "0({x4}, {x0}, 2)"
2435    };
2436}
2437
2438// mod test {
2439//     // test split_c_range
2440//     #[test]
2441//     fn test_split_c_range() {
2442//         let m = 143;
2443//         let mc = 4800;
2444//         let mr = 24;
2445//         let ic_par = 4;
2446//         for ic_id in 0..ic_par {
2447//             let (mc_start, mc_end, mc_left) = super::split_c_range(m, mc, mr, ic_id, ic_par);
2448//             println!("mc_start: {}, mc_end: {}, mc_left: {}", mc_start, mc_end, mc_left);
2449//         }
2450//         assert!(false);
2451//     }
2452// }