1#[cfg(target_arch = "aarch64")]
2pub(crate) mod arm64;
3#[cfg(target_arch = "x86_64")]
4pub(crate) mod x86_64_arch;
5
6#[cfg(target_arch = "x86_64")]
7use x86_64_arch::{
8 get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, pire_gemm_f32, round_k_simd, round_m_simd,
9 KernelDispatcher, KernelDispatcherF32,
10};
11
12use core::mem::size_of;
13
14#[cfg(target_arch = "aarch64")]
15use arm64::{get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, round_k_simd, round_m_simd, KernelDispatcher};
16pub(crate) mod reference;
17
18pub(crate) type TA = f16;
19pub(crate) type TB = f16;
20pub(crate) type TC = f16;
21#[allow(unused)]
22const TC_SIZE: usize = size_of::<TC>();
23
24pub use half::f16;
25
26use pire_base::{
27 get_cache_params, has_f16_compute, has_f16f32_compute, Array, ArrayMut, GemmCache, IdentityFn, PirePar, UnaryFn,
28 AB_ALIGN,
29};
30use reference::{packa_fn_ref, packb_fn_ref, round_k_ref, round_m_ref, RefGemm};
31
32pub trait UnaryFnC: UnaryFn<TC> {}
33impl<F: UnaryFn<TC>> UnaryFnC for F {}
34
35pub(crate) unsafe fn pire_hgemm_fused<F: UnaryFnC>(
36 m: usize,
37 n: usize,
38 k: usize,
39 alpha: TA,
40 a: Array<TA>,
41 b: Array<TB>,
42 beta: TC,
43 c: ArrayMut<TC>,
44 f: F,
45) {
46 let par = PirePar::default(m, n);
47 if has_f16_compute() {
48 #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
49 {
50 let hw_config = KernelDispatcher::new(f);
51 pire_gemm(&hw_config, m, n, k, alpha, a, b, beta, c, &par);
52 return;
53 }
54 }
55 if has_f16f32_compute() {
56 #[cfg(target_arch = "x86_64")]
57 {
58 let hw_config = KernelDispatcherF32::new(f);
59 pire_gemm_f32(&hw_config, m, n, k, alpha.to_f32(), a, b, beta.to_f32(), c, &par);
60 return;
61 }
62 }
63
64 let hw_config = RefGemm::new(f);
66 reference::pire_gemm(&hw_config, m, n, k, alpha.to_f32(), a, b, beta.to_f32(), c, &par);
67}
68
69pub unsafe fn pire_hgemm(
70 m: usize,
71 n: usize,
72 k: usize,
73 alpha: f16,
74 a: *const f16,
75 a_rs: usize,
76 a_cs: usize,
77 b: *const f16,
78 b_rs: usize,
79 b_cs: usize,
80 beta: f16,
81 c: *mut f16,
82 c_rs: usize,
83 c_cs: usize,
84) {
85 let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
87 (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
88 } else {
89 (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
90 };
91 let a = Array::strided_matrix(a, a_rs, a_cs);
92 let b = Array::strided_matrix(b, b_rs, b_cs);
93 let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
94 let identity_fn = IdentityFn {};
95 pire_hgemm_fused(m, n, k, alpha, a, b, beta, c, identity_fn);
96}
97
98#[cfg(feature = "fuse")]
99pub unsafe fn pire_hgemm_fn_ptr(
100 m: usize,
101 n: usize,
102 k: usize,
103 alpha: f16,
104 a: *const f16,
105 a_rs: usize,
106 a_cs: usize,
107 b: *const f16,
108 b_rs: usize,
109 b_cs: usize,
110 beta: f16,
111 c: *mut f16,
112 c_rs: usize,
113 c_cs: usize,
114 unary: unsafe fn(*mut f16, usize),
115) {
116 let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
118 (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
119 } else {
120 (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
121 };
122 let a = Array::strided_matrix(a, a_rs, a_cs);
123 let b = Array::strided_matrix(b, b_rs, b_cs);
124 let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
125 pire_hgemm_fused(m, n, k, alpha, a, b, beta, c, unary);
126}
127
128fn dispatch_round_m() -> fn(usize) -> usize {
129 #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
130 {
131 if has_f16_compute() || has_f16f32_compute() {
132 return round_m_simd;
133 }
134 }
135 round_m_ref
136}
137fn dispatch_round_k() -> fn(usize) -> usize {
138 #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
139 {
140 if has_f16_compute() || has_f16f32_compute() {
141 return round_k_simd;
142 }
143 }
144 round_k_ref
145}
146
147fn dispatch_pack_a() -> unsafe fn(*const TA, *mut TA, usize, usize, usize, usize) {
148 #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
149 {
150 if has_f16_compute() || has_f16f32_compute() {
151 return packa_fn_simd;
152 }
153 }
154 packa_fn_ref
155}
156
157fn dispatch_pack_b() -> unsafe fn(*const TB, *mut TB, usize, usize, usize, usize) {
158 #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
159 {
160 if has_f16_compute() || has_f16f32_compute() {
161 return packb_fn_simd;
162 }
163 }
164 packb_fn_ref
165}
166
167fn dispatch_get_mcnckc() -> (usize, usize, usize) {
168 #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
169 {
170 if has_f16_compute() || has_f16f32_compute() {
171 return get_mcnckc_simd();
172 }
173 }
174 get_cache_params()
175}
176
177pire_base::packing_api!(TA, TB);
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use pire_base::{get_cache_params, matrix_size};
183 use pire_dev::{
184 check_gemm_f16, generate_k_dims, generate_m_dims, generate_n_dims, layout_to_strides, random_matrix_uniform,
185 ABLayout,
186 };
187 #[test]
188 fn test_pack_a() {
189 let a_stride_scale = 1;
190 let (mc, _, kc) = get_mcnckc();
191 let (mr, _, kr) = (48, 8, 8);
192 let m_dims = generate_m_dims(mc, mr);
193 let k_dims = generate_k_dims(kc, kr);
194
195 for &m in &m_dims {
196 for &k in &k_dims {
197 let a_rs = 1 * a_stride_scale;
198 let a_cs = m * a_stride_scale;
199 let a_size = a_size_packed(m, k);
200 let a = vec![TA::ZERO; m * k * a_stride_scale];
201 let mut ap = vec![TA::ZERO; a_size + AB_ALIGN];
202 let ap_align_offset = ap.as_ptr().align_offset(AB_ALIGN);
203 let ap_array = pack_a(m, k, &a, a_rs, a_cs, &mut ap[ap_align_offset..]);
204 assert!(!ap_array.is_strided() || m == 1);
205 }
206 }
207 }
208
209 #[test]
210 fn test_pack_b() {
211 let b_stride_scale = 1;
212 let (_, nc, kc) = get_mcnckc();
213 let (_, nr, kr) = (48, 8, 8);
214 let n_dims = generate_n_dims(nc, nr);
215 let k_dims = generate_k_dims(kc, kr);
216
217 for &n in &n_dims {
218 for &k in &k_dims {
219 let b_rs = 1 * b_stride_scale;
220 let b_cs = k * b_stride_scale;
221 let b_size = b_size_packed(n, k);
222 let b = vec![TB::ZERO; n * k * b_stride_scale];
223 let mut bp = vec![TB::ZERO; b_size + AB_ALIGN];
224 let bp_align_offset = bp.as_ptr().align_offset(AB_ALIGN);
225 let bp_array = pack_b(n, k, &b, b_rs, b_cs, &mut bp[bp_align_offset..]);
226 assert!(!bp_array.is_strided() || n == 1);
227 }
228 }
229 }
230
231 #[allow(unreachable_code)]
232 pub(crate) fn get_mcnckc() -> (usize, usize, usize) {
233 #[cfg(target_arch = "x86_64")]
234 {
235 return x86_64_arch::get_mcnckc_simd();
236 }
237 get_cache_params()
238 }
239
240 unsafe fn unary_fn_test(c: *mut TC, m: usize) {
241 for i in 0..m {
242 *c.add(i) *= f16::from_f32(2.0);
243 }
244 }
245
246 const EPS: f64 = 10e-1;
247
248 static ALPHA_ARR: [f16; 1] = [f16::from_f32_const(1.23)];
249 static BETA_ARR: [f16; 1] = [f16::from_f32_const(1.17)];
250
251 fn test_gemm(layout: &ABLayout, is_a_packed: bool, is_b_packed: bool) {
252 let a_stride_scale = 1;
253 let b_stride_scale = 1;
254 let c_stride_scale = 2;
255 let (mc, nc, kc) = get_mcnckc();
256 let (mr, nr, kr) = (48, 8, 8);
257 let m_dims = generate_m_dims(mc, mr);
258 let n_dims = generate_n_dims(nc, nr);
259 let k_dims = generate_k_dims(kc, kr);
260 let unary_fn: unsafe fn(*mut TC, usize) = unary_fn_test;
261 let m_max = *m_dims.iter().max().unwrap();
262 let n_max = *n_dims.iter().max().unwrap();
263 let k_max = *k_dims.iter().max().unwrap();
264 let a_size = matrix_size(m_max, k_max) * a_stride_scale;
265 let b_size = matrix_size(k_max, n_max) * b_stride_scale;
266 let c_size = matrix_size(m_max, n_max) * c_stride_scale;
267 let mut a = vec![TA::ZERO; a_size];
268 let mut b = vec![TB::ZERO; b_size];
269 random_matrix_uniform(&mut a);
270 random_matrix_uniform(&mut b);
271 let mut c = vec![TC::ZERO; c_size];
272 let mut c_ref = vec![TC::ZERO; c_size];
273
274 let ap_size = if is_a_packed { a_size_packed(m_max, k_max) } else { 0 };
275 let mut ap = vec![TA::ZERO; ap_size + AB_ALIGN];
276 let ap_align_offset = ap.as_ptr().align_offset(AB_ALIGN);
277 let ap_mut_ref = &mut ap[ap_align_offset..];
278
279 let bp_size = if is_b_packed { b_size_packed(n_max, k_max) } else { 0 };
280 let mut bp = vec![TB::ZERO; bp_size + AB_ALIGN];
281 let bp_align_offset = bp.as_ptr().align_offset(AB_ALIGN);
282 let bp_mut_ref = &mut bp[bp_align_offset..];
283 for &m in &m_dims {
284 for &n in &n_dims {
285 for &k in &k_dims {
286 let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = layout_to_strides(&layout, m, n, k);
287 let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = (
288 a_rs * a_stride_scale,
289 a_cs * a_stride_scale,
290 b_rs * b_stride_scale,
291 b_cs * b_stride_scale,
292 c_rs * c_stride_scale,
293 c_cs * c_stride_scale,
294 );
295 let a_matrix = if is_a_packed {
296 pack_a(m, k, &a, a_rs, a_cs, ap_mut_ref)
297 } else {
298 Array::strided_matrix(a.as_ptr(), a_rs, a_cs)
299 };
300 let b_matrix = if is_b_packed {
301 pack_b(n, k, &b, b_rs, b_cs, bp_mut_ref)
302 } else {
303 Array::strided_matrix(b.as_ptr(), b_rs, b_cs)
304 };
305 for alpha in ALPHA_ARR {
306 for beta in BETA_ARR {
307 random_matrix_uniform(&mut c);
308 c_ref.copy_from_slice(&c);
309 let c_matrix = ArrayMut::strided_matrix(c.as_mut_ptr(), c_rs, c_cs);
310 unsafe {
311 pire_hgemm_fused(m, n, k, alpha, a_matrix, b_matrix, beta, c_matrix, unary_fn);
312 }
313 let diff_max = unsafe {
314 check_gemm_f16(
315 m,
316 n,
317 k,
318 alpha,
319 a.as_ptr(),
320 a_rs,
321 a_cs,
322 b.as_ptr(),
323 b_rs,
324 b_cs,
325 beta,
326 &mut c,
327 c_rs,
328 c_cs,
329 &mut c_ref,
330 unary_fn,
331 EPS,
332 )
333 };
334 assert!(
341 diff_max < EPS,
342 "diff_max: {}, m: {}, n: {}, k: {}, alpha: {}, beta: {}",
343 diff_max,
344 m,
345 n,
346 k,
347 alpha,
348 beta
349 );
350 }
351 }
352 }
353 }
354 }
355 }
356 #[test]
357 fn test_nn_col_ap() {
358 test_gemm(&ABLayout::NN, true, false);
359 }
360 #[test]
361 fn test_nt_col_ap() {
362 test_gemm(&ABLayout::NT, true, false);
363 }
364 #[test]
365 fn test_tn_col_ap() {
366 test_gemm(&ABLayout::TN, true, false);
367 }
368 #[test]
369 fn test_tt_col_ap() {
370 test_gemm(&ABLayout::TT, true, false);
371 }
372 #[test]
373 fn test_nn_col_bp() {
374 test_gemm(&ABLayout::NN, false, true);
375 }
376 #[test]
377 fn test_nt_col_bp() {
378 test_gemm(&ABLayout::NT, false, true);
379 }
380 #[test]
381 fn test_tn_col_bp() {
382 test_gemm(&ABLayout::TN, false, true);
383 }
384 #[test]
385 fn test_tt_col_bp() {
386 test_gemm(&ABLayout::TT, false, true);
387 }
388 #[test]
389 fn test_nn_col() {
390 test_gemm(&ABLayout::NN, false, false);
391 }
392 #[test]
393 fn test_nt_col() {
394 test_gemm(&ABLayout::NT, false, false);
395 }
396 #[test]
397 fn test_tn_col() {
398 test_gemm(&ABLayout::TN, false, false);
399 }
400 #[test]
401 fn test_tt_col() {
402 test_gemm(&ABLayout::TT, false, false);
403 }
404 #[test]
405 fn test_nn_col_apbp() {
406 test_gemm(&ABLayout::NN, true, true);
407 }
408 #[test]
409 fn test_nt_col_apbp() {
410 test_gemm(&ABLayout::NT, true, true);
411 }
412 #[test]
413 fn test_tn_col_apbp() {
414 test_gemm(&ABLayout::TN, true, true);
415 }
416 #[test]
417 fn test_tt_col_apbp() {
418 test_gemm(&ABLayout::TT, true, true);
419 }
420}