1use std::sync::{Barrier, Mutex, MutexGuard, RwLock, RwLockReadGuard};
2use once_cell::sync::Lazy;
4
5pub mod range_rwlock;
6
7pub fn matrix_size(rs: usize, cs: usize, m: usize, n: usize) -> usize {
8 m * rs + n * cs - (rs + cs) + 1
9}
10
11use range_rwlock::{RangeLock, RangeLockReadGuard, RangeLockWriteGuard};
12
13#[macro_export]
14macro_rules! env_or {
15 ($name:expr, $default:expr) => {
16 if let Some(value) = std::option_env!($name) {
17 const_str::parse!(value, usize)
18 } else {
19 $default
20 }
21 };
22}
23
24#[derive(Copy, Clone)]
25pub struct CpuFeatures {
26 pub avx: bool,
27 pub avx2: bool,
28 pub avx512f: bool,
29 pub avx512f16: bool,
30 pub avx512bw: bool,
32 pub avx512_vnni: bool,
33 pub fma: bool,
34 pub fma4: bool,
35 pub f16c: bool,
36}
37
38const CACHELINE_PAD: usize = 1024;
40
41#[cfg(target_arch = "x86_64")]
42pub struct HWConfig {
43 pub cpu_ft: CpuFeatures,
44 pub hw_model: HWModel,
45 is_l1_shared: bool,
46 is_l2_shared: bool,
47 is_l3_shared: bool,
48}
49
50impl HWConfig {
51 pub fn get_cache_info(&self) -> (bool, bool, bool) {
52 (self.is_l1_shared, self.is_l2_shared, self.is_l3_shared)
53 }
54 pub fn hw_model(&self) -> HWModel {
55 self.hw_model
56 }
57
58 pub fn cpu_ft(&self) -> CpuFeatures {
59 self.cpu_ft
60 }
61}
62
63#[cfg(target_arch = "aarch64")]
64pub struct HWConfig {
65 neon: bool,
66}
67
68#[derive(Copy, Clone)]
69pub enum HWModel {
70 Reference,
71 Haswell,
72 Skylake,
73}
74
75const SKYLAKE: [u8; 13] = [78, 85, 94, 126, 140, 141, 167, 151, 154, 183, 186, 143, 207];
76
77const HASWELL: [u8; 10] = [69, 70, 63, 42, 58, 165, 79, 86, 61, 71];
78
79impl HWModel {
80 pub fn from_hw(family_id: u8, model_id: u8) -> Self {
81 if family_id == 6 {
82 if SKYLAKE.contains(&model_id) {
83 return HWModel::Skylake;
84 }
85 if HASWELL.contains(&model_id) {
86 return HWModel::Haswell;
87 }
88 }
89
90 return HWModel::Reference;
92 }
93 pub fn get_cache_info(&self) -> (bool, bool, bool) {
94 match self {
95 HWModel::Reference => (false, false, true),
96 HWModel::Haswell => (false, false, true),
97 HWModel::Skylake => (false, false, true),
98 }
99 }
100}
101
102#[inline]
107fn detect_hw_config() -> HWConfig {
108 #[cfg(target_arch = "x86_64")]
109 {
110 let cpuid = raw_cpuid::CpuId::new();
111 let feature_info = cpuid.get_feature_info().unwrap();
112 let extended_feature_info = cpuid.get_extended_feature_info().unwrap();
113 let avx = feature_info.has_avx();
114 let fma = feature_info.has_fma();
115 let avx2 = extended_feature_info.has_avx2();
116 let avx512f16 = extended_feature_info.has_avx512_fp16();
117 let avx512f = extended_feature_info.has_avx512f();
119 let avx512bw = extended_feature_info.has_avx512bw();
120 let avx512_vnni = extended_feature_info.has_avx512vnni();
121 let f16c = feature_info.has_f16c();
122 let extended_prcoessor_info = cpuid.get_extended_processor_and_feature_identifiers().unwrap();
123 let fma4 = extended_prcoessor_info.has_fma4();
124 let cpu_ft = CpuFeatures { avx, avx2, avx512f, avx512f16, avx512bw, avx512_vnni, fma, fma4, f16c };
125 let family_id = feature_info.family_id();
126 let model_id = feature_info.model_id();
127 let hw_model = HWModel::from_hw(family_id, model_id);
128 let (is_l1_shared, is_l2_shared, is_l3_shared) = hw_model.get_cache_info();
129 return HWConfig { cpu_ft, hw_model, is_l1_shared, is_l2_shared, is_l3_shared };
130 }
131 #[cfg(target_arch = "aarch64")]
132 {
133 return HWConfig { neon: true };
134 }
135}
136
137pub static RUNTIME_HW_CONFIG: Lazy<HWConfig> = Lazy::new(|| detect_hw_config());
138
139pub static GLAR_NUM_THREADS: Lazy<usize> = Lazy::new(|| {
140 let n_core = std::thread::available_parallelism().unwrap().get();
141 let x = std::env::var("GLAR_NUM_THREADS").unwrap_or(n_core.to_string());
143 x.parse::<usize>().unwrap()
144});
145#[cfg(target_arch = "x86_64")]
146pub(crate) mod cpu_features {
147 use super::HWModel;
148 use super::RUNTIME_HW_CONFIG;
149
150 pub fn hw_model() -> HWModel {
151 RUNTIME_HW_CONFIG.hw_model
152 }
153
154 pub fn has_f32_compute() -> bool {
155 RUNTIME_HW_CONFIG.cpu_ft.avx
159 }
160
161 pub fn has_f16f32_compute() -> bool {
162 RUNTIME_HW_CONFIG.cpu_ft.avx && RUNTIME_HW_CONFIG.cpu_ft.f16c
166 }
167 pub fn has_f64_compute() -> bool {
168 RUNTIME_HW_CONFIG.cpu_ft.avx
169 }
170 pub fn has_f16_compute() -> bool {
171 RUNTIME_HW_CONFIG.cpu_ft.avx512f16
175 && RUNTIME_HW_CONFIG.cpu_ft.avx
176 && RUNTIME_HW_CONFIG.cpu_ft.f16c
177 && RUNTIME_HW_CONFIG.cpu_ft.fma
178 }
179 pub fn has_i16i32_compute() -> bool {
180 RUNTIME_HW_CONFIG.cpu_ft.avx2 && RUNTIME_HW_CONFIG.cpu_ft.avx
181 }
182 pub fn has_i8i32_compute() -> bool {
183 RUNTIME_HW_CONFIG.cpu_ft.avx2 && RUNTIME_HW_CONFIG.cpu_ft.avx
184 }
185 pub fn get_cache_params() -> (usize, usize, usize) {
187 (4800, 256, 128)
188 }
189}
190#[cfg(target_arch = "aarch64")]
191pub(crate) mod cpu_features {
192 use super::RUNTIME_HW_CONFIG;
193 pub fn hw_neon() -> bool {
194 RUNTIME_HW_CONFIG.neon
195 }
196}
197pub use cpu_features::*;
198
199pub struct PackPool {
200 pub buffer: RwLock<Vec<Mutex<Vec<u8>>>>,
201}
202
203pub static PACK_POOL: PackPool = PackPool { buffer: RwLock::new(vec![]) };
204
205pub fn acquire<'a>(
206 pool_guard: &'a RwLockReadGuard<'a, Vec<Mutex<Vec<u8>>>>,
207 pack_size: usize,
208) -> Option<MutexGuard<'a, Vec<u8>>> {
209 for i in pool_guard.iter() {
212 let lock = i.try_lock();
219 if let Ok(mutex) = lock {
220 if mutex.len() >= pack_size {
221 return Some(mutex);
222 }
223 }
224 }
225
226 None
227}
228
229pub fn extend<'a>(pool_vec: Vec<u8>) {
230 let mut pool_guard = PACK_POOL.buffer.write().unwrap();
231 pool_guard.push(Mutex::new(pool_vec));
232}
233
234pub struct GlarThreadConfig<'a> {
235 pub ic_id: usize,
236 pub jc_id: usize,
238 pub ir_id: usize,
239 pub jr_id: usize,
240 pub i_load_p_idx: usize,
241 pub j_load_p_idx: usize,
242 pub mc_eff: usize,
243 pub nc_eff: usize,
244 pub kc_eff: usize,
245 pub par: GlarPar,
246 pub packa_barrier: &'a [Barrier],
247 pub packb_barrier: &'a [Barrier],
248}
249
250pub fn get_apbp_barrier(par: &GlarPar) -> (Vec<Barrier>, Vec<Barrier>) {
251 let mut packa_barrier = vec![];
252 for _ in 0..par.ic_par {
253 let barrier = Barrier::new(par.jc_par * par.pc_par * par.ir_par * par.jr_par);
254 packa_barrier.push(barrier);
255 }
256
257 let mut packb_barrier = vec![];
258 for _ in 0..par.jc_par {
259 let barrier = Barrier::new(par.ic_par * par.pc_par * par.ir_par * par.jr_par);
260 packb_barrier.push(barrier);
261 }
262
263 (packa_barrier, packb_barrier)
264}
265
266impl<'a> GlarThreadConfig<'a> {
267 pub fn new(
268 par: GlarPar,
269 packa_barrier: &'a [Barrier],
270 packb_barrier: &'a [Barrier],
271 t_id: usize,
272 mc_eff: usize,
273 nc_eff: usize,
274 kc_eff: usize,
275 ) -> Self {
276 let ic_id = par.get_ic_id(t_id);
277 let jc_id = par.get_jc_id(t_id);
279 let ir_id = par.get_ir_id(t_id);
280 let jr_id = par.get_jr_id(t_id);
281 let i_load_p_idx = jc_id * par.ir_par * par.jr_par + ir_id * par.jr_par + jr_id;
282 let j_load_p_idx = ic_id * par.ir_par * par.jr_par + ir_id * par.jr_par + jr_id;
283
284 Self {
285 ic_id,
286 jc_id,
288 ir_id,
289 jr_id,
290 i_load_p_idx,
291 j_load_p_idx,
292 mc_eff,
293 nc_eff,
294 kc_eff,
295 par,
296 packa_barrier,
297 packb_barrier,
298 }
299 }
300 #[inline]
301 pub fn wait_packa(&self) {
302 if self.par.jc_par * self.par.pc_par * self.par.ir_par * self.par.jr_par > 1 {
303 self.packa_barrier[self.ic_id].wait();
304 }
305 }
306
307 #[inline]
308 pub fn wait_packb(&self) {
309 if self.par.ic_par * self.par.pc_par * self.par.ir_par * self.par.jr_par > 1 {
310 self.packb_barrier[self.jc_id].wait();
311 }
312 }
313}
314
315pub fn check_mem_size(mem_size: usize, rs: usize, cs: usize, m: usize, n: usize) {
316 assert!(mem_size >= rs * cs * m * n);
317 assert!(rs >= 1 && cs >= 1 && m >= 0 && n >= 0);
318}
319
320#[inline(always)]
322pub fn glar_num_threads() -> usize {
323 return *GLAR_NUM_THREADS;
324}
325
326#[derive(Copy, Clone)]
327pub struct GlarPar {
328 pub num_threads: usize,
329 pub ic_par: usize,
330 pub pc_par: usize,
331 pub jc_par: usize,
332 pub ir_par: usize,
333 pub jr_par: usize,
334}
335
336#[inline(always)]
339fn inc_par(ic_par: usize, jc_par: usize, ic_max: usize, jc_max: usize, factor: usize) -> (usize, usize, usize, usize) {
340 if (ic_par < jc_par && ic_par < ic_max) || (jc_par >= jc_max && ic_par < ic_max) {
341 (ic_par * factor, jc_par, ic_max / factor, jc_max)
342 } else if (ic_par >= jc_par && jc_par < jc_max) || (ic_par >= ic_max && jc_par < jc_max) {
343 (ic_par, jc_par * factor, ic_max, jc_max / factor)
344 } else {
345 (ic_par, jc_par, ic_max, jc_max)
346 }
347}
348impl GlarPar {
349 pub fn new(num_threads: usize, ic_par: usize, pc_par: usize, jc_par: usize, ir_par: usize, jr_par: usize) -> Self {
350 assert_eq!(num_threads, jc_par * pc_par * ic_par * jr_par * ir_par);
351 Self { num_threads, ic_par, pc_par, jc_par, ir_par, jr_par }
352 }
353 pub fn from_num_threads(num_threads: usize, m: usize, n: usize) -> Self {
354 let mut num_threads = num_threads;
355 let mut ic_par_max = if m < 96 {
356 1
357 } else if m < 400 {
358 2
359 } else {
360 m / 200
361 };
362 let mut jc_par_max = if n < 48 {
363 1
364 } else if n < 200 {
365 2
366 } else {
367 n / 100
368 };
369
370 if num_threads <= 12 {
371 let jc_par_max = jc_par_max.min(num_threads);
372 let n_thread = (num_threads / jc_par_max) * jc_par_max;
373 return Self::new(n_thread, num_threads / jc_par_max, 1, jc_par_max, 1, 1);
374 }
375 num_threads = num_threads.min(ic_par_max * jc_par_max);
377 let mut ic_par = 1;
378 let pc_par = 1;
379 let mut jc_par = 1;
380 let mut ir_par = 1;
381 let jr_par = 1;
382
383 while num_threads > 1 {
384 if num_threads % 2 == 0 {
385 num_threads = num_threads / 2;
386 (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 2);
387 } else if num_threads % 3 == 0 {
388 num_threads = num_threads / 3;
389 (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 3);
390 } else if num_threads % 5 == 0 {
391 num_threads = num_threads / 5;
392 (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 5);
393 continue;
394 } else if num_threads % 7 == 0 {
395 num_threads = num_threads / 7;
396 (ic_par, jc_par, ic_par_max, jc_par_max) = inc_par(ic_par, jc_par, ic_par_max, jc_par_max, 7);
397 continue;
398 } else {
399 num_threads = num_threads / 2 * 2;
402 }
403 }
419 if ic_par >= 8 {
420 ic_par = ic_par / 2;
421 ir_par = 2;
422 }
423 let num_threads = ic_par * pc_par * jc_par * ir_par * jr_par;
424 Self { num_threads, ic_par, pc_par, jc_par, ir_par, jr_par }
425 }
426 #[inline(always)]
427 pub fn default(m: usize, n: usize) -> Self {
428 let num_threads = glar_num_threads();
429 Self::from_num_threads(num_threads, m, n)
430 }
431 #[inline]
432 fn get_ic_id(&self, t_id: usize) -> usize {
433 (t_id / (self.pc_par * self.jc_par * self.ir_par * self.jr_par)) % self.ic_par
434 }
435
436 #[inline]
441 fn get_jc_id(&self, t_id: usize) -> usize {
442 (t_id / (self.jr_par * self.ir_par)) % self.jc_par
443 }
444 #[inline]
445 fn get_jr_id(&self, t_id: usize) -> usize {
446 (t_id / self.ir_par) % self.jr_par
447 }
448 #[inline]
449 fn get_ir_id(&self, t_id: usize) -> usize {
450 t_id % self.ir_par
451 }
452
453 pub fn get_load_par(
454 &self,
455 gemm_mode: &GemmPool,
456 m: usize,
457 n: usize,
458 mc_eff: usize,
459 nc_eff: usize,
460 ) -> (usize, usize) {
461 let m = (m / self.ic_par).min(mc_eff);
462 let n = (n / self.jc_par).min(nc_eff);
463 let i_load_par = ((m + 127) / 128).min(self.num_threads / self.ic_par);
464 let j_load_par = ((n + 127) / 128).min(self.num_threads / self.jc_par);
465 let i_load_par = match gemm_mode {
466 GemmPool::Goto => i_load_par,
467 GemmPool::SmallM => i_load_par,
468 GemmPool::SmallN => 1,
469 };
470 (i_load_par.max(1), j_load_par.max(1))
471 }
472}
473
474#[inline]
475pub fn split_c_range(m: usize, mc: usize, mr: usize, ic_id: usize, ic_par: usize) -> (usize, usize, bool) {
476 let chunk_len = (m / (mr * ic_par)) * mr;
477 let rem = m % (mr * ic_par);
478 if ic_id == 0 {
479 let x = chunk_len + rem % mr;
480 let mc_left = ((((x + mc - 1) / mc) * mc) * ic_par) < m;
481 return (m - chunk_len - (rem % mr), m, mc_left);
482 }
483 let ic_id = ic_id - 1;
484 let m0 = (m / mr) * mr;
485 let rem = m0 % (mr * ic_par);
486 let start_delta = rem.min(ic_id * mr);
487 let end_delta = rem.min((ic_id + 1) * mr);
488 let mc_coeff = (chunk_len + end_delta - start_delta + mc - 1) / mc;
490 let mc_left = ((mc_coeff * mc) * ic_par) < m;
491 (chunk_len * ic_id + start_delta, chunk_len * (ic_id + 1) + end_delta, mc_left)
493}
494
495#[inline]
496pub fn split_range(range_len: usize, unit_len: usize, r_id: usize, r_par: usize) -> (usize, usize) {
497 let chunk_start = (range_len / (unit_len * r_par)) * unit_len * r_id;
498 let chunk_end = (range_len / (unit_len * r_par)) * unit_len * (r_id + 1);
499 let rem = range_len % (unit_len * r_par);
500 let rem = rem - rem % unit_len;
501 let rem_start = rem.min(r_id * unit_len);
502 let rem_end = rem.min((r_id + 1) * unit_len);
503 if r_id == r_par - 1 {
504 return (chunk_start + rem_start, range_len);
505 }
506 (chunk_start + rem_start, chunk_end + rem_end)
507}
508
509pub trait BaseNum: Copy + 'static + Send {}
510
511impl<T> BaseNum for T where T: Copy + 'static + Send {}
512
513#[derive(Copy, Clone)]
514pub struct PoolSize {
515 pub m: usize,
516 pub n: usize,
517 pub k: usize,
518 pub ap_pool_size: usize,
519 pub ap_pool_multiplicity: usize,
520 pub bp_pool_size: usize,
521 pub bp_pool_multiplicity: usize,
522}
523
524impl PoolSize {
525 pub fn mem_pool_size_b<TA, TB>(&self) -> usize {
527 self.ap_pool_size * std::mem::size_of::<TA>() * self.ap_pool_multiplicity
529 + self.bp_pool_size * std::mem::size_of::<TB>() * self.bp_pool_multiplicity
530 + 2 * AB_ALIGN
531 }
532
533 pub fn ap_size_b<TA>(&self) -> usize {
534 self.ap_pool_size * std::mem::size_of::<TA>()
535 }
536
537 pub fn bp_size_b<TB>(&self) -> usize {
538 self.bp_pool_size * std::mem::size_of::<TB>()
539 }
540
541 pub fn ap_size_t_b<TA>(&self) -> usize {
542 self.ap_pool_size * std::mem::size_of::<TA>() * self.ap_pool_multiplicity
543 }
544
545 pub fn bp_size_t_b<TB>(&self) -> usize {
546 self.bp_pool_size * std::mem::size_of::<TB>() * self.bp_pool_multiplicity
547 }
548
549 pub fn slice_mut_from_pool<TA, TB>(
550 &self,
551 mem_pool: &mut [u8],
552 i_load_par: usize,
553 j_load_par: usize,
554 pool_size: PoolSize,
555 mr: usize,
556 nr: usize,
557 ) -> (Vec<RangeLock<'_, TA>>, Vec<RangeLock<'_, TB>>) {
559 let m_size = pool_size.m;
560 let n_size = pool_size.n;
561 let k_size = pool_size.k;
562 let ap_pool_size = self.ap_pool_size;
563 let ap_pool_size_b = ap_pool_size * std::mem::size_of::<TA>();
564 let a_alignment = std::mem::align_of::<TA>();
565 assert_eq!(ap_pool_size_b % a_alignment, 0);
566 let bp_pool_size = self.bp_pool_size;
567 let bp_pool_size_b = bp_pool_size * std::mem::size_of::<TB>();
568 let b_alignment = std::mem::align_of::<TB>();
569 assert_eq!(bp_pool_size_b % b_alignment, 0);
570 let mut ap = vec![];
571 let mut bp = vec![];
572 assert!(mem_pool.len() >= self.mem_pool_size_b::<TA, TB>());
575 let align_offset = mem_pool.as_ptr().align_offset(AB_ALIGN);
577 let mut mem_pool = &mut mem_pool[align_offset..];
578 assert_eq!(mem_pool.as_ptr().align_offset(a_alignment), 0);
580 for _ in 0..self.ap_pool_multiplicity {
581 let (a, rest) = mem_pool.split_at_mut(ap_pool_size_b);
582 let ap_pool = unsafe { std::slice::from_raw_parts_mut::<TA>(a.as_mut_ptr() as *mut TA, ap_pool_size) };
583 if ap_pool_size == 0 {
584 ap.push(RangeLock::from(ap_pool, i_load_par, 0, k_size, mr));
585 } else {
586 ap.push(RangeLock::from(ap_pool, i_load_par, m_size, k_size, mr));
587 }
588 mem_pool = rest;
589 }
590 let align_offset = mem_pool.as_ptr().align_offset(AB_ALIGN);
591 let mut mem_pool = &mut mem_pool[align_offset..];
592 assert_eq!(mem_pool.as_ptr().align_offset(b_alignment), 0);
594 for _ in 0..self.bp_pool_multiplicity {
595 let (b, rest) = mem_pool.split_at_mut(bp_pool_size_b);
596 let bp_pool = unsafe { std::slice::from_raw_parts_mut::<TB>(b.as_mut_ptr() as *mut TB, bp_pool_size) };
597 if bp_pool_size == 0 {
598 bp.push(RangeLock::from(bp_pool, j_load_par, 0, k_size, nr));
599 } else {
600 bp.push(RangeLock::from(bp_pool, j_load_par, n_size, k_size, nr));
601 }
602 mem_pool = rest;
603 }
604 (ap, bp)
605 }
606}
607
608pub fn get_mem_pool_size_goto<AP: BaseNum, BP: BaseNum, HWConfig: GemmCache>(
609 hw_config: &HWConfig,
610 par: &GlarPar,
611 a_need_pool: bool,
612 b_need_pool: bool,
613) -> PoolSize {
614 let m = hw_config.get_mc_eff(par.ic_par);
615 let n = hw_config.get_nc_eff(par.jc_par);
616 let k = hw_config.get_kc_eff();
617 let (ap_pool_size, ap_pool_multiplicity) = if a_need_pool {
618 let ap_pool_multiplicity = par.ic_par;
619 let ap_pool_size = hw_config.get_ap_pool_size(par.ic_par) + CACHELINE_PAD / std::mem::size_of::<AP>();
620 (ap_pool_size, ap_pool_multiplicity)
621 } else {
622 (0, 1)
623 };
624 let (bp_pool_size, bp_pool_multiplicity) = if b_need_pool {
625 let bp_pool_multiplicity = par.jc_par;
626 let bp_pool_size = hw_config.get_bp_pool_size(par.jc_par) + CACHELINE_PAD / std::mem::size_of::<BP>();
627 (bp_pool_size, bp_pool_multiplicity)
628 } else {
629 (0, 1)
630 };
631 PoolSize { m, n, k, ap_pool_size, ap_pool_multiplicity, bp_pool_size, bp_pool_multiplicity }
632}
633
634pub fn get_mem_pool_size_small_m<AP: BaseNum, BP: BaseNum, HWConfig: GemmCache>(
635 hw_config: &HWConfig,
636 par: &GlarPar,
637 a_need_pool: bool,
638) -> PoolSize {
639 let m = hw_config.get_mc_eff(par.ic_par);
640 let n = hw_config.get_nc_eff(par.jc_par);
641 let k = hw_config.get_kc_eff();
642 if a_need_pool {
643 let ap_pool_multiplicity = par.ic_par;
644 let ap_pool_size = hw_config.get_ap_pool_size(par.ic_par) + CACHELINE_PAD / std::mem::size_of::<AP>();
645 PoolSize { m, n, k, ap_pool_size, ap_pool_multiplicity, bp_pool_size: 0, bp_pool_multiplicity: 1 }
646 } else {
647 PoolSize { m, n, k, ap_pool_size: 0, ap_pool_multiplicity: 1, bp_pool_size: 0, bp_pool_multiplicity: 1 }
648 }
649}
650
651pub fn get_mem_pool_size_small_n<AP: BaseNum, BP: BaseNum, HWConfig: GemmCache>(
652 hw_config: &HWConfig,
653 par: &GlarPar,
654 b_need_pool: bool,
655) -> PoolSize {
656 let ap_pool_size = hw_config.get_ap_pool_size2() + CACHELINE_PAD / std::mem::size_of::<AP>();
657 let ap_pool_multiplicity = par.num_threads;
658 let m = hw_config.mr();
659 let n = hw_config.get_nc_eff(par.jc_par);
660 let k = hw_config.get_kc_eff();
661 if b_need_pool {
662 let bp_pool_multiplicity = par.jc_par;
663 let bp_pool_size = hw_config.get_bp_pool_size(par.jc_par) + CACHELINE_PAD / std::mem::size_of::<BP>();
664 PoolSize { m, n, k, ap_pool_size, ap_pool_multiplicity, bp_pool_size, bp_pool_multiplicity }
665 } else {
666 PoolSize { m, n, k, ap_pool_size, ap_pool_multiplicity, bp_pool_size: 0, bp_pool_multiplicity: 1 }
667 }
668}
669
670pub fn run_small_m(m: usize) -> bool {
678 m < 144
679}
680
681pub fn run_small_n(n: usize) -> bool {
682 n < 144
683}
684
685pub enum GemmPool {
686 Goto,
687 SmallM,
688 SmallN,
689}
690
691pub fn ap_size<T>(m: usize, k: usize) -> usize {
692 let vs = 64 / std::mem::size_of::<T>();
693 let m_max = (m + vs - 1) / vs * vs;
694 m_max * k + AB_ALIGN / std::mem::size_of::<T>()
695}
696
697pub fn bp_size<T>(n: usize, k: usize) -> usize {
698 n * k + AB_ALIGN / std::mem::size_of::<T>()
699}
700
701pub fn ap_size_int<T, P>(m: usize, k: usize) -> usize {
702 let vs = 64 / std::mem::size_of::<T>();
703 let c_r = std::mem::size_of::<P>() / std::mem::size_of::<T>();
704 let k_r = (k + c_r - 1) / c_r * c_r;
705 let m_max = (m + vs - 1) / vs * vs;
706 m_max * k_r + AB_ALIGN / std::mem::size_of::<T>()
707}
708
709pub fn bp_size_int<T, P>(n: usize, k: usize) -> usize {
710 let c_r = std::mem::size_of::<P>() / std::mem::size_of::<T>();
711 let k_r = (k + c_r - 1) / c_r * c_r;
712 n * k_r + AB_ALIGN / std::mem::size_of::<T>()
713}
714
715#[derive(Clone, Copy)]
716pub struct StridedMatrix<T> {
717 pub(crate) src: *const T,
718 pub(crate) rs: usize,
719 pub(crate) cs: usize,
720}
721
722impl<T> StridedMatrix<T> {
723 pub fn new(src: *const T, rs: usize, cs: usize) -> Self {
724 Self { src, rs, cs }
725 }
726}
727
728unsafe impl<T> Send for StridedMatrix<T> {}
729
730#[derive(Clone, Copy)]
731pub struct StridedMatrixMut<T> {
732 pub(crate) src: *mut T,
733 pub(crate) rs: usize,
734 pub(crate) cs: usize,
735}
736
737unsafe impl<T> Send for StridedMatrixMut<T> {}
738
739impl<T> StridedMatrixMut<T> {
740 pub fn new(src: *mut T, rs: usize, cs: usize) -> Self {
741 Self { src, rs, cs }
742 }
743}
744
745#[derive(Clone)]
746pub struct StridedMatrixP<'a, T, U> {
747 pub(crate) src: *const T,
748 pub(crate) rs: usize,
749 pub(crate) cs: usize,
750 pub(crate) dst: &'a RangeLock<'a, U>,
751}
752
753unsafe impl<'a, T, U> Send for StridedMatrixP<'a, T, U> {}
754
755impl<'a, T, U> StridedMatrixP<'a, T, U> {
756 pub fn src(&self) -> *const T {
757 self.src
758 }
759 pub fn dst_w(&self, idx: usize, kc: usize) -> RangeLockWriteGuard<'a, 'a, U> {
760 self.dst.write(idx, kc).unwrap()
761 }
762 pub fn dst_r(&self) -> RangeLockReadGuard<'a, 'a, U> {
763 self.dst.read().unwrap()
764 }
765 pub fn get_mc(&self) -> usize {
766 self.dst.get_mc()
767 }
768 pub fn rs(&self) -> usize {
769 self.rs
770 }
771 pub fn cs(&self) -> usize {
772 self.cs
773 }
774}
775
776#[derive(Clone, Copy)]
777pub struct PackedMatrix<T> {
778 pub(crate) src: *const T,
779 pub(crate) k: usize,
780 pub(crate) m: usize,
781}
782
783unsafe impl<T> Send for PackedMatrix<T> {}
784
785impl<T> PackedMatrix<T> {
786 pub fn src(&self) -> *const T {
787 self.src
788 }
789 pub fn k(&self) -> usize {
790 self.k
791 }
792 pub fn m(&self) -> usize {
793 self.m
794 }
795}
796
797#[derive(Clone)]
798pub struct PackedMatrixMixed<'a, X, Y> {
799 pub(crate) src: *const X,
800 pub(crate) dst: &'a RangeLock<'a, Y>,
801 pub(crate) k: usize,
802 pub(crate) m: usize,
803}
804
805impl<'a, X, Y> PackedMatrixMixed<'a, X, Y> {
806 pub fn src(&self) -> *const X {
807 self.src
808 }
809 pub fn k(&self) -> usize {
810 self.k
811 }
812 pub fn m(&self) -> usize {
813 self.m
814 }
815
816 pub fn dst_w(&self, idx: usize, kc: usize) -> RangeLockWriteGuard<'a, 'a, Y> {
817 self.dst.write(idx, kc).unwrap()
818 }
819
820 pub fn get_mc(&self) -> usize {
821 self.dst.get_mc()
822 }
823
824 pub fn dst_r(&self) -> RangeLockReadGuard<'a, 'a, Y> {
825 self.dst.read().unwrap()
826 }
827}
828
829unsafe impl<X, Y> Send for PackedMatrixMixed<'_, X, Y> {}
830
831pub const AB_ALIGN: usize = 1024;
834
835pub trait GemmCache {
836 fn mr(&self) -> usize;
837 fn nr(&self) -> usize;
838 fn get_mc_eff(&self, par: usize) -> usize;
839 fn get_kc_eff(&self) -> usize;
840 fn get_nc_eff(&self, par: usize) -> usize;
841 fn get_ap_pool_size(&self, ic_par: usize) -> usize {
842 let mc_eff = self.get_mc_eff(ic_par);
843 let kc_eff = self.get_kc_eff();
844 mc_eff * kc_eff
845 }
846 fn get_ap_pool_size2(&self) -> usize {
847 let kc_eff = self.get_kc_eff();
848 self.mr() * kc_eff
849 }
850 fn get_bp_pool_size(&self, jc_par: usize) -> usize {
851 let nc_eff = self.get_nc_eff(jc_par);
852 let kc_eff = self.get_kc_eff();
853 nc_eff * kc_eff
854 }
855}
856
857#[derive(Copy, Clone)]
858pub enum Array<X> {
859 StridedMatrix(StridedMatrix<X>),
860 PackedMatrix(PackedMatrix<X>),
861}
862
863impl<X> Array<X> {
864 pub fn strided_matrix(src: *const X, rs: usize, cs: usize) -> Self {
865 Array::StridedMatrix(StridedMatrix::new(src, rs, cs))
866 }
867 pub fn packed_matrix(src: *const X, m: usize, k: usize) -> Self {
868 Array::PackedMatrix(PackedMatrix { src, k, m })
869 }
870 pub fn into_pack_array<'a>(&self, a: &'a [RangeLock<'a, X>], p_id: usize) -> PArray<'a, X> {
871 match self {
872 Array::StridedMatrix(x) => {
873 let x = StridedMatrixP { src: x.src, rs: x.rs, cs: x.cs, dst: &a[p_id] };
874 PArray::<X>::StridedMatrix(x)
875 }
876 Array::PackedMatrix(x) => {
877 let x = PackedMatrix { src: x.src, k: x.k, m: x.m };
878 PArray::PackedMatrix(x)
879 }
880 }
881 }
882 pub fn into_pack_array2<'a, Y>(&self, a: &'a [RangeLock<'a, Y>], p_id: usize) -> PArrayMixed<'a, X, Y> {
883 match self {
884 Array::StridedMatrix(x) => {
885 let x = StridedMatrixP { src: x.src, rs: x.rs, cs: x.cs, dst: &a[p_id] };
886 PArrayMixed::<X, Y>::StridedMatrix(x)
887 }
888 Array::PackedMatrix(x) => {
889 let x = PackedMatrixMixed { src: x.src, dst: &a[p_id], k: x.k, m: x.m };
890 PArrayMixed::PackedMatrix(x)
891 }
892 }
893 }
894
895 pub fn src(&self) -> *const X {
896 match self {
897 Array::StridedMatrix(x) => x.src,
898 Array::PackedMatrix(x) => x.src,
899 }
900 }
901
902 pub fn transpose(&mut self) {
903 match self {
904 Array::StridedMatrix(x) => {
905 let temp = x.rs;
906 x.rs = x.cs;
907 x.cs = temp;
908 }
909 _ => {
910 panic!("Only StridedMatrix has transpose");
911 }
912 }
913 }
914
915 pub fn rs(&self) -> usize {
916 match self {
917 Array::StridedMatrix(x) => x.rs,
918 _ => {
919 panic!("Only StridedMatrix has rs");
920 }
921 }
922 }
923
924 pub fn cs(&self) -> usize {
925 match self {
926 Array::StridedMatrix(x) => x.cs,
927 _ => {
928 panic!("Only StridedMatrix has cs");
929 }
930 }
931 }
932
933 pub fn is_strided(&self) -> bool {
934 match self {
935 Array::StridedMatrix(_) => true,
936 _ => false,
937 }
938 }
939}
940
941#[derive(Copy, Clone)]
942pub enum ArrayMut<X> {
943 StridedMatrix(StridedMatrixMut<X>),
944}
945
946impl<X> ArrayMut<X> {
947 pub fn strided_matrix(src: *mut X, rs: usize, cs: usize) -> Self {
948 ArrayMut::StridedMatrix(StridedMatrixMut::new(src, rs, cs))
949 }
950
951 pub fn src(&self) -> *mut X {
952 match self {
953 ArrayMut::StridedMatrix(x) => x.src,
954 }
955 }
956
957 pub fn transpose(&mut self) {
958 match self {
959 ArrayMut::StridedMatrix(x) => {
960 let temp = x.rs;
961 x.rs = x.cs;
962 x.cs = temp;
963 }
964 }
965 }
966
967 pub fn rs(&self) -> usize {
968 match self {
969 ArrayMut::StridedMatrix(x) => x.rs,
970 }
971 }
972
973 pub fn cs(&self) -> usize {
974 match self {
975 ArrayMut::StridedMatrix(x) => x.cs,
976 }
977 }
978}
979
980#[derive(Clone)]
981pub enum PArray<'a, X> {
982 StridedMatrix(StridedMatrixP<'a, X, X>),
983 PackedMatrix(PackedMatrix<X>),
984}
985
986impl<'a, X> PArray<'a, X> {
987 pub fn src(&self) -> *const X {
988 match self {
989 Self::StridedMatrix(x) => x.src,
990 Self::PackedMatrix(x) => x.src,
991 }
992 }
993
994 pub fn rs(&self) -> usize {
995 match self {
996 Self::StridedMatrix(x) => x.rs,
997 _ => {
998 panic!("Only StridedMatrix has rs");
999 }
1000 }
1001 }
1002
1003 pub fn cs(&self) -> usize {
1004 match self {
1005 Self::StridedMatrix(x) => x.cs,
1006 _ => {
1007 panic!("Only StridedMatrix has cs");
1008 }
1009 }
1010 }
1011
1012 pub fn dst_w(&self, idx: usize, kc: usize) -> RangeLockWriteGuard<'a, 'a, X> {
1013 match self {
1014 Self::StridedMatrix(x) => x.dst.write(idx, kc).unwrap(),
1015 _ => {
1016 panic!("Only StridedMatrix has write guard");
1017 }
1018 }
1019 }
1020
1021 pub fn dst_r(&self) -> RangeLockReadGuard<'a, 'a, X> {
1022 match self {
1023 Self::StridedMatrix(x) => x.dst.read().unwrap(),
1024 _ => {
1025 panic!("Only StridedMatrix has read guard");
1026 }
1027 }
1028 }
1029
1030 pub fn is_strided(&self) -> bool {
1031 match self {
1032 Self::StridedMatrix(_) => true,
1033 _ => false,
1034 }
1035 }
1036}
1037
1038#[derive(Clone)]
1039pub enum PArrayMixed<'a, X, Y> {
1040 StridedMatrix(StridedMatrixP<'a, X, Y>),
1041 PackedMatrix(PackedMatrixMixed<'a, X, Y>),
1042}
1043
1044impl<'a, X, Y> PArrayMixed<'a, X, Y> {
1045 pub fn src(&self) -> *const X {
1046 match self {
1047 Self::StridedMatrix(x) => x.src,
1048 Self::PackedMatrix(x) => x.src,
1049 }
1050 }
1051
1052 pub fn rs(&self) -> usize {
1053 match self {
1054 Self::StridedMatrix(x) => x.rs,
1055 _ => {
1056 panic!("Only StridedMatrix has rs");
1057 }
1058 }
1059 }
1060
1061 pub fn cs(&self) -> usize {
1062 match self {
1063 Self::StridedMatrix(x) => x.cs,
1064 _ => {
1065 panic!("Only StridedMatrix has cs");
1066 }
1067 }
1068 }
1069
1070 pub fn dst_w(&self, idx: usize, kc: usize) -> RangeLockWriteGuard<'a, 'a, Y> {
1071 match self {
1072 Self::StridedMatrix(x) => x.dst.write(idx, kc).unwrap(),
1073 Self::PackedMatrix(x) => x.dst.write(idx, kc).unwrap(),
1074 }
1075 }
1076 pub fn dst_r(&self) -> RangeLockReadGuard<'a, 'a, Y> {
1077 match self {
1078 Self::StridedMatrix(x) => x.dst.read().unwrap(),
1079 Self::PackedMatrix(x) => x.dst.read().unwrap(),
1080 }
1081 }
1082 pub fn is_strided(&self) -> bool {
1083 match self {
1084 Self::StridedMatrix(_) => true,
1085 _ => false,
1086 }
1087 }
1088}
1089
1090pub enum PtrData<'a, X> {
1091 RefData(RangeLockReadGuard<'a, 'a, X>),
1092 PtrData(*const X),
1093}
1094
1095impl<'a, X> PtrData<'a, X> {
1096 pub fn src(&self) -> *const X {
1097 match self {
1098 PtrData::RefData(x) => x.get().as_ptr(),
1099 PtrData::PtrData(x) => x.clone(),
1100 }
1101 }
1102}
1103
1104#[macro_export]
1105macro_rules! is_mixed {
1106 (T, $st1:expr, $st2:expr) => {
1107 $st1
1108 };
1109 (F, $src:expr, $st2:expr) => {
1110 $st2
1111 };
1112}
1113
1114#[macro_export]
1115macro_rules! def_pa {
1116 ($packa_ty:tt,F,$ta:tt,$tap:tt) => {
1117 type $packa_ty<'a> = PArray<'a, $tap>;
1118 };
1119 ($packa_ty:tt,T,$ta:tt,$tap:tt) => {
1120 type $packa_ty<'a> = PArrayMixed<'a, $ta, $tap>;
1121 };
1122}
1123
1124#[macro_export]
1125macro_rules! def_glar_gemm {
1126 (
1127 $t_dispatcher:tt,
1128 $ta:tt,$tap:ty,$tb:ty,$tbp:ty,$tc:ty,$t_as:ty,$t_bs:ty,
1129 $packa_ty:tt,$packb_ty:tt,
1130 $one:expr,
1131 $name:ident, $name_mt:ident,
1132 $goto_name:ident, $goto_kernel:ident,
1133 $small_m_name:ident, $small_m_kernel:ident,
1134 $small_n_name:ident, $small_n_kernel:ident,
1135 $gemv_name:ident, $gemv_name2:ident,
1136 $packa_name:ident, $packb_name:ident,
1137 $run_small_m:expr, $run_small_n:expr,
1138 $pack_fn:tt, $include_flag:tt,
1139 ) => {
1140 def_pa!($packa_ty,$include_flag,$ta,$tap);
1141 def_pa!($packb_ty,$include_flag,$tb,$tbp);
1142 pub unsafe fn $name <F:MyFn>(
1143 hw_config: &$t_dispatcher <F>,
1144 m: usize, n: usize, k: usize,
1145 alpha: $t_as,
1146 a: Array<$ta>,
1147 b: Array<$tb>,
1148 beta: $t_bs,
1149 c: ArrayMut<$tc>,
1150 par: &GlarPar,
1151 )
1152 {
1153 let a_need_pool = a.is_strided() || !hw_config.is_compute_native();
1154 let b_need_pool = b.is_strided() || !hw_config.is_compute_native();
1155 if n == 1 && a.is_strided() {
1156 let alpha = &alpha as *const $t_as;
1157 let beta = &beta as *const $t_bs;
1158 $gemv_name(hw_config, m, k, alpha, a, b, beta, c);
1159 return;
1160 }
1161 if m == 1 && b.is_strided() {
1162 let alpha = &alpha as *const $t_as;
1163 let beta = &beta as *const $t_bs;
1164 let mut a = a;
1165 a.transpose();
1166 let mut b = b;
1167 b.transpose();
1168 let mut c = c;
1169 c.transpose();
1170 $gemv_name2(hw_config, n, k, alpha.into(), b, a, beta, c);
1171 return;
1172 }
1173 let (gemm_mode, gemm_fun, pool_info)
1174 : (
1175 GemmPool, unsafe fn(
1176 &$t_dispatcher <F>, usize, usize, usize, *const $t_as, $packa_ty, $packb_ty, *const $t_bs, ArrayMut<$tc>, &GlarThreadConfig,
1177 ),
1178 PoolSize
1179 )
1180 = if run_small_m(m) && $run_small_m && b.is_strided() {
1181 (GemmPool::SmallM, $small_m_name, get_mem_pool_size_small_m::<$tap,$tbp,$t_dispatcher::<F>>(hw_config, par, a_need_pool))
1182 } else if run_small_n(n) && $run_small_n && a.is_strided() {
1183 (GemmPool::SmallN, $small_n_name, get_mem_pool_size_small_n::<$tap,$tbp,$t_dispatcher::<F>>(hw_config, par, b_need_pool))
1184 } else {
1185 (GemmPool::Goto, $goto_name, get_mem_pool_size_goto::<$tap,$tbp,$t_dispatcher::<F>>(hw_config, par, a_need_pool, b_need_pool))
1186 };
1187 let mem_pool_size = pool_info.mem_pool_size_b::<$tap,$tbp>();
1188 {
1199 let pool_guard = PACK_POOL.buffer.read().unwrap();
1200 let y = acquire(&pool_guard, mem_pool_size);
1201 if let Some(mut pool_vec) = y {
1202 let pool_buf = &mut pool_vec;
1203 $name_mt(
1204 hw_config, m, n, k, alpha, a, b, beta, c, par, pool_buf, gemm_mode, pool_info, gemm_fun
1205 );
1206 return;
1207 }
1208 }
1209 let mut pool_vec = vec![0_u8; mem_pool_size];
1210 let pool_buf = &mut pool_vec;
1211 $name_mt(
1212 hw_config, m, n, k, alpha, a, b, beta, c, par, pool_buf, gemm_mode, pool_info, gemm_fun
1213 );
1214 extend(pool_vec);
1215 }
1216
1217 pub unsafe fn $name_mt<F:MyFn>(
1218 hw_config: &$t_dispatcher <F>,
1219 m: usize, n: usize, k: usize,
1220 alpha: $t_as,
1221 a: Array<$ta>,
1222 b: Array<$tb>,
1223 beta: $t_bs,
1224 c: ArrayMut<$tc>,
1225 par: &GlarPar,
1226 pool_buf: &mut [u8],
1227 gemm_mode: GemmPool,
1228 pool_info: PoolSize,
1229 gemm_fn: unsafe fn(
1230 &$t_dispatcher <F>, usize, usize, usize, *const $t_as, $packa_ty, $packb_ty, *const $t_bs, ArrayMut<$tc>, &GlarThreadConfig
1231 )
1232 )
1233 where $t_dispatcher <F>: GemmCache
1234 {
1235
1236 let mc_eff = <$t_dispatcher::<F> as GemmCache>::get_mc_eff(hw_config, par.ic_par);
1237 let nc_eff = <$t_dispatcher::<F> as GemmCache>::get_nc_eff(hw_config, par.jc_par);
1238 let kc_eff = <$t_dispatcher::<F> as GemmCache>::get_kc_eff(hw_config);
1239 let (pa_br_vec_ref, pb_br_vec_ref) = get_apbp_barrier(par);
1240
1241 let (i_load_par, j_load_par) = par.get_load_par(&gemm_mode, m, n, mc_eff, nc_eff);
1242 let (ap_pool_vec, bp_pool_vec) = pool_info.slice_mut_from_pool::<$tap,$tbp>(
1243 pool_buf, i_load_par, j_load_par, pool_info, hw_config.mr, hw_config.nr
1244 );
1245 let (ap_pool, bp_pool) = (&ap_pool_vec, &bp_pool_vec);
1246
1247 std::thread::scope(|s| {
1249 for t_id in 1..par.num_threads {
1250 let t_cfg = GlarThreadConfig::new(
1251 par.clone(), &pa_br_vec_ref, &pb_br_vec_ref, t_id, mc_eff, nc_eff, kc_eff
1252 );
1253 let ic_id = t_cfg.ic_id;
1254 let jc_id = t_cfg.jc_id;
1255 let ap_id = match gemm_mode {
1256 GemmPool::Goto => ic_id,
1257 GemmPool::SmallM => ic_id,
1258 GemmPool::SmallN => t_id,
1259 };
1260 let bp_id = match gemm_mode {
1261 GemmPool::Goto => jc_id,
1262 GemmPool::SmallM => 0,
1263 GemmPool::SmallN => jc_id,
1264 };
1265 let ap_cur = a.$pack_fn(ap_pool, ap_id);
1266 let bp_cur = b.$pack_fn(bp_pool, bp_id);
1267 let g = hw_config;
1268 s.spawn(move || {
1269 let alpha = &alpha as *const $t_as;
1270 let beta = &beta as *const $t_bs;
1271 gemm_fn(g, m, n, k, alpha, ap_cur, bp_cur, beta, c, &t_cfg);
1272 }
1273 );
1274 }
1275 {
1276 let ap = a.$pack_fn(ap_pool, 0);
1277 let bp = b.$pack_fn(bp_pool, 0);
1278 let t_id: usize = 0;
1279 let t_cfg = GlarThreadConfig::new(par.clone(), &pa_br_vec_ref, &pb_br_vec_ref, t_id, mc_eff, nc_eff, kc_eff);
1280 let alpha = &alpha as *const $t_as;
1281 let beta = &beta as *const $t_bs;
1282 gemm_fn(hw_config, m, n, k, alpha, ap, bp, beta, c, &t_cfg);
1283 }
1284 });
1285 }
1286
1287 unsafe fn $goto_name<F:MyFn>(
1288 hw_cfg: &$t_dispatcher <F>,
1289 m: usize, n: usize, k: usize,
1290 alpha: *const $t_as,
1291 a: $packa_ty,
1292 b: $packb_ty,
1293 beta: *const $t_bs,
1294 c: ArrayMut<$tc>,
1295 t_cfg: &GlarThreadConfig
1296 ) {
1297 let ic_id = t_cfg.ic_id;
1298 let jc_id = t_cfg.jc_id;
1299 let ir_id = t_cfg.ir_id;
1300 let jr_id = t_cfg.jr_id;
1301 let ir_par = t_cfg.par.ir_par;
1302 let jr_par = t_cfg.par.jr_par;
1303 let ic_par = t_cfg.par.ic_par;
1304 let jc_par = t_cfg.par.jc_par;
1305 let mc = t_cfg.mc_eff;
1306 let nc = t_cfg.nc_eff;
1307 let kc = t_cfg.kc_eff;
1308 let mr = hw_cfg.mr;
1309 let nr = hw_cfg.nr;
1310 let (mc_start, mc_end, mc_left) = split_c_range(m, mc, mr, ic_id, ic_par);
1311 let (nc_start, nc_end, nc_left) = split_c_range(n, nc, nr, jc_id, jc_par);
1312 let (kc_start, d1_end) = (0, k);
1313 let one = $one;
1314 let c_rs = c.rs();
1315 let c_cs = c.cs();
1316 let c_ptr = c.src();
1317 let mut mc_i = mc_start;
1318 while mc_i < mc_end {
1319 let mc_len = mc.min(mc_end - mc_i);
1320 let mut kc_i = kc_start;
1321 let (mr_start, mr_end) = split_range(mc_len, mr, ir_id, ir_par);
1322 let mr_len = mr_end - mr_start;
1323 let c_i = c_ptr.add((mc_i+mr_start) * c_rs);
1324 while kc_i < d1_end {
1325 let kc_len = kc.min(d1_end - kc_i);
1326 let kc_len_eff = hw_cfg.round_up(kc_len);
1327 let mut nc_i = nc_start;
1328 let kc_last = kc_i + kc_len == d1_end;
1329 let kc_first = kc_i == kc_start;
1330 let beta_t = if kc_i == kc_start { beta } else { &one as *const $t_bs};
1331 let ap_data = $packa_name(hw_cfg, &a, mc_i, kc_i, mc_len, kc_len, t_cfg);
1332 let ap = ap_data.src();
1333 let ap = ap.add(mr_start*kc_len_eff);
1334 while nc_i < nc_end {
1335 let nc_len = nc.min(nc_end - nc_i);
1336 let (nr_start, nr_end) = split_range(nc_len, nr, jr_id, jr_par);
1337 let nr_len = nr_end - nr_start;
1338 let c_ij = c_i.add((nc_i+nr_start) * c_cs);
1339 let bp_data = $packb_name(hw_cfg, &b, nc_i, kc_i, nc_len, kc_len, t_cfg);
1340 let bp = bp_data.src();
1341 let bp = bp.add(nr_start*kc_len_eff);
1342 $goto_kernel(
1343 hw_cfg, mr_len, nr_len, kc_len, alpha, beta_t, c_ij, c_rs, c_cs,
1344 ap, bp,
1345 kc_last, kc_first
1346 );
1347
1348 nc_i += nc;
1349 }
1350 if nc_left {
1351 t_cfg.wait_packb();
1352 t_cfg.wait_packb();
1353 }
1354 kc_i += kc;
1355 }
1356 mc_i += mc;
1357 }
1358 if mc_left {
1359 let mut kc_i = kc_start;
1360 while kc_i < d1_end {
1361 let kc_len = kc.min(d1_end -kc_i);
1362 t_cfg.wait_packa();
1363 t_cfg.wait_packa();
1364 let mut nc_i = nc_start;
1365 while nc_i < nc_end {
1366 let nc_len = nc.min(nc_end - nc_i);
1367 let _ = $packb_name(hw_cfg, &b, nc_i, kc_i, nc_len, kc_len, t_cfg);
1368 nc_i += nc;
1369 }
1370 if nc_left{
1371 t_cfg.wait_packb();
1372 t_cfg.wait_packb();
1373 }
1374 kc_i += kc;
1375 }
1376 }
1377 }
1378 unsafe fn $small_m_name<F:MyFn>(
1379 hw_cfg: &$t_dispatcher <F>,
1380 m: usize, n: usize, k: usize,
1381 alpha: *const $t_as,
1382 a: $packa_ty,
1383 b: $packb_ty,
1384 beta: *const $t_bs,
1385 c: ArrayMut<$tc>,
1386 t_cfg: &GlarThreadConfig
1387 ) {
1388 let par = &t_cfg.par;
1389 let ic_id = t_cfg.ic_id;
1390 let jc_id = t_cfg.jc_id;
1391 let ir_id = t_cfg.ir_id;
1392 let ir_par = par.ir_par;
1393 let jr_id = t_cfg.jr_id;
1394 let jr_par = par.jr_par;
1395 let mc = t_cfg.mc_eff;
1396 let nc = t_cfg.nc_eff;
1397 let kc = t_cfg.kc_eff;
1398 let mr = hw_cfg.mr;
1399 let nr = hw_cfg.nr;
1400 let (mc_start, mc_end, mc_left) = split_c_range(m, mc, mr, ic_id, par.ic_par);
1401 let (nc_start, nc_end, _) = split_c_range(n, nc, nr, jc_id, par.jc_par);
1402 let (kc_start, kc_end) = (0, k);
1403 let one = $one;
1404
1405 let b_ptr = b.src();
1406 let b_rs = b.rs();
1407 let b_cs = b.cs();
1408 let c_rs = c.rs();
1409 let c_cs = c.cs();
1410 let c_ptr = c.src();
1411 let mut mc_i = mc_start;
1412 while mc_i < mc_end {
1413 let mc_len = mc.min(mc_end - mc_i);
1414 let (mr_start, mr_end) = split_range(mc_len, mr, ir_id, ir_par);
1415 let mr_len = mr_end - mr_start;
1416 let c_i = c_ptr.add((mc_i+mr_start) * c_rs);
1417 let mut kc_i = kc_start;
1418 while kc_i < kc_end {
1419 let kc_len = kc.min(kc_end - kc_i);
1420 let kc_len_eff = hw_cfg.round_up(kc_len);
1421 let beta_t = if kc_i == kc_start { beta } else { &one as *const $t_bs};
1422 let kc_last = kc_i + kc_len == kc_end;
1423 let kc_first = kc_i == kc_start;
1424 let mut nc_i = nc_start;
1425 let ap_data = $packa_name(hw_cfg, &a, mc_i, kc_i, mc_len, kc_len, t_cfg);
1426 let ap = ap_data.src();
1427 let ap = ap.add(mr_start*kc_len_eff);
1428 let b_j = b_ptr.add(kc_i * b_rs);
1429 while nc_i < nc_end {
1430 let nc_len = nc.min(nc_end - nc_i);
1431 let (nr_start, nr_end) = split_range(nc_len, nr, jr_id, jr_par);
1432 let nr_len = nr_end - nr_start;
1433 let c_ij = c_i.add((nc_i + nr_start) * c_cs);
1434 let b_cur = b_j.add((nc_i + nr_start) * b_cs);
1435 $small_m_kernel(
1436 hw_cfg, mr_len, nr_len, kc_len, alpha, beta_t,
1437 b_cur, b_rs, b_cs,
1438 c_ij, c_rs, c_cs,
1439 ap,
1440 kc_last, kc_first
1441 );
1442 nc_i += nc;
1443 }
1444 kc_i += kc;
1445 }
1446 mc_i += mc;
1447 }
1448
1449 if mc_left {
1450 let mut kc_i = kc_start;
1451 while kc_i < kc_end {
1452 t_cfg.wait_packa();
1453 t_cfg.wait_packa();
1454 kc_i += kc;
1455 }
1456 }
1457 }
1458 unsafe fn $small_n_name<F:MyFn>(
1459 hw_cfg: &$t_dispatcher <F>,
1460 m: usize, n: usize, k: usize,
1461 alpha: *const $t_as,
1462 a: $packa_ty,
1463 b: $packb_ty,
1464 beta: *const $t_bs,
1465 c: ArrayMut<$tc>,
1466 t_cfg: &GlarThreadConfig
1467 ) {
1468 let par = &t_cfg.par;
1469 let ic_id = t_cfg.ic_id;
1470 let jc_id = t_cfg.jc_id;
1471 let ir_id = t_cfg.ir_id;
1472 let ir_par = par.ir_par;
1473 let jr_id = t_cfg.jr_id;
1474 let jr_par = par.jr_par;
1475 let mc = t_cfg.mc_eff;
1476 let nc = t_cfg.nc_eff;
1477 let kc = t_cfg.kc_eff;
1478 let mr = hw_cfg.mr;
1479 let nr = hw_cfg.nr;
1480 let (mc_start, mc_end, mc_left) = split_c_range(m, mc, mr, ic_id, par.ic_par);
1481 let (nc_start, nc_end, nc_left) = split_c_range(n, nc, nr, jc_id, par.jc_par);
1482 let (kc_start, kc_end) = (0, k);
1483 let one = $one;
1484
1485 let c_rs = c.rs();
1486 let c_cs = c.cs();
1487 let c_ptr = c.src();
1488 let a_ptr = a.src();
1489 let a_rs = a.rs();
1490 let a_cs = a.cs();
1491 let mut a_dst = a.dst_w(0, kc);
1493 let a_dst_ref = a_dst.get();
1494 let a_dst_ptr = a_dst_ref.as_mut_ptr();
1495 let mut mc_i = mc_start;
1496 while mc_i < mc_end {
1497 let mc_len = mc.min(mc_end - mc_i);
1498 let (mr_start, mr_end) = split_range(mc_len, mr, ir_id, ir_par);
1499 let mr_len = mr_end - mr_start;
1500 let c_i = c_ptr.add((mc_i+mr_start) * c_rs);
1501 let a_i = a_ptr.add((mc_i+mr_start) * a_rs);
1502 let mut kc_i = kc_start;
1503 while kc_i < kc_end {
1504 let kc_len = kc.min(kc_end - kc_i);
1505 let kc_last = kc_i + kc_len == kc_end;
1506 let kc_first = kc_i == kc_start;
1507 let beta_t = if kc_i == kc_start { beta } else { &one as *const $t_bs};
1508 let a_cur = a_i.add(kc_i*a_cs);
1509 let mut nc_i = nc_start;
1510 while nc_i < nc_end {
1511 let nc_len = nc.min(nc_end - nc_i);
1512 let (nr_start, nr_end) = split_range(nc_len, nr, jr_id, jr_par);
1513 let nr_len = nr_end - nr_start;
1514 let bp_data = $packb_name(hw_cfg, &b, nc_i, kc_i, nc_len, kc_len, t_cfg);
1515 let bp = bp_data.src();
1516 let c_ij = c_i.add((nc_i + nr_start) * c_cs);
1517 $small_n_kernel(
1518 hw_cfg, mr_len, nr_len, kc_len, alpha, beta_t,
1519 a_cur, a_rs, a_cs,
1520 a_dst_ptr, bp,
1521 c_ij, c_rs, c_cs,
1522 kc_last, kc_first
1523 );
1524 nc_i += nc;
1525 }
1526 if nc_left {
1527 t_cfg.wait_packb();
1528 t_cfg.wait_packb();
1529 }
1530 kc_i += kc;
1531 }
1532 mc_i += mc;
1533 }
1534 if mc_left {
1535 let mut kc_i = kc_start;
1536 while kc_i < kc_end {
1537 let kc_len = kc.min(kc_end - kc_i);
1538 let mut nc_i = nc_start;
1539 while nc_i < nc_end {
1540 let nc_len = nc.min(nc_end - nc_i);
1541 let _ = $packb_name(hw_cfg, &b, nc_i, kc_i, nc_len, kc_len, t_cfg);
1542 nc_i += nc;
1543 }
1544 if nc_left{
1545 t_cfg.wait_packb();
1546 t_cfg.wait_packb();
1547 }
1548 kc_i += kc;
1549 }
1550 }
1551 }
1552 pub(crate) unsafe fn $packa_name<'a,'b,F:MyFn>(hw_cfg: &$t_dispatcher <F>, x: &'b $packa_ty<'a>, mc_i: usize, kc_i: usize, mc_len: usize, kc_len: usize, t_cfg: &GlarThreadConfig) -> PtrData<'a,$tap> {
1558 t_cfg.wait_packa();
1559 let xp_ptr = match x {
1560 $packa_ty::StridedMatrix(x_i) => {
1561 let mc_par = x_i.get_mc();
1562 let mc_offset = mc_par * t_cfg.i_load_p_idx;
1563 if mc_len > mc_offset {
1564 let kc_len_ro = hw_cfg.round_up(kc_len);
1565 let mc_len_x = (mc_len - mc_offset).min(mc_par);
1566 let mc_i = mc_i + mc_offset;
1567 let rs = x_i.rs();
1568 let cs = x_i.cs();
1569 let src_ptr = x_i.src().add(mc_i*rs + kc_i*cs);
1570 let mut dst = x_i.dst_w(t_cfg.i_load_p_idx, kc_len_ro);
1571 let dst_ref = dst.get();
1572 let dst_ptr = dst_ref.as_mut_ptr();
1573 hw_cfg.packa_fn(src_ptr, dst_ptr, mc_len_x, kc_len, rs, cs);
1574 }
1575 t_cfg.wait_packa();
1576 PtrData::RefData(x_i.dst_r())
1577 }
1578 $packa_ty::PackedMatrix(x_i) => {
1579 let vs = hw_cfg.vs;
1580 let m_ro = (x_i.m() + vs - 1) / vs * vs;
1581 let kc_len_ro = hw_cfg.round_up(kc_len);
1582 let res = is_mixed!(
1583 $include_flag,
1584 {
1585 let mc_par = x_i.get_mc();
1586 let mc_offset = mc_par * t_cfg.i_load_p_idx;
1587 let mc_len_ro = (mc_len + vs - 1) / vs * vs;
1588 if mc_len_ro > mc_offset {
1589 let mc_len_ro_x = (mc_len_ro - mc_offset).min(mc_par);
1590 let mc_i = mc_i + mc_offset;
1591 let src_ptr = x_i.src().add(mc_i*kc_len_ro + kc_i*m_ro);
1592 let mut dst = x_i.dst_w(t_cfg.i_load_p_idx, kc_len_ro);
1593 let dst_ref = dst.get();
1594 let dst_ptr = dst_ref.as_mut_ptr();
1595 hw_cfg.cvt_mixed(src_ptr, dst_ptr, mc_len_ro_x*kc_len_ro);
1596 }
1597 t_cfg.wait_packa();
1598 PtrData::RefData(x_i.dst_r())
1599 },
1600 {
1601 let src_ptr = x_i.src().add(mc_i*kc_len_ro + kc_i*m_ro);
1602 t_cfg.wait_packa();
1603 PtrData::PtrData(src_ptr)
1604 }
1605
1606 );
1607 res
1608
1609 }
1610 };
1611 xp_ptr
1612 }
1613 pub(crate) unsafe fn $packb_name<'a,'b,F:MyFn>(hw_cfg: & $t_dispatcher <F>, x: &'b$packb_ty<'a>, nc_i: usize, kc_i: usize, nc_len: usize, kc_len: usize, t_cfg: &GlarThreadConfig) -> PtrData<'a,$tbp> {
1615 t_cfg.wait_packb();
1616 let xp_ptr = match x {
1617 $packb_ty::StridedMatrix(x_i) => {
1618 let nc_par = x_i.get_mc();
1619 let nc_offset = nc_par * t_cfg.j_load_p_idx;
1620 if nc_len > nc_offset {
1621 let kc_len_ro = hw_cfg.round_up(kc_len);
1622 let nc_len_x = (nc_len - nc_offset).min(nc_par);
1623 let nc_i = nc_i + nc_offset;
1624 let rs = x_i.rs();
1625 let cs = x_i.cs();
1626 let src_ptr = x_i.src().add(kc_i*rs + nc_i*cs);
1627 let mut dst = x_i.dst_w(t_cfg.j_load_p_idx, kc_len_ro);
1628 let dst_ref = dst.get();
1629 let dst_ptr = dst_ref.as_mut_ptr();
1630 hw_cfg.packb_fn(src_ptr, dst_ptr, nc_len_x, kc_len, rs, cs);
1631 }
1632 t_cfg.wait_packb();
1633 PtrData::RefData(x_i.dst_r())
1634 }
1635 $packb_ty::PackedMatrix(x_i) => {
1636 let kc_len_ro = hw_cfg.round_up(kc_len);
1637 let n_ro = x_i.m();
1638 let res = is_mixed!(
1639 $include_flag,
1640 {
1641 let nc_par = x_i.get_mc();
1642 let nc_offset = nc_par * t_cfg.j_load_p_idx;
1643 if nc_len > nc_offset {
1644 let nc_len_x = (nc_len - nc_offset).min(nc_par);
1645 let nc_i = nc_i + nc_offset;
1646 let src_ptr = x_i.src().add(nc_i*kc_len_ro + kc_i*n_ro);
1647 let mut dst = x_i.dst_w(t_cfg.j_load_p_idx, kc_len_ro);
1648 let dst_ref = dst.get();
1649 let dst_ptr = dst_ref.as_mut_ptr();
1650 hw_cfg.cvt_mixed(src_ptr, dst_ptr, nc_len_x*kc_len_ro);
1651 }
1652 t_cfg.wait_packb();
1653 PtrData::RefData(x_i.dst_r())
1654 },
1655 {
1656 let src_ptr = x_i.src().add(nc_i*kc_len_ro + kc_i*n_ro);
1657 t_cfg.wait_packb();
1658 PtrData::PtrData(src_ptr)
1659 }
1660
1661 );
1662 res
1663 }
1664 };
1665 xp_ptr
1666 }
1667 }
1668}
1669
1670#[macro_export]
1671macro_rules! def_kernel_bb_pf1_no_beta {
1672 (
1673 $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
1674 $MR:tt, $NR:tt, $pf1_0:tt, $pf_step:tt, $($mr_left:tt),*
1675 ) => {
1676 paste! {
1677 #[target_feature(enable = "avx")]
1678 pub unsafe fn kernel_bb<F: MyFn, const STRIDED: bool>(
1679 m: usize, n: usize, k: usize,
1680 alpha: *const $t_as,
1681 c: *mut $tc, c_rs: usize, c_cs: usize,
1682 ap: *const $ta, bp: *const $tb,
1683 f: F,
1684 ) {
1685 const MR: usize = $MR;
1686 const NR: usize = $NR;
1687 let m_rounded = m / MR * MR;
1688 let n_rounded = n / NR * NR;
1689 let m_left = m % MR;
1690 let n_left = n % NR;
1691
1692 let d_arr = [0, 0, c_rs, c_cs];
1693
1694 let mut m_i = 0;
1695 while m_i < m_rounded {
1696 let c_cur0 = c.add(m_i * c_rs);
1697 let ap_cur = ap.add(m_i * k);
1698 let mut a_pft1_offset = $pf1_0 * k;
1699 let mut n_i = 0;
1700 while n_i < n_rounded {
1701 let bp_cur = bp.add(n_i * k);
1702 let c_cur1 = c_cur0.add(n_i * c_cs);
1703 [<ukernel_$MR x $NR _bb>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, k, d_arr, a_pft1_offset, f);
1704 n_i += NR;
1705 a_pft1_offset += $pf_step * k;
1706 }
1707 if n_left != 0 {
1709 let bp_cur = bp.add(n_i * k);
1710 let c_cur1 = c_cur0.add(n_i * c_cs);
1711 [<ukernel_$MR x n _bb>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, k, d_arr, MR, n_left, f);
1712 }
1713 m_i += MR;
1714 }
1715
1716 $(
1717 if (m_left+VS-1) / VS * VS == $mr_left {
1718 let c_cur0 = c.add(m_i * c_rs);
1719 let ap_cur = ap.add(m_i * k);
1720 let mut n_i = 0;
1721 while n_i < n_rounded {
1722 let bp_cur = bp.add(n_i * k);
1723 let c_cur1 = c_cur0.add(n_i * c_cs);
1724 [<ukernel_$mr_left x $NR _bb_partial>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, k, d_arr, m_left, f);
1725 n_i += NR;
1726 }
1727 if n_left !=0 {
1728 let bp_cur = bp.add(n_i * k);
1729 let c_cur1 = c_cur0.add(n_i * c_cs);
1730 [<ukernel_$mr_left x n_bb_partial>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, k, d_arr, m_left, n_left, f);
1731 }
1732 }
1733 )*
1734
1735 asm!("vzeroupper");
1736 }
1737 }
1738 };
1739}
1740
1741#[macro_export]
1742macro_rules! def_kernel_bb_pf1 {
1743 (
1744 $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
1745 $MR:tt, $NR:tt, $pf1_0:tt, $pf_step:tt, $($mr_left:tt),*
1746 ) => {
1747 paste! {
1748 #[target_feature(enable = "avx")]
1749 pub unsafe fn kernel_bb<F: MyFn, const STRIDED: bool>(
1750 m: usize, n: usize, k: usize,
1751 alpha: *const $t_as,
1752 beta: *const $t_bs,
1753 c: *mut $tc, c_rs: usize, c_cs: usize,
1754 ap: *const $ta, bp: *const $tb,
1755 f: F,
1756 ) {
1757 const MR: usize = $MR;
1758 const NR: usize = $NR;
1759 let m_rounded = m / MR * MR;
1760 let n_rounded = n / NR * NR;
1761 let m_left = m % MR;
1762 let n_left = n % NR;
1763
1764 let d_arr = [0, 0, c_rs, c_cs];
1765
1766 let mut m_i = 0;
1767 while m_i < m_rounded {
1768 let c_cur0 = c.add(m_i * c_rs);
1769 let ap_cur = ap.add(m_i * k);
1770 let mut a_pft1_offset = $pf1_0 * k;
1771 let mut n_i = 0;
1772 while n_i < n_rounded {
1773 let bp_cur = bp.add(n_i * k);
1774 let c_cur1 = c_cur0.add(n_i * c_cs);
1775 [<ukernel_$MR x $NR _bb>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, a_pft1_offset, f);
1776 n_i += NR;
1777 a_pft1_offset += $pf_step * k;
1778 }
1779 if n_left != 0 {
1781 let bp_cur = bp.add(n_i * k);
1782 let c_cur1 = c_cur0.add(n_i * c_cs);
1783 [<ukernel_$MR x n _bb>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, MR, n_left, f);
1784 }
1785 m_i += MR;
1786 }
1787
1788
1789 $(
1790 if (m_left+VS-1) / VS * VS == $mr_left {
1791 let c_cur0 = c.add(m_i * c_rs);
1792 let ap_cur = ap.add(m_i * k);
1793 let mut n_i = 0;
1794 while n_i < n_rounded {
1795 let bp_cur = bp.add(n_i * k);
1796 let c_cur1 = c_cur0.add(n_i * c_cs);
1797 [<ukernel_$mr_left x $NR _bb_partial>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, m_left, f);
1798 n_i += NR;
1799 }
1800 if n_left !=0 {
1801 let bp_cur = bp.add(n_i * k);
1802 let c_cur1 = c_cur0.add(n_i * c_cs);
1803 [<ukernel_$mr_left x n_bb_partial>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, m_left, n_left, f);
1804 }
1805 }
1806 )*
1807
1808 asm!("vzeroupper");
1809 }
1810 }
1811 };
1812}
1813
1814#[macro_export]
1815macro_rules! def_kernel_bb_v0 {
1816 (
1817 $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
1818 $MR:tt, $NR:tt, $($mr_left:tt),*
1819 ) => {
1820 paste! {
1821 #[target_feature(enable = "avx")]
1822 pub unsafe fn [<kernel_$MR x $NR _bb>]<F: MyFn, const STRIDED: bool>(
1823 m: usize, n: usize, k: usize,
1824 alpha: *const $t_as,
1825 beta: *const $t_bs,
1826 c: *mut $tc, c_rs: usize, c_cs: usize,
1827 ap: *const $ta, bp: *const $tb,
1828 f: F,
1829 ) {
1830 const MR: usize = $MR;
1831 const NR: usize = $NR;
1832 let m_rounded = m / MR * MR;
1833 let n_rounded = n / NR * NR;
1834 let m_left = m % MR;
1835 let n_left = n % NR;
1836
1837 let d_arr = [0, 0, c_rs, c_cs];
1838
1839 let mut m_i = 0;
1840 while m_i < m_rounded {
1841 let c_cur0 = c.add(m_i * c_rs);
1842 let ap_cur = ap.add(m_i * k);
1843 let mut n_i = 0;
1844 while n_i < n_rounded {
1845 let bp_cur = bp.add(n_i * k);
1846 let c_cur1 = c_cur0.add(n_i * c_cs);
1847 [<ukernel_$MR x $NR _bb>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, MR, f);
1848 n_i += NR;
1849 }
1850 if n_left != 0 {
1851 let bp_cur = bp.add(n_i * k);
1852 let c_cur1 = c_cur0.add(n_i * c_cs);
1853 [<ukernel_$MR x n _bb>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, MR, n_left, f);
1854 }
1855 m_i += MR;
1856 }
1857
1858 $(
1859 if (m_left+VS-1) / VS * VS == $mr_left {
1860 let c_cur0 = c.add(m_i * c_rs);
1861 let ap_cur = ap.add(m_i * k);
1862 let mut n_i = 0;
1863 while n_i < n_rounded {
1864 let bp_cur = bp.add(n_i * k);
1865 let c_cur1 = c_cur0.add(n_i * c_cs);
1866 [<ukernel_$mr_left x $NR _bb_partial>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, m_left, f);
1867 n_i += NR;
1868 }
1869 if n_left !=0 {
1870 let bp_cur = bp.add(n_i * k);
1871 let c_cur1 = c_cur0.add(n_i * c_cs);
1872 [<ukernel_$mr_left x n_bb_partial>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, beta, k, d_arr, m_left, n_left, f);
1873 }
1874 }
1875 )*
1876
1877 asm!("vzeroupper");
1878 }
1879 }
1880 };
1881}
1882
1883#[macro_export]
1884macro_rules! def_kernel_bb_v0_no_beta {
1885 (
1886 $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
1887 $MR:tt, $NR:tt, $($mr_left:tt),*
1888 ) => {
1889 paste! {
1890 #[target_feature(enable = "avx")]
1891 pub unsafe fn [<kernel_$MR x $NR _bb>]<F: MyFn, const STRIDED: bool>(
1892 m: usize, n: usize, k: usize,
1893 alpha: *const $t_as,
1894 c: *mut $tc, c_rs: usize, c_cs: usize,
1895 ap: *const $ta, bp: *const $tb,
1896 f: F,
1897 ) {
1898 const MR: usize = $MR;
1899 const NR: usize = $NR;
1900 let m_rounded = m / MR * MR;
1901 let n_rounded = n / NR * NR;
1902 let m_left = m % MR;
1903 let n_left = n % NR;
1904
1905 let d_arr = [0, 0, c_rs, c_cs];
1906
1907 let mut m_i = 0;
1908 while m_i < m_rounded {
1909 let c_cur0 = c.add(m_i * c_rs);
1910 let ap_cur = ap.add(m_i * k);
1911 let mut n_i = 0;
1912 while n_i < n_rounded {
1913 let bp_cur = bp.add(n_i * k);
1914 let c_cur1 = c_cur0.add(n_i * c_cs);
1915 [<ukernel_$MR x $NR _bb>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, k, d_arr, MR, f);
1916 n_i += NR;
1917 }
1918 if n_left != 0 {
1919 let bp_cur = bp.add(n_i * k);
1920 let c_cur1 = c_cur0.add(n_i * c_cs);
1921 [<ukernel_$MR x n _bb>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, k, d_arr, MR, n_left, f);
1922 }
1923 m_i += MR;
1924 }
1925
1926
1927 $(
1928 if (m_left+VS-1) / VS * VS == $mr_left {
1929 let c_cur0 = c.add(m_i * c_rs);
1930 let ap_cur = ap.add(m_i * k);
1931 let mut n_i = 0;
1932 while n_i < n_rounded {
1933 let bp_cur = bp.add(n_i * k);
1934 let c_cur1 = c_cur0.add(n_i * c_cs);
1935 [<ukernel_$mr_left x $NR _bb_partial>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, k, d_arr, m_left, f);
1936 n_i += NR;
1937 }
1938 if n_left !=0 {
1939 let bp_cur = bp.add(n_i * k);
1940 let c_cur1 = c_cur0.add(n_i * c_cs);
1941 [<ukernel_$mr_left x n_bb_partial>]::<_, STRIDED>(ap_cur, bp_cur, c_cur1, alpha, k, d_arr, m_left, n_left, f);
1942 }
1943 }
1944 )*
1945
1946 asm!("vzeroupper");
1947 }
1948 }
1949 };
1950}
1951
1952#[macro_export]
1953macro_rules! def_kernel_sb_pf1 {
1954 (
1955 $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
1956 $RS:tt,
1957 $MR:tt, $NR:tt, $pf1_0:tt, $pf_step:tt, $($mr_left:tt),*
1958 ) => {
1959 paste! {
1960 #[target_feature(enable = "avx")]
1961 pub unsafe fn [<kernel_$MR x $NR _sb_v0>]<F: MyFn, const STRIDED: bool>(
1962 m: usize, n: usize, k: usize,
1963 alpha: *const $t_as, beta: *const $t_bs,
1964 a: *const $ta, a_rs: usize, a_cs: usize,
1965 bp: *const $tb,
1966 c: *mut $tc, c_rs: usize, c_cs: usize,
1967 ap: *mut $ta,
1968 f: F,
1969 ) {
1970 let k_eff = (k+$RS-1) / $RS * $RS;
1971 const MR: usize = $MR;
1972 const NR: usize = $NR;
1973 let m_rounded = m / MR * MR;
1974 let n_rounded = n / NR * NR;
1975 let m_left = m % MR;
1976 let n_left = n % NR;
1977
1978 let d_arr = [0, 0, c_rs, c_cs];
1979
1980 let mut m_i = 0;
1981 while m_i < m_rounded {
1982 let c_cur0 = c.add(m_i * c_rs);
1983 let a_cur = a.add(m_i * a_rs);
1984 let a_pft1_offset = $pf1_0 * k;
1985 [<packa_panel_$MR>](MR, k, a_cur, a_rs, a_cs, ap, VS);
1986 let mut n_i = 0;
1987 while n_i < n_rounded {
1988 let bp_cur = bp.add(n_i * k_eff);
1989 let c_cur1 = c_cur0.add(n_i * c_cs);
1990 [<ukernel_$MR x $NR _bb>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, a_pft1_offset, f);
1991 n_i += NR;
1992 }
1993 if n_left != 0 {
1994 let bp_cur = bp.add(n_i * k_eff);
1995 let c_cur1 = c_cur0.add(n_i * c_cs);
1996 [<ukernel_$MR x n _bb>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, MR, n_left, f);
1997 }
1998 m_i += MR;
1999 }
2000
2001 $(
2002 if (m_left+VS-1) / VS *VS == $mr_left {
2003 let c_cur0 = c.add(m_i * c_rs);
2004 let a_cur = a.add(m_i * a_rs);
2005 [<packa_panel_ $MR>](m_left, k, a_cur, a_rs, a_cs, ap, VS);
2006 let mut n_i = 0;
2007 while n_i < n_rounded {
2008 let bp_cur = bp.add(n_i * k_eff);
2009 let c_cur1 = c_cur0.add(n_i * c_cs);
2010 [<ukernel_$mr_left x $NR _bb_partial>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, m_left, f);
2011 n_i += NR;
2012 }
2013 if n_left != 0 {
2014 let bp_cur = bp.add(n_i * k_eff);
2015 let c_cur1 = c_cur0.add(n_i * c_cs);
2016 [<ukernel_$mr_left xn_bb_partial>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, m_left, n_left, f);
2017 }
2018 return;
2019 }
2020 )*
2021
2022 asm!("vzeroupper");
2023 }
2024 }
2025 };
2026}
2027
2028#[macro_export]
2029macro_rules! def_kernel_sb_v0 {
2030 (
2031 $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
2032 $RS:tt,
2033 $MR:tt, $NR:tt, $($mr_left:tt),*
2034 ) => {
2035 paste! {
2036 #[target_feature(enable = "avx")]
2037 pub unsafe fn [<kernel_$MR x $NR _sb_v0>]<F: MyFn, const STRIDED: bool>(
2038 m: usize, n: usize, k: usize,
2039 alpha: *const $t_as, beta: *const $t_bs,
2040 a: *const $ta, a_rs: usize, a_cs: usize,
2041 bp: *const $tb,
2042 c: *mut $tc, c_rs: usize, c_cs: usize,
2043 ap: *mut $ta,
2044 f: F,
2045 ) {
2046 let k_eff = (k+$RS-1) / $RS * $RS;
2047 const MR: usize = $MR;
2048 const NR: usize = $NR;
2049 let m_rounded = m / MR * MR;
2050 let n_rounded = n / NR * NR;
2051 let m_left = m % MR;
2052 let n_left = n % NR;
2053
2054 let d_arr = [0, 0, c_rs, c_cs];
2055
2056 let mut m_i = 0;
2057 while m_i < m_rounded {
2058 let c_cur0 = c.add(m_i * c_rs);
2059 let a_cur = a.add(m_i * a_rs);
2060 [<packa_panel_$MR>](MR, k, a_cur, a_rs, a_cs, ap, VS);
2061 let mut n_i = 0;
2062 while n_i < n_rounded {
2063 let bp_cur = bp.add(n_i * k_eff);
2064 let c_cur1 = c_cur0.add(n_i * c_cs);
2065 [<ukernel_$MR x $NR _bb>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, MR, f);
2066 n_i += NR;
2067 }
2068 if n_left != 0 {
2069 let bp_cur = bp.add(n_i * k_eff);
2070 let c_cur1 = c_cur0.add(n_i * c_cs);
2071 [<ukernel_$MR x n _bb>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, MR, n_left, f);
2072 }
2073 m_i += MR;
2074 }
2075
2076 $(
2077 if (m_left+VS-1) / VS *VS == $mr_left {
2078 let c_cur0 = c.add(m_i * c_rs);
2079 let a_cur = a.add(m_i * a_rs);
2080 [<packa_panel_ $MR>](m_left, k, a_cur, a_rs, a_cs, ap, VS);
2081 let mut n_i = 0;
2082 while n_i < n_rounded {
2083 let bp_cur = bp.add(n_i * k_eff);
2084 let c_cur1 = c_cur0.add(n_i * c_cs);
2085 [<ukernel_$mr_left x $NR _bb_partial>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, m_left, f);
2086 n_i += NR;
2087 }
2088 if n_left != 0 {
2089 let bp_cur = bp.add(n_i * k_eff);
2090 let c_cur1 = c_cur0.add(n_i * c_cs);
2091 [<ukernel_$mr_left xn_bb_partial>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, beta, k_eff, d_arr, m_left, n_left, f);
2092 }
2093 return;
2094 }
2095 )*
2096
2097 asm!("vzeroupper");
2098 }
2099 }
2100 };
2101}
2102
2103#[macro_export]
2104macro_rules! def_kernel_sb_v0_no_beta {
2105 (
2106 $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
2107 $MR:tt, $NR:tt, $($mr_left:tt),*
2108 ) => {
2109 paste! {
2110 #[target_feature(enable = "avx")]
2111 pub unsafe fn [<kernel_$MR x $NR _sb_v0>]<F: MyFn, const STRIDED: bool>(
2112 m: usize, n: usize, k: usize,
2113 alpha: *const $t_as,
2114 a: *const $ta, a_rs: usize, a_cs: usize,
2115 bp: *const $tb,
2116 c: *mut $tc, c_rs: usize, c_cs: usize,
2117 ap: *mut $ta,
2118 f: F,
2119 ) {
2120 const MR: usize = $MR;
2121 const NR: usize = $NR;
2122 let m_rounded = m / MR * MR;
2123 let n_rounded = n / NR * NR;
2124 let m_left = m % MR;
2125 let n_left = n % NR;
2126
2127 let d_arr = [0, 0, c_rs, c_cs];
2128
2129 let mut m_i = 0;
2130 while m_i < m_rounded {
2131 let c_cur0 = c.add(m_i * c_rs);
2132 let a_cur = a.add(m_i * a_rs);
2133 [<packa_panel_$MR>](MR, k, a_cur, a_rs, a_cs, ap, VS);
2134 let mut n_i = 0;
2135 while n_i < n_rounded {
2136 let bp_cur = bp.add(n_i * k);
2137 let c_cur1 = c_cur0.add(n_i * c_cs);
2138 [<ukernel_$MR x $NR _bb>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, k, d_arr, MR, f);
2139 n_i += NR;
2140 }
2141 if n_left != 0 {
2142 let bp_cur = bp.add(n_i * k);
2143 let c_cur1 = c_cur0.add(n_i * c_cs);
2144 [<ukernel_$MR x n _bb>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, k, d_arr, MR, n_left, f);
2145 }
2146 m_i += MR;
2147 }
2148
2149 $(
2150 if (m_left+VS-1) / VS *VS == $mr_left {
2151 let c_cur0 = c.add(m_i * c_rs);
2152 let a_cur = a.add(m_i * a_rs);
2153 [<packa_panel_ $MR>](m_left, k, a_cur, a_rs, a_cs, ap, VS);
2154 let mut n_i = 0;
2155 while n_i < n_rounded {
2156 let bp_cur = bp.add(n_i * k);
2157 let c_cur1 = c_cur0.add(n_i * c_cs);
2158 [<ukernel_$mr_left x $NR _bb_partial>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, k, d_arr, m_left, f);
2159 n_i += NR;
2160 }
2161 if n_left != 0 {
2162 let bp_cur = bp.add(n_i * k);
2163 let c_cur1 = c_cur0.add(n_i * c_cs);
2164 [<ukernel_$mr_left xn_bb_partial>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, k, d_arr, m_left, n_left, f);
2165 }
2166 return;
2167 }
2168 )*
2169
2170 asm!("vzeroupper");
2171 }
2172 }
2173 };
2174}
2175
2176#[macro_export]
2177macro_rules! def_kernel_sb_pf1_no_beta {
2178 (
2179 $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
2180 $MR:tt, $NR:tt, $pf1_0:tt, $pf_step:tt, $($mr_left:tt),*
2181 ) => {
2182 paste! {
2183 #[target_feature(enable = "avx")]
2184 pub unsafe fn [<kernel_$MR x $NR _sb_v0>]<F: MyFn, const STRIDED: bool>(
2185 m: usize, n: usize, k: usize,
2186 alpha: *const $t_as,
2187 a: *const $ta, a_rs: usize, a_cs: usize,
2188 bp: *const $tb,
2189 c: *mut $tc, c_rs: usize, c_cs: usize,
2190 ap: *mut $ta,
2191 f: F,
2192 ) {
2193 const MR: usize = $MR;
2194 const NR: usize = $NR;
2195 let m_rounded = m / MR * MR;
2196 let n_rounded = n / NR * NR;
2197 let m_left = m % MR;
2198 let n_left = n % NR;
2199
2200 let d_arr = [0, 0, c_rs, c_cs];
2201
2202 let mut m_i = 0;
2203 while m_i < m_rounded {
2204 let c_cur0 = c.add(m_i * c_rs);
2205 let a_cur = a.add(m_i * a_rs);
2206 let a_pft1_offset = $pf1_0 * k;
2207 [<packa_panel_$MR>](MR, k, a_cur, a_rs, a_cs, ap, VS);
2208 let mut n_i = 0;
2209 while n_i < n_rounded {
2210 let bp_cur = bp.add(n_i * k);
2211 let c_cur1 = c_cur0.add(n_i * c_cs);
2212 [<ukernel_$MR x $NR _bb>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, k, d_arr, a_pft1_offset, f);
2213 n_i += NR;
2214 }
2215 if n_left != 0 {
2216 let bp_cur = bp.add(n_i * k);
2217 let c_cur1 = c_cur0.add(n_i * c_cs);
2218 [<ukernel_$MR x n _bb>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, k, d_arr, MR, n_left, f);
2219 }
2220 m_i += MR;
2221 }
2222
2223 $(
2224 if (m_left+VS-1) / VS *VS == $mr_left {
2225 let c_cur0 = c.add(m_i * c_rs);
2226 let a_cur = a.add(m_i * a_rs);
2227 [<packa_panel_ $MR>](m_left, k, a_cur, a_rs, a_cs, ap, VS);
2228 let mut n_i = 0;
2229 while n_i < n_rounded {
2230 let bp_cur = bp.add(n_i * k);
2231 let c_cur1 = c_cur0.add(n_i * c_cs);
2232 [<ukernel_$mr_left x $NR _bb_partial>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, k, d_arr, m_left, f);
2233 n_i += NR;
2234 }
2235 if n_left != 0 {
2236 let bp_cur = bp.add(n_i * k);
2237 let c_cur1 = c_cur0.add(n_i * c_cs);
2238 [<ukernel_$mr_left xn_bb_partial>]::<_, STRIDED>(ap, bp_cur, c_cur1, alpha, k, d_arr, m_left, n_left, f);
2239 }
2240 return;
2241 }
2242 )*
2243
2244 asm!("vzeroupper");
2245 }
2246 }
2247 };
2248}
2249
2250#[macro_export]
2251macro_rules! def_kernel_bs_no_beta {
2252 (
2253 $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
2254 $MR:tt, $NR:tt, $($mr_left:tt),*
2255 ) => {
2256 paste! {
2257 #[target_feature(enable = "avx")]
2258 pub unsafe fn [<kernel_$MR x $NR _bs_v0>]<F: MyFn, const STRIDED: bool>(
2259 m: usize, n: usize, k: usize,
2260 alpha: *const $t_as,
2261 b: *const $tb, b_rs: usize, b_cs: usize,
2262 c: *mut $tc, c_rs: usize, c_cs: usize,
2263 ap: *const $ta,
2264 f: F,
2265 ) {
2266 const MR: usize = $MR;
2267 const NR: usize = $NR;
2268 let m_rounded = m / MR * MR;
2269 let n_rounded = n / NR * NR;
2270 let m_left = m % MR;
2271 let n_left = n % NR;
2272
2273 let d_arr = [b_rs, b_cs, c_rs, c_cs];
2274
2275 let mut m_i = 0;
2276 while m_i < m_rounded {
2277 let c_cur0 = c.add(m_i * c_rs);
2278 let ap_cur = ap.add(m_i * k);
2279 let mut n_i = 0;
2280 while n_i < n_rounded {
2281 let b_cur = b.add(n_i * b_cs);
2282 let c_cur1 = c_cur0.add(n_i * c_cs);
2283 [<ukernel_$MR x $NR _bs>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, k, d_arr, MR, f);
2284 n_i += NR;
2285 }
2286 if n_left != 0 {
2287 let b_cur = b.add(n_i * k);
2288 let c_cur1 = c_cur0.add(n_i * c_cs);
2289 [<ukernel_$MR xn_bs>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, k, d_arr, MR, n_left, f);
2290 }
2291 m_i += MR;
2292 }
2293
2294 $(
2295 if (m_left+VS-1) / VS * VS == $mr_left {
2296 let c_cur0 = c.add(m_i * c_rs);
2297 let ap_cur = ap.add(m_i * k);
2298 let mut n_i = 0;
2299 while n_i < n_rounded {
2300 let b_cur = b.add(n_i * b_cs);
2301 let c_cur1 = c_cur0.add(n_i * c_cs);
2302 [<ukernel_$mr_left x $NR _bs_partial>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, k, d_arr, m_left, f);
2303 n_i += NR;
2304 }
2305 if n_left != 0 {
2306 let b_cur = b.add(n_i * b_cs);
2307 let c_cur1 = c_cur0.add(n_i * c_cs);
2308 [<ukernel_$mr_left xn_bs_partial>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, k, d_arr, m_left, n_left, f);
2309 }
2310 return;
2311 }
2312 )*
2313
2314 asm!("vzeroupper");
2315 }
2316 }
2317 };
2318}
2319
2320#[macro_export]
2321macro_rules! def_kernel_bs {
2322 (
2323 $ta:ty, $tb:ty, $tc:ty, $t_as:ty, $t_bs:ty,
2324 $MR:tt, $NR:tt, $($mr_left:tt),*
2325 ) => {
2326 paste! {
2327 #[target_feature(enable = "avx")]
2328 pub unsafe fn [<kernel_$MR x $NR _bs_v0>]<F: MyFn, const STRIDED: bool>(
2329 m: usize, n: usize, k: usize,
2330 alpha: *const $t_as, beta: *const $t_bs,
2331 b: *const $tb, b_rs: usize, b_cs: usize,
2332 c: *mut $tc, c_rs: usize, c_cs: usize,
2333 ap: *const $ta,
2334 f: F,
2335 ) {
2336 const MR: usize = $MR;
2337 const NR: usize = $NR;
2338 let m_rounded = m / MR * MR;
2339 let n_rounded = n / NR * NR;
2340 let m_left = m % MR;
2341 let n_left = n % NR;
2342
2343 let d_arr = [b_rs, b_cs, c_rs, c_cs];
2344
2345 let mut m_i = 0;
2346 while m_i < m_rounded {
2347 let c_cur0 = c.add(m_i * c_rs);
2348 let ap_cur = ap.add(m_i * k);
2349 let mut n_i = 0;
2350 while n_i < n_rounded {
2351 let b_cur = b.add(n_i * b_cs);
2352 let c_cur1 = c_cur0.add(n_i * c_cs);
2353 [<ukernel_$MR x $NR _bs>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, beta, k, d_arr, MR, f);
2354 n_i += NR;
2355 }
2356 if n_left != 0 {
2357 let b_cur = b.add(n_i * b_cs);
2358 let c_cur1 = c_cur0.add(n_i * c_cs);
2359 [<ukernel_$MR xn_bs>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, beta, k, d_arr, MR, n_left, f);
2360 }
2361 m_i += MR;
2362 }
2363
2364 $(
2365 if (m_left+VS-1) / VS * VS == $mr_left {
2366 let c_cur0 = c.add(m_i * c_rs);
2367 let ap_cur = ap.add(m_i * k);
2368 let mut n_i = 0;
2369 while n_i < n_rounded {
2370 let b_cur = b.add(n_i * b_cs);
2371 let c_cur1 = c_cur0.add(n_i * c_cs);
2372 [<ukernel_$mr_left x $NR _bs_partial>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, beta, k, d_arr, m_left, f);
2373 n_i += NR;
2374 }
2375 if n_left != 0 {
2376 let b_cur = b.add(n_i * b_cs);
2377 let c_cur1 = c_cur0.add(n_i * c_cs);
2378 [<ukernel_$mr_left xn_bs_partial>]::<_, STRIDED>(ap_cur, b_cur, c_cur1, alpha, beta, k, d_arr, m_left, n_left, f);
2379 }
2380 return;
2381 }
2382 )*
2383
2384 asm!("vzeroupper");
2385 }
2386 }
2387 };
2388}
2389#[macro_export]
2390macro_rules! c_mem {
2391 (0) => {
2392 "0({cx})"
2393 };
2394 (1) => {
2395 "0({cx}, {x0})"
2396 };
2397 (2) => {
2398 "0({cx}, {x0}, 2)"
2399 };
2400 (3) => {
2401 "0({x1})"
2402 };
2403 (4) => {
2404 "0({x1}, {x0})"
2405 };
2406 (5) => {
2407 "0({x1}, {x0}, 2)"
2408 };
2409 (6) => {
2410 "0({x2})"
2411 };
2412 (7) => {
2413 "0({x2}, {x0})"
2414 };
2415 (8) => {
2416 "0({x2}, {x0}, 2)"
2417 };
2418 (9) => {
2419 "0({x3})"
2420 };
2421 (10) => {
2422 "0({x3}, {x0})"
2423 };
2424 (11) => {
2425 "0({x3}, {x0}, 2)"
2426 };
2427 (12) => {
2428 "0({x4})"
2429 };
2430 (13) => {
2431 "0({x4}, {x0})"
2432 };
2433 (14) => {
2434 "0({x4}, {x0}, 2)"
2435 };
2436}
2437
2438