1#[cfg(target_arch = "aarch64")]
2pub(crate) mod arm64;
3#[cfg(target_arch = "x86_64")]
4pub(crate) mod x86_64_arch;
5#[cfg(target_arch = "x86")]
6pub(crate) mod x86_arch;
7
8#[cfg(target_arch = "x86_64")]
9use x86_64_arch::{
10 get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, round_k_simd, round_m_simd, KernelDispatcher,
11};
12
13#[cfg(target_arch = "x86")]
14use x86_arch::{
15 get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, round_k_simd, round_m_simd, KernelDispatcher,
16};
17
18#[cfg(target_arch = "aarch64")]
19use arm64::{get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, round_k_simd, round_m_simd, KernelDispatcher};
20
21pub(crate) mod reference;
22use core::mem::size_of;
23
24pub(crate) type TA = f64;
25pub(crate) type TB = f64;
26pub(crate) type TC = f64;
27#[allow(unused)]
28const TC_SIZE: usize = size_of::<TC>();
29
30use pire_base::{
31 get_cache_params, has_f64_compute, Array, ArrayMut, GemmCache, IdentityFn, PirePar, UnaryFn, AB_ALIGN,
32};
33use reference::{packa_fn_ref, packb_fn_ref, round_k_ref, round_m_ref, RefGemm};
34
35pub trait UnaryFnC: UnaryFn<TC> {}
36impl<F: UnaryFn<TC>> UnaryFnC for F {}
37
38pub(crate) unsafe fn pire_dgemm_fused<F: UnaryFnC>(
39 m: usize,
40 n: usize,
41 k: usize,
42 alpha: TA,
43 a: Array<TA>,
44 b: Array<TB>,
45 beta: TC,
46 c: ArrayMut<TC>,
47 f: F,
48) {
49 let par = PirePar::default(m, n);
50 if has_f64_compute() {
51 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
52 {
53 let hw_config = KernelDispatcher::new(f);
54 pire_gemm(&hw_config, m, n, k, alpha, a, b, beta, c, &par);
55 return;
56 }
57 }
58 let hw_config = RefGemm::new(f);
60 reference::pire_gemm(&hw_config, m, n, k, alpha, a, b, beta, c, &par);
61}
62pub unsafe fn pire_dgemm(
63 m: usize,
64 n: usize,
65 k: usize,
66 alpha: TA,
67 a: *const TA,
68 a_rs: usize,
69 a_cs: usize,
70 b: *const TB,
71 b_rs: usize,
72 b_cs: usize,
73 beta: TC,
74 c: *mut TC,
75 c_rs: usize,
76 c_cs: usize,
77) {
78 let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
80 (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
81 } else {
82 (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
83 };
84 let a = Array::strided_matrix(a, a_rs, a_cs);
85 let b = Array::strided_matrix(b, b_rs, b_cs);
86 let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
87 let identity_fn = IdentityFn {};
88 pire_dgemm_fused(m, n, k, alpha, a, b, beta, c, identity_fn);
89}
90
91#[cfg(feature = "fuse")]
92pub unsafe fn pire_dgemm_fn_ptr(
93 m: usize,
94 n: usize,
95 k: usize,
96 alpha: TA,
97 a: *const TA,
98 a_rs: usize,
99 a_cs: usize,
100 b: *const TB,
101 b_rs: usize,
102 b_cs: usize,
103 beta: TC,
104 c: *mut TC,
105 c_rs: usize,
106 c_cs: usize,
107 unary: unsafe fn(*mut TC, usize),
108) {
109 let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
111 (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
112 } else {
113 (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
114 };
115 let a = Array::strided_matrix(a, a_rs, a_cs);
116 let b = Array::strided_matrix(b, b_rs, b_cs);
117 let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
118 pire_dgemm_fused(m, n, k, alpha, a, b, beta, c, unary);
119}
120
121fn dispatch_round_m() -> fn(usize) -> usize {
122 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
123 {
124 if has_f64_compute() {
125 return round_m_simd;
126 }
127 }
128 round_m_ref
129}
130fn dispatch_round_k() -> fn(usize) -> usize {
131 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
132 {
133 if has_f64_compute() {
134 return round_k_simd;
135 }
136 }
137 round_k_ref
138}
139
140fn dispatch_pack_a() -> unsafe fn(*const TA, *mut TA, usize, usize, usize, usize) {
141 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
142 {
143 if has_f64_compute() {
144 return packa_fn_simd;
145 }
146 }
147 packa_fn_ref
148}
149
150fn dispatch_pack_b() -> unsafe fn(*const TB, *mut TB, usize, usize, usize, usize) {
151 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
152 {
153 if has_f64_compute() {
154 return packb_fn_simd;
155 }
156 }
157 packb_fn_ref
158}
159
160fn dispatch_get_mcnckc() -> (usize, usize, usize) {
161 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
162 {
163 if has_f64_compute() {
164 return get_mcnckc_simd();
165 }
166 }
167 get_cache_params()
168}
169
170pire_base::packing_api!(TA, TB);
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use pire_base::{get_cache_params, matrix_size};
176 use pire_dev::{
177 check_gemm_f64, generate_k_dims, generate_m_dims, generate_n_dims, layout_to_strides, random_matrix_uniform,
178 ABLayout,
179 };
180 #[test]
181 fn test_pack_a() {
182 let a_stride_scale = 1;
183 let (mc, _, kc) = get_mcnckc();
184 let (mr, _, kr) = (48, 8, 8);
185 let m_dims = generate_m_dims(mc, mr);
186 let k_dims = generate_k_dims(kc, kr);
187
188 for &m in &m_dims {
189 for &k in &k_dims {
190 let a_rs = 1 * a_stride_scale;
191 let a_cs = m * a_stride_scale;
192 let a_size = a_size_packed(m, k);
193 let a = vec![0.0; m * k * a_stride_scale];
194 let mut ap = vec![0.0; a_size + AB_ALIGN];
195 let ap_align_offset = ap.as_ptr().align_offset(AB_ALIGN);
196 let ap_array = pack_a(m, k, &a, a_rs, a_cs, &mut ap[ap_align_offset..]);
197 assert!(!ap_array.is_strided() || m == 1);
198 }
199 }
200 }
201
202 #[test]
203 fn test_pack_b() {
204 let b_stride_scale = 1;
205 let (_, nc, kc) = get_mcnckc();
206 let (_, nr, kr) = (48, 8, 8);
207 let n_dims = generate_n_dims(nc, nr);
208 let k_dims = generate_k_dims(kc, kr);
209
210 for &n in &n_dims {
211 for &k in &k_dims {
212 let b_rs = 1 * b_stride_scale;
213 let b_cs = k * b_stride_scale;
214 let b_size = b_size_packed(n, k);
215 let b = vec![0.0; n * k * b_stride_scale];
216 let mut bp = vec![0.0; b_size + AB_ALIGN];
217 let bp_align_offset = bp.as_ptr().align_offset(AB_ALIGN);
218 let bp_array = pack_b(n, k, &b, b_rs, b_cs, &mut bp[bp_align_offset..]);
219 assert!(!bp_array.is_strided() || n == 1);
220 }
221 }
222 }
223
224 #[allow(unreachable_code)]
225 pub(crate) fn get_mcnckc() -> (usize, usize, usize) {
226 #[cfg(target_arch = "x86_64")]
227 {
228 return x86_64_arch::get_mcnckc_simd();
229 }
230 get_cache_params()
231 }
232
233 unsafe fn unary_fn_test(c: *mut TC, m: usize) {
234 for i in 0..m {
235 *c.add(i) *= 2.0;
236 }
237 }
238
239 const EPS: f64 = 2e-2;
240
241 static ALPHA_ARR: [f64; 1] = [1.79];
242 static BETA_ARR: [f64; 1] = [3.0];
243
244 fn test_gemm(layout: &ABLayout, is_a_packed: bool, is_b_packed: bool) {
245 let a_stride_scale = 1;
246 let b_stride_scale = 1;
247 let c_stride_scale = 2;
248 let (mc, nc, kc) = get_mcnckc();
249 let (mr, nr, kr) = (48, 8, 8);
250 let m_dims = generate_m_dims(mc, mr);
251 let n_dims = generate_n_dims(nc, nr);
252 let k_dims = generate_k_dims(kc, kr);
253 let unary_fn: unsafe fn(*mut TC, usize) = unary_fn_test;
254 let m_max = *m_dims.iter().max().unwrap();
255 let n_max = *n_dims.iter().max().unwrap();
256 let k_max = *k_dims.iter().max().unwrap();
257 let a_size = matrix_size(m_max, k_max) * a_stride_scale;
258 let b_size = matrix_size(k_max, n_max) * b_stride_scale;
259 let c_size = matrix_size(m_max, n_max) * c_stride_scale;
260 let mut a = vec![0f64; a_size];
261 let mut b = vec![0f64; b_size];
262 random_matrix_uniform(&mut a);
263 random_matrix_uniform(&mut b);
264 let mut c = vec![0f64; c_size];
265 let mut c_ref = vec![0f64; c_size];
266
267 let ap_size = if is_a_packed { a_size_packed(m_max, k_max) } else { 0 };
268 let mut ap = vec![0f64; ap_size + AB_ALIGN];
269 let ap_align_offset = ap.as_ptr().align_offset(AB_ALIGN);
270 let ap_mut_ref = &mut ap[ap_align_offset..];
271
272 let bp_size = if is_b_packed { b_size_packed(n_max, k_max) } else { 0 };
273 let mut bp = vec![0f64; bp_size + AB_ALIGN];
274 let bp_align_offset = bp.as_ptr().align_offset(AB_ALIGN);
275 let bp_mut_ref = &mut bp[bp_align_offset..];
276 for &m in &m_dims {
277 for &n in &n_dims {
278 for &k in &k_dims {
279 let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = layout_to_strides(&layout, m, n, k);
280 let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = (
281 a_rs * a_stride_scale,
282 a_cs * a_stride_scale,
283 b_rs * b_stride_scale,
284 b_cs * b_stride_scale,
285 c_rs * c_stride_scale,
286 c_cs * c_stride_scale,
287 );
288 let a_matrix = if is_a_packed {
289 pack_a(m, k, &a, a_rs, a_cs, ap_mut_ref)
290 } else {
291 Array::strided_matrix(a.as_ptr(), a_rs, a_cs)
292 };
293 let b_matrix = if is_b_packed {
294 pack_b(n, k, &b, b_rs, b_cs, bp_mut_ref)
295 } else {
296 Array::strided_matrix(b.as_ptr(), b_rs, b_cs)
297 };
298 for alpha in ALPHA_ARR {
299 for beta in BETA_ARR {
300 random_matrix_uniform(&mut c);
301 c_ref.copy_from_slice(&c);
302 let c_matrix = ArrayMut::strided_matrix(c.as_mut_ptr(), c_rs, c_cs);
303 unsafe {
304 pire_dgemm_fused(m, n, k, alpha, a_matrix, b_matrix, beta, c_matrix, unary_fn);
305 }
306 let diff_max = unsafe {
307 check_gemm_f64(
308 m,
309 n,
310 k,
311 alpha,
312 a.as_ptr(),
313 a_rs,
314 a_cs,
315 b.as_ptr(),
316 b_rs,
317 b_cs,
318 beta,
319 &mut c,
320 c_rs,
321 c_cs,
322 &mut c_ref,
323 unary_fn,
324 EPS,
325 )
326 };
327 assert!(
334 diff_max < EPS,
335 "diff_max: {}, m: {}, n: {}, k: {}, alpha: {}, beta: {}",
336 diff_max,
337 m,
338 n,
339 k,
340 alpha,
341 beta
342 );
343 }
344 }
345 }
346 }
347 }
348 }
349 #[test]
350 fn test_nn_col() {
351 test_gemm(&ABLayout::NN, false, false);
352 }
353
354 #[test]
355 fn test_nt_col() {
356 test_gemm(&ABLayout::NT, false, false);
357 }
358
359 #[test]
360 fn test_tn_col() {
361 test_gemm(&ABLayout::TN, false, false);
362 }
363
364 #[test]
365 fn test_tt_col() {
366 test_gemm(&ABLayout::TT, false, false);
367 }
368 #[test]
369 fn test_nn_col_ap() {
370 test_gemm(&ABLayout::NN, true, false);
371 }
372 #[test]
373 fn test_nt_col_ap() {
374 test_gemm(&ABLayout::NT, true, false);
375 }
376 #[test]
377 fn test_tn_col_ap() {
378 test_gemm(&ABLayout::TN, true, false);
379 }
380 #[test]
381 fn test_tt_col_ap() {
382 test_gemm(&ABLayout::TT, true, false);
383 }
384 #[test]
385 fn test_nn_col_bp() {
386 test_gemm(&ABLayout::NN, false, true);
387 }
388 #[test]
389 fn test_nt_col_bp() {
390 test_gemm(&ABLayout::NT, false, true);
391 }
392 #[test]
393 fn test_tn_col_bp() {
394 test_gemm(&ABLayout::TN, false, true);
395 }
396 #[test]
397 fn test_tt_col_bp() {
398 test_gemm(&ABLayout::TT, false, true);
399 }
400
401 #[test]
402 fn test_nn_col_apbp() {
403 test_gemm(&ABLayout::NN, true, true);
404 }
405 #[test]
406 fn test_nt_col_apbp() {
407 test_gemm(&ABLayout::NT, true, true);
408 }
409 #[test]
410 fn test_tn_col_apbp() {
411 test_gemm(&ABLayout::TN, true, true);
412 }
413 #[test]
414 fn test_tt_col_apbp() {
415 test_gemm(&ABLayout::TT, true, true);
416 }
417}