pire_gemm_f32/
lib.rs

1#[cfg(target_arch = "aarch64")]
2pub(crate) mod arm64;
3#[cfg(target_arch = "x86_64")]
4pub(crate) mod x86_64_arch;
5#[cfg(target_arch = "x86")]
6pub(crate) mod x86_arch;
7
8#[cfg(target_arch = "x86_64")]
9use x86_64_arch::{
10    get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, round_k_simd, round_m_simd, KernelDispatcher,
11};
12
13#[cfg(target_arch = "x86")]
14use x86_arch::{
15    get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, round_k_simd, round_m_simd, KernelDispatcher,
16};
17
18#[cfg(target_arch = "aarch64")]
19use arm64::{get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, round_k_simd, round_m_simd, KernelDispatcher};
20
21pub(crate) mod reference;
22
23use core::mem::size_of;
24
25pub(crate) type TA = f32;
26pub(crate) type TB = f32;
27pub(crate) type TC = f32;
28#[allow(unused)]
29const TC_SIZE: usize = size_of::<TC>();
30
31use reference::{packa_fn_ref, packb_fn_ref, round_k_ref, round_m_ref, RefGemm};
32
33use pire_base::{
34    get_cache_params, has_f32_compute, Array, ArrayMut, GemmCache, IdentityFn, PirePar, UnaryFn, AB_ALIGN,
35};
36
37pub trait UnaryFnC: UnaryFn<TC> {}
38impl<F: UnaryFn<TC>> UnaryFnC for F {}
39
40pub(crate) unsafe fn pire_sgemm_fused<F: UnaryFnC>(
41    m: usize,
42    n: usize,
43    k: usize,
44    alpha: TA,
45    a: Array<TA>,
46    b: Array<TB>,
47    beta: TC,
48    c: ArrayMut<TC>,
49    f: F,
50) {
51    let par = PirePar::default(m, n);
52    if has_f32_compute() {
53        #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
54        {
55            let hw_config = KernelDispatcher::new(f);
56            pire_gemm(&hw_config, m, n, k, alpha, a, b, beta, c, &par);
57            return;
58        }
59    }
60    // if none of the optimized paths are available, use reference implementation
61    let hw_config = RefGemm::new(f);
62    reference::pire_gemm(&hw_config, m, n, k, alpha, a, b, beta, c, &par);
63}
64
65pub unsafe fn pire_sgemm(
66    m: usize,
67    n: usize,
68    k: usize,
69    alpha: TA,
70    a: *const TA,
71    a_rs: usize,
72    a_cs: usize,
73    b: *const TB,
74    b_rs: usize,
75    b_cs: usize,
76    beta: TC,
77    c: *mut TC,
78    c_rs: usize,
79    c_cs: usize,
80) {
81    // transpose if c is row strided i.e. c_cs == 1 and c_rs != 1
82    let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
83        (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
84    } else {
85        (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
86    };
87    let a = Array::strided_matrix(a, a_rs, a_cs);
88    let b = Array::strided_matrix(b, b_rs, b_cs);
89    let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
90    let identity_fn = IdentityFn {};
91    pire_sgemm_fused(m, n, k, alpha, a, b, beta, c, identity_fn);
92}
93
94#[cfg(feature = "fuse")]
95pub unsafe fn pire_sgemm_fn_ptr(
96    m: usize,
97    n: usize,
98    k: usize,
99    alpha: TA,
100    a: *const TA,
101    a_rs: usize,
102    a_cs: usize,
103    b: *const TB,
104    b_rs: usize,
105    b_cs: usize,
106    beta: TC,
107    c: *mut TC,
108    c_rs: usize,
109    c_cs: usize,
110    unary: unsafe fn(*mut TC, usize),
111) {
112    // transpose if c is row strided i.e. c_cs == 1 and c_rs != 1
113    let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
114        (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
115    } else {
116        (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
117    };
118    let a = Array::strided_matrix(a, a_rs, a_cs);
119    let b = Array::strided_matrix(b, b_rs, b_cs);
120    let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
121    pire_sgemm_fused(m, n, k, alpha, a, b, beta, c, unary);
122}
123
124fn dispatch_round_m() -> fn(usize) -> usize {
125    #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
126    {
127        if has_f32_compute() {
128            return round_m_simd;
129        }
130    }
131    round_m_ref
132}
133fn dispatch_round_k() -> fn(usize) -> usize {
134    #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
135    {
136        if has_f32_compute() {
137            return round_k_simd;
138        }
139    }
140    round_k_ref
141}
142
143fn dispatch_pack_a() -> unsafe fn(*const TA, *mut TA, usize, usize, usize, usize) {
144    #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
145    {
146        if has_f32_compute() {
147            return packa_fn_simd;
148        }
149    }
150    packa_fn_ref
151}
152
153fn dispatch_pack_b() -> unsafe fn(*const TB, *mut TB, usize, usize, usize, usize) {
154    #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
155    {
156        if has_f32_compute() {
157            return packb_fn_simd;
158        }
159    }
160    packb_fn_ref
161}
162
163fn dispatch_get_mcnckc() -> (usize, usize, usize) {
164    #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
165    {
166        if has_f32_compute() {
167            return get_mcnckc_simd();
168        }
169    }
170    get_cache_params()
171}
172
173pire_base::packing_api!(TA, TB);
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use pire_base::{get_cache_params, matrix_size};
179    use pire_dev::{
180        check_gemm_f32, generate_k_dims, generate_m_dims, generate_n_dims, layout_to_strides, random_matrix_uniform,
181        ABLayout,
182    };
183    #[test]
184    fn test_pack_a() {
185        let a_stride_scale = 1;
186        let (mc, _, kc) = get_mcnckc();
187        let (mr, _, kr) = (48, 8, 8);
188        let m_dims = generate_m_dims(mc, mr);
189        let k_dims = generate_k_dims(kc, kr);
190
191        for &m in &m_dims {
192            for &k in &k_dims {
193                let a_rs = 1 * a_stride_scale;
194                let a_cs = m * a_stride_scale;
195                let a_size = a_size_packed(m, k);
196                let a = vec![0.0; m * k * a_stride_scale];
197                let mut ap = vec![0.0; a_size + AB_ALIGN];
198                let ap_align_offset = ap.as_ptr().align_offset(AB_ALIGN);
199                let ap_array = pack_a(m, k, &a, a_rs, a_cs, &mut ap[ap_align_offset..]);
200                assert!(!ap_array.is_strided() || m == 1);
201            }
202        }
203    }
204
205    #[test]
206    fn test_pack_b() {
207        let b_stride_scale = 1;
208        let (_, nc, kc) = get_mcnckc();
209        let (_, nr, kr) = (48, 8, 8);
210        let n_dims = generate_n_dims(nc, nr);
211        let k_dims = generate_k_dims(kc, kr);
212
213        for &n in &n_dims {
214            for &k in &k_dims {
215                let b_rs = 1 * b_stride_scale;
216                let b_cs = k * b_stride_scale;
217                let b_size = b_size_packed(n, k);
218                let b = vec![0.0; n * k * b_stride_scale];
219                let mut bp = vec![0.0; b_size + AB_ALIGN];
220                let bp_align_offset = bp.as_ptr().align_offset(AB_ALIGN);
221                let bp_array = pack_b(n, k, &b, b_rs, b_cs, &mut bp[bp_align_offset..]);
222                assert!(!bp_array.is_strided() || n == 1);
223            }
224        }
225    }
226
227    #[allow(unreachable_code)]
228    pub(crate) fn get_mcnckc() -> (usize, usize, usize) {
229        #[cfg(target_arch = "x86_64")]
230        {
231            return x86_64_arch::get_mcnckc_simd();
232        }
233        get_cache_params()
234    }
235
236    unsafe fn unary_fn_test(c: *mut TC, m: usize) {
237        for i in 0..m {
238            *c.add(i) *= 2.0;
239        }
240    }
241
242    const EPS: f64 = 2e-2;
243
244    static ALPHA_ARR: [f32; 1] = [2.0];
245    static BETA_ARR: [f32; 1] = [3.1415];
246
247    fn test_gemm(layout: &ABLayout, is_a_packed: bool, is_b_packed: bool) {
248        let a_stride_scale = 1;
249        let b_stride_scale = 1;
250        let c_stride_scale = 2;
251        let (mc, nc, kc) = get_mcnckc();
252        let (mr, nr, kr) = (48, 8, 8);
253        let m_dims = generate_m_dims(mc, mr);
254        let n_dims = generate_n_dims(nc, nr);
255        let k_dims = generate_k_dims(kc, kr);
256        let unary_fn: unsafe fn(*mut TC, usize) = unary_fn_test;
257        let m_max = *m_dims.iter().max().unwrap();
258        let n_max = *n_dims.iter().max().unwrap();
259        let k_max = *k_dims.iter().max().unwrap();
260        let a_size = matrix_size(m_max, k_max) * a_stride_scale;
261        let b_size = matrix_size(k_max, n_max) * b_stride_scale;
262        let c_size = matrix_size(m_max, n_max) * c_stride_scale;
263        let mut a = vec![0f32; a_size];
264        let mut b = vec![0f32; b_size];
265        random_matrix_uniform(&mut a);
266        random_matrix_uniform(&mut b);
267        let mut c = vec![0f32; c_size];
268        let mut c_ref = vec![0f32; c_size];
269
270        let ap_size = if is_a_packed { a_size_packed(m_max, k_max) } else { 0 };
271        let mut ap = vec![0f32; ap_size + AB_ALIGN];
272        let ap_align_offset = ap.as_ptr().align_offset(AB_ALIGN);
273        let ap_mut_ref = &mut ap[ap_align_offset..];
274
275        let bp_size = if is_b_packed { b_size_packed(n_max, k_max) } else { 0 };
276        let mut bp = vec![0f32; bp_size + AB_ALIGN];
277        let bp_align_offset = bp.as_ptr().align_offset(AB_ALIGN);
278        let bp_mut_ref = &mut bp[bp_align_offset..];
279        for &m in &m_dims {
280            for &n in &n_dims {
281                for &k in &k_dims {
282                    let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = layout_to_strides(&layout, m, n, k);
283                    let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = (
284                        a_rs * a_stride_scale,
285                        a_cs * a_stride_scale,
286                        b_rs * b_stride_scale,
287                        b_cs * b_stride_scale,
288                        c_rs * c_stride_scale,
289                        c_cs * c_stride_scale,
290                    );
291                    let a_matrix = if is_a_packed {
292                        pack_a(m, k, &a, a_rs, a_cs, ap_mut_ref)
293                    } else {
294                        Array::strided_matrix(a.as_ptr(), a_rs, a_cs)
295                    };
296                    let b_matrix = if is_b_packed {
297                        pack_b(n, k, &b, b_rs, b_cs, bp_mut_ref)
298                    } else {
299                        Array::strided_matrix(b.as_ptr(), b_rs, b_cs)
300                    };
301                    for alpha in ALPHA_ARR {
302                        for beta in BETA_ARR {
303                            random_matrix_uniform(&mut c);
304                            c_ref.copy_from_slice(&c);
305                            let c_matrix = ArrayMut::strided_matrix(c.as_mut_ptr(), c_rs, c_cs);
306                            unsafe {
307                                pire_sgemm_fused(m, n, k, alpha, a_matrix, b_matrix, beta, c_matrix, unary_fn);
308                            }
309                            let diff_max = unsafe {
310                                check_gemm_f32(
311                                    m,
312                                    n,
313                                    k,
314                                    alpha,
315                                    a.as_ptr(),
316                                    a_rs,
317                                    a_cs,
318                                    b.as_ptr(),
319                                    b_rs,
320                                    b_cs,
321                                    beta,
322                                    &mut c,
323                                    c_rs,
324                                    c_cs,
325                                    &mut c_ref,
326                                    unary_fn,
327                                    EPS,
328                                )
329                            };
330                            // if diff_max >= EPS {
331                            // 	println!("a: {:?}", a);
332                            // 	println!("b: {:?}", b);
333                            // 	println!("c:     {:?}", c);
334                            // 	println!("c_ref: {:?}", c_ref);
335                            // }
336                            assert!(
337                                diff_max < EPS,
338                                "diff_max: {}, m: {}, n: {}, k: {}, alpha: {}, beta: {}",
339                                diff_max,
340                                m,
341                                n,
342                                k,
343                                alpha,
344                                beta
345                            );
346                        }
347                    }
348                }
349            }
350        }
351    }
352    #[test]
353    fn test_nn_col() {
354        test_gemm(&ABLayout::NN, false, false);
355    }
356
357    #[test]
358    fn test_nt_col() {
359        test_gemm(&ABLayout::NT, false, false);
360    }
361
362    #[test]
363    fn test_tn_col() {
364        test_gemm(&ABLayout::TN, false, false);
365    }
366
367    #[test]
368    fn test_tt_col() {
369        test_gemm(&ABLayout::TT, false, false);
370    }
371    #[test]
372    fn test_nn_col_ap() {
373        test_gemm(&ABLayout::NN, true, false);
374    }
375    #[test]
376    fn test_nt_col_ap() {
377        test_gemm(&ABLayout::NT, true, false);
378    }
379    #[test]
380    fn test_tn_col_ap() {
381        test_gemm(&ABLayout::TN, true, false);
382    }
383    #[test]
384    fn test_tt_col_ap() {
385        test_gemm(&ABLayout::TT, true, false);
386    }
387    #[test]
388    fn test_nn_col_bp() {
389        test_gemm(&ABLayout::NN, false, true);
390    }
391    #[test]
392    fn test_nt_col_bp() {
393        test_gemm(&ABLayout::NT, false, true);
394    }
395    #[test]
396    fn test_tn_col_bp() {
397        test_gemm(&ABLayout::TN, false, true);
398    }
399    #[test]
400    fn test_tt_col_bp() {
401        test_gemm(&ABLayout::TT, false, true);
402    }
403
404    #[test]
405    fn test_nn_col_apbp() {
406        test_gemm(&ABLayout::NN, true, true);
407    }
408    #[test]
409    fn test_nt_col_apbp() {
410        test_gemm(&ABLayout::NT, true, true);
411    }
412    #[test]
413    fn test_tn_col_apbp() {
414        test_gemm(&ABLayout::TN, true, true);
415    }
416    #[test]
417    fn test_tt_col_apbp() {
418        test_gemm(&ABLayout::TT, true, true);
419    }
420}