1#![allow(unsafe_op_in_unsafe_fn)]
4use std::cell::RefCell;
15use std::time::Instant;
16
17use crate::error::TruenoError;
18
19#[cfg(target_arch = "x86_64")]
20use super::microkernels::microkernel_16x8_avx512;
21#[cfg(target_arch = "x86_64")]
22use super::microkernels::microkernel_8x6_true_asm;
23use super::microkernels::microkernel_scalar;
24use super::packing::{pack_a_block, pack_b_block, packed_a_size, packed_b_size};
25#[cfg(target_arch = "x86_64")]
26use super::packing::{pack_a_block_512, pack_b_block_512, packed_a_size_512, packed_b_size_512};
27use super::prepacked::PrepackedB;
28use super::profiler::{BlisProfileLevel, BlisProfiler};
29use super::reference::gemm_reference;
30use super::{KC, MC, MR, NC, NR};
31#[cfg(target_arch = "x86_64")]
32use super::{KC_512, MC_512, MR_512, NC_512, NR_512};
33
34thread_local! {
38 static TL_PACKED_A: RefCell<Vec<f32>> = const { RefCell::new(Vec::new()) };
39 static TL_PACKED_B: RefCell<Vec<f32>> = const { RefCell::new(Vec::new()) };
40 static TL_C_MICRO: RefCell<Vec<f32>> = const { RefCell::new(Vec::new()) };
41}
42
43#[inline(always)]
45fn load_c_tile(
46 c: &[f32],
47 c_micro: &mut [f32],
48 row: usize,
49 col: usize,
50 mr: usize,
51 nr: usize,
52 n: usize,
53) {
54 for jj in 0..nr {
55 for ii in 0..mr {
56 c_micro[jj * MR + ii] = c[(row + ii) * n + (col + jj)];
57 }
58 for ii in mr..MR {
59 c_micro[jj * MR + ii] = 0.0;
60 }
61 }
62 for jj in nr..NR {
63 for ii in 0..MR {
64 c_micro[jj * MR + ii] = 0.0;
65 }
66 }
67}
68
69#[inline(always)]
71fn store_c_tile(
72 c: &mut [f32],
73 c_micro: &[f32],
74 row: usize,
75 col: usize,
76 mr: usize,
77 nr: usize,
78 n: usize,
79) {
80 for jj in 0..nr {
81 for ii in 0..mr {
82 c[(row + ii) * n + (col + jj)] = c_micro[jj * MR + ii];
83 }
84 }
85}
86
87#[inline(always)]
89fn dispatch_microkernel(
90 kc: usize,
91 a_panel: &[f32],
92 b_panel: &[f32],
93 c_micro: &mut [f32],
94 mr_block: usize,
95 nr_block: usize,
96) {
97 #[cfg(target_arch = "x86_64")]
98 {
99 if is_x86_feature_detected!("avx2")
100 && is_x86_feature_detected!("fma")
101 && mr_block == MR
102 && nr_block == NR
103 {
104 unsafe {
106 microkernel_8x6_true_asm(
107 kc,
108 a_panel.as_ptr(),
109 b_panel.as_ptr(),
110 c_micro.as_mut_ptr(),
111 MR,
112 );
113 }
114 return;
115 }
116 }
117 microkernel_scalar(kc, a_panel, b_panel, c_micro, MR);
118}
119
120#[allow(clippy::too_many_arguments)]
122fn compute_macroblock(
123 c: &mut [f32],
124 packed_a: &[f32],
125 packed_b: &[f32],
126 c_micro: &mut [f32],
127 ic: usize,
128 jc: usize,
129 mc_block: usize,
130 nc_block: usize,
131 kc_block: usize,
132 n: usize,
133 profiler: &mut Option<&mut BlisProfiler>,
134) {
135 let track_time = profiler.is_some();
138 let midi_start = if track_time { Some(Instant::now()) } else { None };
139
140 for ir in (0..mc_block).step_by(MR) {
141 let mr_block = MR.min(mc_block - ir);
142 for jr in (0..nc_block).step_by(NR) {
143 let nr_block = NR.min(nc_block - jr);
144 let micro_start = if track_time { Some(Instant::now()) } else { None };
145
146 let a_panel = &packed_a[(ir / MR) * MR * kc_block..];
147 let b_panel = &packed_b[(jr / NR) * NR * kc_block..];
148
149 load_c_tile(c, c_micro, ic + ir, jc + jr, mr_block, nr_block, n);
150 dispatch_microkernel(kc_block, a_panel, b_panel, c_micro, mr_block, nr_block);
151 store_c_tile(c, c_micro, ic + ir, jc + jr, mr_block, nr_block, n);
152
153 if let (Some(ref mut prof), Some(start)) = (profiler.as_deref_mut(), micro_start) {
154 prof.record(
155 BlisProfileLevel::Micro,
156 start.elapsed().as_nanos() as u64,
157 (2 * mr_block * nr_block * kc_block) as u64,
158 );
159 }
160 }
161 }
162
163 if let (Some(ref mut prof), Some(start)) = (profiler.as_deref_mut(), midi_start) {
164 prof.record(
165 BlisProfileLevel::Midi,
166 start.elapsed().as_nanos() as u64,
167 (2 * mc_block * nc_block * kc_block) as u64,
168 );
169 }
170}
171
172#[cfg(target_arch = "x86_64")]
190#[target_feature(enable = "avx2", enable = "fma")]
191unsafe fn gemm_direct_rowmajor(
192 m: usize,
193 n: usize,
194 k: usize,
195 a: &[f32],
196 b: &[f32],
197 c: &mut [f32],
198) -> Result<(), TruenoError> {
199 use std::arch::x86_64::*;
200
201 let a_ptr = a.as_ptr();
202 let b_ptr = b.as_ptr();
203 let c_ptr = c.as_mut_ptr();
204
205 unsafe {
206 for ir in (0..m).step_by(8) {
207 for jr in (0..n).step_by(8) {
208 let c_base = c_ptr.add(ir * n + jr);
210 let mut c0 = _mm256_loadu_ps(c_base);
211 let mut c1 = _mm256_loadu_ps(c_base.add(n));
212 let mut c2 = _mm256_loadu_ps(c_base.add(2 * n));
213 let mut c3 = _mm256_loadu_ps(c_base.add(3 * n));
214 let mut c4 = _mm256_loadu_ps(c_base.add(4 * n));
215 let mut c5 = _mm256_loadu_ps(c_base.add(5 * n));
216 let mut c6 = _mm256_loadu_ps(c_base.add(6 * n));
217 let mut c7 = _mm256_loadu_ps(c_base.add(7 * n));
218
219 let a0 = a_ptr.add(ir * k);
221 let a1 = a_ptr.add((ir + 1) * k);
222 let a2 = a_ptr.add((ir + 2) * k);
223 let a3 = a_ptr.add((ir + 3) * k);
224 let a4 = a_ptr.add((ir + 4) * k);
225 let a5 = a_ptr.add((ir + 5) * k);
226 let a6 = a_ptr.add((ir + 6) * k);
227 let a7 = a_ptr.add((ir + 7) * k);
228
229 let b_base = b_ptr.add(jr);
231
232 let k4 = k / 4;
234 let k_rem = k % 4;
235
236 for p4 in 0..k4 {
237 let p = p4 * 4;
238
239 let b_row = _mm256_loadu_ps(b_base.add(p * n));
241 c0 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a0.add(p)), b_row, c0);
242 c1 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a1.add(p)), b_row, c1);
243 c2 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a2.add(p)), b_row, c2);
244 c3 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a3.add(p)), b_row, c3);
245 c4 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a4.add(p)), b_row, c4);
246 c5 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a5.add(p)), b_row, c5);
247 c6 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a6.add(p)), b_row, c6);
248 c7 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a7.add(p)), b_row, c7);
249
250 let b_row = _mm256_loadu_ps(b_base.add((p + 1) * n));
252 c0 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a0.add(p + 1)), b_row, c0);
253 c1 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a1.add(p + 1)), b_row, c1);
254 c2 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a2.add(p + 1)), b_row, c2);
255 c3 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a3.add(p + 1)), b_row, c3);
256 c4 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a4.add(p + 1)), b_row, c4);
257 c5 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a5.add(p + 1)), b_row, c5);
258 c6 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a6.add(p + 1)), b_row, c6);
259 c7 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a7.add(p + 1)), b_row, c7);
260
261 let b_row = _mm256_loadu_ps(b_base.add((p + 2) * n));
263 c0 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a0.add(p + 2)), b_row, c0);
264 c1 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a1.add(p + 2)), b_row, c1);
265 c2 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a2.add(p + 2)), b_row, c2);
266 c3 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a3.add(p + 2)), b_row, c3);
267 c4 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a4.add(p + 2)), b_row, c4);
268 c5 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a5.add(p + 2)), b_row, c5);
269 c6 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a6.add(p + 2)), b_row, c6);
270 c7 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a7.add(p + 2)), b_row, c7);
271
272 let b_row = _mm256_loadu_ps(b_base.add((p + 3) * n));
274 c0 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a0.add(p + 3)), b_row, c0);
275 c1 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a1.add(p + 3)), b_row, c1);
276 c2 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a2.add(p + 3)), b_row, c2);
277 c3 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a3.add(p + 3)), b_row, c3);
278 c4 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a4.add(p + 3)), b_row, c4);
279 c5 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a5.add(p + 3)), b_row, c5);
280 c6 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a6.add(p + 3)), b_row, c6);
281 c7 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a7.add(p + 3)), b_row, c7);
282 }
283
284 let base_rem = k4 * 4;
286 for rp in 0..k_rem {
287 let pp = base_rem + rp;
288 let b_row = _mm256_loadu_ps(b_base.add(pp * n));
289 c0 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a0.add(pp)), b_row, c0);
290 c1 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a1.add(pp)), b_row, c1);
291 c2 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a2.add(pp)), b_row, c2);
292 c3 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a3.add(pp)), b_row, c3);
293 c4 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a4.add(pp)), b_row, c4);
294 c5 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a5.add(pp)), b_row, c5);
295 c6 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a6.add(pp)), b_row, c6);
296 c7 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a7.add(pp)), b_row, c7);
297 }
298
299 _mm256_storeu_ps(c_base, c0);
301 _mm256_storeu_ps(c_base.add(n), c1);
302 _mm256_storeu_ps(c_base.add(2 * n), c2);
303 _mm256_storeu_ps(c_base.add(3 * n), c3);
304 _mm256_storeu_ps(c_base.add(4 * n), c4);
305 _mm256_storeu_ps(c_base.add(5 * n), c5);
306 _mm256_storeu_ps(c_base.add(6 * n), c6);
307 _mm256_storeu_ps(c_base.add(7 * n), c7);
308 }
309 }
310 }
311 Ok(())
312}
313
314#[cfg(target_arch = "x86_64")]
317#[target_feature(enable = "avx2", enable = "fma")]
318unsafe fn gemm_small_strided_avx2(
319 m: usize,
320 n: usize,
321 k: usize,
322 a: &[f32],
323 b: &[f32],
324 c: &mut [f32],
325) -> Result<(), TruenoError> {
326 use std::arch::x86_64::*;
327 unsafe {
328 for jr in (0..n).step_by(NR) {
329 let nr = NR.min(n - jr);
330 for ir in (0..m).step_by(MR) {
331 let mr = MR.min(m - ir);
332 let mut cv = [_mm256_setzero_ps(); 6];
333 for j in 0..nr {
334 if mr == MR {
335 cv[j] = _mm256_set_ps(
336 *c.get_unchecked((ir + 7) * n + jr + j),
337 *c.get_unchecked((ir + 6) * n + jr + j),
338 *c.get_unchecked((ir + 5) * n + jr + j),
339 *c.get_unchecked((ir + 4) * n + jr + j),
340 *c.get_unchecked((ir + 3) * n + jr + j),
341 *c.get_unchecked((ir + 2) * n + jr + j),
342 *c.get_unchecked((ir + 1) * n + jr + j),
343 *c.get_unchecked(ir * n + jr + j),
344 );
345 } else {
346 let mut t = [0.0f32; 8];
347 for i in 0..mr {
348 t[i] = *c.get_unchecked((ir + i) * n + jr + j);
349 }
350 cv[j] = _mm256_loadu_ps(t.as_ptr());
351 }
352 }
353 for p in 0..k {
354 let a_col = if mr == MR {
355 _mm256_set_ps(
356 *a.get_unchecked((ir + 7) * k + p),
357 *a.get_unchecked((ir + 6) * k + p),
358 *a.get_unchecked((ir + 5) * k + p),
359 *a.get_unchecked((ir + 4) * k + p),
360 *a.get_unchecked((ir + 3) * k + p),
361 *a.get_unchecked((ir + 2) * k + p),
362 *a.get_unchecked((ir + 1) * k + p),
363 *a.get_unchecked(ir * k + p),
364 )
365 } else {
366 let mut t = [0.0f32; 8];
367 for i in 0..mr {
368 t[i] = *a.get_unchecked((ir + i) * k + p);
369 }
370 _mm256_loadu_ps(t.as_ptr())
371 };
372 let bp = b.as_ptr().add(p * n + jr);
373 if nr == NR {
375 cv[0] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp), cv[0]);
376 cv[1] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(1)), cv[1]);
377 cv[2] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(2)), cv[2]);
378 cv[3] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(3)), cv[3]);
379 cv[4] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(4)), cv[4]);
380 cv[5] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(5)), cv[5]);
381 } else {
382 for j in 0..nr {
383 cv[j] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(j)), cv[j]);
384 }
385 }
386 }
387 for j in 0..nr {
388 let mut t = [0.0f32; 8];
389 _mm256_storeu_ps(t.as_mut_ptr(), cv[j]);
390 for i in 0..mr {
391 *c.get_unchecked_mut((ir + i) * n + jr + j) = t[i];
392 }
393 }
394 }
395 }
396 }
397 Ok(())
398}
399
400#[cfg(target_arch = "x86_64")]
407#[target_feature(enable = "avx2", enable = "fma")]
408unsafe fn gemm_small_nopack_8x8(
409 m: usize,
410 n: usize,
411 k: usize,
412 a: &[f32],
413 b: &[f32],
414 c: &mut [f32],
415) -> Result<(), TruenoError> {
416 use crate::blis::microkernels::microkernel_8x8_avx2_fma;
417 use std::arch::x86_64::*;
418
419 let panels_m = m / 8;
420 let panels_n = n / 8;
421
422 let mut packed_a = [0.0f32; 8 * 256];
424 let mut c_micro = [0.0f32; 64];
425
426 let mut all_packed_b = vec![0.0f32; panels_n * k * 8];
429
430 unsafe {
431 for jr_panel in 0..panels_n {
432 let jr = jr_panel * 8;
433 let b_dst = all_packed_b.as_mut_ptr().add(jr_panel * k * 8);
434 for p in 0..k {
435 _mm256_storeu_ps(b_dst.add(p * 8), _mm256_loadu_ps(b.as_ptr().add(p * n + jr)));
436 }
437 }
438
439 for ir_panel in 0..panels_m {
440 let ir = ir_panel * 8;
441
442 let k_blocks = k / 8;
444 let k_rem = k_blocks * 8;
445 for kb in 0..k_blocks {
446 let p = kb * 8;
447 let r0 = _mm256_loadu_ps(a.as_ptr().add(ir * k + p));
448 let r1 = _mm256_loadu_ps(a.as_ptr().add((ir + 1) * k + p));
449 let r2 = _mm256_loadu_ps(a.as_ptr().add((ir + 2) * k + p));
450 let r3 = _mm256_loadu_ps(a.as_ptr().add((ir + 3) * k + p));
451 let r4 = _mm256_loadu_ps(a.as_ptr().add((ir + 4) * k + p));
452 let r5 = _mm256_loadu_ps(a.as_ptr().add((ir + 5) * k + p));
453 let r6 = _mm256_loadu_ps(a.as_ptr().add((ir + 6) * k + p));
454 let r7 = _mm256_loadu_ps(a.as_ptr().add((ir + 7) * k + p));
455
456 let t0 = _mm256_unpacklo_ps(r0, r1);
457 let t1 = _mm256_unpackhi_ps(r0, r1);
458 let t2 = _mm256_unpacklo_ps(r2, r3);
459 let t3 = _mm256_unpackhi_ps(r2, r3);
460 let t4 = _mm256_unpacklo_ps(r4, r5);
461 let t5 = _mm256_unpackhi_ps(r4, r5);
462 let t6 = _mm256_unpacklo_ps(r6, r7);
463 let t7 = _mm256_unpackhi_ps(r6, r7);
464
465 let u0 = _mm256_shuffle_ps(t0, t2, 0x44);
466 let u1 = _mm256_shuffle_ps(t0, t2, 0xEE);
467 let u2 = _mm256_shuffle_ps(t1, t3, 0x44);
468 let u3 = _mm256_shuffle_ps(t1, t3, 0xEE);
469 let u4 = _mm256_shuffle_ps(t4, t6, 0x44);
470 let u5 = _mm256_shuffle_ps(t4, t6, 0xEE);
471 let u6 = _mm256_shuffle_ps(t5, t7, 0x44);
472 let u7 = _mm256_shuffle_ps(t5, t7, 0xEE);
473
474 let dst = packed_a.as_mut_ptr().add(p * 8);
475 _mm256_storeu_ps(dst, _mm256_permute2f128_ps(u0, u4, 0x20));
476 _mm256_storeu_ps(dst.add(8), _mm256_permute2f128_ps(u1, u5, 0x20));
477 _mm256_storeu_ps(dst.add(16), _mm256_permute2f128_ps(u2, u6, 0x20));
478 _mm256_storeu_ps(dst.add(24), _mm256_permute2f128_ps(u3, u7, 0x20));
479 _mm256_storeu_ps(dst.add(32), _mm256_permute2f128_ps(u0, u4, 0x31));
480 _mm256_storeu_ps(dst.add(40), _mm256_permute2f128_ps(u1, u5, 0x31));
481 _mm256_storeu_ps(dst.add(48), _mm256_permute2f128_ps(u2, u6, 0x31));
482 _mm256_storeu_ps(dst.add(56), _mm256_permute2f128_ps(u3, u7, 0x31));
483 }
484 for p in k_rem..k {
485 for i in 0..8 {
486 *packed_a.get_unchecked_mut(p * 8 + i) = *a.get_unchecked((ir + i) * k + p);
487 }
488 }
489
490 for jr_panel in 0..panels_n {
491 let jr = jr_panel * 8;
492 let packed_b_ptr = all_packed_b.as_ptr().add(jr_panel * k * 8);
493
494 for jj in 0..8 {
496 for ii in 0..8 {
497 c_micro[jj * 8 + ii] = *c.get_unchecked((ir + ii) * n + jr + jj);
498 }
499 }
500
501 microkernel_8x8_avx2_fma(
502 k,
503 packed_a.as_ptr(),
504 packed_b_ptr,
505 c_micro.as_mut_ptr(),
506 8,
507 );
508
509 for jj in 0..8 {
511 for ii in 0..8 {
512 *c.get_unchecked_mut((ir + ii) * n + jr + jj) = c_micro[jj * 8 + ii];
513 }
514 }
515 }
516 }
517 }
518 Ok(())
519}
520
521#[cfg(target_arch = "x86_64")]
525#[target_feature(enable = "avx2", enable = "fma")]
526#[allow(dead_code)] unsafe fn gemm_small_8x8(
528 m: usize,
529 n: usize,
530 k: usize,
531 a: &[f32],
532 b: &[f32],
533 c: &mut [f32],
534) -> Result<(), TruenoError> {
535 use crate::blis::microkernels::microkernel_8x8_avx2_fma;
536 let mut packed_a = vec![0.0f32; m * k];
538 let mut packed_b = vec![0.0f32; k * n];
539 let mut c_micro = [0.0f32; 8 * 8]; let panels_m = (m + 7) / 8;
543 for panel in 0..panels_m {
544 let ir = panel * 8;
545 let mr = 8.min(m - ir);
546 for p in 0..k {
547 for i in 0..8 {
548 unsafe {
549 packed_a[panel * 8 * k + p * 8 + i] =
550 if i < mr { *a.get_unchecked((ir + i) * k + p) } else { 0.0 };
551 }
552 }
553 }
554 }
555 let panels_n = (n + 7) / 8;
557 for panel in 0..panels_n {
558 let jr = panel * 8;
559 let nr = 8.min(n - jr);
560 for p in 0..k {
561 for j in 0..8 {
562 unsafe {
563 packed_b[panel * 8 * k + p * 8 + j] =
564 if j < nr { *b.get_unchecked(p * n + jr + j) } else { 0.0 };
565 }
566 }
567 }
568 }
569
570 unsafe {
572 for ir_panel in 0..panels_m {
573 let ir = ir_panel * 8;
574 let mr = 8.min(m - ir);
575 for jr_panel in 0..panels_n {
576 let jr = jr_panel * 8;
577 let nr = 8.min(n - jr);
578 for jj in 0..8 {
580 for ii in 0..8 {
581 c_micro[jj * 8 + ii] = if ii < mr && jj < nr {
582 *c.get_unchecked((ir + ii) * n + jr + jj)
583 } else {
584 0.0
585 };
586 }
587 }
588 let ap = packed_a.as_ptr().add(ir_panel * 8 * k);
589 let bp = packed_b.as_ptr().add(jr_panel * 8 * k);
590 microkernel_8x8_avx2_fma(k, ap, bp, c_micro.as_mut_ptr(), 8);
591 for jj in 0..nr {
593 for ii in 0..mr {
594 *c.get_unchecked_mut((ir + ii) * n + jr + jj) = c_micro[jj * 8 + ii];
595 }
596 }
597 }
598 }
599 }
600 Ok(())
601}
602
603#[cfg(target_arch = "x86_64")]
607#[target_feature(enable = "avx512f")]
608unsafe fn gemm_small_avx512_16x8(
609 m: usize,
610 n: usize,
611 k: usize,
612 a: &[f32],
613 b: &[f32],
614 c: &mut [f32],
615) -> Result<(), TruenoError> {
616 use super::microkernels::microkernel_16x8_avx512;
617 use std::arch::x86_64::*;
618
619 let panels_m = m / 16;
620 let panels_n = n / 8;
621
622 let mut all_packed_b = vec![0.0f32; panels_n * k * 8];
624 unsafe {
625 for jr_panel in 0..panels_n {
626 let jr = jr_panel * 8;
627 let b_dst = all_packed_b.as_mut_ptr().add(jr_panel * k * 8);
628 for p in 0..k {
629 _mm256_storeu_ps(b_dst.add(p * 8), _mm256_loadu_ps(b.as_ptr().add(p * n + jr)));
630 }
631 }
632 }
633
634 let mut packed_a = [0.0f32; 16 * 256];
636 let mut c_micro = [0.0f32; 16 * 8];
637
638 unsafe {
639 for ir_panel in 0..panels_m {
640 let ir = ir_panel * 16;
641
642 let k_blocks = k / 8;
644 let k_rem_start = k_blocks * 8;
645
646 for kb in 0..k_blocks {
647 let p = kb * 8;
648
649 let r0 = _mm256_loadu_ps(a.as_ptr().add(ir * k + p));
651 let r1 = _mm256_loadu_ps(a.as_ptr().add((ir + 1) * k + p));
652 let r2 = _mm256_loadu_ps(a.as_ptr().add((ir + 2) * k + p));
653 let r3 = _mm256_loadu_ps(a.as_ptr().add((ir + 3) * k + p));
654 let r4 = _mm256_loadu_ps(a.as_ptr().add((ir + 4) * k + p));
655 let r5 = _mm256_loadu_ps(a.as_ptr().add((ir + 5) * k + p));
656 let r6 = _mm256_loadu_ps(a.as_ptr().add((ir + 6) * k + p));
657 let r7 = _mm256_loadu_ps(a.as_ptr().add((ir + 7) * k + p));
658
659 let t0 = _mm256_unpacklo_ps(r0, r1);
660 let t1 = _mm256_unpackhi_ps(r0, r1);
661 let t2 = _mm256_unpacklo_ps(r2, r3);
662 let t3 = _mm256_unpackhi_ps(r2, r3);
663 let t4 = _mm256_unpacklo_ps(r4, r5);
664 let t5 = _mm256_unpackhi_ps(r4, r5);
665 let t6 = _mm256_unpacklo_ps(r6, r7);
666 let t7 = _mm256_unpackhi_ps(r6, r7);
667
668 let u0 = _mm256_shuffle_ps(t0, t2, 0x44);
669 let u1 = _mm256_shuffle_ps(t0, t2, 0xEE);
670 let u2 = _mm256_shuffle_ps(t1, t3, 0x44);
671 let u3 = _mm256_shuffle_ps(t1, t3, 0xEE);
672 let u4 = _mm256_shuffle_ps(t4, t6, 0x44);
673 let u5 = _mm256_shuffle_ps(t4, t6, 0xEE);
674 let u6 = _mm256_shuffle_ps(t5, t7, 0x44);
675 let u7 = _mm256_shuffle_ps(t5, t7, 0xEE);
676
677 let dst = packed_a.as_mut_ptr().add(p * 16);
679 _mm256_storeu_ps(dst, _mm256_permute2f128_ps(u0, u4, 0x20));
680 _mm256_storeu_ps(dst.add(16), _mm256_permute2f128_ps(u1, u5, 0x20));
681 _mm256_storeu_ps(dst.add(32), _mm256_permute2f128_ps(u2, u6, 0x20));
682 _mm256_storeu_ps(dst.add(48), _mm256_permute2f128_ps(u3, u7, 0x20));
683 _mm256_storeu_ps(dst.add(64), _mm256_permute2f128_ps(u0, u4, 0x31));
684 _mm256_storeu_ps(dst.add(80), _mm256_permute2f128_ps(u1, u5, 0x31));
685 _mm256_storeu_ps(dst.add(96), _mm256_permute2f128_ps(u2, u6, 0x31));
686 _mm256_storeu_ps(dst.add(112), _mm256_permute2f128_ps(u3, u7, 0x31));
687
688 let r0 = _mm256_loadu_ps(a.as_ptr().add((ir + 8) * k + p));
690 let r1 = _mm256_loadu_ps(a.as_ptr().add((ir + 9) * k + p));
691 let r2 = _mm256_loadu_ps(a.as_ptr().add((ir + 10) * k + p));
692 let r3 = _mm256_loadu_ps(a.as_ptr().add((ir + 11) * k + p));
693 let r4 = _mm256_loadu_ps(a.as_ptr().add((ir + 12) * k + p));
694 let r5 = _mm256_loadu_ps(a.as_ptr().add((ir + 13) * k + p));
695 let r6 = _mm256_loadu_ps(a.as_ptr().add((ir + 14) * k + p));
696 let r7 = _mm256_loadu_ps(a.as_ptr().add((ir + 15) * k + p));
697
698 let t0 = _mm256_unpacklo_ps(r0, r1);
699 let t1 = _mm256_unpackhi_ps(r0, r1);
700 let t2 = _mm256_unpacklo_ps(r2, r3);
701 let t3 = _mm256_unpackhi_ps(r2, r3);
702 let t4 = _mm256_unpacklo_ps(r4, r5);
703 let t5 = _mm256_unpackhi_ps(r4, r5);
704 let t6 = _mm256_unpacklo_ps(r6, r7);
705 let t7 = _mm256_unpackhi_ps(r6, r7);
706
707 let u0 = _mm256_shuffle_ps(t0, t2, 0x44);
708 let u1 = _mm256_shuffle_ps(t0, t2, 0xEE);
709 let u2 = _mm256_shuffle_ps(t1, t3, 0x44);
710 let u3 = _mm256_shuffle_ps(t1, t3, 0xEE);
711 let u4 = _mm256_shuffle_ps(t4, t6, 0x44);
712 let u5 = _mm256_shuffle_ps(t4, t6, 0xEE);
713 let u6 = _mm256_shuffle_ps(t5, t7, 0x44);
714 let u7 = _mm256_shuffle_ps(t5, t7, 0xEE);
715
716 let dst_lo = packed_a.as_mut_ptr().add(p * 16 + 8);
718 _mm256_storeu_ps(dst_lo, _mm256_permute2f128_ps(u0, u4, 0x20));
719 _mm256_storeu_ps(dst_lo.add(16), _mm256_permute2f128_ps(u1, u5, 0x20));
720 _mm256_storeu_ps(dst_lo.add(32), _mm256_permute2f128_ps(u2, u6, 0x20));
721 _mm256_storeu_ps(dst_lo.add(48), _mm256_permute2f128_ps(u3, u7, 0x20));
722 _mm256_storeu_ps(dst_lo.add(64), _mm256_permute2f128_ps(u0, u4, 0x31));
723 _mm256_storeu_ps(dst_lo.add(80), _mm256_permute2f128_ps(u1, u5, 0x31));
724 _mm256_storeu_ps(dst_lo.add(96), _mm256_permute2f128_ps(u2, u6, 0x31));
725 _mm256_storeu_ps(dst_lo.add(112), _mm256_permute2f128_ps(u3, u7, 0x31));
726 }
727
728 for p in k_rem_start..k {
730 for i in 0..16 {
731 *packed_a.get_unchecked_mut(p * 16 + i) = *a.get_unchecked((ir + i) * k + p);
732 }
733 }
734
735 for jr_panel in 0..panels_n {
736 let jr = jr_panel * 8;
737 let packed_b_ptr = all_packed_b.as_ptr().add(jr_panel * k * 8);
738
739 for jj in 0..8 {
741 for ii in 0..16 {
742 c_micro[jj * 16 + ii] = *c.get_unchecked((ir + ii) * n + jr + jj);
743 }
744 }
745
746 microkernel_16x8_avx512(
747 k,
748 packed_a.as_ptr(),
749 packed_b_ptr,
750 c_micro.as_mut_ptr(),
751 16,
752 );
753
754 for jj in 0..8 {
756 for ii in 0..16 {
757 *c.get_unchecked_mut((ir + ii) * n + jr + jj) = c_micro[jj * 16 + ii];
758 }
759 }
760 }
761 }
762 }
763 Ok(())
764}
765
766fn validate_gemm_dims(
768 m: usize,
769 n: usize,
770 k: usize,
771 a: &[f32],
772 b: &[f32],
773 c: &[f32],
774) -> Result<(), TruenoError> {
775 if a.len() != m * k {
776 return Err(TruenoError::InvalidInput(format!(
777 "A size mismatch: expected {}, got {}",
778 m * k,
779 a.len()
780 )));
781 }
782 if b.len() != k * n {
783 return Err(TruenoError::InvalidInput(format!(
784 "B size mismatch: expected {}, got {}",
785 k * n,
786 b.len()
787 )));
788 }
789 if c.len() != m * n {
790 return Err(TruenoError::InvalidInput(format!(
791 "C size mismatch: expected {}, got {}",
792 m * n,
793 c.len()
794 )));
795 }
796 Ok(())
797}
798
799#[inline(always)]
801fn record_prof(
802 profiler: &mut Option<&mut BlisProfiler>,
803 level: BlisProfileLevel,
804 start: Option<Instant>,
805 flops: u64,
806) {
807 if let (Some(ref mut prof), Some(s)) = (profiler.as_deref_mut(), start) {
808 prof.record(level, s.elapsed().as_nanos() as u64, flops);
809 }
810}
811
812pub fn gemm_blis(
821 m: usize,
822 n: usize,
823 k: usize,
824 a: &[f32],
825 b: &[f32],
826 c: &mut [f32],
827 mut profiler: Option<&mut BlisProfiler>,
828) -> Result<(), TruenoError> {
829 contract_pre_flops_per_tile!();
830 validate_gemm_dims(m, n, k, a, b, c)?;
831
832 if m == 0 || n == 0 || k == 0 {
833 return Ok(());
834 }
835 if m * n * k < 4096 {
836 return gemm_reference(m, n, k, a, b, c);
837 }
838
839 #[cfg(target_arch = "x86_64")]
841 if profiler.is_none()
842 && m <= 256
843 && n <= 256
844 && k <= 256
845 && is_x86_feature_detected!("avx2")
846 && is_x86_feature_detected!("fma")
847 {
848 unsafe {
849 if m <= 128 && n <= 128 && m % 8 == 0 && n % 8 == 0 {
851 return gemm_direct_rowmajor(m, n, k, a, b, c);
852 }
853 if is_x86_feature_detected!("avx512f") && m >= 16 && m % 16 == 0 && n % 8 == 0 {
855 return gemm_small_avx512_16x8(m, n, k, a, b, c);
856 }
857 if m >= MR && m % 8 == 0 && n % 8 == 0 {
858 return gemm_small_nopack_8x8(m, n, k, a, b, c);
859 }
860 return gemm_small_strided_avx2(m, n, k, a, b, c);
861 }
862 }
863
864 #[cfg(target_arch = "x86_64")]
869 if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("fma") {
870 return unsafe { gemm_blis_avx512_large(m, n, k, a, b, c, &mut profiler) };
871 }
872
873 #[cfg(target_arch = "x86_64")]
875 if profiler.is_none() && is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
876 return unsafe { gemm_blis_nr8_rowmajor_c(m, n, k, a, b, c) };
877 }
878
879 let track_time = profiler.is_some();
881 let start = if track_time { Some(Instant::now()) } else { None };
882
883 let mc = MC.min(m);
884 let nc = NC.min(n);
885 let kc = KC.min(k);
886
887 let needed_a = packed_a_size(mc, kc);
888 let needed_b = packed_b_size(kc, nc);
889 let needed_c = MR * NR;
890
891 TL_PACKED_A.with(|tl_a| {
894 TL_PACKED_B.with(|tl_b| {
895 TL_C_MICRO.with(|tl_c| {
896 let mut packed_a = tl_a.borrow_mut();
897 let mut packed_b = tl_b.borrow_mut();
898 let mut c_micro = tl_c.borrow_mut();
899
900 if packed_a.len() < needed_a {
903 packed_a.resize(needed_a, 0.0);
904 }
905 if packed_b.len() < needed_b {
906 packed_b.resize(needed_b, 0.0);
907 }
908 if c_micro.len() < needed_c {
909 c_micro.resize(needed_c, 0.0);
910 }
911
912 for jc in (0..n).step_by(NC) {
913 let nc_block = NC.min(n - jc);
914
915 for pc in (0..k).step_by(KC) {
916 let kc_block = KC.min(k - pc);
917
918 let pack_start = if track_time { Some(Instant::now()) } else { None };
919 pack_b_block(b, n, pc, jc, kc_block, nc_block, &mut packed_b);
920 record_prof(&mut profiler, BlisProfileLevel::Pack, pack_start, 0);
921
922 for ic in (0..m).step_by(MC) {
923 let mc_block = MC.min(m - ic);
924
925 let pack_start = if track_time { Some(Instant::now()) } else { None };
926 pack_a_block(a, k, ic, pc, mc_block, kc_block, &mut packed_a);
927 record_prof(&mut profiler, BlisProfileLevel::Pack, pack_start, 0);
928
929 compute_macroblock(
930 c,
931 &packed_a,
932 &packed_b,
933 &mut c_micro,
934 ic,
935 jc,
936 mc_block,
937 nc_block,
938 kc_block,
939 n,
940 &mut profiler,
941 );
942 }
943 }
944 }
945
946 if let (Some(prof), Some(s)) = (profiler, start) {
947 prof.record(
948 BlisProfileLevel::Macro,
949 s.elapsed().as_nanos() as u64,
950 (2 * m * n * k) as u64,
951 );
952 }
953 });
954 });
955 });
956
957 contract_post_flops_per_tile!(c);
958 Ok(())
959}
960
961#[cfg(target_arch = "x86_64")]
973#[target_feature(enable = "avx512f", enable = "fma")]
974unsafe fn gemm_blis_avx512_large(
975 m: usize,
976 n: usize,
977 k: usize,
978 a: &[f32],
979 b: &[f32],
980 c: &mut [f32],
981 profiler: &mut Option<&mut BlisProfiler>,
982) -> Result<(), TruenoError> {
983 let track_time = profiler.is_some();
985 let start = if track_time { Some(Instant::now()) } else { None };
986
987 let blk = if n >= 32 {
999 super::cache_topology::blocking_8x32()
1000 } else {
1001 super::cache_topology::blocking_8x16()
1002 };
1003 let mr = blk.mr;
1004 let nr = blk.nr;
1005 let mc = blk.mc.min(m);
1006 let nc = blk.nc.min(n);
1007 let kc_param = blk.kc;
1008
1009 TL_PACKED_A.with(|tl_a| {
1010 TL_PACKED_B.with(|tl_b| {
1011 let mut packed_a = tl_a.borrow_mut();
1012 let mut packed_b = tl_b.borrow_mut();
1013
1014 let needed_a = packed_a_size(mc, kc_param);
1015 let b_panels = (nc + nr - 1) / nr;
1017 let needed_b = b_panels * nr * kc_param;
1018 if packed_a.len() < needed_a {
1019 packed_a.resize(needed_a, 0.0);
1020 }
1021 if packed_b.len() < needed_b {
1022 packed_b.resize(needed_b, 0.0);
1023 }
1024
1025 for jc in (0..n).step_by(nc) {
1026 let nc_block = nc.min(n - jc);
1027
1028 for pc in (0..k).step_by(kc_param) {
1029 let kc_block = kc_param.min(k - pc);
1030
1031 if nr == 48 {
1033 pack_b_block_generic(b, n, pc, jc, kc_block, nc_block, 48, &mut packed_b);
1034 } else if nr == 32 {
1035 pack_b_block_generic(b, n, pc, jc, kc_block, nc_block, 32, &mut packed_b);
1036 } else {
1037 pack_b_block_nr16(b, n, pc, jc, kc_block, nc_block, &mut packed_b);
1038 }
1039
1040 for ic in (0..m).step_by(mc) {
1041 let mc_block = mc.min(m - ic);
1042
1043 pack_a_block(a, k, ic, pc, mc_block, kc_block, &mut packed_a);
1044
1045 let panels_m = (mc_block + mr - 1) / mr;
1046 let panels_n = (nc_block + nr - 1) / nr;
1047
1048 for ir_panel in 0..panels_m {
1049 let ir = ir_panel * mr;
1050 let mr_block = mr.min(mc_block - ir);
1051
1052 for jr_panel in 0..panels_n {
1053 let jr = jr_panel * nr;
1054 let nr_block = nr.min(nc_block - jr);
1055
1056 let a_panel = &packed_a[ir_panel * mr * kc_block..];
1057 let b_panel = &packed_b[jr_panel * nr * kc_block..];
1058
1059 if mr_block == 8 && nr_block == 48 && nr == 48 {
1060 unsafe {
1063 super::microkernels::codegen::microkernel_8x48_avx512_gen(
1064 kc_block,
1065 a_panel.as_ptr(),
1066 b_panel.as_ptr(),
1067 c.as_mut_ptr().add((ic + ir) * n + (jc + jr)),
1068 n,
1069 );
1070 }
1071 } else if mr_block == 8 && nr_block == 32 && nr == 32 {
1072 unsafe {
1074 avx512_microkernel_8x32_rowmajor(
1075 kc_block,
1076 a_panel.as_ptr(),
1077 b_panel.as_ptr(),
1078 c.as_mut_ptr().add((ic + ir) * n + (jc + jr)),
1079 n,
1080 );
1081 }
1082 } else if mr_block == 8 && nr_block == 16 && nr == 16 {
1083 unsafe {
1085 avx512_microkernel_8x16_rowmajor(
1086 kc_block,
1087 a_panel.as_ptr(),
1088 b_panel.as_ptr(),
1089 c.as_mut_ptr().add((ic + ir) * n + (jc + jr)),
1090 n,
1091 );
1092 }
1093 } else {
1094 for ir_local in 0..mr_block {
1096 for jr_local in 0..nr_block {
1097 let mut sum =
1098 c[(ic + ir + ir_local) * n + (jc + jr + jr_local)];
1099 for p in 0..kc_block {
1100 sum += a_panel[p * mr + ir_local]
1101 * b_panel[p * nr + jr_local];
1102 }
1103 c[(ic + ir + ir_local) * n + (jc + jr + jr_local)] =
1104 sum;
1105 }
1106 }
1107 }
1108 }
1109 }
1110 }
1111 }
1112 }
1113 if let (Some(prof), Some(s)) = (profiler.as_mut(), start) {
1115 prof.record_avx512_blis(m, n, k, s.elapsed());
1116 }
1117
1118 Ok(())
1119 })
1120 })
1121}
1122
1123#[cfg(target_arch = "x86_64")]
1133#[target_feature(enable = "avx512f", enable = "fma")]
1134unsafe fn gemm_blis_avx512_bcast_b(
1135 m: usize,
1136 n: usize,
1137 k: usize,
1138 a: &[f32],
1139 b: &[f32],
1140 c: &mut [f32],
1141) -> Result<(), TruenoError> {
1142 let blk = super::cache_topology::blocking_64x6_bcast_b();
1143 let mr = blk.mr; let nr = blk.nr; let mc = blk.mc.min(m);
1146 let nc = blk.nc.min(n);
1147 let kc = blk.kc;
1148
1149 let a_panels = (mc + mr - 1) / mr;
1151 let needed_a = a_panels * mr * kc;
1152 let b_panels = (nc + nr - 1) / nr;
1153 let needed_b = b_panels * nr * kc;
1154
1155 let mut packed_a = vec![0.0f32; needed_a];
1156 let mut packed_b = vec![0.0f32; needed_b];
1157
1158 for jc in (0..n).step_by(nc) {
1159 let nc_block = nc.min(n - jc);
1160
1161 for pc in (0..k).step_by(kc) {
1162 let kc_block = kc.min(k - pc);
1163
1164 pack_b_block_generic(b, n, pc, jc, kc_block, nc_block, nr, &mut packed_b);
1166
1167 for ic in (0..m).step_by(mc) {
1168 let mc_block = mc.min(m - ic);
1169
1170 pack_a_block_generic(a, k, ic, pc, mc_block, kc_block, mr, &mut packed_a);
1172
1173 let panels_m = (mc_block + mr - 1) / mr;
1174 let panels_n = (nc_block + nr - 1) / nr;
1175
1176 for ir_panel in 0..panels_m {
1177 let ir = ir_panel * mr;
1178 let mr_block = mr.min(mc_block - ir);
1179
1180 for jr_panel in 0..panels_n {
1181 let jr = jr_panel * nr;
1182 let nr_block = nr.min(nc_block - jr);
1183
1184 let a_panel = &packed_a[ir_panel * mr * kc_block..];
1185 let b_panel = &packed_b[jr_panel * nr * kc_block..];
1186
1187 if mr_block == 64 && nr_block == 6 {
1188 unsafe {
1190 super::microkernels::codegen::microkernel_64x6_avx512_bcast_b(
1191 kc_block,
1192 a_panel.as_ptr(),
1193 b_panel.as_ptr(),
1194 c.as_mut_ptr().add((ic + ir) * n + (jc + jr)),
1195 n,
1196 );
1197 }
1198 } else {
1199 for ir_local in 0..mr_block {
1201 for jr_local in 0..nr_block {
1202 let mut sum = 0.0f32;
1203 for p in 0..kc_block {
1204 sum +=
1205 a_panel[p * mr + ir_local] * b_panel[p * nr + jr_local];
1206 }
1207 c[(ic + ir + ir_local) * n + (jc + jr + jr_local)] += sum;
1208 }
1209 }
1210 }
1211 }
1212 }
1213 }
1214 }
1215 }
1216
1217 Ok(())
1218}
1219
1220fn pack_a_block_generic(
1224 a: &[f32],
1225 lda: usize,
1226 row_start: usize,
1227 col_start: usize,
1228 rows: usize,
1229 cols: usize,
1230 mr: usize,
1231 packed: &mut [f32],
1232) {
1233 let panels = (rows + mr - 1) / mr;
1234 let mut pack_idx = 0;
1235 for panel in 0..panels {
1236 let ir = panel * mr;
1237 let mr_actual = mr.min(rows - ir);
1238 for col in 0..cols {
1239 for row in 0..mr {
1240 if row < mr_actual {
1241 packed[pack_idx] = a[(row_start + ir + row) * lda + col_start + col];
1242 } else {
1243 packed[pack_idx] = 0.0;
1244 }
1245 pack_idx += 1;
1246 }
1247 }
1248 }
1249}
1250
1251#[cfg(target_arch = "x86_64")]
1254pub fn gemm_blis_broadcast_b(
1255 m: usize,
1256 n: usize,
1257 k: usize,
1258 a: &[f32],
1259 b: &[f32],
1260 c: &mut [f32],
1261) -> Result<(), TruenoError> {
1262 if a.len() != m * k || b.len() != k * n || c.len() != m * n {
1263 return Err(TruenoError::InvalidInput("Dimension mismatch".to_string()));
1264 }
1265 if std::arch::is_x86_feature_detected!("avx512f") {
1266 unsafe { gemm_blis_avx512_bcast_b(m, n, k, a, b, c) }
1268 } else {
1269 gemm_blis(m, n, k, a, b, c, None)
1270 }
1271}
1272
1273#[cfg(target_arch = "x86_64")]
1279#[target_feature(enable = "avx512f", enable = "fma")]
1280pub(super) unsafe fn avx512_microkernel_8x16_rowmajor(
1281 k: usize,
1282 a: *const f32, b: *const f32, c: *mut f32,
1285 ldc: usize, ) {
1287 use std::arch::x86_64::*;
1288
1289 let mut c0 = _mm512_loadu_ps(c);
1291 let mut c1 = _mm512_loadu_ps(c.add(ldc));
1292 let mut c2 = _mm512_loadu_ps(c.add(2 * ldc));
1293 let mut c3 = _mm512_loadu_ps(c.add(3 * ldc));
1294 let mut c4 = _mm512_loadu_ps(c.add(4 * ldc));
1295 let mut c5 = _mm512_loadu_ps(c.add(5 * ldc));
1296 let mut c6 = _mm512_loadu_ps(c.add(6 * ldc));
1297 let mut c7 = _mm512_loadu_ps(c.add(7 * ldc));
1298
1299 for p in 0..k {
1305 let b_row = _mm512_loadu_ps(b.add(p * 16));
1306 let ap = a.add(p * 8); c0 = _mm512_fmadd_ps(_mm512_set1_ps(*ap), b_row, c0);
1309 c1 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(1)), b_row, c1);
1310 c2 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(2)), b_row, c2);
1311 c3 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(3)), b_row, c3);
1312 c4 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(4)), b_row, c4);
1313 c5 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(5)), b_row, c5);
1314 c6 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(6)), b_row, c6);
1315 c7 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(7)), b_row, c7);
1316 }
1317
1318 _mm512_storeu_ps(c, c0);
1320 _mm512_storeu_ps(c.add(ldc), c1);
1321 _mm512_storeu_ps(c.add(2 * ldc), c2);
1322 _mm512_storeu_ps(c.add(3 * ldc), c3);
1323 _mm512_storeu_ps(c.add(4 * ldc), c4);
1324 _mm512_storeu_ps(c.add(5 * ldc), c5);
1325 _mm512_storeu_ps(c.add(6 * ldc), c6);
1326 _mm512_storeu_ps(c.add(7 * ldc), c7);
1327}
1328
1329#[cfg(target_arch = "x86_64")]
1335#[target_feature(enable = "avx512f", enable = "fma")]
1336pub(super) unsafe fn avx512_microkernel_8x32_rowmajor(
1337 k: usize,
1338 a: *const f32, b: *const f32, c: *mut f32,
1341 ldc: usize, ) {
1343 use std::arch::x86_64::*;
1344
1345 let mut c0l = _mm512_loadu_ps(c);
1347 let mut c0h = _mm512_loadu_ps(c.add(16));
1348 let mut c1l = _mm512_loadu_ps(c.add(ldc));
1349 let mut c1h = _mm512_loadu_ps(c.add(ldc + 16));
1350 let mut c2l = _mm512_loadu_ps(c.add(2 * ldc));
1351 let mut c2h = _mm512_loadu_ps(c.add(2 * ldc + 16));
1352 let mut c3l = _mm512_loadu_ps(c.add(3 * ldc));
1353 let mut c3h = _mm512_loadu_ps(c.add(3 * ldc + 16));
1354 let mut c4l = _mm512_loadu_ps(c.add(4 * ldc));
1355 let mut c4h = _mm512_loadu_ps(c.add(4 * ldc + 16));
1356 let mut c5l = _mm512_loadu_ps(c.add(5 * ldc));
1357 let mut c5h = _mm512_loadu_ps(c.add(5 * ldc + 16));
1358 let mut c6l = _mm512_loadu_ps(c.add(6 * ldc));
1359 let mut c6h = _mm512_loadu_ps(c.add(6 * ldc + 16));
1360 let mut c7l = _mm512_loadu_ps(c.add(7 * ldc));
1361 let mut c7h = _mm512_loadu_ps(c.add(7 * ldc + 16));
1362
1363 for p in 0..k {
1369 let bl = _mm512_loadu_ps(b.add(p * 32));
1370 let bh = _mm512_loadu_ps(b.add(p * 32 + 16));
1371 let ap = a.add(p * 8);
1372
1373 let a0 = _mm512_set1_ps(*ap);
1374 c0l = _mm512_fmadd_ps(a0, bl, c0l);
1375 c0h = _mm512_fmadd_ps(a0, bh, c0h);
1376 let a1 = _mm512_set1_ps(*ap.add(1));
1377 c1l = _mm512_fmadd_ps(a1, bl, c1l);
1378 c1h = _mm512_fmadd_ps(a1, bh, c1h);
1379 let a2 = _mm512_set1_ps(*ap.add(2));
1380 c2l = _mm512_fmadd_ps(a2, bl, c2l);
1381 c2h = _mm512_fmadd_ps(a2, bh, c2h);
1382 let a3 = _mm512_set1_ps(*ap.add(3));
1383 c3l = _mm512_fmadd_ps(a3, bl, c3l);
1384 c3h = _mm512_fmadd_ps(a3, bh, c3h);
1385 let a4 = _mm512_set1_ps(*ap.add(4));
1386 c4l = _mm512_fmadd_ps(a4, bl, c4l);
1387 c4h = _mm512_fmadd_ps(a4, bh, c4h);
1388 let a5 = _mm512_set1_ps(*ap.add(5));
1389 c5l = _mm512_fmadd_ps(a5, bl, c5l);
1390 c5h = _mm512_fmadd_ps(a5, bh, c5h);
1391 let a6 = _mm512_set1_ps(*ap.add(6));
1392 c6l = _mm512_fmadd_ps(a6, bl, c6l);
1393 c6h = _mm512_fmadd_ps(a6, bh, c6h);
1394 let a7 = _mm512_set1_ps(*ap.add(7));
1395 c7l = _mm512_fmadd_ps(a7, bl, c7l);
1396 c7h = _mm512_fmadd_ps(a7, bh, c7h);
1397 }
1398
1399 _mm512_storeu_ps(c, c0l);
1401 _mm512_storeu_ps(c.add(16), c0h);
1402 _mm512_storeu_ps(c.add(ldc), c1l);
1403 _mm512_storeu_ps(c.add(ldc + 16), c1h);
1404 _mm512_storeu_ps(c.add(2 * ldc), c2l);
1405 _mm512_storeu_ps(c.add(2 * ldc + 16), c2h);
1406 _mm512_storeu_ps(c.add(3 * ldc), c3l);
1407 _mm512_storeu_ps(c.add(3 * ldc + 16), c3h);
1408 _mm512_storeu_ps(c.add(4 * ldc), c4l);
1409 _mm512_storeu_ps(c.add(4 * ldc + 16), c4h);
1410 _mm512_storeu_ps(c.add(5 * ldc), c5l);
1411 _mm512_storeu_ps(c.add(5 * ldc + 16), c5h);
1412 _mm512_storeu_ps(c.add(6 * ldc), c6l);
1413 _mm512_storeu_ps(c.add(6 * ldc + 16), c6h);
1414 _mm512_storeu_ps(c.add(7 * ldc), c7l);
1415 _mm512_storeu_ps(c.add(7 * ldc + 16), c7h);
1416}
1417
1418pub(super) fn pack_b_block_nr16(
1421 b: &[f32],
1422 ldb: usize,
1423 pc: usize,
1424 jc: usize,
1425 kc: usize,
1426 nc: usize,
1427 packed: &mut [f32],
1428) {
1429 let nr = 16;
1430 let panels = (nc + nr - 1) / nr;
1431 for panel in 0..panels {
1432 let j_start = panel * nr;
1433 let nr_local = nr.min(nc - j_start);
1434 for p in 0..kc {
1435 for j in 0..nr_local {
1436 packed[panel * nr * kc + p * nr + j] = b[(pc + p) * ldb + (jc + j_start + j)];
1437 }
1438 for j in nr_local..nr {
1440 packed[panel * nr * kc + p * nr + j] = 0.0;
1441 }
1442 }
1443 }
1444}
1445
1446pub(super) fn pack_b_block_generic(
1450 b: &[f32],
1451 ldb: usize,
1452 pc: usize,
1453 jc: usize,
1454 kc: usize,
1455 nc: usize,
1456 nr: usize,
1457 packed: &mut [f32],
1458) {
1459 #[cfg(target_arch = "x86_64")]
1460 if nr == 32 && std::arch::is_x86_feature_detected!("avx512f") {
1461 unsafe {
1463 pack_b_block_nr32_avx512(b, ldb, pc, jc, kc, nc, packed);
1464 }
1465 return;
1466 }
1467
1468 let panels = (nc + nr - 1) / nr;
1469 for panel in 0..panels {
1470 let j_start = panel * nr;
1471 let nr_local = nr.min(nc - j_start);
1472 for p in 0..kc {
1473 let dst_base = panel * nr * kc + p * nr;
1474 for j in 0..nr_local {
1475 packed[dst_base + j] = b[(pc + p) * ldb + (jc + j_start + j)];
1476 }
1477 for j in nr_local..nr {
1478 packed[dst_base + j] = 0.0;
1479 }
1480 }
1481 }
1482}
1483
1484#[cfg(target_arch = "x86_64")]
1488#[target_feature(enable = "avx512f")]
1489unsafe fn pack_b_block_nr32_avx512(
1490 b: &[f32],
1491 ldb: usize,
1492 pc: usize,
1493 jc: usize,
1494 kc: usize,
1495 nc: usize,
1496 packed: &mut [f32],
1497) {
1498 use std::arch::x86_64::*;
1499
1500 let nr = 32;
1501 let panels = (nc + nr - 1) / nr;
1502 for panel in 0..panels {
1503 let j_start = panel * nr;
1504 let nr_local = nr.min(nc - j_start);
1505
1506 if nr_local == 32 {
1507 let panel_base = panel * nr * kc;
1510 let b_col = jc + j_start;
1511 let kc2 = kc / 2 * 2;
1512 let mut p = 0;
1513 while p < kc2 {
1514 let src0 = b.as_ptr().add((pc + p) * ldb + b_col);
1515 let src1 = b.as_ptr().add((pc + p + 1) * ldb + b_col);
1516 let dst0 = packed.as_mut_ptr().add(panel_base + p * nr);
1517 let dst1 = packed.as_mut_ptr().add(panel_base + (p + 1) * nr);
1518 let v0a = _mm512_loadu_ps(src0);
1519 let v0b = _mm512_loadu_ps(src0.add(16));
1520 let v1a = _mm512_loadu_ps(src1);
1521 let v1b = _mm512_loadu_ps(src1.add(16));
1522 _mm512_storeu_ps(dst0, v0a);
1523 _mm512_storeu_ps(dst0.add(16), v0b);
1524 _mm512_storeu_ps(dst1, v1a);
1525 _mm512_storeu_ps(dst1.add(16), v1b);
1526 p += 2;
1527 }
1528 while p < kc {
1530 let src = b.as_ptr().add((pc + p) * ldb + b_col);
1531 let dst = packed.as_mut_ptr().add(panel_base + p * nr);
1532 let v0 = _mm512_loadu_ps(src);
1533 let v1 = _mm512_loadu_ps(src.add(16));
1534 _mm512_storeu_ps(dst, v0);
1535 _mm512_storeu_ps(dst.add(16), v1);
1536 p += 1;
1537 }
1538 } else {
1539 for p in 0..kc {
1541 let dst_base = panel * nr * kc + p * nr;
1542 for j in 0..nr_local {
1543 packed[dst_base + j] = b[(pc + p) * ldb + (jc + j_start + j)];
1544 }
1545 for j in nr_local..nr {
1546 packed[dst_base + j] = 0.0;
1547 }
1548 }
1549 }
1550 }
1551}
1552
1553#[cfg(target_arch = "x86_64")]
1563#[target_feature(enable = "avx2", enable = "fma")]
1564unsafe fn gemm_blis_nr8_rowmajor_c(
1565 m: usize,
1566 n: usize,
1567 k: usize,
1568 a: &[f32],
1569 b: &[f32],
1570 c: &mut [f32],
1571) -> Result<(), TruenoError> {
1572 use std::arch::x86_64::*;
1573
1574 let mc = 64_usize.min(m);
1579 let nc = 1024_usize.min(n);
1580 let kc_param = KC;
1581 let nr = 8_usize; let mr = MR; TL_PACKED_A.with(|tl_a| {
1585 TL_PACKED_B.with(|tl_b| {
1586 let mut packed_a = tl_a.borrow_mut();
1587 let mut packed_b = tl_b.borrow_mut();
1588
1589 let needed_a = packed_a_size(mc, kc_param);
1590 let needed_b = packed_b_size_512(kc_param, nc); if packed_a.len() < needed_a {
1592 packed_a.resize(needed_a, 0.0);
1593 }
1594 if packed_b.len() < needed_b {
1595 packed_b.resize(needed_b, 0.0);
1596 }
1597
1598 for jc in (0..n).step_by(nc) {
1600 let nc_block = nc.min(n - jc);
1601
1602 for pc in (0..k).step_by(kc_param) {
1603 let kc_block = kc_param.min(k - pc);
1604
1605 pack_b_block_512(b, n, pc, jc, kc_block, nc_block, &mut packed_b);
1607
1608 for ic in (0..m).step_by(mc) {
1609 let mc_block = mc.min(m - ic);
1610
1611 pack_a_block(a, k, ic, pc, mc_block, kc_block, &mut packed_a);
1613
1614 let panels_m = (mc_block + mr - 1) / mr;
1616 let panels_n = (nc_block + nr - 1) / nr;
1617
1618 for ir_panel in 0..panels_m {
1619 let ir = ir_panel * mr;
1620 let mr_block = mr.min(mc_block - ir);
1621
1622 for jr_panel in 0..panels_n {
1623 let jr = jr_panel * nr;
1624 let nr_block = nr.min(nc_block - jr);
1625
1626 let a_panel = &packed_a[ir_panel * mr * kc_block..];
1627 let b_panel = &packed_b[jr_panel * nr * kc_block..];
1628
1629 if mr_block == 8 && nr_block == 8 {
1630 unsafe {
1632 let c_base = c.as_mut_ptr().add((ic + ir) * n + (jc + jr));
1633
1634 let mut c0 = _mm256_loadu_ps(c_base);
1636 let mut c1 = _mm256_loadu_ps(c_base.add(n));
1637 let mut c2 = _mm256_loadu_ps(c_base.add(2 * n));
1638 let mut c3 = _mm256_loadu_ps(c_base.add(3 * n));
1639 let mut c4 = _mm256_loadu_ps(c_base.add(4 * n));
1640 let mut c5 = _mm256_loadu_ps(c_base.add(5 * n));
1641 let mut c6 = _mm256_loadu_ps(c_base.add(6 * n));
1642 let mut c7 = _mm256_loadu_ps(c_base.add(7 * n));
1643
1644 let ap = a_panel.as_ptr();
1645 let bp = b_panel.as_ptr();
1646
1647 let k4 = kc_block / 4;
1649 let k_rem = kc_block % 4;
1650
1651 for p4 in 0..k4 {
1652 let p = p4 * 4;
1653
1654 let b_row = _mm256_loadu_ps(bp.add(p * 8));
1655 c0 = _mm256_fmadd_ps(
1656 _mm256_broadcast_ss(&*ap.add(p * 8)),
1657 b_row,
1658 c0,
1659 );
1660 c1 = _mm256_fmadd_ps(
1661 _mm256_broadcast_ss(&*ap.add(p * 8 + 1)),
1662 b_row,
1663 c1,
1664 );
1665 c2 = _mm256_fmadd_ps(
1666 _mm256_broadcast_ss(&*ap.add(p * 8 + 2)),
1667 b_row,
1668 c2,
1669 );
1670 c3 = _mm256_fmadd_ps(
1671 _mm256_broadcast_ss(&*ap.add(p * 8 + 3)),
1672 b_row,
1673 c3,
1674 );
1675 c4 = _mm256_fmadd_ps(
1676 _mm256_broadcast_ss(&*ap.add(p * 8 + 4)),
1677 b_row,
1678 c4,
1679 );
1680 c5 = _mm256_fmadd_ps(
1681 _mm256_broadcast_ss(&*ap.add(p * 8 + 5)),
1682 b_row,
1683 c5,
1684 );
1685 c6 = _mm256_fmadd_ps(
1686 _mm256_broadcast_ss(&*ap.add(p * 8 + 6)),
1687 b_row,
1688 c6,
1689 );
1690 c7 = _mm256_fmadd_ps(
1691 _mm256_broadcast_ss(&*ap.add(p * 8 + 7)),
1692 b_row,
1693 c7,
1694 );
1695
1696 let b_row = _mm256_loadu_ps(bp.add((p + 1) * 8));
1697 c0 = _mm256_fmadd_ps(
1698 _mm256_broadcast_ss(&*ap.add((p + 1) * 8)),
1699 b_row,
1700 c0,
1701 );
1702 c1 = _mm256_fmadd_ps(
1703 _mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 1)),
1704 b_row,
1705 c1,
1706 );
1707 c2 = _mm256_fmadd_ps(
1708 _mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 2)),
1709 b_row,
1710 c2,
1711 );
1712 c3 = _mm256_fmadd_ps(
1713 _mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 3)),
1714 b_row,
1715 c3,
1716 );
1717 c4 = _mm256_fmadd_ps(
1718 _mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 4)),
1719 b_row,
1720 c4,
1721 );
1722 c5 = _mm256_fmadd_ps(
1723 _mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 5)),
1724 b_row,
1725 c5,
1726 );
1727 c6 = _mm256_fmadd_ps(
1728 _mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 6)),
1729 b_row,
1730 c6,
1731 );
1732 c7 = _mm256_fmadd_ps(
1733 _mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 7)),
1734 b_row,
1735 c7,
1736 );
1737
1738 let b_row = _mm256_loadu_ps(bp.add((p + 2) * 8));
1739 c0 = _mm256_fmadd_ps(
1740 _mm256_broadcast_ss(&*ap.add((p + 2) * 8)),
1741 b_row,
1742 c0,
1743 );
1744 c1 = _mm256_fmadd_ps(
1745 _mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 1)),
1746 b_row,
1747 c1,
1748 );
1749 c2 = _mm256_fmadd_ps(
1750 _mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 2)),
1751 b_row,
1752 c2,
1753 );
1754 c3 = _mm256_fmadd_ps(
1755 _mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 3)),
1756 b_row,
1757 c3,
1758 );
1759 c4 = _mm256_fmadd_ps(
1760 _mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 4)),
1761 b_row,
1762 c4,
1763 );
1764 c5 = _mm256_fmadd_ps(
1765 _mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 5)),
1766 b_row,
1767 c5,
1768 );
1769 c6 = _mm256_fmadd_ps(
1770 _mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 6)),
1771 b_row,
1772 c6,
1773 );
1774 c7 = _mm256_fmadd_ps(
1775 _mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 7)),
1776 b_row,
1777 c7,
1778 );
1779
1780 let b_row = _mm256_loadu_ps(bp.add((p + 3) * 8));
1781 c0 = _mm256_fmadd_ps(
1782 _mm256_broadcast_ss(&*ap.add((p + 3) * 8)),
1783 b_row,
1784 c0,
1785 );
1786 c1 = _mm256_fmadd_ps(
1787 _mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 1)),
1788 b_row,
1789 c1,
1790 );
1791 c2 = _mm256_fmadd_ps(
1792 _mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 2)),
1793 b_row,
1794 c2,
1795 );
1796 c3 = _mm256_fmadd_ps(
1797 _mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 3)),
1798 b_row,
1799 c3,
1800 );
1801 c4 = _mm256_fmadd_ps(
1802 _mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 4)),
1803 b_row,
1804 c4,
1805 );
1806 c5 = _mm256_fmadd_ps(
1807 _mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 5)),
1808 b_row,
1809 c5,
1810 );
1811 c6 = _mm256_fmadd_ps(
1812 _mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 6)),
1813 b_row,
1814 c6,
1815 );
1816 c7 = _mm256_fmadd_ps(
1817 _mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 7)),
1818 b_row,
1819 c7,
1820 );
1821 }
1822
1823 let base_rem = k4 * 4;
1824 for rp in 0..k_rem {
1825 let pp = base_rem + rp;
1826 let b_row = _mm256_loadu_ps(bp.add(pp * 8));
1827 c0 = _mm256_fmadd_ps(
1828 _mm256_broadcast_ss(&*ap.add(pp * 8)),
1829 b_row,
1830 c0,
1831 );
1832 c1 = _mm256_fmadd_ps(
1833 _mm256_broadcast_ss(&*ap.add(pp * 8 + 1)),
1834 b_row,
1835 c1,
1836 );
1837 c2 = _mm256_fmadd_ps(
1838 _mm256_broadcast_ss(&*ap.add(pp * 8 + 2)),
1839 b_row,
1840 c2,
1841 );
1842 c3 = _mm256_fmadd_ps(
1843 _mm256_broadcast_ss(&*ap.add(pp * 8 + 3)),
1844 b_row,
1845 c3,
1846 );
1847 c4 = _mm256_fmadd_ps(
1848 _mm256_broadcast_ss(&*ap.add(pp * 8 + 4)),
1849 b_row,
1850 c4,
1851 );
1852 c5 = _mm256_fmadd_ps(
1853 _mm256_broadcast_ss(&*ap.add(pp * 8 + 5)),
1854 b_row,
1855 c5,
1856 );
1857 c6 = _mm256_fmadd_ps(
1858 _mm256_broadcast_ss(&*ap.add(pp * 8 + 6)),
1859 b_row,
1860 c6,
1861 );
1862 c7 = _mm256_fmadd_ps(
1863 _mm256_broadcast_ss(&*ap.add(pp * 8 + 7)),
1864 b_row,
1865 c7,
1866 );
1867 }
1868
1869 _mm256_storeu_ps(c_base, c0);
1871 _mm256_storeu_ps(c_base.add(n), c1);
1872 _mm256_storeu_ps(c_base.add(2 * n), c2);
1873 _mm256_storeu_ps(c_base.add(3 * n), c3);
1874 _mm256_storeu_ps(c_base.add(4 * n), c4);
1875 _mm256_storeu_ps(c_base.add(5 * n), c5);
1876 _mm256_storeu_ps(c_base.add(6 * n), c6);
1877 _mm256_storeu_ps(c_base.add(7 * n), c7);
1878 }
1879 } else {
1880 for p in 0..kc_block {
1882 for jj in 0..nr_block {
1883 let b_val = b_panel[p * nr + jj];
1884 for ii in 0..mr_block {
1885 c[(ic + ir + ii) * n + (jc + jr + jj)] +=
1886 a_panel[p * mr + ii] * b_val;
1887 }
1888 }
1889 }
1890 }
1891 }
1892 }
1893 }
1894 }
1895 }
1896 });
1897 });
1898
1899 Ok(())
1900}
1901
1902#[cfg(target_arch = "x86_64")]
1909#[allow(dead_code)] fn gemm_blis_avx512_packed(
1911 m: usize,
1912 n: usize,
1913 k: usize,
1914 a: &[f32],
1915 b: &[f32],
1916 c: &mut [f32],
1917) -> Result<(), TruenoError> {
1918 let mc = MC_512.min(m);
1919 let nc = NC_512.min(n);
1920 let kc = KC_512.min(k);
1921
1922 let needed_a = packed_a_size_512(mc, kc);
1923 let needed_b = packed_b_size_512(kc, nc);
1924 let needed_c = MR_512 * NR_512;
1925
1926 TL_PACKED_A.with(|tl_a| {
1927 TL_PACKED_B.with(|tl_b| {
1928 TL_C_MICRO.with(|tl_c| {
1929 let mut packed_a = tl_a.borrow_mut();
1930 let mut packed_b = tl_b.borrow_mut();
1931 let mut c_micro = tl_c.borrow_mut();
1932
1933 if packed_a.len() < needed_a {
1934 packed_a.resize(needed_a, 0.0);
1935 }
1936 if packed_b.len() < needed_b {
1937 packed_b.resize(needed_b, 0.0);
1938 }
1939 if c_micro.len() < needed_c {
1940 c_micro.resize(needed_c, 0.0);
1941 }
1942
1943 for jc in (0..n).step_by(NC_512) {
1944 let nc_block = NC_512.min(n - jc);
1945
1946 for pc in (0..k).step_by(KC_512) {
1947 let kc_block = KC_512.min(k - pc);
1948
1949 pack_b_block_512(b, n, pc, jc, kc_block, nc_block, &mut packed_b);
1950
1951 for ic in (0..m).step_by(MC_512) {
1952 let mc_block = MC_512.min(m - ic);
1953
1954 pack_a_block_512(a, k, ic, pc, mc_block, kc_block, &mut packed_a);
1955
1956 for ir in (0..mc_block).step_by(MR_512) {
1958 let mr_block = MR_512.min(mc_block - ir);
1959 for jr in (0..nc_block).step_by(NR_512) {
1960 let nr_block = NR_512.min(nc_block - jr);
1961
1962 let a_panel = &packed_a[(ir / MR_512) * MR_512 * kc_block..];
1963 let b_panel = &packed_b[(jr / NR_512) * NR_512 * kc_block..];
1964
1965 for jj in 0..nr_block {
1967 for ii in 0..mr_block {
1968 c_micro[jj * MR_512 + ii] =
1969 c[(ic + ir + ii) * n + (jc + jr + jj)];
1970 }
1971 for ii in mr_block..MR_512 {
1972 c_micro[jj * MR_512 + ii] = 0.0;
1973 }
1974 }
1975 for jj in nr_block..NR_512 {
1976 for ii in 0..MR_512 {
1977 c_micro[jj * MR_512 + ii] = 0.0;
1978 }
1979 }
1980
1981 if mr_block == MR_512 && nr_block == NR_512 {
1983 unsafe {
1986 microkernel_16x8_avx512(
1987 kc_block,
1988 a_panel.as_ptr(),
1989 b_panel.as_ptr(),
1990 c_micro.as_mut_ptr(),
1991 MR_512,
1992 );
1993 }
1994 } else {
1995 for p in 0..kc_block {
1997 for jj in 0..NR_512 {
1998 let b_val = b_panel[p * NR_512 + jj];
1999 for ii in 0..MR_512 {
2000 c_micro[jj * MR_512 + ii] +=
2001 a_panel[p * MR_512 + ii] * b_val;
2002 }
2003 }
2004 }
2005 }
2006
2007 for jj in 0..nr_block {
2009 for ii in 0..mr_block {
2010 c[(ic + ir + ii) * n + (jc + jr + jj)] =
2011 c_micro[jj * MR_512 + ii];
2012 }
2013 }
2014 }
2015 }
2016 }
2017 }
2018 }
2019 });
2020 });
2021 });
2022
2023 Ok(())
2024}
2025
2026pub fn gemm_blis_with_prepacked_b(
2037 m: usize,
2038 n: usize,
2039 k: usize,
2040 a: &[f32],
2041 prepacked_b: &PrepackedB,
2042 c: &mut [f32],
2043 mut profiler: Option<&mut BlisProfiler>,
2044) -> Result<(), TruenoError> {
2045 if a.len() != m * k {
2046 return Err(TruenoError::InvalidInput(format!(
2047 "A size mismatch: expected {}, got {}",
2048 m * k,
2049 a.len()
2050 )));
2051 }
2052 if c.len() != m * n {
2053 return Err(TruenoError::InvalidInput(format!(
2054 "C size mismatch: expected {}, got {}",
2055 m * n,
2056 c.len()
2057 )));
2058 }
2059 if prepacked_b.k != k || prepacked_b.n != n {
2060 return Err(TruenoError::InvalidInput(format!(
2061 "PrepackedB dimension mismatch: expected ({}, {}), got ({}, {})",
2062 k, n, prepacked_b.k, prepacked_b.n
2063 )));
2064 }
2065
2066 if m == 0 || n == 0 || k == 0 {
2067 return Ok(());
2068 }
2069
2070 let track_time = profiler.is_some();
2071 let start = if track_time { Some(Instant::now()) } else { None };
2072
2073 let mc = MC.min(m);
2074 let kc = KC.min(k);
2075
2076 let needed_a = packed_a_size(mc, kc);
2077 let needed_c = MR * NR;
2078
2079 TL_PACKED_A.with(|tl_a| {
2081 TL_C_MICRO.with(|tl_c| {
2082 let mut packed_a = tl_a.borrow_mut();
2083 let mut c_micro = tl_c.borrow_mut();
2084
2085 if packed_a.len() < needed_a {
2086 packed_a.resize(needed_a, 0.0);
2087 } else {
2088 packed_a[..needed_a].fill(0.0);
2089 }
2090 if c_micro.len() < needed_c {
2091 c_micro.resize(needed_c, 0.0);
2092 } else {
2093 c_micro[..needed_c].fill(0.0);
2094 }
2095
2096 for (jc_idx, jc) in (0..n).step_by(NC).enumerate() {
2097 let nc_block = NC.min(n - jc);
2098
2099 for (pc_idx, pc) in (0..k).step_by(KC).enumerate() {
2100 let kc_block = KC.min(k - pc);
2101
2102 let packed_b_tile = prepacked_b.tile(jc_idx, pc_idx);
2104
2105 for ic in (0..m).step_by(MC) {
2106 let mc_block = MC.min(m - ic);
2107
2108 let pack_start = if track_time { Some(Instant::now()) } else { None };
2109 pack_a_block(a, k, ic, pc, mc_block, kc_block, &mut packed_a);
2110 record_prof(&mut profiler, BlisProfileLevel::Pack, pack_start, 0);
2111
2112 compute_macroblock(
2113 c,
2114 &packed_a,
2115 packed_b_tile,
2116 &mut c_micro,
2117 ic,
2118 jc,
2119 mc_block,
2120 nc_block,
2121 kc_block,
2122 n,
2123 &mut profiler,
2124 );
2125 }
2126 }
2127 }
2128
2129 if let (Some(prof), Some(s)) = (profiler, start) {
2130 prof.record(
2131 BlisProfileLevel::Macro,
2132 s.elapsed().as_nanos() as u64,
2133 (2 * m * n * k) as u64,
2134 );
2135 }
2136 });
2137 });
2138
2139 Ok(())
2140}