1#[cfg(target_arch = "aarch64")]
2pub(crate) mod arm64;
3#[cfg(target_arch = "x86_64")]
4pub(crate) mod x86_64_arch;
5#[cfg(target_arch = "x86")]
6pub(crate) mod x86_arch;
7
8#[cfg(target_arch = "x86_64")]
9use x86_64_arch::{
10 get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, round_k_simd, round_m_simd, KernelDispatcher,
11};
12
13#[cfg(target_arch = "x86")]
14use x86_arch::{
15 get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, round_k_simd, round_m_simd, KernelDispatcher,
16};
17
18#[cfg(target_arch = "aarch64")]
19use arm64::{get_mcnckc_simd, packa_fn_simd, packb_fn_simd, pire_gemm, round_k_simd, round_m_simd, KernelDispatcher};
20
21pub(crate) mod reference;
22
23use core::mem::size_of;
24use num_complex::Complex;
25
26pub(crate) type TA = Complex<f32>;
27pub(crate) type TB = Complex<f32>;
28pub(crate) type TC = Complex<f32>;
29#[allow(unused)]
30const TC_SIZE: usize = size_of::<TC>();
31
32use reference::{packa_fn_ref, packb_fn_ref, round_k_ref, round_m_ref, RefGemm};
33
34use pire_base::{
35 get_cache_params, has_c32_compute, Array, ArrayMut, GemmCache, IdentityFn, PirePar, UnaryFn, AB_ALIGN,
36};
37
38pub trait UnaryFnC: UnaryFn<TC> {}
39impl<F: UnaryFn<TC>> UnaryFnC for F {}
40
41pub(crate) unsafe fn pire_cgemm_fused<F: UnaryFnC>(
42 m: usize,
43 n: usize,
44 k: usize,
45 alpha: TA,
46 a: Array<TA>,
47 b: Array<TB>,
48 beta: TC,
49 c: ArrayMut<TC>,
50 f: F,
51) {
52 let par = PirePar::default(m, n);
53 if has_c32_compute() {
54 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
55 {
56 let hw_config = KernelDispatcher::new(f);
57 pire_gemm(&hw_config, m, n, k, alpha, a, b, beta, c, &par);
58 return;
59 }
60 }
61 let hw_config = RefGemm::new(f);
63 reference::pire_gemm(&hw_config, m, n, k, alpha, a, b, beta, c, &par);
64}
65
66pub unsafe fn pire_cgemm(
67 m: usize,
68 n: usize,
69 k: usize,
70 alpha: TA,
71 a: *const TA,
72 a_rs: usize,
73 a_cs: usize,
74 b: *const TB,
75 b_rs: usize,
76 b_cs: usize,
77 beta: TC,
78 c: *mut TC,
79 c_rs: usize,
80 c_cs: usize,
81) {
82 let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
84 (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
85 } else {
86 (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
87 };
88 let a = Array::strided_matrix(a, a_rs, a_cs);
89 let b = Array::strided_matrix(b, b_rs, b_cs);
90 let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
91 let identity_fn = IdentityFn {};
92 pire_cgemm_fused(m, n, k, alpha, a, b, beta, c, identity_fn);
93}
94
95#[cfg(feature = "fuse")]
96pub unsafe fn pire_cgemm_fn_ptr(
97 m: usize,
98 n: usize,
99 k: usize,
100 alpha: TA,
101 a: *const TA,
102 a_rs: usize,
103 a_cs: usize,
104 b: *const TB,
105 b_rs: usize,
106 b_cs: usize,
107 beta: TC,
108 c: *mut TC,
109 c_rs: usize,
110 c_cs: usize,
111 unary: unsafe fn(*mut TC, usize),
112) {
113 let (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b) = if c_cs == 1 && c_rs != 1 {
115 (n, m, b_rs, b_cs, a_rs, a_cs, c_cs, c_rs, b, a)
116 } else {
117 (m, n, a_rs, a_cs, b_rs, b_cs, c_rs, c_cs, a, b)
118 };
119 let a = Array::strided_matrix(a, a_rs, a_cs);
120 let b = Array::strided_matrix(b, b_rs, b_cs);
121 let c = ArrayMut::strided_matrix(c, c_rs, c_cs);
122 pire_cgemm_fused(m, n, k, alpha, a, b, beta, c, unary);
123}
124
125fn dispatch_round_m() -> fn(usize) -> usize {
126 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
127 {
128 if has_c32_compute() {
129 return round_m_simd;
130 }
131 }
132 round_m_ref
133}
134fn dispatch_round_k() -> fn(usize) -> usize {
135 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
136 {
137 if has_c32_compute() {
138 return round_k_simd;
139 }
140 }
141 round_k_ref
142}
143
144fn dispatch_pack_a() -> unsafe fn(*const TA, *mut TA, usize, usize, usize, usize) {
145 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
146 {
147 if has_c32_compute() {
148 return packa_fn_simd;
149 }
150 }
151 packa_fn_ref
152}
153
154fn dispatch_pack_b() -> unsafe fn(*const TB, *mut TB, usize, usize, usize, usize) {
155 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
156 {
157 if has_c32_compute() {
158 return packb_fn_simd;
159 }
160 }
161 packb_fn_ref
162}
163
164fn dispatch_get_mcnckc() -> (usize, usize, usize) {
165 #[cfg(any(target_arch = "x86_64", target_arch = "x86", target_arch = "aarch64"))]
166 {
167 if has_c32_compute() {
168 return get_mcnckc_simd();
169 }
170 }
171 get_cache_params()
172}
173
174pire_base::packing_api!(TA, TB);
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use aligned_vec::avec;
180 use pire_base::{get_cache_params, matrix_size};
181 use pire_dev::{
182 check_gemm_c32, generate_k_dims, generate_m_dims, generate_n_dims, layout_to_strides, random_matrix_uniform,
183 ABLayout,
184 };
185
186 #[test]
187 fn test_pack_a() {
188 let a_stride_scale = 1;
189 let (mc, _, kc) = get_mcnckc();
190 let (mr, _, kr) = (48, 8, 8);
191 let m_dims = generate_m_dims(mc, mr);
192 let k_dims = generate_k_dims(kc, kr);
193
194 for &m in &m_dims {
195 for &k in &k_dims {
196 let a_rs = 1 * a_stride_scale;
197 let a_cs = m * a_stride_scale;
198 let a_size = a_size_packed(m, k);
199 let a = vec![TA::ZERO; m * k * a_stride_scale * size_of::<TA>()];
200 let mut ap = avec![[AB_ALIGN]| TA::ZERO; a_size];
201 let ap_array = pack_a(m, k, &a, a_rs, a_cs, &mut ap);
202 assert!(!ap_array.is_strided() || m == 1);
203 }
204 }
205 }
206
207 #[test]
208 fn test_pack_b() {
209 let b_stride_scale = 1;
210 let (_, nc, kc) = get_mcnckc();
211 let (_, nr, kr) = (48, 8, 8);
212 let n_dims = generate_n_dims(nc, nr);
213 let k_dims = generate_k_dims(kc, kr);
214
215 for &n in &n_dims {
216 for &k in &k_dims {
217 let b_rs = 1 * b_stride_scale;
218 let b_cs = k * b_stride_scale;
219 let b_size = b_size_packed(n, k);
220 let b = vec![TB::ZERO; b_size];
221 let mut bp = avec!([AB_ALIGN]| TB::ZERO; b_size);
222 let bp_array = pack_b(n, k, &b, b_rs, b_cs, &mut bp);
223 assert!(!bp_array.is_strided() || n == 1);
224 }
225 }
226 }
227
228 #[allow(unreachable_code)]
229 pub(crate) fn get_mcnckc() -> (usize, usize, usize) {
230 #[cfg(target_arch = "x86_64")]
231 {
232 return x86_64_arch::get_mcnckc_simd();
233 }
234 get_cache_params()
235 }
236
237 unsafe fn unary_fn_test(c: *mut TC, m: usize) {
238 for i in 0..m {
239 *c.add(i) *= 2.0;
240 }
241 }
242
243 const EPS: f64 = 2e-2;
244
245 static ALPHA_ARR: [TA; 1] = [Complex { re: 1.0, im: 0.79 }];
246 static BETA_ARR: [TC; 1] = [Complex { re: 1.0, im: 1.7 }];
247
248 fn test_gemm(layout: &ABLayout, is_a_packed: bool, is_b_packed: bool) {
249 let a_stride_scale = 1;
250 let b_stride_scale = 1;
251 let c_stride_scale = 2;
252 let (mc, nc, kc) = get_mcnckc();
253 let (mr, nr, kr) = (48, 8, 8);
254 let m_dims = generate_m_dims(mc, mr);
255 let n_dims = generate_n_dims(nc, nr);
256 let k_dims = generate_k_dims(kc, kr);
257 let unary_fn: unsafe fn(*mut TC, usize) = unary_fn_test;
258 let m_max = *m_dims.iter().max().unwrap();
259 let n_max = *n_dims.iter().max().unwrap();
260 let k_max = *k_dims.iter().max().unwrap();
261 let a_size = matrix_size(m_max, k_max) * a_stride_scale;
262 let b_size = matrix_size(k_max, n_max) * b_stride_scale;
263 let c_size = matrix_size(m_max, n_max) * c_stride_scale;
264 let mut a = vec![TA::ZERO; a_size];
265 let mut b = vec![TB::ZERO; b_size];
266 random_matrix_uniform(&mut a);
267 random_matrix_uniform(&mut b);
268 let mut c = vec![TC::ZERO; c_size];
269 let mut c_ref = vec![TC::ZERO; c_size];
270
271 let ap_size = if is_a_packed { a_size_packed(m_max, k_max) } else { 0 };
272 let mut ap = avec![[AB_ALIGN]| TA::ZERO; ap_size];
273
274 let bp_size = if is_b_packed { b_size_packed(n_max, k_max) } else { 0 };
275 let mut bp = avec![[AB_ALIGN]| TB::ZERO; bp_size];
276 for &m in &m_dims {
277 for &n in &n_dims {
278 for &k in &k_dims {
279 let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = layout_to_strides(&layout, m, n, k);
280 let (a_rs, a_cs, b_rs, b_cs, c_rs, c_cs) = (
281 a_rs * a_stride_scale,
282 a_cs * a_stride_scale,
283 b_rs * b_stride_scale,
284 b_cs * b_stride_scale,
285 c_rs * c_stride_scale,
286 c_cs * c_stride_scale,
287 );
288 let a_matrix = if is_a_packed {
289 pack_a(m, k, &a, a_rs, a_cs, &mut ap)
290 } else {
291 Array::strided_matrix(a.as_ptr(), a_rs, a_cs)
292 };
293 let b_matrix = if is_b_packed {
294 pack_b(n, k, &b, b_rs, b_cs, &mut bp)
295 } else {
296 Array::strided_matrix(b.as_ptr(), b_rs, b_cs)
297 };
298 for alpha in ALPHA_ARR {
299 for beta in BETA_ARR {
300 random_matrix_uniform(&mut c);
301 c_ref.copy_from_slice(&c);
302 let c_matrix = ArrayMut::strided_matrix(c.as_mut_ptr(), c_rs, c_cs);
303 unsafe {
304 pire_cgemm_fused(m, n, k, alpha, a_matrix, b_matrix, beta, c_matrix, unary_fn);
305 }
306 let diff_max = unsafe {
307 check_gemm_c32(
308 m,
309 n,
310 k,
311 alpha,
312 a.as_ptr(),
313 a_rs,
314 a_cs,
315 b.as_ptr(),
316 b_rs,
317 b_cs,
318 beta,
319 &mut c,
320 c_rs,
321 c_cs,
322 &mut c_ref,
323 unary_fn,
324 EPS,
325 )
326 };
327 assert!(
334 diff_max < EPS,
335 "diff_max: {}, m: {}, n: {}, k: {}, alpha: {}, beta: {}",
336 diff_max,
337 m,
338 n,
339 k,
340 alpha,
341 beta
342 );
343 }
344 }
345 }
346 }
347 }
348 }
349 #[test]
350 fn test_nn_col() {
351 test_gemm(&ABLayout::NN, false, false);
352 }
353
354 #[test]
355 fn test_nt_col() {
356 test_gemm(&ABLayout::NT, false, false);
357 }
358
359 #[test]
360 fn test_tn_col() {
361 test_gemm(&ABLayout::TN, false, false);
362 }
363
364 #[test]
365 fn test_tt_col() {
366 test_gemm(&ABLayout::TT, false, false);
367 }
368 #[test]
369 fn test_nn_col_ap() {
370 test_gemm(&ABLayout::NN, true, false);
371 }
372 #[test]
373 fn test_nt_col_ap() {
374 test_gemm(&ABLayout::NT, true, false);
375 }
376 #[test]
377 fn test_tn_col_ap() {
378 test_gemm(&ABLayout::TN, true, false);
379 }
380 #[test]
381 fn test_tt_col_ap() {
382 test_gemm(&ABLayout::TT, true, false);
383 }
384 #[test]
385 fn test_nn_col_bp() {
386 test_gemm(&ABLayout::NN, false, true);
387 }
388 #[test]
389 fn test_nt_col_bp() {
390 test_gemm(&ABLayout::NT, false, true);
391 }
392 #[test]
393 fn test_tn_col_bp() {
394 test_gemm(&ABLayout::TN, false, true);
395 }
396 #[test]
397 fn test_tt_col_bp() {
398 test_gemm(&ABLayout::TT, false, true);
399 }
400
401 #[test]
402 fn test_nn_col_apbp() {
403 test_gemm(&ABLayout::NN, true, true);
404 }
405 #[test]
406 fn test_nt_col_apbp() {
407 test_gemm(&ABLayout::NT, true, true);
408 }
409 #[test]
410 fn test_tn_col_apbp() {
411 test_gemm(&ABLayout::TN, true, true);
412 }
413 #[test]
414 fn test_tt_col_apbp() {
415 test_gemm(&ABLayout::TT, true, true);
416 }
417}