1#[cfg(target_arch = "aarch64")]
2pub(crate) mod arm64;
3#[cfg(target_arch = "x86_64")]
4pub(crate) mod x86_64_arch;
5#[cfg(target_arch = "x86")]
6pub(crate) mod x86_arch;
7
8#[cfg(target_arch = "x86_64")]
9use x86_64_arch::{
10 get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, round_k_simd, round_m_simd, KernelDispatcher,
11};
12
13#[cfg(target_arch = "x86")]
14use x86_arch::{
15 get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, round_k_simd, round_m_simd, KernelDispatcher,
16};
17
18#[cfg(target_arch = "aarch64")]
19use arm64::{get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, round_k_simd, round_m_simd, KernelDispatcher};
20
21pub(crate) mod reference;
22
23use core::mem::size_of;
24
25pub(crate) type TA = f32;
26pub(crate) type TB = f32;
27pub(crate) type TC = f32;
28#[allow(unused)]
29const TC_SIZE: usize = size_of::<TC>();
30
31use reference::{packa_fn_ref, packb_fn_ref, round_k_ref, round_m_ref, RefGemm};
32
33use pire_base::{
34 get_cache_params, has_f32_compute, Array, ArrayMut, GemmCache, IdentityFn, PirePar, UnaryFn, AB_ALIGN,
35};
36
37pub trait UnaryFnC: UnaryFn<TC> {}
38impl<F: UnaryFn<TC>> UnaryFnC for F {}
39
40pub(crate) unsafe fn pire_sgemm_fused<F: UnaryFnC>(
41 m: usize,
42 n: usize,
43 k: usize,
44 alpha: TA,
45 a: Array<TA>,
46 b: Array<TB>,
47 beta: TC,
48 c: ArrayMut<TC>,
49 f: F,
50) {
51 let par = PirePar::default(m, n);
52 if has_f32_compute() {
53 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
54 {
55 let hw_config = KernelDispatcher::new(f);
56 pire_gemm(&hw_config, m, n, k, alpha, a, b, beta, c, &par);
57 return;
58 }
59 }
60 let hw_config = RefGemm::new(f);
62 reference::pire_gemm(&hw_config, m, n, k, alpha, a, b, beta, c, &par);
63}
64
65pub unsafe fn pire_sgemm(
66 m: usize,
67 n: usize,
68 k: usize,
69 alpha: TA,
70 a: *const TA,
71 a_rs: usize,
72 a_cs: usize,
73 b: *const TB,
74 b_rs: usize,
75 b_cs: usize,
76 beta: TC,
77 c: *mut TC,
78 c_rs: usize,
79 c_cs: usize,
80) {
81 let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
83 (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
84 } else {
85 (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
86 };
87 let a = Array::strided_matrix(a, a_rs, a_cs);
88 let b = Array::strided_matrix(b, b_rs, b_cs);
89 let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
90 let identity_fn = IdentityFn {};
91 pire_sgemm_fused(m, n, k, alpha, a, b, beta, c, identity_fn);
92}
93
94#[cfg(feature = "fuse")]
95pub unsafe fn pire_sgemm_fn_ptr(
96 m: usize,
97 n: usize,
98 k: usize,
99 alpha: TA,
100 a: *const TA,
101 a_rs: usize,
102 a_cs: usize,
103 b: *const TB,
104 b_rs: usize,
105 b_cs: usize,
106 beta: TC,
107 c: *mut TC,
108 c_rs: usize,
109 c_cs: usize,
110 unary: unsafe fn(*mut TC, usize),
111) {
112 let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
114 (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
115 } else {
116 (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
117 };
118 let a = Array::strided_matrix(a, a_rs, a_cs);
119 let b = Array::strided_matrix(b, b_rs, b_cs);
120 let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
121 pire_sgemm_fused(m, n, k, alpha, a, b, beta, c, unary);
122}
123
124fn dispatch_round_m() -> fn(usize) -> usize {
125 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
126 {
127 if has_f32_compute() {
128 return round_m_simd;
129 }
130 }
131 round_m_ref
132}
133fn dispatch_round_k() -> fn(usize) -> usize {
134 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
135 {
136 if has_f32_compute() {
137 return round_k_simd;
138 }
139 }
140 round_k_ref
141}
142
143fn dispatch_pack_a() -> unsafe fn(*const TA, *mut TA, usize, usize, usize, usize) {
144 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
145 {
146 if has_f32_compute() {
147 return packa_fn_simd;
148 }
149 }
150 packa_fn_ref
151}
152
153fn dispatch_pack_b() -> unsafe fn(*const TB, *mut TB, usize, usize, usize, usize) {
154 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
155 {
156 if has_f32_compute() {
157 return packb_fn_simd;
158 }
159 }
160 packb_fn_ref
161}
162
163fn dispatch_get_mcnckc() -> (usize, usize, usize) {
164 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
165 {
166 if has_f32_compute() {
167 return get_mcnckc_simd();
168 }
169 }
170 get_cache_params()
171}
172
173pire_base::packing_api!(TA, TB);
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use pire_base::{get_cache_params, matrix_size};
179 use pire_dev::{
180 check_gemm_f32, generate_k_dims, generate_m_dims, generate_n_dims, layout_to_strides, random_matrix_uniform,
181 ABLayout,
182 };
183 #[test]
184 fn test_pack_a() {
185 let a_stride_scale = 1;
186 let (mc, _, kc) = get_mcnckc();
187 let (mr, _, kr) = (48, 8, 8);
188 let m_dims = generate_m_dims(mc, mr);
189 let k_dims = generate_k_dims(kc, kr);
190
191 for &m in &m_dims {
192 for &k in &k_dims {
193 let a_rs = 1 * a_stride_scale;
194 let a_cs = m * a_stride_scale;
195 let a_size = a_size_packed(m, k);
196 let a = vec![0.0; m * k * a_stride_scale];
197 let mut ap = vec![0.0; a_size + AB_ALIGN];
198 let ap_align_offset = ap.as_ptr().align_offset(AB_ALIGN);
199 let ap_array = pack_a(m, k, &a, a_rs, a_cs, &mut ap[ap_align_offset..]);
200 assert!(!ap_array.is_strided() || m == 1);
201 }
202 }
203 }
204
205 #[test]
206 fn test_pack_b() {
207 let b_stride_scale = 1;
208 let (_, nc, kc) = get_mcnckc();
209 let (_, nr, kr) = (48, 8, 8);
210 let n_dims = generate_n_dims(nc, nr);
211 let k_dims = generate_k_dims(kc, kr);
212
213 for &n in &n_dims {
214 for &k in &k_dims {
215 let b_rs = 1 * b_stride_scale;
216 let b_cs = k * b_stride_scale;
217 let b_size = b_size_packed(n, k);
218 let b = vec![0.0; n * k * b_stride_scale];
219 let mut bp = vec![0.0; b_size + AB_ALIGN];
220 let bp_align_offset = bp.as_ptr().align_offset(AB_ALIGN);
221 let bp_array = pack_b(n, k, &b, b_rs, b_cs, &mut bp[bp_align_offset..]);
222 assert!(!bp_array.is_strided() || n == 1);
223 }
224 }
225 }
226
227 #[allow(unreachable_code)]
228 pub(crate) fn get_mcnckc() -> (usize, usize, usize) {
229 #[cfg(target_arch = "x86_64")]
230 {
231 return x86_64_arch::get_mcnckc_simd();
232 }
233 get_cache_params()
234 }
235
236 unsafe fn unary_fn_test(c: *mut TC, m: usize) {
237 for i in 0..m {
238 *c.add(i) *= 2.0;
239 }
240 }
241
242 const EPS: f64 = 2e-2;
243
244 static ALPHA_ARR: [f32; 1] = [2.0];
245 static BETA_ARR: [f32; 1] = [3.1415];
246
247 fn test_gemm(layout: &ABLayout, is_a_packed: bool, is_b_packed: bool) {
248 let a_stride_scale = 1;
249 let b_stride_scale = 1;
250 let c_stride_scale = 2;
251 let (mc, nc, kc) = get_mcnckc();
252 let (mr, nr, kr) = (48, 8, 8);
253 let m_dims = generate_m_dims(mc, mr);
254 let n_dims = generate_n_dims(nc, nr);
255 let k_dims = generate_k_dims(kc, kr);
256 let unary_fn: unsafe fn(*mut TC, usize) = unary_fn_test;
257 let m_max = *m_dims.iter().max().unwrap();
258 let n_max = *n_dims.iter().max().unwrap();
259 let k_max = *k_dims.iter().max().unwrap();
260 let a_size = matrix_size(m_max, k_max) * a_stride_scale;
261 let b_size = matrix_size(k_max, n_max) * b_stride_scale;
262 let c_size = matrix_size(m_max, n_max) * c_stride_scale;
263 let mut a = vec![0f32; a_size];
264 let mut b = vec![0f32; b_size];
265 random_matrix_uniform(&mut a);
266 random_matrix_uniform(&mut b);
267 let mut c = vec![0f32; c_size];
268 let mut c_ref = vec![0f32; c_size];
269
270 let ap_size = if is_a_packed { a_size_packed(m_max, k_max) } else { 0 };
271 let mut ap = vec![0f32; ap_size + AB_ALIGN];
272 let ap_align_offset = ap.as_ptr().align_offset(AB_ALIGN);
273 let ap_mut_ref = &mut ap[ap_align_offset..];
274
275 let bp_size = if is_b_packed { b_size_packed(n_max, k_max) } else { 0 };
276 let mut bp = vec![0f32; bp_size + AB_ALIGN];
277 let bp_align_offset = bp.as_ptr().align_offset(AB_ALIGN);
278 let bp_mut_ref = &mut bp[bp_align_offset..];
279 for &m in &m_dims {
280 for &n in &n_dims {
281 for &k in &k_dims {
282 let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = layout_to_strides(&layout, m, n, k);
283 let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = (
284 a_rs * a_stride_scale,
285 a_cs * a_stride_scale,
286 b_rs * b_stride_scale,
287 b_cs * b_stride_scale,
288 c_rs * c_stride_scale,
289 c_cs * c_stride_scale,
290 );
291 let a_matrix = if is_a_packed {
292 pack_a(m, k, &a, a_rs, a_cs, ap_mut_ref)
293 } else {
294 Array::strided_matrix(a.as_ptr(), a_rs, a_cs)
295 };
296 let b_matrix = if is_b_packed {
297 pack_b(n, k, &b, b_rs, b_cs, bp_mut_ref)
298 } else {
299 Array::strided_matrix(b.as_ptr(), b_rs, b_cs)
300 };
301 for alpha in ALPHA_ARR {
302 for beta in BETA_ARR {
303 random_matrix_uniform(&mut c);
304 c_ref.copy_from_slice(&c);
305 let c_matrix = ArrayMut::strided_matrix(c.as_mut_ptr(), c_rs, c_cs);
306 unsafe {
307 pire_sgemm_fused(m, n, k, alpha, a_matrix, b_matrix, beta, c_matrix, unary_fn);
308 }
309 let diff_max = unsafe {
310 check_gemm_f32(
311 m,
312 n,
313 k,
314 alpha,
315 a.as_ptr(),
316 a_rs,
317 a_cs,
318 b.as_ptr(),
319 b_rs,
320 b_cs,
321 beta,
322 &mut c,
323 c_rs,
324 c_cs,
325 &mut c_ref,
326 unary_fn,
327 EPS,
328 )
329 };
330 assert!(
337 diff_max < EPS,
338 "diff_max: {}, m: {}, n: {}, k: {}, alpha: {}, beta: {}",
339 diff_max,
340 m,
341 n,
342 k,
343 alpha,
344 beta
345 );
346 }
347 }
348 }
349 }
350 }
351 }
352 #[test]
353 fn test_nn_col() {
354 test_gemm(&ABLayout::NN, false, false);
355 }
356
357 #[test]
358 fn test_nt_col() {
359 test_gemm(&ABLayout::NT, false, false);
360 }
361
362 #[test]
363 fn test_tn_col() {
364 test_gemm(&ABLayout::TN, false, false);
365 }
366
367 #[test]
368 fn test_tt_col() {
369 test_gemm(&ABLayout::TT, false, false);
370 }
371 #[test]
372 fn test_nn_col_ap() {
373 test_gemm(&ABLayout::NN, true, false);
374 }
375 #[test]
376 fn test_nt_col_ap() {
377 test_gemm(&ABLayout::NT, true, false);
378 }
379 #[test]
380 fn test_tn_col_ap() {
381 test_gemm(&ABLayout::TN, true, false);
382 }
383 #[test]
384 fn test_tt_col_ap() {
385 test_gemm(&ABLayout::TT, true, false);
386 }
387 #[test]
388 fn test_nn_col_bp() {
389 test_gemm(&ABLayout::NN, false, true);
390 }
391 #[test]
392 fn test_nt_col_bp() {
393 test_gemm(&ABLayout::NT, false, true);
394 }
395 #[test]
396 fn test_tn_col_bp() {
397 test_gemm(&ABLayout::TN, false, true);
398 }
399 #[test]
400 fn test_tt_col_bp() {
401 test_gemm(&ABLayout::TT, false, true);
402 }
403
404 #[test]
405 fn test_nn_col_apbp() {
406 test_gemm(&ABLayout::NN, true, true);
407 }
408 #[test]
409 fn test_nt_col_apbp() {
410 test_gemm(&ABLayout::NT, true, true);
411 }
412 #[test]
413 fn test_tn_col_apbp() {
414 test_gemm(&ABLayout::TN, true, true);
415 }
416 #[test]
417 fn test_tt_col_apbp() {
418 test_gemm(&ABLayout::TT, true, true);
419 }
420}