glar_gemm_f32/
lib.rs

1#[cfg(target_arch = "x86_64")]
2pub(crate) mod x86_64_arch;
3
4#[cfg(target_arch = "aarch64")]
5pub(crate) mod armv8;
6
7pub(crate) mod reference;
8
9pub(crate) type TA = f32;
10pub(crate) type TB = f32;
11pub(crate) type TC = f32;
12
13#[derive(Copy, Clone)]
14pub(crate) struct NullFn;
15
16pub(crate) trait MyFn: Copy + std::marker::Sync {
17    unsafe fn call(self, c: *mut TC, m: usize);
18}
19
20impl MyFn for NullFn {
21    #[inline(always)]
22    unsafe fn call(self, _c: *mut TC, _m: usize) {}
23}
24
25impl MyFn for unsafe fn(*mut TC, m: usize) {
26    #[inline(always)]
27    unsafe fn call(self, c: *mut TC, m: usize) {
28        self(c, m);
29    }
30}
31
32#[cfg(target_arch = "x86_64")]
33use x86_64_arch::X86_64dispatcher;
34
35use reference::RefGemm;
36
37use glar_base::{
38    ap_size, bp_size, get_cache_params, has_f32_compute, Array, ArrayMut, GemmCache, GlarPar, HWModel,
39    RUNTIME_HW_CONFIG,
40};
41
42#[inline(always)]
43pub(crate) unsafe fn load_buf(c: *const TC, c_rs: usize, c_cs: usize, c_buf: &mut [TC], m: usize, n: usize) {
44    for j in 0..n {
45        for i in 0..m {
46            c_buf[i + j * m] = *c.add(i * c_rs + j * c_cs);
47        }
48    }
49}
50
51#[inline(always)]
52pub(crate) unsafe fn store_buf(c: *mut TC, c_rs: usize, c_cs: usize, c_buf: &[TC], m: usize, n: usize) {
53    for j in 0..n {
54        for i in 0..m {
55            *c.add(i * c_rs + j * c_cs) = c_buf[i + j * m];
56        }
57    }
58}
59
60#[inline(always)]
61fn get_mcnckc() -> (usize, usize, usize) {
62    // let mc = std::env::var("GLAR_MC").unwrap_or("4800".to_string()).parse::<usize>().unwrap();
63    // let nc = std::env::var("GLAR_NC").unwrap_or("192".to_string()).parse::<usize>().unwrap();
64    // let kc = std::env::var("GLAR_KC").unwrap_or("768".to_string()).parse::<usize>().unwrap();
65    // return (mc, nc, kc);
66    let (mc, nc, kc) = match (*RUNTIME_HW_CONFIG).hw_model {
67        HWModel::Skylake => (4800, 384, 1024),
68        HWModel::Haswell => (4800, 320, 192),
69        _ => get_cache_params(),
70    };
71    (mc, nc, kc)
72}
73
74pub(crate) unsafe fn glar_sgemm_generic<F: MyFn>(
75    m: usize,
76    n: usize,
77    k: usize,
78    alpha: TA,
79    a: Array<TA>,
80    b: Array<TB>,
81    beta: TC,
82    c: ArrayMut<TC>,
83    f: F,
84) {
85    let par = GlarPar::default(m, n);
86    let (mc, nc, kc) = get_mcnckc();
87    if has_f32_compute() {
88        let hw_config = X86_64dispatcher::from_hw_cfg(&*RUNTIME_HW_CONFIG, mc, nc, kc, f);
89        x86_64_arch::glar_gemm(&hw_config, m, n, k, alpha, a, b, beta, c, &par);
90        return;
91    }
92    // if none of the optimized paths are available, use reference implementation
93    let hw_config = RefGemm::from_hw_cfg(&*RUNTIME_HW_CONFIG, mc, nc, kc, f);
94    reference::glar_gemm(&hw_config, m, n, k, alpha, a, b, beta, c, &par);
95}
96
97pub unsafe fn glar_sgemm(
98    m: usize,
99    n: usize,
100    k: usize,
101    alpha: TA,
102    a: *const TA,
103    a_rs: usize,
104    a_cs: usize,
105    b: *const TB,
106    b_rs: usize,
107    b_cs: usize,
108    beta: TC,
109    c: *mut TC,
110    c_rs: usize,
111    c_cs: usize,
112) {
113    // transpose if c is row strided i.e. c_cs == 1 and c_rs != 1
114    let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
115        (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
116    } else {
117        (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
118    };
119    let a = Array::strided_matrix(a, a_rs, a_cs);
120    let b = Array::strided_matrix(b, b_rs, b_cs);
121    let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
122    let null_fn = NullFn {};
123    glar_sgemm_generic(m, n, k, alpha, a, b, beta, c, null_fn);
124}
125
126#[cfg(feature = "fuse")]
127pub unsafe fn glar_sgemm_fused(
128    m: usize,
129    n: usize,
130    k: usize,
131    alpha: TA,
132    a: *const TA,
133    a_rs: usize,
134    a_cs: usize,
135    b: *const TB,
136    b_rs: usize,
137    b_cs: usize,
138    beta: TC,
139    c: *mut TC,
140    c_rs: usize,
141    c_cs: usize,
142    unary: fn(*mut TC, usize),
143) {
144    // transpose if c is row strided i.e. c_cs == 1 and c_rs != 1
145    let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
146        (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
147    } else {
148        (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
149    };
150    let a = Array::strided_matrix(a, a_rs, a_cs);
151    let b = Array::strided_matrix(b, b_rs, b_cs);
152    let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
153    glar_sgemm_generic(m, n, k, alpha, a, b, beta, c, unary);
154}
155
156pub unsafe fn glar_sgemv(
157    m: usize,
158    n: usize,
159    alpha: TA,
160    a: *const TA,
161    a_rs: usize,
162    a_cs: usize,
163    x: *const TB,
164    incx: usize,
165    beta: TC,
166    y: *mut TC,
167    incy: usize,
168) {
169    glar_sgemm(m, 1, n, alpha, a, a_rs, a_cs, x, 1, incx, beta, y, 1, incy)
170}
171pub unsafe fn glar_sdot(
172    n: usize,
173    alpha: TA,
174    x: *const TA,
175    incx: usize,
176    y: *const TB,
177    incy: usize,
178    beta: TC,
179    res: *mut TC,
180) {
181    glar_sgemm(1, 1, n, alpha, x, incx, 1, y, incy, 1, beta, res, 1, 1)
182}
183
184// block idx for packa and packb is s.t.
185// m dim for block idx is contiguous and n dim is contiguous
186// this is to ensure that indexing for parallelization over these dims are easy  (otherwise ranges would have to be in the same mc, nc range)
187// this is not an issue since we do not parallelize over k dim (think about this when we parallelize over k dim in the future, which is only beneficial only
188// in the special case of very large k and small m, n
189pub unsafe fn packa_f32(m: usize, k: usize, a: *const TA, a_rs: usize, a_cs: usize, ap: *mut TA) -> Array<TA> {
190    assert_eq!(ap.align_offset(glar_base::AB_ALIGN), 0);
191    let mut ap = ap;
192    if m == 1 {
193        for j in 0..k {
194            *ap.add(j) = *a.add(j * a_cs);
195        }
196        return Array::strided_matrix(ap, 1, m);
197    }
198    let (mc, nc, kc) = get_mcnckc();
199    let hw_config = X86_64dispatcher::from_hw_cfg(&*RUNTIME_HW_CONFIG, mc, nc, kc, NullFn {});
200    // if none of the optimized paths are available, use reference implementation
201    let hw_config_ref = RefGemm::from_hw_cfg(&*RUNTIME_HW_CONFIG, mc, nc, kc, NullFn {});
202
203    #[cfg(target_arch = "x86_64")]
204    {
205        let ap0 = ap;
206        let vs = if has_f32_compute() { hw_config.vs } else { hw_config_ref.vs };
207        for p in (0..k).step_by(kc) {
208            let kc_len = if k >= (p + kc) { kc } else { k - p };
209            for i in (0..m).step_by(mc) {
210                let mc_len = if m >= (i + mc) { mc } else { m - i };
211                let mc_len_eff = (mc_len + vs - 1) / vs * vs;
212                let a_cur = a.add(i * a_rs + p * a_cs);
213                if has_f32_compute() {
214                    hw_config.packa_fn(a_cur, ap, mc_len, kc_len, a_rs, a_cs);
215                } else {
216                    hw_config_ref.packa_fn(a_cur, ap, mc_len, kc_len, a_rs, a_cs);
217                }
218                ap = ap.add(mc_len_eff * kc_len);
219            }
220        }
221        return Array::packed_matrix(ap0, m, k);
222    }
223}
224
225pub unsafe fn packb_f32(n: usize, k: usize, b: *const TB, b_rs: usize, b_cs: usize, bp: *mut TB) -> Array<TB> {
226    assert_eq!(bp.align_offset(glar_base::AB_ALIGN), 0);
227    let mut bp = bp;
228    if n == 1 {
229        for j in 0..k {
230            *bp.add(j) = *b.add(j * b_rs);
231        }
232        return Array::strided_matrix(bp, 1, k);
233    }
234    let (mc, nc, kc) = get_mcnckc();
235    let hw_config_ref = RefGemm::from_hw_cfg(&*RUNTIME_HW_CONFIG, mc, nc, kc, NullFn {});
236    let hw_config = X86_64dispatcher::from_hw_cfg(&*RUNTIME_HW_CONFIG, mc, nc, kc, NullFn {});
237    #[cfg(target_arch = "x86_64")]
238    {
239        let bp0 = bp;
240        for p in (0..k).step_by(kc) {
241            let kc_len = if k >= (p + kc) { kc } else { k - p };
242            for i in (0..n).step_by(nc) {
243                let nc_len = if n >= (i + nc) { nc } else { n - i };
244                let nc_len_eff = nc_len;
245                let b_cur = b.add(i * b_cs + p * b_rs);
246                if has_f32_compute() {
247                    hw_config.packb_fn(b_cur, bp, nc_len, kc_len, b_rs, b_cs);
248                } else {
249                    hw_config_ref.packb_fn(b_cur, bp, nc_len, kc_len, b_rs, b_cs);
250                }
251                bp = bp.add(nc_len_eff * kc_len);
252            }
253        }
254        return Array::packed_matrix(bp0, n, k);
255    }
256}
257
258pub unsafe fn packa_f32_with_ref(m: usize, k: usize, a: &[TA], a_rs: usize, a_cs: usize, ap: &mut [TA]) -> Array<TA> {
259    let pack_size = ap_size::<TA>(m, k);
260    let ap_align_offset = ap.as_ptr().align_offset(glar_base::AB_ALIGN);
261    // safety check
262    assert!(ap.len() >= pack_size);
263    let ap = &mut ap[ap_align_offset..];
264    unsafe { packa_f32(m, k, a.as_ptr(), a_rs, a_cs, ap.as_mut_ptr()) }
265}
266
267pub unsafe fn packb_f32_with_ref(n: usize, k: usize, b: &[TB], b_rs: usize, b_cs: usize, bp: &mut [TB]) -> Array<TB> {
268    let pack_size = bp_size::<TB>(n, k);
269    let bp_align_offset = bp.as_ptr().align_offset(glar_base::AB_ALIGN);
270    // safety check
271    assert!(bp.len() >= pack_size);
272    let bp = &mut bp[bp_align_offset..];
273    unsafe { packb_f32(n, k, b.as_ptr(), b_rs, b_cs, bp.as_mut_ptr()) }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use glar_base::matrix_size;
280    use glar_dev::{
281        check_gemm_f32, generate_k_dims, generate_m_dims, generate_n_dims, layout_to_strides, random_matrix_uniform,
282        ABLayout,
283    };
284
285    unsafe fn my_unary(c: *mut TC, m: usize) {
286        for i in 0..m {
287            *c.add(i) *= 2.0;
288        }
289    }
290
291    // fn my_unary(_c: *mut TC, _m: usize) {}
292
293    const EPS: f64 = 2e-2;
294
295    // static ALPHA_ARR: [f32; 2] = [1.0, 3.1415];
296    // static BETA_ARR: [f32; 3] = [1.0, 3.1415, 0.0];
297    static ALPHA_ARR: [f32; 1] = [1.0];
298    static BETA_ARR: [f32; 1] = [1.0];
299
300    fn test_gemm(layout: &ABLayout, is_a_packed: bool, is_b_packed: bool) {
301        let (mc, nc, kc) = get_mcnckc();
302        let (mr, nr, kr) = (48, 8, 8);
303        let m_dims = generate_m_dims(mc, mr);
304        let n_dims = generate_n_dims(nc, nr);
305        let k_dims = generate_k_dims(kc, kr);
306        let unary_fn: unsafe fn(*mut TC, usize) = my_unary;
307        for m in m_dims.iter() {
308            let m = *m;
309            let (c_rs, c_cs) = (1, m);
310            for n in n_dims.iter() {
311                let n = *n;
312                let c_size = matrix_size(c_rs, c_cs, m, n);
313                let mut c = vec![0.0; c_size];
314                let mut c_ref = vec![0.0; c_size];
315                for k in k_dims.iter() {
316                    let k = *k;
317                    let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = layout_to_strides(&layout, m, n, k);
318                    let mut a = vec![0.0; m * k];
319                    let mut b = vec![0.0; k * n];
320                    random_matrix_uniform(m, k, &mut a, m);
321                    random_matrix_uniform(k, n, &mut b, k);
322                    let ap_size = if is_a_packed { ap_size::<TA>(m, k) } else { 0 };
323                    let mut ap = vec![0_f32; ap_size];
324                    let a_matrix = if is_a_packed {
325                        unsafe { packa_f32_with_ref(m, k, &a, a_rs, a_cs, &mut ap) }
326                    } else {
327                        Array::strided_matrix(a.as_ptr(), a_rs, a_cs)
328                    };
329                    let bp_size = if is_b_packed { bp_size::<TB>(n, k) } else { 0 };
330                    let mut bp = vec![0_f32; bp_size];
331                    let b_matrix = if is_b_packed {
332                        unsafe { packb_f32_with_ref(n, k, &b, b_rs, b_cs, &mut bp) }
333                    } else {
334                        Array::strided_matrix(b.as_ptr(), b_rs, b_cs)
335                    };
336                    for alpha in ALPHA_ARR {
337                        for beta in BETA_ARR {
338                            random_matrix_uniform(m, n, &mut c, m);
339                            c_ref.copy_from_slice(&c);
340                            let c_matrix = ArrayMut::strided_matrix(c.as_mut_ptr(), c_rs, c_cs);
341                            unsafe {
342                                glar_sgemm_generic(m, n, k, alpha, a_matrix, b_matrix, beta, c_matrix, unary_fn);
343                            }
344                            let diff_max = unsafe {
345                                check_gemm_f32(
346                                    m,
347                                    n,
348                                    k,
349                                    alpha,
350                                    a.as_ptr(),
351                                    a_rs,
352                                    a_cs,
353                                    b.as_ptr(),
354                                    b_rs,
355                                    b_cs,
356                                    beta,
357                                    &mut c,
358                                    c_rs,
359                                    c_cs,
360                                    &mut c_ref,
361                                    unary_fn,
362                                    EPS,
363                                )
364                            };
365                            // if diff_max >= EPS {
366                            // 	println!("a: {:?}", a);
367                            // 	println!("b: {:?}", b);
368                            // 	println!("c:     {:?}", c);
369                            // 	println!("c_ref: {:?}", c_ref);
370                            // }
371                            assert!(
372                                diff_max < EPS,
373                                "diff_max: {}, m: {}, n: {}, k: {}, alpha: {}, beta: {}",
374                                diff_max,
375                                m,
376                                n,
377                                k,
378                                alpha,
379                                beta
380                            );
381                        }
382                    }
383                }
384            }
385        }
386    }
387    #[test]
388    fn test_nn_col() {
389        test_gemm(&ABLayout::NN, false, false);
390    }
391
392    #[test]
393    fn test_nt_col() {
394        test_gemm(&ABLayout::NT, false, false);
395    }
396
397    #[test]
398    fn test_tn_col() {
399        test_gemm(&ABLayout::TN, false, false);
400    }
401
402    #[test]
403    fn test_tt_col() {
404        test_gemm(&ABLayout::TT, false, false);
405    }
406    #[test]
407    fn test_nn_col_ap() {
408        test_gemm(&ABLayout::NN, true, false);
409    }
410    #[test]
411    fn test_nt_col_ap() {
412        test_gemm(&ABLayout::NT, true, false);
413    }
414    #[test]
415    fn test_tn_col_ap() {
416        test_gemm(&ABLayout::TN, true, false);
417    }
418    #[test]
419    fn test_tt_col_ap() {
420        test_gemm(&ABLayout::TT, true, false);
421    }
422    #[test]
423    fn test_nn_col_bp() {
424        test_gemm(&ABLayout::NN, false, true);
425    }
426    #[test]
427    fn test_nt_col_bp() {
428        test_gemm(&ABLayout::NT, false, true);
429    }
430    #[test]
431    fn test_tn_col_bp() {
432        test_gemm(&ABLayout::TN, false, true);
433    }
434    #[test]
435    fn test_tt_col_bp() {
436        test_gemm(&ABLayout::TT, false, true);
437    }
438
439    #[test]
440    fn test_nn_col_apbp() {
441        test_gemm(&ABLayout::NN, true, true);
442    }
443    #[test]
444    fn test_nt_col_apbp() {
445        test_gemm(&ABLayout::NT, true, true);
446    }
447    #[test]
448    fn test_tn_col_apbp() {
449        test_gemm(&ABLayout::TN, true, true);
450    }
451    #[test]
452    fn test_tt_col_apbp() {
453        test_gemm(&ABLayout::TT, true, true);
454    }
455}