pire_gemm_f16/
lib.rs

1#[cfg(target_arch = "aarch64")]
2pub(crate) mod arm64;
3#[cfg(target_arch = "x86_64")]
4pub(crate) mod x86_64_arch;
5
6#[cfg(target_arch = "x86_64")]
7use x86_64_arch::{
8    get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, pire_gemm_f32, round_k_simd, round_m_simd,
9    KernelDispatcher, KernelDispatcherF32,
10};
11
12use core::mem::size_of;
13
14#[cfg(target_arch = "aarch64")]
15use arm64::{get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, round_k_simd, round_m_simd, KernelDispatcher};
16pub(crate) mod reference;
17
18pub(crate) type TA = f16;
19pub(crate) type TB = f16;
20pub(crate) type TC = f16;
21#[allow(unused)]
22const TC_SIZE: usize = size_of::<TC>();
23
24pub use half::f16;
25
26use pire_base::{
27    get_cache_params, has_f16_compute, has_f16f32_compute, Array, ArrayMut, GemmCache, IdentityFn, PirePar, UnaryFn,
28    AB_ALIGN,
29};
30use reference::{packa_fn_ref, packb_fn_ref, round_k_ref, round_m_ref, RefGemm};
31
32pub trait UnaryFnC: UnaryFn<TC> {}
33impl<F: UnaryFn<TC>> UnaryFnC for F {}
34
35pub(crate) unsafe fn pire_hgemm_fused<F: UnaryFnC>(
36    m: usize,
37    n: usize,
38    k: usize,
39    alpha: TA,
40    a: Array<TA>,
41    b: Array<TB>,
42    beta: TC,
43    c: ArrayMut<TC>,
44    f: F,
45) {
46    let par = PirePar::default(m, n);
47    if has_f16_compute() {
48        #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
49        {
50            let hw_config = KernelDispatcher::new(f);
51            pire_gemm(&hw_config, m, n, k, alpha, a, b, beta, c, &par);
52            return;
53        }
54    }
55    if has_f16f32_compute() {
56        #[cfg(target_arch = "x86_64")]
57        {
58            let hw_config = KernelDispatcherF32::new(f);
59            pire_gemm_f32(&hw_config, m, n, k, alpha.to_f32(), a, b, beta.to_f32(), c, &par);
60            return;
61        }
62    }
63
64    // if none of the optimized paths are available, use reference implementation
65    let hw_config = RefGemm::new(f);
66    reference::pire_gemm(&hw_config, m, n, k, alpha.to_f32(), a, b, beta.to_f32(), c, &par);
67}
68
69pub unsafe fn pire_hgemm(
70    m: usize,
71    n: usize,
72    k: usize,
73    alpha: f16,
74    a: *const f16,
75    a_rs: usize,
76    a_cs: usize,
77    b: *const f16,
78    b_rs: usize,
79    b_cs: usize,
80    beta: f16,
81    c: *mut f16,
82    c_rs: usize,
83    c_cs: usize,
84) {
85    // do not exchange if transa && transb
86    let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
87        (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
88    } else {
89        (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
90    };
91    let a = Array::strided_matrix(a, a_rs, a_cs);
92    let b = Array::strided_matrix(b, b_rs, b_cs);
93    let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
94    let identity_fn = IdentityFn {};
95    pire_hgemm_fused(m, n, k, alpha, a, b, beta, c, identity_fn);
96}
97
98#[cfg(feature = "fuse")]
99pub unsafe fn pire_hgemm_fn_ptr(
100    m: usize,
101    n: usize,
102    k: usize,
103    alpha: f16,
104    a: *const f16,
105    a_rs: usize,
106    a_cs: usize,
107    b: *const f16,
108    b_rs: usize,
109    b_cs: usize,
110    beta: f16,
111    c: *mut f16,
112    c_rs: usize,
113    c_cs: usize,
114    unary: unsafe fn(*mut f16, usize),
115) {
116    // transpose if c is row strided i.e. c_cs == 1 and c_rs != 1
117    let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
118        (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
119    } else {
120        (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
121    };
122    let a = Array::strided_matrix(a, a_rs, a_cs);
123    let b = Array::strided_matrix(b, b_rs, b_cs);
124    let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
125    pire_hgemm_fused(m, n, k, alpha, a, b, beta, c, unary);
126}
127
128fn dispatch_round_m() -> fn(usize) -> usize {
129    #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
130    {
131        if has_f16_compute() || has_f16f32_compute() {
132            return round_m_simd;
133        }
134    }
135    round_m_ref
136}
137fn dispatch_round_k() -> fn(usize) -> usize {
138    #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
139    {
140        if has_f16_compute() || has_f16f32_compute() {
141            return round_k_simd;
142        }
143    }
144    round_k_ref
145}
146
147fn dispatch_pack_a() -> unsafe fn(*const TA, *mut TA, usize, usize, usize, usize) {
148    #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
149    {
150        if has_f16_compute() || has_f16f32_compute() {
151            return packa_fn_simd;
152        }
153    }
154    packa_fn_ref
155}
156
157fn dispatch_pack_b() -> unsafe fn(*const TB, *mut TB, usize, usize, usize, usize) {
158    #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
159    {
160        if has_f16_compute() || has_f16f32_compute() {
161            return packb_fn_simd;
162        }
163    }
164    packb_fn_ref
165}
166
167fn dispatch_get_mcnckc() -> (usize, usize, usize) {
168    #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
169    {
170        if has_f16_compute() || has_f16f32_compute() {
171            return get_mcnckc_simd();
172        }
173    }
174    get_cache_params()
175}
176
177pire_base::packing_api!(TA, TB);
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use pire_base::{get_cache_params, matrix_size};
183    use pire_dev::{
184        check_gemm_f16, generate_k_dims, generate_m_dims, generate_n_dims, layout_to_strides, random_matrix_uniform,
185        ABLayout,
186    };
187    #[test]
188    fn test_pack_a() {
189        let a_stride_scale = 1;
190        let (mc, _, kc) = get_mcnckc();
191        let (mr, _, kr) = (48, 8, 8);
192        let m_dims = generate_m_dims(mc, mr);
193        let k_dims = generate_k_dims(kc, kr);
194
195        for &m in &m_dims {
196            for &k in &k_dims {
197                let a_rs = 1 * a_stride_scale;
198                let a_cs = m * a_stride_scale;
199                let a_size = a_size_packed(m, k);
200                let a = vec![TA::ZERO; m * k * a_stride_scale];
201                let mut ap = vec![TA::ZERO; a_size + AB_ALIGN];
202                let ap_align_offset = ap.as_ptr().align_offset(AB_ALIGN);
203                let ap_array = pack_a(m, k, &a, a_rs, a_cs, &mut ap[ap_align_offset..]);
204                assert!(!ap_array.is_strided() || m == 1);
205            }
206        }
207    }
208
209    #[test]
210    fn test_pack_b() {
211        let b_stride_scale = 1;
212        let (_, nc, kc) = get_mcnckc();
213        let (_, nr, kr) = (48, 8, 8);
214        let n_dims = generate_n_dims(nc, nr);
215        let k_dims = generate_k_dims(kc, kr);
216
217        for &n in &n_dims {
218            for &k in &k_dims {
219                let b_rs = 1 * b_stride_scale;
220                let b_cs = k * b_stride_scale;
221                let b_size = b_size_packed(n, k);
222                let b = vec![TB::ZERO; n * k * b_stride_scale];
223                let mut bp = vec![TB::ZERO; b_size + AB_ALIGN];
224                let bp_align_offset = bp.as_ptr().align_offset(AB_ALIGN);
225                let bp_array = pack_b(n, k, &b, b_rs, b_cs, &mut bp[bp_align_offset..]);
226                assert!(!bp_array.is_strided() || n == 1);
227            }
228        }
229    }
230
231    #[allow(unreachable_code)]
232    pub(crate) fn get_mcnckc() -> (usize, usize, usize) {
233        #[cfg(target_arch = "x86_64")]
234        {
235            return x86_64_arch::get_mcnckc_simd();
236        }
237        get_cache_params()
238    }
239
240    unsafe fn unary_fn_test(c: *mut TC, m: usize) {
241        for i in 0..m {
242            *c.add(i) *= f16::from_f32(2.0);
243        }
244    }
245
246    const EPS: f64 = 10e-1;
247
248    static ALPHA_ARR: [f16; 1] = [f16::from_f32_const(1.23)];
249    static BETA_ARR: [f16; 1] = [f16::from_f32_const(1.17)];
250
251    fn test_gemm(layout: &ABLayout, is_a_packed: bool, is_b_packed: bool) {
252        let a_stride_scale = 1;
253        let b_stride_scale = 1;
254        let c_stride_scale = 2;
255        let (mc, nc, kc) = get_mcnckc();
256        let (mr, nr, kr) = (48, 8, 8);
257        let m_dims = generate_m_dims(mc, mr);
258        let n_dims = generate_n_dims(nc, nr);
259        let k_dims = generate_k_dims(kc, kr);
260        let unary_fn: unsafe fn(*mut TC, usize) = unary_fn_test;
261        let m_max = *m_dims.iter().max().unwrap();
262        let n_max = *n_dims.iter().max().unwrap();
263        let k_max = *k_dims.iter().max().unwrap();
264        let a_size = matrix_size(m_max, k_max) * a_stride_scale;
265        let b_size = matrix_size(k_max, n_max) * b_stride_scale;
266        let c_size = matrix_size(m_max, n_max) * c_stride_scale;
267        let mut a = vec![TA::ZERO; a_size];
268        let mut b = vec![TB::ZERO; b_size];
269        random_matrix_uniform(&mut a);
270        random_matrix_uniform(&mut b);
271        let mut c = vec![TC::ZERO; c_size];
272        let mut c_ref = vec![TC::ZERO; c_size];
273
274        let ap_size = if is_a_packed { a_size_packed(m_max, k_max) } else { 0 };
275        let mut ap = vec![TA::ZERO; ap_size + AB_ALIGN];
276        let ap_align_offset = ap.as_ptr().align_offset(AB_ALIGN);
277        let ap_mut_ref = &mut ap[ap_align_offset..];
278
279        let bp_size = if is_b_packed { b_size_packed(n_max, k_max) } else { 0 };
280        let mut bp = vec![TB::ZERO; bp_size + AB_ALIGN];
281        let bp_align_offset = bp.as_ptr().align_offset(AB_ALIGN);
282        let bp_mut_ref = &mut bp[bp_align_offset..];
283        for &m in &m_dims {
284            for &n in &n_dims {
285                for &k in &k_dims {
286                    let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = layout_to_strides(&layout, m, n, k);
287                    let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = (
288                        a_rs * a_stride_scale,
289                        a_cs * a_stride_scale,
290                        b_rs * b_stride_scale,
291                        b_cs * b_stride_scale,
292                        c_rs * c_stride_scale,
293                        c_cs * c_stride_scale,
294                    );
295                    let a_matrix = if is_a_packed {
296                        pack_a(m, k, &a, a_rs, a_cs, ap_mut_ref)
297                    } else {
298                        Array::strided_matrix(a.as_ptr(), a_rs, a_cs)
299                    };
300                    let b_matrix = if is_b_packed {
301                        pack_b(n, k, &b, b_rs, b_cs, bp_mut_ref)
302                    } else {
303                        Array::strided_matrix(b.as_ptr(), b_rs, b_cs)
304                    };
305                    for alpha in ALPHA_ARR {
306                        for beta in BETA_ARR {
307                            random_matrix_uniform(&mut c);
308                            c_ref.copy_from_slice(&c);
309                            let c_matrix = ArrayMut::strided_matrix(c.as_mut_ptr(), c_rs, c_cs);
310                            unsafe {
311                                pire_hgemm_fused(m, n, k, alpha, a_matrix, b_matrix, beta, c_matrix, unary_fn);
312                            }
313                            let diff_max = unsafe {
314                                check_gemm_f16(
315                                    m,
316                                    n,
317                                    k,
318                                    alpha,
319                                    a.as_ptr(),
320                                    a_rs,
321                                    a_cs,
322                                    b.as_ptr(),
323                                    b_rs,
324                                    b_cs,
325                                    beta,
326                                    &mut c,
327                                    c_rs,
328                                    c_cs,
329                                    &mut c_ref,
330                                    unary_fn,
331                                    EPS,
332                                )
333                            };
334                            // if diff_max >= EPS {
335                            // 	println!("a: {:?}", a);
336                            // 	println!("b: {:?}", b);
337                            // 	println!("c:     {:?}", c);
338                            // 	println!("c_ref: {:?}", c_ref);
339                            // }
340                            assert!(
341                                diff_max < EPS,
342                                "diff_max: {}, m: {}, n: {}, k: {}, alpha: {}, beta: {}",
343                                diff_max,
344                                m,
345                                n,
346                                k,
347                                alpha,
348                                beta
349                            );
350                        }
351                    }
352                }
353            }
354        }
355    }
356    #[test]
357    fn test_nn_col_ap() {
358        test_gemm(&ABLayout::NN, true, false);
359    }
360    #[test]
361    fn test_nt_col_ap() {
362        test_gemm(&ABLayout::NT, true, false);
363    }
364    #[test]
365    fn test_tn_col_ap() {
366        test_gemm(&ABLayout::TN, true, false);
367    }
368    #[test]
369    fn test_tt_col_ap() {
370        test_gemm(&ABLayout::TT, true, false);
371    }
372    #[test]
373    fn test_nn_col_bp() {
374        test_gemm(&ABLayout::NN, false, true);
375    }
376    #[test]
377    fn test_nt_col_bp() {
378        test_gemm(&ABLayout::NT, false, true);
379    }
380    #[test]
381    fn test_tn_col_bp() {
382        test_gemm(&ABLayout::TN, false, true);
383    }
384    #[test]
385    fn test_tt_col_bp() {
386        test_gemm(&ABLayout::TT, false, true);
387    }
388    #[test]
389    fn test_nn_col() {
390        test_gemm(&ABLayout::NN, false, false);
391    }
392    #[test]
393    fn test_nt_col() {
394        test_gemm(&ABLayout::NT, false, false);
395    }
396    #[test]
397    fn test_tn_col() {
398        test_gemm(&ABLayout::TN, false, false);
399    }
400    #[test]
401    fn test_tt_col() {
402        test_gemm(&ABLayout::TT, false, false);
403    }
404    #[test]
405    fn test_nn_col_apbp() {
406        test_gemm(&ABLayout::NN, true, true);
407    }
408    #[test]
409    fn test_nt_col_apbp() {
410        test_gemm(&ABLayout::NT, true, true);
411    }
412    #[test]
413    fn test_tn_col_apbp() {
414        test_gemm(&ABLayout::TN, true, true);
415    }
416    #[test]
417    fn test_tt_col_apbp() {
418        test_gemm(&ABLayout::TT, true, true);
419    }
420}