1use core::mem::size_of;
6use once_cell::sync::Lazy;
7use std::sync::{Barrier, Mutex, MutexGuard, RwLock, RwLockReadGuard};
8
9pub mod range_rwlock;
10
11#[derive(Copy, Clone)]
12pub struct IdentityFn;
13
14pub trait UnaryFn<T>: Copy + std::marker::Sync {
15 unsafe fn call(self, c: *mut T, m: usize);
16}
17
18impl<T> UnaryFn<T> for IdentityFn {
19 #[inline(always)]
20 unsafe fn call(self, _c: *mut T, _m: usize) {}
21}
22
23impl<T> UnaryFn<T> for unsafe fn(*mut T, m: usize) {
24 #[inline(always)]
25 unsafe fn call(self, c: *mut T, m: usize) {
26 self(c, m);
27 }
28}
29
30#[inline(always)]
31pub unsafe fn load_buf<T: Copy>(c: *const T, c_rs: usize, c_cs: usize, c_buf: &mut [T], m: usize, n: usize, mr: usize) {
32 for j in 0..n {
33 for i in 0..m {
34 c_buf[i + j * mr] = *c.add(i * c_rs + j * c_cs);
35 }
36 }
37}
38
39#[inline(always)]
40pub unsafe fn store_buf<T: Copy>(c: *mut T, c_rs: usize, c_cs: usize, c_buf: &[T], m: usize, n: usize, mr: usize) {
41 for j in 0..n {
42 for i in 0..m {
43 *c.add(i * c_rs + j * c_cs) = c_buf[i + j * mr];
44 }
45 }
46}
47
48pub fn matrix_size(m: usize, n: usize) -> usize {
49 n * m
50}
51
52use range_rwlock::{RangeLock, RangeLockReadGuard, RangeLockWriteGuard};
53
54#[cfg(target_arch = "x86_64")]
55#[derive(Copy, Clone)]
56pub struct CpuFeatures {
57 pub sse: bool,
58 pub sse2: bool,
59 pub sse3: bool,
60 pub ssse3: bool,
61 pub avx: bool,
62 pub avx2: bool,
63 pub avx512f: bool,
64 pub avx512f16: bool,
65 pub avx512bw: bool,
67 pub avx512_vnni: bool,
68 pub fma: bool,
69 pub fma4: bool,
70 pub f16c: bool,
71}
72
73#[cfg(target_arch = "x86")]
74#[derive(Copy, Clone)]
75pub struct CpuFeatures {
76 pub sse: bool,
77 pub sse2: bool,
78 pub sse3: bool,
79 pub ssse3: bool,
80}
81
82#[cfg(target_arch = "aarch64")]
83#[derive(Copy, Clone)]
84pub struct CpuFeatures {
85 pub sve: bool,
86 pub neon: bool,
87 pub fp16: bool,
88 pub f32mm: bool,
89 pub fcma: bool,
90 pub i8mm: bool,
91}
92
93#[cfg(not(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64")))]
94#[derive(Copy, Clone)]
95pub struct CpuFeatures {
96 pub dummy: bool,
97}
98
99const CACHELINE_PAD: usize = 1024;
101
102pub struct HWConfig {
103 pub cpu_ft: CpuFeatures,
104 pub hw_model: HWModel,
105 is_l1_shared: bool,
106 is_l2_shared: bool,
107 is_l3_shared: bool,
108}
109
110impl HWConfig {
111 pub fn get_cache_info(&self) -> (bool, bool, bool) {
112 (self.is_l1_shared, self.is_l2_shared, self.is_l3_shared)
113 }
114 pub fn hw_model(&self) -> HWModel {
115 self.hw_model
116 }
117
118 pub fn cpu_ft(&self) -> CpuFeatures {
119 self.cpu_ft
120 }
121}
122
123#[derive(Copy, Clone)]
124pub enum HWModel {
125 Reference,
126 Haswell,
127 Skylake,
128}
129
130const SKYLAKE: [u8; 13] = [78, 85, 94, 126, 140, 141, 167, 151, 154, 183, 186, 143, 207];
131
132const HASWELL: [u8; 10] = [69, 70, 63, 42, 58, 165, 79, 86, 61, 71];
133
134impl HWModel {
135 pub fn from_hw(family_id: u8, model_id: u8) -> Self {
136 if family_id == 6 {
137 if SKYLAKE.contains(&model_id) {
138 return HWModel::Skylake;
139 }
140 if HASWELL.contains(&model_id) {
141 return HWModel::Haswell;
142 }
143 }
144
145 return HWModel::Reference;
147 }
148 pub fn get_cache_info(&self) -> (bool, bool, bool) {
149 match self {
150 HWModel::Reference => (false, false, true),
151 HWModel::Haswell => (false, false, true),
152 HWModel::Skylake => (false, false, true),
153 }
154 }
155}
156
157#[inline]
162fn detect_hw_config() -> HWConfig {
163 #[cfg(target_arch = "x86_64")]
164 {
165 let cpuid = raw_cpuid::CpuId::new();
166 let feature_info = cpuid.get_feature_info().unwrap();
167 let extended_feature_info = cpuid.get_extended_feature_info().unwrap();
168 let sse = feature_info.has_sse();
169 let sse2 = feature_info.has_sse2();
170 let sse3 = feature_info.has_sse3();
171 let ssse3 = feature_info.has_ssse3();
172 let avx = feature_info.has_avx();
173 let fma = feature_info.has_fma();
174 let avx2 = extended_feature_info.has_avx2();
175 let avx512f16 = extended_feature_info.has_avx512_fp16();
176 let avx512f = extended_feature_info.has_avx512f();
178 let avx512bw = extended_feature_info.has_avx512bw();
179 let avx512_vnni = extended_feature_info.has_avx512vnni();
180 let f16c = feature_info.has_f16c();
181 let extended_processor_info = cpuid.get_extended_processor_and_feature_identifiers().unwrap();
182 let fma4 = extended_processor_info.has_fma4();
183 let cpu_ft = CpuFeatures {
184 sse,
185 sse2,
186 sse3,
187 ssse3,
188 avx,
189 avx2,
190 avx512f,
191 avx512f16,
192 avx512bw,
193 avx512_vnni,
194 fma,
195 fma4,
196 f16c,
197 };
198 let family_id = feature_info.family_id();
199 let model_id = feature_info.model_id();
200 let hw_model = HWModel::from_hw(family_id, model_id);
201 let (is_l1_shared, is_l2_shared, is_l3_shared) = hw_model.get_cache_info();
202 return HWConfig { cpu_ft, hw_model, is_l1_shared, is_l2_shared, is_l3_shared };
203 }
204 #[cfg(target_arch = "x86")]
205 {
206 let cpuid = raw_cpuid::CpuId::new();
207 let feature_info = cpuid.get_feature_info().unwrap();
208 let sse = feature_info.has_sse();
209 let sse2 = feature_info.has_sse2();
210 let sse3 = feature_info.has_sse3();
211 let ssse3 = feature_info.has_ssse3();
212 let cpu_ft = CpuFeatures { sse, sse2, sse3, ssse3 };
213 let family_id = feature_info.family_id();
214 let model_id = feature_info.model_id();
215 let hw_model = HWModel::from_hw(family_id, model_id);
216 let (is_l1_shared, is_l2_shared, is_l3_shared) = hw_model.get_cache_info();
217 return HWConfig { cpu_ft, hw_model, is_l1_shared, is_l2_shared, is_l3_shared };
218 }
219 #[cfg(target_arch = "aarch64")]
220 {
221 use std::arch::is_aarch64_feature_detected;
222 let neon = is_aarch64_feature_detected!("neon");
223 let sve = is_aarch64_feature_detected!("sve");
224 let fp16 = is_aarch64_feature_detected!("fp16");
225 let f32mm = is_aarch64_feature_detected!("f32mm");
226 let fcma = is_aarch64_feature_detected!("fcma");
227 let i8mm = is_aarch64_feature_detected!("i8mm");
228
229 return HWConfig {
230 cpu_ft: CpuFeatures { neon, sve, fp16, f32mm, fcma, i8mm },
231 hw_model: HWModel::Reference,
232 is_l1_shared: false,
233 is_l2_shared: false,
234 is_l3_shared: true,
235 };
236 }
237 #[cfg(not(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64")))]
238 {
239 return HWConfig {
240 cpu_ft: CpuFeatures { dummy: false },
241 hw_model: HWModel::Reference,
242 is_l1_shared: false,
243 is_l2_shared: false,
244 is_l3_shared: true,
245 };
246 }
247}
248
249#[cfg(feature = "debug_cpu_features")]
250#[allow(unused)]
251fn apply_debug_cpu_features(cpu_ft: &mut CpuFeatures) {
252 #[cfg(target_arch = "x86_64")]
253 {
254 let sse_turn_off = std::env::var("PIRE_SSE_OFF").is_ok();
255 let sse2_turn_off = std::env::var("PIRE_SSE2_OFF").is_ok();
256 let sse3_turn_off = std::env::var("PIRE_SSE3_OFF").is_ok();
257 let ssse3_turn_off = std::env::var("PIRE_SSSE3_OFF").is_ok();
258 let avx_turn_off = std::env::var("PIRE_AVX_OFF").is_ok();
259 let avx2_turn_off = std::env::var("PIRE_AVX2_OFF").is_ok();
260 let avx512f_turn_off = std::env::var("PIRE_AVX512F_OFF").is_ok();
261 let avx512f16_turn_off = std::env::var("PIRE_AVX512F16_OFF").is_ok();
262 let avx512bw_turn_off = std::env::var("PIRE_AVX512BW_OFF").is_ok();
263 let avx512_vnni_turn_off = std::env::var("PIRE_AVX512_VNNI_OFF").is_ok();
264 let fma_turn_off = std::env::var("PIRE_FMA_OFF").is_ok();
265 let fma4_turn_off = std::env::var("PIRE_FMA4_OFF").is_ok();
266 let f16c_turn_off = std::env::var("PIRE_F16C_OFF").is_ok();
267
268 cpu_ft.sse = cpu_ft.sse && !sse_turn_off;
269 cpu_ft.sse2 = cpu_ft.sse2 && !sse2_turn_off;
270 cpu_ft.sse3 = cpu_ft.sse3 && !sse3_turn_off;
271 cpu_ft.ssse3 = cpu_ft.ssse3 && !ssse3_turn_off;
272 cpu_ft.avx = cpu_ft.avx && !avx_turn_off;
273 cpu_ft.avx2 = cpu_ft.avx2 && !avx2_turn_off;
274 cpu_ft.avx512f = cpu_ft.avx512f && !avx512f_turn_off;
275 cpu_ft.avx512f16 = cpu_ft.avx512f16 && !avx512f16_turn_off;
276 cpu_ft.avx512bw = cpu_ft.avx512bw && !avx512bw_turn_off;
277 cpu_ft.avx512_vnni = cpu_ft.avx512_vnni && !avx512_vnni_turn_off;
278 cpu_ft.fma = cpu_ft.fma && !fma_turn_off;
279 cpu_ft.fma4 = cpu_ft.fma4 && !fma4_turn_off;
280 cpu_ft.f16c = cpu_ft.f16c && !f16c_turn_off;
281 }
282 #[cfg(target_arch = "x86")]
283 {
284 let sse_turn_off = std::env::var("PIRE_SSE_OFF").is_ok();
285 let sse2_turn_off = std::env::var("PIRE_SSE2_OFF").is_ok();
286 let sse3_turn_off = std::env::var("PIRE_SSE3_OFF").is_ok();
287 let ssse3_turn_off = std::env::var("PIRE_SSSE3_OFF").is_ok();
288
289 cpu_ft.sse = cpu_ft.sse && !sse_turn_off;
290 cpu_ft.sse2 = cpu_ft.sse2 && !sse2_turn_off;
291 cpu_ft.sse3 = cpu_ft.sse3 && !sse3_turn_off;
292 cpu_ft.ssse3 = cpu_ft.ssse3 && !ssse3_turn_off;
293 }
294 #[cfg(target_arch = "aarch64")]
295 {
296 let neon_turn_off = std::env::var("PIRE_NEON_OFF").is_ok();
297 let sve_turn_off = std::env::var("PIRE_SVE_OFF").is_ok();
298 let fp16_turn_off = std::env::var("PIRE_FP16_OFF").is_ok();
299 let f32mm_turn_off = std::env::var("PIRE_F32MM_OFF").is_ok();
300 let fcma_turn_off = std::env::var("PIRE_FCMA_OFF").is_ok();
301 let i8mm_turn_off = std::env::var("PIRE_I8MM_OFF").is_ok();
302
303 cpu_ft.neon = cpu_ft.neon && !neon_turn_off;
304 cpu_ft.sve = cpu_ft.sve && !sve_turn_off;
305 cpu_ft.fp16 = cpu_ft.fp16 && !fp16_turn_off;
306 cpu_ft.f32mm = cpu_ft.f32mm && !f32mm_turn_off;
307 cpu_ft.fcma = cpu_ft.fcma && !fcma_turn_off;
308 cpu_ft.i8mm = cpu_ft.i8mm && !i8mm_turn_off;
309 }
310}
311
312#[cfg(not(feature = "debug_cpu_features"))]
313pub static RUNTIME_HW_CONFIG: Lazy<HWConfig> = Lazy::new(|| detect_hw_config());
314#[cfg(feature = "debug_cpu_features")]
315pub static RUNTIME_HW_CONFIG: Lazy<HWConfig> = Lazy::new(|| {
316 let mut hw_config = detect_hw_config();
317 apply_debug_cpu_features(&mut hw_config.cpu_ft);
318 hw_config
319});
320
321pub static PIRE_NUM_THREADS: Lazy<usize> = Lazy::new(|| {
322 let n_core = std::thread::available_parallelism().unwrap().get();
323 let x = std::env::var("PIRE_NUM_THREADS").unwrap_or(n_core.to_string());
325 x.parse::<usize>().unwrap()
326});
327#[cfg(target_arch = "x86_64")]
328pub(crate) mod cpu_features {
329 use super::HWModel;
330 use super::RUNTIME_HW_CONFIG;
331
332 pub fn hw_model() -> HWModel {
333 RUNTIME_HW_CONFIG.hw_model
334 }
335
336 pub fn has_f32_compute() -> bool {
337 RUNTIME_HW_CONFIG.cpu_ft.avx || RUNTIME_HW_CONFIG.cpu_ft.sse
341 }
342
343 pub fn has_c32_compute() -> bool {
344 RUNTIME_HW_CONFIG.cpu_ft.avx || (RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse3)
345 }
346
347 pub fn has_f16f32_compute() -> bool {
348 RUNTIME_HW_CONFIG.cpu_ft.avx && RUNTIME_HW_CONFIG.cpu_ft.f16c
349 }
350 pub fn has_f64_compute() -> bool {
351 RUNTIME_HW_CONFIG.cpu_ft.avx || (RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse2)
352 }
353
354 pub fn has_c64_compute() -> bool {
355 RUNTIME_HW_CONFIG.cpu_ft.avx
356 || (RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse2 && RUNTIME_HW_CONFIG.cpu_ft.sse3)
357 }
358
359 pub fn has_f16_compute() -> bool {
360 RUNTIME_HW_CONFIG.cpu_ft.avx512f16
361 && RUNTIME_HW_CONFIG.cpu_ft.avx
362 && RUNTIME_HW_CONFIG.cpu_ft.f16c
363 && RUNTIME_HW_CONFIG.cpu_ft.fma
364 }
365 pub fn has_i16i32_compute() -> bool {
366 (RUNTIME_HW_CONFIG.cpu_ft.avx2 && RUNTIME_HW_CONFIG.cpu_ft.avx)
367 || (RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse2)
368 }
369 pub fn has_i8i32_compute() -> bool {
370 (RUNTIME_HW_CONFIG.cpu_ft.avx2 && RUNTIME_HW_CONFIG.cpu_ft.avx)
371 || (RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse2 && RUNTIME_HW_CONFIG.cpu_ft.ssse3)
372 }
373 pub fn get_cache_params() -> (usize, usize, usize) {
375 (4800, 256, 128)
376 }
377}
378
379#[cfg(target_arch = "x86")]
380pub(crate) mod cpu_features {
381 use super::HWModel;
382 use super::RUNTIME_HW_CONFIG;
383
384 pub fn hw_model() -> HWModel {
385 RUNTIME_HW_CONFIG.hw_model
386 }
387
388 pub fn has_f32_compute() -> bool {
389 RUNTIME_HW_CONFIG.cpu_ft.sse
393 }
394
395 pub fn has_c32_compute() -> bool {
396 RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse3
397 }
398
399 pub fn has_f16f32_compute() -> bool {
400 false
401 }
402 pub fn has_f64_compute() -> bool {
403 RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse2
404 }
405
406 pub fn has_c64_compute() -> bool {
407 RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse2 && RUNTIME_HW_CONFIG.cpu_ft.sse3
408 }
409
410 pub fn has_f16_compute() -> bool {
411 false
412 }
413 pub fn has_i16i32_compute() -> bool {
414 RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse2
415 }
416 pub fn has_i8i32_compute() -> bool {
417 RUNTIME_HW_CONFIG.cpu_ft.sse && RUNTIME_HW_CONFIG.cpu_ft.sse2 && RUNTIME_HW_CONFIG.cpu_ft.ssse3
418 }
419 pub fn get_cache_params() -> (usize, usize, usize) {
421 (4800, 256, 128)
422 }
423}
424#[cfg(target_arch = "aarch64")]
425pub(crate) mod cpu_features {
426
427 use super::HWModel;
433 use super::RUNTIME_HW_CONFIG;
434
435 pub fn hw_model() -> HWModel {
436 RUNTIME_HW_CONFIG.hw_model
437 }
438
439 pub fn has_f32_compute() -> bool {
440 RUNTIME_HW_CONFIG.cpu_ft.neon
444 }
445
446 pub fn has_c32_compute() -> bool {
447 RUNTIME_HW_CONFIG.cpu_ft.neon
448 }
449
450 pub fn has_f16f32_compute() -> bool {
451 false
452 }
453 pub fn has_f64_compute() -> bool {
454 RUNTIME_HW_CONFIG.cpu_ft.neon
455 }
456
457 pub fn has_c64_compute() -> bool {
458 RUNTIME_HW_CONFIG.cpu_ft.neon
459 }
460
461 pub fn has_f16_compute() -> bool {
462 RUNTIME_HW_CONFIG.cpu_ft.fp16 && RUNTIME_HW_CONFIG.cpu_ft.neon
463 }
464 pub fn has_i16i32_compute() -> bool {
465 false
468 }
469 pub fn has_i8i32_compute() -> bool {
470 RUNTIME_HW_CONFIG.cpu_ft.i8mm && RUNTIME_HW_CONFIG.cpu_ft.neon
471 }
472 pub fn get_cache_params() -> (usize, usize, usize) {
474 (4800, 256, 128)
475 }
476}
477
478#[cfg(not(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64")))]
479pub(crate) mod cpu_features {
480 use super::HWModel;
481 use super::RUNTIME_HW_CONFIG;
482
483 pub fn hw_model() -> HWModel {
484 RUNTIME_HW_CONFIG.hw_model
485 }
486
487 pub fn has_f32_compute() -> bool {
488 false
489 }
490
491 pub fn has_c32_compute() -> bool {
492 false
493 }
494
495 pub fn has_f16f32_compute() -> bool {
496 false
497 }
498 pub fn has_f64_compute() -> bool {
499 false
500 }
501
502 pub fn has_c64_compute() -> bool {
503 false
504 }
505
506 pub fn has_f16_compute() -> bool {
507 false
508 }
509 pub fn has_i16i32_compute() -> bool {
510 false
511 }
512 pub fn has_i8i32_compute() -> bool {
513 false
514 }
515 pub fn get_cache_params() -> (usize, usize, usize) {
516 (4800, 256, 128)
517 }
518}
519pub use cpu_features::*;
520
521pub struct PackPool {
522 pub buffer: RwLock<Vec<Mutex<Vec<u8>>>>,
523}
524
525pub static PACK_POOL: PackPool = PackPool { buffer: RwLock::new(vec![]) };
526
527pub fn acquire<'a>(
528 pool_guard: &'a RwLockReadGuard<'a, Vec<Mutex<Vec<u8>>>>,
529 pack_size: usize,
530) -> Option<MutexGuard<'a, Vec<u8>>> {
531 for i in pool_guard.iter() {
534 let lock = i.try_lock();
541 if let Ok(mutex) = lock {
542 if mutex.len() >= pack_size {
543 return Some(mutex);
544 }
545 }
546 }
547
548 None
549}
550
551pub fn extend<'a>(pool_vec: Vec<u8>) {
552 let mut pool_guard = PACK_POOL.buffer.write().unwrap();
553 pool_guard.push(Mutex::new(pool_vec));
554}
555
556pub struct PireThreadConfig<'a> {
557 pub ic_id: usize,
558 pub jc_id: usize,
560 pub ir_id: usize,
561 pub jr_id: usize,
562 pub i_load_p_idx: usize,
563 pub j_load_p_idx: usize,
564 pub mc_eff: usize,
565 pub nc_eff: usize,
566 pub kc_eff: usize,
567 pub par: PirePar,
568 pub packa_barrier: &'a [Barrier],
569 pub packb_barrier: &'a [Barrier],
570}
571
572pub fn get_apbp_barrier(par: &PirePar) -> (Vec<Barrier>, Vec<Barrier>) {
573 let mut packa_barrier = vec![];
574 for _ in 0..par.ic_par {
575 let barrier = Barrier::new(par.jc_par * par.pc_par * par.ir_par * par.jr_par);
576 packa_barrier.push(barrier);
577 }
578
579 let mut packb_barrier = vec![];
580 for _ in 0..par.jc_par {
581 let barrier = Barrier::new(par.ic_par * par.pc_par * par.ir_par * par.jr_par);
582 packb_barrier.push(barrier);
583 }
584
585 (packa_barrier, packb_barrier)
586}
587
588impl<'a> PireThreadConfig<'a> {
589 pub fn new(
590 par: PirePar,
591 packa_barrier: &'a [Barrier],
592 packb_barrier: &'a [Barrier],
593 t_id: usize,
594 mc_eff: usize,
595 nc_eff: usize,
596 kc_eff: usize,
597 ) -> Self {
598 let ic_id = par.get_ic_id(t_id);
599 let jc_id = par.get_jc_id(t_id);
601 let ir_id = par.get_ir_id(t_id);
602 let jr_id = par.get_jr_id(t_id);
603 let i_load_p_idx = jc_id * par.ir_par * par.jr_par + ir_id * par.jr_par + jr_id;
604 let j_load_p_idx = ic_id * par.ir_par * par.jr_par + ir_id * par.jr_par + jr_id;
605
606 Self {
607 ic_id,
608 jc_id,
610 ir_id,
611 jr_id,
612 i_load_p_idx,
613 j_load_p_idx,
614 mc_eff,
615 nc_eff,
616 kc_eff,
617 par,
618 packa_barrier,
619 packb_barrier,
620 }
621 }
622 #[inline]
623 pub fn wait_packa(&self) {
624 if self.par.jc_par * self.par.pc_par * self.par.ir_par * self.par.jr_par > 1 {
625 self.packa_barrier[self.ic_id].wait();
626 }
627 }
628
629 #[inline]
630 pub fn wait_packb(&self) {
631 if self.par.ic_par * self.par.pc_par * self.par.ir_par * self.par.jr_par > 1 {
632 self.packb_barrier[self.jc_id].wait();
633 }
634 }
635}
636
637#[inline(always)]
644pub fn pire_num_threads() -> usize {
645 return *PIRE_NUM_THREADS;
646}
647
648#[derive(Copy, Clone)]
649pub struct PirePar {
650 pub num_threads: usize,
651 pub ic_par: usize,
652 pub pc_par: usize,
653 pub jc_par: usize,
654 pub ir_par: usize,
655 pub jr_par: usize,
656}
657
658#[inline(always)]
661fn inc_par(ic_par: usize, jc_par: usize, ic_max: usize, jc_max: usize, factor: usize) -> (usize, usize, usize, usize) {
662 if (ic_par < jc_par && ic_par < ic_max) || (jc_par >= jc_max && ic_par < ic_max) {
663 (ic_par * factor, jc_par, ic_max / factor, jc_max)
664 } else if (ic_par >= jc_par && jc_par < jc_max) || (ic_par >= ic_max && jc_par < jc_max) {
665 (ic_par, jc_par * factor, ic_max, jc_max / factor)
666 } else {
667 (ic_par, jc_par, ic_max, jc_max)
668 }
669}
670impl PirePar {
671 pub fn new(num_threads: usize, ic_par: usize, pc_par: usize, jc_par: usize, ir_par: usize, jr_par: usize) -> Self {
672 assert_eq!(num_threads, jc_par * pc_par * ic_par * jr_par * ir_par);
673 Self { num_threads, ic_par, pc_par, jc_par, ir_par, jr_par }
674 }
675 pub fn from_num_threads(num_threads: usize, m: usize, n: usize) -> Self {
676 let mut num_threads = num_threads;
677 let mut ic_par_max = if m < 96 {
678 1
679 } else if m < 400 {
680 2
681 } else {
682 m / 200
683 };
684 let mut jc_par_max = if n < 48 {
685 1
686 } else if n < 200 {
687 2
688 } else {
689 n / 100
690 };
691
692 if num_threads <= 12 {
693 let jc_par_max = jc_par_max.min(num_threads);
694 let n_thread = (num_threads / jc_par_max) * jc_par_max;
695 return Self::new(n_thread, num_threads / jc_par_max, 1, jc_par_max, 1, 1);
696 }
697 num_threads = num_threads.min(ic_par_max * jc_par_max);
699 let mut ic_par = 1;
700 let pc_par = 1;
701 let mut jc_par = 1;
702 let mut ir_par = 1;
703 let jr_par = 1;
704
705 while num_threads > 1 {
706 if num_threads % 2 == 0 {
707 num_threads = num_threads / 2;
708 (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 2);
709 } else if num_threads % 3 == 0 {
710 num_threads = num_threads / 3;
711 (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 3);
712 } else if num_threads % 5 == 0 {
713 num_threads = num_threads / 5;
714 (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 5);
715 continue;
716 } else if num_threads % 7 == 0 {
717 num_threads = num_threads / 7;
718 (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 7);
719 continue;
720 } else {
721 num_threads = num_threads / 2 * 2;
724 }
725 }
741 if ic_par >= 8 {
742 ic_par = ic_par / 2;
743 ir_par = 2;
744 }
745 let num_threads = ic_par * pc_par * jc_par * ir_par * jr_par;
746 Self { num_threads, ic_par, pc_par, jc_par, ir_par, jr_par }
747 }
748 #[inline(always)]
749 pub fn default(m: usize, n: usize) -> Self {
750 let num_threads = pire_num_threads();
751 Self::from_num_threads(num_threads, m, n)
752 }
753 #[inline]
754 fn get_ic_id(&self, t_id: usize) -> usize {
755 (t_id / (self.pc_par * self.jc_par * self.ir_par * self.jr_par)) % self.ic_par
756 }
757
758 #[inline]
763 fn get_jc_id(&self, t_id: usize) -> usize {
764 (t_id / (self.jr_par * self.ir_par)) % self.jc_par
765 }
766 #[inline]
767 fn get_jr_id(&self, t_id: usize) -> usize {
768 (t_id / self.ir_par) % self.jr_par
769 }
770 #[inline]
771 fn get_ir_id(&self, t_id: usize) -> usize {
772 t_id % self.ir_par
773 }
774
775 pub fn get_load_par(
776 &self,
777 gemm_mode: &GemmPool,
778 m: usize,
779 n: usize,
780 mc_eff: usize,
781 nc_eff: usize,
782 ) -> (usize, usize) {
783 let m = (m / self.ic_par).min(mc_eff);
784 let n = (n / self.jc_par).min(nc_eff);
785 let i_load_par = ((m + 127) / 128).min(self.num_threads / self.ic_par);
786 let j_load_par = ((n + 127) / 128).min(self.num_threads / self.jc_par);
787 let i_load_par = match gemm_mode {
788 GemmPool::Goto => i_load_par,
789 GemmPool::SmallM => i_load_par,
790 GemmPool::SmallN => 1,
791 };
792 (i_load_par.max(1), j_load_par.max(1))
793 }
794}
795
796#[inline]
797pub fn split_c_range(m: usize, mc: usize, mr: usize, ic_id: usize, ic_par: usize) -> (usize, usize, bool) {
798 let chunk_len = (m / (mr * ic_par)) * mr;
799 let rem = m % (mr * ic_par);
800 if ic_id == 0 {
801 let x = chunk_len + rem % mr;
802 let mc_left = ((((x + mc - 1) / mc) * mc) * ic_par) < m;
803 return (m - chunk_len - (rem % mr), m, mc_left);
804 }
805 let ic_id = ic_id - 1;
806 let m0 = (m / mr) * mr;
807 let rem = m0 % (mr * ic_par);
808 let start_delta = rem.min(ic_id * mr);
809 let end_delta = rem.min((ic_id + 1) * mr);
810 let mc_coeff = (chunk_len + end_delta - start_delta + mc - 1) / mc;
812 let mc_left = ((mc_coeff * mc) * ic_par) < m;
813 (chunk_len * ic_id + start_delta, chunk_len * (ic_id + 1) + end_delta, mc_left)
815}
816
817#[inline]
818pub fn split_range(range_len: usize, unit_len: usize, r_id: usize, r_par: usize) -> (usize, usize) {
819 let chunk_start = (range_len / (unit_len * r_par)) * unit_len * r_id;
820 let chunk_end = (range_len / (unit_len * r_par)) * unit_len * (r_id + 1);
821 let rem = range_len % (unit_len * r_par);
822 let rem = rem - rem % unit_len;
823 let rem_start = rem.min(r_id * unit_len);
824 let rem_end = rem.min((r_id + 1) * unit_len);
825 if r_id == r_par - 1 {
826 return (chunk_start + rem_start, range_len);
827 }
828 (chunk_start + rem_start, chunk_end + rem_end)
829}
830
831pub trait BaseNum: Copy + 'static + Send {}
832
833impl<T> BaseNum for T where T: Copy + 'static + Send {}
834
835#[derive(Copy, Clone)]
836pub struct PoolSize {
837 pub m: usize,
838 pub n: usize,
839 pub k: usize,
840 pub ap_pool_size: usize,
841 pub ap_pool_multiplicity: usize,
842 pub bp_pool_size: usize,
843 pub bp_pool_multiplicity: usize,
844}
845
846impl PoolSize {
847 pub fn mem_pool_size_b<TA, TB>(&self) -> usize {
849 self.ap_pool_size * size_of::<TA>() * self.ap_pool_multiplicity
851 + self.bp_pool_size * size_of::<TB>() * self.bp_pool_multiplicity
852 + 2 * AB_ALIGN
853 }
854
855 pub fn ap_size_b<TA>(&self) -> usize {
856 self.ap_pool_size * size_of::<TA>()
857 }
858
859 pub fn bp_size_b<TB>(&self) -> usize {
860 self.bp_pool_size * size_of::<TB>()
861 }
862
863 pub fn ap_size_t_b<TA>(&self) -> usize {
864 self.ap_pool_size * size_of::<TA>() * self.ap_pool_multiplicity
865 }
866
867 pub fn bp_size_t_b<TB>(&self) -> usize {
868 self.bp_pool_size * size_of::<TB>() * self.bp_pool_multiplicity
869 }
870
871 pub fn slice_mut_from_pool<TA, TB>(
872 &self,
873 mem_pool: &mut [u8],
874 i_load_par: usize,
875 j_load_par: usize,
876 pool_size: PoolSize,
877 mr: usize,
878 nr: usize,
879 ) -> (Vec<RangeLock<'_, TA>>, Vec<RangeLock<'_, TB>>) {
881 let m_size = pool_size.m;
882 let n_size = pool_size.n;
883 let k_size = pool_size.k;
884 let ap_pool_size = self.ap_pool_size;
885 let ap_pool_size_b = ap_pool_size * size_of::<TA>();
886 let a_alignment = std::mem::align_of::<TA>();
887 assert_eq!(ap_pool_size_b % a_alignment, 0);
888 let bp_pool_size = self.bp_pool_size;
889 let bp_pool_size_b = bp_pool_size * size_of::<TB>();
890 let b_alignment = std::mem::align_of::<TB>();
891 assert_eq!(bp_pool_size_b % b_alignment, 0);
892 let mut ap = vec![];
893 let mut bp = vec![];
894 assert!(mem_pool.len() >= self.mem_pool_size_b::<TA, TB>());
897 let align_offset = mem_pool.as_ptr().align_offset(AB_ALIGN);
899 let mut mem_pool = &mut mem_pool[align_offset..];
900 assert_eq!(mem_pool.as_ptr().align_offset(a_alignment), 0);
902 for _ in 0..self.ap_pool_multiplicity {
903 let (a, rest) = mem_pool.split_at_mut(ap_pool_size_b);
904 let ap_pool = unsafe { std::slice::from_raw_parts_mut::<TA>(a.as_mut_ptr() as *mut TA, ap_pool_size) };
905 if ap_pool_size == 0 {
906 ap.push(RangeLock::from(ap_pool, i_load_par, 0, k_size, mr));
907 } else {
908 ap.push(RangeLock::from(ap_pool, i_load_par, m_size, k_size, mr));
909 }
910 mem_pool = rest;
911 }
912 let align_offset = mem_pool.as_ptr().align_offset(AB_ALIGN);
913 let mut mem_pool = &mut mem_pool[align_offset..];
914 assert_eq!(mem_pool.as_ptr().align_offset(b_alignment), 0);
916 for _ in 0..self.bp_pool_multiplicity {
917 let (b, rest) = mem_pool.split_at_mut(bp_pool_size_b);
918 let bp_pool = unsafe { std::slice::from_raw_parts_mut::<TB>(b.as_mut_ptr() as *mut TB, bp_pool_size) };
919 if bp_pool_size == 0 {
920 bp.push(RangeLock::from(bp_pool, j_load_par, 0, k_size, nr));
921 } else {
922 bp.push(RangeLock::from(bp_pool, j_load_par, n_size, k_size, nr));
923 }
924 mem_pool = rest;
925 }
926 (ap, bp)
927 }
928}
929
930pub fn get_mem_pool_size_goto<AP: BaseNum, BP: BaseNum, HWConfig: GemmCache>(
931 hw_config: &HWConfig,
932 par: &PirePar,
933 a_need_pool: bool,
934 b_need_pool: bool,
935) -> PoolSize {
936 let m = hw_config.get_mc_eff(par.ic_par);
937 let n = hw_config.get_nc_eff(par.jc_par);
938 let k = hw_config.get_kc_eff();
939 let (ap_pool_size, ap_pool_multiplicity) = if a_need_pool {
940 let ap_pool_multiplicity = par.ic_par;
941 let ap_pool_size = hw_config.get_ap_pool_size(par.ic_par) + CACHELINE_PAD / size_of::<AP>();
942 (ap_pool_size, ap_pool_multiplicity)
943 } else {
944 (0, 1)
945 };
946 let (bp_pool_size, bp_pool_multiplicity) = if b_need_pool {
947 let bp_pool_multiplicity = par.jc_par;
948 let bp_pool_size = hw_config.get_bp_pool_size(par.jc_par) + CACHELINE_PAD / size_of::<BP>();
949 (bp_pool_size, bp_pool_multiplicity)
950 } else {
951 (0, 1)
952 };
953 PoolSize { m, n, k, ap_pool_size, ap_pool_multiplicity, bp_pool_size, bp_pool_multiplicity }
954}
955
956pub fn get_mem_pool_size_small_m<AP: BaseNum, BP: BaseNum, HWConfig: GemmCache>(
957 hw_config: &HWConfig,
958 par: &PirePar,
959 a_need_pool: bool,
960) -> PoolSize {
961 let m = hw_config.get_mc_eff(par.ic_par);
962 let n = hw_config.get_nc_eff(par.jc_par);
963 let k = hw_config.get_kc_eff();
964 if a_need_pool {
965 let ap_pool_multiplicity = par.ic_par;
966 let ap_pool_size = hw_config.get_ap_pool_size(par.ic_par) + CACHELINE_PAD / size_of::<AP>();
967 PoolSize { m, n, k, ap_pool_size, ap_pool_multiplicity, bp_pool_size: 0, bp_pool_multiplicity: 1 }
968 } else {
969 PoolSize { m, n, k, ap_pool_size: 0, ap_pool_multiplicity: 1, bp_pool_size: 0, bp_pool_multiplicity: 1 }
970 }
971}
972
973pub fn get_mem_pool_size_small_n<AP: BaseNum, BP: BaseNum, HWConfig: GemmCache>(
974 hw_config: &HWConfig,
975 par: &PirePar,
976 b_need_pool: bool,
977) -> PoolSize {
978 let ap_pool_size = hw_config.get_ap_pool_size2() + CACHELINE_PAD / size_of::<AP>();
979 let ap_pool_multiplicity = par.num_threads;
980 let m = hw_config.mr();
981 let n = hw_config.get_nc_eff(par.jc_par);
982 let k = hw_config.get_kc_eff();
983 if b_need_pool {
984 let bp_pool_multiplicity = par.jc_par;
985 let bp_pool_size = hw_config.get_bp_pool_size(par.jc_par) + CACHELINE_PAD / size_of::<BP>();
986 PoolSize { m, n, k, ap_pool_size, ap_pool_multiplicity, bp_pool_size, bp_pool_multiplicity }
987 } else {
988 PoolSize { m, n, k, ap_pool_size, ap_pool_multiplicity, bp_pool_size: 0, bp_pool_multiplicity: 1 }
989 }
990}
991
992pub fn run_small_m(m: usize) -> bool {
1000 m < 144
1001}
1002
1003pub fn run_small_n(n: usize) -> bool {
1004 n < 144
1005}
1006
1007pub enum GemmPool {
1008 Goto,
1009 SmallM,
1010 SmallN,
1011}
1012
1013#[derive(Clone, Copy)]
1014pub struct StridedMatrix<T> {
1015 pub(crate) src: *const T,
1016 pub(crate) rs: usize,
1017 pub(crate) cs: usize,
1018}
1019
1020impl<T> StridedMatrix<T> {
1021 pub fn new(src: *const T, rs: usize, cs: usize) -> Self {
1022 Self { src, rs, cs }
1023 }
1024}
1025
1026unsafe impl<T> Send for StridedMatrix<T> {}
1027
1028#[derive(Clone, Copy)]
1029pub struct StridedMatrixMut<T> {
1030 pub(crate) src: *mut T,
1031 pub(crate) rs: usize,
1032 pub(crate) cs: usize,
1033}
1034
1035unsafe impl<T> Send for StridedMatrixMut<T> {}
1036
1037impl<T> StridedMatrixMut<T> {
1038 pub fn new(src: *mut T, rs: usize, cs: usize) -> Self {
1039 Self { src, rs, cs }
1040 }
1041}
1042
1043#[derive(Clone)]
1044pub struct StridedMatrixP<'a, T, U> {
1045 pub(crate) src: *const T,
1046 pub(crate) rs: usize,
1047 pub(crate) cs: usize,
1048 pub(crate) dst: &'a RangeLock<'a, U>,
1049}
1050
1051unsafe impl<'a, T, U> Send for StridedMatrixP<'a, T, U> {}
1052
1053impl<'a, T, U> StridedMatrixP<'a, T, U> {
1054 pub fn src(&self) -> *const T {
1055 self.src
1056 }
1057 pub fn dst_write(&self, idx: usize, kc: usize) -> RangeLockWriteGuard<'a, 'a, U> {
1058 self.dst.write(idx, kc).unwrap()
1059 }
1060 pub fn dst_read(&self) -> RangeLockReadGuard<'a, 'a, U> {
1061 self.dst.read().unwrap()
1062 }
1063 pub fn get_mc(&self) -> usize {
1064 self.dst.get_mc()
1065 }
1066 pub fn rs(&self) -> usize {
1067 self.rs
1068 }
1069 pub fn cs(&self) -> usize {
1070 self.cs
1071 }
1072}
1073
1074#[derive(Clone, Copy)]
1075pub struct PackedMatrix<T> {
1076 pub(crate) src: *const T,
1077 pub(crate) k: usize,
1078 pub(crate) m: usize,
1079 }
1082
1083unsafe impl<T> Send for PackedMatrix<T> {}
1084
1085impl<T> PackedMatrix<T> {
1086 pub fn src(&self) -> *const T {
1087 self.src
1088 }
1089 pub fn k(&self) -> usize {
1090 self.k
1091 }
1092 pub fn m(&self) -> usize {
1093 self.m
1094 }
1095 }
1099
1100#[derive(Clone)]
1101pub struct PackedMatrixMixed<'a, X, Y> {
1102 pub(crate) src: *const X,
1103 pub(crate) dst: &'a RangeLock<'a, Y>,
1104 pub(crate) k: usize,
1105 pub(crate) m: usize,
1106}
1107
1108impl<'a, X, Y> PackedMatrixMixed<'a, X, Y> {
1109 pub fn src(&self) -> *const X {
1110 self.src
1111 }
1112 pub fn k(&self) -> usize {
1113 self.k
1114 }
1115 pub fn m(&self) -> usize {
1116 self.m
1117 }
1118
1119 pub fn dst_write(&self, idx: usize, kc: usize) -> RangeLockWriteGuard<'a, 'a, Y> {
1120 self.dst.write(idx, kc).unwrap()
1121 }
1122
1123 pub fn get_mc(&self) -> usize {
1124 self.dst.get_mc()
1125 }
1126
1127 pub fn dst_read(&self) -> RangeLockReadGuard<'a, 'a, Y> {
1128 self.dst.read().unwrap()
1129 }
1130}
1131
1132unsafe impl<X, Y> Send for PackedMatrixMixed<'_, X, Y> {}
1133
1134pub const AB_ALIGN: usize = 1024;
1137
1138pub trait GemmCache {
1139 fn mr(&self) -> usize;
1140 fn get_mc_eff(&self, par: usize) -> usize;
1141 fn get_kc_eff(&self) -> usize;
1142 fn get_nc_eff(&self, par: usize) -> usize;
1143 fn get_ap_pool_size(&self, ic_par: usize) -> usize {
1144 let mc_eff = self.get_mc_eff(ic_par);
1145 let kc_eff = self.get_kc_eff();
1146 mc_eff * kc_eff
1147 }
1148 fn get_ap_pool_size2(&self) -> usize {
1149 let kc_eff = self.get_kc_eff();
1150 self.mr() * kc_eff
1151 }
1152 fn get_bp_pool_size(&self, jc_par: usize) -> usize {
1153 let nc_eff = self.get_nc_eff(jc_par);
1154 let kc_eff = self.get_kc_eff();
1155 nc_eff * kc_eff
1156 }
1157}
1158
1159#[derive(Copy, Clone)]
1160pub enum Array<X> {
1161 StridedMatrix(StridedMatrix<X>),
1162 PackedMatrix(PackedMatrix<X>),
1163}
1164
1165impl<X> Array<X> {
1166 pub fn strided_matrix(src: *const X, rs: usize, cs: usize) -> Self {
1167 Array::StridedMatrix(StridedMatrix::new(src, rs, cs))
1168 }
1169 pub fn packed_matrix(src: *const X, m: usize, k: usize) -> Self {
1170 Array::PackedMatrix(PackedMatrix { src, k, m })
1171 }
1172 pub fn into_pack_array<'a>(&self, a: &'a [RangeLock<'a, X>], p_id: usize) -> PArray<'a, X> {
1173 match self {
1174 Array::StridedMatrix(x) => {
1175 let x = StridedMatrixP { src: x.src, rs: x.rs, cs: x.cs, dst: &a[p_id] };
1176 PArray::<X>::StridedMatrix(x)
1177 }
1178 Array::PackedMatrix(x) => {
1179 let x = PackedMatrix { src: x.src, k: x.k, m: x.m };
1180 PArray::PackedMatrix(x)
1181 }
1182 }
1183 }
1184 pub fn into_pack_array2<'a, Y>(&self, a: &'a [RangeLock<'a, Y>], p_id: usize) -> PArrayMixed<'a, X, Y> {
1185 match self {
1186 Array::StridedMatrix(x) => {
1187 let x = StridedMatrixP { src: x.src, rs: x.rs, cs: x.cs, dst: &a[p_id] };
1188 PArrayMixed::<X, Y>::StridedMatrix(x)
1189 }
1190 Array::PackedMatrix(x) => {
1191 let x = PackedMatrixMixed { src: x.src, dst: &a[p_id], k: x.k, m: x.m };
1192 PArrayMixed::PackedMatrix(x)
1193 }
1194 }
1195 }
1196
1197 pub fn src(&self) -> *const X {
1198 match self {
1199 Array::StridedMatrix(x) => x.src,
1200 Array::PackedMatrix(x) => x.src,
1201 }
1202 }
1203
1204 pub fn transpose(&mut self) {
1205 match self {
1206 Array::StridedMatrix(x) => {
1207 let temp = x.rs;
1208 x.rs = x.cs;
1209 x.cs = temp;
1210 }
1211 _ => {
1212 panic!("Only StridedMatrix has transpose");
1213 }
1214 }
1215 }
1216
1217 pub fn rs(&self) -> usize {
1218 match self {
1219 Array::StridedMatrix(x) => x.rs,
1220 _ => {
1221 panic!("Only StridedMatrix has rs");
1222 }
1223 }
1224 }
1225
1226 pub fn cs(&self) -> usize {
1227 match self {
1228 Array::StridedMatrix(x) => x.cs,
1229 _ => {
1230 panic!("Only StridedMatrix has cs");
1231 }
1232 }
1233 }
1234
1235 pub fn is_strided(&self) -> bool {
1236 match self {
1237 Array::StridedMatrix(_) => true,
1238 _ => false,
1239 }
1240 }
1241}
1242
1243#[derive(Copy, Clone)]
1244pub enum ArrayMut<X> {
1245 StridedMatrix(StridedMatrixMut<X>),
1246}
1247
1248impl<X> ArrayMut<X> {
1249 pub fn strided_matrix(src: *mut X, rs: usize, cs: usize) -> Self {
1250 ArrayMut::StridedMatrix(StridedMatrixMut::new(src, rs, cs))
1251 }
1252
1253 pub fn src(&self) -> *mut X {
1254 match self {
1255 ArrayMut::StridedMatrix(x) => x.src,
1256 }
1257 }
1258
1259 pub fn transpose(&mut self) {
1260 match self {
1261 ArrayMut::StridedMatrix(x) => {
1262 let temp = x.rs;
1263 x.rs = x.cs;
1264 x.cs = temp;
1265 }
1266 }
1267 }
1268
1269 pub fn rs(&self) -> usize {
1270 match self {
1271 ArrayMut::StridedMatrix(x) => x.rs,
1272 }
1273 }
1274
1275 pub fn cs(&self) -> usize {
1276 match self {
1277 ArrayMut::StridedMatrix(x) => x.cs,
1278 }
1279 }
1280}
1281
1282#[derive(Clone)]
1283pub enum PArray<'a, X> {
1284 StridedMatrix(StridedMatrixP<'a, X, X>),
1285 PackedMatrix(PackedMatrix<X>),
1286}
1287
1288impl<'a, X> PArray<'a, X> {
1289 pub fn src(&self) -> *const X {
1290 match self {
1291 Self::StridedMatrix(x) => x.src,
1292 Self::PackedMatrix(x) => x.src,
1293 }
1294 }
1295
1296 pub fn rs(&self) -> usize {
1297 match self {
1298 Self::StridedMatrix(x) => x.rs,
1299 _ => {
1300 panic!("Only StridedMatrix has rs");
1301 }
1302 }
1303 }
1304
1305 pub fn cs(&self) -> usize {
1306 match self {
1307 Self::StridedMatrix(x) => x.cs,
1308 _ => {
1309 panic!("Only StridedMatrix has cs");
1310 }
1311 }
1312 }
1313
1314 pub fn dst_write(&self, idx: usize, kc: usize) -> RangeLockWriteGuard<'a, 'a, X> {
1315 match self {
1316 Self::StridedMatrix(x) => x.dst.write(idx, kc).unwrap(),
1317 _ => {
1318 panic!("Only StridedMatrix has write guard");
1319 }
1320 }
1321 }
1322
1323 pub fn dst_read(&self) -> RangeLockReadGuard<'a, 'a, X> {
1324 match self {
1325 Self::StridedMatrix(x) => x.dst.read().unwrap(),
1326 _ => {
1327 panic!("Only StridedMatrix has read guard");
1328 }
1329 }
1330 }
1331
1332 pub fn is_strided(&self) -> bool {
1333 match self {
1334 Self::StridedMatrix(_) => true,
1335 _ => false,
1336 }
1337 }
1338}
1339
1340#[derive(Clone)]
1341pub enum PArrayMixed<'a, X, Y> {
1342 StridedMatrix(StridedMatrixP<'a, X, Y>),
1343 PackedMatrix(PackedMatrixMixed<'a, X, Y>),
1344}
1345
1346impl<'a, X, Y> PArrayMixed<'a, X, Y> {
1347 pub fn src(&self) -> *const X {
1348 match self {
1349 Self::StridedMatrix(x) => x.src,
1350 Self::PackedMatrix(x) => x.src,
1351 }
1352 }
1353
1354 pub fn rs(&self) -> usize {
1355 match self {
1356 Self::StridedMatrix(x) => x.rs,
1357 _ => {
1358 panic!("Only StridedMatrix has rs");
1359 }
1360 }
1361 }
1362
1363 pub fn cs(&self) -> usize {
1364 match self {
1365 Self::StridedMatrix(x) => x.cs,
1366 _ => {
1367 panic!("Only StridedMatrix has cs");
1368 }
1369 }
1370 }
1371
1372 pub fn dst_write(&self, idx: usize, kc: usize) -> RangeLockWriteGuard<'a, 'a, Y> {
1373 match self {
1374 Self::StridedMatrix(x) => x.dst.write(idx, kc).unwrap(),
1375 Self::PackedMatrix(x) => x.dst.write(idx, kc).unwrap(),
1376 }
1377 }
1378 pub fn dst_read(&self) -> RangeLockReadGuard<'a, 'a, Y> {
1379 match self {
1380 Self::StridedMatrix(x) => x.dst.read().unwrap(),
1381 Self::PackedMatrix(x) => x.dst.read().unwrap(),
1382 }
1383 }
1384 pub fn is_strided(&self) -> bool {
1385 match self {
1386 Self::StridedMatrix(_) => true,
1387 _ => false,
1388 }
1389 }
1390}
1391
1392pub enum PtrData<'a, X> {
1393 RefData(RangeLockReadGuard<'a, 'a, X>),
1394 PtrData(*const X),
1395}
1396
1397impl<'a, X> PtrData<'a, X> {
1398 pub fn src(&self) -> *const X {
1399 match self {
1400 PtrData::RefData(x) => x.get().as_ptr(),
1401 PtrData::PtrData(x) => x.clone(),
1402 }
1403 }
1404}
1405
1406pub fn matrix_size_strided(m: usize, n: usize, rs: usize, cs: usize) -> usize {
1407 (m - 1) * rs + (n - 1) * cs
1408}
1409
1410#[macro_export]
1411macro_rules! packing_api {
1412 ($ta:ty, $tb:ty) => {
1413 fn a_size_packed(m: usize, k: usize) -> usize {
1414 let round_m_fn = dispatch_round_m();
1415 let round_k_fn = dispatch_round_k();
1416 let m_round = round_m_fn(m);
1417 let k_round = round_k_fn(k);
1418 return m_round * k_round;
1419 }
1420
1421 fn b_size_packed(n: usize, k: usize) -> usize {
1422 let round_k_fn = dispatch_round_k();
1423 let k_round = round_k_fn(k);
1424 return n * k_round;
1425 }
1426 pub unsafe fn pack_a_unchecked(
1436 m: usize,
1437 k: usize,
1438 a: *const $ta,
1439 a_rs: usize,
1440 a_cs: usize,
1441 ap: *mut $ta,
1442 ) -> Array<TA> {
1443 assert_eq!(ap.align_offset(AB_ALIGN), 0);
1444 if m == 1 {
1445 for j in 0..k {
1446 *ap.add(j) = *a.add(j * a_cs);
1447 }
1448 return Array::strided_matrix(ap, 1, m);
1449 }
1450 let pack_fn = dispatch_pack_a();
1451 let round_m_fn = dispatch_round_m();
1452 let round_k_fn = dispatch_round_k();
1453
1454 let (mc, _, kc) = dispatch_get_mcnckc();
1455 let mut ap_cur = ap;
1456 for p in (0..k).step_by(kc) {
1457 let kc_len = kc.min(k - p);
1458 let kc_len_eff = round_k_fn(kc_len);
1459 for i in (0..m).step_by(mc) {
1460 let mc_len = mc.min(m - i);
1461 let mc_len_eff = round_m_fn(mc_len);
1462 let a_cur = a.add(i * a_rs + p * a_cs);
1463 pack_fn(a_cur, ap_cur, mc_len, kc_len, a_rs, a_cs);
1464 ap_cur = ap_cur.add(mc_len_eff * kc_len_eff);
1465 }
1466 }
1467 return Array::packed_matrix(ap, m, k);
1468 }
1469
1470 pub unsafe fn pack_b_unchecked(
1474 n: usize,
1475 k: usize,
1476 b: *const $tb,
1477 b_rs: usize,
1478 b_cs: usize,
1479 bp: *mut $tb,
1480 ) -> Array<TB> {
1481 assert_eq!(bp.align_offset(AB_ALIGN), 0);
1482 if n == 1 {
1483 for j in 0..k {
1484 *bp.add(j) = *b.add(j * b_rs);
1485 }
1486 return Array::strided_matrix(bp, 1, k);
1487 }
1488 let pack_fn = dispatch_pack_b();
1489 let round_k_fn = dispatch_round_k();
1490
1491 let (_, nc, kc) = dispatch_get_mcnckc();
1492 let mut bp_cur = bp;
1493 for p in (0..k).step_by(kc) {
1494 let kc_len = kc.min(k - p);
1495 let kc_len_eff = round_k_fn(kc_len);
1496 for i in (0..n).step_by(nc) {
1497 let nc_len = nc.min(n - i);
1498 let b_cur = b.add(i * b_cs + p * b_rs);
1499 pack_fn(b_cur, bp_cur, nc_len, kc_len, b_rs, b_cs);
1500 bp_cur = bp_cur.add(nc_len * kc_len_eff);
1501 }
1502 }
1503 return Array::packed_matrix(bp, n, k);
1504 }
1505
1506 pub fn pack_a(m: usize, k: usize, a: &[$ta], a_rs: usize, a_cs: usize, ap: &mut [$ta]) -> Array<TA> {
1507 assert!(ap.len() >= a_size_packed(m, k));
1510 assert!(a.len() >= pire_base::matrix_size_strided(m, k, a_rs, a_cs));
1511 unsafe { pack_a_unchecked(m, k, a.as_ptr(), a_rs, a_cs, ap.as_mut_ptr()) }
1513 }
1514
1515 pub fn pack_b(n: usize, k: usize, b: &[$tb], b_rs: usize, b_cs: usize, bp: &mut [$tb]) -> Array<TB> {
1516 assert!(bp.len() >= b_size_packed(n, k));
1519 assert!(b.len() >= pire_base::matrix_size_strided(k, n, b_rs, b_cs));
1520 unsafe { pack_b_unchecked(n, k, b.as_ptr(), b_rs, b_cs, bp.as_mut_ptr()) }
1522 }
1523 };
1524}
1525
1526#[macro_export]
1527macro_rules! is_mixed {
1528 (T, $st1:expr, $st2:expr) => {
1529 $st1
1530 };
1531 (F, $src:expr, $st2:expr) => {
1532 $st2
1533 };
1534}
1535
1536#[macro_export]
1537macro_rules! def_pa {
1538 ($packa_ty:tt,F,$ta:tt,$tap:tt) => {
1539 type $packa_ty<'a> = PArray<'a, $tap>;
1540 };
1541 ($packa_ty:tt,T,$ta:tt,$tap:tt) => {
1542 type $packa_ty<'a> = PArrayMixed<'a, $ta, $tap>;
1543 };
1544}
1545
1546#[macro_export]
1547macro_rules! def_pire_gemm {
1548 (
1549 $t_dispatcher:tt,
1550 $ta:tt,$tap:ty,$tb:ty,$tbp:ty,$tc:ty,$t_as:ty,$t_bs:ty,
1551 $packa_ty:tt,$packb_ty:tt,
1552 $one:expr,
1553 $name:ident, $name_mt:ident,
1554 $goto_name:ident, $goto_kernel:ident,
1555 $small_m_name:ident, $small_m_kernel:ident,
1556 $small_n_name:ident, $small_n_kernel:ident,
1557 $gemv_name:ident, $gemv_name2:ident,
1558 $packa_name:ident, $packb_name:ident,
1559 $packa_name0:ident, $packb_name0:ident,
1560 $run_small_m:expr, $run_small_n:expr,
1561 $pack_fn:tt, $include_flag:tt,
1562 ) => {
1563 def_pa!($packa_ty,$include_flag,$ta,$tap);
1564 def_pa!($packb_ty,$include_flag,$tb,$tbp);
1565 pub(crate) unsafe fn $name <F:UnaryFnC>(
1566 hw_config: &$t_dispatcher <F>,
1567 m: usize, n: usize, k: usize,
1568 alpha: $t_as,
1569 a: Array<$ta>,
1570 b: Array<$tb>,
1571 beta: $t_bs,
1572 c: ArrayMut<$tc>,
1573 par: &PirePar,
1574 )
1575 {
1576 let a_need_pool = a.is_strided() || !hw_config.is_compute_native();
1577 let b_need_pool = b.is_strided() || !hw_config.is_compute_native();
1578 if n == 1 && a.is_strided() {
1579 let alpha = &alpha as *const $t_as;
1580 let beta = &beta as *const $t_bs;
1581 $gemv_name(hw_config, m, k, alpha, a, b, beta, c);
1582 return;
1583 }
1584 if m == 1 && b.is_strided() {
1585 let alpha = &alpha as *const $t_as;
1586 let beta = &beta as *const $t_bs;
1587 let mut a = a;
1588 a.transpose();
1589 let mut b = b;
1590 b.transpose();
1591 let mut c = c;
1592 c.transpose();
1593 $gemv_name2(hw_config, n, k, alpha.into(), b, a, beta, c);
1594 return;
1595 }
1596 let (gemm_mode, gemm_fun, pool_info)
1597 : (
1598 GemmPool, unsafe fn(
1599 &$t_dispatcher <F>, usize, usize, usize, *const $t_as, $packa_ty, $packb_ty, *const $t_bs, ArrayMut<$tc>, &PireThreadConfig,
1600 ),
1601 PoolSize
1602 )
1603 = if run_small_m(m) && $run_small_m && b.is_strided() {
1604 (GemmPool::SmallM, $small_m_name, get_mem_pool_size_small_m::<$tap,$tbp,$t_dispatcher::<F>>(hw_config, par, a_need_pool))
1605 } else if run_small_n(n) && $run_small_n && a.is_strided() {
1606 (GemmPool::SmallN, $small_n_name, get_mem_pool_size_small_n::<$tap,$tbp,$t_dispatcher::<F>>(hw_config, par, b_need_pool))
1607 } else {
1608 (GemmPool::Goto, $goto_name, get_mem_pool_size_goto::<$tap,$tbp,$t_dispatcher::<F>>(hw_config, par, a_need_pool, b_need_pool))
1609 };
1610 let mem_pool_size = pool_info.mem_pool_size_b::<$tap,$tbp>();
1611 {
1622 let pool_guard = PACK_POOL.buffer.read().unwrap();
1623 let y = acquire(&pool_guard, mem_pool_size);
1624 if let Some(mut pool_vec) = y {
1625 let pool_buf = &mut pool_vec;
1626 $name_mt(
1627 hw_config, m, n, k, alpha, a, b, beta, c, par, pool_buf, gemm_mode, pool_info, gemm_fun
1628 );
1629 return;
1630 }
1631 }
1632 let mut pool_vec = vec![0_u8; mem_pool_size];
1633 let pool_buf = &mut pool_vec;
1634 $name_mt(
1635 hw_config, m, n, k, alpha, a, b, beta, c, par, pool_buf, gemm_mode, pool_info, gemm_fun
1636 );
1637 extend(pool_vec);
1638 }
1639
1640 pub(crate) unsafe fn $name_mt<F:UnaryFnC>(
1641 hw_config: &$t_dispatcher <F>,
1642 m: usize, n: usize, k: usize,
1643 alpha: $t_as,
1644 a: Array<$ta>,
1645 b: Array<$tb>,
1646 beta: $t_bs,
1647 c: ArrayMut<$tc>,
1648 par: &PirePar,
1649 pool_buf: &mut [u8],
1650 gemm_mode: GemmPool,
1651 pool_info: PoolSize,
1652 gemm_fn: unsafe fn(
1653 &$t_dispatcher <F>, usize, usize, usize, *const $t_as, $packa_ty, $packb_ty, *const $t_bs, ArrayMut<$tc>, &PireThreadConfig
1654 )
1655 )
1656 where $t_dispatcher <F>: GemmCache
1657 {
1658
1659 let mc_eff = <$t_dispatcher::<F> as GemmCache>::get_mc_eff(hw_config, par.ic_par);
1660 let nc_eff = <$t_dispatcher::<F> as GemmCache>::get_nc_eff(hw_config, par.jc_par);
1661 let kc_eff = <$t_dispatcher::<F> as GemmCache>::get_kc_eff(hw_config);
1662 let (pa_br_vec_ref, pb_br_vec_ref) = get_apbp_barrier(par);
1663
1664 let (i_load_par, j_load_par) = par.get_load_par(&gemm_mode, m, n, mc_eff, nc_eff);
1665 let (ap_pool_vec, bp_pool_vec) = pool_info.slice_mut_from_pool::<$tap,$tbp>(
1666 pool_buf, i_load_par, j_load_par, pool_info, hw_config.mr, hw_config.nr
1667 );
1668 let (ap_pool, bp_pool) = (&ap_pool_vec, &bp_pool_vec);
1669
1670 std::thread::scope(|s| {
1672 for t_id in 1..par.num_threads {
1673 let t_cfg = PireThreadConfig::new(
1674 par.clone(), &pa_br_vec_ref, &pb_br_vec_ref, t_id, mc_eff, nc_eff, kc_eff
1675 );
1676 let ic_id = t_cfg.ic_id;
1677 let jc_id = t_cfg.jc_id;
1678 let ap_id = match gemm_mode {
1679 GemmPool::Goto => ic_id,
1680 GemmPool::SmallM => ic_id,
1681 GemmPool::SmallN => t_id,
1682 };
1683 let bp_id = match gemm_mode {
1684 GemmPool::Goto => jc_id,
1685 GemmPool::SmallM => 0,
1686 GemmPool::SmallN => jc_id,
1687 };
1688 let ap_cur = a.$pack_fn(ap_pool, ap_id);
1689 let bp_cur = b.$pack_fn(bp_pool, bp_id);
1690 let g = hw_config;
1691 s.spawn(move || {
1692 let alpha = &alpha as *const $t_as;
1693 let beta = &beta as *const $t_bs;
1694 gemm_fn(g, m, n, k, alpha, ap_cur, bp_cur, beta, c, &t_cfg);
1695 }
1696 );
1697 }
1698 {
1699 let ap = a.$pack_fn(ap_pool, 0);
1700 let bp = b.$pack_fn(bp_pool, 0);
1701 let t_id: usize = 0;
1702 let t_cfg = PireThreadConfig::new(par.clone(), &pa_br_vec_ref, &pb_br_vec_ref, t_id, mc_eff, nc_eff, kc_eff);
1703 let alpha = &alpha as *const $t_as;
1704 let beta = &beta as *const $t_bs;
1705 gemm_fn(hw_config, m, n, k, alpha, ap, bp, beta, c, &t_cfg);
1706 }
1707 });
1708 }
1709
1710 unsafe fn $goto_name<F:UnaryFnC>(
1711 hw_cfg: &$t_dispatcher <F>,
1712 m: usize, n: usize, k: usize,
1713 alpha: *const $t_as,
1714 a: $packa_ty,
1715 b: $packb_ty,
1716 beta: *const $t_bs,
1717 c: ArrayMut<$tc>,
1718 t_cfg: &PireThreadConfig
1719 ) {
1720 let ic_id = t_cfg.ic_id;
1721 let jc_id = t_cfg.jc_id;
1722 let ir_id = t_cfg.ir_id;
1723 let jr_id = t_cfg.jr_id;
1724 let ir_par = t_cfg.par.ir_par;
1725 let jr_par = t_cfg.par.jr_par;
1726 let ic_par = t_cfg.par.ic_par;
1727 let jc_par = t_cfg.par.jc_par;
1728 let mc = t_cfg.mc_eff;
1729 let nc = t_cfg.nc_eff;
1730 let kc = t_cfg.kc_eff;
1731 let mr = hw_cfg.mr;
1732 let nr = hw_cfg.nr;
1733 let (mc_start, mc_end, mc_left) = split_c_range(m, mc, mr, ic_id, ic_par);
1734 let (nc_start, nc_end, nc_left) = split_c_range(n, nc, nr, jc_id, jc_par);
1735 let (kc_start, d1_end) = (0, k);
1736 let one = $one;
1737 let c_rs = c.rs();
1738 let c_cs = c.cs();
1739 let c_ptr = c.src();
1740 let mut mc_i = mc_start;
1741 while mc_i < mc_end {
1742 let mc_len = mc.min(mc_end - mc_i);
1743 let mut kc_i = kc_start;
1744 let (mr_start, mr_end) = split_range(mc_len, mr, ir_id, ir_par);
1745 let mr_len = mr_end - mr_start;
1746 let c_i = c_ptr.add((mc_i+mr_start) * c_rs);
1747 while kc_i < d1_end {
1748 let kc_len = kc.min(d1_end - kc_i);
1749 let kc_len_eff = hw_cfg.round_k(kc_len);
1750 let mut nc_i = nc_start;
1751 let kc_last = kc_i + kc_len == d1_end;
1752 let beta_t = if kc_i == kc_start { beta } else { &one as *const $t_bs};
1753 let ap_data = $packa_name(hw_cfg, &a, mc_i, kc_i, mc_len, kc_len, t_cfg);
1754 let ap = ap_data.src();
1755 let ap = ap.add(mr_start*kc_len_eff);
1756 while nc_i < nc_end {
1757 let nc_len = nc.min(nc_end - nc_i);
1758 let (nr_start, nr_end) = split_range(nc_len, nr, jr_id, jr_par);
1759 let nr_len = nr_end - nr_start;
1760 let c_ij = c_i.add((nc_i+nr_start) * c_cs);
1761 let bp_data = $packb_name(hw_cfg, &b, nc_i, kc_i, nc_len, kc_len, t_cfg);
1762 let bp = bp_data.src();
1763 let bp = bp.add(nr_start*kc_len_eff);
1764 $goto_kernel(
1765 hw_cfg, mr_len, nr_len, kc_len, alpha, beta_t, c_ij, c_rs, c_cs,
1766 ap, bp,
1767 kc_last,
1768 );
1769
1770 nc_i += nc;
1771 }
1772 if nc_left {
1773 t_cfg.wait_packb();
1774 t_cfg.wait_packb();
1775 }
1776 kc_i += kc;
1777 }
1778 mc_i += mc;
1779 }
1780 if mc_left {
1781 let mut kc_i = kc_start;
1782 while kc_i < d1_end {
1783 let kc_len = kc.min(d1_end -kc_i);
1784 t_cfg.wait_packa();
1785 t_cfg.wait_packa();
1786 let mut nc_i = nc_start;
1787 while nc_i < nc_end {
1788 let nc_len = nc.min(nc_end - nc_i);
1789 let _ = $packb_name(hw_cfg, &b, nc_i, kc_i, nc_len, kc_len, t_cfg);
1790 nc_i += nc;
1791 }
1792 if nc_left{
1793 t_cfg.wait_packb();
1794 t_cfg.wait_packb();
1795 }
1796 kc_i += kc;
1797 }
1798 }
1799 }
1800 unsafe fn $small_m_name<F:UnaryFnC>(
1801 hw_cfg: &$t_dispatcher <F>,
1802 m: usize, n: usize, k: usize,
1803 alpha: *const $t_as,
1804 a: $packa_ty,
1805 b: $packb_ty,
1806 beta: *const $t_bs,
1807 c: ArrayMut<$tc>,
1808 t_cfg: &PireThreadConfig
1809 ) {
1810 let par = &t_cfg.par;
1811 let ic_id = t_cfg.ic_id;
1812 let jc_id = t_cfg.jc_id;
1813 let ir_id = t_cfg.ir_id;
1814 let ir_par = par.ir_par;
1815 let jr_id = t_cfg.jr_id;
1816 let jr_par = par.jr_par;
1817 let mc = t_cfg.mc_eff;
1818 let nc = t_cfg.nc_eff;
1819 let kc = t_cfg.kc_eff;
1820 let mr = hw_cfg.mr;
1821 let nr = hw_cfg.nr;
1822 let (mc_start, mc_end, mc_left) = split_c_range(m, mc, mr, ic_id, par.ic_par);
1823 let (nc_start, nc_end, _) = split_c_range(n, nc, nr, jc_id, par.jc_par);
1824 let (kc_start, kc_end) = (0, k);
1825 let one = $one;
1826
1827 let b_ptr = b.src();
1828 let b_rs = b.rs();
1829 let b_cs = b.cs();
1830 let c_rs = c.rs();
1831 let c_cs = c.cs();
1832 let c_ptr = c.src();
1833 let mut mc_i = mc_start;
1834 while mc_i < mc_end {
1835 let mc_len = mc.min(mc_end - mc_i);
1836 let (mr_start, mr_end) = split_range(mc_len, mr, ir_id, ir_par);
1837 let mr_len = mr_end - mr_start;
1838 let c_i = c_ptr.add((mc_i+mr_start) * c_rs);
1839 let mut kc_i = kc_start;
1840 while kc_i < kc_end {
1841 let kc_len = kc.min(kc_end - kc_i);
1842 let kc_len_eff = hw_cfg.round_k(kc_len);
1843 let beta_t = if kc_i == kc_start { beta } else { &one as *const $t_bs};
1844 let kc_last = kc_i + kc_len == kc_end;
1845 let mut nc_i = nc_start;
1846 let ap_data = $packa_name(hw_cfg, &a, mc_i, kc_i, mc_len, kc_len, t_cfg);
1847 let ap = ap_data.src();
1848 let ap = ap.add(mr_start*kc_len_eff);
1849 let b_j = b_ptr.add(kc_i * b_rs);
1850 while nc_i < nc_end {
1851 let nc_len = nc.min(nc_end - nc_i);
1852 let (nr_start, nr_end) = split_range(nc_len, nr, jr_id, jr_par);
1853 let nr_len = nr_end - nr_start;
1854 let c_ij = c_i.add((nc_i + nr_start) * c_cs);
1855 let b_cur = b_j.add((nc_i + nr_start) * b_cs);
1856 $small_m_kernel(
1857 hw_cfg, mr_len, nr_len, kc_len, alpha, beta_t,
1858 b_cur, b_rs, b_cs,
1859 c_ij, c_rs, c_cs,
1860 ap,
1861 kc_last,
1862 );
1863 nc_i += nc;
1864 }
1865 kc_i += kc;
1866 }
1867 mc_i += mc;
1868 }
1869
1870 if mc_left {
1871 let mut kc_i = kc_start;
1872 while kc_i < kc_end {
1873 t_cfg.wait_packa();
1874 t_cfg.wait_packa();
1875 kc_i += kc;
1876 }
1877 }
1878 }
1879 unsafe fn $small_n_name<F:UnaryFnC>(
1880 hw_cfg: &$t_dispatcher <F>,
1881 m: usize, n: usize, k: usize,
1882 alpha: *const $t_as,
1883 a: $packa_ty,
1884 b: $packb_ty,
1885 beta: *const $t_bs,
1886 c: ArrayMut<$tc>,
1887 t_cfg: &PireThreadConfig
1888 ) {
1889 let par = &t_cfg.par;
1890 let ic_id = t_cfg.ic_id;
1891 let jc_id = t_cfg.jc_id;
1892 let ir_id = t_cfg.ir_id;
1893 let ir_par = par.ir_par;
1894 let jr_id = t_cfg.jr_id;
1895 let jr_par = par.jr_par;
1896 let mc = t_cfg.mc_eff;
1897 let nc = t_cfg.nc_eff;
1898 let kc = t_cfg.kc_eff;
1899 let mr = hw_cfg.mr;
1900 let nr = hw_cfg.nr;
1901 let (mc_start, mc_end, mc_left) = split_c_range(m, mc, mr, ic_id, par.ic_par);
1902 let (nc_start, nc_end, nc_left) = split_c_range(n, nc, nr, jc_id, par.jc_par);
1903 let (kc_start, kc_end) = (0, k);
1904 let one = $one;
1905
1906 let c_rs = c.rs();
1907 let c_cs = c.cs();
1908 let c_ptr = c.src();
1909 let a_ptr = a.src();
1910 let a_rs = a.rs();
1911 let a_cs = a.cs();
1912 let a_dst = a.dst_write(0, kc);
1914 let a_dst_ref = a_dst.get();
1915 let a_dst_ptr = a_dst_ref.as_mut_ptr();
1916 let mut mc_i = mc_start;
1917 while mc_i < mc_end {
1918 let mc_len = mc.min(mc_end - mc_i);
1919 let (mr_start, mr_end) = split_range(mc_len, mr, ir_id, ir_par);
1920 let mr_len = mr_end - mr_start;
1921 let c_i = c_ptr.add((mc_i+mr_start) * c_rs);
1922 let a_i = a_ptr.add((mc_i+mr_start) * a_rs);
1923 let mut kc_i = kc_start;
1924 while kc_i < kc_end {
1925 let kc_len = kc.min(kc_end - kc_i);
1926 let kc_last = kc_i + kc_len == kc_end;
1927 let beta_t = if kc_i == kc_start { beta } else { &one as *const $t_bs};
1928 let a_cur = a_i.add(kc_i*a_cs);
1929 let mut nc_i = nc_start;
1930 while nc_i < nc_end {
1931 let nc_len = nc.min(nc_end - nc_i);
1932 let (nr_start, nr_end) = split_range(nc_len, nr, jr_id, jr_par);
1933 let nr_len = nr_end - nr_start;
1934 let bp_data = $packb_name(hw_cfg, &b, nc_i, kc_i, nc_len, kc_len, t_cfg);
1935 let bp = bp_data.src();
1936 let c_ij = c_i.add((nc_i + nr_start) * c_cs);
1937 $small_n_kernel(
1938 hw_cfg, mr_len, nr_len, kc_len, alpha, beta_t,
1939 a_cur, a_rs, a_cs,
1940 a_dst_ptr, bp,
1941 c_ij, c_rs, c_cs,
1942 kc_last,
1943 );
1944 nc_i += nc;
1945 }
1946 if nc_left {
1947 t_cfg.wait_packb();
1948 t_cfg.wait_packb();
1949 }
1950 kc_i += kc;
1951 }
1952 mc_i += mc;
1953 }
1954 if mc_left {
1955 let mut kc_i = kc_start;
1956 while kc_i < kc_end {
1957 let kc_len = kc.min(kc_end - kc_i);
1958 let mut nc_i = nc_start;
1959 while nc_i < nc_end {
1960 let nc_len = nc.min(nc_end - nc_i);
1961 let _ = $packb_name(hw_cfg, &b, nc_i, kc_i, nc_len, kc_len, t_cfg);
1962 nc_i += nc;
1963 }
1964 if nc_left{
1965 t_cfg.wait_packb();
1966 t_cfg.wait_packb();
1967 }
1968 kc_i += kc;
1969 }
1970 }
1971 }
1972 pub(crate) unsafe fn $packa_name<'a,'b,F:UnaryFnC>(hw_cfg: &$t_dispatcher <F>, x: &'b $packa_ty<'a>, mc_i: usize, kc_i: usize, mc_len: usize, kc_len: usize, t_cfg: &PireThreadConfig) -> PtrData<'a,$tap> {
1978 t_cfg.wait_packa();
1979 let xp_ptr = match x {
1980 $packa_ty::StridedMatrix(x_i) => {
1981 let mc_par = x_i.get_mc();
1982 let mc_offset = mc_par * t_cfg.i_load_p_idx;
1983 if mc_len > mc_offset {
1984 let kc_len_ro = hw_cfg.round_k(kc_len);
1985 let mc_len_x = (mc_len - mc_offset).min(mc_par);
1986 let mc_i = mc_i + mc_offset;
1987 let (rs, cs) = (x_i.rs(), x_i.cs());
1988 let src_ptr = x_i.src().add(mc_i*rs + kc_i*cs);
1989 let dst = x_i.dst_write(t_cfg.i_load_p_idx, kc_len_ro);
1990 let dst_ref = dst.get();
1991 let dst_ptr = dst_ref.as_mut_ptr();
1992 $packa_name0(src_ptr, dst_ptr, mc_len_x, kc_len, rs, cs);
1993 }
1994 t_cfg.wait_packa();
1995 PtrData::RefData(x_i.dst_read())
1996 }
1997 $packa_ty::PackedMatrix(x_i) => {
1998 let m_ro = hw_cfg.round_m(x_i.m());
1999 let kc_len_ro = hw_cfg.round_k(kc_len);
2000 let res = is_mixed!(
2001 $include_flag,
2002 {
2003 let mc_par = x_i.get_mc();
2004 let mc_offset = mc_par * t_cfg.i_load_p_idx;
2005 if mc_len > mc_offset {
2006 let mc_len_x = (mc_len - mc_offset).min(mc_par);
2007 let mc_i = mc_i + mc_offset;
2008 let src_ptr = x_i.src().add(mc_i*kc_len_ro + kc_i*m_ro);
2009 let dst = x_i.dst_write(t_cfg.i_load_p_idx, kc_len_ro);
2010 let dst_ref = dst.get();
2011 let dst_ptr = dst_ref.as_mut_ptr();
2012 let mc_len_x_ro = hw_cfg.round_m(mc_len_x);
2013 hw_cfg.cvt_mixed(src_ptr, dst_ptr, mc_len_x_ro*kc_len_ro);
2014 }
2015 t_cfg.wait_packa();
2016 PtrData::RefData(x_i.dst_read())
2017 },
2018 {
2019 let src_ptr = x_i.src().add(mc_i*kc_len_ro + kc_i*m_ro);
2020 t_cfg.wait_packa();
2021 PtrData::PtrData(src_ptr)
2022 }
2023
2024 );
2025 res
2026
2027 }
2028 };
2029 xp_ptr
2030 }
2031 pub(crate) unsafe fn $packb_name<'a,'b,F:UnaryFnC>(hw_cfg: & $t_dispatcher <F>, x: &'b$packb_ty<'a>, nc_i: usize, kc_i: usize, nc_len: usize, kc_len: usize, t_cfg: &PireThreadConfig) -> PtrData<'a,$tbp> {
2033 t_cfg.wait_packb();
2034 let xp_ptr = match x {
2035 $packb_ty::StridedMatrix(x_i) => {
2036 let nc_par = x_i.get_mc();
2037 let nc_offset = nc_par * t_cfg.j_load_p_idx;
2038 if nc_len > nc_offset {
2039 let kc_len_ro = hw_cfg.round_k(kc_len);
2040 let nc_len_x = (nc_len - nc_offset).min(nc_par);
2041 let nc_i = nc_i + nc_offset;
2042 let rs = x_i.rs();
2043 let cs = x_i.cs();
2044 let src_ptr = x_i.src().add(kc_i*rs + nc_i*cs);
2045 let dst = x_i.dst_write(t_cfg.j_load_p_idx, kc_len_ro);
2046 let dst_ref = dst.get();
2047 let dst_ptr = dst_ref.as_mut_ptr();
2048 $packb_name0(src_ptr, dst_ptr, nc_len_x, kc_len, rs, cs);
2049 }
2050 t_cfg.wait_packb();
2051 PtrData::RefData(x_i.dst_read())
2052 }
2053 $packb_ty::PackedMatrix(x_i) => {
2054 let kc_len_ro = hw_cfg.round_k(kc_len);
2055 let n_ro = x_i.m();
2056 let res = is_mixed!(
2057 $include_flag,
2058 {
2059 let nc_par = x_i.get_mc();
2060 let nc_offset = nc_par * t_cfg.j_load_p_idx;
2061 if nc_len > nc_offset {
2062 let nc_len_x = (nc_len - nc_offset).min(nc_par);
2063 let nc_i = nc_i + nc_offset;
2064 let src_ptr = x_i.src().add(nc_i*kc_len_ro + kc_i*n_ro);
2065 let dst = x_i.dst_write(t_cfg.j_load_p_idx, kc_len_ro);
2066 let dst_ref = dst.get();
2067 let dst_ptr = dst_ref.as_mut_ptr();
2068 hw_cfg.cvt_mixed(src_ptr, dst_ptr, nc_len_x*kc_len_ro);
2069 }
2070 t_cfg.wait_packb();
2071 PtrData::RefData(x_i.dst_read())
2072 },
2073 {
2074 let src_ptr = x_i.src().add(nc_i*kc_len_ro + kc_i*n_ro);
2075 t_cfg.wait_packb();
2076 PtrData::PtrData(src_ptr)
2077 }
2078
2079 );
2080 res
2081 }
2082 };
2083 xp_ptr
2084 }
2085 }
2086}
2087#[macro_export]
2088macro_rules! partial_strided {
2089 ($strided:tt, $strided2:tt, F) => {
2090 $strided
2091 };
2092 ($strided:tt, $strided2:tt, T) => {
2093 $strided2
2094 };
2095}
2096
2097#[target_feature(enable = "avx")]
2098#[cfg(target_arch = "x86_64")]
2099unsafe fn vzeroupper_unchecked() {
2100 core::arch::x86_64::_mm256_zeroupper();
2101}
2102
2103pub fn avx_vzeroupper() {
2104 #[cfg(target_arch = "x86_64")]
2105 if (*RUNTIME_HW_CONFIG).cpu_ft.avx {
2106 unsafe {
2107 vzeroupper_unchecked();
2108 }
2109 }
2110}
2111
2112#[macro_export]
2113macro_rules! def_kernel_bb_pf1 {
2114 (
2115 $t_ap:ty, $t_bp:ty, $t_c:ty, $t_s:ty,
2116 $no_partial:tt,
2117 $RS:tt,
2118 $MR:tt, $NR:tt, $pf1_0:tt, $pf_step:tt
2119 ) => {
2120
2121 pub unsafe fn kernel_bb<F: UnaryFnC, const STRIDED: bool>(
2122 m: usize, n: usize, k: usize,
2123 alpha: *const $t_s,
2124 beta: *const $t_s,
2125 c: *mut $t_c, c_rs: usize, c_cs: usize,
2126 ap: *const $t_ap, bp: *const $t_bp,
2127 f: F,
2128 ) {
2129 const STRIDED_PARTIAL: bool = true;
2130 const MR: usize = $MR * VS;
2131 const NR: usize = $NR;
2132 let m_rounded = m / MR * MR;
2133 let n_rounded = n / NR * NR;
2134 let m_left = m % MR;
2135 let n_left = n % NR;
2136
2137 let d_arr = [0, 0, c_rs];
2138
2139 let mut m_i = 0;
2140 while m_i < m_rounded {
2141 let c_cur0 = c.add(m_i * c_rs);
2142 let ap_cur = ap.add(m_i * k);
2143 let mut a_pft1_offset = $pf1_0 * k;
2144 let mut n_i = 0;
2145 while n_i < n_rounded {
2146 let bp_cur = bp.add(n_i * k);
2147 let c_cur1 = c_cur0.add(n_i * c_cs);
2148 ukernel_bbc::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, c_cs, a_pft1_offset, NR, f);
2149 n_i += NR;
2150 a_pft1_offset += $pf_step * k;
2151 }
2152 if n_left != 0 {
2154 let bp_cur = bp.add(n_i * k);
2155 let c_cur1 = c_cur0.add(n_i * c_cs);
2156 ukernel_n_bbc::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, c_cs, MR, n_left, f);
2157 }
2158 m_i += MR;
2159 }
2160
2161
2162 seq_macro::seq!(mr_left in 1..=$MR {
2163 if (m_left+VS-1) / VS == mr_left {
2164 let c_cur0 = c.add(m_i * c_rs);
2165 let ap_cur = ap.add(m_i * k);
2166 let mut n_i = 0;
2167 while n_i < n_rounded {
2168 let bp_cur = bp.add(n_i * k);
2169 let c_cur1 = c_cur0.add(n_i * c_cs);
2170 paste::paste! {
2171 [<ukernel_ mr_left _bbp>]::<_, pire_base::partial_strided!(STRIDED,STRIDED_PARTIAL,$no_partial)>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, c_cs, m_left, NR, f);
2172 }
2173 n_i += NR;
2174 }
2175 if n_left !=0 {
2176 let bp_cur = bp.add(n_i * k);
2177 let c_cur1 = c_cur0.add(n_i * c_cs);
2178 paste::paste! {
2179 [<ukernel_ mr_left xn_bbp>]::<_, pire_base::partial_strided!(STRIDED,STRIDED_PARTIAL,$no_partial)>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, c_cs, m_left, n_left, f);
2180 }
2181 }
2182 }
2183 });
2184 }
2185
2186 pub(crate) unsafe fn kernel<F: UnaryFnC>(
2187 m: usize,
2188 n: usize,
2189 k: usize,
2190 alpha: *const $t_s,
2191 beta: *const $t_s,
2192 c: *mut $t_c,
2193 c_rs: usize,
2194 c_cs: usize,
2195 ap: *const $t_ap,
2196 bp: *const $t_bp,
2197 f: F,
2198 ) {
2199 let k = (k + $RS - 1) / $RS * $RS;
2200 if c_rs == 1 {
2201 kernel_bb::<_, false>(m, n, k, alpha, beta, c, c_rs, c_cs, ap, bp, f)
2202 } else {
2203 kernel_bb::<_, true>(m, n, k, alpha, beta, c, c_rs, c_cs, ap, bp, f)
2204 }
2205 pire_base::avx_vzeroupper();
2206 }
2207
2208 };
2209}
2210
2211#[macro_export]
2212macro_rules! def_kernel_bb_v0 {
2213 (
2214 $t_ap:ty, $t_bp:ty, $t_c:ty, $t_s:ty,
2215 $no_partial:tt,
2216 $RS:tt,
2217 $MR:tt, $NR:tt
2218 ) => {
2219 pub unsafe fn kernel_bb<F: UnaryFnC, const STRIDED: bool>(
2220 m: usize, n: usize, k: usize,
2221 alpha: *const $t_s,
2222 beta: *const $t_s,
2223 c: *mut $t_c, c_rs: usize, c_cs: usize,
2224 ap: *const $t_ap, bp: *const $t_bp,
2225 f: F,
2226 ) {
2227 let vs = simd_vector_length();
2228 const STRIDED_PARTIAL: bool = true;
2229 let mr = $MR * vs;
2230 const NR: usize = $NR;
2231 let m_rounded = m / mr * mr;
2232 let n_rounded = n / NR * NR;
2233 let m_left = m % mr;
2234 let n_left = n % NR;
2235
2236 let d_arr = [0, 0, c_rs];
2237
2238 let mut m_i = 0;
2239 while m_i < m_rounded {
2240 let c_cur0 = c.add(m_i * c_rs);
2241 let ap_cur = ap.add(m_i * k);
2242 let mut n_i = 0;
2243 while n_i < n_rounded {
2244 let bp_cur = bp.add(n_i * k);
2245 let c_cur1 = c_cur0.add(n_i * c_cs);
2246 ukernel_bbc::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, c_cs, mr, NR, f);
2247 n_i += NR;
2248 }
2249 if n_left != 0 {
2250 let bp_cur = bp.add(n_i * k);
2251 let c_cur1 = c_cur0.add(n_i * c_cs);
2252 ukernel_n_bbc::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, c_cs, mr, n_left, f);
2253 }
2254 m_i += mr;
2255 }
2256
2257 seq_macro::seq!(mr_left in 1..=$MR {
2258 if (m_left+vs-1) / vs == mr_left {
2259 let c_cur0 = c.add(m_i * c_rs);
2260 let ap_cur = ap.add(m_i * k);
2261 let mut n_i = 0;
2262 while n_i < n_rounded {
2263 let bp_cur = bp.add(n_i * k);
2264 let c_cur1 = c_cur0.add(n_i * c_cs);
2265 paste::paste! {
2266 [<ukernel_ mr_left _bbp>]::<_, pire_base::partial_strided!(STRIDED,STRIDED_PARTIAL,$no_partial)>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, c_cs, m_left, NR, f);
2267 }
2268 n_i += NR;
2269 }
2270 if n_left !=0 {
2271 let bp_cur = bp.add(n_i * k);
2272 let c_cur1 = c_cur0.add(n_i * c_cs);
2273 paste::paste! {
2274 [<ukernel_ mr_left xn_bbp>]::<_, pire_base::partial_strided!(STRIDED,STRIDED_PARTIAL,$no_partial)>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, c_cs, m_left, n_left, f);
2275 }
2276 }
2277 }
2278 });
2279 }
2280 pub(crate) unsafe fn kernel<F: UnaryFnC>(
2281 m: usize,
2282 n: usize,
2283 k: usize,
2284 alpha: *const $t_s,
2285 beta: *const $t_s,
2286 c: *mut $t_c,
2287 c_rs: usize,
2288 c_cs: usize,
2289 ap: *const $t_ap,
2290 bp: *const $t_bp,
2291 f: F,
2292 ) {
2293 let k = (k + $RS - 1) / $RS * $RS;
2294 if c_rs == 1 {
2295 kernel_bb::<_, false>(m, n, k, alpha, beta, c, c_rs, c_cs, ap, bp, f)
2296 } else {
2297 kernel_bb::<_, true>(m, n, k, alpha, beta, c, c_rs, c_cs, ap, bp, f)
2298 }
2299 pire_base::avx_vzeroupper();
2300 }
2301 };
2302}
2303
2304#[macro_export]
2305macro_rules! def_kernel_sb_pf1 {
2306 (
2307 $t_a:ty, $t_ap:ty, $t_bp:ty, $t_c:ty, $t_s:ty,
2308 $pack_fn:tt,
2309 $RS:tt,
2310 $MR:tt, $NR:tt, $pf1_0:tt, $pf_step:tt
2311 ) => {
2312 pub unsafe fn kernel_sb_v0<F: UnaryFnC, const STRIDED: bool>(
2313 m: usize, n: usize, k: usize,
2314 alpha: *const $t_s, beta: *const $t_s,
2315 a: *const $t_a, a_rs: usize, a_cs: usize,
2316 bp: *const $t_bp,
2317 c: *mut $t_c, c_rs: usize, c_cs: usize,
2318 ap: *mut $t_ap,
2319 f: F,
2320 ) {
2321 let k_eff = (k+$RS-1) / $RS * $RS;
2322 const MR: usize = $MR * VS;
2323 const NR: usize = $NR;
2324 let m_rounded = m / MR * MR;
2325 let n_rounded = n / NR * NR;
2326 let m_left = m % MR;
2327 let n_left = n % NR;
2328
2329 let d_arr = [0, 0, c_rs];
2330
2331 let mut m_i = 0;
2332 while m_i < m_rounded {
2333 let c_cur0 = c.add(m_i * c_rs);
2334 let a_cur = a.add(m_i * a_rs);
2335 let a_pft1_offset = $pf1_0 * k;
2336 $pack_fn(MR, k, a_cur, a_rs, a_cs, ap, VS);
2337 let mut n_i = 0;
2338 while n_i < n_rounded {
2339 let bp_cur = bp.add(n_i * k_eff);
2340 let c_cur1 = c_cur0.add(n_i * c_cs);
2341 ukernel_bbc::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, c_cs, a_pft1_offset, NR, f);
2342 n_i += NR;
2343 }
2344 if n_left != 0 {
2345 let bp_cur = bp.add(n_i * k_eff);
2346 let c_cur1 = c_cur0.add(n_i * c_cs);
2347 ukernel_n_bbc::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, c_cs, MR, n_left, f);
2348 }
2349 m_i += MR;
2350 }
2351
2352 seq_macro::seq!(mr_left in 1..=$MR {
2353 if (m_left+VS-1) / VS == mr_left {
2354 let c_cur0 = c.add(m_i * c_rs);
2355 let a_cur = a.add(m_i * a_rs);
2356 $pack_fn(m_left, k, a_cur, a_rs, a_cs, ap, VS);
2357 let mut n_i = 0;
2358 while n_i < n_rounded {
2359 let bp_cur = bp.add(n_i * k_eff);
2360 let c_cur1 = c_cur0.add(n_i * c_cs);
2361 paste::paste! {
2362 [<ukernel_ mr_left _bbp>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, c_cs, m_left, NR, f);
2363 }
2364 n_i += NR;
2365 }
2366 if n_left != 0 {
2367 let bp_cur = bp.add(n_i * k_eff);
2368 let c_cur1 = c_cur0.add(n_i * c_cs);
2369 paste::paste! {
2370 [<ukernel_ mr_left xn_bbp>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, c_cs, m_left, n_left, f);
2371 }
2372 }
2373 return;
2374 }
2375 });
2376 }
2377
2378 pub(crate) unsafe fn kernel_sb<F: UnaryFnC>(
2379 m: usize,
2380 n: usize,
2381 k: usize,
2382 alpha: *const $t_s,
2383 beta: *const $t_s,
2384 a: *const $t_a,
2385 a_rs: usize,
2386 a_cs: usize,
2387 b: *const $t_bp,
2388 c: *mut $t_c,
2389 c_rs: usize,
2390 c_cs: usize,
2391 ap_buf: *mut $t_ap,
2392 f: F,
2393 ) {
2394 if c_rs == 1 {
2395 kernel_sb_v0::<_, false>(m, n, k, alpha, beta, a, a_rs, a_cs, b, c, c_rs, c_cs, ap_buf, f);
2396 } else {
2397 kernel_sb_v0::<_, true>(m, n, k, alpha, beta, a, a_rs, a_cs, b, c, c_rs, c_cs, ap_buf, f);
2398 }
2399 pire_base::avx_vzeroupper();
2400 }
2401 };
2402}
2403
2404#[macro_export]
2405macro_rules! def_kernel_sb_v0 {
2406 (
2407 $t_a:ty, $t_ap:ty, $t_bp:ty, $t_c:ty, $t_s:ty,
2408 $no_partial:tt,
2409 $pack_fn:tt,
2410 $RS:tt,
2411 $MR:tt, $NR:tt
2412 ) => {
2413 pub unsafe fn kernel_sb_v0<F: UnaryFnC, const STRIDED: bool>(
2414 m: usize, n: usize, k: usize,
2415 alpha: *const $t_s, beta: *const $t_s,
2416 a: *const $t_a, a_rs: usize, a_cs: usize,
2417 bp: *const $t_bp,
2418 c: *mut $t_c, c_rs: usize, c_cs: usize,
2419 ap: *mut $t_ap,
2420 f: F,
2421 ) {
2422 let k_eff = (k+$RS-1) / $RS * $RS;
2423 const STRIDED_PARTIAL: bool = true;
2424 let vs = simd_vector_length();
2425 let mr = $MR * vs;
2426 const NR: usize = $NR;
2427 let m_rounded = m / mr * mr;
2428 let n_rounded = n / NR * NR;
2429 let m_left = m % mr;
2430 let n_left = n % NR;
2431
2432 let d_arr = [0, 0, c_rs];
2433
2434 let mut m_i = 0;
2435 while m_i < m_rounded {
2436 let c_cur0 = c.add(m_i * c_rs);
2437 let a_cur = a.add(m_i * a_rs);
2438 $pack_fn(mr, k, a_cur, a_rs, a_cs, ap, vs);
2439 let mut n_i = 0;
2440 while n_i < n_rounded {
2441 let bp_cur = bp.add(n_i * k_eff);
2442 let c_cur1 = c_cur0.add(n_i * c_cs);
2443 ukernel_bbc::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, c_cs, mr, NR, f);
2444 n_i += NR;
2445 }
2446 if n_left != 0 {
2447 let bp_cur = bp.add(n_i * k_eff);
2448 let c_cur1 = c_cur0.add(n_i * c_cs);
2449 ukernel_n_bbc::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, c_cs, mr, n_left, f);
2450 }
2451 m_i += mr;
2452 }
2453
2454 seq_macro::seq!(mr_left in 1..=$MR {
2455 if (m_left+vs-1) / vs == mr_left {
2456 let c_cur0 = c.add(m_i * c_rs);
2457 let a_cur = a.add(m_i * a_rs);
2458 $pack_fn(m_left, k, a_cur, a_rs, a_cs, ap, vs);
2459 let mut n_i = 0;
2460 while n_i < n_rounded {
2461 let bp_cur = bp.add(n_i * k_eff);
2462 let c_cur1 = c_cur0.add(n_i * c_cs);
2463 paste::paste! {
2464 [<ukernel_ mr_left _bbp>]::<_, pire_base::partial_strided!(STRIDED,STRIDED_PARTIAL,$no_partial)>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, c_cs, m_left, NR, f);
2465 }
2466 n_i += NR;
2467 }
2468 if n_left != 0 {
2469 let bp_cur = bp.add(n_i * k_eff);
2470 let c_cur1 = c_cur0.add(n_i * c_cs);
2471 paste::paste! {
2472 [<ukernel_ mr_left xn_bbp>]::<_, pire_base::partial_strided!(STRIDED,STRIDED_PARTIAL,$no_partial)>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, c_cs, m_left, n_left, f);
2473 }
2474 }
2475 return;
2476 }
2477 });
2478 }
2479
2480 pub(crate) unsafe fn kernel_sb<F: UnaryFnC>(
2481 m: usize,
2482 n: usize,
2483 k: usize,
2484 alpha: *const $t_s,
2485 beta: *const $t_s,
2486 a: *const $t_a,
2487 a_rs: usize,
2488 a_cs: usize,
2489 b: *const $t_bp,
2490 c: *mut $t_c,
2491 c_rs: usize,
2492 c_cs: usize,
2493 ap_buf: *mut $t_ap,
2494 f: F,
2495 ) {
2496 if c_rs == 1 {
2497 kernel_sb_v0::<_, false>(m, n, k, alpha, beta, a, a_rs, a_cs, b, c, c_rs, c_cs, ap_buf, f);
2498 } else {
2499 kernel_sb_v0::<_, true>(m, n, k, alpha, beta, a, a_rs, a_cs, b, c, c_rs, c_cs, ap_buf, f);
2500 }
2501 pire_base::avx_vzeroupper();
2502 }
2503 };
2504}
2505
2506#[macro_export]
2507macro_rules! def_kernel_bs {
2508 (
2509 $t_ap:ty, $t_b:ty, $t_c:ty, $t_s:ty,
2510 $MR:tt, $NR:tt
2511 ) => {
2512 pub unsafe fn kernel_bs_v0<F: UnaryFnC, const STRIDED: bool>(
2513 m: usize, n: usize, k: usize,
2514 alpha: *const $t_s, beta: *const $t_s,
2515 b: *const $t_b, b_rs: usize, b_cs: usize,
2516 c: *mut $t_c, c_rs: usize, c_cs: usize,
2517 ap: *const $t_ap,
2518 f: F,
2519 ) {
2520 const MR: usize = $MR * VS;
2521 const NR: usize = $NR;
2522 let m_rounded = m / MR * MR;
2523 let n_rounded = n / NR * NR;
2524 let m_left = m % MR;
2525 let n_left = n % NR;
2526
2527 let d_arr = [b_rs, b_cs, c_rs];
2528
2529 let mut m_i = 0;
2530 while m_i < m_rounded {
2531 let c_cur0 = c.add(m_i * c_rs);
2532 let ap_cur = ap.add(m_i * k);
2533 let mut n_i = 0;
2534 while n_i < n_rounded {
2535 let b_cur = b.add(n_i * b_cs);
2536 let c_cur1 = c_cur0.add(n_i * c_cs);
2537 ukernel_bsc::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, beta, k, d_arr, c_cs, MR, NR, f);
2538 n_i += NR;
2539 }
2540 if n_left != 0 {
2541 let b_cur = b.add(n_i * b_cs);
2542 let c_cur1 = c_cur0.add(n_i * c_cs);
2543 ukernel_n_bsc::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, beta, k, d_arr, c_cs, MR, n_left, f);
2544 }
2545 m_i += MR;
2546 }
2547 seq_macro::seq!(mr_left in 1..=$MR {
2548 if (m_left+VS-1) / VS == mr_left {
2549 let c_cur0 = c.add(m_i * c_rs);
2550 let ap_cur = ap.add(m_i * k);
2551 let mut n_i = 0;
2552 while n_i < n_rounded {
2553 let b_cur = b.add(n_i * b_cs);
2554 let c_cur1 = c_cur0.add(n_i * c_cs);
2555 paste::paste! {
2556 [<ukernel_ mr_left _bsp>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, beta, k, d_arr, c_cs, m_left, NR, f);
2557 }
2558 n_i += NR;
2559 }
2560 if n_left != 0 {
2561 let b_cur = b.add(n_i * b_cs);
2562 let c_cur1 = c_cur0.add(n_i * c_cs);
2563 paste::paste! {
2564 [<ukernel_ mr_left xn_bsp>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, beta, k, d_arr, c_cs, m_left, n_left, f);
2565 }
2566 }
2567 return;
2568 }
2569 });
2570 }
2571
2572 pub(crate) unsafe fn kernel_bs<F: UnaryFnC>(
2573 m: usize,
2574 n: usize,
2575 k: usize,
2576 alpha: *const $t_s,
2577 beta: *const $t_s,
2578 b: *const $t_b,
2579 b_rs: usize,
2580 b_cs: usize,
2581 c: *mut $t_c,
2582 c_rs: usize,
2583 c_cs: usize,
2584 ap: *const $t_ap,
2585 f: F,
2586 ) {
2587 if c_rs == 1 {
2588 kernel_bs_v0::<_, false>(m, n, k, alpha, beta, b, b_rs, b_cs, c, c_rs, c_cs, ap, f);
2589 } else {
2590 kernel_bs_v0::<_, true>(m, n, k, alpha, beta, b, b_rs, b_cs, c, c_rs, c_cs, ap, f);
2591 }
2592 pire_base::avx_vzeroupper();
2593 }
2594 };
2595}
2596
2597#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2598#[macro_export]
2599macro_rules! mem {
2600 ($m0:tt, $b0:tt) => {
2601 concat!($b0, "+", $m0)
2602 };
2603}
2604
2605#[cfg(target_arch = "aarch64")]
2606#[macro_export]
2607macro_rules! mem {
2608 ($m0:tt, $b0:tt, $b1:tt) => {
2609 concat!("[", $m0, ", #", $b0, ", ", $b1, "]")
2610 };
2611 ($m0:tt, $b0:tt) => {
2612 concat!("[", $m0, ", #", $b0, "]")
2613 };
2614 ($m0:tt) => {
2615 concat!("[", $m0, "]")
2616 };
2617}
2618
2619#[macro_export]
2620macro_rules! n_cond {
2621 (1, $ni:tt, $nr:tt) => {
2622 $ni == $nr
2623 };
2624 ($n0:tt, $ni:tt, $nr:tt) => {
2625 true
2626 };
2627}
2628
2629#[macro_export]
2630macro_rules! load_a_avx {
2631 ($mr:tt, $K:tt) => {
2632 pire_base::loadp_avx!($mr, concat!($mr, "*32*", $K, "({ax})"))
2633 };
2634}
2635#[macro_export]
2636macro_rules! load_a_avx512 {
2637 ($mr:tt) => {
2638 pire_base::loadp_avx512!($mr, "0({ax})")
2639 };
2640}
2641
2642#[macro_export]
2652macro_rules! init_ab {
2653 (B) => {
2654 concat!(
2655 "/* {x5} */\n",
2656 "/* {x4} */\n",
2657 "/* {x3} */\n",
2658 "/* {x2} */\n",
2659 "/* {x1} */\n",
2660 "mov 24({dim_arrx}),{x0}\n",
2661 )
2662 };
2663 (S) => {
2664 concat!(
2665 "mov ({dim_arrx}), {x1}\n",
2667 "mov 8({dim_arrx}), {x2}\n",
2668 "lea ({x2}, {x2}, 2), {x5}\n",
2669 "lea ({bx}, {x5}, 1), {x3}\n",
2670 "lea ({x3}, {x5}, 1), {x4}\n",
2671 "lea ({x4}, {x5}, 1), {x5}\n",
2672 "mov 24({dim_arrx}),{x0}\n",
2673 )
2674 };
2675}
2676
2677#[macro_export]
2678macro_rules! init_ab_2 {
2679 (B) => {
2680 concat!("mov 8({dim_arrx}),{x0}\n",)
2681 };
2682}
2683
2684#[macro_export]
2685macro_rules! c_load {
2686 () => {
2687 concat!(
2688 "mov 16({dim_arrx}),{x0}\n",
2689 "lea ({x0}, {x0}, 2), {x3}\n",
2690 "lea ({cx}, {x3},), {x1}\n",
2691 "lea ({x1}, {x3},), {x2}\n",
2692 "lea ({x2}, {x3},), {x3}\n",
2693 )
2694 };
2695}
2696
2697#[macro_export]
2698macro_rules! init_ab_avx {
2699 (B) => {
2700 concat!("/* {x3} */\n", "/* {x2} */\n", "/* {x1} */\n", "mov 24({dim_arrx}),{x0}\n",)
2701 };
2702 (S) => {
2703 concat!(
2704 "mov ({dim_arrx}), {x1}\n",
2706 "mov 8({dim_arrx}), {x2}\n",
2707 "lea ({x2}, {x2}, 2), {x3}\n",
2708 "lea ({bx}, {x3}, 1), {x3}\n",
2709 "mov 24({dim_arrx}),{x0}\n",
2710 )
2711 };
2712}
2713
2714#[macro_export]
2715macro_rules! b_reg {
2716 (0) => {
2717 "({bx})"
2718 };
2719 (1) => {
2720 "({bx},{x2},1)"
2721 };
2722 (2) => {
2723 "({bx},{x2},2)"
2724 };
2725 (3) => {
2726 "({x3})"
2727 };
2728 (4) => {
2729 "({x3},{x2},1)"
2730 };
2731 (5) => {
2732 "({x3},{x2},2)"
2733 };
2734 (6) => {
2735 "({x4})"
2736 };
2737 (7) => {
2738 "({x4},{x2},1)"
2739 };
2740 (8) => {
2741 "({x4},{x2},2)"
2742 };
2743 (9) => {
2744 "({x5})"
2745 };
2746 (10) => {
2747 "({x5},{x2},1)"
2748 };
2749 (11) => {
2750 "({x5},{x2},2)"
2751 };
2752}
2753
2754#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
2755#[macro_export]
2756macro_rules! c_mem {
2757 (0) => {
2758 "0({cx})"
2759 };
2760 (1) => {
2761 "0({cx}, {x0})"
2762 };
2763 (2) => {
2764 "0({cx}, {x0}, 2)"
2765 };
2766 (3) => {
2767 "0({x1})"
2768 };
2769 (4) => {
2770 "0({x1}, {x0})"
2771 };
2772 (5) => {
2773 "0({x1}, {x0}, 2)"
2774 };
2775 (6) => {
2776 "0({x2})"
2777 };
2778 (7) => {
2779 "0({x2}, {x0})"
2780 };
2781 (8) => {
2782 "0({x2}, {x0}, 2)"
2783 };
2784 (9) => {
2785 "0({x3})"
2786 };
2787 (10) => {
2788 "0({x3}, {x0})"
2789 };
2790 (11) => {
2791 "0({x3}, {x0}, 2)"
2792 };
2793 (12) => {
2794 "0({x4})"
2795 };
2796 (13) => {
2797 "0({x4}, {x0})"
2798 };
2799 (14) => {
2800 "0({x4}, {x0}, 2)"
2801 };
2802}
2803
2804#[cfg(target_arch = "aarch64")]
2805#[macro_export]
2806macro_rules! c_mem {
2807 (0) => {
2808 "{cx}"
2809 };
2810 (1) => {
2811 "{x1}"
2812 };
2813 (2) => {
2814 "{x2}"
2815 };
2816 (3) => {
2817 "{x3}"
2818 };
2819 (4) => {
2820 "{x4}"
2821 };
2822 (5) => {
2823 "{x5}"
2824 };
2825 (6) => {
2826 "{x6}"
2827 };
2828 (7) => {
2829 "{x7}"
2830 };
2831 (8) => {
2832 "{x8}"
2833 };
2834 (9) => {
2835 "{x9}"
2836 };
2837 (10) => {
2838 "{x10}"
2839 };
2840 (11) => {
2841 "{x11}"
2842 };
2843 (12) => {
2844 "{x12}"
2845 };
2846 (13) => {
2847 "{x13}"
2848 };
2849 (14) => {
2850 "{x14}"
2851 };
2852}
2853
2854#[macro_export]
2855macro_rules! c_reg_2x4 {
2856 (0,0) => {
2857 4
2858 };
2859 (1,0) => {
2860 5
2861 };
2862 (0,1) => {
2863 6
2864 };
2865 (1,1) => {
2866 7
2867 };
2868 (0,2) => {
2869 8
2870 };
2871 (1,2) => {
2872 9
2873 };
2874 (0,3) => {
2875 10
2876 };
2877 (1,3) => {
2878 11
2879 };
2880}
2881#[macro_export]
2882macro_rules! c_reg_1x4 {
2883 (0,0) => {
2884 7
2885 };
2886 (0,1) => {
2887 8
2888 };
2889 (0,2) => {
2890 9
2891 };
2892 (0,3) => {
2893 10
2894 };
2895}
2896#[macro_export]
2897macro_rules! c_reg_3x4 {
2898 (0,0) => {
2899 4
2900 };
2901 (1,0) => {
2902 5
2903 };
2904 (2,0) => {
2905 6
2906 };
2907 (0,1) => {
2908 7
2909 };
2910 (1,1) => {
2911 8
2912 };
2913 (2,1) => {
2914 9
2915 };
2916 (0,2) => {
2917 10
2918 };
2919 (1,2) => {
2920 11
2921 };
2922 (2,2) => {
2923 12
2924 };
2925 (0,3) => {
2926 13
2927 };
2928 (1,3) => {
2929 14
2930 };
2931 (2,3) => {
2932 15
2933 };
2934}
2935#[macro_export]
2936macro_rules! c_reg_2x6 {
2937 (0,0) => {
2938 4
2939 };
2940 (1,0) => {
2941 5
2942 };
2943 (0,1) => {
2944 6
2945 };
2946 (1,1) => {
2947 7
2948 };
2949 (0,2) => {
2950 8
2951 };
2952 (1,2) => {
2953 9
2954 };
2955 (0,3) => {
2956 10
2957 };
2958 (1,3) => {
2959 11
2960 };
2961 (0,4) => {
2962 12
2963 };
2964 (1,4) => {
2965 13
2966 };
2967 (0,5) => {
2968 14
2969 };
2970 (1,5) => {
2971 15
2972 };
2973}
2974#[macro_export]
2975macro_rules! c_reg_1x6 {
2976 (0,0) => {
2977 7
2978 };
2979 (0,1) => {
2980 8
2981 };
2982 (0,2) => {
2983 9
2984 };
2985 (0,3) => {
2986 10
2987 };
2988 (0,4) => {
2989 11
2990 };
2991 (0,5) => {
2992 12
2993 };
2994}
2995#[macro_export]
2996macro_rules! c_reg_3x8 {
2997 (0,0) => {
2998 8
2999 };
3000 (1,0) => {
3001 9
3002 };
3003 (2,0) => {
3004 10
3005 };
3006 (0,1) => {
3007 11
3008 };
3009 (1,1) => {
3010 12
3011 };
3012 (2,1) => {
3013 13
3014 };
3015 (0,2) => {
3016 14
3017 };
3018 (1,2) => {
3019 15
3020 };
3021 (2,2) => {
3022 16
3023 };
3024 (0,3) => {
3025 17
3026 };
3027 (1,3) => {
3028 18
3029 };
3030 (2,3) => {
3031 19
3032 };
3033 (0,4) => {
3034 20
3035 };
3036 (1,4) => {
3037 21
3038 };
3039 (2,4) => {
3040 22
3041 };
3042 (0,5) => {
3043 23
3044 };
3045 (1,5) => {
3046 24
3047 };
3048 (2,5) => {
3049 25
3050 };
3051 (0,6) => {
3052 26
3053 };
3054 (1,6) => {
3055 27
3056 };
3057 (2,6) => {
3058 28
3059 };
3060 (0,7) => {
3061 29
3062 };
3063 (1,7) => {
3064 30
3065 };
3066 (2,7) => {
3067 31
3068 };
3069}
3070#[macro_export]
3071macro_rules! c_reg_2x12 {
3072 (0,0) => {
3073 8
3074 };
3075 (1,0) => {
3076 9
3077 };
3078 (0,1) => {
3079 10
3080 };
3081 (1,1) => {
3082 11
3083 };
3084 (0,2) => {
3085 12
3086 };
3087 (1,2) => {
3088 13
3089 };
3090 (0,3) => {
3091 14
3092 };
3093 (1,3) => {
3094 15
3095 };
3096 (0,4) => {
3097 16
3098 };
3099 (1,4) => {
3100 17
3101 };
3102 (0,5) => {
3103 18
3104 };
3105 (1,5) => {
3106 19
3107 };
3108 (0,6) => {
3109 20
3110 };
3111 (1,6) => {
3112 21
3113 };
3114 (0,7) => {
3115 22
3116 };
3117 (1,7) => {
3118 23
3119 };
3120 (0,8) => {
3121 24
3122 };
3123 (1,8) => {
3124 25
3125 };
3126 (0,9) => {
3127 26
3128 };
3129 (1,9) => {
3130 27
3131 };
3132 (0,10) => {
3133 28
3134 };
3135 (1,10) => {
3136 29
3137 };
3138 (0,11) => {
3139 30
3140 };
3141 (1,11) => {
3142 31
3143 };
3144}
3145#[macro_export]
3146macro_rules! c_reg_1x12 {
3147 (0,0) => {
3148 9
3149 };
3150 (0,1) => {
3151 10
3152 };
3153 (0,2) => {
3154 11
3155 };
3156 (0,3) => {
3157 12
3158 };
3159 (0,4) => {
3160 13
3161 };
3162 (0,5) => {
3163 14
3164 };
3165 (0,6) => {
3166 15
3167 };
3168 (0,7) => {
3169 16
3170 };
3171 (0,8) => {
3172 17
3173 };
3174 (0,9) => {
3175 18
3176 };
3177 (0,10) => {
3178 19
3179 };
3180 (0,11) => {
3181 20
3182 };
3183}
3184
3185#[macro_export]
3186macro_rules! acc_3x4 {
3187 ($ni:tt, $layout:tt, $q:tt) => {
3188 acc_p_avx!($layout, c_mem!($ni), $q, c_reg_3x4!(0, $ni), c_reg_3x4!(1, $ni), c_reg_3x4!(2, $ni))
3189 };
3190}
3191#[macro_export]
3192macro_rules! store_3x4 {
3193 ($ni:tt, $layout:tt) => {
3194 storep_avx!($layout, c_mem!($ni), c_reg_3x4!(0, $ni), c_reg_3x4!(1, $ni), c_reg_3x4!(2, $ni))
3195 };
3196}
3197#[macro_export]
3198macro_rules! acc_2x6 {
3199 ($ni:tt, $layout:tt, $q:tt) => {
3200 acc_p_avx!($layout, c_mem!($ni), $q, c_reg_2x6!(0, $ni), c_reg_2x6!(1, $ni))
3201 };
3202}
3203#[macro_export]
3204macro_rules! store_2x6 {
3205 ($ni:tt, $layout:tt) => {
3206 storep_avx!($layout, c_mem!($ni), c_reg_2x6!(0, $ni), c_reg_2x6!(1, $ni))
3207 };
3208}
3209#[macro_export]
3210macro_rules! acc_1x6 {
3211 ($ni:tt, $layout:tt, $q:tt) => {
3212 acc_p_avx!($layout, c_mem!($ni), $q, c_reg_1x6!(0, $ni))
3213 };
3214}
3215#[macro_export]
3216macro_rules! store_1x6 {
3217 ($ni:tt, $layout:tt) => {
3218 storep_avx!($layout, c_mem!($ni), c_reg_1x6!(0, $ni))
3219 };
3220}
3221
3222#[macro_export]
3223macro_rules! acc_3x8 {
3224 ($ni:tt, $layout:tt, $q:tt) => {
3225 acc_p_avx512!($layout, c_mem!($ni), $q, c_reg_3x8!(0, $ni), c_reg_3x8!(1, $ni), c_reg_3x8!(2, $ni))
3226 };
3227}
3228#[macro_export]
3229macro_rules! store_3x8 {
3230 ($ni:tt, $layout:tt) => {
3231 storep_avx512!($layout, c_mem!($ni), c_reg_3x8!(0, $ni), c_reg_3x8!(1, $ni), c_reg_3x8!(2, $ni))
3232 };
3233}
3234#[macro_export]
3235macro_rules! acc_2x12 {
3236 ($ni:tt, $layout:tt, $q:tt) => {
3237 acc_p_avx512!($layout, c_mem!($ni), $q, c_reg_2x12!(0, $ni), c_reg_2x12!(1, $ni))
3238 };
3239}
3240#[macro_export]
3241macro_rules! store_2x12 {
3242 ($ni:tt, $layout:tt) => {
3243 storep_avx512!($layout, c_mem!($ni), c_reg_2x12!(0, $ni), c_reg_2x12!(1, $ni))
3244 };
3245}
3246#[macro_export]
3247macro_rules! acc_1x12 {
3248 ($ni:tt, $layout:tt, $q:tt) => {
3249 acc_p_avx512!($layout, c_mem!($ni), $q, c_reg_1x12!(0, $ni))
3250 };
3251}
3252#[macro_export]
3253macro_rules! store_1x12 {
3254 ($ni:tt, $layout:tt) => {
3255 storep_avx512!($layout, c_mem!($ni), c_reg_1x12!(0, $ni))
3256 };
3257}
3258#[macro_export]
3259macro_rules! acc_p_avx {
3260 ($layout:tt, $m0:expr, $q:tt, $r1:expr, $r2:expr, $r3:expr) => {
3261 concat!(
3262 beta_fmadd!(C, $m0, $r1, $q),
3263 beta_fmadd!(C, pire_base::mem!($m0, "0x20"), $r2, $q),
3264 beta_fmadd!($layout, pire_base::mem!($m0, "0x40"), $r3, $q),
3265 )
3266 };
3267 ($layout:tt, $m0:expr, $q:tt, $r1:expr, $r2:expr) => {
3268 concat!(beta_fmadd!(C, $m0, $r1, $q), beta_fmadd!($layout, pire_base::mem!($m0, "0x20"), $r2, $q),)
3269 };
3270 ($layout:tt, $m0:expr, $q:tt, $r1:expr) => {
3271 concat!(beta_fmadd!($layout, $m0, $r1, $q),)
3272 };
3273}
3274
3275#[macro_export]
3276macro_rules! loadp_avx {
3277 (3, $m0:expr) => {
3278 concat!(
3279 loadp_unit!($m0, 0),
3280 loadp_unit!(pire_base::mem!($m0, "0x20"), 1),
3281 loadp_unit!(pire_base::mem!($m0, "0x40"), 2),
3282 )
3283 };
3284 (2, $m0:expr) => {
3285 concat!(loadp_unit!($m0, 0), loadp_unit!(pire_base::mem!($m0, "0x20"), 1),)
3286 };
3287 (1, $m0:expr) => {
3288 concat!(loadp_unit!($m0, 0),)
3289 };
3290}
3291
3292#[macro_export]
3293macro_rules! storep_avx {
3294 ($layout:tt, $m0:expr, $r1:expr, $r2:expr, $r3:expr) => {
3295 concat!(
3296 storep_unit!(C, $r1, $m0),
3297 storep_unit!(C, $r2, pire_base::mem!($m0, "0x20")),
3298 storep_unit!($layout, $r3, pire_base::mem!($m0, "0x40")),
3299 )
3300 };
3301 ($layout:tt, $m0:expr, $r1:expr, $r2:expr) => {
3302 concat!(storep_unit!(C, $r1, $m0), storep_unit!($layout, $r2, pire_base::mem!($m0, "0x20")),)
3303 };
3304 ($layout:tt, $m0:expr, $r1:expr) => {
3305 concat!(storep_unit!($layout, $r1, $m0),)
3306 };
3307}
3308
3309#[macro_export]
3310macro_rules! acc_p_avx512 {
3311 ($layout:tt, $m0:expr, $q:tt, $r1:expr, $r2:expr, $r3:expr) => {
3312 concat!(
3313 beta_fmadd!(C, $m0, $r1, $q),
3314 beta_fmadd!(C, pire_base::mem!($m0, "0x40"), $r2, $q),
3315 beta_fmadd!($layout, pire_base::mem!($m0, "0x80"), $r3, $q),
3316 )
3317 };
3318 ($layout:tt, $m0:expr, $q:tt, $r1:expr, $r2:expr) => {
3319 concat!(beta_fmadd!(C, $m0, $r1, $q), beta_fmadd!($layout, pire_base::mem!($m0, "0x40"), $r2, $q),)
3320 };
3321 ($layout:tt, $m0:expr, $q:tt, $r1:expr) => {
3322 concat!(beta_fmadd!($layout, $m0, $r1, $q),)
3323 };
3324}
3325
3326#[macro_export]
3327macro_rules! loadp_avx512 {
3328 (3, $m0:expr) => {
3329 concat!(
3330 loadp_unit!($m0, 0),
3331 loadp_unit!(pire_base::mem!($m0, "0x40"), 1),
3332 loadp_unit!(pire_base::mem!($m0, "0x80"), 2),
3333 )
3334 };
3335 (2, $m0:expr) => {
3336 concat!(loadp_unit!($m0, 0), loadp_unit!(pire_base::mem!($m0, "0x40"), 1),)
3337 };
3338 (1, $m0:expr) => {
3339 concat!(loadp_unit!($m0, 0),)
3340 };
3341}
3342
3343#[macro_export]
3344macro_rules! storep_avx512 {
3345 ($layout:tt, $m0:expr, $r1:expr, $r2:expr, $r3:expr) => {
3346 concat!(
3347 storep_unit!(C, $r1, $m0),
3348 storep_unit!(C, $r2, pire_base::mem!($m0, "0x40")),
3349 storep_unit!($layout, $r3, pire_base::mem!($m0, "0x80")),
3350 )
3351 };
3352 ($layout:tt, $m0:expr, $r1:expr, $r2:expr) => {
3353 concat!(storep_unit!(C, $r1, $m0), storep_unit!($layout, $r2, pire_base::mem!($m0, "0x40")),)
3354 };
3355 ($layout:tt, $m0:expr, $r1:expr) => {
3356 concat!(storep_unit!($layout, $r1, $m0),)
3357 };
3358}
3359
3360#[macro_export]
3361macro_rules! cum_seq {
3362 ($step_macro:tt, $nr:tt, $layout:tt, $b:tt) => {
3363 seq!(n in 0..$nr {
3364 concat!(#($step_macro!(n, $layout, $b),)*)
3365 })
3366 };
3367 ($step_macro:tt, $nr:tt, $layout:tt) => {
3368 seq!(n in 0..$nr {
3369 concat!(#($step_macro!(n, $layout),)*)
3370 })
3371 };
3372}
3373
3374#[macro_export]
3375macro_rules! b_num_3x8 {
3376 (0) => {
3377 3
3378 };
3379 (1) => {
3380 4
3381 };
3382 (2) => {
3383 5
3384 };
3385 (3) => {
3386 6
3387 };
3388 (4) => {
3389 7
3390 };
3391 (5) => {
3392 3
3393 };
3394 (6) => {
3395 4
3396 };
3397 (7) => {
3398 5
3399 };
3400}
3401#[macro_export]
3402macro_rules! b_num_2x12 {
3403 (0) => {
3404 2
3405 };
3406 (1) => {
3407 3
3408 };
3409 (2) => {
3410 4
3411 };
3412 (3) => {
3413 5
3414 };
3415 (4) => {
3416 6
3417 };
3418 (5) => {
3419 7
3420 };
3421 (6) => {
3422 2
3423 };
3424 (7) => {
3425 3
3426 };
3427 (8) => {
3428 4
3429 };
3430 (9) => {
3431 5
3432 };
3433 (10) => {
3434 6
3435 };
3436 (11) => {
3437 7
3438 };
3439}
3440#[macro_export]
3441macro_rules! b_num_1x12 {
3442 (0) => {
3443 1
3444 };
3445 (1) => {
3446 2
3447 };
3448 (2) => {
3449 3
3450 };
3451 (3) => {
3452 4
3453 };
3454 (4) => {
3455 5
3456 };
3457 (5) => {
3458 6
3459 };
3460 (6) => {
3461 7
3462 };
3463 (7) => {
3464 8
3465 };
3466 (8) => {
3467 9
3468 };
3469 (9) => {
3470 10
3471 };
3472 (10) => {
3473 11
3474 };
3475 (11) => {
3476 12
3477 };
3478}
3479
3480#[macro_export]
3481macro_rules! b_num_2x4 {
3482 (0) => {
3483 2
3484 };
3485 (1) => {
3486 3
3487 };
3488 (2) => {
3489 2
3490 };
3491 (3) => {
3492 3
3493 };
3494}
3495#[macro_export]
3496macro_rules! b_num_1x4 {
3497 (0) => {
3498 1
3499 };
3500 (1) => {
3501 2
3502 };
3503 (2) => {
3504 3
3505 };
3506 (3) => {
3507 4
3508 };
3509}
3510#[macro_export]
3511macro_rules! b_num_2x6 {
3512 (0) => {
3513 2
3514 };
3515 (1) => {
3516 3
3517 };
3518 (2) => {
3519 2
3520 };
3521 (3) => {
3522 3
3523 };
3524 (4) => {
3525 2
3526 };
3527 (5) => {
3528 3
3529 };
3530}
3531
3532#[macro_export]
3533macro_rules! b_num_1x6 {
3534 (0) => {
3535 1
3536 };
3537 (1) => {
3538 2
3539 };
3540 (2) => {
3541 3
3542 };
3543 (3) => {
3544 4
3545 };
3546 (4) => {
3547 5
3548 };
3549 (5) => {
3550 6
3551 };
3552}
3553
3554#[macro_export]
3555macro_rules! fmadd_3x8 {
3556 ($ni:tt) => {
3557 concat!(
3558 vfmadd!(0, b_num_3x8!($ni), c_reg_3x8!(0, $ni)),
3559 vfmadd!(1, b_num_3x8!($ni), c_reg_3x8!(1, $ni)),
3560 vfmadd!(2, b_num_3x8!($ni), c_reg_3x8!(2, $ni)),
3561 )
3562 };
3563}
3564#[macro_export]
3565macro_rules! fmadd_2x12 {
3566 ($ni:tt) => {
3567 concat!(vfmadd!(0, b_num_2x12!($ni), c_reg_2x12!(0, $ni)), vfmadd!(1, b_num_2x12!($ni), c_reg_2x12!(1, $ni)),)
3568 };
3569}
3570#[macro_export]
3571macro_rules! fmadd_1x12 {
3572 ($ni:tt) => {
3573 concat!(vfmadd!(0, b_num_1x12!($ni), c_reg_1x12!(0, $ni)),)
3574 };
3575}
3576
3577#[cfg(target_arch = "x86_64")]
3578#[macro_export]
3579macro_rules! prefetch_c_avx512 {
3580 (3, $nr:tt, $c:tt, $ldc:tt) => {
3581 use std::arch::x86_64::_mm_prefetch;
3582 seq!(j in 0..$nr {
3583 let c_u8 = $c.add(j*$ldc) as *const i8;
3584 _mm_prefetch(c_u8, 3);
3585 _mm_prefetch(c_u8.add(64), 3);
3586 _mm_prefetch(c_u8.add(128), 3);
3587 });
3588 };
3589 (2, $nr:tt, $c:tt, $ldc:tt) => {
3590 use std::arch::x86_64::_mm_prefetch;
3591 seq!(j in 0..$nr {
3592 let c_u8 = $c.add(j*$ldc) as *const i8;
3593 _mm_prefetch(c_u8, 3);
3594 _mm_prefetch(c_u8.add(64), 3);
3595 });
3596 };
3597 (1, $nr:tt, $c:tt, $ldc:tt) => {
3598 use std::arch::x86_64::_mm_prefetch;
3599 seq!(j in 0..$nr {
3600 let c_u8 = $c.add(j*$ldc) as *const i8;
3601 _mm_prefetch(c_u8, 3);
3602 });
3603 };
3604}
3605
3606#[cfg(target_arch = "x86_64")]
3607#[macro_export]
3608macro_rules! prefetch_c_avx {
3609 (3, $nr:tt, $c:tt, $ldc:tt) => {
3610 use std::arch::x86_64::_mm_prefetch;
3611 seq!(j in 0..$nr {
3612 let c_u8 = $c.add(j*$ldc) as *const i8;
3613 _mm_prefetch(c_u8, 3);
3614 _mm_prefetch(c_u8.add(64), 3);
3615 _mm_prefetch(c_u8.add(92), 3);
3616 });
3617 };
3618 (2, $nr:tt, $c:tt, $ldc:tt) => {
3619 use std::arch::x86_64::_mm_prefetch;
3620 seq!(j in 0..$nr {
3621 let c_u8 = $c.add(j*$ldc) as *const i8;
3622 _mm_prefetch(c_u8, 3);
3623 _mm_prefetch(c_u8.add(60), 3);
3624 });
3625 };
3626 (1, $nr:tt, $c:tt, $ldc:tt) => {
3627 use std::arch::x86_64::_mm_prefetch;
3628 seq!(j in 0..$nr {
3629 let c_u8 = $c.add(j*$ldc) as *const i8;
3630 _mm_prefetch(c_u8, 3);
3631 });
3632 };
3633}
3634
3635#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3636#[macro_export]
3637macro_rules! prefetch_c_sse {
3638 (3, $nr:tt, $c:tt, $ldc:tt) => {
3639 #[cfg(target_arch="x86")]
3640 use std::arch::x86::_mm_prefetch;
3641 #[cfg(target_arch="x86_64")]
3642 use std::arch::x86_64::_mm_prefetch;
3643 seq!(j in 0..$nr {
3644 let c_u8 = $c.add(j*$ldc) as *const i8;
3645 _mm_prefetch(c_u8, 3);
3646 _mm_prefetch(c_u8.add(64), 3);
3647 });
3648 };
3649 (2, $nr:tt, $c:tt, $ldc:tt) => {
3650 #[cfg(target_arch="x86")]
3651 use std::arch::x86::_mm_prefetch;
3652 #[cfg(target_arch="x86_64")]
3653 use std::arch::x86_64::_mm_prefetch;
3654 seq!(j in 0..$nr {
3655 let c_u8 = $c.add(j*$ldc) as *const i8;
3656 _mm_prefetch(c_u8, 3);
3657 });
3658 };
3659 (1, $nr:tt, $c:tt, $ldc:tt) => {
3660 #[cfg(target_arch="x86")]
3661 use std::arch::x86::_mm_prefetch;
3662 #[cfg(target_arch="x86_64")]
3663 use std::arch::x86_64::_mm_prefetch;
3664 seq!(j in 0..$nr {
3665 let c_u8 = $c.add(j*$ldc) as *const i8;
3666 _mm_prefetch(c_u8, 3);
3667 });
3668 };
3669}
3670
3671#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
3672#[macro_export]
3673macro_rules! prefetch_0 {
3674 ($dist:tt, $reg:tt) => {
3675 concat!("prefetcht0 ", $dist, "(", $reg, ")\n",)
3676 };
3677}
3678
3679#[cfg(target_arch = "aarch64")]
3680#[macro_export]
3681macro_rules! prefetch_0 {
3682 ($dist:tt, $reg:tt) => {
3683 concat!("prfm pldl1keep, [", $reg, ", #", $dist, "] \n",)
3684 };
3685}
3686
3687#[macro_export]
3688macro_rules! prefetch_b {
3689 (S) => {
3690 ""
3691 };
3692 (B) => {
3693 concat!("prefetcht0 192({bx}) \n",)
3694 };
3695}
3696
3697#[cfg(target_arch = "x86_64")]
3700#[macro_export]
3701macro_rules! def_ukernel_avx {
3702 (
3703 $k_unit:tt,
3704 $step_macro:tt,
3705 $acc_macro:tt,
3706 $store_macro:tt,
3707 $mr:tt, $nr:tt,
3708 $n0:tt, $n1:tt,
3709 $b_layout:tt,
3710 $is_partial:tt,
3711 $func_name:ident
3712 ) => {
3713 pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
3714 a: *const TA, b: *const TB, c: *mut TC,
3715 alpha: *const TS, beta: *const TS,
3716 k: usize,
3717 d_arr: [usize; 3], c_cs: usize,
3718 m: usize, n: usize,
3719 f: F,
3720 ) {
3721 use core::mem::size_of;
3722 const MR: usize = $mr * VS;
3723 mask_ptr!($is_partial, m, x, mask_ptr);
3724 let mut dim_arr = [d_arr[0]*size_of::<TA>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / ($k_unit*4), (k % ($k_unit*4)) / $k_unit];
3725 let mut c_k = c;
3726 let mut c_buf = [ZERO;MR*$nr];
3727 let alpha_st = if *alpha == ONE_SCALAR {
3728 0i32
3729 } else {
3730 1i32
3731 };
3732 let beta_st = if *beta == ZERO_SCALAR {
3733 0i32
3734 } else if *beta == ONE_SCALAR {
3735 1i32
3736 } else {
3737 2i32
3738 };
3739 if BUF {
3740 pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, n, MR);
3741 dim_arr[2] = MR*TC_SIZE;
3742 c_k = c_buf.as_mut_ptr();
3743 }
3744 let _ = 'blk: {
3745 seq!(ni in $n0..$n1 {
3746 if pire_base::n_cond!($n0, ni, n) {
3747 pire_base::prefetch_c_avx!($mr,ni,c,c_cs);
3748 asm!(
3749 vzero_kernel!(),
3750
3751 init_ab_avx!($b_layout),
3752
3753 "test {x0}, {x0}", "je 3f", "2:", pire_base::prefetch_b!($b_layout),
3757 $step_macro!(ni, $b_layout, 0),
3758 $step_macro!(ni, $b_layout, 1),
3759 $step_macro!(ni, $b_layout, 2),
3760 $step_macro!(ni, $b_layout, 3),
3761
3762 inc_a_k_unroll!($mr, 4),
3763 inc_b_k_unroll!($b_layout, ni, 4),
3764
3765 "dec {x0}", "jne 2b", "3:", "mov 32({dim_arrx}), {x0}",
3769 "test {x0},{x0}", "je 5f", "4:", $step_macro!(ni, $b_layout, 0),
3773 inc_a_k_unroll!($mr, 1),
3774 inc_b_k_unroll!($b_layout, ni, 1),
3775
3776 "dec {x0}", "jne 4b", "5:", c_load!(),
3780
3781 "cmpw $0, ({alpha_st})",
3782 "je 9f",
3783 alpha_scale!(),
3784 "9:",
3785
3786 load_mask!($is_partial),
3787
3788 "cmpw $0, ({beta_st})",
3789 "je 6f",
3790
3791 "cmpw $1, ({beta_st})",
3792 "je 15f",
3793
3794 load_beta!(),
3795 pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
3796 "jmp 6f",
3797
3798 "15:",
3799 pire_base::cum_seq!($acc_macro,ni,$is_partial,1),
3800
3801 "6:",
3802 pire_base::cum_seq!($store_macro,ni,$is_partial),
3803
3804 ax = inout(reg) a => _,
3805 bx = inout(reg) b => _,
3806 cx = inout(reg) c_k => _,
3807 dim_arrx = inout(reg) dim_arr.as_ptr() => _,
3808 alphax = inout(reg) alpha => _,
3809 betax = inout(reg) beta => _,
3810 beta_st = in(reg) &beta_st,
3811 alpha_st = in(reg) &alpha_st,
3812 maskx = inout(reg) mask_ptr => _,
3813 x0 = out(reg) _,
3814 x1 = out(reg) _,
3815 x2 = out(reg) _,
3816 x3 = out(reg) _,
3817 out("ymm0") _, out("ymm1") _, out("ymm2") _, out("ymm3") _,
3818 out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _,
3819 out("ymm8") _, out("ymm9") _, out("ymm10") _, out("ymm11") _,
3820 out("ymm12") _, out("ymm13") _, out("ymm14") _, out("ymm15") _,
3821 options(att_syntax)
3822 );
3823 break 'blk;
3824 }
3825 });
3826 };
3827 if BUF {
3828 for j in 0..n {
3829 f.call(c_k.add(j*MR), MR);
3830 }
3831 pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, MR);
3832 } else {
3833 for j in 0..n {
3834 f.call(c_k.add(j*c_cs), m);
3835 }
3836 }
3837 }
3838 };
3839}
3840
3841#[macro_export]
3842macro_rules! def_ukernel_avx512 {
3843 (
3844 $k_unit:tt,
3845 $step_macro:tt,
3846 $acc_macro:tt,
3847 $store_macro:tt,
3848 $mr:tt, $nr:tt,
3849 $n0:tt, $n1:tt,
3850 $b_layout:tt,
3851 $is_partial:tt,
3852 $func_name:ident
3853 ) => {
3854 pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
3855 a: *const TA, b: *const TB, c: *mut TC,
3856 alpha: *const TS, beta: *const TS,
3857 k: usize,
3858 d_arr: [usize; 3], c_cs: usize,
3859 m: usize, n: usize,
3860 f: F,
3861 ) {
3862 use core::mem::size_of;
3863 const MR: usize = $mr * VS;
3864 mask_ptr!($is_partial, m, x, mask_ptr);
3865 let mut dim_arr = [d_arr[0]*size_of::<TA>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / ($k_unit*4), (k % ($k_unit*4)) / $k_unit];
3866 let mut c_k = c;
3867 let mut c_buf = [ZERO;MR*$nr];
3868 let alpha_st = if *alpha == ONE_SCALAR {
3869 0i32
3870 } else {
3871 1i32
3872 };
3873 let beta_st = if *beta == ZERO_SCALAR {
3874 0i32
3875 } else if *beta == ONE_SCALAR {
3876 1i32
3877 } else {
3878 2i32
3879 };
3880 if BUF {
3881 pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, n, MR);
3882 dim_arr[2] = MR*TC_SIZE;
3883 c_k = c_buf.as_mut_ptr();
3884 }
3885 let _ = 'blk: {
3886 seq!(ni in $n0..$n1 {
3887 if pire_base::n_cond!($n0, ni, n) {
3888 pire_base::prefetch_c_avx512!($mr,ni,c,c_cs);
3889 asm!(
3890 vzero_kernel!(),
3891
3892 init_ab!($b_layout),
3893 "test {x0}, {x0}", "je 3f", "2:", $step_macro!(ni, $b_layout),
3897 $step_macro!(ni, $b_layout),
3898 $step_macro!(ni, $b_layout),
3899 $step_macro!(ni, $b_layout),
3900 "dec {x0}", "jne 2b", "3:", "mov 32({dim_arrx}), {x0}",
3904 "test {x0},{x0}", "je 5f", "4:", $step_macro!(ni, $b_layout),
3908
3909 "dec {x0}", "jne 4b", "5:", c_load!(),
3913
3914 "cmpw $0, ({alpha_st})",
3915 "je 9f",
3916 alpha_scale!(),
3917 "9:",
3918
3919 load_mask!($is_partial),
3920
3921 "cmpw $0, ({beta_st})",
3922 "je 6f",
3923
3924 "cmpw $1, ({beta_st})",
3925 "je 15f",
3926
3927 load_beta!(),
3928 pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
3929 "jmp 6f",
3930
3931 "15:",
3932 pire_base::cum_seq!($acc_macro,ni,$is_partial,1),
3933
3934 "6:",
3935 pire_base::cum_seq!($store_macro,ni,$is_partial),
3936
3937 ax = inout(reg) a => _,
3938 bx = inout(reg) b => _,
3939 cx = inout(reg) c_k => _,
3940 dim_arrx = inout(reg) dim_arr.as_ptr() => _,
3941 alphax = inout(reg) alpha => _,
3942 betax = inout(reg) beta => _,
3943 beta_st = in(reg) &beta_st,
3944 alpha_st = in(reg) &alpha_st,
3945 maskx = inout(reg) mask_ptr => _,
3946 x0 = out(reg) _,
3947 x1 = out(reg) _,
3948 x2 = out(reg) _,
3949 x3 = out(reg) _,
3950 x4 = out(reg) _,
3951 x5 = out(reg) _,
3952 out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _,
3953 out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,
3954 out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _,
3955 out("zmm12") _, out("zmm13") _, out("zmm14") _, out("zmm15") _,
3956 out("zmm16") _, out("zmm17") _, out("zmm18") _, out("zmm19") _,
3957 out("zmm20") _, out("zmm21") _, out("zmm22") _, out("zmm23") _,
3958 out("zmm24") _, out("zmm25") _, out("zmm26") _, out("zmm27") _,
3959 out("zmm28") _, out("zmm29") _, out("zmm30") _, out("zmm31") _,
3960 out("k1") _,
3961 options(att_syntax)
3962 );
3963 break 'blk;
3964 }
3965 });
3966 };
3967 if BUF {
3968 for j in 0..n {
3969 f.call(c_k.add(j*MR), MR);
3970 }
3971 pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, MR);
3972 } else {
3973 for j in 0..n {
3974 f.call(c_k.add(j*c_cs), m);
3975 }
3976 }
3977 }
3978 };
3979}
3980
3981#[cfg(target_arch = "x86_64")]
3982#[macro_export]
3983macro_rules! def_ukernel_sse {
3984 (
3985 $k_unit:tt,
3986 $step_macro:tt,
3987 $acc_macro:tt,
3988 $store_macro:tt,
3989 $mr:tt, $nr:tt,
3990 $n0:tt, $n1:tt,
3991 $b_layout:tt,
3992 $is_partial:tt,
3993 $func_name:ident
3994 ) => {
3995 pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
3996 a: *const TA, b: *const TB, c: *mut TC,
3997 alpha: *const TS, beta: *const TS,
3998 k: usize,
3999 d_arr: [usize; 3], c_cs: usize,
4000 m: usize, n: usize,
4001 f: F,
4002 ) {
4003 use core::mem::size_of;
4004 const MR: usize = $mr * VS;
4005 let mut dim_arr = [d_arr[0]*size_of::<TA>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / ($k_unit*4), (k % ($k_unit*4)) / $k_unit];
4006 let mut c_k = c;
4007 let mut c_buf = [ZERO;MR*$nr];
4008 let alpha_st = if *alpha == ONE_SCALAR {
4009 0i32
4010 } else {
4011 1i32
4012 };
4013 let beta_st = if *beta == ZERO_SCALAR {
4014 0i32
4015 } else if *beta == ONE_SCALAR {
4016 1i32
4017 } else {
4018 2i32
4019 };
4020 if BUF {
4021 pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, n, MR);
4022 dim_arr[2] = MR*TC_SIZE;
4023 c_k = c_buf.as_mut_ptr();
4024 }
4025 let _ = 'blk: {
4026 seq!(ni in $n0..$n1 {
4027 if pire_base::n_cond!($n0, ni, n) {
4028 pire_base::prefetch_c_sse!($mr,ni,c,c_cs);
4029 asm!(
4030 vzero_kernel!(),
4031
4032 init_ab_avx!($b_layout),
4033 "test {x0}, {x0}", "je 3f", "2:", pire_base::prefetch_b!($b_layout),
4037 $step_macro!(ni, $b_layout, 0),
4038 $step_macro!(ni, $b_layout, 1),
4039 $step_macro!(ni, $b_layout, 2),
4040 $step_macro!(ni, $b_layout, 3),
4041
4042 inc_a_k_unroll!($mr, 4),
4043 inc_b_k_unroll!($b_layout, ni, 4),
4044 "dec {x0}", "jne 2b", "3:", "mov 32({dim_arrx}), {x0}",
4048 "test {x0},{x0}", "je 5f", "4:", $step_macro!(ni, $b_layout, 0),
4052 inc_a_k_unroll!($mr, 1),
4053 inc_b_k_unroll!($b_layout, ni, 1),
4054
4055 "dec {x0}", "jne 4b", "5:", c_load!(),
4059
4060 "cmpw $0, ({alpha_st})",
4061 "je 9f",
4062 alpha_scale!(),
4063 "9:",
4064
4065 "cmpw $0, ({beta_st})",
4066 "je 6f",
4067
4068 "cmpw $1, ({beta_st})",
4069 "je 15f",
4070
4071 load_beta!(),
4072 pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
4073 "jmp 6f",
4074
4075 "15:",
4076 pire_base::cum_seq!($acc_macro,ni,$is_partial,1),
4077
4078 "6:",
4079 pire_base::cum_seq!($store_macro,ni,$is_partial),
4080
4081 ax = inout(reg) a => _,
4082 bx = inout(reg) b => _,
4083 cx = inout(reg) c_k => _,
4084 dim_arrx = inout(reg) dim_arr.as_ptr() => _,
4085 alphax = inout(reg) alpha => _,
4086 betax = inout(reg) beta => _,
4087 beta_st = in(reg) &beta_st,
4088 alpha_st = in(reg) &alpha_st,
4089 x0 = out(reg) _,
4090 x1 = out(reg) _,
4091 x2 = out(reg) _,
4092 x3 = out(reg) _,
4093 out("xmm0") _, out("xmm1") _, out("xmm2") _, out("xmm3") _,
4094 out("xmm4") _, out("xmm5") _, out("xmm6") _, out("xmm7") _,
4095 out("xmm8") _, out("xmm9") _, out("xmm10") _, out("xmm11") _,
4096 out("xmm12") _, out("xmm13") _, out("xmm14") _, out("xmm15") _,
4097 options(att_syntax)
4098 );
4099 break 'blk;
4100 }
4101 });
4102 };
4103 if BUF {
4104 for j in 0..n {
4105 f.call(c_k.add(j*MR), MR);
4106 }
4107 pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, MR);
4108 } else {
4109 for j in 0..n {
4110 f.call(c_k.add(j*c_cs), m);
4111 }
4112 }
4113 }
4114 };
4115}
4116
4117#[cfg(target_arch = "x86_64")]
4118#[macro_export]
4119macro_rules! def_ukernel_avx_2 {
4120 ($k_unit:tt, $step:ident, $acc:ident, $store:ident, $mr:tt, $nr:tt, $kl_pf:tt, $pf1_step:tt) => {
4121 pub(crate) unsafe fn ukernel_bbc<F: UnaryFnC, const BUF: bool>(
4122 a: *const TA, b: *const TB, c: *mut TC,
4123 alpha: *const TS, beta: *const TS,
4124 k: usize,
4125 d_arr: [usize; 3], c_cs: usize,
4126 a_pft1_offset: usize, _n: usize,
4127 f: F,
4128 ) {
4129 const MR: usize = $mr * VS;
4130 let k_l0 = k % $kl_pf;
4131 let k_l = if k_l0 == 0 {$kl_pf/$k_unit} else {k_l0/$k_unit};
4132 let k_i = (k - k_l*$k_unit) / (4*$k_unit);
4133 let mut c_k = c;
4134
4135 let mut dim_arr = [c_cs*TC_SIZE, k_i, k_l, a_pft1_offset];
4136 let mut c_buf = [ZERO; MR*$nr];
4137 let alpha_st = if *alpha == ONE_SCALAR {
4138 0i32
4139 } else {
4140 1i32
4141 };
4142 let beta_st = if *beta == ZERO_SCALAR {
4143 0i32
4144 } else if *beta == ONE_SCALAR {
4145 1i32
4146 } else {
4147 2i32
4148 };
4149 if BUF {
4150 pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, MR, $nr, MR);
4151 dim_arr[0] = MR*TC_SIZE;
4152 c_k = c_buf.as_mut_ptr();
4153 }
4154 asm!(
4155 vzero_kernel!(),
4156 init_ab_2!(B),
4157 "test {x0},{x0}",
4158 "je 3f",
4159 "mov {cx}, {x2}",
4160 "mov {ax}, {x5}",
4161 "mov 24({dim_arrx}),{x1}",
4162 "add {x1}, {x5}",
4163 "mov ({dim_arrx}),{x1}",
4164 "2:",
4165 prefetch_0!(256, "{bx}"),
4166 $step!($nr, B, 0),
4167
4168 "movq $64*4, {x4}",
4169 "testq $3, {x0}",
4171 "cmovz {x1},{x4}",
4172
4173 $step!($nr, B, 1),
4174
4175 "prefetcht1 ({x2})",
4176
4177 "subq $64*3, {x2}",
4178 "addq {x4}, {x2}",
4179
4180 $step!($nr, B, 2),
4181
4182 "prefetcht1 ({x5})",
4183 "addq $16, {x5}",
4184
4185 "testq $63, {x0}",
4186 "cmovz {cx},{x2}",
4187
4188 $step!($nr, B, 3),
4189
4190 inc_a_k_unroll!($mr, 4),
4191 inc_b_k_unroll!(B, $nr, 4),
4192
4193 "dec {x0}",
4194 "jne 2b",
4195 "3:",
4196 "mov 16({dim_arrx}),{x0}",
4197 "test {x0},{x0}", "je 5f", "mov {cx}, {x2}",
4200 "mov ({dim_arrx}),{x1}",
4201 "4:",
4202 "prefetcht0 ({x2})",
4203 "prefetcht0 64({x2})",
4204 "prefetcht0 92({x2})",
4205 $step!($nr, B, 0),
4206 inc_a_k_unroll!($mr, 1),
4207 inc_b_k_unroll!(B, $nr, 1),
4208 "add {x1}, {x2}", "dec {x0}", "jne 4b",
4209
4210 "5:",
4211 c_load_2!(),
4212
4213 "cmpw $0, ({alpha_st})",
4214 "je 9f",
4215 alpha_scale!(),
4216 "9:",
4217 "cmpw $0, ({beta_st})",
4218 "je 6f",
4219
4220 "cmpw $1, ({beta_st})",
4221 "je 15f",
4222
4223 load_beta!(),
4224 pire_base::cum_seq!($acc,$nr,C,2),
4225 "jmp 6f",
4226
4227 "15:",
4228 pire_base::cum_seq!($acc,$nr,C,1),
4229
4230 "6:",
4231 pire_base::cum_seq!($store,$nr,C),
4232
4233 ax = inout(reg) a => _,
4234 bx = inout(reg) b => _,
4235 cx = inout(reg) c_k => _,
4236 dim_arrx = inout(reg) dim_arr.as_ptr() => _,
4237 alphax = inout(reg) alpha => _,
4238 betax = inout(reg) beta => _,
4239 beta_st = in(reg) &beta_st,
4240 alpha_st = in(reg) &alpha_st,
4241 x0 = out(reg) _,
4242 x1 = out(reg)_,
4243 x2 = out(reg) _,
4244 x3 = out(reg) _,
4245 x4 = out(reg) _,
4246 x5 = out(reg) _,
4247 out("ymm0") _, out("ymm1") _, out("ymm2") _, out("ymm3") _,
4248 out("ymm4") _, out("ymm5") _, out("ymm6") _, out("ymm7") _,
4249 out("ymm8") _, out("ymm9") _, out("ymm10") _, out("ymm11") _,
4250 out("ymm12") _, out("ymm13") _, out("ymm14") _, out("ymm15") _,
4251 options(att_syntax)
4252 );
4253 if BUF {
4254 for j in 0..$nr {
4255 f.call(c_k.add(j*MR), MR);
4256 }
4257 pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, MR, $nr, MR);
4258 } else {
4259 for j in 0..$nr {
4260 f.call(c_k.add(j*c_cs), MR);
4261 }
4262 }
4263 }
4264 };
4265}
4266
4267#[cfg(target_arch = "x86_64")]
4268#[macro_export]
4269macro_rules! def_ukernel_avx512_2 {
4270 ($k_unit:tt, $step:ident, $acc:ident, $store:ident, $mr:tt, $nr:tt, $kl_pf:tt, $pf1_step:tt) => {
4271 pub(crate) unsafe fn ukernel_bbc<F: UnaryFnC, const BUF: bool>(
4272 a: *const TA, b: *const TB, c: *mut TC,
4273 alpha: *const TS, beta: *const TS,
4274 k: usize,
4275 d_arr: [usize; 3], c_cs: usize,
4276 a_pft1_offset: usize, _n: usize,
4277 f: F,
4278 ) {
4279 const MR: usize = $mr * VS;
4280 let k_l0 = k % $kl_pf;
4281 let k_l = if k_l0 == 0 {$kl_pf/$k_unit} else {k_l0/$k_unit};
4282 let k_i = (k - k_l*$k_unit) / (4*$k_unit);
4283 let mut c_k = c;
4284
4285 let mut dim_arr = [c_cs*TC_SIZE, k_i, k_l, a_pft1_offset];
4286 let mut c_buf = [ZERO; MR * $nr];
4287 let alpha_st = if *alpha == ONE_SCALAR {
4288 0i32
4289 } else {
4290 1i32
4291 };
4292 let beta_st = if *beta == ZERO_SCALAR {
4293 0i32
4294 } else if *beta == ONE_SCALAR {
4295 1i32
4296 } else {
4297 2i32
4298 };
4299 if BUF {
4300 pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, MR, $nr, MR);
4301 dim_arr[0] = MR*TC_SIZE;
4302 c_k = c_buf.as_mut_ptr();
4303 }
4304 asm!(
4305 vzero_kernel!(),
4306 init_ab_2!(B),
4307 "test {x0},{x0}", "je 3f",
4308
4309 "mov {cx}, {x2}",
4310 "mov {ax}, {x5}",
4311 "mov 24({dim_arrx}),{x1}",
4312 "add {x1}, {x5}",
4313 "mov ({dim_arrx}),{x1}",
4314
4315 "2:", $step!($nr, B),
4317
4318 "movq $64*4, {x4}",
4319 "testq $3, {x0}",
4321 "cmovz {x1},{x4}",
4322
4323 $step!($nr, B),
4324
4325 "prefetcht1 ({x2})",
4326
4327 "subq $64*3, {x2}",
4328 "addq {x4}, {x2}",
4329
4330 $step!($nr, B),
4331
4332 "prefetcht1 ({x5})",
4333 concat!("addq $", $pf1_step, ", {x5}"),
4334
4335 "testq $63, {x0}",
4336 "cmovz {cx},{x2}",
4337
4338 $step!($nr, B),
4339
4340 "dec {x0}", "jne 2b", "3:",
4343 "mov 16({dim_arrx}),{x0}",
4344 "test {x0},{x0}", "je 5f", "mov {cx}, {x2}",
4348 "mov ({dim_arrx}),{x1}",
4349
4350 "4:", "prefetcht0 ({x2})",
4352 "prefetcht0 64({x2})",
4353 "prefetcht0 128({x2})",
4354 $step!($nr, B),
4355 "add {x1}, {x2}", "dec {x0}", "jne 4b", "5:", c_load_2!(),
4359
4360 "cmpw $0, ({alpha_st})",
4361 "je 9f",
4362 alpha_scale!(),
4363 "9:",
4364
4365 "cmpw $0, ({beta_st})",
4366 "je 6f",
4367
4368 "cmpw $1, ({beta_st})",
4369 "je 15f",
4370
4371 load_beta!(),
4372 pire_base::cum_seq!($acc,$nr,C,2),
4373 "jmp 6f",
4374
4375 "15:",
4376 pire_base::cum_seq!($acc,$nr,C,1),
4377
4378 "6:",
4379 pire_base::cum_seq!($store,$nr,C),
4380
4381 ax = inout(reg) a => _,
4382 bx = inout(reg) b => _,
4383 cx = inout(reg) c_k => _,
4384 dim_arrx = inout(reg) dim_arr.as_ptr() => _,
4385 alphax = inout(reg) alpha => _,
4386 betax = inout(reg) beta => _,
4387 beta_st = in(reg) &beta_st,
4388 alpha_st = in(reg) &alpha_st,
4389 x0 = out(reg) _,
4390 x1 = out(reg)_,
4391 x2 = out(reg) _,
4392 x3 = out(reg) _,
4393 x4 = out(reg) _,
4394 x5 = out(reg) _,
4395 out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _,
4396 out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,
4397 out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _,
4398 out("zmm12") _, out("zmm13") _, out("zmm14") _, out("zmm15") _,
4399 out("zmm16") _, out("zmm17") _, out("zmm18") _, out("zmm19") _,
4400 out("zmm20") _, out("zmm21") _, out("zmm22") _, out("zmm23") _,
4401 out("zmm24") _, out("zmm25") _, out("zmm26") _, out("zmm27") _,
4402 out("zmm28") _, out("zmm29") _, out("zmm30") _, out("zmm31") _,
4403 out("k1") _,
4404 options(att_syntax)
4405 );
4406 if BUF {
4407 for j in 0..$nr {
4408 f.call(c_k.add(j*MR), MR);
4409 }
4410 pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, MR, $nr, MR);
4411 } else {
4412 for j in 0..$nr {
4413 f.call(c_k.add(j*c_cs), MR);
4414 }
4415 }
4416 }
4417 };
4418}
4419
4420#[cfg(target_arch = "x86")]
4421#[macro_export]
4422macro_rules! def_ukernel_sse {
4423 (
4424 $k_unit:tt,
4425 $step_macro:tt,
4426 $acc_macro:tt,
4427 $store_macro:tt,
4428 $mr:tt, $nr:tt,
4429 $n0:tt, $n1:tt,
4430 $b_layout:tt,
4431 $is_partial:tt,
4432 $func_name:ident
4433 ) => {
4434 pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
4435 a: *const TA, b: *const TB, c: *mut TC,
4436 alpha: *const TS, beta: *const TS,
4437 k: usize,
4438 d_arr: [usize; 3], c_cs: usize,
4439 m: usize, n: usize,
4440 f: F,
4441 ) {
4442 let alpha_st = if *alpha == ONE_SCALAR {
4443 0i32
4444 } else {
4445 1i32
4446 };
4447 let beta_st = if *beta == ZERO_SCALAR {
4448 0i32
4449 } else if *beta == ONE_SCALAR {
4450 1i32
4451 } else {
4452 2i32
4453 };
4454 const MR: usize = $mr * VS;
4455 let mut dim_arr = [d_arr[0]*8, d_arr[1]*8, c_cs*TC_SIZE, k / ($k_unit*4), (k % ($k_unit*4)) / $k_unit, beta_st as usize, alpha_st as usize];
4456 let mut ptr_arr = [alpha, beta];
4457 let mut cf = c;
4458 let mut c_buf = [ZERO;MR*$nr];
4459 if BUF {
4460 pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, n, MR);
4461 dim_arr[2] = MR*TC_SIZE;
4462 cf = c_buf.as_mut_ptr();
4463 }
4464 let _ = 'blk: {
4465 seq!(ni in $n0..$n1 {
4466 if pire_base::n_cond!($n0, ni, n) {
4467 pire_base::prefetch_c_sse!($mr,ni,c,c_cs);
4468 asm!(
4469 vzero_kernel!(),
4470
4471 init_ab!($b_layout),
4472 "test {x0}, {x0}", "je 3f", "2:", pire_base::prefetch_b!($b_layout),
4476 $step_macro!(ni, $b_layout, 0),
4477 $step_macro!(ni, $b_layout, 1),
4478 $step_macro!(ni, $b_layout, 2),
4479 $step_macro!(ni, $b_layout, 3),
4480
4481 inc_a_k_unroll!($mr, 4),
4482 inc_b_k_unroll!($b_layout, ni, 4),
4483
4484 "dec {x0}", "jne 2b", "3:", "mov 16({dim_arrx}), {x0}",
4488 "test {x0},{x0}", "je 5f", "4:", $step_macro!(ni, $b_layout, 0),
4492 inc_a_k_unroll!($mr, 1),
4493 inc_b_k_unroll!($b_layout, ni, 1),
4494
4495 "dec {x0}", "jne 4b", "5:", c_load!(),
4499
4500 "cmpw $0, 24({dim_arrx})",
4501 "je 9f",
4502 alpha_scale!(),
4503 "9:",
4504
4505 "cmpw $0, 20({dim_arrx})",
4506 "je 6f",
4507
4508 "cmpw $1, 20({dim_arrx})",
4509 "je 15f",
4510
4511 load_beta!(),
4512 pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
4513 "jmp 6f",
4514
4515 "15:",
4516 pire_base::cum_seq!($acc_macro,ni,$is_partial,1),
4517
4518 "6:",
4519 pire_base::cum_seq!($store_macro,ni,$is_partial),
4520
4521 ax = inout(reg) a => _,
4522 bx = inout(reg) b => _,
4523 cx = inout(reg) cf => _,
4524 ptr_arrx = inout(reg) ptr_arr.as_ptr() => _,
4525 dim_arrx = inout(reg) dim_arr.as_ptr() => _,
4526 x0 = out(reg) _,
4527 out("xmm0") _, out("xmm1") _, out("xmm2") _, out("xmm3") _,
4528 out("xmm4") _, out("xmm5") _, out("xmm6") _, out("xmm7") _,
4529 options(att_syntax)
4530 );
4531 break 'blk;
4532 }
4533 });
4534 };
4535 if BUF {
4536 for j in 0..n {
4537 f.call(cf.add(j*MR), MR);
4538 }
4539 pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, MR);
4540 } else {
4541 for j in 0..n {
4542 f.call(cf.add(j*c_cs), m);
4543 }
4544 }
4545 }
4546 };
4547}
4548
4549#[macro_export]
4550macro_rules! def_ukernel_neon {
4551 (
4552 $step_macro:tt,
4553 $acc_macro:tt,
4554 $store_macro:tt,
4555 $mr:tt, $nr:tt,
4556 $n0:tt, $n1:tt,
4557 $b_layout:tt,
4558 $is_partial:tt,
4559 $func_name:ident
4560 ) => {
4561 #[target_feature(enable="neon")]
4562 pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
4563 a: *const TA, b: *const TB, c: *mut TC,
4564 alpha: *const TA, beta: *const TB,
4565 k: usize,
4566 d_arr: [usize; 3], c_cs: usize,
4567 m: usize, n: usize,
4568 f: F,
4569 ) {
4570 const MR: usize = $mr * VS;
4571 use core::mem::size_of;
4572 let mut dim_arr = [d_arr[0]*size_of::<TB>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / 4, k % 4];
4573 let mut cf = c;
4574 let mut c_buf = [ZERO;MR*$nr];
4575 let alpha_st = if *alpha == ONE_SCALAR {
4576 0i32
4577 } else {
4578 1i32
4579 };
4580 let beta_st = if *beta == ZERO_SCALAR {
4581 0i32
4582 } else {
4583 1i32
4584 };
4585 let _ = 'blk: {
4586 seq!(ni in $n0..$n1 {
4587 if pire_base::n_cond!($n0, ni, n) {
4588 if BUF {
4589 pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, ni, MR);
4590 dim_arr[2] = MR*TC_SIZE;
4591 cf = c_buf.as_mut_ptr();
4592 }
4593 asm!(
4594 prefetch_c!(),
4595 vzero_kernel!(),
4596
4597 init_ab!($b_layout),
4598
4599 "cmp {x0}, #0", "BEQ 3f",
4601
4602 "2:",
4604 prefetch_0!(128, "{bx}"),
4605 $step_macro!(ni, $b_layout),
4606 $step_macro!(ni, $b_layout),
4607 $step_macro!(ni, $b_layout),
4608 $step_macro!(ni, $b_layout),
4609
4610 "sub {x0}, {x0}, #1",
4611 "cmp {x0}, 0",
4613 "BNE 2b",
4614
4615 "3:",
4617 "ldr {x0}, [{dim_arrx}, #32]",
4618 "cmp {x0}, #0",
4619
4620 "BEQ 5f",
4622 "4:",
4624 $step_macro!(ni, $b_layout),
4625
4626 "sub {x0}, {x0}, #1",
4627
4628 "cmp {x0}, 0",
4630 "BNE 4b",
4631
4632 "5:",
4634 c_load!(),
4635 "cmp {alpha_st:w}, #0",
4636 "BEQ 13f",
4637 alpha_scale!(),
4638 "13:",
4639
4640 "cmp {beta_st:w}, #0",
4641 "BEQ 6f",
4642
4643 load_beta!(),
4644
4645 pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
4646
4647 "6:",
4649 pire_base::cum_seq!($store_macro,ni,$is_partial),
4650
4651 ax = inout(reg) a => _,
4652 bx = inout(reg) b => _,
4653 cx = inout(reg) cf => _,
4654 dim_arrx = inout(reg) dim_arr.as_ptr() => _,
4655 alphax = inout(reg) alpha => _,
4656 betax = inout(reg) beta => _,
4657 alpha_st = in(reg) alpha_st,
4658 beta_st = in(reg) beta_st,
4659 x0 = out(reg) _,
4660 x1 = out(reg) _,
4661 x2 = out(reg) _,
4662 x3 = out(reg) _,
4663 x4 = out(reg) _,
4664 x5 = out(reg) _,
4665 out("v0") _, out("v1") _, out("v2") _, out("v3") _, out("v4") _, out("v5") _, out("v6") _, out("v7") _,
4666 out("v8") _, out("v9") _, out("v10") _, out("v11") _, out("v12") _, out("v13") _, out("v14") _, out("v15") _,
4667 out("v16") _, out("v17") _, out("v18") _, out("v19") _, out("v20") _, out("v21") _, out("v22") _, out("v23") _,
4668 out("v24") _, out("v25") _, out("v26") _, out("v27") _, out("v28") _, out("v29") _, out("v30") _, out("v31") _,
4669 );
4670 break 'blk;
4671 }
4672 });
4673 };
4674 if BUF {
4675 for j in 0..n {
4676 f.call(cf.add(j*MR), MR);
4677 }
4678 pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, MR);
4679 } else {
4680 for j in 0..n {
4681 f.call(cf.add(j*c_cs), m);
4682 }
4683 }
4684 }
4685 };
4686}
4687
4688#[macro_export]
4689macro_rules! def_ukernel_neon_alt {
4690 (
4691 $step_macro:tt,
4692 $acc_macro:tt,
4693 $store_macro:tt,
4694 $mr:tt, $nr:tt,
4695 $n0:tt, $n1:tt,
4696 $b_layout:tt,
4697 $is_partial:tt,
4698 $func_name:ident
4699 ) => {
4700 #[target_feature(enable="neon")]
4701 pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
4702 a: *const TA, b: *const TB, c: *mut TC,
4703 alpha: *const TA, beta: *const TB,
4704 k: usize,
4705 d_arr: [usize; 3], c_cs: usize,
4706 m: usize, n: usize,
4707 f: F,
4708 ) {
4709 alt_arr!(alt);
4710 const MR: usize = $mr * VS;
4711 use core::mem::size_of;
4712 let mut dim_arr = [d_arr[0]*size_of::<TB>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / 4, k % 4];
4713 let mut cf = c;
4714 let mut c_buf = [ZERO;MR*$nr];
4715 let alpha_st = if *alpha == ONE_SCALAR {
4716 0i32
4717 } else {
4718 1i32
4719 };
4720 let beta_st = if *beta == ZERO_SCALAR {
4721 0i32
4722 } else {
4723 1i32
4724 };
4725 let _ = 'blk: {
4726 seq!(ni in $n0..$n1 {
4727 if pire_base::n_cond!($n0, ni, n) {
4728 if BUF {
4729 pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, ni, MR);
4730 dim_arr[2] = MR*TC_SIZE;
4731 cf = c_buf.as_mut_ptr();
4732 }
4733 asm!(
4734 prefetch_c!(),
4735 vzero_kernel!(),
4736
4737 init_ab!($b_layout),
4738
4739 "cmp {x0}, #0", "BEQ 3f",
4741
4742 "2:",
4744 prefetch_0!(128, "{bx}"),
4745 $step_macro!(ni, $b_layout),
4746 $step_macro!(ni, $b_layout),
4747 $step_macro!(ni, $b_layout),
4748 $step_macro!(ni, $b_layout),
4749
4750 "sub {x0}, {x0}, #1",
4751 "cmp {x0}, 0",
4753 "BNE 2b",
4754
4755 "3:",
4757 "ldr {x0}, [{dim_arrx}, #32]",
4758 "cmp {x0}, #0",
4759
4760 "BEQ 5f",
4762 "4:",
4764 $step_macro!(ni, $b_layout),
4765
4766 "sub {x0}, {x0}, #1",
4767
4768 "cmp {x0}, 0",
4770 "BNE 4b",
4771
4772 "5:",
4774 c_load!(),
4775 "cmp {alpha_st:w}, #0",
4776 "BEQ 13f",
4777 alpha_scale!(),
4778 "13:",
4779 "cmp {beta_st:w}, #0",
4780 "BEQ 6f",
4781
4782 load_beta!(),
4783
4784 pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
4785
4786 "6:",
4788 pire_base::cum_seq!($store_macro,ni,$is_partial),
4789
4790 ax = inout(reg) a => _,
4791 bx = inout(reg) b => _,
4792 cx = inout(reg) cf => _,
4793 dim_arrx = inout(reg) dim_arr.as_ptr() => _,
4794 alphax = inout(reg) alpha => _,
4795 betax = inout(reg) beta => _,
4796 alpha_st = in(reg) alpha_st,
4797 beta_st = in(reg) beta_st,
4798 altx = inout(reg) alt.as_ptr() => _,
4799 x0 = out(reg) _,
4800 x1 = out(reg) _,
4801 x2 = out(reg) _,
4802 x3 = out(reg) _,
4803 x4 = out(reg) _,
4804 x5 = out(reg) _,
4805 out("v0") _, out("v1") _, out("v2") _, out("v3") _, out("v4") _, out("v5") _, out("v6") _, out("v7") _,
4806 out("v8") _, out("v9") _, out("v10") _, out("v11") _, out("v12") _, out("v13") _, out("v14") _, out("v15") _,
4807 out("v16") _, out("v17") _, out("v18") _, out("v19") _, out("v20") _, out("v21") _, out("v22") _, out("v23") _,
4808 out("v24") _, out("v25") _, out("v26") _, out("v27") _, out("v28") _, out("v29") _, out("v30") _, out("v31") _,
4809 );
4810 break 'blk;
4811 }
4812 });
4813 };
4814 if BUF {
4815 for j in 0..n {
4816 f.call(cf.add(j*MR), MR);
4817 }
4818 pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, MR);
4819 } else {
4820 for j in 0..n {
4821 f.call(cf.add(j*c_cs), m);
4822 }
4823 }
4824 }
4825 };
4826}
4827
4828#[macro_export]
4829macro_rules! def_ukernel_neon_fp16 {
4830 (
4831 $step_macro:tt,
4832 $acc_macro:tt,
4833 $store_macro:tt,
4834 $mr:tt, $nr:tt,
4835 $n0:tt, $n1:tt,
4836 $b_layout:tt,
4837 $is_partial:tt,
4838 $func_name:ident
4839 ) => {
4840 #[target_feature(enable="neon,fp16")]
4841 pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
4842 a: *const TA, b: *const TB, c: *mut TC,
4843 alpha: *const TA, beta: *const TB,
4844 k: usize,
4845 d_arr: [usize; 3], c_cs: usize,
4846 m: usize, n: usize,
4847 f: F,
4848 ) {
4849 const MR: usize = $mr * VS;
4850 use core::mem::size_of;
4851 let mut dim_arr = [d_arr[0]*size_of::<TB>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / 4, k % 4];
4852 let mut cf = c;
4853 let mut c_buf = [ZERO;MR*$nr];
4854 let alpha_st = if *alpha == ONE_SCALAR {
4855 0i32
4856 } else {
4857 1i32
4858 };
4859 let beta_st = if *beta == ZERO_SCALAR {
4860 0i32
4861 } else {
4862 1i32
4863 };
4864 let _ = 'blk: {
4865 seq!(ni in $n0..$n1 {
4866 if BUF {
4867 pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, ni, MR);
4868 dim_arr[2] = MR*TC_SIZE;
4869 cf = c_buf.as_mut_ptr();
4870 }
4871 if pire_base::n_cond!($n0, ni, n) {
4872 asm!(
4873 prefetch_c!(),
4874 vzero_kernel!(),
4875
4876 init_ab!($b_layout),
4877
4878 "cmp {x0}, #0", "BEQ 3f",
4880
4881 "2:",
4883 prefetch_0!(128, "{bx}"),
4884 $step_macro!(ni, $b_layout),
4885 $step_macro!(ni, $b_layout),
4886 $step_macro!(ni, $b_layout),
4887 $step_macro!(ni, $b_layout),
4888
4889 "sub {x0}, {x0}, #1",
4890 "cmp {x0}, 0",
4892 "BNE 2b",
4893
4894 "3:",
4896 "ldr {x0}, [{dim_arrx}, #32]",
4897 "cmp {x0}, #0",
4898
4899 "BEQ 5f",
4901 "4:",
4903 $step_macro!(ni, $b_layout),
4904
4905 "sub {x0}, {x0}, #1",
4906
4907 "cmp {x0}, 0",
4909 "BNE 4b",
4910
4911 "5:",
4913 c_load!(),
4914 "cmp {alpha_st:w}, #0",
4915 "BEQ 13f",
4916 alpha_scale!(),
4917 "13:",
4918
4919 "cmp {beta_st:w}, #0",
4920 "BEQ 6f",
4921
4922 load_beta!(),
4923
4924 pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
4925
4926 "6:",
4928 pire_base::cum_seq!($store_macro,ni,$is_partial),
4929
4930 ax = inout(reg) a => _,
4931 bx = inout(reg) b => _,
4932 cx = inout(reg) cf => _,
4933 dim_arrx = inout(reg) dim_arr.as_ptr() => _,
4934 alphax = inout(reg) alpha => _,
4935 betax = inout(reg) beta => _,
4936 alpha_st = in(reg) alpha_st,
4937 beta_st = in(reg) beta_st,
4938 x0 = out(reg) _,
4939 x1 = out(reg) _,
4940 x2 = out(reg) _,
4941 x3 = out(reg) _,
4942 x4 = out(reg) _,
4943 x5 = out(reg) _,
4944 out("v0") _, out("v1") _, out("v2") _, out("v3") _, out("v4") _, out("v5") _, out("v6") _, out("v7") _,
4945 out("v8") _, out("v9") _, out("v10") _, out("v11") _, out("v12") _, out("v13") _, out("v14") _, out("v15") _,
4946 out("v16") _, out("v17") _, out("v18") _, out("v19") _, out("v20") _, out("v21") _, out("v22") _, out("v23") _,
4947 out("v24") _, out("v25") _, out("v26") _, out("v27") _, out("v28") _, out("v29") _, out("v30") _, out("v31") _,
4948 );
4949 break 'blk;
4950 }
4951 });
4952 };
4953 if BUF {
4954 for j in 0..n {
4955 f.call(cf.add(j*MR), MR);
4956 }
4957 pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, MR);
4958 } else {
4959 for j in 0..n {
4960 f.call(cf.add(j*c_cs), m);
4961 }
4962 }
4963 }
4964 };
4965}
4966
4967#[macro_export]
4968macro_rules! def_ukernel_neon_i8mm {
4969 (
4970 $step_macro:tt,
4971 $acc_macro:tt,
4972 $store_macro:tt,
4973 $mr:tt, $nr:tt,
4974 $n0:tt, $n1:tt,
4975 $b_layout:tt,
4976 $is_partial:tt,
4977 $func_name:ident
4978 ) => {
4979 #[target_feature(enable="neon,i8mm")]
4980 pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
4981 a: *const TA, b: *const TB, c: *mut TC,
4982 alpha: *const TS, beta: *const TS,
4983 k: usize,
4984 d_arr: [usize; 3], c_cs: usize,
4985 m: usize, n: usize,
4986 f: F,
4987 ) {
4988 const MR: usize = $mr * VS;
4989 use core::mem::size_of;
4990 let mut dim_arr = [d_arr[0]*size_of::<TB>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / 32, (k % 32) / 8];
4991 let mut cf = c;
4992 let mut c_buf = [ZERO;MR*$nr];
4993 let alpha_st = if *alpha == ONE_SCALAR {
4994 0i32
4995 } else {
4996 1i32
4997 };
4998 let beta_st = if *beta == ZERO_SCALAR {
4999 0i32
5000 } else if *beta == ONE_SCALAR {
5001 1i32
5002 } else {
5003 2i32
5004 };
5005 let _ = 'blk: {
5006 seq!(ni in $n0..$n1 {
5007 if pire_base::n_cond!($n0, ni, n) {
5008 if BUF {
5009 pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, ni, MR);
5010 dim_arr[2] = MR*TC_SIZE;
5011 cf = c_buf.as_mut_ptr();
5012 }
5013 asm!(
5014 prefetch_c!(),
5015 vzero_kernel!(),
5016
5017 init_ab!($b_layout),
5018
5019 "cmp {x0}, #0", "BEQ 3f",
5021
5022 "2:",
5024 prefetch_0!(128, "{bx}"),
5025 $step_macro!(ni, $b_layout),
5026 $step_macro!(ni, $b_layout),
5027 $step_macro!(ni, $b_layout),
5028 $step_macro!(ni, $b_layout),
5029
5030 "sub {x0}, {x0}, #1",
5031 "cmp {x0}, 0",
5033 "BNE 2b",
5034
5035 "3:",
5037 "ldr {x0}, [{dim_arrx}, #32]",
5038 "cmp {x0}, #0",
5039
5040 "BEQ 5f",
5042 "4:",
5044 $step_macro!(ni, $b_layout),
5045
5046 "sub {x0}, {x0}, #1",
5047
5048 "cmp {x0}, 0",
5050 "BNE 4b",
5051
5052 "5:",
5054 c_load!(),
5055 "cmp {alpha_st:w}, #0",
5056 "BEQ 13f",
5057 alpha_scale!(),
5058 "13:",
5059
5060 "cmp {beta_st:w}, #0",
5061 "BEQ 6f",
5062
5063 "cmp {beta_st:w}, #1",
5064 "BEQ 9f",
5065
5066 load_beta!(),
5067 pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
5068 "B 6f",
5069
5070 "9:",
5071 pire_base::cum_seq!($acc_macro,ni,$is_partial,1),
5073
5074 "6:",
5076 pire_base::cum_seq!($store_macro,ni,$is_partial),
5077
5078 ax = inout(reg) a => _,
5079 bx = inout(reg) b => _,
5080 cx = inout(reg) cf => _,
5081 dim_arrx = inout(reg) dim_arr.as_ptr() => _,
5082 alphax = inout(reg) alpha => _,
5083 betax = inout(reg) beta => _,
5084 alpha_st = in(reg) alpha_st,
5085 beta_st = in(reg) beta_st,
5086 x0 = out(reg) _,
5087 x1 = out(reg) _,
5088 x2 = out(reg) _,
5089 x3 = out(reg) _,
5090 x4 = out(reg) _,
5091 x5 = out(reg) _,
5092 x6 = out(reg) _,
5093 x7 = out(reg) _,
5094 x8 = out(reg) _,
5095 x9 = out(reg) _,
5096 x10 = out(reg) _,
5097 x11 = out(reg) _,
5098 out("v0") _, out("v1") _, out("v2") _, out("v3") _, out("v4") _, out("v5") _, out("v6") _, out("v7") _,
5099 out("v8") _, out("v9") _, out("v10") _, out("v11") _, out("v12") _, out("v13") _, out("v14") _, out("v15") _,
5100 out("v16") _, out("v17") _, out("v18") _, out("v19") _, out("v20") _, out("v21") _, out("v22") _, out("v23") _,
5101 out("v24") _, out("v25") _, out("v26") _, out("v27") _, out("v28") _, out("v29") _, out("v30") _, out("v31") _,
5102 );
5103 break 'blk;
5104 }
5105 });
5106 };
5107 if BUF {
5108 for j in 0..n {
5109 f.call(cf.add(j*MR), MR);
5110 }
5111 pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, MR);
5112 } else {
5113 for j in 0..n {
5114 f.call(cf.add(j*c_cs), m);
5115 }
5116 }
5117 }
5118 };
5119}
5120
5121#[macro_export]
5122macro_rules! def_ukernel_sve {
5123 (
5124 $step_macro:tt,
5125 $acc_macro:tt,
5126 $store_macro:tt,
5127 $mr:tt, $nr:tt,
5128 $n0:tt, $n1:tt,
5129 $b_layout:tt,
5130 $is_partial:tt,
5131 $func_name:ident
5132 ) => {
5133 #[target_feature(enable="neon,sve")]
5134 pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
5135 a: *const TA, b: *const TB, c: *mut TC,
5136 alpha: *const TA, beta: *const TB,
5137 k: usize,
5138 d_arr: [usize; 3], c_cs: usize,
5139 m: usize, n: usize,
5140 f: F,
5141 ) {
5142 use core::mem::size_of;
5143 let vs = sve_vs();
5144 let m_left = if m % vs == 0 {vs} else {m%vs};
5145 let inc_a = vs * $mr * size_of::<TA>();
5146 let mr = $mr * vs;
5147 let mut dim_arr = [d_arr[0]*size_of::<TB>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / 4, k % 4];
5148 let mut cf = c;
5149 let mut c_buf = [ZERO;(256/size_of::<TC>())*$mr*$nr];
5150 let alpha_st = if *alpha == ONE_SCALAR {
5151 0i32
5152 } else {
5153 1i32
5154 };
5155 let beta_st = if *beta == ZERO_SCALAR {
5156 0i32
5157 } else {
5158 1i32
5159 };
5160 let _ = 'blk: {
5161 seq!(ni in $n0..$n1 {
5162 if BUF {
5165 pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, ni, mr);
5166 dim_arr[2] = mr*TC_SIZE;
5167 cf = c_buf.as_mut_ptr();
5168 }
5169 if pire_base::n_cond!($n0, ni, n) {
5170 asm!(
5171 "ptrue p0.h",
5172 "mov {m_s}, #0",
5173 "/* {m_e} */", "\n",
5174 prefetch_c!(),
5175 vzero_kernel!(),
5176
5177 init_ab!($b_layout),
5178
5179 "cmp {x0}, #0", "BEQ 3f",
5181
5182 "2:",
5184 prefetch_0!(128, "{bx}"),
5185 $step_macro!(ni, $b_layout),
5186 $step_macro!(ni, $b_layout),
5187 $step_macro!(ni, $b_layout),
5188 $step_macro!(ni, $b_layout),
5189
5190 "sub {x0}, {x0}, #1",
5191 "cmp {x0}, 0",
5193 "BNE 2b",
5194
5195 "3:",
5197 "ldr {x0}, [{dim_arrx}, #32]",
5198 "cmp {x0}, #0",
5199
5200 "BEQ 5f",
5202 "4:",
5204 $step_macro!(ni, $b_layout),
5205
5206 "sub {x0}, {x0}, #1",
5207
5208 "cmp {x0}, 0",
5210 "BNE 4b",
5211
5212 "5:",
5214 c_load!(),
5215 "cmp {alpha_st:w}, #0",
5216 "BEQ 13f",
5217 alpha_scale!(),
5218 "13:",
5219
5220 "cmp {beta_st:w}, #0",
5221 "BEQ 6f",
5222
5223 load_beta!(),
5224
5225 pire_base::cum_seq!($acc_macro,ni,$is_partial),
5226
5227 "6:",
5229 pire_base::cum_seq!($store_macro,ni,$is_partial),
5230
5231 ax = inout(reg) a => _,
5232 bx = inout(reg) b => _,
5233 cx = inout(reg) cf => _,
5234 dim_arrx = inout(reg) dim_arr.as_ptr() => _,
5235 alphax = inout(reg) alpha => _,
5236 betax = inout(reg) beta => _,
5237 incax = in(reg) inc_a as u64,
5238 alpha_st = in(reg) alpha_st,
5239 beta_st = in(reg) beta_st,
5240 m_s = out(reg) _,
5241 m_e = inout(reg) m_left as u64 => _,
5242 x0 = out(reg) _,
5243 x1 = out(reg) _,
5244 x2 = out(reg) _,
5245 x3 = out(reg) _,
5246 x4 = out(reg) _,
5247 x5 = out(reg) _,
5248 x6 = out(reg) _,
5249 x7 = out(reg) _,
5250 out("v0") _, out("v1") _, out("v2") _, out("v3") _, out("v4") _, out("v5") _, out("v6") _, out("v7") _,
5251 out("v8") _, out("v9") _, out("v10") _, out("v11") _, out("v12") _, out("v13") _, out("v14") _, out("v15") _,
5252 out("v16") _, out("v17") _, out("v18") _, out("v19") _, out("v20") _, out("v21") _, out("v22") _, out("v23") _,
5253 out("v24") _, out("v25") _, out("v26") _, out("v27") _, out("v28") _, out("v29") _, out("v30") _, out("v31") _,
5254 out("p0") _, out("p1") _, out("p2") _, out("p3") _,
5255 );
5256 break 'blk;
5257 }
5258 });
5259 };
5260 if BUF {
5261 for j in 0..n {
5262 f.call(cf.add(j*mr), mr);
5263 }
5264 pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, mr);
5265 } else {
5266 for j in 0..n {
5267 f.call(cf.add(j*c_cs), m);
5268 }
5269 }
5270 }
5271 };
5272}
5273
5274#[macro_export]
5275macro_rules! def_ukernel_sve_i8mm {
5276 (
5277 $step_macro:tt,
5278 $acc_macro:tt,
5279 $store_macro:tt,
5280 $mr:tt, $nr:tt,
5281 $n0:tt, $n1:tt,
5282 $b_layout:tt,
5283 $is_partial:tt,
5284 $func_name:ident
5286 ) => {
5287 #[target_feature(enable="neon,sve,i8mm")]
5288 pub(crate) unsafe fn $func_name<F: UnaryFnC, const BUF: bool>(
5289 a: *const TA, b: *const TB, c: *mut TC,
5290 alpha: *const TS, beta: *const TS,
5291 k: usize,
5292 d_arr: [usize; 3], c_cs: usize,
5293 m: usize, n: usize,
5294 f: F,
5295 ) {
5296 use core::mem::size_of;
5297 let vs = sve_vs();
5298 let m_left = if m % vs == 0 {vs} else {m%vs};
5299 let inc_a = $mr * vs * size_of::<TA>() * 8;
5300 let mr = $mr * vs;
5301 let mut dim_arr = [d_arr[0]*size_of::<TB>(), d_arr[1]*size_of::<TB>(), c_cs*TC_SIZE, k / 32, (k % 32) / 8];
5302 let mut cf = c;
5303 let mut c_buf = [ZERO;(256/size_of::<TC>())*$mr*$nr];
5304 let alpha_st = if *alpha == ONE_SCALAR {
5305 0i32
5306 } else {
5307 1i32
5308 };
5309 let beta_st = if *beta == ZERO_SCALAR {
5310 0i32
5311 } else if *beta == ONE_SCALAR {
5312 1i32
5313 } else {
5314 2i32
5315 };
5316 let _ = 'blk: {
5317 seq!(ni in $n0..$n1 {
5318 if BUF {
5321 pire_base::load_buf(c, d_arr[2], c_cs, &mut c_buf, m, ni, mr);
5322 dim_arr[2] = mr*TC_SIZE;
5323 cf = c_buf.as_mut_ptr();
5324 }
5325 if pire_base::n_cond!($n0, ni, n) {
5326 asm!(
5327 "ptrue p0.h",
5328 "/* {m_s} */", "\n",
5329 "/* {m_e} */", "\n",
5330 prefetch_c!(),
5331 vzero_kernel!(),
5332
5333 init_ab!($b_layout),
5334
5335 "cmp {x0}, #0", "BEQ 3f",
5337
5338 "2:",
5340 prefetch_0!(128, "{bx}"),
5341 $step_macro!(ni, $b_layout),
5342 $step_macro!(ni, $b_layout),
5343 $step_macro!(ni, $b_layout),
5344 $step_macro!(ni, $b_layout),
5345
5346 "sub {x0}, {x0}, #1",
5347 "cmp {x0}, 0",
5349 "BNE 2b",
5350
5351 "3:",
5353 "ldr {x0}, [{dim_arrx}, #32]",
5354 "cmp {x0}, #0",
5355
5356 "BEQ 5f",
5358 "4:",
5360 $step_macro!(ni, $b_layout),
5361
5362 "sub {x0}, {x0}, #1",
5363
5364 "cmp {x0}, 0",
5366 "BNE 4b",
5367
5368 "5:",
5370 c_load!(),
5371 "cmp {alpha_st:w}, #0",
5372 "BEQ 13f",
5373 alpha_scale!(),
5374 "13:",
5375
5376 "cmp {beta_st:w}, #0",
5377 "BEQ 6f",
5378
5379 "cmp {beta_st:w}, #1",
5380 "BEQ 9f",
5381
5382 load_beta!(),
5383 pire_base::cum_seq!($acc_macro,ni,$is_partial,2),
5384 "B 6f",
5385
5386 "9:",
5387 pire_base::cum_seq!($acc_macro,ni,$is_partial,1),
5389
5390 "6:",
5392 pire_base::cum_seq!($store_macro,ni,$is_partial),
5393 ax = inout(reg) a => _,
5394 bx = inout(reg) b => _,
5395 cx = inout(reg) cf => _,
5396 dim_arrx = inout(reg) dim_arr.as_ptr() => _,
5397 alphax = inout(reg) alpha => _,
5398 betax = inout(reg) beta => _,
5399 incax = in(reg) inc_a as u64,
5400 alpha_st = in(reg) alpha_st,
5401 beta_st = in(reg) beta_st,
5402 m_s = in(reg) 0 as u64,
5403 m_e = in(reg) m_left as u64,
5404 x0 = out(reg) _,
5405 x1 = out(reg) _,
5406 x2 = out(reg) _,
5407 x3 = out(reg) _,
5408 x4 = out(reg) _,
5409 x5 = out(reg) _,
5410 x6 = out(reg) _,
5411 x7 = out(reg) _,
5412 x8 = out(reg) _,
5413 x9 = out(reg) _,
5414 x10 = out(reg) _,
5415 x11 = out(reg) _,
5416 out("v0") _, out("v1") _, out("v2") _, out("v3") _, out("v4") _, out("v5") _, out("v6") _, out("v7") _,
5417 out("v8") _, out("v9") _, out("v10") _, out("v11") _, out("v12") _, out("v13") _, out("v14") _, out("v15") _,
5418 out("v16") _, out("v17") _, out("v18") _, out("v19") _, out("v20") _, out("v21") _, out("v22") _, out("v23") _,
5419 out("v24") _, out("v25") _, out("v26") _, out("v27") _, out("v28") _, out("v29") _, out("v30") _, out("v31") _,
5420 out("p0") _, out("p1") _, out("p2") _, out("p3") _,
5421 );
5422 break 'blk;
5423 }
5424 });
5425 };
5426 if BUF {
5427 for j in 0..n {
5428 f.call(cf.add(j*mr), mr);
5429 }
5430 pire_base::store_buf(c, d_arr[2], c_cs, &c_buf, m, n, mr);
5431 } else {
5432 for j in 0..n {
5433 f.call(cf.add(j*c_cs), m);
5434 }
5435 }
5436 }
5437 };
5438}
5439
5440