1#[cfg(target_arch = "x86_64")]
2pub(crate) mod x86_64_arch;
3
4#[cfg(target_arch = "aarch64")]
5pub(crate) mod armv8;
6
7pub(crate) mod reference;
8
9pub(crate) type TA = f32;
10pub(crate) type TB = f32;
11pub(crate) type TC = f32;
12
13#[derive(Copy, Clone)]
14pub(crate) struct NullFn;
15
16pub(crate) trait MyFn: Copy + std::marker::Sync {
17 unsafe fn call(self, c: *mut TC, m: usize);
18}
19
20impl MyFn for NullFn {
21 #[inline(always)]
22 unsafe fn call(self, _c: *mut TC, _m: usize) {}
23}
24
25impl MyFn for unsafe fn(*mut TC, m: usize) {
26 #[inline(always)]
27 unsafe fn call(self, c: *mut TC, m: usize) {
28 self(c, m);
29 }
30}
31
32#[cfg(target_arch = "x86_64")]
33use x86_64_arch::X86_64dispatcher;
34
35use reference::RefGemm;
36
37use glar_base::{
38 ap_size, bp_size, get_cache_params, has_f32_compute, Array, ArrayMut, GemmCache, GlarPar, HWModel,
39 RUNTIME_HW_CONFIG,
40};
41
42#[inline(always)]
43pub(crate) unsafe fn load_buf(c: *const TC, c_rs: usize, c_cs: usize, c_buf: &mut [TC], m: usize, n: usize) {
44 for j in 0..n {
45 for i in 0..m {
46 c_buf[i + j * m] = *c.add(i * c_rs + j * c_cs);
47 }
48 }
49}
50
51#[inline(always)]
52pub(crate) unsafe fn store_buf(c: *mut TC, c_rs: usize, c_cs: usize, c_buf: &[TC], m: usize, n: usize) {
53 for j in 0..n {
54 for i in 0..m {
55 *c.add(i * c_rs + j * c_cs) = c_buf[i + j * m];
56 }
57 }
58}
59
60#[inline(always)]
61fn get_mcnckc() -> (usize, usize, usize) {
62 let (mc, nc, kc) = match (*RUNTIME_HW_CONFIG).hw_model {
67 HWModel::Skylake => (4800, 384, 1024),
68 HWModel::Haswell => (4800, 320, 192),
69 _ => get_cache_params(),
70 };
71 (mc, nc, kc)
72}
73
74pub(crate) unsafe fn glar_sgemm_generic<F: MyFn>(
75 m: usize,
76 n: usize,
77 k: usize,
78 alpha: TA,
79 a: Array<TA>,
80 b: Array<TB>,
81 beta: TC,
82 c: ArrayMut<TC>,
83 f: F,
84) {
85 let par = GlarPar::default(m, n);
86 let (mc, nc, kc) = get_mcnckc();
87 if has_f32_compute() {
88 let hw_config = X86_64dispatcher::from_hw_cfg(&*RUNTIME_HW_CONFIG, mc, nc, kc, f);
89 x86_64_arch::glar_gemm(&hw_config, m, n, k, alpha, a, b, beta, c, &par);
90 return;
91 }
92 let hw_config = RefGemm::from_hw_cfg(&*RUNTIME_HW_CONFIG, mc, nc, kc, f);
94 reference::glar_gemm(&hw_config, m, n, k, alpha, a, b, beta, c, &par);
95}
96
97pub unsafe fn glar_sgemm(
98 m: usize,
99 n: usize,
100 k: usize,
101 alpha: TA,
102 a: *const TA,
103 a_rs: usize,
104 a_cs: usize,
105 b: *const TB,
106 b_rs: usize,
107 b_cs: usize,
108 beta: TC,
109 c: *mut TC,
110 c_rs: usize,
111 c_cs: usize,
112) {
113 let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
115 (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
116 } else {
117 (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
118 };
119 let a = Array::strided_matrix(a, a_rs, a_cs);
120 let b = Array::strided_matrix(b, b_rs, b_cs);
121 let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
122 let null_fn = NullFn {};
123 glar_sgemm_generic(m, n, k, alpha, a, b, beta, c, null_fn);
124}
125
126#[cfg(feature = "fuse")]
127pub unsafe fn glar_sgemm_fused(
128 m: usize,
129 n: usize,
130 k: usize,
131 alpha: TA,
132 a: *const TA,
133 a_rs: usize,
134 a_cs: usize,
135 b: *const TB,
136 b_rs: usize,
137 b_cs: usize,
138 beta: TC,
139 c: *mut TC,
140 c_rs: usize,
141 c_cs: usize,
142 unary: fn(*mut TC, usize),
143) {
144 let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
146 (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
147 } else {
148 (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
149 };
150 let a = Array::strided_matrix(a, a_rs, a_cs);
151 let b = Array::strided_matrix(b, b_rs, b_cs);
152 let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
153 glar_sgemm_generic(m, n, k, alpha, a, b, beta, c, unary);
154}
155
156pub unsafe fn glar_sgemv(
157 m: usize,
158 n: usize,
159 alpha: TA,
160 a: *const TA,
161 a_rs: usize,
162 a_cs: usize,
163 x: *const TB,
164 incx: usize,
165 beta: TC,
166 y: *mut TC,
167 incy: usize,
168) {
169 glar_sgemm(m, 1, n, alpha, a, a_rs, a_cs, x, 1, incx, beta, y, 1, incy)
170}
171pub unsafe fn glar_sdot(
172 n: usize,
173 alpha: TA,
174 x: *const TA,
175 incx: usize,
176 y: *const TB,
177 incy: usize,
178 beta: TC,
179 res: *mut TC,
180) {
181 glar_sgemm(1, 1, n, alpha, x, incx, 1, y, incy, 1, beta, res, 1, 1)
182}
183
184pub unsafe fn packa_f32(m: usize, k: usize, a: *const TA, a_rs: usize, a_cs: usize, ap: *mut TA) -> Array<TA> {
190 assert_eq!(ap.align_offset(glar_base::AB_ALIGN), 0);
191 let mut ap = ap;
192 if m == 1 {
193 for j in 0..k {
194 *ap.add(j) = *a.add(j * a_cs);
195 }
196 return Array::strided_matrix(ap, 1, m);
197 }
198 let (mc, nc, kc) = get_mcnckc();
199 let hw_config = X86_64dispatcher::from_hw_cfg(&*RUNTIME_HW_CONFIG, mc, nc, kc, NullFn {});
200 let hw_config_ref = RefGemm::from_hw_cfg(&*RUNTIME_HW_CONFIG, mc, nc, kc, NullFn {});
202
203 #[cfg(target_arch = "x86_64")]
204 {
205 let ap0 = ap;
206 let vs = if has_f32_compute() { hw_config.vs } else { hw_config_ref.vs };
207 for p in (0..k).step_by(kc) {
208 let kc_len = if k >= (p + kc) { kc } else { k - p };
209 for i in (0..m).step_by(mc) {
210 let mc_len = if m >= (i + mc) { mc } else { m - i };
211 let mc_len_eff = (mc_len + vs - 1) / vs * vs;
212 let a_cur = a.add(i * a_rs + p * a_cs);
213 if has_f32_compute() {
214 hw_config.packa_fn(a_cur, ap, mc_len, kc_len, a_rs, a_cs);
215 } else {
216 hw_config_ref.packa_fn(a_cur, ap, mc_len, kc_len, a_rs, a_cs);
217 }
218 ap = ap.add(mc_len_eff * kc_len);
219 }
220 }
221 return Array::packed_matrix(ap0, m, k);
222 }
223}
224
225pub unsafe fn packb_f32(n: usize, k: usize, b: *const TB, b_rs: usize, b_cs: usize, bp: *mut TB) -> Array<TB> {
226 assert_eq!(bp.align_offset(glar_base::AB_ALIGN), 0);
227 let mut bp = bp;
228 if n == 1 {
229 for j in 0..k {
230 *bp.add(j) = *b.add(j * b_rs);
231 }
232 return Array::strided_matrix(bp, 1, k);
233 }
234 let (mc, nc, kc) = get_mcnckc();
235 let hw_config_ref = RefGemm::from_hw_cfg(&*RUNTIME_HW_CONFIG, mc, nc, kc, NullFn {});
236 let hw_config = X86_64dispatcher::from_hw_cfg(&*RUNTIME_HW_CONFIG, mc, nc, kc, NullFn {});
237 #[cfg(target_arch = "x86_64")]
238 {
239 let bp0 = bp;
240 for p in (0..k).step_by(kc) {
241 let kc_len = if k >= (p + kc) { kc } else { k - p };
242 for i in (0..n).step_by(nc) {
243 let nc_len = if n >= (i + nc) { nc } else { n - i };
244 let nc_len_eff = nc_len;
245 let b_cur = b.add(i * b_cs + p * b_rs);
246 if has_f32_compute() {
247 hw_config.packb_fn(b_cur, bp, nc_len, kc_len, b_rs, b_cs);
248 } else {
249 hw_config_ref.packb_fn(b_cur, bp, nc_len, kc_len, b_rs, b_cs);
250 }
251 bp = bp.add(nc_len_eff * kc_len);
252 }
253 }
254 return Array::packed_matrix(bp0, n, k);
255 }
256}
257
258pub unsafe fn packa_f32_with_ref(m: usize, k: usize, a: &[TA], a_rs: usize, a_cs: usize, ap: &mut [TA]) -> Array<TA> {
259 let pack_size = ap_size::<TA>(m, k);
260 let ap_align_offset = ap.as_ptr().align_offset(glar_base::AB_ALIGN);
261 assert!(ap.len() >= pack_size);
263 let ap = &mut ap[ap_align_offset..];
264 unsafe { packa_f32(m, k, a.as_ptr(), a_rs, a_cs, ap.as_mut_ptr()) }
265}
266
267pub unsafe fn packb_f32_with_ref(n: usize, k: usize, b: &[TB], b_rs: usize, b_cs: usize, bp: &mut [TB]) -> Array<TB> {
268 let pack_size = bp_size::<TB>(n, k);
269 let bp_align_offset = bp.as_ptr().align_offset(glar_base::AB_ALIGN);
270 assert!(bp.len() >= pack_size);
272 let bp = &mut bp[bp_align_offset..];
273 unsafe { packb_f32(n, k, b.as_ptr(), b_rs, b_cs, bp.as_mut_ptr()) }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use glar_base::matrix_size;
280 use glar_dev::{
281 check_gemm_f32, generate_k_dims, generate_m_dims, generate_n_dims, layout_to_strides, random_matrix_uniform,
282 ABLayout,
283 };
284
285 unsafe fn my_unary(c: *mut TC, m: usize) {
286 for i in 0..m {
287 *c.add(i) *= 2.0;
288 }
289 }
290
291 const EPS: f64 = 2e-2;
294
295 static ALPHA_ARR: [f32; 1] = [1.0];
298 static BETA_ARR: [f32; 1] = [1.0];
299
300 fn test_gemm(layout: &ABLayout, is_a_packed: bool, is_b_packed: bool) {
301 let (mc, nc, kc) = get_mcnckc();
302 let (mr, nr, kr) = (48, 8, 8);
303 let m_dims = generate_m_dims(mc, mr);
304 let n_dims = generate_n_dims(nc, nr);
305 let k_dims = generate_k_dims(kc, kr);
306 let unary_fn: unsafe fn(*mut TC, usize) = my_unary;
307 for m in m_dims.iter() {
308 let m = *m;
309 let (c_rs, c_cs) = (1, m);
310 for n in n_dims.iter() {
311 let n = *n;
312 let c_size = matrix_size(c_rs, c_cs, m, n);
313 let mut c = vec![0.0; c_size];
314 let mut c_ref = vec![0.0; c_size];
315 for k in k_dims.iter() {
316 let k = *k;
317 let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = layout_to_strides(&layout, m, n, k);
318 let mut a = vec![0.0; m * k];
319 let mut b = vec![0.0; k * n];
320 random_matrix_uniform(m, k, &mut a, m);
321 random_matrix_uniform(k, n, &mut b, k);
322 let ap_size = if is_a_packed { ap_size::<TA>(m, k) } else { 0 };
323 let mut ap = vec![0_f32; ap_size];
324 let a_matrix = if is_a_packed {
325 unsafe { packa_f32_with_ref(m, k, &a, a_rs, a_cs, &mut ap) }
326 } else {
327 Array::strided_matrix(a.as_ptr(), a_rs, a_cs)
328 };
329 let bp_size = if is_b_packed { bp_size::<TB>(n, k) } else { 0 };
330 let mut bp = vec![0_f32; bp_size];
331 let b_matrix = if is_b_packed {
332 unsafe { packb_f32_with_ref(n, k, &b, b_rs, b_cs, &mut bp) }
333 } else {
334 Array::strided_matrix(b.as_ptr(), b_rs, b_cs)
335 };
336 for alpha in ALPHA_ARR {
337 for beta in BETA_ARR {
338 random_matrix_uniform(m, n, &mut c, m);
339 c_ref.copy_from_slice(&c);
340 let c_matrix = ArrayMut::strided_matrix(c.as_mut_ptr(), c_rs, c_cs);
341 unsafe {
342 glar_sgemm_generic(m, n, k, alpha, a_matrix, b_matrix, beta, c_matrix, unary_fn);
343 }
344 let diff_max = unsafe {
345 check_gemm_f32(
346 m,
347 n,
348 k,
349 alpha,
350 a.as_ptr(),
351 a_rs,
352 a_cs,
353 b.as_ptr(),
354 b_rs,
355 b_cs,
356 beta,
357 &mut c,
358 c_rs,
359 c_cs,
360 &mut c_ref,
361 unary_fn,
362 EPS,
363 )
364 };
365 assert!(
372 diff_max < EPS,
373 "diff_max: {}, m: {}, n: {}, k: {}, alpha: {}, beta: {}",
374 diff_max,
375 m,
376 n,
377 k,
378 alpha,
379 beta
380 );
381 }
382 }
383 }
384 }
385 }
386 }
387 #[test]
388 fn test_nn_col() {
389 test_gemm(&ABLayout::NN, false, false);
390 }
391
392 #[test]
393 fn test_nt_col() {
394 test_gemm(&ABLayout::NT, false, false);
395 }
396
397 #[test]
398 fn test_tn_col() {
399 test_gemm(&ABLayout::TN, false, false);
400 }
401
402 #[test]
403 fn test_tt_col() {
404 test_gemm(&ABLayout::TT, false, false);
405 }
406 #[test]
407 fn test_nn_col_ap() {
408 test_gemm(&ABLayout::NN, true, false);
409 }
410 #[test]
411 fn test_nt_col_ap() {
412 test_gemm(&ABLayout::NT, true, false);
413 }
414 #[test]
415 fn test_tn_col_ap() {
416 test_gemm(&ABLayout::TN, true, false);
417 }
418 #[test]
419 fn test_tt_col_ap() {
420 test_gemm(&ABLayout::TT, true, false);
421 }
422 #[test]
423 fn test_nn_col_bp() {
424 test_gemm(&ABLayout::NN, false, true);
425 }
426 #[test]
427 fn test_nt_col_bp() {
428 test_gemm(&ABLayout::NT, false, true);
429 }
430 #[test]
431 fn test_tn_col_bp() {
432 test_gemm(&ABLayout::TN, false, true);
433 }
434 #[test]
435 fn test_tt_col_bp() {
436 test_gemm(&ABLayout::TT, false, true);
437 }
438
439 #[test]
440 fn test_nn_col_apbp() {
441 test_gemm(&ABLayout::NN, true, true);
442 }
443 #[test]
444 fn test_nt_col_apbp() {
445 test_gemm(&ABLayout::NT, true, true);
446 }
447 #[test]
448 fn test_tn_col_apbp() {
449 test_gemm(&ABLayout::TN, true, true);
450 }
451 #[test]
452 fn test_tt_col_apbp() {
453 test_gemm(&ABLayout::TT, true, true);
454 }
455}