pire_gemm_f64/
lib.rs

1#[cfg(target_arch = "aarch64")]
2pub(crate) mod arm64;
3#[cfg(target_arch = "x86_64")]
4pub(crate) mod x86_64_arch;
5#[cfg(target_arch = "x86")]
6pub(crate) mod x86_arch;
7
8#[cfg(target_arch = "x86_64")]
9use x86_64_arch::{
10    get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, round_k_simd, round_m_simd, KernelDispatcher,
11};
12
13#[cfg(target_arch = "x86")]
14use x86_arch::{
15    get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, round_k_simd, round_m_simd, KernelDispatcher,
16};
17
18#[cfg(target_arch = "aarch64")]
19use arm64::{get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, round_k_simd, round_m_simd, KernelDispatcher};
20
21pub(crate) mod reference;
22use core::mem::size_of;
23
24pub(crate) type TA = f64;
25pub(crate) type TB = f64;
26pub(crate) type TC = f64;
27#[allow(unused)]
28const TC_SIZE: usize = size_of::<TC>();
29
30use pire_base::{
31    get_cache_params, has_f64_compute, Array, ArrayMut, GemmCache, IdentityFn, PirePar, UnaryFn, AB_ALIGN,
32};
33use reference::{packa_fn_ref, packb_fn_ref, round_k_ref, round_m_ref, RefGemm};
34
35pub trait UnaryFnC: UnaryFn<TC> {}
36impl<F: UnaryFn<TC>> UnaryFnC for F {}
37
38pub(crate) unsafe fn pire_dgemm_fused<F: UnaryFnC>(
39    m: usize,
40    n: usize,
41    k: usize,
42    alpha: TA,
43    a: Array<TA>,
44    b: Array<TB>,
45    beta: TC,
46    c: ArrayMut<TC>,
47    f: F,
48) {
49    let par = PirePar::default(m, n);
50    if has_f64_compute() {
51        #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
52        {
53            let hw_config = KernelDispatcher::new(f);
54            pire_gemm(&hw_config, m, n, k, alpha, a, b, beta, c, &par);
55            return;
56        }
57    }
58    // if none of the optimized paths are available, use reference implementation
59    let hw_config = RefGemm::new(f);
60    reference::pire_gemm(&hw_config, m, n, k, alpha, a, b, beta, c, &par);
61}
62pub unsafe fn pire_dgemm(
63    m: usize,
64    n: usize,
65    k: usize,
66    alpha: TA,
67    a: *const TA,
68    a_rs: usize,
69    a_cs: usize,
70    b: *const TB,
71    b_rs: usize,
72    b_cs: usize,
73    beta: TC,
74    c: *mut TC,
75    c_rs: usize,
76    c_cs: usize,
77) {
78    // transpose if c is row strided i.e. c_cs == 1 and c_rs != 1
79    let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
80        (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
81    } else {
82        (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
83    };
84    let a = Array::strided_matrix(a, a_rs, a_cs);
85    let b = Array::strided_matrix(b, b_rs, b_cs);
86    let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
87    let identity_fn = IdentityFn {};
88    pire_dgemm_fused(m, n, k, alpha, a, b, beta, c, identity_fn);
89}
90
91#[cfg(feature = "fuse")]
92pub unsafe fn pire_dgemm_fn_ptr(
93    m: usize,
94    n: usize,
95    k: usize,
96    alpha: TA,
97    a: *const TA,
98    a_rs: usize,
99    a_cs: usize,
100    b: *const TB,
101    b_rs: usize,
102    b_cs: usize,
103    beta: TC,
104    c: *mut TC,
105    c_rs: usize,
106    c_cs: usize,
107    unary: unsafe fn(*mut TC, usize),
108) {
109    // transpose if c is row strided i.e. c_cs == 1 and c_rs != 1
110    let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
111        (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
112    } else {
113        (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
114    };
115    let a = Array::strided_matrix(a, a_rs, a_cs);
116    let b = Array::strided_matrix(b, b_rs, b_cs);
117    let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
118    pire_dgemm_fused(m, n, k, alpha, a, b, beta, c, unary);
119}
120
121fn dispatch_round_m() -> fn(usize) -> usize {
122    #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
123    {
124        if has_f64_compute() {
125            return round_m_simd;
126        }
127    }
128    round_m_ref
129}
130fn dispatch_round_k() -> fn(usize) -> usize {
131    #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
132    {
133        if has_f64_compute() {
134            return round_k_simd;
135        }
136    }
137    round_k_ref
138}
139
140fn dispatch_pack_a() -> unsafe fn(*const TA, *mut TA, usize, usize, usize, usize) {
141    #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
142    {
143        if has_f64_compute() {
144            return packa_fn_simd;
145        }
146    }
147    packa_fn_ref
148}
149
150fn dispatch_pack_b() -> unsafe fn(*const TB, *mut TB, usize, usize, usize, usize) {
151    #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
152    {
153        if has_f64_compute() {
154            return packb_fn_simd;
155        }
156    }
157    packb_fn_ref
158}
159
160fn dispatch_get_mcnckc() -> (usize, usize, usize) {
161    #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
162    {
163        if has_f64_compute() {
164            return get_mcnckc_simd();
165        }
166    }
167    get_cache_params()
168}
169
170pire_base::packing_api!(TA, TB);
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use pire_base::{get_cache_params, matrix_size};
176    use pire_dev::{
177        check_gemm_f64, generate_k_dims, generate_m_dims, generate_n_dims, layout_to_strides, random_matrix_uniform,
178        ABLayout,
179    };
180    #[test]
181    fn test_pack_a() {
182        let a_stride_scale = 1;
183        let (mc, _, kc) = get_mcnckc();
184        let (mr, _, kr) = (48, 8, 8);
185        let m_dims = generate_m_dims(mc, mr);
186        let k_dims = generate_k_dims(kc, kr);
187
188        for &m in &m_dims {
189            for &k in &k_dims {
190                let a_rs = 1 * a_stride_scale;
191                let a_cs = m * a_stride_scale;
192                let a_size = a_size_packed(m, k);
193                let a = vec![0.0; m * k * a_stride_scale];
194                let mut ap = vec![0.0; a_size + AB_ALIGN];
195                let ap_align_offset = ap.as_ptr().align_offset(AB_ALIGN);
196                let ap_array = pack_a(m, k, &a, a_rs, a_cs, &mut ap[ap_align_offset..]);
197                assert!(!ap_array.is_strided() || m == 1);
198            }
199        }
200    }
201
202    #[test]
203    fn test_pack_b() {
204        let b_stride_scale = 1;
205        let (_, nc, kc) = get_mcnckc();
206        let (_, nr, kr) = (48, 8, 8);
207        let n_dims = generate_n_dims(nc, nr);
208        let k_dims = generate_k_dims(kc, kr);
209
210        for &n in &n_dims {
211            for &k in &k_dims {
212                let b_rs = 1 * b_stride_scale;
213                let b_cs = k * b_stride_scale;
214                let b_size = b_size_packed(n, k);
215                let b = vec![0.0; n * k * b_stride_scale];
216                let mut bp = vec![0.0; b_size + AB_ALIGN];
217                let bp_align_offset = bp.as_ptr().align_offset(AB_ALIGN);
218                let bp_array = pack_b(n, k, &b, b_rs, b_cs, &mut bp[bp_align_offset..]);
219                assert!(!bp_array.is_strided() || n == 1);
220            }
221        }
222    }
223
224    #[allow(unreachable_code)]
225    pub(crate) fn get_mcnckc() -> (usize, usize, usize) {
226        #[cfg(target_arch = "x86_64")]
227        {
228            return x86_64_arch::get_mcnckc_simd();
229        }
230        get_cache_params()
231    }
232
233    unsafe fn unary_fn_test(c: *mut TC, m: usize) {
234        for i in 0..m {
235            *c.add(i) *= 2.0;
236        }
237    }
238
239    const EPS: f64 = 2e-2;
240
241    static ALPHA_ARR: [f64; 1] = [1.79];
242    static BETA_ARR: [f64; 1] = [3.0];
243
244    fn test_gemm(layout: &ABLayout, is_a_packed: bool, is_b_packed: bool) {
245        let a_stride_scale = 1;
246        let b_stride_scale = 1;
247        let c_stride_scale = 2;
248        let (mc, nc, kc) = get_mcnckc();
249        let (mr, nr, kr) = (48, 8, 8);
250        let m_dims = generate_m_dims(mc, mr);
251        let n_dims = generate_n_dims(nc, nr);
252        let k_dims = generate_k_dims(kc, kr);
253        let unary_fn: unsafe fn(*mut TC, usize) = unary_fn_test;
254        let m_max = *m_dims.iter().max().unwrap();
255        let n_max = *n_dims.iter().max().unwrap();
256        let k_max = *k_dims.iter().max().unwrap();
257        let a_size = matrix_size(m_max, k_max) * a_stride_scale;
258        let b_size = matrix_size(k_max, n_max) * b_stride_scale;
259        let c_size = matrix_size(m_max, n_max) * c_stride_scale;
260        let mut a = vec![0f64; a_size];
261        let mut b = vec![0f64; b_size];
262        random_matrix_uniform(&mut a);
263        random_matrix_uniform(&mut b);
264        let mut c = vec![0f64; c_size];
265        let mut c_ref = vec![0f64; c_size];
266
267        let ap_size = if is_a_packed { a_size_packed(m_max, k_max) } else { 0 };
268        let mut ap = vec![0f64; ap_size + AB_ALIGN];
269        let ap_align_offset = ap.as_ptr().align_offset(AB_ALIGN);
270        let ap_mut_ref = &mut ap[ap_align_offset..];
271
272        let bp_size = if is_b_packed { b_size_packed(n_max, k_max) } else { 0 };
273        let mut bp = vec![0f64; bp_size + AB_ALIGN];
274        let bp_align_offset = bp.as_ptr().align_offset(AB_ALIGN);
275        let bp_mut_ref = &mut bp[bp_align_offset..];
276        for &m in &m_dims {
277            for &n in &n_dims {
278                for &k in &k_dims {
279                    let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = layout_to_strides(&layout, m, n, k);
280                    let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = (
281                        a_rs * a_stride_scale,
282                        a_cs * a_stride_scale,
283                        b_rs * b_stride_scale,
284                        b_cs * b_stride_scale,
285                        c_rs * c_stride_scale,
286                        c_cs * c_stride_scale,
287                    );
288                    let a_matrix = if is_a_packed {
289                        pack_a(m, k, &a, a_rs, a_cs, ap_mut_ref)
290                    } else {
291                        Array::strided_matrix(a.as_ptr(), a_rs, a_cs)
292                    };
293                    let b_matrix = if is_b_packed {
294                        pack_b(n, k, &b, b_rs, b_cs, bp_mut_ref)
295                    } else {
296                        Array::strided_matrix(b.as_ptr(), b_rs, b_cs)
297                    };
298                    for alpha in ALPHA_ARR {
299                        for beta in BETA_ARR {
300                            random_matrix_uniform(&mut c);
301                            c_ref.copy_from_slice(&c);
302                            let c_matrix = ArrayMut::strided_matrix(c.as_mut_ptr(), c_rs, c_cs);
303                            unsafe {
304                                pire_dgemm_fused(m, n, k, alpha, a_matrix, b_matrix, beta, c_matrix, unary_fn);
305                            }
306                            let diff_max = unsafe {
307                                check_gemm_f64(
308                                    m,
309                                    n,
310                                    k,
311                                    alpha,
312                                    a.as_ptr(),
313                                    a_rs,
314                                    a_cs,
315                                    b.as_ptr(),
316                                    b_rs,
317                                    b_cs,
318                                    beta,
319                                    &mut c,
320                                    c_rs,
321                                    c_cs,
322                                    &mut c_ref,
323                                    unary_fn,
324                                    EPS,
325                                )
326                            };
327                            // if diff_max >= EPS {
328                            // 	println!("a: {:?}", a);
329                            // 	println!("b: {:?}", b);
330                            // 	println!("c:     {:?}", c);
331                            // 	println!("c_ref: {:?}", c_ref);
332                            // }
333                            assert!(
334                                diff_max < EPS,
335                                "diff_max: {}, m: {}, n: {}, k: {}, alpha: {}, beta: {}",
336                                diff_max,
337                                m,
338                                n,
339                                k,
340                                alpha,
341                                beta
342                            );
343                        }
344                    }
345                }
346            }
347        }
348    }
349    #[test]
350    fn test_nn_col() {
351        test_gemm(&ABLayout::NN, false, false);
352    }
353
354    #[test]
355    fn test_nt_col() {
356        test_gemm(&ABLayout::NT, false, false);
357    }
358
359    #[test]
360    fn test_tn_col() {
361        test_gemm(&ABLayout::TN, false, false);
362    }
363
364    #[test]
365    fn test_tt_col() {
366        test_gemm(&ABLayout::TT, false, false);
367    }
368    #[test]
369    fn test_nn_col_ap() {
370        test_gemm(&ABLayout::NN, true, false);
371    }
372    #[test]
373    fn test_nt_col_ap() {
374        test_gemm(&ABLayout::NT, true, false);
375    }
376    #[test]
377    fn test_tn_col_ap() {
378        test_gemm(&ABLayout::TN, true, false);
379    }
380    #[test]
381    fn test_tt_col_ap() {
382        test_gemm(&ABLayout::TT, true, false);
383    }
384    #[test]
385    fn test_nn_col_bp() {
386        test_gemm(&ABLayout::NN, false, true);
387    }
388    #[test]
389    fn test_nt_col_bp() {
390        test_gemm(&ABLayout::NT, false, true);
391    }
392    #[test]
393    fn test_tn_col_bp() {
394        test_gemm(&ABLayout::TN, false, true);
395    }
396    #[test]
397    fn test_tt_col_bp() {
398        test_gemm(&ABLayout::TT, false, true);
399    }
400
401    #[test]
402    fn test_nn_col_apbp() {
403        test_gemm(&ABLayout::NN, true, true);
404    }
405    #[test]
406    fn test_nt_col_apbp() {
407        test_gemm(&ABLayout::NT, true, true);
408    }
409    #[test]
410    fn test_tn_col_apbp() {
411        test_gemm(&ABLayout::TN, true, true);
412    }
413    #[test]
414    fn test_tt_col_apbp() {
415        test_gemm(&ABLayout::TT, true, true);
416    }
417}