pire_base/
lib.rs

1//! # This crate is only for internal use in the pire project
2//! Nothing is expected to be used outside this module
3//! No semver guarantees
4
5use core::mem::size_of;
6use once_cell::sync::Lazy;
7use std::sync::{Barrier, Mutex, MutexGuard, RwLock, RwLockReadGuard};
8
9pub mod range_rwlock;
10
11#[derive(Copy, Clone)]
12pub struct IdentityFn;
13
14pub trait UnaryFn<T>: Copy + std::marker::Sync {
15    unsafe fn call(self, c: *mut T, m: usize);
16}
17
18impl<T> UnaryFn<T> for IdentityFn {
19    #[inline(always)]
20    unsafe fn call(self, _c: *mut T, _m: usize) {}
21}
22
23impl<T> UnaryFn<T> for unsafe fn(*mut T, m: usize) {
24    #[inline(always)]
25    unsafe fn call(self, c: *mut T, m: usize) {
26        self(c, m);
27    }
28}
29
30#[inline(always)]
31pub unsafe fn load_buf<T: Copy>(c: *const T, c_rs: usize, c_cs: usize, c_buf: &mut [T], m: usize, n: usize, mr: usize) {
32    for j in 0..n {
33        for i in 0..m {
34            c_buf[i + j * mr] = *c.add(i * c_rs + j * c_cs);
35        }
36    }
37}
38
39#[inline(always)]
40pub unsafe fn store_buf<T: Copy>(c: *mut T, c_rs: usize, c_cs: usize, c_buf: &[T], m: usize, n: usize, mr: usize) {
41    for j in 0..n {
42        for i in 0..m {
43            *c.add(i * c_rs + j * c_cs) = c_buf[i + j * mr];
44        }
45    }
46}
47
48pub fn matrix_size(m: usize, n: usize) -> usize {
49    n * m
50}
51
52use range_rwlock::{RangeLock, RangeLockReadGuard, RangeLockWriteGuard};
53
54#[cfg(target_arch = "x86_64")]
55#[derive(Copy, Clone)]
56pub struct CpuFeatures {
57    pub sse: bool,
58    pub sse2: bool,
59    pub sse3: bool,
60    pub ssse3: bool,
61    pub avx: bool,
62    pub avx2: bool,
63    pub avx512f: bool,
64    pub avx512f16: bool,
65    // pub avx512bf16: bool,
66    pub avx512bw: bool,
67    pub avx512_vnni: bool,
68    pub fma: bool,
69    pub fma4: bool,
70    pub f16c: bool,
71}
72
73#[cfg(target_arch = "x86")]
74#[derive(Copy, Clone)]
75pub struct CpuFeatures {
76    pub sse: bool,
77    pub sse2: bool,
78    pub sse3: bool,
79    pub ssse3: bool,
80}
81
82#[cfg(target_arch = "aarch64")]
83#[derive(Copy, Clone)]
84pub struct CpuFeatures {
85    pub sve: bool,
86    pub neon: bool,
87    pub fp16: bool,
88    pub f32mm: bool,
89    pub fcma: bool,
90    pub i8mm: bool,
91}
92
93#[cfg(not(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64")))]
94#[derive(Copy, Clone)]
95pub struct CpuFeatures {
96    pub dummy: bool,
97}
98
99// padding in bytes
100const CACHELINE_PAD: usize = 1024;
101
102pub struct HWConfig {
103    pub cpu_ft: CpuFeatures,
104    pub hw_model: HWModel,
105    is_l1_shared: bool,
106    is_l2_shared: bool,
107    is_l3_shared: bool,
108}
109
110impl HWConfig {
111    pub fn get_cache_info(&self) -> (bool, bool, bool) {
112        (self.is_l1_shared, self.is_l2_shared, self.is_l3_shared)
113    }
114    pub fn hw_model(&self) -> HWModel {
115        self.hw_model
116    }
117
118    pub fn cpu_ft(&self) -> CpuFeatures {
119        self.cpu_ft
120    }
121}
122
123#[derive(Copy, Clone)]
124pub enum HWModel {
125    Reference,
126    Haswell,
127    Skylake,
128}
129
130const SKYLAKE: [u8; 13] = [78, 85, 94, 126, 140, 141, 167, 151, 154, 183, 186, 143, 207];
131
132const HASWELL: [u8; 10] = [69, 70, 63, 42, 58, 165, 79, 86, 61, 71];
133
134impl HWModel {
135    pub fn from_hw(family_id: u8, model_id: u8) -> Self {
136        if family_id == 6 {
137            if SKYLAKE.contains(&model_id) {
138                return HWModel::Skylake;
139            }
140            if HASWELL.contains(&model_id) {
141                return HWModel::Haswell;
142            }
143        }
144
145        // default to reeference
146        return HWModel::Reference;
147    }
148    pub fn get_cache_info(&self) -> (bool, bool, bool) {
149        match self {
150            HWModel::Reference => (false, false, true),
151            HWModel::Haswell => (false, false, true),
152            HWModel::Skylake => (false, false, true),
153        }
154    }
155}
156
157// Use family and model id instead of cache size parameters
158// since the relation between optimal parameters (based on performance) and cache size parameters  can be non-trivial
159// e.g. it might be cpu model dependent
160
161#[inline]
162fn detect_hw_config() -> HWConfig {
163    #[cfg(target_arch = "x86_64")]
164    {
165        let cpuid = raw_cpuid::CpuId::new();
166        let feature_info = cpuid.get_feature_info().unwrap();
167        let extended_feature_info = cpuid.get_extended_feature_info().unwrap();
168        let sse = feature_info.has_sse();
169        let sse2 = feature_info.has_sse2();
170        let sse3 = feature_info.has_sse3();
171        let ssse3 = feature_info.has_ssse3();
172        let avx = feature_info.has_avx();
173        let fma = feature_info.has_fma();
174        let avx2 = extended_feature_info.has_avx2();
175        let avx512f16 = extended_feature_info.has_avx512_fp16();
176        // let avx512bf16 = extended_feature_info.has_avx512_bf16();
177        let avx512f = extended_feature_info.has_avx512f();
178        let avx512bw = extended_feature_info.has_avx512bw();
179        let avx512_vnni = extended_feature_info.has_avx512vnni();
180        let f16c = feature_info.has_f16c();
181        let extended_processor_info = cpuid.get_extended_processor_and_feature_identifiers().unwrap();
182        let fma4 = extended_processor_info.has_fma4();
183        let cpu_ft = CpuFeatures {
184            sse,
185            sse2,
186            sse3,
187            ssse3,
188            avx,
189            avx2,
190            avx512f,
191            avx512f16,
192            avx512bw,
193            avx512_vnni,
194            fma,
195            fma4,
196            f16c,
197        };
198        let family_id = feature_info.family_id();
199        let model_id = feature_info.model_id();
200        let hw_model = HWModel::from_hw(family_id, model_id);
201        let (is_l1_shared, is_l2_shared, is_l3_shared) = hw_model.get_cache_info();
202        return HWConfig { cpu_ft, hw_model, is_l1_shared, is_l2_shared, is_l3_shared };
203    }
204    #[cfg(target_arch = "x86")]
205    {
206        let cpuid = raw_cpuid::CpuId::new();
207        let feature_info = cpuid.get_feature_info().unwrap();
208        let sse = feature_info.has_sse();
209        let sse2 = feature_info.has_sse2();
210        let sse3 = feature_info.has_sse3();
211        let ssse3 = feature_info.has_ssse3();
212        let cpu_ft = CpuFeatures { sse, sse2, sse3, ssse3 };
213        let family_id = feature_info.family_id();
214        let model_id = feature_info.model_id();
215        let hw_model = HWModel::from_hw(family_id, model_id);
216        let (is_l1_shared, is_l2_shared, is_l3_shared) = hw_model.get_cache_info();
217        return HWConfig { cpu_ft, hw_model, is_l1_shared, is_l2_shared, is_l3_shared };
218    }
219    #[cfg(target_arch = "aarch64")]
220    {
221        use std::arch::is_aarch64_feature_detected;
222        let neon = is_aarch64_feature_detected!("neon");
223        let sve = is_aarch64_feature_detected!("sve");
224        let fp16 = is_aarch64_feature_detected!("fp16");
225        let f32mm = is_aarch64_feature_detected!("f32mm");
226        let fcma = is_aarch64_feature_detected!("fcma");
227        let i8mm = is_aarch64_feature_detected!("i8mm");
228
229        return HWConfig {
230            cpu_ft: CpuFeatures { neon, sve, fp16, f32mm, fcma, i8mm },
231            hw_model: HWModel::Reference,
232            is_l1_shared: false,
233            is_l2_shared: false,
234            is_l3_shared: true,
235        };
236    }
237    #[cfg(not(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64")))]
238    {
239        return HWConfig {
240            cpu_ft: CpuFeatures { dummy: false },
241            hw_model: HWModel::Reference,
242            is_l1_shared: false,
243            is_l2_shared: false,
244            is_l3_shared: true,
245        };
246    }
247}
248
249#[cfg(feature = "debug_cpu_features")]
250#[allow(unused)]
251fn apply_debug_cpu_features(cpu_ft: &mut CpuFeatures) {
252    #[cfg(target_arch = "x86_64")]
253    {
254        let sse_turn_off = std::env::var("PIRE_SSE_OFF").is_ok();
255        let sse2_turn_off = std::env::var("PIRE_SSE2_OFF").is_ok();
256        let sse3_turn_off = std::env::var("PIRE_SSE3_OFF").is_ok();
257        let ssse3_turn_off = std::env::var("PIRE_SSSE3_OFF").is_ok();
258        let avx_turn_off = std::env::var("PIRE_AVX_OFF").is_ok();
259        let avx2_turn_off = std::env::var("PIRE_AVX2_OFF").is_ok();
260        let avx512f_turn_off = std::env::var("PIRE_AVX512F_OFF").is_ok();
261        let avx512f16_turn_off = std::env::var("PIRE_AVX512F16_OFF").is_ok();
262        let avx512bw_turn_off = std::env::var("PIRE_AVX512BW_OFF").is_ok();
263        let avx512_vnni_turn_off = std::env::var("PIRE_AVX512_VNNI_OFF").is_ok();
264        let fma_turn_off = std::env::var("PIRE_FMA_OFF").is_ok();
265        let fma4_turn_off = std::env::var("PIRE_FMA4_OFF").is_ok();
266        let f16c_turn_off = std::env::var("PIRE_F16C_OFF").is_ok();
267
268        cpu_ft.sse = cpu_ft.sse && !sse_turn_off;
269        cpu_ft.sse2 = cpu_ft.sse2 && !sse2_turn_off;
270        cpu_ft.sse3 = cpu_ft.sse3 && !sse3_turn_off;
271        cpu_ft.ssse3 = cpu_ft.ssse3 && !ssse3_turn_off;
272        cpu_ft.avx = cpu_ft.avx && !avx_turn_off;
273        cpu_ft.avx2 = cpu_ft.avx2 && !avx2_turn_off;
274        cpu_ft.avx512f = cpu_ft.avx512f && !avx512f_turn_off;
275        cpu_ft.avx512f16 = cpu_ft.avx512f16 && !avx512f16_turn_off;
276        cpu_ft.avx512bw = cpu_ft.avx512bw && !avx512bw_turn_off;
277        cpu_ft.avx512_vnni = cpu_ft.avx512_vnni && !avx512_vnni_turn_off;
278        cpu_ft.fma = cpu_ft.fma && !fma_turn_off;
279        cpu_ft.fma4 = cpu_ft.fma4 && !fma4_turn_off;
280        cpu_ft.f16c = cpu_ft.f16c && !f16c_turn_off;
281    }
282    #[cfg(target_arch = "x86")]
283    {
284        let sse_turn_off = std::env::var("PIRE_SSE_OFF").is_ok();
285        let sse2_turn_off = std::env::var("PIRE_SSE2_OFF").is_ok();
286        let sse3_turn_off = std::env::var("PIRE_SSE3_OFF").is_ok();
287        let ssse3_turn_off = std::env::var("PIRE_SSSE3_OFF").is_ok();
288
289        cpu_ft.sse = cpu_ft.sse && !sse_turn_off;
290        cpu_ft.sse2 = cpu_ft.sse2 && !sse2_turn_off;
291        cpu_ft.sse3 = cpu_ft.sse3 && !sse3_turn_off;
292        cpu_ft.ssse3 = cpu_ft.ssse3 && !ssse3_turn_off;
293    }
294    #[cfg(target_arch = "aarch64")]
295    {
296        let neon_turn_off = std::env::var("PIRE_NEON_OFF").is_ok();
297        let sve_turn_off = std::env::var("PIRE_SVE_OFF").is_ok();
298        let fp16_turn_off = std::env::var("PIRE_FP16_OFF").is_ok();
299        let f32mm_turn_off = std::env::var("PIRE_F32MM_OFF").is_ok();
300        let fcma_turn_off = std::env::var("PIRE_FCMA_OFF").is_ok();
301        let i8mm_turn_off = std::env::var("PIRE_I8MM_OFF").is_ok();
302
303        cpu_ft.neon = cpu_ft.neon && !neon_turn_off;
304        cpu_ft.sve = cpu_ft.sve && !sve_turn_off;
305        cpu_ft.fp16 = cpu_ft.fp16 && !fp16_turn_off;
306        cpu_ft.f32mm = cpu_ft.f32mm && !f32mm_turn_off;
307        cpu_ft.fcma = cpu_ft.fcma && !fcma_turn_off;
308        cpu_ft.i8mm = cpu_ft.i8mm && !i8mm_turn_off;
309    }
310}
311
312#[cfg(not(feature = "debug_cpu_features"))]
313pub static RUNTIME_HW_CONFIG: Lazy<HWConfig> = Lazy::new(|| detect_hw_config());
314#[cfg(feature = "debug_cpu_features")]
315pub static RUNTIME_HW_CONFIG: Lazy<HWConfig> = Lazy::new(|| {
316    let mut hw_config = detect_hw_config();
317    apply_debug_cpu_features(&mut hw_config.cpu_ft);
318    hw_config
319});
320
321pub static PIRE_NUM_THREADS: Lazy<usize> = Lazy::new(|| {
322    let n_core = std::thread::available_parallelism().unwrap().get();
323    // PIRE_NUM_THREADS or the number of logical cores
324    let x = std::env::var("PIRE_NUM_THREADS").unwrap_or(n_core.to_string());
325    x.parse::<usize>().unwrap()
326});
327#[cfg(target_arch = "x86_64")]
328pub(crate) mod cpu_features {
329    use super::HWModel;
330    use super::RUNTIME_HW_CONFIG;
331
332    pub fn hw_model() -> HWModel {
333        RUNTIME_HW_CONFIG.hw_model
334    }
335
336    pub fn has_f32_compute() -> bool {
337        // RUNTIME_HW_CONFIG.cpu_ft.avx512f || RUNTIME_HW_CONFIG.cpu_ft.avx
338        // dont use above since some avx512f also rely on avx instructions
339        // (even though avx512f should imply), we are being super conservative here
340        RUNTIME_HW_CONFIG.cpu_ft.avx || RUNTIME_HW_CONFIG.cpu_ft.sse
341    }
342
343    pub fn has_c32_compute() -> bool {
344        RUNTIME_HW_CONFIG.cpu_ft.avx || (RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse3)
345    }
346
347    pub fn has_f16f32_compute() -> bool {
348        RUNTIME_HW_CONFIG.cpu_ft.avx && RUNTIME_HW_CONFIG.cpu_ft.f16c
349    }
350    pub fn has_f64_compute() -> bool {
351        RUNTIME_HW_CONFIG.cpu_ft.avx || (RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse2)
352    }
353
354    pub fn has_c64_compute() -> bool {
355        RUNTIME_HW_CONFIG.cpu_ft.avx
356            || (RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse2 && RUNTIME_HW_CONFIG.cpu_ft.sse3)
357    }
358
359    pub fn has_f16_compute() -> bool {
360        RUNTIME_HW_CONFIG.cpu_ft.avx512f16
361            && RUNTIME_HW_CONFIG.cpu_ft.avx
362            && RUNTIME_HW_CONFIG.cpu_ft.f16c
363            && RUNTIME_HW_CONFIG.cpu_ft.fma
364    }
365    pub fn has_i16i32_compute() -> bool {
366        (RUNTIME_HW_CONFIG.cpu_ft.avx2 && RUNTIME_HW_CONFIG.cpu_ft.avx)
367            || (RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse2)
368    }
369    pub fn has_i8i32_compute() -> bool {
370        (RUNTIME_HW_CONFIG.cpu_ft.avx2 && RUNTIME_HW_CONFIG.cpu_ft.avx)
371            || (RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse2 && RUNTIME_HW_CONFIG.cpu_ft.ssse3)
372    }
373    // TODO: Use actual info from hardware
374    pub fn get_cache_params() -> (usize, usize, usize) {
375        (4800, 256, 128)
376    }
377}
378
379#[cfg(target_arch = "x86")]
380pub(crate) mod cpu_features {
381    use super::HWModel;
382    use super::RUNTIME_HW_CONFIG;
383
384    pub fn hw_model() -> HWModel {
385        RUNTIME_HW_CONFIG.hw_model
386    }
387
388    pub fn has_f32_compute() -> bool {
389        // RUNTIME_HW_CONFIG.cpu_ft.avx512f || RUNTIME_HW_CONFIG.cpu_ft.avx
390        // dont use above since some avx512f also rely on avx instructions
391        // (even though avx512f should imply), we are being super conservative here
392        RUNTIME_HW_CONFIG.cpu_ft.sse
393    }
394
395    pub fn has_c32_compute() -> bool {
396        RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse3
397    }
398
399    pub fn has_f16f32_compute() -> bool {
400        false
401    }
402    pub fn has_f64_compute() -> bool {
403        RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse2
404    }
405
406    pub fn has_c64_compute() -> bool {
407        RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse2 && RUNTIME_HW_CONFIG.cpu_ft.sse3
408    }
409
410    pub fn has_f16_compute() -> bool {
411        false
412    }
413    pub fn has_i16i32_compute() -> bool {
414        RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse2
415    }
416    pub fn has_i8i32_compute() -> bool {
417        RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse2 && RUNTIME_HW_CONFIG.cpu_ft.ssse3
418    }
419    // TODO: Use actual info from hardware
420    pub fn get_cache_params() -> (usize, usize, usize) {
421        (4800, 256, 128)
422    }
423}
424#[cfg(target_arch = "aarch64")]
425pub(crate) mod cpu_features {
426
427    // neon is required for all the compute since
428    // it is used for packing and unpacking and
429    // available for all arch for which other extension are available
430    // unless something weird happends with the vendor
431    // For those (marginal) cases, we probably dont want to bother supporting
432    use super::HWModel;
433    use super::RUNTIME_HW_CONFIG;
434
435    pub fn hw_model() -> HWModel {
436        RUNTIME_HW_CONFIG.hw_model
437    }
438
439    pub fn has_f32_compute() -> bool {
440        // RUNTIME_HW_CONFIG.cpu_ft.avx512f || RUNTIME_HW_CONFIG.cpu_ft.avx
441        // dont use above since some avx512f also rely on avx instructions
442        // (even though avx512f should imply), we are being super conservative here
443        RUNTIME_HW_CONFIG.cpu_ft.neon
444    }
445
446    pub fn has_c32_compute() -> bool {
447        RUNTIME_HW_CONFIG.cpu_ft.neon
448    }
449
450    pub fn has_f16f32_compute() -> bool {
451        false
452    }
453    pub fn has_f64_compute() -> bool {
454        RUNTIME_HW_CONFIG.cpu_ft.neon
455    }
456
457    pub fn has_c64_compute() -> bool {
458        RUNTIME_HW_CONFIG.cpu_ft.neon
459    }
460
461    pub fn has_f16_compute() -> bool {
462        RUNTIME_HW_CONFIG.cpu_ft.fp16 && RUNTIME_HW_CONFIG.cpu_ft.neon
463    }
464    pub fn has_i16i32_compute() -> bool {
465        // currenty we do not support this
466        // since the only insturction is smlal, whose throupout is not high enough
467        false
468    }
469    pub fn has_i8i32_compute() -> bool {
470        RUNTIME_HW_CONFIG.cpu_ft.i8mm && RUNTIME_HW_CONFIG.cpu_ft.neon
471    }
472    // TODO: Use actual info from hardware
473    pub fn get_cache_params() -> (usize, usize, usize) {
474        (4800, 256, 128)
475    }
476}
477
478#[cfg(not(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64")))]
479pub(crate) mod cpu_features {
480    use super::HWModel;
481    use super::RUNTIME_HW_CONFIG;
482
483    pub fn hw_model() -> HWModel {
484        RUNTIME_HW_CONFIG.hw_model
485    }
486
487    pub fn has_f32_compute() -> bool {
488        false
489    }
490
491    pub fn has_c32_compute() -> bool {
492        false
493    }
494
495    pub fn has_f16f32_compute() -> bool {
496        false
497    }
498    pub fn has_f64_compute() -> bool {
499        false
500    }
501
502    pub fn has_c64_compute() -> bool {
503        false
504    }
505
506    pub fn has_f16_compute() -> bool {
507        false
508    }
509    pub fn has_i16i32_compute() -> bool {
510        false
511    }
512    pub fn has_i8i32_compute() -> bool {
513        false
514    }
515    pub fn get_cache_params() -> (usize, usize, usize) {
516        (4800, 256, 128)
517    }
518}
519pub use cpu_features::*;
520
521pub struct PackPool {
522    pub buffer: RwLock<Vec<Mutex<Vec<u8>>>>,
523}
524
525pub static PACK_POOL: PackPool = PackPool { buffer: RwLock::new(vec![]) };
526
527pub fn acquire<'a>(
528    pool_guard: &'a RwLockReadGuard<'a, Vec<Mutex<Vec<u8>>>>,
529    pack_size: usize,
530) -> Option<MutexGuard<'a, Vec<u8>>> {
531    // find the first free buffer with enough size
532    // let x = PACK_POOL.buffer.read().unwrap();
533    for i in pool_guard.iter() {
534        // TODO: this might be the most optimal algo in terms of fragmentation/meory reuse
535        // It is very optimal for all cases (except for a few exceptional cases)
536        // Exceptional cases: You have  mulththreading along mc loop that is changing in run time (so it requires varying number of packa pool for 1 gemm run)
537        // This is exceptional since this can happen only if the threadConfig is created by user and threadconfig is changing its parallelsis along mc during run.
538        // I cannot think of a rason why someone would do that (maybe unusual hardware, or just experimentation).
539        // Also, the current algo is very simple and easy  to understand.
540        let lock = i.try_lock();
541        if let Ok(mutex) = lock {
542            if mutex.len() >= pack_size {
543                return Some(mutex);
544            }
545        }
546    }
547
548    None
549}
550
551pub fn extend<'a>(pool_vec: Vec<u8>) {
552    let mut pool_guard = PACK_POOL.buffer.write().unwrap();
553    pool_guard.push(Mutex::new(pool_vec));
554}
555
556pub struct PireThreadConfig<'a> {
557    pub ic_id: usize,
558    // pc_id: usize,
559    pub jc_id: usize,
560    pub ir_id: usize,
561    pub jr_id: usize,
562    pub i_load_p_idx: usize,
563    pub j_load_p_idx: usize,
564    pub mc_eff: usize,
565    pub nc_eff: usize,
566    pub kc_eff: usize,
567    pub par: PirePar,
568    pub packa_barrier: &'a [Barrier],
569    pub packb_barrier: &'a [Barrier],
570}
571
572pub fn get_apbp_barrier(par: &PirePar) -> (Vec<Barrier>, Vec<Barrier>) {
573    let mut packa_barrier = vec![];
574    for _ in 0..par.ic_par {
575        let barrier = Barrier::new(par.jc_par * par.pc_par * par.ir_par * par.jr_par);
576        packa_barrier.push(barrier);
577    }
578
579    let mut packb_barrier = vec![];
580    for _ in 0..par.jc_par {
581        let barrier = Barrier::new(par.ic_par * par.pc_par * par.ir_par * par.jr_par);
582        packb_barrier.push(barrier);
583    }
584
585    (packa_barrier, packb_barrier)
586}
587
588impl<'a> PireThreadConfig<'a> {
589    pub fn new(
590        par: PirePar,
591        packa_barrier: &'a [Barrier],
592        packb_barrier: &'a [Barrier],
593        t_id: usize,
594        mc_eff: usize,
595        nc_eff: usize,
596        kc_eff: usize,
597    ) -> Self {
598        let ic_id = par.get_ic_id(t_id);
599        // let pc_id = par.get_pc_id(t_id);
600        let jc_id = par.get_jc_id(t_id);
601        let ir_id = par.get_ir_id(t_id);
602        let jr_id = par.get_jr_id(t_id);
603        let i_load_p_idx = jc_id * par.ir_par * par.jr_par + ir_id * par.jr_par + jr_id;
604        let j_load_p_idx = ic_id * par.ir_par * par.jr_par + ir_id * par.jr_par + jr_id;
605
606        Self {
607            ic_id,
608            // pc_id,
609            jc_id,
610            ir_id,
611            jr_id,
612            i_load_p_idx,
613            j_load_p_idx,
614            mc_eff,
615            nc_eff,
616            kc_eff,
617            par,
618            packa_barrier,
619            packb_barrier,
620        }
621    }
622    #[inline]
623    pub fn wait_packa(&self) {
624        if self.par.jc_par * self.par.pc_par * self.par.ir_par * self.par.jr_par > 1 {
625            self.packa_barrier[self.ic_id].wait();
626        }
627    }
628
629    #[inline]
630    pub fn wait_packb(&self) {
631        if self.par.ic_par * self.par.pc_par * self.par.ir_par * self.par.jr_par > 1 {
632            self.packb_barrier[self.jc_id].wait();
633        }
634    }
635}
636
637// pub fn check_mem_size(mem_size: usize, rs: usize, cs: usize, m: usize, n: usize) {
638//     assert!(mem_size >= rs * cs * m * n);
639//     assert!(rs >= 1 && cs >= 1 && m >= 0 && n >= 0);
640// }
641
642// once this is read, this cannot be changed for the time being.
643#[inline(always)]
644pub fn pire_num_threads() -> usize {
645    return *PIRE_NUM_THREADS;
646}
647
648#[derive(Copy, Clone)]
649pub struct PirePar {
650    pub num_threads: usize,
651    pub ic_par: usize,
652    pub pc_par: usize,
653    pub jc_par: usize,
654    pub ir_par: usize,
655    pub jr_par: usize,
656}
657
658// greedy algo to distribute the number of threads evenly
659// simple works for the time being
660#[inline(always)]
661fn inc_par(ic_par: usize, jc_par: usize, ic_max: usize, jc_max: usize, factor: usize) -> (usize, usize, usize, usize) {
662    if (ic_par < jc_par && ic_par < ic_max) || (jc_par >= jc_max && ic_par < ic_max) {
663        (ic_par * factor, jc_par, ic_max / factor, jc_max)
664    } else if (ic_par >= jc_par && jc_par < jc_max) || (ic_par >= ic_max && jc_par < jc_max) {
665        (ic_par, jc_par * factor, ic_max, jc_max / factor)
666    } else {
667        (ic_par, jc_par, ic_max, jc_max)
668    }
669}
670impl PirePar {
671    pub fn new(num_threads: usize, ic_par: usize, pc_par: usize, jc_par: usize, ir_par: usize, jr_par: usize) -> Self {
672        assert_eq!(num_threads, jc_par * pc_par * ic_par * jr_par * ir_par);
673        Self { num_threads, ic_par, pc_par, jc_par, ir_par, jr_par }
674    }
675    pub fn from_num_threads(num_threads: usize, m: usize, n: usize) -> Self {
676        let mut num_threads = num_threads;
677        let mut ic_par_max = if m < 96 {
678            1
679        } else if m < 400 {
680            2
681        } else {
682            m / 200
683        };
684        let mut jc_par_max = if n < 48 {
685            1
686        } else if n < 200 {
687            2
688        } else {
689            n / 100
690        };
691
692        if num_threads <= 12 {
693            let jc_par_max = jc_par_max.min(num_threads);
694            let n_thread = (num_threads / jc_par_max) * jc_par_max;
695            return Self::new(n_thread, num_threads / jc_par_max, 1, jc_par_max, 1, 1);
696        }
697        // let mut jr_par_max = if k < 96 { 1 } else if jc_par_max => 4 { 4.min(k / 4) };
698        num_threads = num_threads.min(ic_par_max * jc_par_max);
699        let mut ic_par = 1;
700        let pc_par = 1;
701        let mut jc_par = 1;
702        let mut ir_par = 1;
703        let jr_par = 1;
704
705        while num_threads > 1 {
706            if num_threads % 2 == 0 {
707                num_threads = num_threads / 2;
708                (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 2);
709            } else if num_threads % 3 == 0 {
710                num_threads = num_threads / 3;
711                (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 3);
712            } else if num_threads % 5 == 0 {
713                num_threads = num_threads / 5;
714                (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 5);
715                continue;
716            } else if num_threads % 7 == 0 {
717                num_threads = num_threads / 7;
718                (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 7);
719                continue;
720            } else {
721                // if it is non trivial prime factor (i.e. not divisible by 2,3,5,7)
722                // round it so it is a "nice" number
723                num_threads = num_threads / 2 * 2;
724            }
725            // if num_threads % 11 == 0 {
726            //     num_threads = num_threads / 11;
727            //     (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 11);
728            //     continue;
729            // }
730            // if num_threads % 13 == 0 {
731            //     num_threads = num_threads / 13;
732            //     (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 13);
733            //     continue;
734            // }
735            // if num_threads % 17 == 0 {
736            //     num_threads = num_threads / 17;
737            //     (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 17);
738            //     continue;
739            // }
740        }
741        if ic_par >= 8 {
742            ic_par = ic_par / 2;
743            ir_par = 2;
744        }
745        let num_threads = ic_par * pc_par * jc_par * ir_par * jr_par;
746        Self { num_threads, ic_par, pc_par, jc_par, ir_par, jr_par }
747    }
748    #[inline(always)]
749    pub fn default(m: usize, n: usize) -> Self {
750        let num_threads = pire_num_threads();
751        Self::from_num_threads(num_threads, m, n)
752    }
753    #[inline]
754    fn get_ic_id(&self, t_id: usize) -> usize {
755        (t_id / (self.pc_par * self.jc_par * self.ir_par * self.jr_par)) % self.ic_par
756    }
757
758    //    #[inline]
759    //    fn get_pc_id(&self, t_id: usize) -> usize {
760    //        (t_id / (self.jr_par*self.ir_par*self.ic_par)) % self.pc_par
761    //    }
762    #[inline]
763    fn get_jc_id(&self, t_id: usize) -> usize {
764        (t_id / (self.jr_par * self.ir_par)) % self.jc_par
765    }
766    #[inline]
767    fn get_jr_id(&self, t_id: usize) -> usize {
768        (t_id / self.ir_par) % self.jr_par
769    }
770    #[inline]
771    fn get_ir_id(&self, t_id: usize) -> usize {
772        t_id % self.ir_par
773    }
774
775    pub fn get_load_par(
776        &self,
777        gemm_mode: &GemmPool,
778        m: usize,
779        n: usize,
780        mc_eff: usize,
781        nc_eff: usize,
782    ) -> (usize, usize) {
783        let m = (m / self.ic_par).min(mc_eff);
784        let n = (n / self.jc_par).min(nc_eff);
785        let i_load_par = ((m + 127) / 128).min(self.num_threads / self.ic_par);
786        let j_load_par = ((n + 127) / 128).min(self.num_threads / self.jc_par);
787        let i_load_par = match gemm_mode {
788            GemmPool::Goto => i_load_par,
789            GemmPool::SmallM => i_load_par,
790            GemmPool::SmallN => 1,
791        };
792        (i_load_par.max(1), j_load_par.max(1))
793    }
794}
795
796#[inline]
797pub fn split_c_range(m: usize, mc: usize, mr: usize, ic_id: usize, ic_par: usize) -> (usize, usize, bool) {
798    let chunk_len = (m / (mr * ic_par)) * mr;
799    let rem = m % (mr * ic_par);
800    if ic_id == 0 {
801        let x = chunk_len + rem % mr;
802        let mc_left = ((((x + mc - 1) / mc) * mc) * ic_par) < m;
803        return (m - chunk_len - (rem % mr), m, mc_left);
804    }
805    let ic_id = ic_id - 1;
806    let m0 = (m / mr) * mr;
807    let rem = m0 % (mr * ic_par);
808    let start_delta = rem.min(ic_id * mr);
809    let end_delta = rem.min((ic_id + 1) * mr);
810    //    let is_m_boundary = (chunk_len + end_delta - start_delta ) % mc == 0;
811    let mc_coeff = (chunk_len + end_delta - start_delta + mc - 1) / mc;
812    let mc_left = ((mc_coeff * mc) * ic_par) < m;
813    //    let mc_left = is_m_boundary && rem != 0 && end_delta == start_delta;
814    (chunk_len * ic_id + start_delta, chunk_len * (ic_id + 1) + end_delta, mc_left)
815}
816
817#[inline]
818pub fn split_range(range_len: usize, unit_len: usize, r_id: usize, r_par: usize) -> (usize, usize) {
819    let chunk_start = (range_len / (unit_len * r_par)) * unit_len * r_id;
820    let chunk_end = (range_len / (unit_len * r_par)) * unit_len * (r_id + 1);
821    let rem = range_len % (unit_len * r_par);
822    let rem = rem - rem % unit_len;
823    let rem_start = rem.min(r_id * unit_len);
824    let rem_end = rem.min((r_id + 1) * unit_len);
825    if r_id == r_par - 1 {
826        return (chunk_start + rem_start, range_len);
827    }
828    (chunk_start + rem_start, chunk_end + rem_end)
829}
830
831pub trait BaseNum: Copy + 'static + Send {}
832
833impl<T> BaseNum for T where T: Copy + 'static + Send {}
834
835#[derive(Copy, Clone)]
836pub struct PoolSize {
837    pub m: usize,
838    pub n: usize,
839    pub k: usize,
840    pub ap_pool_size: usize,
841    pub ap_pool_multiplicity: usize,
842    pub bp_pool_size: usize,
843    pub bp_pool_multiplicity: usize,
844}
845
846impl PoolSize {
847    // add alignment padding for ab only for total memory pool sizes
848    pub fn mem_pool_size_b<TA, TB>(&self) -> usize {
849        // be conservative and add 2 * AB_ALIGN padding always
850        self.ap_pool_size * size_of::<TA>() * self.ap_pool_multiplicity
851            + self.bp_pool_size * size_of::<TB>() * self.bp_pool_multiplicity
852            + 2 * AB_ALIGN
853    }
854
855    pub fn ap_size_b<TA>(&self) -> usize {
856        self.ap_pool_size * size_of::<TA>()
857    }
858
859    pub fn bp_size_b<TB>(&self) -> usize {
860        self.bp_pool_size * size_of::<TB>()
861    }
862
863    pub fn ap_size_t_b<TA>(&self) -> usize {
864        self.ap_pool_size * size_of::<TA>() * self.ap_pool_multiplicity
865    }
866
867    pub fn bp_size_t_b<TB>(&self) -> usize {
868        self.bp_pool_size * size_of::<TB>() * self.bp_pool_multiplicity
869    }
870
871    pub fn slice_mut_from_pool<TA, TB>(
872        &self,
873        mem_pool: &mut [u8],
874        i_load_par: usize,
875        j_load_par: usize,
876        pool_size: PoolSize,
877        mr: usize,
878        nr: usize,
879        // mc: usize, nc: usize, kc: usize, mr: usize, nr: usize,
880    ) -> (Vec<RangeLock<'_, TA>>, Vec<RangeLock<'_, TB>>) {
881        let m_size = pool_size.m;
882        let n_size = pool_size.n;
883        let k_size = pool_size.k;
884        let ap_pool_size = self.ap_pool_size;
885        let ap_pool_size_b = ap_pool_size * size_of::<TA>();
886        let a_alignment = std::mem::align_of::<TA>();
887        assert_eq!(ap_pool_size_b % a_alignment, 0);
888        let bp_pool_size = self.bp_pool_size;
889        let bp_pool_size_b = bp_pool_size * size_of::<TB>();
890        let b_alignment = std::mem::align_of::<TB>();
891        assert_eq!(bp_pool_size_b % b_alignment, 0);
892        let mut ap = vec![];
893        let mut bp = vec![];
894        // safety for pointer to slice casting: assert len of mem_pool is enough
895        // ap_pool_size
896        assert!(mem_pool.len() >= self.mem_pool_size_b::<TA, TB>());
897        // align mem_pool
898        let align_offset = mem_pool.as_ptr().align_offset(AB_ALIGN);
899        let mut mem_pool = &mut mem_pool[align_offset..];
900        // safety for pointer to slice casting: ap has right alignment
901        assert_eq!(mem_pool.as_ptr().align_offset(a_alignment), 0);
902        for _ in 0..self.ap_pool_multiplicity {
903            let (a, rest) = mem_pool.split_at_mut(ap_pool_size_b);
904            let ap_pool = unsafe { std::slice::from_raw_parts_mut::<TA>(a.as_mut_ptr() as *mut TA, ap_pool_size) };
905            if ap_pool_size == 0 {
906                ap.push(RangeLock::from(ap_pool, i_load_par, 0, k_size, mr));
907            } else {
908                ap.push(RangeLock::from(ap_pool, i_load_par, m_size, k_size, mr));
909            }
910            mem_pool = rest;
911        }
912        let align_offset = mem_pool.as_ptr().align_offset(AB_ALIGN);
913        let mut mem_pool = &mut mem_pool[align_offset..];
914        // safety for pointer to slice casting: bp has right alignment
915        assert_eq!(mem_pool.as_ptr().align_offset(b_alignment), 0);
916        for _ in 0..self.bp_pool_multiplicity {
917            let (b, rest) = mem_pool.split_at_mut(bp_pool_size_b);
918            let bp_pool = unsafe { std::slice::from_raw_parts_mut::<TB>(b.as_mut_ptr() as *mut TB, bp_pool_size) };
919            if bp_pool_size == 0 {
920                bp.push(RangeLock::from(bp_pool, j_load_par, 0, k_size, nr));
921            } else {
922                bp.push(RangeLock::from(bp_pool, j_load_par, n_size, k_size, nr));
923            }
924            mem_pool = rest;
925        }
926        (ap, bp)
927    }
928}
929
930pub fn get_mem_pool_size_goto<AP: BaseNum, BP: BaseNum, HWConfig: GemmCache>(
931    hw_config: &HWConfig,
932    par: &PirePar,
933    a_need_pool: bool,
934    b_need_pool: bool,
935) -> PoolSize {
936    let m = hw_config.get_mc_eff(par.ic_par);
937    let n = hw_config.get_nc_eff(par.jc_par);
938    let k = hw_config.get_kc_eff();
939    let (ap_pool_size, ap_pool_multiplicity) = if a_need_pool {
940        let ap_pool_multiplicity = par.ic_par;
941        let ap_pool_size = hw_config.get_ap_pool_size(par.ic_par) + CACHELINE_PAD / size_of::<AP>();
942        (ap_pool_size, ap_pool_multiplicity)
943    } else {
944        (0, 1)
945    };
946    let (bp_pool_size, bp_pool_multiplicity) = if b_need_pool {
947        let bp_pool_multiplicity = par.jc_par;
948        let bp_pool_size = hw_config.get_bp_pool_size(par.jc_par) + CACHELINE_PAD / size_of::<BP>();
949        (bp_pool_size, bp_pool_multiplicity)
950    } else {
951        (0, 1)
952    };
953    PoolSize { m, n, k, ap_pool_size, ap_pool_multiplicity, bp_pool_size, bp_pool_multiplicity }
954}
955
956pub fn get_mem_pool_size_small_m<AP: BaseNum, BP: BaseNum, HWConfig: GemmCache>(
957    hw_config: &HWConfig,
958    par: &PirePar,
959    a_need_pool: bool,
960) -> PoolSize {
961    let m = hw_config.get_mc_eff(par.ic_par);
962    let n = hw_config.get_nc_eff(par.jc_par);
963    let k = hw_config.get_kc_eff();
964    if a_need_pool {
965        let ap_pool_multiplicity = par.ic_par;
966        let ap_pool_size = hw_config.get_ap_pool_size(par.ic_par) + CACHELINE_PAD / size_of::<AP>();
967        PoolSize { m, n, k, ap_pool_size, ap_pool_multiplicity, bp_pool_size: 0, bp_pool_multiplicity: 1 }
968    } else {
969        PoolSize { m, n, k, ap_pool_size: 0, ap_pool_multiplicity: 1, bp_pool_size: 0, bp_pool_multiplicity: 1 }
970    }
971}
972
973pub fn get_mem_pool_size_small_n<AP: BaseNum, BP: BaseNum, HWConfig: GemmCache>(
974    hw_config: &HWConfig,
975    par: &PirePar,
976    b_need_pool: bool,
977) -> PoolSize {
978    let ap_pool_size = hw_config.get_ap_pool_size2() + CACHELINE_PAD / size_of::<AP>();
979    let ap_pool_multiplicity = par.num_threads;
980    let m = hw_config.mr();
981    let n = hw_config.get_nc_eff(par.jc_par);
982    let k = hw_config.get_kc_eff();
983    if b_need_pool {
984        let bp_pool_multiplicity = par.jc_par;
985        let bp_pool_size = hw_config.get_bp_pool_size(par.jc_par) + CACHELINE_PAD / size_of::<BP>();
986        PoolSize { m, n, k, ap_pool_size, ap_pool_multiplicity, bp_pool_size, bp_pool_multiplicity }
987    } else {
988        PoolSize { m, n, k, ap_pool_size, ap_pool_multiplicity, bp_pool_size: 0, bp_pool_multiplicity: 1 }
989    }
990}
991
992// Choose ap_size, bp_size as arguments since they are specific to Gemm implementation,
993// It is determined by hardware, gemm implementation (e.g. f64, f32, f16),
994// Otherwise, this base crate would include code coupled with other gemm crates,
995// this would require either cyclic dep (Not allowed of course) or separate code for each specii hardware and gemm
996// imple inside this crate, which is not desirable. We want this crate to be as decoupled as possbile from
997// specific gemm implementation and hardware.
998
999pub fn run_small_m(m: usize) -> bool {
1000    m < 144
1001}
1002
1003pub fn run_small_n(n: usize) -> bool {
1004    n < 144
1005}
1006
1007pub enum GemmPool {
1008    Goto,
1009    SmallM,
1010    SmallN,
1011}
1012
1013#[derive(Clone, Copy)]
1014pub struct StridedMatrix<T> {
1015    pub(crate) src: *const T,
1016    pub(crate) rs: usize,
1017    pub(crate) cs: usize,
1018}
1019
1020impl<T> StridedMatrix<T> {
1021    pub fn new(src: *const T, rs: usize, cs: usize) -> Self {
1022        Self { src, rs, cs }
1023    }
1024}
1025
1026unsafe impl<T> Send for StridedMatrix<T> {}
1027
1028#[derive(Clone, Copy)]
1029pub struct StridedMatrixMut<T> {
1030    pub(crate) src: *mut T,
1031    pub(crate) rs: usize,
1032    pub(crate) cs: usize,
1033}
1034
1035unsafe impl<T> Send for StridedMatrixMut<T> {}
1036
1037impl<T> StridedMatrixMut<T> {
1038    pub fn new(src: *mut T, rs: usize, cs: usize) -> Self {
1039        Self { src, rs, cs }
1040    }
1041}
1042
1043#[derive(Clone)]
1044pub struct StridedMatrixP<'a, T, U> {
1045    pub(crate) src: *const T,
1046    pub(crate) rs: usize,
1047    pub(crate) cs: usize,
1048    pub(crate) dst: &'a RangeLock<'a, U>,
1049}
1050
1051unsafe impl<'a, T, U> Send for StridedMatrixP<'a, T, U> {}
1052
1053impl<'a, T, U> StridedMatrixP<'a, T, U> {
1054    pub fn src(&self) -> *const T {
1055        self.src
1056    }
1057    pub fn dst_write(&self, idx: usize, kc: usize) -> RangeLockWriteGuard<'a, 'a, U> {
1058        self.dst.write(idx, kc).unwrap()
1059    }
1060    pub fn dst_read(&self) -> RangeLockReadGuard<'a, 'a, U> {
1061        self.dst.read().unwrap()
1062    }
1063    pub fn get_mc(&self) -> usize {
1064        self.dst.get_mc()
1065    }
1066    pub fn rs(&self) -> usize {
1067        self.rs
1068    }
1069    pub fn cs(&self) -> usize {
1070        self.cs
1071    }
1072}
1073
1074#[derive(Clone, Copy)]
1075pub struct PackedMatrix<T> {
1076    pub(crate) src: *const T,
1077    pub(crate) k: usize,
1078    pub(crate) m: usize,
1079    // pub(crate) m0: usize,
1080    // pub(crate) k0: usize,
1081}
1082
1083unsafe impl<T> Send for PackedMatrix<T> {}
1084
1085impl<T> PackedMatrix<T> {
1086    pub fn src(&self) -> *const T {
1087        self.src
1088    }
1089    pub fn k(&self) -> usize {
1090        self.k
1091    }
1092    pub fn m(&self) -> usize {
1093        self.m
1094    }
1095    // pub fn at(&self, m: usize, k: usize) -> *const T {
1096    //     let m_rounded = (m+m0-1) / m0 * m0;
1097    //     self.src.add(m_rounded)
1098}
1099
1100#[derive(Clone)]
1101pub struct PackedMatrixMixed<'a, X, Y> {
1102    pub(crate) src: *const X,
1103    pub(crate) dst: &'a RangeLock<'a, Y>,
1104    pub(crate) k: usize,
1105    pub(crate) m: usize,
1106}
1107
1108impl<'a, X, Y> PackedMatrixMixed<'a, X, Y> {
1109    pub fn src(&self) -> *const X {
1110        self.src
1111    }
1112    pub fn k(&self) -> usize {
1113        self.k
1114    }
1115    pub fn m(&self) -> usize {
1116        self.m
1117    }
1118
1119    pub fn dst_write(&self, idx: usize, kc: usize) -> RangeLockWriteGuard<'a, 'a, Y> {
1120        self.dst.write(idx, kc).unwrap()
1121    }
1122
1123    pub fn get_mc(&self) -> usize {
1124        self.dst.get_mc()
1125    }
1126
1127    pub fn dst_read(&self) -> RangeLockReadGuard<'a, 'a, Y> {
1128        self.dst.read().unwrap()
1129    }
1130}
1131
1132unsafe impl<X, Y> Send for PackedMatrixMixed<'_, X, Y> {}
1133
1134// must be multiple largest vector size that we support
1135// Now, it avx512 -> 64 bytes
1136pub const AB_ALIGN: usize = 1024;
1137
1138pub trait GemmCache {
1139    fn mr(&self) -> usize;
1140    fn get_mc_eff(&self, par: usize) -> usize;
1141    fn get_kc_eff(&self) -> usize;
1142    fn get_nc_eff(&self, par: usize) -> usize;
1143    fn get_ap_pool_size(&self, ic_par: usize) -> usize {
1144        let mc_eff = self.get_mc_eff(ic_par);
1145        let kc_eff = self.get_kc_eff();
1146        mc_eff * kc_eff
1147    }
1148    fn get_ap_pool_size2(&self) -> usize {
1149        let kc_eff = self.get_kc_eff();
1150        self.mr() * kc_eff
1151    }
1152    fn get_bp_pool_size(&self, jc_par: usize) -> usize {
1153        let nc_eff = self.get_nc_eff(jc_par);
1154        let kc_eff = self.get_kc_eff();
1155        nc_eff * kc_eff
1156    }
1157}
1158
1159#[derive(Copy, Clone)]
1160pub enum Array<X> {
1161    StridedMatrix(StridedMatrix<X>),
1162    PackedMatrix(PackedMatrix<X>),
1163}
1164
1165impl<X> Array<X> {
1166    pub fn strided_matrix(src: *const X, rs: usize, cs: usize) -> Self {
1167        Array::StridedMatrix(StridedMatrix::new(src, rs, cs))
1168    }
1169    pub fn packed_matrix(src: *const X, m: usize, k: usize) -> Self {
1170        Array::PackedMatrix(PackedMatrix { src, k, m })
1171    }
1172    pub fn into_pack_array<'a>(&self, a: &'a [RangeLock<'a, X>], p_id: usize) -> PArray<'a, X> {
1173        match self {
1174            Array::StridedMatrix(x) => {
1175                let x = StridedMatrixP { src: x.src, rs: x.rs, cs: x.cs, dst: &a[p_id] };
1176                PArray::<X>::StridedMatrix(x)
1177            }
1178            Array::PackedMatrix(x) => {
1179                let x = PackedMatrix { src: x.src, k: x.k, m: x.m };
1180                PArray::PackedMatrix(x)
1181            }
1182        }
1183    }
1184    pub fn into_pack_array2<'a, Y>(&self, a: &'a [RangeLock<'a, Y>], p_id: usize) -> PArrayMixed<'a, X, Y> {
1185        match self {
1186            Array::StridedMatrix(x) => {
1187                let x = StridedMatrixP { src: x.src, rs: x.rs, cs: x.cs, dst: &a[p_id] };
1188                PArrayMixed::<X, Y>::StridedMatrix(x)
1189            }
1190            Array::PackedMatrix(x) => {
1191                let x = PackedMatrixMixed { src: x.src, dst: &a[p_id], k: x.k, m: x.m };
1192                PArrayMixed::PackedMatrix(x)
1193            }
1194        }
1195    }
1196
1197    pub fn src(&self) -> *const X {
1198        match self {
1199            Array::StridedMatrix(x) => x.src,
1200            Array::PackedMatrix(x) => x.src,
1201        }
1202    }
1203
1204    pub fn transpose(&mut self) {
1205        match self {
1206            Array::StridedMatrix(x) => {
1207                let temp = x.rs;
1208                x.rs = x.cs;
1209                x.cs = temp;
1210            }
1211            _ => {
1212                panic!("Only StridedMatrix has transpose");
1213            }
1214        }
1215    }
1216
1217    pub fn rs(&self) -> usize {
1218        match self {
1219            Array::StridedMatrix(x) => x.rs,
1220            _ => {
1221                panic!("Only StridedMatrix has rs");
1222            }
1223        }
1224    }
1225
1226    pub fn cs(&self) -> usize {
1227        match self {
1228            Array::StridedMatrix(x) => x.cs,
1229            _ => {
1230                panic!("Only StridedMatrix has cs");
1231            }
1232        }
1233    }
1234
1235    pub fn is_strided(&self) -> bool {
1236        match self {
1237            Array::StridedMatrix(_) => true,
1238            _ => false,
1239        }
1240    }
1241}
1242
1243#[derive(Copy, Clone)]
1244pub enum ArrayMut<X> {
1245    StridedMatrix(StridedMatrixMut<X>),
1246}
1247
1248impl<X> ArrayMut<X> {
1249    pub fn strided_matrix(src: *mut X, rs: usize, cs: usize) -> Self {
1250        ArrayMut::StridedMatrix(StridedMatrixMut::new(src, rs, cs))
1251    }
1252
1253    pub fn src(&self) -> *mut X {
1254        match self {
1255            ArrayMut::StridedMatrix(x) => x.src,
1256        }
1257    }
1258
1259    pub fn transpose(&mut self) {
1260        match self {
1261            ArrayMut::StridedMatrix(x) => {
1262                let temp = x.rs;
1263                x.rs = x.cs;
1264                x.cs = temp;
1265            }
1266        }
1267    }
1268
1269    pub fn rs(&self) -> usize {
1270        match self {
1271            ArrayMut::StridedMatrix(x) => x.rs,
1272        }
1273    }
1274
1275    pub fn cs(&self) -> usize {
1276        match self {
1277            ArrayMut::StridedMatrix(x) => x.cs,
1278        }
1279    }
1280}
1281
1282#[derive(Clone)]
1283pub enum PArray<'a, X> {
1284    StridedMatrix(StridedMatrixP<'a, X, X>),
1285    PackedMatrix(PackedMatrix<X>),
1286}
1287
1288impl<'a, X> PArray<'a, X> {
1289    pub fn src(&self) -> *const X {
1290        match self {
1291            Self::StridedMatrix(x) => x.src,
1292            Self::PackedMatrix(x) => x.src,
1293        }
1294    }
1295
1296    pub fn rs(&self) -> usize {
1297        match self {
1298            Self::StridedMatrix(x) => x.rs,
1299            _ => {
1300                panic!("Only StridedMatrix has rs");
1301            }
1302        }
1303    }
1304
1305    pub fn cs(&self) -> usize {
1306        match self {
1307            Self::StridedMatrix(x) => x.cs,
1308            _ => {
1309                panic!("Only StridedMatrix has cs");
1310            }
1311        }
1312    }
1313
1314    pub fn dst_write(&self, idx: usize, kc: usize) -> RangeLockWriteGuard<'a, 'a, X> {
1315        match self {
1316            Self::StridedMatrix(x) => x.dst.write(idx, kc).unwrap(),
1317            _ => {
1318                panic!("Only StridedMatrix has write guard");
1319            }
1320        }
1321    }
1322
1323    pub fn dst_read(&self) -> RangeLockReadGuard<'a, 'a, X> {
1324        match self {
1325            Self::StridedMatrix(x) => x.dst.read().unwrap(),
1326            _ => {
1327                panic!("Only StridedMatrix has read guard");
1328            }
1329        }
1330    }
1331
1332    pub fn is_strided(&self) -> bool {
1333        match self {
1334            Self::StridedMatrix(_) => true,
1335            _ => false,
1336        }
1337    }
1338}
1339
1340#[derive(Clone)]
1341pub enum PArrayMixed<'a, X, Y> {
1342    StridedMatrix(StridedMatrixP<'a, X, Y>),
1343    PackedMatrix(PackedMatrixMixed<'a, X, Y>),
1344}
1345
1346impl<'a, X, Y> PArrayMixed<'a, X, Y> {
1347    pub fn src(&self) -> *const X {
1348        match self {
1349            Self::StridedMatrix(x) => x.src,
1350            Self::PackedMatrix(x) => x.src,
1351        }
1352    }
1353
1354    pub fn rs(&self) -> usize {
1355        match self {
1356            Self::StridedMatrix(x) => x.rs,
1357            _ => {
1358                panic!("Only StridedMatrix has rs");
1359            }
1360        }
1361    }
1362
1363    pub fn cs(&self) -> usize {
1364        match self {
1365            Self::StridedMatrix(x) => x.cs,
1366            _ => {
1367                panic!("Only StridedMatrix has cs");
1368            }
1369        }
1370    }
1371
1372    pub fn dst_write(&self, idx: usize, kc: usize) -> RangeLockWriteGuard<'a, 'a, Y> {
1373        match self {
1374            Self::StridedMatrix(x) => x.dst.write(idx, kc).unwrap(),
1375            Self::PackedMatrix(x) => x.dst.write(idx, kc).unwrap(),
1376        }
1377    }
1378    pub fn dst_read(&self) -> RangeLockReadGuard<'a, 'a, Y> {
1379        match self {
1380            Self::StridedMatrix(x) => x.dst.read().unwrap(),
1381            Self::PackedMatrix(x) => x.dst.read().unwrap(),
1382        }
1383    }
1384    pub fn is_strided(&self) -> bool {
1385        match self {
1386            Self::StridedMatrix(_) => true,
1387            _ => false,
1388        }
1389    }
1390}
1391
1392pub enum PtrData<'a, X> {
1393    RefData(RangeLockReadGuard<'a, 'a, X>),
1394    PtrData(*const X),
1395}
1396
1397impl<'a, X> PtrData<'a, X> {
1398    pub fn src(&self) -> *const X {
1399        match self {
1400            PtrData::RefData(x) => x.get().as_ptr(),
1401            PtrData::PtrData(x) => x.clone(),
1402        }
1403    }
1404}
1405
1406pub fn matrix_size_strided(m: usize, n: usize, rs: usize, cs: usize) -> usize {
1407    (m - 1) * rs + (n - 1) * cs
1408}
1409
1410#[macro_export]
1411macro_rules! packing_api {
1412    ($ta:ty, $tb:ty) => {
1413        fn a_size_packed(m: usize, k: usize) -> usize {
1414            let round_m_fn = dispatch_round_m();
1415            let round_k_fn = dispatch_round_k();
1416            let m_round = round_m_fn(m);
1417            let k_round = round_k_fn(k);
1418            return m_round * k_round;
1419        }
1420
1421        fn b_size_packed(n: usize, k: usize) -> usize {
1422            let round_k_fn = dispatch_round_k();
1423            let k_round = round_k_fn(k);
1424            return n * k_round;
1425        }
1426        // block idx for packa and packb is s.t.
1427        // m dim for block idx is contiguous and n dim is contiguous
1428        // this is to ensure that indexing for parallelization over these dims are easy  (otherwise ranges would have to be in the same mc, nc range)
1429        // this is not an issue since we do not parallelize over k dim (think about this when we parallelize over k dim in the future, which is only beneficial only
1430        // in the special case of very large k and small m, n
1431
1432        /// # Safety
1433        ///
1434        /// a and ap must have big enough size to store the packed matrix
1435        pub unsafe fn pack_a_unchecked(
1436            m: usize,
1437            k: usize,
1438            a: *const $ta,
1439            a_rs: usize,
1440            a_cs: usize,
1441            ap: *mut $ta,
1442        ) -> Array<TA> {
1443            assert_eq!(ap.align_offset(AB_ALIGN), 0);
1444            if m == 1 {
1445                for j in 0..k {
1446                    *ap.add(j) = *a.add(j * a_cs);
1447                }
1448                return Array::strided_matrix(ap, 1, m);
1449            }
1450            let pack_fn = dispatch_pack_a();
1451            let round_m_fn = dispatch_round_m();
1452            let round_k_fn = dispatch_round_k();
1453
1454            let (mc, _, kc) = dispatch_get_mcnckc();
1455            let mut ap_cur = ap;
1456            for p in (0..k).step_by(kc) {
1457                let kc_len = kc.min(k - p);
1458                let kc_len_eff = round_k_fn(kc_len);
1459                for i in (0..m).step_by(mc) {
1460                    let mc_len = mc.min(m - i);
1461                    let mc_len_eff = round_m_fn(mc_len);
1462                    let a_cur = a.add(i * a_rs + p * a_cs);
1463                    pack_fn(a_cur, ap_cur, mc_len, kc_len, a_rs, a_cs);
1464                    ap_cur = ap_cur.add(mc_len_eff * kc_len_eff);
1465                }
1466            }
1467            return Array::packed_matrix(ap, m, k);
1468        }
1469
1470        /// # Safety
1471        ///
1472        /// b and bp must have big enough size to store the packed matrix
1473        pub unsafe fn pack_b_unchecked(
1474            n: usize,
1475            k: usize,
1476            b: *const $tb,
1477            b_rs: usize,
1478            b_cs: usize,
1479            bp: *mut $tb,
1480        ) -> Array<TB> {
1481            assert_eq!(bp.align_offset(AB_ALIGN), 0);
1482            if n == 1 {
1483                for j in 0..k {
1484                    *bp.add(j) = *b.add(j * b_rs);
1485                }
1486                return Array::strided_matrix(bp, 1, k);
1487            }
1488            let pack_fn = dispatch_pack_b();
1489            let round_k_fn = dispatch_round_k();
1490
1491            let (_, nc, kc) = dispatch_get_mcnckc();
1492            let mut bp_cur = bp;
1493            for p in (0..k).step_by(kc) {
1494                let kc_len = kc.min(k - p);
1495                let kc_len_eff = round_k_fn(kc_len);
1496                for i in (0..n).step_by(nc) {
1497                    let nc_len = nc.min(n - i);
1498                    let b_cur = b.add(i * b_cs + p * b_rs);
1499                    pack_fn(b_cur, bp_cur, nc_len, kc_len, b_rs, b_cs);
1500                    bp_cur = bp_cur.add(nc_len * kc_len_eff);
1501                }
1502            }
1503            return Array::packed_matrix(bp, n, k);
1504        }
1505
1506        pub fn pack_a(m: usize, k: usize, a: &[$ta], a_rs: usize, a_cs: usize, ap: &mut [$ta]) -> Array<TA> {
1507            // panics if ap does not have enough size
1508            // safety check for size
1509            assert!(ap.len() >= a_size_packed(m, k));
1510            assert!(a.len() >= pire_base::matrix_size_strided(m, k, a_rs, a_cs));
1511            // safety: ap has enough size due to the assert above
1512            unsafe { pack_a_unchecked(m, k, a.as_ptr(), a_rs, a_cs, ap.as_mut_ptr()) }
1513        }
1514
1515        pub fn pack_b(n: usize, k: usize, b: &[$tb], b_rs: usize, b_cs: usize, bp: &mut [$tb]) -> Array<TB> {
1516            // panics if bp does not have enough size
1517            // safety check for size
1518            assert!(bp.len() >= b_size_packed(n, k));
1519            assert!(b.len() >= pire_base::matrix_size_strided(k, n, b_rs, b_cs));
1520            // safety: bp has enough size due to the assert above
1521            unsafe { pack_b_unchecked(n, k, b.as_ptr(), b_rs, b_cs, bp.as_mut_ptr()) }
1522        }
1523    };
1524}
1525
1526#[macro_export]
1527macro_rules! is_mixed {
1528    (T, $st1:expr, $st2:expr) => {
1529        $st1
1530    };
1531    (F, $src:expr, $st2:expr) => {
1532        $st2
1533    };
1534}
1535
1536#[macro_export]
1537macro_rules! def_pa {
1538    ($packa_ty:tt,F,$ta:tt,$tap:tt) => {
1539        type $packa_ty<'a> = PArray<'a, $tap>;
1540    };
1541    ($packa_ty:tt,T,$ta:tt,$tap:tt) => {
1542        type $packa_ty<'a> = PArrayMixed<'a, $ta, $tap>;
1543    };
1544}
1545
1546#[macro_export]
1547macro_rules! def_pire_gemm {
1548    (
1549        $t_dispatcher:tt,
1550        $ta:tt,$tap:ty,$tb:ty,$tbp:ty,$tc:ty,$t_as:ty,$t_bs:ty,
1551        $packa_ty:tt,$packb_ty:tt,
1552        $one:expr,
1553        $name:ident, $name_mt:ident,
1554        $goto_name:ident, $goto_kernel:ident,
1555        $small_m_name:ident, $small_m_kernel:ident,
1556        $small_n_name:ident, $small_n_kernel:ident,
1557        $gemv_name:ident, $gemv_name2:ident,
1558        $packa_name:ident, $packb_name:ident,
1559        $packa_name0:ident, $packb_name0:ident,
1560        $run_small_m:expr, $run_small_n:expr,
1561        $pack_fn:tt, $include_flag:tt,
1562    ) => {
1563        def_pa!($packa_ty,$include_flag,$ta,$tap);
1564        def_pa!($packb_ty,$include_flag,$tb,$tbp);
1565        pub(crate) unsafe fn $name <F:UnaryFnC>(
1566            hw_config: &$t_dispatcher <F>,
1567            m: usize, n: usize, k: usize,
1568            alpha: $t_as,
1569            a: Array<$ta>,
1570            b: Array<$tb>,
1571            beta: $t_bs,
1572            c: ArrayMut<$tc>,
1573            par: &PirePar,
1574        )
1575        {
1576            let a_need_pool = a.is_strided() || !hw_config.is_compute_native();
1577            let b_need_pool = b.is_strided() || !hw_config.is_compute_native();
1578            if n == 1 && a.is_strided() {
1579                let alpha = &alpha as *const $t_as;
1580                let beta = &beta as *const $t_bs;
1581                $gemv_name(hw_config, m, k, alpha, a, b, beta, c);
1582                return;
1583            }
1584            if m == 1 && b.is_strided() {
1585                let alpha = &alpha as *const $t_as;
1586                let beta = &beta as *const $t_bs;
1587                let mut a = a;
1588                a.transpose();
1589                let mut b = b;
1590                b.transpose();
1591                let mut c = c;
1592                c.transpose();
1593                $gemv_name2(hw_config, n, k, alpha.into(), b, a, beta, c);
1594                return;
1595            }
1596            let (gemm_mode, gemm_fun, pool_info)
1597            : (
1598                GemmPool, unsafe fn(
1599                    &$t_dispatcher <F>, usize, usize, usize, *const $t_as, $packa_ty, $packb_ty, *const $t_bs, ArrayMut<$tc>, &PireThreadConfig,
1600                ),
1601                PoolSize
1602            )
1603             = if run_small_m(m) && $run_small_m && b.is_strided() {
1604                (GemmPool::SmallM, $small_m_name, get_mem_pool_size_small_m::<$tap,$tbp,$t_dispatcher::<F>>(hw_config, par, a_need_pool))
1605            } else if run_small_n(n) && $run_small_n && a.is_strided() {
1606                (GemmPool::SmallN, $small_n_name, get_mem_pool_size_small_n::<$tap,$tbp,$t_dispatcher::<F>>(hw_config, par, b_need_pool))
1607            } else {
1608                (GemmPool::Goto, $goto_name, get_mem_pool_size_goto::<$tap,$tbp,$t_dispatcher::<F>>(hw_config, par, a_need_pool, b_need_pool))
1609            };
1610            let mem_pool_size = pool_info.mem_pool_size_b::<$tap,$tbp>();
1611            // TODO: zero pool size case is very special (aonly packed and b) to optimize, optimization will not be worth it
1612            // if mem_pool_size == 0 {
1613            //     let mut pool_vec = [0_u8; 1];
1614            //     let pool_buf = &mut pool_vec;
1615            //     $name_mt(
1616            //         hw_config, m, n, k, alpha, a, b, beta, c, par, pool_buf, gemm_mode, pool_info, gemm_fun
1617            //     );
1618            //     return;
1619            // }
1620            // run goto algo
1621            {
1622                let pool_guard = PACK_POOL.buffer.read().unwrap();
1623                let y = acquire(&pool_guard, mem_pool_size);
1624                if let Some(mut pool_vec) = y {
1625                    let pool_buf = &mut pool_vec;
1626                    $name_mt(
1627                        hw_config, m, n, k, alpha, a, b, beta, c, par, pool_buf, gemm_mode, pool_info, gemm_fun
1628                    );
1629                    return;
1630                }
1631            }
1632            let mut pool_vec = vec![0_u8; mem_pool_size];
1633            let pool_buf = &mut pool_vec;
1634            $name_mt(
1635                hw_config, m, n, k, alpha, a, b, beta, c, par, pool_buf, gemm_mode, pool_info, gemm_fun
1636            );
1637            extend(pool_vec);
1638        }
1639
1640        pub(crate) unsafe fn $name_mt<F:UnaryFnC>(
1641            hw_config: &$t_dispatcher <F>,
1642            m: usize, n: usize, k: usize,
1643            alpha: $t_as,
1644            a: Array<$ta>,
1645            b: Array<$tb>,
1646            beta: $t_bs,
1647            c: ArrayMut<$tc>,
1648            par: &PirePar,
1649            pool_buf: &mut [u8],
1650            gemm_mode: GemmPool,
1651            pool_info: PoolSize,
1652            gemm_fn: unsafe fn(
1653                &$t_dispatcher <F>, usize, usize, usize, *const $t_as, $packa_ty, $packb_ty, *const $t_bs, ArrayMut<$tc>, &PireThreadConfig
1654            )
1655        )
1656        where $t_dispatcher <F>: GemmCache
1657        {
1658
1659            let mc_eff = <$t_dispatcher::<F> as GemmCache>::get_mc_eff(hw_config, par.ic_par);
1660            let nc_eff = <$t_dispatcher::<F> as GemmCache>::get_nc_eff(hw_config, par.jc_par);
1661            let kc_eff = <$t_dispatcher::<F> as GemmCache>::get_kc_eff(hw_config);
1662            let (pa_br_vec_ref, pb_br_vec_ref) = get_apbp_barrier(par);
1663
1664            let (i_load_par, j_load_par) = par.get_load_par(&gemm_mode, m, n, mc_eff, nc_eff);
1665            let (ap_pool_vec, bp_pool_vec) = pool_info.slice_mut_from_pool::<$tap,$tbp>(
1666                pool_buf, i_load_par, j_load_par, pool_info, hw_config.mr, hw_config.nr
1667            );
1668            let (ap_pool, bp_pool) = (&ap_pool_vec, &bp_pool_vec);
1669
1670            // remove par.clone
1671            std::thread::scope(|s| {
1672                for t_id in 1..par.num_threads {
1673                    let t_cfg = PireThreadConfig::new(
1674                        par.clone(), &pa_br_vec_ref, &pb_br_vec_ref, t_id, mc_eff, nc_eff, kc_eff
1675                    );
1676                    let ic_id = t_cfg.ic_id;
1677                    let jc_id = t_cfg.jc_id;
1678                    let ap_id = match gemm_mode {
1679                        GemmPool::Goto => ic_id,
1680                        GemmPool::SmallM => ic_id,
1681                        GemmPool::SmallN => t_id,
1682                    };
1683                    let bp_id = match gemm_mode {
1684                        GemmPool::Goto => jc_id,
1685                        GemmPool::SmallM => 0,
1686                        GemmPool::SmallN => jc_id,
1687                    };
1688                    let ap_cur = a.$pack_fn(ap_pool, ap_id);
1689                    let bp_cur = b.$pack_fn(bp_pool, bp_id);
1690                    let g = hw_config;
1691                    s.spawn(move || {
1692                            let alpha = &alpha as *const $t_as;
1693                            let beta = &beta as *const $t_bs;
1694                            gemm_fn(g, m, n, k, alpha, ap_cur, bp_cur, beta, c, &t_cfg);
1695                        }
1696                    );
1697                }
1698                {
1699                    let ap = a.$pack_fn(ap_pool, 0);
1700                    let bp = b.$pack_fn(bp_pool, 0);
1701                    let t_id: usize = 0;
1702                    let t_cfg = PireThreadConfig::new(par.clone(), &pa_br_vec_ref, &pb_br_vec_ref, t_id, mc_eff, nc_eff, kc_eff);
1703                    let alpha = &alpha as *const $t_as;
1704                    let beta = &beta as *const $t_bs;
1705                    gemm_fn(hw_config, m, n, k, alpha, ap, bp, beta, c, &t_cfg);
1706                }
1707            });
1708        }
1709
1710        unsafe fn $goto_name<F:UnaryFnC>(
1711            hw_cfg: &$t_dispatcher <F>,
1712            m: usize, n: usize, k: usize,
1713            alpha: *const $t_as,
1714            a: $packa_ty,
1715            b: $packb_ty,
1716            beta: *const $t_bs,
1717            c: ArrayMut<$tc>,
1718            t_cfg: &PireThreadConfig
1719        ) {
1720            let ic_id = t_cfg.ic_id;
1721            let jc_id = t_cfg.jc_id;
1722            let ir_id = t_cfg.ir_id;
1723            let jr_id = t_cfg.jr_id;
1724            let ir_par = t_cfg.par.ir_par;
1725            let jr_par = t_cfg.par.jr_par;
1726            let ic_par = t_cfg.par.ic_par;
1727            let jc_par = t_cfg.par.jc_par;
1728            let mc = t_cfg.mc_eff;
1729            let nc = t_cfg.nc_eff;
1730            let kc = t_cfg.kc_eff;
1731            let mr = hw_cfg.mr;
1732            let nr = hw_cfg.nr;
1733            let (mc_start, mc_end, mc_left) = split_c_range(m, mc, mr, ic_id, ic_par);
1734            let (nc_start, nc_end, nc_left) = split_c_range(n, nc, nr, jc_id, jc_par);
1735            let (kc_start, d1_end) = (0, k);
1736            let one = $one;
1737            let c_rs = c.rs();
1738            let c_cs = c.cs();
1739            let c_ptr = c.src();
1740            let mut mc_i = mc_start;
1741            while mc_i < mc_end {
1742                let mc_len = mc.min(mc_end - mc_i);
1743                let mut kc_i = kc_start;
1744                let (mr_start, mr_end) = split_range(mc_len, mr, ir_id, ir_par);
1745                let mr_len = mr_end - mr_start;
1746                let c_i = c_ptr.add((mc_i+mr_start) * c_rs);
1747                while kc_i < d1_end {
1748                    let kc_len = kc.min(d1_end - kc_i);
1749                    let kc_len_eff = hw_cfg.round_k(kc_len);
1750                    let mut nc_i = nc_start;
1751                    let kc_last = kc_i + kc_len == d1_end;
1752                    let beta_t = if kc_i == kc_start { beta } else { &one as *const $t_bs};
1753                    let ap_data = $packa_name(hw_cfg, &a, mc_i, kc_i, mc_len, kc_len, t_cfg);
1754                    let ap = ap_data.src();
1755                    let ap = ap.add(mr_start*kc_len_eff);
1756                    while nc_i < nc_end {
1757                        let nc_len = nc.min(nc_end - nc_i);
1758                        let (nr_start, nr_end) = split_range(nc_len, nr, jr_id, jr_par);
1759                        let nr_len = nr_end - nr_start;
1760                        let c_ij = c_i.add((nc_i+nr_start) * c_cs);
1761                        let bp_data = $packb_name(hw_cfg, &b, nc_i, kc_i, nc_len, kc_len, t_cfg);
1762                        let bp = bp_data.src();
1763                        let bp = bp.add(nr_start*kc_len_eff);
1764                        $goto_kernel(
1765                            hw_cfg, mr_len, nr_len, kc_len, alpha, beta_t, c_ij, c_rs, c_cs,
1766                            ap, bp,
1767                            kc_last,
1768                        );
1769
1770                        nc_i += nc;
1771                    }
1772                    if nc_left {
1773                        t_cfg.wait_packb();
1774                        t_cfg.wait_packb();
1775                    }
1776                    kc_i += kc;
1777                }
1778                mc_i += mc;
1779            }
1780            if mc_left {
1781                let mut kc_i = kc_start;
1782                while kc_i < d1_end {
1783                    let kc_len = kc.min(d1_end -kc_i);
1784                    t_cfg.wait_packa();
1785                    t_cfg.wait_packa();
1786                    let mut nc_i = nc_start;
1787                    while nc_i < nc_end {
1788                        let nc_len = nc.min(nc_end - nc_i);
1789                        let _ = $packb_name(hw_cfg, &b, nc_i, kc_i, nc_len, kc_len, t_cfg);
1790                        nc_i += nc;
1791                    }
1792                    if nc_left{
1793                        t_cfg.wait_packb();
1794                        t_cfg.wait_packb();
1795                    }
1796                    kc_i += kc;
1797                }
1798            }
1799        }
1800        unsafe fn $small_m_name<F:UnaryFnC>(
1801            hw_cfg: &$t_dispatcher <F>,
1802            m: usize, n: usize, k: usize,
1803            alpha: *const $t_as,
1804            a: $packa_ty,
1805            b: $packb_ty,
1806            beta: *const $t_bs,
1807            c: ArrayMut<$tc>,
1808            t_cfg: &PireThreadConfig
1809        ) {
1810            let par = &t_cfg.par;
1811            let ic_id = t_cfg.ic_id;
1812            let jc_id = t_cfg.jc_id;
1813            let ir_id = t_cfg.ir_id;
1814            let ir_par = par.ir_par;
1815            let jr_id = t_cfg.jr_id;
1816            let jr_par = par.jr_par;
1817            let mc = t_cfg.mc_eff;
1818            let nc = t_cfg.nc_eff;
1819            let kc = t_cfg.kc_eff;
1820            let mr = hw_cfg.mr;
1821            let nr = hw_cfg.nr;
1822            let (mc_start, mc_end, mc_left) = split_c_range(m, mc, mr, ic_id, par.ic_par);
1823            let (nc_start, nc_end, _) = split_c_range(n, nc, nr, jc_id, par.jc_par);
1824            let (kc_start, kc_end) = (0, k);
1825            let one = $one;
1826
1827            let b_ptr = b.src();
1828            let b_rs = b.rs();
1829            let b_cs = b.cs();
1830            let c_rs = c.rs();
1831            let c_cs = c.cs();
1832            let c_ptr = c.src();
1833            let mut mc_i = mc_start;
1834            while mc_i < mc_end {
1835                let mc_len = mc.min(mc_end - mc_i);
1836                let (mr_start, mr_end) = split_range(mc_len, mr, ir_id, ir_par);
1837                let mr_len = mr_end - mr_start;
1838                let c_i = c_ptr.add((mc_i+mr_start) * c_rs);
1839                let mut kc_i = kc_start;
1840                while kc_i < kc_end {
1841                    let kc_len = kc.min(kc_end - kc_i);
1842                    let kc_len_eff = hw_cfg.round_k(kc_len);
1843                    let beta_t = if kc_i == kc_start { beta } else { &one as *const $t_bs};
1844                    let kc_last = kc_i + kc_len == kc_end;
1845                    let mut nc_i = nc_start;
1846                    let ap_data = $packa_name(hw_cfg, &a, mc_i, kc_i, mc_len, kc_len, t_cfg);
1847                    let ap = ap_data.src();
1848                    let ap = ap.add(mr_start*kc_len_eff);
1849                    let b_j = b_ptr.add(kc_i * b_rs);
1850                    while nc_i < nc_end {
1851                        let nc_len = nc.min(nc_end - nc_i);
1852                        let (nr_start, nr_end) = split_range(nc_len, nr, jr_id, jr_par);
1853                        let nr_len = nr_end - nr_start;
1854                        let c_ij = c_i.add((nc_i + nr_start) * c_cs);
1855                        let b_cur = b_j.add((nc_i + nr_start) * b_cs);
1856                        $small_m_kernel(
1857                            hw_cfg, mr_len, nr_len, kc_len, alpha, beta_t,
1858                            b_cur, b_rs, b_cs,
1859                            c_ij, c_rs, c_cs,
1860                            ap,
1861                            kc_last,
1862                        );
1863                        nc_i += nc;
1864                    }
1865                    kc_i += kc;
1866                }
1867                mc_i += mc;
1868            }
1869
1870            if mc_left {
1871                let mut kc_i = kc_start;
1872                while kc_i < kc_end {
1873                    t_cfg.wait_packa();
1874                    t_cfg.wait_packa();
1875                    kc_i += kc;
1876                }
1877            }
1878        }
1879        unsafe fn $small_n_name<F:UnaryFnC>(
1880            hw_cfg: &$t_dispatcher <F>,
1881            m: usize, n: usize, k: usize,
1882            alpha: *const $t_as,
1883            a: $packa_ty,
1884            b: $packb_ty,
1885            beta: *const $t_bs,
1886            c: ArrayMut<$tc>,
1887            t_cfg: &PireThreadConfig
1888        ) {
1889            let par = &t_cfg.par;
1890            let ic_id = t_cfg.ic_id;
1891            let jc_id = t_cfg.jc_id;
1892            let ir_id = t_cfg.ir_id;
1893            let ir_par = par.ir_par;
1894            let jr_id = t_cfg.jr_id;
1895            let jr_par = par.jr_par;
1896            let mc = t_cfg.mc_eff;
1897            let nc = t_cfg.nc_eff;
1898            let kc = t_cfg.kc_eff;
1899            let mr = hw_cfg.mr;
1900            let nr = hw_cfg.nr;
1901            let (mc_start, mc_end, mc_left) = split_c_range(m, mc, mr, ic_id, par.ic_par);
1902            let (nc_start, nc_end, nc_left) = split_c_range(n, nc, nr, jc_id, par.jc_par);
1903            let (kc_start, kc_end) = (0, k);
1904            let one = $one;
1905
1906            let c_rs = c.rs();
1907            let c_cs = c.cs();
1908            let c_ptr = c.src();
1909            let a_ptr = a.src();
1910            let a_rs = a.rs();
1911            let a_cs = a.cs();
1912            // make sure this ap is hwole slice
1913            let a_dst = a.dst_write(0, kc);
1914            let a_dst_ref = a_dst.get();
1915            let a_dst_ptr = a_dst_ref.as_mut_ptr();
1916            let mut mc_i = mc_start;
1917            while mc_i < mc_end {
1918                let mc_len = mc.min(mc_end - mc_i);
1919                let (mr_start, mr_end) = split_range(mc_len, mr, ir_id, ir_par);
1920                let mr_len = mr_end - mr_start;
1921                let c_i = c_ptr.add((mc_i+mr_start) * c_rs);
1922                let a_i = a_ptr.add((mc_i+mr_start) * a_rs);
1923                let mut kc_i = kc_start;
1924                while kc_i < kc_end {
1925                    let kc_len = kc.min(kc_end - kc_i);
1926                    let kc_last = kc_i + kc_len == kc_end;
1927                    let beta_t = if kc_i == kc_start { beta } else { &one as *const $t_bs};
1928                    let a_cur = a_i.add(kc_i*a_cs);
1929                    let mut nc_i = nc_start;
1930                    while nc_i < nc_end {
1931                        let nc_len = nc.min(nc_end - nc_i);
1932                        let (nr_start, nr_end) = split_range(nc_len, nr, jr_id, jr_par);
1933                        let nr_len = nr_end - nr_start;
1934                        let bp_data = $packb_name(hw_cfg, &b, nc_i, kc_i, nc_len, kc_len, t_cfg);
1935                        let bp = bp_data.src();
1936                        let c_ij = c_i.add((nc_i + nr_start) * c_cs);
1937                        $small_n_kernel(
1938                            hw_cfg, mr_len, nr_len, kc_len, alpha, beta_t,
1939                            a_cur, a_rs, a_cs,
1940                            a_dst_ptr, bp,
1941                            c_ij, c_rs, c_cs,
1942                            kc_last,
1943                        );
1944                        nc_i += nc;
1945                    }
1946                    if nc_left {
1947                        t_cfg.wait_packb();
1948                        t_cfg.wait_packb();
1949                    }
1950                    kc_i += kc;
1951                }
1952                mc_i += mc;
1953            }
1954            if mc_left {
1955                let mut kc_i = kc_start;
1956                while kc_i < kc_end {
1957                    let kc_len = kc.min(kc_end - kc_i);
1958                    let mut nc_i = nc_start;
1959                    while nc_i < nc_end {
1960                        let nc_len = nc.min(nc_end - nc_i);
1961                        let _ = $packb_name(hw_cfg, &b, nc_i, kc_i, nc_len, kc_len, t_cfg);
1962                        nc_i += nc;
1963                    }
1964                    if nc_left{
1965                        t_cfg.wait_packb();
1966                        t_cfg.wait_packb();
1967                    }
1968                    kc_i += kc;
1969                }
1970            }
1971        }
1972        // for packed api mc_i(nc_i) should be multiple of mr (nr, which we ensure by the split_c_range
1973        // for packed api kc_i should be multiple of kc_eff, which is always true since we dont parallelize over kc
1974        // this is subject to change if we parallelize over kc, but this is not in the plan
1975        // sync right before write and right before read
1976        // NOTE: dont return before the second packa as it ensures sync between threads
1977        pub(crate) unsafe fn $packa_name<'a,'b,F:UnaryFnC>(hw_cfg: &$t_dispatcher <F>, x: &'b $packa_ty<'a>, mc_i: usize, kc_i: usize, mc_len: usize, kc_len: usize, t_cfg: &PireThreadConfig) -> PtrData<'a,$tap> {
1978            t_cfg.wait_packa();
1979            let xp_ptr = match x {
1980                $packa_ty::StridedMatrix(x_i) => {
1981                    let mc_par = x_i.get_mc();
1982                    let mc_offset = mc_par * t_cfg.i_load_p_idx;
1983                    if mc_len > mc_offset {
1984                        let kc_len_ro = hw_cfg.round_k(kc_len);
1985                        let mc_len_x = (mc_len - mc_offset).min(mc_par);
1986                        let mc_i = mc_i + mc_offset;
1987                        let (rs, cs) = (x_i.rs(), x_i.cs());
1988                        let src_ptr = x_i.src().add(mc_i*rs + kc_i*cs);
1989                        let dst = x_i.dst_write(t_cfg.i_load_p_idx, kc_len_ro);
1990                        let dst_ref = dst.get();
1991                        let dst_ptr = dst_ref.as_mut_ptr();
1992                        $packa_name0(src_ptr, dst_ptr, mc_len_x, kc_len, rs, cs);
1993                    }
1994                    t_cfg.wait_packa();
1995                    PtrData::RefData(x_i.dst_read())
1996                }
1997                $packa_ty::PackedMatrix(x_i) => {
1998                    let m_ro = hw_cfg.round_m(x_i.m());
1999                    let kc_len_ro = hw_cfg.round_k(kc_len);
2000                    let res = is_mixed!(
2001                        $include_flag,
2002                        {
2003                            let mc_par = x_i.get_mc();
2004                            let mc_offset = mc_par * t_cfg.i_load_p_idx;
2005                            if mc_len > mc_offset {
2006                                let mc_len_x = (mc_len - mc_offset).min(mc_par);
2007                                let mc_i = mc_i + mc_offset;
2008                                let src_ptr = x_i.src().add(mc_i*kc_len_ro + kc_i*m_ro);
2009                                let dst = x_i.dst_write(t_cfg.i_load_p_idx, kc_len_ro);
2010                                let dst_ref = dst.get();
2011                                let dst_ptr = dst_ref.as_mut_ptr();
2012                                let mc_len_x_ro = hw_cfg.round_m(mc_len_x);
2013                                hw_cfg.cvt_mixed(src_ptr, dst_ptr, mc_len_x_ro*kc_len_ro);
2014                            }
2015                            t_cfg.wait_packa();
2016                            PtrData::RefData(x_i.dst_read())
2017                        },
2018                        {
2019                            let src_ptr = x_i.src().add(mc_i*kc_len_ro + kc_i*m_ro);
2020                            t_cfg.wait_packa();
2021                            PtrData::PtrData(src_ptr)
2022                        }
2023
2024                    );
2025                    res
2026
2027                }
2028            };
2029            xp_ptr
2030        }
2031        // NOTE: dont return before the second packa as it ensures sync between threads
2032        pub(crate) unsafe fn $packb_name<'a,'b,F:UnaryFnC>(hw_cfg: & $t_dispatcher <F>, x: &'b$packb_ty<'a>, nc_i: usize, kc_i: usize, nc_len: usize, kc_len: usize, t_cfg: &PireThreadConfig) -> PtrData<'a,$tbp> {
2033            t_cfg.wait_packb();
2034            let xp_ptr = match x {
2035                $packb_ty::StridedMatrix(x_i) => {
2036                    let nc_par = x_i.get_mc();
2037                    let nc_offset = nc_par * t_cfg.j_load_p_idx;
2038                    if nc_len > nc_offset {
2039                        let kc_len_ro = hw_cfg.round_k(kc_len);
2040                        let nc_len_x = (nc_len - nc_offset).min(nc_par);
2041                        let nc_i = nc_i + nc_offset;
2042                        let rs = x_i.rs();
2043                        let cs = x_i.cs();
2044                        let src_ptr = x_i.src().add(kc_i*rs + nc_i*cs);
2045                        let dst = x_i.dst_write(t_cfg.j_load_p_idx, kc_len_ro);
2046                        let dst_ref = dst.get();
2047                        let dst_ptr = dst_ref.as_mut_ptr();
2048                        $packb_name0(src_ptr, dst_ptr, nc_len_x, kc_len, rs, cs);
2049                    }
2050                    t_cfg.wait_packb();
2051                    PtrData::RefData(x_i.dst_read())
2052                }
2053                $packb_ty::PackedMatrix(x_i) => {
2054                    let kc_len_ro = hw_cfg.round_k(kc_len);
2055                    let n_ro = x_i.m();
2056                    let res = is_mixed!(
2057                        $include_flag,
2058                        {
2059                            let nc_par = x_i.get_mc();
2060                            let nc_offset = nc_par * t_cfg.j_load_p_idx;
2061                            if nc_len > nc_offset {
2062                                let nc_len_x = (nc_len - nc_offset).min(nc_par);
2063                                let nc_i = nc_i + nc_offset;
2064                                let src_ptr = x_i.src().add(nc_i*kc_len_ro + kc_i*n_ro);
2065                                let dst = x_i.dst_write(t_cfg.j_load_p_idx, kc_len_ro);
2066                                let dst_ref = dst.get();
2067                                let dst_ptr = dst_ref.as_mut_ptr();
2068                                hw_cfg.cvt_mixed(src_ptr, dst_ptr, nc_len_x*kc_len_ro);
2069                            }
2070                            t_cfg.wait_packb();
2071                            PtrData::RefData(x_i.dst_read())
2072                        },
2073                        {
2074                            let src_ptr = x_i.src().add(nc_i*kc_len_ro + kc_i*n_ro);
2075                            t_cfg.wait_packb();
2076                            PtrData::PtrData(src_ptr)
2077                        }
2078
2079                    );
2080                    res
2081                }
2082            };
2083            xp_ptr
2084        }
2085    }
2086}
2087#[macro_export]
2088macro_rules! partial_strided {
2089    ($strided:tt, $strided2:tt, F) => {
2090        $strided
2091    };
2092    ($strided:tt, $strided2:tt, T) => {
2093        $strided2
2094    };
2095}
2096
2097#[target_feature(enable = "avx")]
2098#[cfg(target_arch = "x86_64")]
2099unsafe fn vzeroupper_unchecked() {
2100    core::arch::x86_64::_mm256_zeroupper();
2101}
2102
2103pub fn avx_vzeroupper() {
2104    #[cfg(target_arch = "x86_64")]
2105    if (*RUNTIME_HW_CONFIG).cpu_ft.avx {
2106        unsafe {
2107            vzeroupper_unchecked();
2108        }
2109    }
2110}
2111
2112#[macro_export]
2113macro_rules! def_kernel_bb_pf1 {
2114    (
2115        $t_ap:ty, $t_bp:ty, $t_c:ty, $t_s:ty,
2116        $no_partial:tt,
2117        $RS:tt,
2118        $MR:tt, $NR:tt, $pf1_0:tt, $pf_step:tt
2119    ) => {
2120
2121        pub unsafe fn kernel_bb<F: UnaryFnC, const STRIDED: bool>(
2122            m: usize, n: usize, k: usize,
2123            alpha: *const $t_s,
2124            beta: *const $t_s,
2125            c: *mut $t_c, c_rs: usize, c_cs: usize,
2126            ap: *const $t_ap, bp: *const $t_bp,
2127            f: F,
2128        ) {
2129            const STRIDED_PARTIAL: bool = true;
2130            const MR: usize = $MR * VS;
2131            const NR: usize = $NR;
2132            let m_rounded = m / MR * MR;
2133            let n_rounded = n / NR * NR;
2134            let m_left = m % MR;
2135            let n_left = n % NR;
2136
2137            let d_arr = [0, 0, c_rs];
2138
2139            let mut m_i = 0;
2140            while m_i < m_rounded {
2141                let c_cur0 = c.add(m_i * c_rs);
2142                let ap_cur = ap.add(m_i * k);
2143                let mut a_pft1_offset = $pf1_0 * k;
2144                let mut n_i = 0;
2145                while n_i < n_rounded {
2146                    let bp_cur = bp.add(n_i * k);
2147                    let c_cur1 = c_cur0.add(n_i * c_cs);
2148                    ukernel_bbc::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, c_cs, a_pft1_offset, NR, f);
2149                    n_i += NR;
2150                    a_pft1_offset += $pf_step * k;
2151                }
2152                // let a_pft1_offset = ($MR+(n_iter0-n_iter)*2)*4*k;
2153                if n_left != 0 {
2154                    let bp_cur = bp.add(n_i * k);
2155                    let c_cur1 = c_cur0.add(n_i * c_cs);
2156                    ukernel_n_bbc::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, c_cs, MR, n_left, f);
2157                }
2158                m_i += MR;
2159            }
2160
2161
2162            seq_macro::seq!(mr_left in 1..=$MR {
2163                if (m_left+VS-1) / VS == mr_left {
2164                    let c_cur0 = c.add(m_i * c_rs);
2165                    let ap_cur = ap.add(m_i * k);
2166                    let mut n_i = 0;
2167                    while n_i < n_rounded {
2168                        let bp_cur = bp.add(n_i * k);
2169                        let c_cur1 = c_cur0.add(n_i * c_cs);
2170                        paste::paste! {
2171                            [<ukernel_ mr_left _bbp>]::<_, pire_base::partial_strided!(STRIDED,STRIDED_PARTIAL,$no_partial)>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, c_cs, m_left, NR, f);
2172                        }
2173                        n_i += NR;
2174                    }
2175                    if n_left !=0 {
2176                        let bp_cur = bp.add(n_i * k);
2177                        let c_cur1 = c_cur0.add(n_i * c_cs);
2178                        paste::paste! {
2179                            [<ukernel_ mr_left xn_bbp>]::<_, pire_base::partial_strided!(STRIDED,STRIDED_PARTIAL,$no_partial)>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, c_cs, m_left, n_left, f);
2180                        }
2181                    }
2182                }
2183            });
2184        }
2185
2186        pub(crate) unsafe fn kernel<F: UnaryFnC>(
2187            m: usize,
2188            n: usize,
2189            k: usize,
2190            alpha: *const $t_s,
2191            beta: *const $t_s,
2192            c: *mut $t_c,
2193            c_rs: usize,
2194            c_cs: usize,
2195            ap: *const $t_ap,
2196            bp: *const $t_bp,
2197            f: F,
2198        ) {
2199            let k = (k + $RS - 1) / $RS * $RS;
2200            if c_rs == 1 {
2201                kernel_bb::<_, false>(m, n, k, alpha, beta, c, c_rs, c_cs, ap, bp, f)
2202            } else {
2203                kernel_bb::<_, true>(m, n, k, alpha, beta, c, c_rs, c_cs, ap, bp, f)
2204            }
2205            pire_base::avx_vzeroupper();
2206        }
2207
2208    };
2209}
2210
2211#[macro_export]
2212macro_rules! def_kernel_bb_v0 {
2213    (
2214        $t_ap:ty, $t_bp:ty, $t_c:ty, $t_s:ty,
2215        $no_partial:tt,
2216        $RS:tt,
2217        $MR:tt, $NR:tt
2218    ) => {
2219        pub unsafe fn kernel_bb<F: UnaryFnC, const STRIDED: bool>(
2220            m: usize, n: usize, k: usize,
2221            alpha: *const $t_s,
2222            beta: *const $t_s,
2223            c: *mut $t_c, c_rs: usize, c_cs: usize,
2224            ap: *const $t_ap, bp: *const $t_bp,
2225            f: F,
2226        ) {
2227            let vs = simd_vector_length();
2228            const STRIDED_PARTIAL: bool = true;
2229            let mr = $MR * vs;
2230            const NR: usize = $NR;
2231            let m_rounded = m / mr * mr;
2232            let n_rounded = n / NR * NR;
2233            let m_left = m % mr;
2234            let n_left = n % NR;
2235
2236            let d_arr = [0, 0, c_rs];
2237
2238            let mut m_i = 0;
2239            while m_i < m_rounded {
2240                let c_cur0 = c.add(m_i * c_rs);
2241                let ap_cur = ap.add(m_i * k);
2242                let mut n_i = 0;
2243                while n_i < n_rounded {
2244                    let bp_cur = bp.add(n_i * k);
2245                    let c_cur1 = c_cur0.add(n_i * c_cs);
2246                    ukernel_bbc::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, c_cs, mr, NR, f);
2247                    n_i += NR;
2248                }
2249                if n_left != 0 {
2250                    let bp_cur = bp.add(n_i * k);
2251                    let c_cur1 = c_cur0.add(n_i * c_cs);
2252                    ukernel_n_bbc::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, c_cs, mr, n_left, f);
2253                }
2254                m_i += mr;
2255            }
2256
2257            seq_macro::seq!(mr_left in 1..=$MR {
2258                if (m_left+vs-1) / vs == mr_left {
2259                    let c_cur0 = c.add(m_i * c_rs);
2260                    let ap_cur = ap.add(m_i * k);
2261                    let mut n_i = 0;
2262                    while n_i < n_rounded {
2263                        let bp_cur = bp.add(n_i * k);
2264                        let c_cur1 = c_cur0.add(n_i * c_cs);
2265                        paste::paste! {
2266                            [<ukernel_ mr_left _bbp>]::<_, pire_base::partial_strided!(STRIDED,STRIDED_PARTIAL,$no_partial)>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, c_cs, m_left, NR, f);
2267                        }
2268                        n_i += NR;
2269                    }
2270                    if n_left !=0 {
2271                        let bp_cur = bp.add(n_i * k);
2272                        let c_cur1 = c_cur0.add(n_i * c_cs);
2273                        paste::paste! {
2274                            [<ukernel_ mr_left xn_bbp>]::<_, pire_base::partial_strided!(STRIDED,STRIDED_PARTIAL,$no_partial)>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, c_cs, m_left, n_left, f);
2275                        }
2276                    }
2277                }
2278            });
2279        }
2280        pub(crate) unsafe fn kernel<F: UnaryFnC>(
2281            m: usize,
2282            n: usize,
2283            k: usize,
2284            alpha: *const $t_s,
2285            beta: *const $t_s,
2286            c: *mut $t_c,
2287            c_rs: usize,
2288            c_cs: usize,
2289            ap: *const $t_ap,
2290            bp: *const $t_bp,
2291            f: F,
2292        ) {
2293            let k = (k + $RS - 1) / $RS * $RS;
2294            if c_rs == 1 {
2295                kernel_bb::<_, false>(m, n, k, alpha, beta, c, c_rs, c_cs, ap, bp, f)
2296            } else {
2297                kernel_bb::<_, true>(m, n, k, alpha, beta, c, c_rs, c_cs, ap, bp, f)
2298            }
2299            pire_base::avx_vzeroupper();
2300        }
2301    };
2302}
2303
2304#[macro_export]
2305macro_rules! def_kernel_sb_pf1 {
2306    (
2307        $t_a:ty, $t_ap:ty, $t_bp:ty, $t_c:ty, $t_s:ty,
2308        $pack_fn:tt,
2309        $RS:tt,
2310        $MR:tt, $NR:tt, $pf1_0:tt, $pf_step:tt
2311    ) => {
2312        pub unsafe fn kernel_sb_v0<F: UnaryFnC, const STRIDED: bool>(
2313            m: usize, n: usize, k: usize,
2314            alpha: *const $t_s, beta: *const $t_s,
2315            a: *const $t_a, a_rs: usize, a_cs: usize,
2316            bp: *const $t_bp,
2317            c: *mut $t_c, c_rs: usize, c_cs: usize,
2318            ap: *mut $t_ap,
2319            f: F,
2320        ) {
2321            let k_eff = (k+$RS-1) / $RS * $RS;
2322            const MR: usize = $MR * VS;
2323            const NR: usize = $NR;
2324            let m_rounded = m / MR * MR;
2325            let n_rounded = n / NR * NR;
2326            let m_left = m % MR;
2327            let n_left = n % NR;
2328
2329            let d_arr = [0, 0, c_rs];
2330
2331            let mut m_i = 0;
2332            while m_i < m_rounded {
2333                let c_cur0 = c.add(m_i * c_rs);
2334                let a_cur = a.add(m_i * a_rs);
2335                let a_pft1_offset = $pf1_0 * k;
2336                $pack_fn(MR, k, a_cur, a_rs, a_cs, ap, VS);
2337                let mut n_i = 0;
2338                while n_i < n_rounded {
2339                    let bp_cur = bp.add(n_i * k_eff);
2340                    let c_cur1 = c_cur0.add(n_i * c_cs);
2341                    ukernel_bbc::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, c_cs, a_pft1_offset, NR, f);
2342                    n_i += NR;
2343                }
2344                if n_left != 0 {
2345                    let bp_cur = bp.add(n_i * k_eff);
2346                    let c_cur1 = c_cur0.add(n_i * c_cs);
2347                    ukernel_n_bbc::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, c_cs, MR, n_left, f);
2348                }
2349                m_i += MR;
2350            }
2351
2352            seq_macro::seq!(mr_left in 1..=$MR {
2353                if (m_left+VS-1) / VS == mr_left {
2354                    let c_cur0 = c.add(m_i * c_rs);
2355                    let a_cur = a.add(m_i * a_rs);
2356                    $pack_fn(m_left, k, a_cur, a_rs, a_cs, ap, VS);
2357                    let mut n_i = 0;
2358                    while n_i < n_rounded {
2359                        let bp_cur = bp.add(n_i * k_eff);
2360                        let c_cur1 = c_cur0.add(n_i * c_cs);
2361                        paste::paste! {
2362                            [<ukernel_ mr_left _bbp>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, c_cs, m_left, NR, f);
2363                        }
2364                        n_i += NR;
2365                    }
2366                    if n_left != 0 {
2367                        let bp_cur = bp.add(n_i * k_eff);
2368                        let c_cur1 = c_cur0.add(n_i * c_cs);
2369                        paste::paste! {
2370                            [<ukernel_ mr_left xn_bbp>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, c_cs, m_left, n_left, f);
2371                        }
2372                    }
2373                    return;
2374                }
2375            });
2376        }
2377
2378        pub(crate) unsafe fn kernel_sb<F: UnaryFnC>(
2379            m: usize,
2380            n: usize,
2381            k: usize,
2382            alpha: *const $t_s,
2383            beta: *const $t_s,
2384            a: *const $t_a,
2385            a_rs: usize,
2386            a_cs: usize,
2387            b: *const $t_bp,
2388            c: *mut $t_c,
2389            c_rs: usize,
2390            c_cs: usize,
2391            ap_buf: *mut $t_ap,
2392            f: F,
2393        ) {
2394            if c_rs == 1 {
2395                kernel_sb_v0::<_, false>(m, n, k, alpha, beta, a, a_rs, a_cs, b, c, c_rs, c_cs, ap_buf, f);
2396            } else {
2397                kernel_sb_v0::<_, true>(m, n, k, alpha, beta, a, a_rs, a_cs, b, c, c_rs, c_cs, ap_buf, f);
2398            }
2399            pire_base::avx_vzeroupper();
2400        }
2401    };
2402}
2403
2404#[macro_export]
2405macro_rules! def_kernel_sb_v0 {
2406    (
2407        $t_a:ty, $t_ap:ty, $t_bp:ty, $t_c:ty, $t_s:ty,
2408        $no_partial:tt,
2409        $pack_fn:tt,
2410        $RS:tt,
2411        $MR:tt, $NR:tt
2412    ) => {
2413        pub unsafe fn kernel_sb_v0<F: UnaryFnC, const STRIDED: bool>(
2414            m: usize, n: usize, k: usize,
2415            alpha: *const $t_s, beta: *const $t_s,
2416            a: *const $t_a, a_rs: usize, a_cs: usize,
2417            bp: *const $t_bp,
2418            c: *mut $t_c, c_rs: usize, c_cs: usize,
2419            ap: *mut $t_ap,
2420            f: F,
2421        ) {
2422            let k_eff = (k+$RS-1) / $RS * $RS;
2423            const STRIDED_PARTIAL: bool = true;
2424            let vs = simd_vector_length();
2425            let mr = $MR * vs;
2426            const NR: usize = $NR;
2427            let m_rounded = m / mr * mr;
2428            let n_rounded = n / NR * NR;
2429            let m_left = m % mr;
2430            let n_left = n % NR;
2431
2432            let d_arr = [0, 0, c_rs];
2433
2434            let mut m_i = 0;
2435            while m_i < m_rounded {
2436                let c_cur0 = c.add(m_i * c_rs);
2437                let a_cur = a.add(m_i * a_rs);
2438                $pack_fn(mr, k, a_cur, a_rs, a_cs, ap, vs);
2439                let mut n_i = 0;
2440                while n_i < n_rounded {
2441                    let bp_cur = bp.add(n_i * k_eff);
2442                    let c_cur1 = c_cur0.add(n_i * c_cs);
2443                    ukernel_bbc::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, c_cs, mr, NR, f);
2444                    n_i += NR;
2445                }
2446                if n_left != 0 {
2447                    let bp_cur = bp.add(n_i * k_eff);
2448                    let c_cur1 = c_cur0.add(n_i * c_cs);
2449                    ukernel_n_bbc::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, c_cs, mr, n_left, f);
2450                }
2451                m_i += mr;
2452            }
2453
2454            seq_macro::seq!(mr_left in 1..=$MR {
2455                if (m_left+vs-1) / vs == mr_left {
2456                    let c_cur0 = c.add(m_i * c_rs);
2457                    let a_cur = a.add(m_i * a_rs);
2458                    $pack_fn(m_left, k, a_cur, a_rs, a_cs, ap, vs);
2459                    let mut n_i = 0;
2460                    while n_i < n_rounded {
2461                        let bp_cur = bp.add(n_i * k_eff);
2462                        let c_cur1 = c_cur0.add(n_i * c_cs);
2463                        paste::paste! {
2464                            [<ukernel_ mr_left _bbp>]::<_, pire_base::partial_strided!(STRIDED,STRIDED_PARTIAL,$no_partial)>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, c_cs, m_left, NR, f);
2465                        }
2466                        n_i += NR;
2467                    }
2468                    if n_left != 0 {
2469                        let bp_cur = bp.add(n_i * k_eff);
2470                        let c_cur1 = c_cur0.add(n_i * c_cs);
2471                        paste::paste! {
2472                            [<ukernel_ mr_left xn_bbp>]::<_, pire_base::partial_strided!(STRIDED,STRIDED_PARTIAL,$no_partial)>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, c_cs, m_left, n_left, f);
2473                        }
2474                    }
2475                    return;
2476                }
2477            });
2478        }
2479
2480        pub(crate) unsafe fn kernel_sb<F: UnaryFnC>(
2481            m: usize,
2482            n: usize,
2483            k: usize,
2484            alpha: *const $t_s,
2485            beta: *const $t_s,
2486            a: *const $t_a,
2487            a_rs: usize,
2488            a_cs: usize,
2489            b: *const $t_bp,
2490            c: *mut $t_c,
2491            c_rs: usize,
2492            c_cs: usize,
2493            ap_buf: *mut $t_ap,
2494            f: F,
2495        ) {
2496            if c_rs == 1 {
2497                kernel_sb_v0::<_, false>(m, n, k, alpha, beta, a, a_rs, a_cs, b, c, c_rs, c_cs, ap_buf, f);
2498            } else {
2499                kernel_sb_v0::<_, true>(m, n, k, alpha, beta, a, a_rs, a_cs, b, c, c_rs, c_cs, ap_buf, f);
2500            }
2501            pire_base::avx_vzeroupper();
2502        }
2503    };
2504}
2505
2506#[macro_export]
2507macro_rules! def_kernel_bs {
2508    (
2509        $t_ap:ty, $t_b:ty, $t_c:ty, $t_s:ty,
2510        $MR:tt, $NR:tt
2511    ) => {
2512        pub unsafe fn kernel_bs_v0<F: UnaryFnC, const STRIDED: bool>(
2513            m: usize, n: usize, k: usize,
2514            alpha: *const $t_s, beta: *const $t_s,
2515            b: *const $t_b, b_rs: usize, b_cs: usize,
2516            c: *mut $t_c, c_rs: usize, c_cs: usize,
2517            ap: *const $t_ap,
2518            f: F,
2519        ) {
2520            const MR: usize = $MR * VS;
2521            const NR: usize = $NR;
2522            let m_rounded = m / MR * MR;
2523            let n_rounded = n / NR * NR;
2524            let m_left = m % MR;
2525            let n_left = n % NR;
2526
2527            let d_arr = [b_rs, b_cs, c_rs];
2528
2529            let mut m_i = 0;
2530            while m_i < m_rounded {
2531                let c_cur0 = c.add(m_i * c_rs);
2532                let ap_cur = ap.add(m_i * k);
2533                let mut n_i = 0;
2534                while n_i < n_rounded {
2535                    let b_cur = b.add(n_i * b_cs);
2536                    let c_cur1 = c_cur0.add(n_i * c_cs);
2537                    ukernel_bsc::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, beta, k, d_arr, c_cs, MR, NR, f);
2538                    n_i += NR;
2539                }
2540                if n_left != 0 {
2541                    let b_cur = b.add(n_i * b_cs);
2542                    let c_cur1 = c_cur0.add(n_i * c_cs);
2543                    ukernel_n_bsc::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, beta, k, d_arr, c_cs, MR, n_left, f);
2544                }
2545                m_i += MR;
2546            }
2547            seq_macro::seq!(mr_left in 1..=$MR {
2548                if (m_left+VS-1) / VS == mr_left {
2549                    let c_cur0 = c.add(m_i * c_rs);
2550                    let ap_cur = ap.add(m_i * k);
2551                    let mut n_i = 0;
2552                    while n_i < n_rounded {
2553                        let b_cur = b.add(n_i * b_cs);
2554                        let c_cur1 = c_cur0.add(n_i * c_cs);
2555                        paste::paste! {
2556                            [<ukernel_ mr_left _bsp>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, beta, k, d_arr, c_cs, m_left, NR, f);
2557                        }
2558                        n_i += NR;
2559                    }
2560                    if n_left != 0 {
2561                        let b_cur = b.add(n_i * b_cs);
2562                        let c_cur1 = c_cur0.add(n_i * c_cs);
2563                        paste::paste! {
2564                            [<ukernel_ mr_left xn_bsp>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, beta, k, d_arr, c_cs, m_left, n_left, f);
2565                        }
2566                    }
2567                    return;
2568                }
2569            });
2570        }
2571
2572        pub(crate) unsafe fn kernel_bs<F: UnaryFnC>(
2573            m: usize,
2574            n: usize,
2575            k: usize,
2576            alpha: *const $t_s,
2577            beta: *const $t_s,
2578            b: *const $t_b,
2579            b_rs: usize,
2580            b_cs: usize,
2581            c: *mut $t_c,
2582            c_rs: usize,
2583            c_cs: usize,
2584            ap: *const $t_ap,
2585            f: F,
2586        ) {
2587            if c_rs == 1 {
2588                kernel_bs_v0::<_, false>(m, n, k, alpha, beta, b, b_rs, b_cs, c, c_rs, c_cs, ap, f);
2589            } else {
2590                kernel_bs_v0::<_, true>(m, n, k, alpha, beta, b, b_rs, b_cs, c, c_rs, c_cs, ap, f);
2591            }
2592            pire_base::avx_vzeroupper();
2593        }
2594    };
2595}
2596
2597#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2598#[macro_export]
2599macro_rules! mem {
2600    ($m0:tt, $b0:tt) => {
2601        concat!($b0, "+", $m0)
2602    };
2603}
2604
2605#[cfg(target_arch = "aarch64")]
2606#[macro_export]
2607macro_rules! mem {
2608    ($m0:tt, $b0:tt, $b1:tt) => {
2609        concat!("[", $m0, ", #", $b0, ", ", $b1, "]")
2610    };
2611    ($m0:tt, $b0:tt) => {
2612        concat!("[", $m0, ", #", $b0, "]")
2613    };
2614    ($m0:tt) => {
2615        concat!("[", $m0, "]")
2616    };
2617}
2618
2619#[macro_export]
2620macro_rules! n_cond {
2621    (1, $ni:tt, $nr:tt) => {
2622        $ni == $nr
2623    };
2624    ($n0:tt, $ni:tt, $nr:tt) => {
2625        true
2626    };
2627}
2628
2629#[macro_export]
2630macro_rules! load_a_avx {
2631    ($mr:tt, $K:tt) => {
2632        pire_base::loadp_avx!($mr, concat!($mr, "*32*", $K, "({ax})"))
2633    };
2634}
2635#[macro_export]
2636macro_rules! load_a_avx512 {
2637    ($mr:tt) => {
2638        pire_base::loadp_avx512!($mr, "0({ax})")
2639    };
2640}
2641
2642/*
2643
2644x1 -> cs_a
2645x2 -> cs_b
2646x3 -> ax + 3*cs_a
2647x4 -> bx + 3*cs_b
2648
2649*/
2650
2651#[macro_export]
2652macro_rules! init_ab {
2653    (B) => {
2654        concat!(
2655            "/* {x5} */\n",
2656            "/* {x4} */\n",
2657            "/* {x3} */\n",
2658            "/* {x2} */\n",
2659            "/* {x1} */\n",
2660            "mov 24({dim_arrx}),{x0}\n",
2661        )
2662    };
2663    (S) => {
2664        concat!(
2665            // mov cs_b to reg
2666            "mov ({dim_arrx}), {x1}\n",
2667            "mov 8({dim_arrx}), {x2}\n",
2668            "lea ({x2}, {x2}, 2), {x5}\n",
2669            "lea ({bx}, {x5}, 1), {x3}\n",
2670            "lea ({x3}, {x5}, 1), {x4}\n",
2671            "lea ({x4}, {x5}, 1), {x5}\n",
2672            "mov 24({dim_arrx}),{x0}\n",
2673        )
2674    };
2675}
2676
2677#[macro_export]
2678macro_rules! init_ab_2 {
2679    (B) => {
2680        concat!("mov 8({dim_arrx}),{x0}\n",)
2681    };
2682}
2683
2684#[macro_export]
2685macro_rules! c_load {
2686    () => {
2687        concat!(
2688            "mov 16({dim_arrx}),{x0}\n",
2689            "lea ({x0}, {x0}, 2), {x3}\n",
2690            "lea ({cx}, {x3},), {x1}\n",
2691            "lea ({x1}, {x3},), {x2}\n",
2692            "lea ({x2}, {x3},), {x3}\n",
2693        )
2694    };
2695}
2696
2697#[macro_export]
2698macro_rules! init_ab_avx {
2699    (B) => {
2700        concat!("/* {x3} */\n", "/* {x2} */\n", "/* {x1} */\n", "mov 24({dim_arrx}),{x0}\n",)
2701    };
2702    (S) => {
2703        concat!(
2704            // mov cs_b to reg
2705            "mov ({dim_arrx}), {x1}\n",
2706            "mov 8({dim_arrx}), {x2}\n",
2707            "lea ({x2}, {x2}, 2), {x3}\n",
2708            "lea ({bx}, {x3}, 1), {x3}\n",
2709            "mov 24({dim_arrx}),{x0}\n",
2710        )
2711    };
2712}
2713
2714#[macro_export]
2715macro_rules! b_reg {
2716    (0) => {
2717        "({bx})"
2718    };
2719    (1) => {
2720        "({bx},{x2},1)"
2721    };
2722    (2) => {
2723        "({bx},{x2},2)"
2724    };
2725    (3) => {
2726        "({x3})"
2727    };
2728    (4) => {
2729        "({x3},{x2},1)"
2730    };
2731    (5) => {
2732        "({x3},{x2},2)"
2733    };
2734    (6) => {
2735        "({x4})"
2736    };
2737    (7) => {
2738        "({x4},{x2},1)"
2739    };
2740    (8) => {
2741        "({x4},{x2},2)"
2742    };
2743    (9) => {
2744        "({x5})"
2745    };
2746    (10) => {
2747        "({x5},{x2},1)"
2748    };
2749    (11) => {
2750        "({x5},{x2},2)"
2751    };
2752}
2753
2754#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2755#[macro_export]
2756macro_rules! c_mem {
2757    (0) => {
2758        "0({cx})"
2759    };
2760    (1) => {
2761        "0({cx}, {x0})"
2762    };
2763    (2) => {
2764        "0({cx}, {x0}, 2)"
2765    };
2766    (3) => {
2767        "0({x1})"
2768    };
2769    (4) => {
2770        "0({x1}, {x0})"
2771    };
2772    (5) => {
2773        "0({x1}, {x0}, 2)"
2774    };
2775    (6) => {
2776        "0({x2})"
2777    };
2778    (7) => {
2779        "0({x2}, {x0})"
2780    };
2781    (8) => {
2782        "0({x2}, {x0}, 2)"
2783    };
2784    (9) => {
2785        "0({x3})"
2786    };
2787    (10) => {
2788        "0({x3}, {x0})"
2789    };
2790    (11) => {
2791        "0({x3}, {x0}, 2)"
2792    };
2793    (12) => {
2794        "0({x4})"
2795    };
2796    (13) => {
2797        "0({x4}, {x0})"
2798    };
2799    (14) => {
2800        "0({x4}, {x0}, 2)"
2801    };
2802}
2803
2804#[cfg(target_arch = "aarch64")]
2805#[macro_export]
2806macro_rules! c_mem {
2807    (0) => {
2808        "{cx}"
2809    };
2810    (1) => {
2811        "{x1}"
2812    };
2813    (2) => {
2814        "{x2}"
2815    };
2816    (3) => {
2817        "{x3}"
2818    };
2819    (4) => {
2820        "{x4}"
2821    };
2822    (5) => {
2823        "{x5}"
2824    };
2825    (6) => {
2826        "{x6}"
2827    };
2828    (7) => {
2829        "{x7}"
2830    };
2831    (8) => {
2832        "{x8}"
2833    };
2834    (9) => {
2835        "{x9}"
2836    };
2837    (10) => {
2838        "{x10}"
2839    };
2840    (11) => {
2841        "{x11}"
2842    };
2843    (12) => {
2844        "{x12}"
2845    };
2846    (13) => {
2847        "{x13}"
2848    };
2849    (14) => {
2850        "{x14}"
2851    };
2852}
2853
2854#[macro_export]
2855macro_rules! c_reg_2x4 {
2856    (0,0) => {
2857        4
2858    };
2859    (1,0) => {
2860        5
2861    };
2862    (0,1) => {
2863        6
2864    };
2865    (1,1) => {
2866        7
2867    };
2868    (0,2) => {
2869        8
2870    };
2871    (1,2) => {
2872        9
2873    };
2874    (0,3) => {
2875        10
2876    };
2877    (1,3) => {
2878        11
2879    };
2880}
2881#[macro_export]
2882macro_rules! c_reg_1x4 {
2883    (0,0) => {
2884        7
2885    };
2886    (0,1) => {
2887        8
2888    };
2889    (0,2) => {
2890        9
2891    };
2892    (0,3) => {
2893        10
2894    };
2895}
2896#[macro_export]
2897macro_rules! c_reg_3x4 {
2898    (0,0) => {
2899        4
2900    };
2901    (1,0) => {
2902        5
2903    };
2904    (2,0) => {
2905        6
2906    };
2907    (0,1) => {
2908        7
2909    };
2910    (1,1) => {
2911        8
2912    };
2913    (2,1) => {
2914        9
2915    };
2916    (0,2) => {
2917        10
2918    };
2919    (1,2) => {
2920        11
2921    };
2922    (2,2) => {
2923        12
2924    };
2925    (0,3) => {
2926        13
2927    };
2928    (1,3) => {
2929        14
2930    };
2931    (2,3) => {
2932        15
2933    };
2934}
2935#[macro_export]
2936macro_rules! c_reg_2x6 {
2937    (0,0) => {
2938        4
2939    };
2940    (1,0) => {
2941        5
2942    };
2943    (0,1) => {
2944        6
2945    };
2946    (1,1) => {
2947        7
2948    };
2949    (0,2) => {
2950        8
2951    };
2952    (1,2) => {
2953        9
2954    };
2955    (0,3) => {
2956        10
2957    };
2958    (1,3) => {
2959        11
2960    };
2961    (0,4) => {
2962        12
2963    };
2964    (1,4) => {
2965        13
2966    };
2967    (0,5) => {
2968        14
2969    };
2970    (1,5) => {
2971        15
2972    };
2973}
2974#[macro_export]
2975macro_rules! c_reg_1x6 {
2976    (0,0) => {
2977        7
2978    };
2979    (0,1) => {
2980        8
2981    };
2982    (0,2) => {
2983        9
2984    };
2985    (0,3) => {
2986        10
2987    };
2988    (0,4) => {
2989        11
2990    };
2991    (0,5) => {
2992        12
2993    };
2994}
2995#[macro_export]
2996macro_rules! c_reg_3x8 {
2997    (0,0) => {
2998        8
2999    };
3000    (1,0) => {
3001        9
3002    };
3003    (2,0) => {
3004        10
3005    };
3006    (0,1) => {
3007        11
3008    };
3009    (1,1) => {
3010        12
3011    };
3012    (2,1) => {
3013        13
3014    };
3015    (0,2) => {
3016        14
3017    };
3018    (1,2) => {
3019        15
3020    };
3021    (2,2) => {
3022        16
3023    };
3024    (0,3) => {
3025        17
3026    };
3027    (1,3) => {
3028        18
3029    };
3030    (2,3) => {
3031        19
3032    };
3033    (0,4) => {
3034        20
3035    };
3036    (1,4) => {
3037        21
3038    };
3039    (2,4) => {
3040        22
3041    };
3042    (0,5) => {
3043        23
3044    };
3045    (1,5) => {
3046        24
3047    };
3048    (2,5) => {
3049        25
3050    };
3051    (0,6) => {
3052        26
3053    };
3054    (1,6) => {
3055        27
3056    };
3057    (2,6) => {
3058        28
3059    };
3060    (0,7) => {
3061        29
3062    };
3063    (1,7) => {
3064        30
3065    };
3066    (2,7) => {
3067        31
3068    };
3069}
3070#[macro_export]
3071macro_rules! c_reg_2x12 {
3072    (0,0) => {
3073        8
3074    };
3075    (1,0) => {
3076        9
3077    };
3078    (0,1) => {
3079        10
3080    };
3081    (1,1) => {
3082        11
3083    };
3084    (0,2) => {
3085        12
3086    };
3087    (1,2) => {
3088        13
3089    };
3090    (0,3) => {
3091        14
3092    };
3093    (1,3) => {
3094        15
3095    };
3096    (0,4) => {
3097        16
3098    };
3099    (1,4) => {
3100        17
3101    };
3102    (0,5) => {
3103        18
3104    };
3105    (1,5) => {
3106        19
3107    };
3108    (0,6) => {
3109        20
3110    };
3111    (1,6) => {
3112        21
3113    };
3114    (0,7) => {
3115        22
3116    };
3117    (1,7) => {
3118        23
3119    };
3120    (0,8) => {
3121        24
3122    };
3123    (1,8) => {
3124        25
3125    };
3126    (0,9) => {
3127        26
3128    };
3129    (1,9) => {
3130        27
3131    };
3132    (0,10) => {
3133        28
3134    };
3135    (1,10) => {
3136        29
3137    };
3138    (0,11) => {
3139        30
3140    };
3141    (1,11) => {
3142        31
3143    };
3144}
3145#[macro_export]
3146macro_rules! c_reg_1x12 {
3147    (0,0) => {
3148        9
3149    };
3150    (0,1) => {
3151        10
3152    };
3153    (0,2) => {
3154        11
3155    };
3156    (0,3) => {
3157        12
3158    };
3159    (0,4) => {
3160        13
3161    };
3162    (0,5) => {
3163        14
3164    };
3165    (0,6) => {
3166        15
3167    };
3168    (0,7) => {
3169        16
3170    };
3171    (0,8) => {
3172        17
3173    };
3174    (0,9) => {
3175        18
3176    };
3177    (0,10) => {
3178        19
3179    };
3180    (0,11) => {
3181        20
3182    };
3183}
3184
3185#[macro_export]
3186macro_rules! acc_3x4 {
3187    ($ni:tt, $layout:tt, $q:tt) => {
3188        acc_p_avx!($layout, c_mem!($ni), $q, c_reg_3x4!(0, $ni), c_reg_3x4!(1, $ni), c_reg_3x4!(2, $ni))
3189    };
3190}
3191#[macro_export]
3192macro_rules! store_3x4 {
3193    ($ni:tt, $layout:tt) => {
3194        storep_avx!($layout, c_mem!($ni), c_reg_3x4!(0, $ni), c_reg_3x4!(1, $ni), c_reg_3x4!(2, $ni))
3195    };
3196}
3197#[macro_export]
3198macro_rules! acc_2x6 {
3199    ($ni:tt, $layout:tt, $q:tt) => {
3200        acc_p_avx!($layout, c_mem!($ni), $q, c_reg_2x6!(0, $ni), c_reg_2x6!(1, $ni))
3201    };
3202}
3203#[macro_export]
3204macro_rules! store_2x6 {
3205    ($ni:tt, $layout:tt) => {
3206        storep_avx!($layout, c_mem!($ni), c_reg_2x6!(0, $ni), c_reg_2x6!(1, $ni))
3207    };
3208}
3209#[macro_export]
3210macro_rules! acc_1x6 {
3211    ($ni:tt, $layout:tt, $q:tt) => {
3212        acc_p_avx!($layout, c_mem!($ni), $q, c_reg_1x6!(0, $ni))
3213    };
3214}
3215#[macro_export]
3216macro_rules! store_1x6 {
3217    ($ni:tt, $layout:tt) => {
3218        storep_avx!($layout, c_mem!($ni), c_reg_1x6!(0, $ni))
3219    };
3220}
3221
3222#[macro_export]
3223macro_rules! acc_3x8 {
3224    ($ni:tt, $layout:tt, $q:tt) => {
3225        acc_p_avx512!($layout, c_mem!($ni), $q, c_reg_3x8!(0, $ni), c_reg_3x8!(1, $ni), c_reg_3x8!(2, $ni))
3226    };
3227}
3228#[macro_export]
3229macro_rules! store_3x8 {
3230    ($ni:tt, $layout:tt) => {
3231        storep_avx512!($layout, c_mem!($ni), c_reg_3x8!(0, $ni), c_reg_3x8!(1, $ni), c_reg_3x8!(2, $ni))
3232    };
3233}
3234#[macro_export]
3235macro_rules! acc_2x12 {
3236    ($ni:tt, $layout:tt, $q:tt) => {
3237        acc_p_avx512!($layout, c_mem!($ni), $q, c_reg_2x12!(0, $ni), c_reg_2x12!(1, $ni))
3238    };
3239}
3240#[macro_export]
3241macro_rules! store_2x12 {
3242    ($ni:tt, $layout:tt) => {
3243        storep_avx512!($layout, c_mem!($ni), c_reg_2x12!(0, $ni), c_reg_2x12!(1, $ni))
3244    };
3245}
3246#[macro_export]
3247macro_rules! acc_1x12 {
3248    ($ni:tt, $layout:tt, $q:tt) => {
3249        acc_p_avx512!($layout, c_mem!($ni), $q, c_reg_1x12!(0, $ni))
3250    };
3251}
3252#[macro_export]
3253macro_rules! store_1x12 {
3254    ($ni:tt, $layout:tt) => {
3255        storep_avx512!($layout, c_mem!($ni), c_reg_1x12!(0, $ni))
3256    };
3257}
3258#[macro_export]
3259macro_rules! acc_p_avx {
3260    ($layout:tt, $m0:expr, $q:tt, $r1:expr, $r2:expr, $r3:expr) => {
3261        concat!(
3262            beta_fmadd!(C, $m0, $r1, $q),
3263            beta_fmadd!(C, pire_base::mem!($m0, "0x20"), $r2, $q),
3264            beta_fmadd!($layout, pire_base::mem!($m0, "0x40"), $r3, $q),
3265        )
3266    };
3267    ($layout:tt, $m0:expr, $q:tt, $r1:expr, $r2:expr) => {
3268        concat!(beta_fmadd!(C, $m0, $r1, $q), beta_fmadd!($layout, pire_base::mem!($m0, "0x20"), $r2, $q),)
3269    };
3270    ($layout:tt, $m0:expr, $q:tt, $r1:expr) => {
3271        concat!(beta_fmadd!($layout, $m0, $r1, $q),)
3272    };
3273}
3274
3275#[macro_export]
3276macro_rules! loadp_avx {
3277    (3, $m0:expr) => {
3278        concat!(
3279            loadp_unit!($m0, 0),
3280            loadp_unit!(pire_base::mem!($m0, "0x20"), 1),
3281            loadp_unit!(pire_base::mem!($m0, "0x40"), 2),
3282        )
3283    };
3284    (2, $m0:expr) => {
3285        concat!(loadp_unit!($m0, 0), loadp_unit!(pire_base::mem!($m0, "0x20"), 1),)
3286    };
3287    (1, $m0:expr) => {
3288        concat!(loadp_unit!($m0, 0),)
3289    };
3290}
3291
3292#[macro_export]
3293macro_rules! storep_avx {
3294    ($layout:tt, $m0:expr, $r1:expr, $r2:expr, $r3:expr) => {
3295        concat!(
3296            storep_unit!(C, $r1, $m0),
3297            storep_unit!(C, $r2, pire_base::mem!($m0, "0x20")),
3298            storep_unit!($layout, $r3, pire_base::mem!($m0, "0x40")),
3299        )
3300    };
3301    ($layout:tt, $m0:expr, $r1:expr, $r2:expr) => {
3302        concat!(storep_unit!(C, $r1, $m0), storep_unit!($layout, $r2, pire_base::mem!($m0, "0x20")),)
3303    };
3304    ($layout:tt, $m0:expr, $r1:expr) => {
3305        concat!(storep_unit!($layout, $r1, $m0),)
3306    };
3307}
3308
3309#[macro_export]
3310macro_rules! acc_p_avx512 {
3311    ($layout:tt, $m0:expr, $q:tt, $r1:expr, $r2:expr, $r3:expr) => {
3312        concat!(
3313            beta_fmadd!(C, $m0, $r1, $q),
3314            beta_fmadd!(C, pire_base::mem!($m0, "0x40"), $r2, $q),
3315            beta_fmadd!($layout, pire_base::mem!($m0, "0x80"), $r3, $q),
3316        )
3317    };
3318    ($layout:tt, $m0:expr, $q:tt, $r1:expr, $r2:expr) => {
3319        concat!(beta_fmadd!(C, $m0, $r1, $q), beta_fmadd!($layout, pire_base::mem!($m0, "0x40"), $r2, $q),)
3320    };
3321    ($layout:tt, $m0:expr, $q:tt, $r1:expr) => {
3322        concat!(beta_fmadd!($layout, $m0, $r1, $q),)
3323    };
3324}
3325
3326#[macro_export]
3327macro_rules! loadp_avx512 {
3328    (3, $m0:expr) => {
3329        concat!(
3330            loadp_unit!($m0, 0),
3331            loadp_unit!(pire_base::mem!($m0, "0x40"), 1),
3332            loadp_unit!(pire_base::mem!($m0, "0x80"), 2),
3333        )
3334    };
3335    (2, $m0:expr) => {
3336        concat!(loadp_unit!($m0, 0), loadp_unit!(pire_base::mem!($m0, "0x40"), 1),)
3337    };
3338    (1, $m0:expr) => {
3339        concat!(loadp_unit!($m0, 0),)
3340    };
3341}
3342
3343#[macro_export]
3344macro_rules! storep_avx512 {
3345    ($layout:tt, $m0:expr, $r1:expr, $r2:expr, $r3:expr) => {
3346        concat!(
3347            storep_unit!(C, $r1, $m0),
3348            storep_unit!(C, $r2, pire_base::mem!($m0, "0x40")),
3349            storep_unit!($layout, $r3, pire_base::mem!($m0, "0x80")),
3350        )
3351    };
3352    ($layout:tt, $m0:expr, $r1:expr, $r2:expr) => {
3353        concat!(storep_unit!(C, $r1, $m0), storep_unit!($layout, $r2, pire_base::mem!($m0, "0x40")),)
3354    };
3355    ($layout:tt, $m0:expr, $r1:expr) => {
3356        concat!(storep_unit!($layout, $r1, $m0),)
3357    };
3358}
3359
3360#[macro_export]
3361macro_rules! cum_seq {
3362    ($step_macro:tt, $nr:tt, $layout:tt, $b:tt) => {
3363        seq!(n in 0..$nr {
3364            concat!(#($step_macro!(n, $layout, $b),)*)
3365        })
3366    };
3367    ($step_macro:tt, $nr:tt, $layout:tt) => {
3368        seq!(n in 0..$nr {
3369            concat!(#($step_macro!(n, $layout),)*)
3370        })
3371    };
3372}
3373
3374#[macro_export]
3375macro_rules! b_num_3x8 {
3376    (0) => {
3377        3
3378    };
3379    (1) => {
3380        4
3381    };
3382    (2) => {
3383        5
3384    };
3385    (3) => {
3386        6
3387    };
3388    (4) => {
3389        7
3390    };
3391    (5) => {
3392        3
3393    };
3394    (6) => {
3395        4
3396    };
3397    (7) => {
3398        5
3399    };
3400}
3401#[macro_export]
3402macro_rules! b_num_2x12 {
3403    (0) => {
3404        2
3405    };
3406    (1) => {
3407        3
3408    };
3409    (2) => {
3410        4
3411    };
3412    (3) => {
3413        5
3414    };
3415    (4) => {
3416        6
3417    };
3418    (5) => {
3419        7
3420    };
3421    (6) => {
3422        2
3423    };
3424    (7) => {
3425        3
3426    };
3427    (8) => {
3428        4
3429    };
3430    (9) => {
3431        5
3432    };
3433    (10) => {
3434        6
3435    };
3436    (11) => {
3437        7
3438    };
3439}
3440#[macro_export]
3441macro_rules! b_num_1x12 {
3442    (0) => {
3443        1
3444    };
3445    (1) => {
3446        2
3447    };
3448    (2) => {
3449        3
3450    };
3451    (3) => {
3452        4
3453    };
3454    (4) => {
3455        5
3456    };
3457    (5) => {
3458        6
3459    };
3460    (6) => {
3461        7
3462    };
3463    (7) => {
3464        8
3465    };
3466    (8) => {
3467        9
3468    };
3469    (9) => {
3470        10
3471    };
3472    (10) => {
3473        11
3474    };
3475    (11) => {
3476        12
3477    };
3478}
3479
3480#[macro_export]
3481macro_rules! b_num_2x4 {
3482    (0) => {
3483        2
3484    };
3485    (1) => {
3486        3
3487    };
3488    (2) => {
3489        2
3490    };
3491    (3) => {
3492        3
3493    };
3494}
3495#[macro_export]
3496macro_rules! b_num_1x4 {
3497    (0) => {
3498        1
3499    };
3500    (1) => {
3501        2
3502    };
3503    (2) => {
3504        3
3505    };
3506    (3) => {
3507        4
3508    };
3509}
3510#[macro_export]
3511macro_rules! b_num_2x6 {
3512    (0) => {
3513        2
3514    };
3515    (1) => {
3516        3
3517    };
3518    (2) => {
3519        2
3520    };
3521    (3) => {
3522        3
3523    };
3524    (4) => {
3525        2
3526    };
3527    (5) => {
3528        3
3529    };
3530}
3531
3532#[macro_export]
3533macro_rules! b_num_1x6 {
3534    (0) => {
3535        1
3536    };
3537    (1) => {
3538        2
3539    };
3540    (2) => {
3541        3
3542    };
3543    (3) => {
3544        4
3545    };
3546    (4) => {
3547        5
3548    };
3549    (5) => {
3550        6
3551    };
3552}
3553
3554#[macro_export]
3555macro_rules! fmadd_3x8 {
3556    ($ni:tt) => {
3557        concat!(
3558            vfmadd!(0, b_num_3x8!($ni), c_reg_3x8!(0, $ni)),
3559            vfmadd!(1, b_num_3x8!($ni), c_reg_3x8!(1, $ni)),
3560            vfmadd!(2, b_num_3x8!($ni), c_reg_3x8!(2, $ni)),
3561        )
3562    };
3563}
3564#[macro_export]
3565macro_rules! fmadd_2x12 {
3566    ($ni:tt) => {
3567        concat!(vfmadd!(0, b_num_2x12!($ni), c_reg_2x12!(0, $ni)), vfmadd!(1, b_num_2x12!($ni), c_reg_2x12!(1, $ni)),)
3568    };
3569}
3570#[macro_export]
3571macro_rules! fmadd_1x12 {
3572    ($ni:tt) => {
3573        concat!(vfmadd!(0, b_num_1x12!($ni), c_reg_1x12!(0, $ni)),)
3574    };
3575}
3576
3577#[cfg(target_arch = "x86_64")]
3578#[macro_export]
3579macro_rules! prefetch_c_avx512 {
3580    (3, $nr:tt, $c:tt, $ldc:tt) => {
3581        use std::arch::x86_64::_mm_prefetch;
3582        seq!(j in 0..$nr {
3583            let c_u8 = $c.add(j*$ldc) as *const i8;
3584            _mm_prefetch(c_u8, 3);
3585            _mm_prefetch(c_u8.add(64), 3);
3586            _mm_prefetch(c_u8.add(128), 3);
3587        });
3588    };
3589    (2, $nr:tt, $c:tt, $ldc:tt) => {
3590        use std::arch::x86_64::_mm_prefetch;
3591        seq!(j in 0..$nr {
3592            let c_u8 = $c.add(j*$ldc) as *const i8;
3593            _mm_prefetch(c_u8, 3);
3594            _mm_prefetch(c_u8.add(64), 3);
3595        });
3596    };
3597    (1, $nr:tt, $c:tt, $ldc:tt) => {
3598        use std::arch::x86_64::_mm_prefetch;
3599        seq!(j in 0..$nr {
3600            let c_u8 = $c.add(j*$ldc) as *const i8;
3601            _mm_prefetch(c_u8, 3);
3602        });
3603    };
3604}
3605
3606#[cfg(target_arch = "x86_64")]
3607#[macro_export]
3608macro_rules! prefetch_c_avx {
3609    (3, $nr:tt, $c:tt, $ldc:tt) => {
3610        use std::arch::x86_64::_mm_prefetch;
3611        seq!(j in 0..$nr {
3612            let c_u8 = $c.add(j*$ldc) as *const i8;
3613            _mm_prefetch(c_u8, 3);
3614            _mm_prefetch(c_u8.add(64), 3);
3615            _mm_prefetch(c_u8.add(92), 3);
3616        });
3617    };
3618    (2, $nr:tt, $c:tt, $ldc:tt) => {
3619        use std::arch::x86_64::_mm_prefetch;
3620        seq!(j in 0..$nr {
3621            let c_u8 = $c.add(j*$ldc) as *const i8;
3622            _mm_prefetch(c_u8, 3);
3623            _mm_prefetch(c_u8.add(60), 3);
3624        });
3625    };
3626    (1, $nr:tt, $c:tt, $ldc:tt) => {
3627        use std::arch::x86_64::_mm_prefetch;
3628        seq!(j in 0..$nr {
3629            let c_u8 = $c.add(j*$ldc) as *const i8;
3630            _mm_prefetch(c_u8, 3);
3631        });
3632    };
3633}
3634
3635#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3636#[macro_export]
3637macro_rules! prefetch_c_sse {
3638    (3, $nr:tt, $c:tt, $ldc:tt) => {
3639        #[cfg(target_arch="x86")]
3640        use std::arch::x86::_mm_prefetch;
3641        #[cfg(target_arch="x86_64")]
3642        use std::arch::x86_64::_mm_prefetch;
3643        seq!(j in 0..$nr {
3644            let c_u8 = $c.add(j*$ldc) as *const i8;
3645            _mm_prefetch(c_u8, 3);
3646            _mm_prefetch(c_u8.add(64), 3);
3647        });
3648    };
3649    (2, $nr:tt, $c:tt, $ldc:tt) => {
3650        #[cfg(target_arch="x86")]
3651        use std::arch::x86::_mm_prefetch;
3652        #[cfg(target_arch="x86_64")]
3653        use std::arch::x86_64::_mm_prefetch;
3654        seq!(j in 0..$nr {
3655            let c_u8 = $c.add(j*$ldc) as *const i8;
3656            _mm_prefetch(c_u8, 3);
3657        });
3658    };
3659    (1, $nr:tt, $c:tt, $ldc:tt) => {
3660        #[cfg(target_arch="x86")]
3661        use std::arch::x86::_mm_prefetch;
3662        #[cfg(target_arch="x86_64")]
3663        use std::arch::x86_64::_mm_prefetch;
3664        seq!(j in 0..$nr {
3665            let c_u8 = $c.add(j*$ldc) as *const i8;
3666            _mm_prefetch(c_u8, 3);
3667        });
3668    };
3669}
3670
3671#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3672#[macro_export]
3673macro_rules! prefetch_0 {
3674    ($dist:tt, $reg:tt) => {
3675        concat!("prefetcht0 ", $dist, "(", $reg, ")\n",)
3676    };
3677}
3678
3679#[cfg(target_arch = "aarch64")]
3680#[macro_export]
3681macro_rules! prefetch_0 {
3682    ($dist:tt, $reg:tt) => {
3683        concat!("prfm pldl1keep, [", $reg, ", #", $dist, "] \n",)
3684    };
3685}
3686
3687#[macro_export]
3688macro_rules! prefetch_b {
3689    (S) => {
3690        ""
3691    };
3692    (B) => {
3693        concat!("prefetcht0 192({bx}) \n",)
3694    };
3695}
3696
3697// *********************************************** def ukernel ************************************************
3698
3699#[cfg(target_arch = "x86_64")]
3700#[macro_export]
3701macro_rules! def_ukernel_avx {
3702    (
3703        $k_unit:tt,
3704        $step_macro:tt,
3705        $acc_macro:tt,
3706        $store_macro:tt,
3707        $mr:tt, $nr:tt,
3708        $n0:tt, $n1:tt,
3709        $b_layout:tt,
3710        $is_partial:tt,
3711        $func_name:ident
3712    ) => {
3713        pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
3714            a: *const TA, b: *const TB, c: *mut TC,
3715            alpha: *const TS, beta: *const TS,
3716            k: usize,
3717            d_arr: [usize; 3], c_cs: usize,
3718            m: usize, n: usize,
3719            f: F,
3720        ) {
3721            use core::mem::size_of;
3722            const MR: usize = $mr * VS;
3723            mask_ptr!($is_partial, m, x, mask_ptr);
3724            let mut dim_arr = [d_arr[0]*size_of::<TA>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / ($k_unit*4), (k % ($k_unit*4)) / $k_unit];
3725            let mut c_k = c;
3726            let mut c_buf = [ZERO;MR*$nr];
3727            let alpha_st = if *alpha == ONE_SCALAR {
3728                0i32
3729            } else {
3730                1i32
3731            };
3732            let beta_st = if *beta == ZERO_SCALAR {
3733                0i32
3734            } else if *beta == ONE_SCALAR {
3735                1i32
3736            } else {
3737                2i32
3738            };
3739            if BUF {
3740                pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, n, MR);
3741                dim_arr[2] = MR*TC_SIZE;
3742                c_k = c_buf.as_mut_ptr();
3743            }
3744            let _ = 'blk: {
3745                seq!(ni in $n0..$n1 {
3746                    if pire_base::n_cond!($n0, ni, n) {
3747                        pire_base::prefetch_c_avx!($mr,ni,c,c_cs);
3748                        asm!(
3749                            vzero_kernel!(),
3750
3751                            init_ab_avx!($b_layout),
3752
3753                            "test {x0}, {x0}", "je 3f", // CONSIDKLEFT
3754
3755                            "2:", // KITER
3756                            pire_base::prefetch_b!($b_layout),
3757                            $step_macro!(ni, $b_layout, 0),
3758                            $step_macro!(ni, $b_layout, 1),
3759                            $step_macro!(ni, $b_layout, 2),
3760                            $step_macro!(ni, $b_layout, 3),
3761
3762                            inc_a_k_unroll!($mr, 4),
3763                            inc_b_k_unroll!($b_layout, ni, 4),
3764
3765                            "dec {x0}", "jne 2b", // KITER
3766
3767                            "3:", // CONSIDKLEFT
3768                            "mov 32({dim_arrx}), {x0}",
3769                            "test {x0},{x0}", "je 5f", // POSTACCUM
3770
3771                            "4:", // KLEFT
3772                            $step_macro!(ni, $b_layout, 0),
3773                            inc_a_k_unroll!($mr, 1),
3774                            inc_b_k_unroll!($b_layout, ni, 1),
3775
3776                            "dec {x0}", "jne 4b", // KLEFT
3777
3778                            "5:", // POSTACCUM
3779                            c_load!(),
3780
3781                            "cmpw $0, ({alpha_st})",
3782                            "je 9f",
3783                            alpha_scale!(),
3784                            "9:",
3785
3786                            load_mask!($is_partial),
3787
3788                            "cmpw $0, ({beta_st})",
3789                            "je 6f",
3790
3791                            "cmpw $1, ({beta_st})",
3792                            "je 15f",
3793
3794                            load_beta!(),
3795                            pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
3796                            "jmp 6f",
3797
3798                            "15:",
3799                            pire_base::cum_seq!($acc_macro,ni,$is_partial,1),
3800
3801                            "6:",
3802                            pire_base::cum_seq!($store_macro,ni,$is_partial),
3803
3804                            ax = inout(reg) a => _,
3805                            bx = inout(reg) b => _,
3806                            cx = inout(reg) c_k => _,
3807                            dim_arrx = inout(reg) dim_arr.as_ptr() => _,
3808                            alphax = inout(reg) alpha => _,
3809                            betax = inout(reg) beta => _,
3810                            beta_st = in(reg) &beta_st,
3811                            alpha_st = in(reg) &alpha_st,
3812                            maskx = inout(reg) mask_ptr => _,
3813                            x0 = out(reg) _,
3814                            x1 = out(reg) _,
3815                            x2 = out(reg) _,
3816                            x3 = out(reg) _,
3817                            out("ymm0") _, out("ymm1") _, out("ymm2") _, out("ymm3") _,
3818                            out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _,
3819                            out("ymm8") _, out("ymm9") _, out("ymm10") _, out("ymm11") _,
3820                            out("ymm12") _, out("ymm13") _, out("ymm14") _, out("ymm15") _,
3821                            options(att_syntax)
3822                        );
3823                        break 'blk;
3824                    }
3825                });
3826            };
3827            if BUF {
3828                for j in 0..n {
3829                    f.call(c_k.add(j*MR), MR);
3830                }
3831                pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, MR);
3832            } else {
3833                for j in 0..n {
3834                    f.call(c_k.add(j*c_cs), m);
3835                }
3836            }
3837        }
3838    };
3839}
3840
3841#[macro_export]
3842macro_rules! def_ukernel_avx512 {
3843    (
3844        $k_unit:tt,
3845        $step_macro:tt,
3846        $acc_macro:tt,
3847        $store_macro:tt,
3848        $mr:tt, $nr:tt,
3849        $n0:tt, $n1:tt,
3850        $b_layout:tt,
3851        $is_partial:tt,
3852        $func_name:ident
3853    ) => {
3854        pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
3855            a: *const TA, b: *const TB, c: *mut TC,
3856            alpha: *const TS, beta: *const TS,
3857            k: usize,
3858            d_arr: [usize; 3], c_cs: usize,
3859            m: usize, n: usize,
3860            f: F,
3861        ) {
3862            use core::mem::size_of;
3863            const MR: usize = $mr * VS;
3864            mask_ptr!($is_partial, m, x, mask_ptr);
3865            let mut dim_arr = [d_arr[0]*size_of::<TA>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / ($k_unit*4), (k % ($k_unit*4)) / $k_unit];
3866            let mut c_k = c;
3867            let mut c_buf = [ZERO;MR*$nr];
3868            let alpha_st = if *alpha == ONE_SCALAR {
3869                0i32
3870            } else {
3871                1i32
3872            };
3873            let beta_st = if *beta == ZERO_SCALAR {
3874                0i32
3875            } else if *beta == ONE_SCALAR {
3876                1i32
3877            } else {
3878                2i32
3879            };
3880            if BUF {
3881                pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, n, MR);
3882                dim_arr[2] = MR*TC_SIZE;
3883                c_k = c_buf.as_mut_ptr();
3884            }
3885            let _ = 'blk: {
3886                seq!(ni in $n0..$n1 {
3887                    if pire_base::n_cond!($n0, ni, n) {
3888                        pire_base::prefetch_c_avx512!($mr,ni,c,c_cs);
3889                        asm!(
3890                            vzero_kernel!(),
3891
3892                            init_ab!($b_layout),
3893                            "test {x0}, {x0}", "je 3f", // CONSIDKLEFT
3894
3895                            "2:", // KITER
3896                            $step_macro!(ni, $b_layout),
3897                            $step_macro!(ni, $b_layout),
3898                            $step_macro!(ni, $b_layout),
3899                            $step_macro!(ni, $b_layout),
3900                            "dec {x0}", "jne 2b", // KITER
3901
3902                            "3:", // CONSIDKLEFT
3903                            "mov 32({dim_arrx}), {x0}",
3904                            "test {x0},{x0}", "je 5f", // POSTACCUM
3905
3906                            "4:", // KLEFT
3907                            $step_macro!(ni, $b_layout),
3908
3909                            "dec {x0}", "jne 4b", // KLEFT
3910
3911                            "5:", // POSTACCUM
3912                            c_load!(),
3913
3914                            "cmpw $0, ({alpha_st})",
3915                            "je 9f",
3916                            alpha_scale!(),
3917                            "9:",
3918
3919                            load_mask!($is_partial),
3920
3921                            "cmpw $0, ({beta_st})",
3922                            "je 6f",
3923
3924                            "cmpw $1, ({beta_st})",
3925                            "je 15f",
3926
3927                            load_beta!(),
3928                            pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
3929                            "jmp 6f",
3930
3931                            "15:",
3932                            pire_base::cum_seq!($acc_macro,ni,$is_partial,1),
3933
3934                            "6:",
3935                            pire_base::cum_seq!($store_macro,ni,$is_partial),
3936
3937                            ax = inout(reg) a => _,
3938                            bx = inout(reg) b => _,
3939                            cx = inout(reg) c_k => _,
3940                            dim_arrx = inout(reg) dim_arr.as_ptr() => _,
3941                            alphax = inout(reg) alpha => _,
3942                            betax = inout(reg) beta => _,
3943                            beta_st = in(reg) &beta_st,
3944                            alpha_st = in(reg) &alpha_st,
3945                            maskx = inout(reg) mask_ptr => _,
3946                            x0 = out(reg) _,
3947                            x1 = out(reg) _,
3948                            x2 = out(reg) _,
3949                            x3 = out(reg) _,
3950                            x4 = out(reg) _,
3951                            x5 = out(reg) _,
3952                            out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _,
3953                            out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,
3954                            out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _,
3955                            out("zmm12") _, out("zmm13") _, out("zmm14") _, out("zmm15") _,
3956                            out("zmm16") _, out("zmm17") _, out("zmm18") _, out("zmm19") _,
3957                            out("zmm20") _, out("zmm21") _, out("zmm22") _, out("zmm23") _,
3958                            out("zmm24") _, out("zmm25") _, out("zmm26") _, out("zmm27") _,
3959                            out("zmm28") _, out("zmm29") _, out("zmm30") _, out("zmm31") _,
3960                            out("k1") _,
3961                            options(att_syntax)
3962                        );
3963                        break 'blk;
3964                    }
3965                });
3966            };
3967            if BUF {
3968                for j in 0..n {
3969                    f.call(c_k.add(j*MR), MR);
3970                }
3971                pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, MR);
3972            } else {
3973                for j in 0..n {
3974                    f.call(c_k.add(j*c_cs), m);
3975                }
3976            }
3977        }
3978    };
3979}
3980
3981#[cfg(target_arch = "x86_64")]
3982#[macro_export]
3983macro_rules! def_ukernel_sse {
3984    (
3985        $k_unit:tt,
3986        $step_macro:tt,
3987        $acc_macro:tt,
3988        $store_macro:tt,
3989        $mr:tt, $nr:tt,
3990        $n0:tt, $n1:tt,
3991        $b_layout:tt,
3992        $is_partial:tt,
3993        $func_name:ident
3994    ) => {
3995        pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
3996            a: *const TA, b: *const TB, c: *mut TC,
3997            alpha: *const TS, beta: *const TS,
3998            k: usize,
3999            d_arr: [usize; 3], c_cs: usize,
4000            m: usize, n: usize,
4001            f: F,
4002        ) {
4003            use core::mem::size_of;
4004            const MR: usize = $mr * VS;
4005            let mut dim_arr = [d_arr[0]*size_of::<TA>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / ($k_unit*4), (k % ($k_unit*4)) / $k_unit];
4006            let mut c_k = c;
4007            let mut c_buf = [ZERO;MR*$nr];
4008            let alpha_st = if *alpha == ONE_SCALAR {
4009                0i32
4010            } else {
4011                1i32
4012            };
4013            let beta_st = if *beta == ZERO_SCALAR {
4014                0i32
4015            } else if *beta == ONE_SCALAR {
4016                1i32
4017            } else {
4018                2i32
4019            };
4020            if BUF {
4021                pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, n, MR);
4022                dim_arr[2] = MR*TC_SIZE;
4023                c_k = c_buf.as_mut_ptr();
4024            }
4025            let _ = 'blk: {
4026                seq!(ni in $n0..$n1 {
4027                    if pire_base::n_cond!($n0, ni, n) {
4028                        pire_base::prefetch_c_sse!($mr,ni,c,c_cs);
4029                        asm!(
4030                            vzero_kernel!(),
4031
4032                            init_ab_avx!($b_layout),
4033                            "test {x0}, {x0}", "je 3f", // CONSIDKLEFT
4034
4035                            "2:", // KITER
4036                            pire_base::prefetch_b!($b_layout),
4037                            $step_macro!(ni, $b_layout, 0),
4038                            $step_macro!(ni, $b_layout, 1),
4039                            $step_macro!(ni, $b_layout, 2),
4040                            $step_macro!(ni, $b_layout, 3),
4041
4042                            inc_a_k_unroll!($mr, 4),
4043                            inc_b_k_unroll!($b_layout, ni, 4),
4044                            "dec {x0}", "jne 2b", // KITER
4045
4046                            "3:", // CONSIDKLEFT
4047                            "mov 32({dim_arrx}), {x0}",
4048                            "test {x0},{x0}", "je 5f", // POSTACCUM
4049
4050                            "4:", // KLEFT
4051                            $step_macro!(ni, $b_layout, 0),
4052                            inc_a_k_unroll!($mr, 1),
4053                            inc_b_k_unroll!($b_layout, ni, 1),
4054
4055                            "dec {x0}", "jne 4b", // KLEFT
4056
4057                            "5:", // POSTACCUM
4058                            c_load!(),
4059
4060                            "cmpw $0, ({alpha_st})",
4061                            "je 9f",
4062                            alpha_scale!(),
4063                            "9:",
4064
4065                            "cmpw $0, ({beta_st})",
4066                            "je 6f",
4067
4068                            "cmpw $1, ({beta_st})",
4069                            "je 15f",
4070
4071                            load_beta!(),
4072                            pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
4073                            "jmp 6f",
4074
4075                            "15:",
4076                            pire_base::cum_seq!($acc_macro,ni,$is_partial,1),
4077
4078                            "6:",
4079                            pire_base::cum_seq!($store_macro,ni,$is_partial),
4080
4081                            ax = inout(reg) a => _,
4082                            bx = inout(reg) b => _,
4083                            cx = inout(reg) c_k => _,
4084                            dim_arrx = inout(reg) dim_arr.as_ptr() => _,
4085                            alphax = inout(reg) alpha => _,
4086                            betax = inout(reg) beta => _,
4087                            beta_st = in(reg) &beta_st,
4088                            alpha_st = in(reg) &alpha_st,
4089                            x0 = out(reg) _,
4090                            x1 = out(reg) _,
4091                            x2 = out(reg) _,
4092                            x3 = out(reg) _,
4093                            out("xmm0") _, out("xmm1") _, out("xmm2") _, out("xmm3") _,
4094                            out("xmm4") _, out("xmm5") _, out("xmm6") _, out("xmm7") _,
4095                            out("xmm8") _, out("xmm9") _, out("xmm10") _, out("xmm11") _,
4096                            out("xmm12") _, out("xmm13") _, out("xmm14") _, out("xmm15") _,
4097                            options(att_syntax)
4098                        );
4099                        break 'blk;
4100                    }
4101                });
4102            };
4103            if BUF {
4104                for j in 0..n {
4105                    f.call(c_k.add(j*MR), MR);
4106                }
4107                pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, MR);
4108            } else {
4109                for j in 0..n {
4110                    f.call(c_k.add(j*c_cs), m);
4111                }
4112            }
4113        }
4114    };
4115}
4116
4117#[cfg(target_arch = "x86_64")]
4118#[macro_export]
4119macro_rules! def_ukernel_avx_2 {
4120    ($k_unit:tt, $step:ident, $acc:ident, $store:ident, $mr:tt, $nr:tt, $kl_pf:tt, $pf1_step:tt) => {
4121        pub(crate) unsafe fn ukernel_bbc<F: UnaryFnC, const BUF: bool>(
4122            a: *const TA, b: *const TB, c: *mut TC,
4123            alpha: *const TS, beta: *const TS,
4124            k: usize,
4125            d_arr: [usize; 3], c_cs: usize,
4126            a_pft1_offset: usize, _n: usize,
4127            f: F,
4128        ) {
4129            const MR: usize = $mr * VS;
4130            let k_l0 = k % $kl_pf;
4131            let k_l = if k_l0 == 0 {$kl_pf/$k_unit} else {k_l0/$k_unit};
4132            let k_i = (k - k_l*$k_unit) / (4*$k_unit);
4133            let mut c_k = c;
4134
4135            let mut dim_arr = [c_cs*TC_SIZE, k_i, k_l, a_pft1_offset];
4136            let mut c_buf = [ZERO; MR*$nr];
4137            let alpha_st = if *alpha == ONE_SCALAR {
4138                0i32
4139            } else {
4140                1i32
4141            };
4142            let beta_st = if *beta == ZERO_SCALAR {
4143                0i32
4144            } else if *beta == ONE_SCALAR {
4145                1i32
4146            } else {
4147                2i32
4148            };
4149            if BUF {
4150                pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, MR, $nr, MR);
4151                dim_arr[0] = MR*TC_SIZE;
4152                c_k = c_buf.as_mut_ptr();
4153            }
4154            asm!(
4155                vzero_kernel!(),
4156                init_ab_2!(B),
4157                "test {x0},{x0}",
4158                "je 3f",
4159                "mov {cx}, {x2}",
4160                "mov {ax}, {x5}",
4161                "mov 24({dim_arrx}),{x1}",
4162                "add {x1}, {x5}",
4163                "mov ({dim_arrx}),{x1}",
4164                "2:",
4165                prefetch_0!(256, "{bx}"),
4166                $step!($nr, B, 0),
4167
4168                "movq $64*4, {x4}",
4169                // divisiblity by 4
4170                "testq $3, {x0}",
4171                "cmovz {x1},{x4}",
4172
4173                $step!($nr, B, 1),
4174
4175                "prefetcht1 ({x2})",
4176
4177                "subq $64*3, {x2}",
4178                "addq {x4}, {x2}",
4179
4180                $step!($nr, B, 2),
4181
4182                "prefetcht1 ({x5})",
4183                "addq $16, {x5}",
4184
4185                "testq $63, {x0}",
4186                "cmovz {cx},{x2}",
4187
4188                $step!($nr, B, 3),
4189
4190                inc_a_k_unroll!($mr, 4),
4191                inc_b_k_unroll!(B, $nr, 4),
4192
4193                "dec {x0}",
4194                "jne 2b",
4195                "3:",
4196                "mov 16({dim_arrx}),{x0}",
4197                "test {x0},{x0}", "je 5f", // POSTACCUM
4198
4199                "mov {cx}, {x2}",
4200                "mov ({dim_arrx}),{x1}",
4201                "4:",
4202                "prefetcht0 ({x2})",
4203                "prefetcht0 64({x2})",
4204                "prefetcht0 92({x2})",
4205                $step!($nr, B, 0),
4206                inc_a_k_unroll!($mr, 1),
4207                inc_b_k_unroll!(B, $nr, 1),
4208                "add {x1}, {x2}", "dec {x0}", "jne 4b",
4209
4210                "5:",
4211                c_load_2!(),
4212
4213                "cmpw $0, ({alpha_st})",
4214                "je 9f",
4215                alpha_scale!(),
4216                "9:",
4217                "cmpw $0, ({beta_st})",
4218                "je 6f",
4219
4220                "cmpw $1, ({beta_st})",
4221                "je 15f",
4222
4223                load_beta!(),
4224                pire_base::cum_seq!($acc,$nr,C,2),
4225                "jmp 6f",
4226
4227                "15:",
4228                pire_base::cum_seq!($acc,$nr,C,1),
4229
4230                "6:",
4231                pire_base::cum_seq!($store,$nr,C),
4232
4233                ax = inout(reg) a => _,
4234                bx = inout(reg) b => _,
4235                cx = inout(reg) c_k => _,
4236                dim_arrx = inout(reg) dim_arr.as_ptr() => _,
4237                alphax = inout(reg) alpha => _,
4238                betax = inout(reg) beta => _,
4239                beta_st = in(reg) &beta_st,
4240                alpha_st = in(reg) &alpha_st,
4241                x0 = out(reg) _,
4242                x1 = out(reg)_,
4243                x2 = out(reg) _,
4244                x3 = out(reg) _,
4245                x4 = out(reg) _,
4246                x5 = out(reg) _,
4247                out("ymm0") _, out("ymm1") _, out("ymm2") _, out("ymm3") _,
4248                out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _,
4249                out("ymm8") _, out("ymm9") _, out("ymm10") _, out("ymm11") _,
4250                out("ymm12") _, out("ymm13") _, out("ymm14") _, out("ymm15") _,
4251                options(att_syntax)
4252            );
4253            if BUF {
4254                for j in 0..$nr {
4255                    f.call(c_k.add(j*MR), MR);
4256                }
4257                pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, MR, $nr, MR);
4258            } else {
4259                for j in 0..$nr {
4260                    f.call(c_k.add(j*c_cs), MR);
4261                }
4262            }
4263        }
4264    };
4265}
4266
4267#[cfg(target_arch = "x86_64")]
4268#[macro_export]
4269macro_rules! def_ukernel_avx512_2 {
4270    ($k_unit:tt, $step:ident, $acc:ident, $store:ident, $mr:tt, $nr:tt, $kl_pf:tt, $pf1_step:tt) => {
4271        pub(crate) unsafe fn ukernel_bbc<F: UnaryFnC, const BUF: bool>(
4272            a: *const TA, b: *const TB, c: *mut TC,
4273            alpha: *const TS, beta: *const TS,
4274            k: usize,
4275            d_arr: [usize; 3], c_cs: usize,
4276            a_pft1_offset: usize, _n: usize,
4277            f: F,
4278        ) {
4279            const MR: usize = $mr * VS;
4280            let k_l0 = k % $kl_pf;
4281            let k_l = if k_l0 == 0 {$kl_pf/$k_unit} else {k_l0/$k_unit};
4282            let k_i = (k - k_l*$k_unit) / (4*$k_unit);
4283            let mut c_k = c;
4284
4285            let mut dim_arr = [c_cs*TC_SIZE, k_i, k_l, a_pft1_offset];
4286            let mut c_buf = [ZERO; MR * $nr];
4287            let alpha_st = if *alpha == ONE_SCALAR {
4288                0i32
4289            } else {
4290                1i32
4291            };
4292            let beta_st = if *beta == ZERO_SCALAR {
4293                0i32
4294            } else if *beta == ONE_SCALAR {
4295                1i32
4296            } else {
4297                2i32
4298            };
4299            if BUF {
4300                pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, MR, $nr, MR);
4301                dim_arr[0] = MR*TC_SIZE;
4302                c_k = c_buf.as_mut_ptr();
4303            }
4304            asm!(
4305                vzero_kernel!(),
4306                init_ab_2!(B),
4307                "test {x0},{x0}", "je 3f",
4308
4309                "mov {cx}, {x2}",
4310                "mov {ax}, {x5}",
4311                "mov 24({dim_arrx}),{x1}",
4312                "add {x1}, {x5}",
4313                "mov ({dim_arrx}),{x1}",
4314
4315                "2:", // KITER
4316                $step!($nr, B),
4317
4318                "movq $64*4, {x4}",
4319                // divisiblity by 4
4320                "testq $3, {x0}",
4321                "cmovz {x1},{x4}",
4322
4323                $step!($nr, B),
4324
4325                "prefetcht1 ({x2})",
4326
4327                "subq $64*3, {x2}",
4328                "addq {x4}, {x2}",
4329
4330                $step!($nr, B),
4331
4332                "prefetcht1 ({x5})",
4333                concat!("addq $", $pf1_step, ", {x5}"),
4334
4335                "testq $63, {x0}",
4336                "cmovz {cx},{x2}",
4337
4338                $step!($nr, B),
4339
4340                "dec {x0}", "jne 2b", // KITER
4341
4342                "3:",
4343                "mov 16({dim_arrx}),{x0}",
4344                "test {x0},{x0}", "je 5f", // POSTACCUM
4345
4346
4347                "mov {cx}, {x2}",
4348                "mov ({dim_arrx}),{x1}",
4349
4350                "4:", // KLEFT
4351                "prefetcht0 ({x2})",
4352                "prefetcht0 64({x2})",
4353                "prefetcht0 128({x2})",
4354                $step!($nr, B),
4355                "add {x1}, {x2}", "dec {x0}", "jne 4b", // KLEFT
4356
4357                "5:", // POSTACCUM
4358                c_load_2!(),
4359
4360                "cmpw $0, ({alpha_st})",
4361                "je 9f",
4362                alpha_scale!(),
4363                "9:",
4364
4365                "cmpw $0, ({beta_st})",
4366                "je 6f",
4367
4368                "cmpw $1, ({beta_st})",
4369                "je 15f",
4370
4371                load_beta!(),
4372                pire_base::cum_seq!($acc,$nr,C,2),
4373                "jmp 6f",
4374
4375                "15:",
4376                pire_base::cum_seq!($acc,$nr,C,1),
4377
4378                "6:",
4379                pire_base::cum_seq!($store,$nr,C),
4380
4381                ax = inout(reg) a => _,
4382                bx = inout(reg) b => _,
4383                cx = inout(reg) c_k => _,
4384                dim_arrx = inout(reg) dim_arr.as_ptr() => _,
4385                alphax = inout(reg) alpha => _,
4386                betax = inout(reg) beta => _,
4387                beta_st = in(reg) &beta_st,
4388                alpha_st = in(reg) &alpha_st,
4389                x0 = out(reg) _,
4390                x1 = out(reg)_,
4391                x2 = out(reg) _,
4392                x3 = out(reg) _,
4393                x4 = out(reg) _,
4394                x5 = out(reg) _,
4395                out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _,
4396                out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,
4397                out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _,
4398                out("zmm12") _, out("zmm13") _, out("zmm14") _, out("zmm15") _,
4399                out("zmm16") _, out("zmm17") _, out("zmm18") _, out("zmm19") _,
4400                out("zmm20") _, out("zmm21") _, out("zmm22") _, out("zmm23") _,
4401                out("zmm24") _, out("zmm25") _, out("zmm26") _, out("zmm27") _,
4402                out("zmm28") _, out("zmm29") _, out("zmm30") _, out("zmm31") _,
4403                out("k1") _,
4404                options(att_syntax)
4405            );
4406            if BUF {
4407                for j in 0..$nr {
4408                    f.call(c_k.add(j*MR), MR);
4409                }
4410                pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, MR, $nr, MR);
4411            } else {
4412                for j in 0..$nr {
4413                    f.call(c_k.add(j*c_cs), MR);
4414                }
4415            }
4416        }
4417    };
4418}
4419
4420#[cfg(target_arch = "x86")]
4421#[macro_export]
4422macro_rules! def_ukernel_sse {
4423    (
4424        $k_unit:tt,
4425        $step_macro:tt,
4426        $acc_macro:tt,
4427        $store_macro:tt,
4428        $mr:tt, $nr:tt,
4429        $n0:tt, $n1:tt,
4430        $b_layout:tt,
4431        $is_partial:tt,
4432        $func_name:ident
4433    ) => {
4434        pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
4435            a: *const TA, b: *const TB, c: *mut TC,
4436            alpha: *const TS, beta: *const TS,
4437            k: usize,
4438            d_arr: [usize; 3], c_cs: usize,
4439            m: usize, n: usize,
4440            f: F,
4441        ) {
4442            let alpha_st = if *alpha == ONE_SCALAR {
4443                0i32
4444            } else {
4445                1i32
4446            };
4447            let beta_st = if *beta == ZERO_SCALAR {
4448                0i32
4449            } else if *beta == ONE_SCALAR {
4450                1i32
4451            } else {
4452                2i32
4453            };
4454            const MR: usize = $mr * VS;
4455            let mut dim_arr = [d_arr[0]*8, d_arr[1]*8, c_cs*TC_SIZE, k / ($k_unit*4), (k % ($k_unit*4)) / $k_unit, beta_st as usize, alpha_st as usize];
4456            let mut ptr_arr = [alpha, beta];
4457            let mut cf = c;
4458            let mut c_buf = [ZERO;MR*$nr];
4459            if BUF {
4460                pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, n, MR);
4461                dim_arr[2] = MR*TC_SIZE;
4462                cf = c_buf.as_mut_ptr();
4463            }
4464            let _ = 'blk: {
4465                seq!(ni in $n0..$n1 {
4466                    if pire_base::n_cond!($n0, ni, n) {
4467                        pire_base::prefetch_c_sse!($mr,ni,c,c_cs);
4468                        asm!(
4469                            vzero_kernel!(),
4470
4471                            init_ab!($b_layout),
4472                            "test {x0}, {x0}", "je 3f", // CONSIDKLEFT
4473
4474                            "2:", // KITER
4475                            pire_base::prefetch_b!($b_layout),
4476                            $step_macro!(ni, $b_layout, 0),
4477                            $step_macro!(ni, $b_layout, 1),
4478                            $step_macro!(ni, $b_layout, 2),
4479                            $step_macro!(ni, $b_layout, 3),
4480
4481                            inc_a_k_unroll!($mr, 4),
4482                            inc_b_k_unroll!($b_layout, ni, 4),
4483
4484                            "dec {x0}", "jne 2b", // KITER
4485
4486                            "3:", // CONSIDKLEFT
4487                            "mov 16({dim_arrx}), {x0}",
4488                            "test {x0},{x0}", "je 5f", // POSTACCUM
4489
4490                            "4:", // KLEFT
4491                            $step_macro!(ni, $b_layout, 0),
4492                            inc_a_k_unroll!($mr, 1),
4493                            inc_b_k_unroll!($b_layout, ni, 1),
4494
4495                            "dec {x0}", "jne 4b", // KLEFT
4496
4497                            "5:", // POSTACCUM
4498                            c_load!(),
4499
4500                            "cmpw $0, 24({dim_arrx})",
4501                            "je 9f",
4502                            alpha_scale!(),
4503                            "9:",
4504
4505                            "cmpw $0, 20({dim_arrx})",
4506                            "je 6f",
4507
4508                            "cmpw $1, 20({dim_arrx})",
4509                            "je 15f",
4510
4511                            load_beta!(),
4512                            pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
4513                            "jmp 6f",
4514
4515                            "15:",
4516                            pire_base::cum_seq!($acc_macro,ni,$is_partial,1),
4517
4518                            "6:",
4519                            pire_base::cum_seq!($store_macro,ni,$is_partial),
4520
4521                            ax = inout(reg) a => _,
4522                            bx = inout(reg) b => _,
4523                            cx = inout(reg) cf => _,
4524                            ptr_arrx = inout(reg) ptr_arr.as_ptr() => _,
4525                            dim_arrx = inout(reg) dim_arr.as_ptr() => _,
4526                            x0 = out(reg) _,
4527                            out("xmm0") _, out("xmm1") _, out("xmm2") _, out("xmm3") _,
4528                            out("xmm4") _, out("xmm5") _, out("xmm6") _, out("xmm7") _,
4529                            options(att_syntax)
4530                        );
4531                        break 'blk;
4532                    }
4533                });
4534            };
4535            if BUF {
4536                for j in 0..n {
4537                    f.call(cf.add(j*MR), MR);
4538                }
4539                pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, MR);
4540            } else {
4541                for j in 0..n {
4542                    f.call(cf.add(j*c_cs), m);
4543                }
4544            }
4545        }
4546    };
4547}
4548
4549#[macro_export]
4550macro_rules! def_ukernel_neon {
4551    (
4552        $step_macro:tt,
4553        $acc_macro:tt,
4554        $store_macro:tt,
4555        $mr:tt, $nr:tt,
4556        $n0:tt, $n1:tt,
4557        $b_layout:tt,
4558        $is_partial:tt,
4559        $func_name:ident
4560    ) => {
4561        #[target_feature(enable="neon")]
4562        pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
4563            a: *const TA, b: *const TB, c: *mut TC,
4564            alpha: *const TA, beta: *const TB,
4565            k: usize,
4566            d_arr: [usize; 3], c_cs: usize,
4567            m: usize, n: usize,
4568            f: F,
4569        ) {
4570            const MR: usize = $mr * VS;
4571            use core::mem::size_of;
4572            let mut dim_arr = [d_arr[0]*size_of::<TB>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / 4, k % 4];
4573            let mut cf = c;
4574            let mut c_buf = [ZERO;MR*$nr];
4575            let alpha_st = if *alpha == ONE_SCALAR {
4576                0i32
4577            } else {
4578                1i32
4579            };
4580            let beta_st = if *beta == ZERO_SCALAR {
4581                0i32
4582            } else {
4583                1i32
4584            };
4585            let _ = 'blk: {
4586                seq!(ni in $n0..$n1 {
4587                    if pire_base::n_cond!($n0, ni, n) {
4588                        if BUF {
4589                            pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, ni, MR);
4590                            dim_arr[2] = MR*TC_SIZE;
4591                            cf = c_buf.as_mut_ptr();
4592                        }
4593                        asm!(
4594                            prefetch_c!(),
4595                            vzero_kernel!(),
4596
4597                            init_ab!($b_layout),
4598
4599                            // 3 -> CONSIDKLEFT
4600                            "cmp {x0}, #0", "BEQ 3f",
4601
4602                            // 2 -> KITER
4603                            "2:",
4604                            prefetch_0!(128, "{bx}"),
4605                            $step_macro!(ni, $b_layout),
4606                            $step_macro!(ni, $b_layout),
4607                            $step_macro!(ni, $b_layout),
4608                            $step_macro!(ni, $b_layout),
4609
4610                            "sub {x0}, {x0}, #1",
4611                            // 2 -> KITER
4612                            "cmp {x0}, 0",
4613                            "BNE 2b",
4614
4615                            // 3 -> CONSIDKLEFT
4616                            "3:",
4617                            "ldr {x0}, [{dim_arrx}, #32]",
4618                            "cmp {x0}, #0",
4619
4620                            // 5 -> POSTACCUM
4621                            "BEQ 5f",
4622                            // 4 -> KLEFT
4623                            "4:",
4624                            $step_macro!(ni, $b_layout),
4625
4626                            "sub {x0}, {x0}, #1",
4627
4628                            // 4 -> KLEFT
4629                            "cmp {x0}, 0",
4630                            "BNE 4b",
4631
4632                            // 5 -> POSTACCUM
4633                            "5:",
4634                            c_load!(),
4635                            "cmp {alpha_st:w}, #0",
4636                            "BEQ 13f",
4637                            alpha_scale!(),
4638                            "13:",
4639
4640                            "cmp {beta_st:w}, #0",
4641                            "BEQ 6f",
4642
4643                            load_beta!(),
4644
4645                            pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
4646
4647                            // 6 -> BETAZERO
4648                            "6:",
4649                            pire_base::cum_seq!($store_macro,ni,$is_partial),
4650
4651                            ax = inout(reg) a => _,
4652                            bx = inout(reg) b => _,
4653                            cx = inout(reg) cf => _,
4654                            dim_arrx = inout(reg) dim_arr.as_ptr() => _,
4655                            alphax = inout(reg) alpha => _,
4656                            betax = inout(reg) beta => _,
4657                            alpha_st = in(reg) alpha_st,
4658                            beta_st = in(reg) beta_st,
4659                            x0 = out(reg) _,
4660                            x1 = out(reg) _,
4661                            x2 = out(reg) _,
4662                            x3 = out(reg) _,
4663                            x4 = out(reg) _,
4664                            x5 = out(reg) _,
4665                            out("v0") _, out("v1") _, out("v2") _, out("v3") _, out("v4") _, out("v5") _, out("v6") _, out("v7") _,
4666                            out("v8") _, out("v9") _, out("v10") _, out("v11") _, out("v12") _, out("v13") _, out("v14") _, out("v15") _,
4667                            out("v16") _, out("v17") _, out("v18") _, out("v19") _, out("v20") _, out("v21") _, out("v22") _, out("v23") _,
4668                            out("v24") _, out("v25") _, out("v26") _, out("v27") _, out("v28") _, out("v29") _, out("v30") _, out("v31") _,
4669                        );
4670                        break 'blk;
4671                    }
4672                });
4673            };
4674            if BUF {
4675                for j in 0..n {
4676                    f.call(cf.add(j*MR), MR);
4677                }
4678                pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, MR);
4679            } else {
4680                for j in 0..n {
4681                    f.call(cf.add(j*c_cs), m);
4682                }
4683            }
4684        }
4685    };
4686}
4687
4688#[macro_export]
4689macro_rules! def_ukernel_neon_alt {
4690    (
4691        $step_macro:tt,
4692        $acc_macro:tt,
4693        $store_macro:tt,
4694        $mr:tt, $nr:tt,
4695        $n0:tt, $n1:tt,
4696        $b_layout:tt,
4697        $is_partial:tt,
4698        $func_name:ident
4699    ) => {
4700        #[target_feature(enable="neon")]
4701        pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
4702            a: *const TA, b: *const TB, c: *mut TC,
4703            alpha: *const TA, beta: *const TB,
4704            k: usize,
4705            d_arr: [usize; 3], c_cs: usize,
4706            m: usize, n: usize,
4707            f: F,
4708        ) {
4709            alt_arr!(alt);
4710            const MR: usize = $mr * VS;
4711            use core::mem::size_of;
4712            let mut dim_arr = [d_arr[0]*size_of::<TB>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / 4, k % 4];
4713            let mut cf = c;
4714            let mut c_buf = [ZERO;MR*$nr];
4715            let alpha_st = if *alpha == ONE_SCALAR {
4716                0i32
4717            } else {
4718                1i32
4719            };
4720            let beta_st = if *beta == ZERO_SCALAR {
4721                0i32
4722            } else {
4723                1i32
4724            };
4725            let _ = 'blk: {
4726                seq!(ni in $n0..$n1 {
4727                    if pire_base::n_cond!($n0, ni, n) {
4728                        if BUF {
4729                            pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, ni, MR);
4730                            dim_arr[2] = MR*TC_SIZE;
4731                            cf = c_buf.as_mut_ptr();
4732                        }
4733                        asm!(
4734                            prefetch_c!(),
4735                            vzero_kernel!(),
4736
4737                            init_ab!($b_layout),
4738
4739                            // 3 -> CONSIDKLEFT
4740                            "cmp {x0}, #0", "BEQ 3f",
4741
4742                            // 2 -> KITER
4743                            "2:",
4744                            prefetch_0!(128, "{bx}"),
4745                            $step_macro!(ni, $b_layout),
4746                            $step_macro!(ni, $b_layout),
4747                            $step_macro!(ni, $b_layout),
4748                            $step_macro!(ni, $b_layout),
4749
4750                            "sub {x0}, {x0}, #1",
4751                            // 2 -> KITER
4752                            "cmp {x0}, 0",
4753                            "BNE 2b",
4754
4755                            // 3 -> CONSIDKLEFT
4756                            "3:",
4757                            "ldr {x0}, [{dim_arrx}, #32]",
4758                            "cmp {x0}, #0",
4759
4760                            // 5 -> POSTACCUM
4761                            "BEQ 5f",
4762                            // 4 -> KLEFT
4763                            "4:",
4764                            $step_macro!(ni, $b_layout),
4765
4766                            "sub {x0}, {x0}, #1",
4767
4768                            // 4 -> KLEFT
4769                            "cmp {x0}, 0",
4770                            "BNE 4b",
4771
4772                            // 5 -> POSTACCUM
4773                            "5:",
4774                            c_load!(),
4775                            "cmp {alpha_st:w}, #0",
4776                            "BEQ 13f",
4777                            alpha_scale!(),
4778                            "13:",
4779                            "cmp {beta_st:w}, #0",
4780                            "BEQ 6f",
4781
4782                            load_beta!(),
4783
4784                            pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
4785
4786                            // 6 -> BETAZERO
4787                            "6:",
4788                            pire_base::cum_seq!($store_macro,ni,$is_partial),
4789
4790                            ax = inout(reg) a => _,
4791                            bx = inout(reg) b => _,
4792                            cx = inout(reg) cf => _,
4793                            dim_arrx = inout(reg) dim_arr.as_ptr() => _,
4794                            alphax = inout(reg) alpha => _,
4795                            betax = inout(reg) beta => _,
4796                            alpha_st = in(reg) alpha_st,
4797                            beta_st = in(reg) beta_st,
4798                            altx = inout(reg) alt.as_ptr() => _,
4799                            x0 = out(reg) _,
4800                            x1 = out(reg) _,
4801                            x2 = out(reg) _,
4802                            x3 = out(reg) _,
4803                            x4 = out(reg) _,
4804                            x5 = out(reg) _,
4805                            out("v0") _, out("v1") _, out("v2") _, out("v3") _, out("v4") _, out("v5") _, out("v6") _, out("v7") _,
4806                            out("v8") _, out("v9") _, out("v10") _, out("v11") _, out("v12") _, out("v13") _, out("v14") _, out("v15") _,
4807                            out("v16") _, out("v17") _, out("v18") _, out("v19") _, out("v20") _, out("v21") _, out("v22") _, out("v23") _,
4808                            out("v24") _, out("v25") _, out("v26") _, out("v27") _, out("v28") _, out("v29") _, out("v30") _, out("v31") _,
4809                        );
4810                        break 'blk;
4811                    }
4812                });
4813            };
4814            if BUF {
4815                for j in 0..n {
4816                    f.call(cf.add(j*MR), MR);
4817                }
4818                pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, MR);
4819            } else {
4820                for j in 0..n {
4821                    f.call(cf.add(j*c_cs), m);
4822                }
4823            }
4824        }
4825    };
4826}
4827
4828#[macro_export]
4829macro_rules! def_ukernel_neon_fp16 {
4830    (
4831        $step_macro:tt,
4832        $acc_macro:tt,
4833        $store_macro:tt,
4834        $mr:tt, $nr:tt,
4835        $n0:tt, $n1:tt,
4836        $b_layout:tt,
4837        $is_partial:tt,
4838        $func_name:ident
4839    ) => {
4840        #[target_feature(enable="neon,fp16")]
4841        pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
4842            a: *const TA, b: *const TB, c: *mut TC,
4843            alpha: *const TA, beta: *const TB,
4844            k: usize,
4845            d_arr: [usize; 3], c_cs: usize,
4846            m: usize, n: usize,
4847            f: F,
4848        ) {
4849            const MR: usize = $mr * VS;
4850            use core::mem::size_of;
4851            let mut dim_arr = [d_arr[0]*size_of::<TB>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / 4, k % 4];
4852            let mut cf = c;
4853            let mut c_buf = [ZERO;MR*$nr];
4854            let alpha_st = if *alpha == ONE_SCALAR {
4855                0i32
4856            } else {
4857                1i32
4858            };
4859            let beta_st = if *beta == ZERO_SCALAR {
4860                0i32
4861            } else {
4862                1i32
4863            };
4864            let _ = 'blk: {
4865                seq!(ni in $n0..$n1 {
4866                    if BUF {
4867                        pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, ni, MR);
4868                        dim_arr[2] = MR*TC_SIZE;
4869                        cf = c_buf.as_mut_ptr();
4870                    }
4871                    if pire_base::n_cond!($n0, ni, n) {
4872                        asm!(
4873                            prefetch_c!(),
4874                            vzero_kernel!(),
4875
4876                            init_ab!($b_layout),
4877
4878                            // 3 -> CONSIDKLEFT
4879                            "cmp {x0}, #0", "BEQ 3f",
4880
4881                            // 2 -> KITER
4882                            "2:",
4883                            prefetch_0!(128, "{bx}"),
4884                            $step_macro!(ni, $b_layout),
4885                            $step_macro!(ni, $b_layout),
4886                            $step_macro!(ni, $b_layout),
4887                            $step_macro!(ni, $b_layout),
4888
4889                            "sub {x0}, {x0}, #1",
4890                            // 2 -> KITER
4891                            "cmp {x0}, 0",
4892                            "BNE 2b",
4893
4894                            // 3 -> CONSIDKLEFT
4895                            "3:",
4896                            "ldr {x0}, [{dim_arrx}, #32]",
4897                            "cmp {x0}, #0",
4898
4899                            // 5 -> POSTACCUM
4900                            "BEQ 5f",
4901                            // 4 -> KLEFT
4902                            "4:",
4903                            $step_macro!(ni, $b_layout),
4904
4905                            "sub {x0}, {x0}, #1",
4906
4907                            // 4 -> KLEFT
4908                            "cmp {x0}, 0",
4909                            "BNE 4b",
4910
4911                            // 5 -> POSTACCUM
4912                            "5:",
4913                            c_load!(),
4914                            "cmp {alpha_st:w}, #0",
4915                            "BEQ 13f",
4916                            alpha_scale!(),
4917                            "13:",
4918
4919                            "cmp {beta_st:w}, #0",
4920                            "BEQ 6f",
4921
4922                            load_beta!(),
4923
4924                            pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
4925
4926                            // 6 -> BETAZERO
4927                            "6:",
4928                            pire_base::cum_seq!($store_macro,ni,$is_partial),
4929
4930                            ax = inout(reg) a => _,
4931                            bx = inout(reg) b => _,
4932                            cx = inout(reg) cf => _,
4933                            dim_arrx = inout(reg) dim_arr.as_ptr() => _,
4934                            alphax = inout(reg) alpha => _,
4935                            betax = inout(reg) beta => _,
4936                            alpha_st = in(reg) alpha_st,
4937                            beta_st = in(reg) beta_st,
4938                            x0 = out(reg) _,
4939                            x1 = out(reg) _,
4940                            x2 = out(reg) _,
4941                            x3 = out(reg) _,
4942                            x4 = out(reg) _,
4943                            x5 = out(reg) _,
4944                            out("v0") _, out("v1") _, out("v2") _, out("v3") _, out("v4") _, out("v5") _, out("v6") _, out("v7") _,
4945                            out("v8") _, out("v9") _, out("v10") _, out("v11") _, out("v12") _, out("v13") _, out("v14") _, out("v15") _,
4946                            out("v16") _, out("v17") _, out("v18") _, out("v19") _, out("v20") _, out("v21") _, out("v22") _, out("v23") _,
4947                            out("v24") _, out("v25") _, out("v26") _, out("v27") _, out("v28") _, out("v29") _, out("v30") _, out("v31") _,
4948                        );
4949                        break 'blk;
4950                    }
4951                });
4952            };
4953            if BUF {
4954                for j in 0..n {
4955                    f.call(cf.add(j*MR), MR);
4956                }
4957                pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, MR);
4958            } else {
4959                for j in 0..n {
4960                    f.call(cf.add(j*c_cs), m);
4961                }
4962            }
4963        }
4964    };
4965}
4966
4967#[macro_export]
4968macro_rules! def_ukernel_neon_i8mm {
4969    (
4970        $step_macro:tt,
4971        $acc_macro:tt,
4972        $store_macro:tt,
4973        $mr:tt, $nr:tt,
4974        $n0:tt, $n1:tt,
4975        $b_layout:tt,
4976        $is_partial:tt,
4977        $func_name:ident
4978    ) => {
4979        #[target_feature(enable="neon,i8mm")]
4980        pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
4981            a: *const TA, b: *const TB, c: *mut TC,
4982            alpha: *const TS, beta: *const TS,
4983            k: usize,
4984            d_arr: [usize; 3], c_cs: usize,
4985            m: usize, n: usize,
4986            f: F,
4987        ) {
4988            const MR: usize = $mr * VS;
4989            use core::mem::size_of;
4990            let mut dim_arr = [d_arr[0]*size_of::<TB>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / 32, (k % 32) / 8];
4991            let mut cf = c;
4992            let mut c_buf = [ZERO;MR*$nr];
4993            let alpha_st = if *alpha == ONE_SCALAR {
4994                0i32
4995            } else {
4996                1i32
4997            };
4998            let beta_st = if *beta == ZERO_SCALAR {
4999                0i32
5000            } else if *beta == ONE_SCALAR {
5001                1i32
5002            } else {
5003                2i32
5004            };
5005            let _ = 'blk: {
5006                seq!(ni in $n0..$n1 {
5007                    if pire_base::n_cond!($n0, ni, n) {
5008                        if BUF {
5009                            pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, ni, MR);
5010                            dim_arr[2] = MR*TC_SIZE;
5011                            cf = c_buf.as_mut_ptr();
5012                        }
5013                        asm!(
5014                            prefetch_c!(),
5015                            vzero_kernel!(),
5016
5017                            init_ab!($b_layout),
5018
5019                            // 3 -> CONSIDKLEFT
5020                            "cmp {x0}, #0", "BEQ 3f",
5021
5022                            // 2 -> KITER
5023                            "2:",
5024                            prefetch_0!(128, "{bx}"),
5025                            $step_macro!(ni, $b_layout),
5026                            $step_macro!(ni, $b_layout),
5027                            $step_macro!(ni, $b_layout),
5028                            $step_macro!(ni, $b_layout),
5029
5030                            "sub {x0}, {x0}, #1",
5031                            // 2 -> KITER
5032                            "cmp {x0}, 0",
5033                            "BNE 2b",
5034
5035                            // 3 -> CONSIDKLEFT
5036                            "3:",
5037                            "ldr {x0}, [{dim_arrx}, #32]",
5038                            "cmp {x0}, #0",
5039
5040                            // 5 -> POSTACCUM
5041                            "BEQ 5f",
5042                            // 4 -> KLEFT
5043                            "4:",
5044                            $step_macro!(ni, $b_layout),
5045
5046                            "sub {x0}, {x0}, #1",
5047
5048                            // 4 -> KLEFT
5049                            "cmp {x0}, 0",
5050                            "BNE 4b",
5051
5052                            // 5 -> POSTACCUM
5053                            "5:",
5054                            c_load!(),
5055                            "cmp {alpha_st:w}, #0",
5056                            "BEQ 13f",
5057                            alpha_scale!(),
5058                            "13:",
5059
5060                            "cmp {beta_st:w}, #0",
5061                            "BEQ 6f",
5062
5063                            "cmp {beta_st:w}, #1",
5064                            "BEQ 9f",
5065
5066                            load_beta!(),
5067                            pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
5068                            "B 6f",
5069
5070                            "9:",
5071                            // 9 -> BETAONE
5072                            pire_base::cum_seq!($acc_macro,ni,$is_partial,1),
5073
5074                            // 6 -> BETAZERO
5075                            "6:",
5076                            pire_base::cum_seq!($store_macro,ni,$is_partial),
5077
5078                            ax = inout(reg) a => _,
5079                            bx = inout(reg) b => _,
5080                            cx = inout(reg) cf => _,
5081                            dim_arrx = inout(reg) dim_arr.as_ptr() => _,
5082                            alphax = inout(reg) alpha => _,
5083                            betax = inout(reg) beta => _,
5084                            alpha_st = in(reg) alpha_st,
5085                            beta_st = in(reg) beta_st,
5086                            x0 = out(reg) _,
5087                            x1 = out(reg) _,
5088                            x2 = out(reg) _,
5089                            x3 = out(reg) _,
5090                            x4 = out(reg) _,
5091                            x5 = out(reg) _,
5092                            x6 = out(reg) _,
5093                            x7 = out(reg) _,
5094                            x8 = out(reg) _,
5095                            x9 = out(reg) _,
5096                            x10 = out(reg) _,
5097                            x11 = out(reg) _,
5098                            out("v0") _, out("v1") _, out("v2") _, out("v3") _, out("v4") _, out("v5") _, out("v6") _, out("v7") _,
5099                            out("v8") _, out("v9") _, out("v10") _, out("v11") _, out("v12") _, out("v13") _, out("v14") _, out("v15") _,
5100                            out("v16") _, out("v17") _, out("v18") _, out("v19") _, out("v20") _, out("v21") _, out("v22") _, out("v23") _,
5101                            out("v24") _, out("v25") _, out("v26") _, out("v27") _, out("v28") _, out("v29") _, out("v30") _, out("v31") _,
5102                        );
5103                        break 'blk;
5104                    }
5105                });
5106            };
5107            if BUF {
5108                for j in 0..n {
5109                    f.call(cf.add(j*MR), MR);
5110                }
5111                pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, MR);
5112            } else {
5113                for j in 0..n {
5114                    f.call(cf.add(j*c_cs), m);
5115                }
5116            }
5117        }
5118    };
5119}
5120
5121#[macro_export]
5122macro_rules! def_ukernel_sve {
5123    (
5124        $step_macro:tt,
5125        $acc_macro:tt,
5126        $store_macro:tt,
5127        $mr:tt, $nr:tt,
5128        $n0:tt, $n1:tt,
5129        $b_layout:tt,
5130        $is_partial:tt,
5131        $func_name:ident
5132    ) => {
5133        #[target_feature(enable="neon,sve")]
5134        pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
5135            a: *const TA, b: *const TB, c: *mut TC,
5136            alpha: *const TA, beta: *const TB,
5137            k: usize,
5138            d_arr: [usize; 3], c_cs: usize,
5139            m: usize, n: usize,
5140            f: F,
5141        ) {
5142            use core::mem::size_of;
5143            let vs = sve_vs();
5144            let m_left = if m % vs == 0 {vs} else {m%vs};
5145            let inc_a = vs * $mr * size_of::<TA>();
5146            let mr = $mr * vs;
5147            let mut dim_arr = [d_arr[0]*size_of::<TB>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / 4, k % 4];
5148            let mut cf = c;
5149            let mut c_buf = [ZERO;(256/size_of::<TC>())*$mr*$nr];
5150            let alpha_st = if *alpha == ONE_SCALAR {
5151                0i32
5152            } else {
5153                1i32
5154            };
5155            let beta_st = if *beta == ZERO_SCALAR {
5156                0i32
5157            } else {
5158                1i32
5159            };
5160            let _ = 'blk: {
5161                seq!(ni in $n0..$n1 {
5162                    // usingy dynamic n leads to bug due sve on windows
5163                    // see: https://github.com/llvm/llvm-project/issues/80009
5164                    if BUF {
5165                        pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, ni, mr);
5166                        dim_arr[2] = mr*TC_SIZE;
5167                        cf = c_buf.as_mut_ptr();
5168                    }
5169                    if pire_base::n_cond!($n0, ni, n) {
5170                        asm!(
5171                            "ptrue p0.h",
5172                            "mov {m_s}, #0",
5173                            "/* {m_e} */", "\n",
5174                            prefetch_c!(),
5175                            vzero_kernel!(),
5176
5177                            init_ab!($b_layout),
5178
5179                            // 3 -> CONSIDKLEFT
5180                            "cmp {x0}, #0", "BEQ 3f",
5181
5182                            // 2 -> KITER
5183                            "2:",
5184                            prefetch_0!(128, "{bx}"),
5185                            $step_macro!(ni, $b_layout),
5186                            $step_macro!(ni, $b_layout),
5187                            $step_macro!(ni, $b_layout),
5188                            $step_macro!(ni, $b_layout),
5189
5190                            "sub {x0}, {x0}, #1",
5191                            // 2 -> KITER
5192                            "cmp {x0}, 0",
5193                            "BNE 2b",
5194
5195                            // 3 -> CONSIDKLEFT
5196                            "3:",
5197                            "ldr {x0}, [{dim_arrx}, #32]",
5198                            "cmp {x0}, #0",
5199
5200                            // 5 -> POSTACCUM
5201                            "BEQ 5f",
5202                            // 4 -> KLEFT
5203                            "4:",
5204                            $step_macro!(ni, $b_layout),
5205
5206                            "sub {x0}, {x0}, #1",
5207
5208                            // 4 -> KLEFT
5209                            "cmp {x0}, 0",
5210                            "BNE 4b",
5211
5212                            // 5 -> POSTACCUM
5213                            "5:",
5214                            c_load!(),
5215                            "cmp {alpha_st:w}, #0",
5216                            "BEQ 13f",
5217                            alpha_scale!(),
5218                            "13:",
5219
5220                            "cmp {beta_st:w}, #0",
5221                            "BEQ 6f",
5222
5223                            load_beta!(),
5224
5225                            pire_base::cum_seq!($acc_macro,ni,$is_partial),
5226
5227                            // 6 -> BETAZERO
5228                            "6:",
5229                            pire_base::cum_seq!($store_macro,ni,$is_partial),
5230
5231                            ax = inout(reg) a => _,
5232                            bx = inout(reg) b => _,
5233                            cx = inout(reg) cf => _,
5234                            dim_arrx = inout(reg) dim_arr.as_ptr() => _,
5235                            alphax = inout(reg) alpha => _,
5236                            betax = inout(reg) beta => _,
5237                            incax = in(reg) inc_a as u64,
5238                            alpha_st = in(reg) alpha_st,
5239                            beta_st = in(reg) beta_st,
5240                            m_s = out(reg) _,
5241                            m_e = inout(reg) m_left as u64 => _,
5242                            x0 = out(reg) _,
5243                            x1 = out(reg) _,
5244                            x2 = out(reg) _,
5245                            x3 = out(reg) _,
5246                            x4 = out(reg) _,
5247                            x5 = out(reg) _,
5248                            x6 = out(reg) _,
5249                            x7 = out(reg) _,
5250                            out("v0") _, out("v1") _, out("v2") _, out("v3") _, out("v4") _, out("v5") _, out("v6") _, out("v7") _,
5251                            out("v8") _, out("v9") _, out("v10") _, out("v11") _, out("v12") _, out("v13") _, out("v14") _, out("v15") _,
5252                            out("v16") _, out("v17") _, out("v18") _, out("v19") _, out("v20") _, out("v21") _, out("v22") _, out("v23") _,
5253                            out("v24") _, out("v25") _, out("v26") _, out("v27") _, out("v28") _, out("v29") _, out("v30") _, out("v31") _,
5254                            out("p0") _, out("p1") _, out("p2") _, out("p3") _,
5255                        );
5256                        break 'blk;
5257                    }
5258                });
5259            };
5260            if BUF {
5261                for j in 0..n {
5262                    f.call(cf.add(j*mr), mr);
5263                }
5264                pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, mr);
5265            } else {
5266                for j in 0..n {
5267                    f.call(cf.add(j*c_cs), m);
5268                }
5269            }
5270        }
5271    };
5272}
5273
5274#[macro_export]
5275macro_rules! def_ukernel_sve_i8mm {
5276    (
5277        $step_macro:tt,
5278        $acc_macro:tt,
5279        $store_macro:tt,
5280        $mr:tt, $nr:tt,
5281        $n0:tt, $n1:tt,
5282        $b_layout:tt,
5283        $is_partial:tt,
5284        // $feature_enable:tt,
5285        $func_name:ident
5286    ) => {
5287        #[target_feature(enable="neon,sve,i8mm")]
5288        pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
5289            a: *const TA, b: *const TB, c: *mut TC,
5290            alpha: *const TS, beta: *const TS,
5291            k: usize,
5292            d_arr: [usize; 3], c_cs: usize,
5293            m: usize, n: usize,
5294            f: F,
5295        ) {
5296            use core::mem::size_of;
5297            let vs = sve_vs();
5298            let m_left = if m % vs == 0 {vs} else {m%vs};
5299            let inc_a = $mr * vs * size_of::<TA>() * 8;
5300            let mr = $mr * vs;
5301            let mut dim_arr = [d_arr[0]*size_of::<TB>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / 32, (k % 32) / 8];
5302            let mut cf = c;
5303            let mut c_buf = [ZERO;(256/size_of::<TC>())*$mr*$nr];
5304            let alpha_st = if *alpha == ONE_SCALAR {
5305                0i32
5306            } else {
5307                1i32
5308            };
5309            let beta_st = if *beta == ZERO_SCALAR {
5310                0i32
5311            } else if *beta == ONE_SCALAR {
5312                1i32
5313            } else {
5314                2i32
5315            };
5316            let _ = 'blk: {
5317                seq!(ni in $n0..$n1 {
5318                    // usingy dynamic n leads to bug due to llvm bug sve on windows
5319                    // see: https://github.com/llvm/llvm-project/issues/80009
5320                    if BUF {
5321                        pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, ni, mr);
5322                        dim_arr[2] = mr*TC_SIZE;
5323                        cf = c_buf.as_mut_ptr();
5324                    }
5325                    if pire_base::n_cond!($n0, ni, n) {
5326                        asm!(
5327                            "ptrue p0.h",
5328                            "/* {m_s} */", "\n",
5329                            "/* {m_e} */", "\n",
5330                            prefetch_c!(),
5331                            vzero_kernel!(),
5332
5333                            init_ab!($b_layout),
5334
5335                            // 3 -> CONSIDKLEFT
5336                            "cmp {x0}, #0", "BEQ 3f",
5337
5338                            // 2 -> KITER
5339                            "2:",
5340                            prefetch_0!(128, "{bx}"),
5341                            $step_macro!(ni, $b_layout),
5342                            $step_macro!(ni, $b_layout),
5343                            $step_macro!(ni, $b_layout),
5344                            $step_macro!(ni, $b_layout),
5345
5346                            "sub {x0}, {x0}, #1",
5347                            // 2 -> KITER
5348                            "cmp {x0}, 0",
5349                            "BNE 2b",
5350
5351                            // 3 -> CONSIDKLEFT
5352                            "3:",
5353                            "ldr {x0}, [{dim_arrx}, #32]",
5354                            "cmp {x0}, #0",
5355
5356                            // 5 -> POSTACCUM
5357                            "BEQ 5f",
5358                            // 4 -> KLEFT
5359                            "4:",
5360                            $step_macro!(ni, $b_layout),
5361
5362                            "sub {x0}, {x0}, #1",
5363
5364                            // 4 -> KLEFT
5365                            "cmp {x0}, 0",
5366                            "BNE 4b",
5367
5368                            // 5 -> POSTACCUM
5369                            "5:",
5370                            c_load!(),
5371                            "cmp {alpha_st:w}, #0",
5372                            "BEQ 13f",
5373                            alpha_scale!(),
5374                            "13:",
5375
5376                            "cmp {beta_st:w}, #0",
5377                            "BEQ 6f",
5378
5379                            "cmp {beta_st:w}, #1",
5380                            "BEQ 9f",
5381
5382                            load_beta!(),
5383                            pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
5384                            "B 6f",
5385
5386                            "9:",
5387                            // 9 -> BETAONE
5388                            pire_base::cum_seq!($acc_macro,ni,$is_partial,1),
5389
5390                            // 6 -> BETAZERO
5391                            "6:",
5392                            pire_base::cum_seq!($store_macro,ni,$is_partial),
5393                            ax = inout(reg) a => _,
5394                            bx = inout(reg) b => _,
5395                            cx = inout(reg) cf => _,
5396                            dim_arrx = inout(reg) dim_arr.as_ptr() => _,
5397                            alphax = inout(reg) alpha => _,
5398                            betax = inout(reg) beta => _,
5399                            incax = in(reg) inc_a as u64,
5400                            alpha_st = in(reg) alpha_st,
5401                            beta_st = in(reg) beta_st,
5402                            m_s = in(reg) 0 as u64,
5403                            m_e = in(reg) m_left as u64,
5404                            x0 = out(reg) _,
5405                            x1 = out(reg) _,
5406                            x2 = out(reg) _,
5407                            x3 = out(reg) _,
5408                            x4 = out(reg) _,
5409                            x5 = out(reg) _,
5410                            x6 = out(reg) _,
5411                            x7 = out(reg) _,
5412                            x8 = out(reg) _,
5413                            x9 = out(reg) _,
5414                            x10 = out(reg) _,
5415                            x11 = out(reg) _,
5416                            out("v0") _, out("v1") _, out("v2") _, out("v3") _, out("v4") _, out("v5") _, out("v6") _, out("v7") _,
5417                            out("v8") _, out("v9") _, out("v10") _, out("v11") _, out("v12") _, out("v13") _, out("v14") _, out("v15") _,
5418                            out("v16") _, out("v17") _, out("v18") _, out("v19") _, out("v20") _, out("v21") _, out("v22") _, out("v23") _,
5419                            out("v24") _, out("v25") _, out("v26") _, out("v27") _, out("v28") _, out("v29") _, out("v30") _, out("v31") _,
5420                            out("p0") _, out("p1") _, out("p2") _, out("p3") _,
5421                        );
5422                        break 'blk;
5423                    }
5424                });
5425            };
5426            if BUF {
5427                for j in 0..n {
5428                    f.call(cf.add(j*mr), mr);
5429                }
5430                pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, mr);
5431            } else {
5432                for j in 0..n {
5433                    f.call(cf.add(j*c_cs), m);
5434                }
5435            }
5436        }
5437    };
5438}
5439
5440// mod test {
5441//     // test split_c_range
5442//     #[test]
5443//     fn test_split_c_range() {
5444//         let m = 143;
5445//         let mc = 4800;
5446//         let mr = 24;
5447//         let ic_par = 4;
5448//         for ic_id in 0..ic_par {
5449//             let (mc_start, mc_end, mc_left) = super::split_c_range(m, mc, mr, ic_id, ic_par);
5450//             println!("mc_start: {}, mc_end: {}, mc_left: {}", mc_start, mc_end, mc_left);
5451//         }
5452//         assert!(false);
5453//     }
5454// }