Skip to main content

trueno/blis/
compute.rs

1// Rust 1.93+: BLIS microkernels use bare unsafe ops inside `unsafe fn`.
2// Wrapping each intrinsic in `unsafe {}` would add 100+ blocks with no safety benefit.
3#![allow(unsafe_op_in_unsafe_fn)]
4//! Core BLIS compute routines: microkernel dispatch, macroblock execution,
5//! and the cache-blocked GEMM main loop.
6//!
7//! Implements the 5-loop BLIS algorithm (Van Zee & Van de Geijn, 2015):
8//! - Loop 5 (jc): N dimension, L3 blocking
9//! - Loop 4 (pc): K dimension, L2 blocking
10//! - Loop 3 (ic): M dimension, L1 blocking
11//! - Loop 2 (jr): Microkernel columns
12//! - Loop 1 (ir): Microkernel rows
13
14use std::cell::RefCell;
15use std::time::Instant;
16
17use crate::error::TruenoError;
18
19#[cfg(target_arch = "x86_64")]
20use super::microkernels::microkernel_16x8_avx512;
21#[cfg(target_arch = "x86_64")]
22use super::microkernels::microkernel_8x6_true_asm;
23use super::microkernels::microkernel_scalar;
24use super::packing::{pack_a_block, pack_b_block, packed_a_size, packed_b_size};
25#[cfg(target_arch = "x86_64")]
26use super::packing::{pack_a_block_512, pack_b_block_512, packed_a_size_512, packed_b_size_512};
27use super::prepacked::PrepackedB;
28use super::profiler::{BlisProfileLevel, BlisProfiler};
29use super::reference::gemm_reference;
30use super::{KC, MC, MR, NC, NR};
31#[cfg(target_arch = "x86_64")]
32use super::{KC_512, MC_512, MR_512, NC_512, NR_512};
33
34// Thread-local workspace buffers to eliminate allocation churn in gemm_blis.
35// These grow to the high-water mark and are reused across calls, avoiding
36// ~4.3 MB of allocation+deallocation per GEMM invocation.
37thread_local! {
38    static TL_PACKED_A: RefCell<Vec<f32>> = const { RefCell::new(Vec::new()) };
39    static TL_PACKED_B: RefCell<Vec<f32>> = const { RefCell::new(Vec::new()) };
40    static TL_C_MICRO: RefCell<Vec<f32>> = const { RefCell::new(Vec::new()) };
41}
42
43/// Load a tile of C into the micro workspace for accumulation.
44#[inline(always)]
45fn load_c_tile(
46    c: &[f32],
47    c_micro: &mut [f32],
48    row: usize,
49    col: usize,
50    mr: usize,
51    nr: usize,
52    n: usize,
53) {
54    for jj in 0..nr {
55        for ii in 0..mr {
56            c_micro[jj * MR + ii] = c[(row + ii) * n + (col + jj)];
57        }
58        for ii in mr..MR {
59            c_micro[jj * MR + ii] = 0.0;
60        }
61    }
62    for jj in nr..NR {
63        for ii in 0..MR {
64            c_micro[jj * MR + ii] = 0.0;
65        }
66    }
67}
68
69/// Store a micro tile back into C.
70#[inline(always)]
71fn store_c_tile(
72    c: &mut [f32],
73    c_micro: &[f32],
74    row: usize,
75    col: usize,
76    mr: usize,
77    nr: usize,
78    n: usize,
79) {
80    for jj in 0..nr {
81        for ii in 0..mr {
82            c[(row + ii) * n + (col + jj)] = c_micro[jj * MR + ii];
83        }
84    }
85}
86
87/// Dispatch to the best available microkernel (AVX2 ASM or scalar fallback).
88#[inline(always)]
89fn dispatch_microkernel(
90    kc: usize,
91    a_panel: &[f32],
92    b_panel: &[f32],
93    c_micro: &mut [f32],
94    mr_block: usize,
95    nr_block: usize,
96) {
97    #[cfg(target_arch = "x86_64")]
98    {
99        if is_x86_feature_detected!("avx2")
100            && is_x86_feature_detected!("fma")
101            && mr_block == MR
102            && nr_block == NR
103        {
104            // SAFETY: AVX2+FMA verified by is_x86_feature_detected!() above.
105            unsafe {
106                microkernel_8x6_true_asm(
107                    kc,
108                    a_panel.as_ptr(),
109                    b_panel.as_ptr(),
110                    c_micro.as_mut_ptr(),
111                    MR,
112                );
113            }
114            return;
115        }
116    }
117    microkernel_scalar(kc, a_panel, b_panel, c_micro, MR);
118}
119
120/// Execute microkernel tile iterations over one MC x NC x KC macro-block.
121#[allow(clippy::too_many_arguments)]
122fn compute_macroblock(
123    c: &mut [f32],
124    packed_a: &[f32],
125    packed_b: &[f32],
126    c_micro: &mut [f32],
127    ic: usize,
128    jc: usize,
129    mc_block: usize,
130    nc_block: usize,
131    kc_block: usize,
132    n: usize,
133    profiler: &mut Option<&mut BlisProfiler>,
134) {
135    // KAIZEN-038: Avoid Instant::now() syscall (~20-40ns) when profiler is disabled.
136    // For 1024x1024 GEMM, this eliminates thousands of syscalls per macroblock.
137    let track_time = profiler.is_some();
138    let midi_start = if track_time { Some(Instant::now()) } else { None };
139
140    for ir in (0..mc_block).step_by(MR) {
141        let mr_block = MR.min(mc_block - ir);
142        for jr in (0..nc_block).step_by(NR) {
143            let nr_block = NR.min(nc_block - jr);
144            let micro_start = if track_time { Some(Instant::now()) } else { None };
145
146            let a_panel = &packed_a[(ir / MR) * MR * kc_block..];
147            let b_panel = &packed_b[(jr / NR) * NR * kc_block..];
148
149            load_c_tile(c, c_micro, ic + ir, jc + jr, mr_block, nr_block, n);
150            dispatch_microkernel(kc_block, a_panel, b_panel, c_micro, mr_block, nr_block);
151            store_c_tile(c, c_micro, ic + ir, jc + jr, mr_block, nr_block, n);
152
153            if let (Some(ref mut prof), Some(start)) = (profiler.as_deref_mut(), micro_start) {
154                prof.record(
155                    BlisProfileLevel::Micro,
156                    start.elapsed().as_nanos() as u64,
157                    (2 * mr_block * nr_block * kc_block) as u64,
158                );
159            }
160        }
161    }
162
163    if let (Some(ref mut prof), Some(start)) = (profiler.as_deref_mut(), midi_start) {
164        prof.record(
165            BlisProfileLevel::Midi,
166            start.elapsed().as_nanos() as u64,
167            (2 * mc_block * nc_block * kc_block) as u64,
168        );
169    }
170}
171
172/// Zero-pack row-major GEMM for small matrices (≤128).
173///
174/// Key design: NO packing of A or B, NO C layout conversion.
175/// - A: broadcast scalar elements directly from row-major layout
176/// - B: SIMD load contiguous rows directly from row-major layout
177/// - C: SIMD load/store of contiguous rows (row-major accumulation)
178///
179/// This eliminates ~2µs of overhead for 64×64 GEMM:
180/// - No heap allocation for packed buffers
181/// - No SIMD transpose for A packing
182/// - No scalar C load/store (128 ops → 16 SIMD ops per tile)
183///
184/// Inner loop: for each K, load 1 B row (8 f32), broadcast 8 A elements,
185/// execute 8 FMAs into 8 C row accumulators. 4-way K unrolled.
186///
187/// Reference: Goto & Van de Geijn (2008), "Anatomy of High-Performance
188/// Matrix Multiplication" — panel-panel multiply with register blocking.
189#[cfg(target_arch = "x86_64")]
190#[target_feature(enable = "avx2", enable = "fma")]
191unsafe fn gemm_direct_rowmajor(
192    m: usize,
193    n: usize,
194    k: usize,
195    a: &[f32],
196    b: &[f32],
197    c: &mut [f32],
198) -> Result<(), TruenoError> {
199    use std::arch::x86_64::*;
200
201    let a_ptr = a.as_ptr();
202    let b_ptr = b.as_ptr();
203    let c_ptr = c.as_mut_ptr();
204
205    unsafe {
206        for ir in (0..m).step_by(8) {
207            for jr in (0..n).step_by(8) {
208                // Load 8 rows of C (each row is contiguous in row-major)
209                let c_base = c_ptr.add(ir * n + jr);
210                let mut c0 = _mm256_loadu_ps(c_base);
211                let mut c1 = _mm256_loadu_ps(c_base.add(n));
212                let mut c2 = _mm256_loadu_ps(c_base.add(2 * n));
213                let mut c3 = _mm256_loadu_ps(c_base.add(3 * n));
214                let mut c4 = _mm256_loadu_ps(c_base.add(4 * n));
215                let mut c5 = _mm256_loadu_ps(c_base.add(5 * n));
216                let mut c6 = _mm256_loadu_ps(c_base.add(6 * n));
217                let mut c7 = _mm256_loadu_ps(c_base.add(7 * n));
218
219                // A row base pointers (stride = k between columns)
220                let a0 = a_ptr.add(ir * k);
221                let a1 = a_ptr.add((ir + 1) * k);
222                let a2 = a_ptr.add((ir + 2) * k);
223                let a3 = a_ptr.add((ir + 3) * k);
224                let a4 = a_ptr.add((ir + 4) * k);
225                let a5 = a_ptr.add((ir + 5) * k);
226                let a6 = a_ptr.add((ir + 6) * k);
227                let a7 = a_ptr.add((ir + 7) * k);
228
229                // B base (stride = n between K rows)
230                let b_base = b_ptr.add(jr);
231
232                // 4-way K-unrolled main loop
233                let k4 = k / 4;
234                let k_rem = k % 4;
235
236                for p4 in 0..k4 {
237                    let p = p4 * 4;
238
239                    // K+0
240                    let b_row = _mm256_loadu_ps(b_base.add(p * n));
241                    c0 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a0.add(p)), b_row, c0);
242                    c1 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a1.add(p)), b_row, c1);
243                    c2 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a2.add(p)), b_row, c2);
244                    c3 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a3.add(p)), b_row, c3);
245                    c4 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a4.add(p)), b_row, c4);
246                    c5 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a5.add(p)), b_row, c5);
247                    c6 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a6.add(p)), b_row, c6);
248                    c7 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a7.add(p)), b_row, c7);
249
250                    // K+1
251                    let b_row = _mm256_loadu_ps(b_base.add((p + 1) * n));
252                    c0 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a0.add(p + 1)), b_row, c0);
253                    c1 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a1.add(p + 1)), b_row, c1);
254                    c2 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a2.add(p + 1)), b_row, c2);
255                    c3 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a3.add(p + 1)), b_row, c3);
256                    c4 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a4.add(p + 1)), b_row, c4);
257                    c5 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a5.add(p + 1)), b_row, c5);
258                    c6 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a6.add(p + 1)), b_row, c6);
259                    c7 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a7.add(p + 1)), b_row, c7);
260
261                    // K+2
262                    let b_row = _mm256_loadu_ps(b_base.add((p + 2) * n));
263                    c0 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a0.add(p + 2)), b_row, c0);
264                    c1 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a1.add(p + 2)), b_row, c1);
265                    c2 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a2.add(p + 2)), b_row, c2);
266                    c3 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a3.add(p + 2)), b_row, c3);
267                    c4 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a4.add(p + 2)), b_row, c4);
268                    c5 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a5.add(p + 2)), b_row, c5);
269                    c6 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a6.add(p + 2)), b_row, c6);
270                    c7 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a7.add(p + 2)), b_row, c7);
271
272                    // K+3
273                    let b_row = _mm256_loadu_ps(b_base.add((p + 3) * n));
274                    c0 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a0.add(p + 3)), b_row, c0);
275                    c1 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a1.add(p + 3)), b_row, c1);
276                    c2 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a2.add(p + 3)), b_row, c2);
277                    c3 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a3.add(p + 3)), b_row, c3);
278                    c4 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a4.add(p + 3)), b_row, c4);
279                    c5 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a5.add(p + 3)), b_row, c5);
280                    c6 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a6.add(p + 3)), b_row, c6);
281                    c7 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a7.add(p + 3)), b_row, c7);
282                }
283
284                // Remainder
285                let base_rem = k4 * 4;
286                for rp in 0..k_rem {
287                    let pp = base_rem + rp;
288                    let b_row = _mm256_loadu_ps(b_base.add(pp * n));
289                    c0 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a0.add(pp)), b_row, c0);
290                    c1 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a1.add(pp)), b_row, c1);
291                    c2 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a2.add(pp)), b_row, c2);
292                    c3 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a3.add(pp)), b_row, c3);
293                    c4 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a4.add(pp)), b_row, c4);
294                    c5 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a5.add(pp)), b_row, c5);
295                    c6 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a6.add(pp)), b_row, c6);
296                    c7 = _mm256_fmadd_ps(_mm256_broadcast_ss(&*a7.add(pp)), b_row, c7);
297                }
298
299                // Store 8 rows of C (contiguous SIMD stores)
300                _mm256_storeu_ps(c_base, c0);
301                _mm256_storeu_ps(c_base.add(n), c1);
302                _mm256_storeu_ps(c_base.add(2 * n), c2);
303                _mm256_storeu_ps(c_base.add(3 * n), c3);
304                _mm256_storeu_ps(c_base.add(4 * n), c4);
305                _mm256_storeu_ps(c_base.add(5 * n), c5);
306                _mm256_storeu_ps(c_base.add(6 * n), c6);
307                _mm256_storeu_ps(c_base.add(7 * n), c7);
308            }
309        }
310    }
311    Ok(())
312}
313
314/// Small-matrix stride-based GEMM — no packing, no c_micro buffer.
315/// For m,n,k <= 96 where packing overhead > cache benefit.
316#[cfg(target_arch = "x86_64")]
317#[target_feature(enable = "avx2", enable = "fma")]
318unsafe fn gemm_small_strided_avx2(
319    m: usize,
320    n: usize,
321    k: usize,
322    a: &[f32],
323    b: &[f32],
324    c: &mut [f32],
325) -> Result<(), TruenoError> {
326    use std::arch::x86_64::*;
327    unsafe {
328        for jr in (0..n).step_by(NR) {
329            let nr = NR.min(n - jr);
330            for ir in (0..m).step_by(MR) {
331                let mr = MR.min(m - ir);
332                let mut cv = [_mm256_setzero_ps(); 6];
333                for j in 0..nr {
334                    if mr == MR {
335                        cv[j] = _mm256_set_ps(
336                            *c.get_unchecked((ir + 7) * n + jr + j),
337                            *c.get_unchecked((ir + 6) * n + jr + j),
338                            *c.get_unchecked((ir + 5) * n + jr + j),
339                            *c.get_unchecked((ir + 4) * n + jr + j),
340                            *c.get_unchecked((ir + 3) * n + jr + j),
341                            *c.get_unchecked((ir + 2) * n + jr + j),
342                            *c.get_unchecked((ir + 1) * n + jr + j),
343                            *c.get_unchecked(ir * n + jr + j),
344                        );
345                    } else {
346                        let mut t = [0.0f32; 8];
347                        for i in 0..mr {
348                            t[i] = *c.get_unchecked((ir + i) * n + jr + j);
349                        }
350                        cv[j] = _mm256_loadu_ps(t.as_ptr());
351                    }
352                }
353                for p in 0..k {
354                    let a_col = if mr == MR {
355                        _mm256_set_ps(
356                            *a.get_unchecked((ir + 7) * k + p),
357                            *a.get_unchecked((ir + 6) * k + p),
358                            *a.get_unchecked((ir + 5) * k + p),
359                            *a.get_unchecked((ir + 4) * k + p),
360                            *a.get_unchecked((ir + 3) * k + p),
361                            *a.get_unchecked((ir + 2) * k + p),
362                            *a.get_unchecked((ir + 1) * k + p),
363                            *a.get_unchecked(ir * k + p),
364                        )
365                    } else {
366                        let mut t = [0.0f32; 8];
367                        for i in 0..mr {
368                            t[i] = *a.get_unchecked((ir + i) * k + p);
369                        }
370                        _mm256_loadu_ps(t.as_ptr())
371                    };
372                    let bp = b.as_ptr().add(p * n + jr);
373                    // Unrolled FMA for NR=6 common case
374                    if nr == NR {
375                        cv[0] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp), cv[0]);
376                        cv[1] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(1)), cv[1]);
377                        cv[2] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(2)), cv[2]);
378                        cv[3] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(3)), cv[3]);
379                        cv[4] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(4)), cv[4]);
380                        cv[5] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(5)), cv[5]);
381                    } else {
382                        for j in 0..nr {
383                            cv[j] = _mm256_fmadd_ps(a_col, _mm256_set1_ps(*bp.add(j)), cv[j]);
384                        }
385                    }
386                }
387                for j in 0..nr {
388                    let mut t = [0.0f32; 8];
389                    _mm256_storeu_ps(t.as_mut_ptr(), cv[j]);
390                    for i in 0..mr {
391                        *c.get_unchecked_mut((ir + i) * n + jr + j) = t[i];
392                    }
393                }
394            }
395        }
396    }
397    Ok(())
398}
399
400/// 8x8 GEMM: pre-packed B, SIMD transpose A packing, K-unrolled micro-kernel.
401/// For m,n divisible by 8 and m,n,k ≤ 256.
402///
403/// Key optimization: B panels are packed ONCE before the tile loop, eliminating
404/// panels_m redundant repacking passes. For 128x128 this removes 15/16 = 94%
405/// of B packing work.
406#[cfg(target_arch = "x86_64")]
407#[target_feature(enable = "avx2", enable = "fma")]
408unsafe fn gemm_small_nopack_8x8(
409    m: usize,
410    n: usize,
411    k: usize,
412    a: &[f32],
413    b: &[f32],
414    c: &mut [f32],
415) -> Result<(), TruenoError> {
416    use crate::blis::microkernels::microkernel_8x8_avx2_fma;
417    use std::arch::x86_64::*;
418
419    let panels_m = m / 8;
420    let panels_n = n / 8;
421
422    // Stack buffer for one A panel (8×K col-major).
423    let mut packed_a = [0.0f32; 8 * 256];
424    let mut c_micro = [0.0f32; 64];
425
426    // Pre-pack ALL B panels once (eliminates panels_m redundant repacking).
427    // For 256x256: 32 panels × 256 × 8 = 256KB — acceptable heap alloc.
428    let mut all_packed_b = vec![0.0f32; panels_n * k * 8];
429
430    unsafe {
431        for jr_panel in 0..panels_n {
432            let jr = jr_panel * 8;
433            let b_dst = all_packed_b.as_mut_ptr().add(jr_panel * k * 8);
434            for p in 0..k {
435                _mm256_storeu_ps(b_dst.add(p * 8), _mm256_loadu_ps(b.as_ptr().add(p * n + jr)));
436            }
437        }
438
439        for ir_panel in 0..panels_m {
440            let ir = ir_panel * 8;
441
442            // Pack A panel: SIMD 8×8 transpose blocks (row-major → col-major)
443            let k_blocks = k / 8;
444            let k_rem = k_blocks * 8;
445            for kb in 0..k_blocks {
446                let p = kb * 8;
447                let r0 = _mm256_loadu_ps(a.as_ptr().add(ir * k + p));
448                let r1 = _mm256_loadu_ps(a.as_ptr().add((ir + 1) * k + p));
449                let r2 = _mm256_loadu_ps(a.as_ptr().add((ir + 2) * k + p));
450                let r3 = _mm256_loadu_ps(a.as_ptr().add((ir + 3) * k + p));
451                let r4 = _mm256_loadu_ps(a.as_ptr().add((ir + 4) * k + p));
452                let r5 = _mm256_loadu_ps(a.as_ptr().add((ir + 5) * k + p));
453                let r6 = _mm256_loadu_ps(a.as_ptr().add((ir + 6) * k + p));
454                let r7 = _mm256_loadu_ps(a.as_ptr().add((ir + 7) * k + p));
455
456                let t0 = _mm256_unpacklo_ps(r0, r1);
457                let t1 = _mm256_unpackhi_ps(r0, r1);
458                let t2 = _mm256_unpacklo_ps(r2, r3);
459                let t3 = _mm256_unpackhi_ps(r2, r3);
460                let t4 = _mm256_unpacklo_ps(r4, r5);
461                let t5 = _mm256_unpackhi_ps(r4, r5);
462                let t6 = _mm256_unpacklo_ps(r6, r7);
463                let t7 = _mm256_unpackhi_ps(r6, r7);
464
465                let u0 = _mm256_shuffle_ps(t0, t2, 0x44);
466                let u1 = _mm256_shuffle_ps(t0, t2, 0xEE);
467                let u2 = _mm256_shuffle_ps(t1, t3, 0x44);
468                let u3 = _mm256_shuffle_ps(t1, t3, 0xEE);
469                let u4 = _mm256_shuffle_ps(t4, t6, 0x44);
470                let u5 = _mm256_shuffle_ps(t4, t6, 0xEE);
471                let u6 = _mm256_shuffle_ps(t5, t7, 0x44);
472                let u7 = _mm256_shuffle_ps(t5, t7, 0xEE);
473
474                let dst = packed_a.as_mut_ptr().add(p * 8);
475                _mm256_storeu_ps(dst, _mm256_permute2f128_ps(u0, u4, 0x20));
476                _mm256_storeu_ps(dst.add(8), _mm256_permute2f128_ps(u1, u5, 0x20));
477                _mm256_storeu_ps(dst.add(16), _mm256_permute2f128_ps(u2, u6, 0x20));
478                _mm256_storeu_ps(dst.add(24), _mm256_permute2f128_ps(u3, u7, 0x20));
479                _mm256_storeu_ps(dst.add(32), _mm256_permute2f128_ps(u0, u4, 0x31));
480                _mm256_storeu_ps(dst.add(40), _mm256_permute2f128_ps(u1, u5, 0x31));
481                _mm256_storeu_ps(dst.add(48), _mm256_permute2f128_ps(u2, u6, 0x31));
482                _mm256_storeu_ps(dst.add(56), _mm256_permute2f128_ps(u3, u7, 0x31));
483            }
484            for p in k_rem..k {
485                for i in 0..8 {
486                    *packed_a.get_unchecked_mut(p * 8 + i) = *a.get_unchecked((ir + i) * k + p);
487                }
488            }
489
490            for jr_panel in 0..panels_n {
491                let jr = jr_panel * 8;
492                let packed_b_ptr = all_packed_b.as_ptr().add(jr_panel * k * 8);
493
494                // Load C tile (column-major for micro-kernel)
495                for jj in 0..8 {
496                    for ii in 0..8 {
497                        c_micro[jj * 8 + ii] = *c.get_unchecked((ir + ii) * n + jr + jj);
498                    }
499                }
500
501                microkernel_8x8_avx2_fma(
502                    k,
503                    packed_a.as_ptr(),
504                    packed_b_ptr,
505                    c_micro.as_mut_ptr(),
506                    8,
507                );
508
509                // Store C tile back to row-major
510                for jj in 0..8 {
511                    for ii in 0..8 {
512                        *c.get_unchecked_mut((ir + ii) * n + jr + jj) = c_micro[jj * 8 + ii];
513                    }
514                }
515            }
516        }
517    }
518    Ok(())
519}
520
521/// Small-matrix 8x8 GEMM — stack-packed A/B, striped 8x8 AVX2 kernel.
522/// Fewer tiles than 8x6 (64 outputs vs 48 per tile = 33% fewer tiles).
523/// For dimensions that are multiples of 8.
524#[cfg(target_arch = "x86_64")]
525#[target_feature(enable = "avx2", enable = "fma")]
526#[allow(dead_code)] // Superseded by gemm_small_nopack_8x8 but retained for profiling comparison
527unsafe fn gemm_small_8x8(
528    m: usize,
529    n: usize,
530    k: usize,
531    a: &[f32],
532    b: &[f32],
533    c: &mut [f32],
534) -> Result<(), TruenoError> {
535    use crate::blis::microkernels::microkernel_8x8_avx2_fma;
536    // Stack pack buffers (max 128*128 = 64KB each for 128x128)
537    let mut packed_a = vec![0.0f32; m * k];
538    let mut packed_b = vec![0.0f32; k * n];
539    let mut c_micro = [0.0f32; 8 * 8]; // 8x8 tile
540
541    // Pack A: row-major a[i*k+p] → column-major packed_a[p*8 + (i%8)] per 8-row panel
542    let panels_m = (m + 7) / 8;
543    for panel in 0..panels_m {
544        let ir = panel * 8;
545        let mr = 8.min(m - ir);
546        for p in 0..k {
547            for i in 0..8 {
548                unsafe {
549                    packed_a[panel * 8 * k + p * 8 + i] =
550                        if i < mr { *a.get_unchecked((ir + i) * k + p) } else { 0.0 };
551                }
552            }
553        }
554    }
555    // Pack B: row-major b[p*n+j] → row-major packed_b[panel*8*k + p*8 + (j%8)]
556    let panels_n = (n + 7) / 8;
557    for panel in 0..panels_n {
558        let jr = panel * 8;
559        let nr = 8.min(n - jr);
560        for p in 0..k {
561            for j in 0..8 {
562                unsafe {
563                    packed_b[panel * 8 * k + p * 8 + j] =
564                        if j < nr { *b.get_unchecked(p * n + jr + j) } else { 0.0 };
565                }
566            }
567        }
568    }
569
570    // Run 8x8 micro-tiles
571    unsafe {
572        for ir_panel in 0..panels_m {
573            let ir = ir_panel * 8;
574            let mr = 8.min(m - ir);
575            for jr_panel in 0..panels_n {
576                let jr = jr_panel * 8;
577                let nr = 8.min(n - jr);
578                // Load C tile (8x8 column-major)
579                for jj in 0..8 {
580                    for ii in 0..8 {
581                        c_micro[jj * 8 + ii] = if ii < mr && jj < nr {
582                            *c.get_unchecked((ir + ii) * n + jr + jj)
583                        } else {
584                            0.0
585                        };
586                    }
587                }
588                let ap = packed_a.as_ptr().add(ir_panel * 8 * k);
589                let bp = packed_b.as_ptr().add(jr_panel * 8 * k);
590                microkernel_8x8_avx2_fma(k, ap, bp, c_micro.as_mut_ptr(), 8);
591                // Store C tile
592                for jj in 0..nr {
593                    for ii in 0..mr {
594                        *c.get_unchecked_mut((ir + ii) * n + jr + jj) = c_micro[jj * 8 + ii];
595                    }
596                }
597            }
598        }
599    }
600    Ok(())
601}
602
603/// AVX-512 small GEMM: 16×8 tiles, pre-packed B, scalar-packed A.
604/// For m divisible by 16, n divisible by 8, m,n,k ≤ 256.
605/// 2× vector width over AVX2 gives ~2× throughput on compute-bound tiles.
606#[cfg(target_arch = "x86_64")]
607#[target_feature(enable = "avx512f")]
608unsafe fn gemm_small_avx512_16x8(
609    m: usize,
610    n: usize,
611    k: usize,
612    a: &[f32],
613    b: &[f32],
614    c: &mut [f32],
615) -> Result<(), TruenoError> {
616    use super::microkernels::microkernel_16x8_avx512;
617    use std::arch::x86_64::*;
618
619    let panels_m = m / 16;
620    let panels_n = n / 8;
621
622    // Pre-pack ALL B panels once (SIMD 8-wide contiguous copies).
623    let mut all_packed_b = vec![0.0f32; panels_n * k * 8];
624    unsafe {
625        for jr_panel in 0..panels_n {
626            let jr = jr_panel * 8;
627            let b_dst = all_packed_b.as_mut_ptr().add(jr_panel * k * 8);
628            for p in 0..k {
629                _mm256_storeu_ps(b_dst.add(p * 8), _mm256_loadu_ps(b.as_ptr().add(p * n + jr)));
630            }
631        }
632    }
633
634    // Stack buffers — A panel: 16×K column-major, C micro tile: 16×8.
635    let mut packed_a = [0.0f32; 16 * 256];
636    let mut c_micro = [0.0f32; 16 * 8];
637
638    unsafe {
639        for ir_panel in 0..panels_m {
640            let ir = ir_panel * 16;
641
642            // Pack A: row-major → column-major via two 8×8 SIMD transposes.
643            let k_blocks = k / 8;
644            let k_rem_start = k_blocks * 8;
645
646            for kb in 0..k_blocks {
647                let p = kb * 8;
648
649                // Upper 8 rows
650                let r0 = _mm256_loadu_ps(a.as_ptr().add(ir * k + p));
651                let r1 = _mm256_loadu_ps(a.as_ptr().add((ir + 1) * k + p));
652                let r2 = _mm256_loadu_ps(a.as_ptr().add((ir + 2) * k + p));
653                let r3 = _mm256_loadu_ps(a.as_ptr().add((ir + 3) * k + p));
654                let r4 = _mm256_loadu_ps(a.as_ptr().add((ir + 4) * k + p));
655                let r5 = _mm256_loadu_ps(a.as_ptr().add((ir + 5) * k + p));
656                let r6 = _mm256_loadu_ps(a.as_ptr().add((ir + 6) * k + p));
657                let r7 = _mm256_loadu_ps(a.as_ptr().add((ir + 7) * k + p));
658
659                let t0 = _mm256_unpacklo_ps(r0, r1);
660                let t1 = _mm256_unpackhi_ps(r0, r1);
661                let t2 = _mm256_unpacklo_ps(r2, r3);
662                let t3 = _mm256_unpackhi_ps(r2, r3);
663                let t4 = _mm256_unpacklo_ps(r4, r5);
664                let t5 = _mm256_unpackhi_ps(r4, r5);
665                let t6 = _mm256_unpacklo_ps(r6, r7);
666                let t7 = _mm256_unpackhi_ps(r6, r7);
667
668                let u0 = _mm256_shuffle_ps(t0, t2, 0x44);
669                let u1 = _mm256_shuffle_ps(t0, t2, 0xEE);
670                let u2 = _mm256_shuffle_ps(t1, t3, 0x44);
671                let u3 = _mm256_shuffle_ps(t1, t3, 0xEE);
672                let u4 = _mm256_shuffle_ps(t4, t6, 0x44);
673                let u5 = _mm256_shuffle_ps(t4, t6, 0xEE);
674                let u6 = _mm256_shuffle_ps(t5, t7, 0x44);
675                let u7 = _mm256_shuffle_ps(t5, t7, 0xEE);
676
677                // Direct stores — stride 16 between K columns
678                let dst = packed_a.as_mut_ptr().add(p * 16);
679                _mm256_storeu_ps(dst, _mm256_permute2f128_ps(u0, u4, 0x20));
680                _mm256_storeu_ps(dst.add(16), _mm256_permute2f128_ps(u1, u5, 0x20));
681                _mm256_storeu_ps(dst.add(32), _mm256_permute2f128_ps(u2, u6, 0x20));
682                _mm256_storeu_ps(dst.add(48), _mm256_permute2f128_ps(u3, u7, 0x20));
683                _mm256_storeu_ps(dst.add(64), _mm256_permute2f128_ps(u0, u4, 0x31));
684                _mm256_storeu_ps(dst.add(80), _mm256_permute2f128_ps(u1, u5, 0x31));
685                _mm256_storeu_ps(dst.add(96), _mm256_permute2f128_ps(u2, u6, 0x31));
686                _mm256_storeu_ps(dst.add(112), _mm256_permute2f128_ps(u3, u7, 0x31));
687
688                // Lower 8 rows
689                let r0 = _mm256_loadu_ps(a.as_ptr().add((ir + 8) * k + p));
690                let r1 = _mm256_loadu_ps(a.as_ptr().add((ir + 9) * k + p));
691                let r2 = _mm256_loadu_ps(a.as_ptr().add((ir + 10) * k + p));
692                let r3 = _mm256_loadu_ps(a.as_ptr().add((ir + 11) * k + p));
693                let r4 = _mm256_loadu_ps(a.as_ptr().add((ir + 12) * k + p));
694                let r5 = _mm256_loadu_ps(a.as_ptr().add((ir + 13) * k + p));
695                let r6 = _mm256_loadu_ps(a.as_ptr().add((ir + 14) * k + p));
696                let r7 = _mm256_loadu_ps(a.as_ptr().add((ir + 15) * k + p));
697
698                let t0 = _mm256_unpacklo_ps(r0, r1);
699                let t1 = _mm256_unpackhi_ps(r0, r1);
700                let t2 = _mm256_unpacklo_ps(r2, r3);
701                let t3 = _mm256_unpackhi_ps(r2, r3);
702                let t4 = _mm256_unpacklo_ps(r4, r5);
703                let t5 = _mm256_unpackhi_ps(r4, r5);
704                let t6 = _mm256_unpacklo_ps(r6, r7);
705                let t7 = _mm256_unpackhi_ps(r6, r7);
706
707                let u0 = _mm256_shuffle_ps(t0, t2, 0x44);
708                let u1 = _mm256_shuffle_ps(t0, t2, 0xEE);
709                let u2 = _mm256_shuffle_ps(t1, t3, 0x44);
710                let u3 = _mm256_shuffle_ps(t1, t3, 0xEE);
711                let u4 = _mm256_shuffle_ps(t4, t6, 0x44);
712                let u5 = _mm256_shuffle_ps(t4, t6, 0xEE);
713                let u6 = _mm256_shuffle_ps(t5, t7, 0x44);
714                let u7 = _mm256_shuffle_ps(t5, t7, 0xEE);
715
716                // Lower rows at +8 offset
717                let dst_lo = packed_a.as_mut_ptr().add(p * 16 + 8);
718                _mm256_storeu_ps(dst_lo, _mm256_permute2f128_ps(u0, u4, 0x20));
719                _mm256_storeu_ps(dst_lo.add(16), _mm256_permute2f128_ps(u1, u5, 0x20));
720                _mm256_storeu_ps(dst_lo.add(32), _mm256_permute2f128_ps(u2, u6, 0x20));
721                _mm256_storeu_ps(dst_lo.add(48), _mm256_permute2f128_ps(u3, u7, 0x20));
722                _mm256_storeu_ps(dst_lo.add(64), _mm256_permute2f128_ps(u0, u4, 0x31));
723                _mm256_storeu_ps(dst_lo.add(80), _mm256_permute2f128_ps(u1, u5, 0x31));
724                _mm256_storeu_ps(dst_lo.add(96), _mm256_permute2f128_ps(u2, u6, 0x31));
725                _mm256_storeu_ps(dst_lo.add(112), _mm256_permute2f128_ps(u3, u7, 0x31));
726            }
727
728            // Remainder k columns: scalar pack
729            for p in k_rem_start..k {
730                for i in 0..16 {
731                    *packed_a.get_unchecked_mut(p * 16 + i) = *a.get_unchecked((ir + i) * k + p);
732                }
733            }
734
735            for jr_panel in 0..panels_n {
736                let jr = jr_panel * 8;
737                let packed_b_ptr = all_packed_b.as_ptr().add(jr_panel * k * 8);
738
739                // Load C tile (column-major for micro-kernel: c_micro[j*16+i])
740                for jj in 0..8 {
741                    for ii in 0..16 {
742                        c_micro[jj * 16 + ii] = *c.get_unchecked((ir + ii) * n + jr + jj);
743                    }
744                }
745
746                microkernel_16x8_avx512(
747                    k,
748                    packed_a.as_ptr(),
749                    packed_b_ptr,
750                    c_micro.as_mut_ptr(),
751                    16,
752                );
753
754                // Store C tile back to row-major
755                for jj in 0..8 {
756                    for ii in 0..16 {
757                        *c.get_unchecked_mut((ir + ii) * n + jr + jj) = c_micro[jj * 16 + ii];
758                    }
759                }
760            }
761        }
762    }
763    Ok(())
764}
765
766/// Validate GEMM dimension inputs (Poka-yoke).
767fn validate_gemm_dims(
768    m: usize,
769    n: usize,
770    k: usize,
771    a: &[f32],
772    b: &[f32],
773    c: &[f32],
774) -> Result<(), TruenoError> {
775    if a.len() != m * k {
776        return Err(TruenoError::InvalidInput(format!(
777            "A size mismatch: expected {}, got {}",
778            m * k,
779            a.len()
780        )));
781    }
782    if b.len() != k * n {
783        return Err(TruenoError::InvalidInput(format!(
784            "B size mismatch: expected {}, got {}",
785            k * n,
786            b.len()
787        )));
788    }
789    if c.len() != m * n {
790        return Err(TruenoError::InvalidInput(format!(
791            "C size mismatch: expected {}, got {}",
792            m * n,
793            c.len()
794        )));
795    }
796    Ok(())
797}
798
799/// Record a profiler event if profiler is active.
800#[inline(always)]
801fn record_prof(
802    profiler: &mut Option<&mut BlisProfiler>,
803    level: BlisProfileLevel,
804    start: Option<Instant>,
805    flops: u64,
806) {
807    if let (Some(ref mut prof), Some(s)) = (profiler.as_deref_mut(), start) {
808        prof.record(level, s.elapsed().as_nanos() as u64, flops);
809    }
810}
811
812/// BLIS-style blocked GEMM
813///
814/// Implements the 5-loop BLIS algorithm (Van Zee & Van de Geijn, 2015):
815/// Loop 5 (jc): N dimension, L3 blocking
816/// Loop 4 (pc): K dimension, L2 blocking
817/// Loop 3 (ic): M dimension, L1 blocking
818/// Loop 2 (jr): Microkernel columns
819/// Loop 1 (ir): Microkernel rows
820pub fn gemm_blis(
821    m: usize,
822    n: usize,
823    k: usize,
824    a: &[f32],
825    b: &[f32],
826    c: &mut [f32],
827    mut profiler: Option<&mut BlisProfiler>,
828) -> Result<(), TruenoError> {
829    contract_pre_flops_per_tile!();
830    validate_gemm_dims(m, n, k, a, b, c)?;
831
832    if m == 0 || n == 0 || k == 0 {
833        return Ok(());
834    }
835    if m * n * k < 4096 {
836        return gemm_reference(m, n, k, a, b, c);
837    }
838
839    // Small: optimized GEMM paths (skip when profiler active).
840    #[cfg(target_arch = "x86_64")]
841    if profiler.is_none()
842        && m <= 256
843        && n <= 256
844        && k <= 256
845        && is_x86_feature_detected!("avx2")
846        && is_x86_feature_detected!("fma")
847    {
848        unsafe {
849            // Zero-pack row-major GEMM for ≤128: no packing, no C transpose.
850            if m <= 128 && n <= 128 && m % 8 == 0 && n % 8 == 0 {
851                return gemm_direct_rowmajor(m, n, k, a, b, c);
852            }
853            // AVX-512 for 129-256: 16×8 tiles, pre-packed B.
854            if is_x86_feature_detected!("avx512f") && m >= 16 && m % 16 == 0 && n % 8 == 0 {
855                return gemm_small_avx512_16x8(m, n, k, a, b, c);
856            }
857            if m >= MR && m % 8 == 0 && n % 8 == 0 {
858                return gemm_small_nopack_8x8(m, n, k, a, b, c);
859            }
860            return gemm_small_strided_avx2(m, n, k, a, b, c);
861        }
862    }
863
864    // AVX-512 BLIS: MR=8, NR=16 using zmm registers (2× throughput vs AVX2).
865    // This closes the gap with OpenBLAS which uses AVX-512 on Zen 4.
866    // CRITICAL: Without this, trueno is 0.49x NumPy at 8T (shipping blocker).
867    // Contract: avx512-blis-v1.yaml (C-AVX512-BLIS-001, C-AVX512-PROF-001)
868    #[cfg(target_arch = "x86_64")]
869    if is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("fma") {
870        return unsafe { gemm_blis_avx512_large(m, n, k, a, b, c, &mut profiler) };
871    }
872
873    // NR=8 BLIS with row-major C SIMD load/store (AVX2 fallback).
874    #[cfg(target_arch = "x86_64")]
875    if profiler.is_none() && is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
876        return unsafe { gemm_blis_nr8_rowmajor_c(m, n, k, a, b, c) };
877    }
878
879    // KAIZEN-038: Only call Instant::now() when profiler is active
880    let track_time = profiler.is_some();
881    let start = if track_time { Some(Instant::now()) } else { None };
882
883    let mc = MC.min(m);
884    let nc = NC.min(n);
885    let kc = KC.min(k);
886
887    let needed_a = packed_a_size(mc, kc);
888    let needed_b = packed_b_size(kc, nc);
889    let needed_c = MR * NR;
890
891    // Borrow thread-local workspace buffers, growing if necessary.
892    // This eliminates ~4.3 MB of allocation churn per gemm_blis call.
893    TL_PACKED_A.with(|tl_a| {
894        TL_PACKED_B.with(|tl_b| {
895            TL_C_MICRO.with(|tl_c| {
896                let mut packed_a = tl_a.borrow_mut();
897                let mut packed_b = tl_b.borrow_mut();
898                let mut c_micro = tl_c.borrow_mut();
899
900                // Grow buffers to required size (high-water mark).
901                // Zero-fill to match the semantics of the original vec![0.0; N].
902                if packed_a.len() < needed_a {
903                    packed_a.resize(needed_a, 0.0);
904                }
905                if packed_b.len() < needed_b {
906                    packed_b.resize(needed_b, 0.0);
907                }
908                if c_micro.len() < needed_c {
909                    c_micro.resize(needed_c, 0.0);
910                }
911
912                for jc in (0..n).step_by(NC) {
913                    let nc_block = NC.min(n - jc);
914
915                    for pc in (0..k).step_by(KC) {
916                        let kc_block = KC.min(k - pc);
917
918                        let pack_start = if track_time { Some(Instant::now()) } else { None };
919                        pack_b_block(b, n, pc, jc, kc_block, nc_block, &mut packed_b);
920                        record_prof(&mut profiler, BlisProfileLevel::Pack, pack_start, 0);
921
922                        for ic in (0..m).step_by(MC) {
923                            let mc_block = MC.min(m - ic);
924
925                            let pack_start = if track_time { Some(Instant::now()) } else { None };
926                            pack_a_block(a, k, ic, pc, mc_block, kc_block, &mut packed_a);
927                            record_prof(&mut profiler, BlisProfileLevel::Pack, pack_start, 0);
928
929                            compute_macroblock(
930                                c,
931                                &packed_a,
932                                &packed_b,
933                                &mut c_micro,
934                                ic,
935                                jc,
936                                mc_block,
937                                nc_block,
938                                kc_block,
939                                n,
940                                &mut profiler,
941                            );
942                        }
943                    }
944                }
945
946                if let (Some(prof), Some(s)) = (profiler, start) {
947                    prof.record(
948                        BlisProfileLevel::Macro,
949                        s.elapsed().as_nanos() as u64,
950                        (2 * m * n * k) as u64,
951                    );
952                }
953            });
954        });
955    });
956
957    contract_post_flops_per_tile!(c);
958    Ok(())
959}
960
961/// AVX-512 BLIS 5-loop GEMM — MR=8, NR=16 using zmm registers.
962///
963/// 2× throughput vs AVX2 path: each C row = 16 f32 = 1 zmm register.
964/// 8 zmm accumulators (8 rows × 16 cols = 128 elements), well within
965/// AVX-512's 32-register budget. B loaded as zmm (16 f32), A broadcast.
966///
967/// Cache blocking: MC=64, KC=256, NC=1024.
968/// A packing: MR=8 column-major panels (reuses pack_a_block).
969/// B packing: NR=16 row-major panels.
970/// AVX-512 BLIS 5-loop GEMM — MR=8, NR=16 with BlisProfiler support.
971/// Contract: avx512-blis-v1.yaml (C-AVX512-BLIS-001, C-AVX512-PROF-001)
972#[cfg(target_arch = "x86_64")]
973#[target_feature(enable = "avx512f", enable = "fma")]
974unsafe fn gemm_blis_avx512_large(
975    m: usize,
976    n: usize,
977    k: usize,
978    a: &[f32],
979    b: &[f32],
980    c: &mut [f32],
981    profiler: &mut Option<&mut BlisProfiler>,
982) -> Result<(), TruenoError> {
983    // KAIZEN-038: Only call Instant::now() when profiler is active
984    let track_time = profiler.is_some();
985    let start = if track_time { Some(Instant::now()) } else { None };
986
987    // Phase 4-6: tiered NR selection with dynamic cache blocking.
988    // Contract: cgp-dynamic-cache-v1.yaml, cgp-gemm-codegen-v1.yaml
989    // NR=48 (codegen): 24 FMA/K-step, KC=128 (L1-limited)
990    // NR=32 (hand-written): 16 FMA/K-step, KC=256
991    // NR=16 (hand-written): 8 FMA/K-step, KC=256
992    // NEGATIVE RESULT (2026-04-05): NR=48 codegen with KC=128 regressed:
993    //   512: 41 → 135 GFLOPS after reverting to NR=32
994    //   1024: 85 → 130 GFLOPS after reverting to NR=32
995    // Root cause: KC halved (128 vs 256) → 2× more K-loop packing passes.
996    // The 24 FMA/K-step doesn't compensate for the packing overhead increase.
997    // NR=48 path disabled pending KC optimization (prefetch, double-buffer).
998    let blk = if n >= 32 {
999        super::cache_topology::blocking_8x32()
1000    } else {
1001        super::cache_topology::blocking_8x16()
1002    };
1003    let mr = blk.mr;
1004    let nr = blk.nr;
1005    let mc = blk.mc.min(m);
1006    let nc = blk.nc.min(n);
1007    let kc_param = blk.kc;
1008
1009    TL_PACKED_A.with(|tl_a| {
1010        TL_PACKED_B.with(|tl_b| {
1011            let mut packed_a = tl_a.borrow_mut();
1012            let mut packed_b = tl_b.borrow_mut();
1013
1014            let needed_a = packed_a_size(mc, kc_param);
1015            // packed B: panels * nr * kc, where panels rounds up nc/nr
1016            let b_panels = (nc + nr - 1) / nr;
1017            let needed_b = b_panels * nr * kc_param;
1018            if packed_a.len() < needed_a {
1019                packed_a.resize(needed_a, 0.0);
1020            }
1021            if packed_b.len() < needed_b {
1022                packed_b.resize(needed_b, 0.0);
1023            }
1024
1025            for jc in (0..n).step_by(nc) {
1026                let nc_block = nc.min(n - jc);
1027
1028                for pc in (0..k).step_by(kc_param) {
1029                    let kc_block = kc_param.min(k - pc);
1030
1031                    // Pack B with NR matching the selected microkernel
1032                    if nr == 48 {
1033                        pack_b_block_generic(b, n, pc, jc, kc_block, nc_block, 48, &mut packed_b);
1034                    } else if nr == 32 {
1035                        pack_b_block_generic(b, n, pc, jc, kc_block, nc_block, 32, &mut packed_b);
1036                    } else {
1037                        pack_b_block_nr16(b, n, pc, jc, kc_block, nc_block, &mut packed_b);
1038                    }
1039
1040                    for ic in (0..m).step_by(mc) {
1041                        let mc_block = mc.min(m - ic);
1042
1043                        pack_a_block(a, k, ic, pc, mc_block, kc_block, &mut packed_a);
1044
1045                        let panels_m = (mc_block + mr - 1) / mr;
1046                        let panels_n = (nc_block + nr - 1) / nr;
1047
1048                        for ir_panel in 0..panels_m {
1049                            let ir = ir_panel * mr;
1050                            let mr_block = mr.min(mc_block - ir);
1051
1052                            for jr_panel in 0..panels_n {
1053                                let jr = jr_panel * nr;
1054                                let nr_block = nr.min(nc_block - jr);
1055
1056                                let a_panel = &packed_a[ir_panel * mr * kc_block..];
1057                                let b_panel = &packed_b[jr_panel * nr * kc_block..];
1058
1059                                if mr_block == 8 && nr_block == 48 && nr == 48 {
1060                                    // Full 8×48 codegen tile (Phase 6: 24 accumulators)
1061                                    // Contract: cgp-gemm-codegen-v1.yaml C-CODEGEN-002
1062                                    unsafe {
1063                                        super::microkernels::codegen::microkernel_8x48_avx512_gen(
1064                                            kc_block,
1065                                            a_panel.as_ptr(),
1066                                            b_panel.as_ptr(),
1067                                            c.as_mut_ptr().add((ic + ir) * n + (jc + jr)),
1068                                            n,
1069                                        );
1070                                    }
1071                                } else if mr_block == 8 && nr_block == 32 && nr == 32 {
1072                                    // Full 8×32 AVX-512 tile (Phase 4: 16 accumulators)
1073                                    unsafe {
1074                                        avx512_microkernel_8x32_rowmajor(
1075                                            kc_block,
1076                                            a_panel.as_ptr(),
1077                                            b_panel.as_ptr(),
1078                                            c.as_mut_ptr().add((ic + ir) * n + (jc + jr)),
1079                                            n,
1080                                        );
1081                                    }
1082                                } else if mr_block == 8 && nr_block == 16 && nr == 16 {
1083                                    // Full 8×16 AVX-512 tile (original path)
1084                                    unsafe {
1085                                        avx512_microkernel_8x16_rowmajor(
1086                                            kc_block,
1087                                            a_panel.as_ptr(),
1088                                            b_panel.as_ptr(),
1089                                            c.as_mut_ptr().add((ic + ir) * n + (jc + jr)),
1090                                            n,
1091                                        );
1092                                    }
1093                                } else {
1094                                    // Scalar fallback for edge tiles
1095                                    for ir_local in 0..mr_block {
1096                                        for jr_local in 0..nr_block {
1097                                            let mut sum =
1098                                                c[(ic + ir + ir_local) * n + (jc + jr + jr_local)];
1099                                            for p in 0..kc_block {
1100                                                sum += a_panel[p * mr + ir_local]
1101                                                    * b_panel[p * nr + jr_local];
1102                                            }
1103                                            c[(ic + ir + ir_local) * n + (jc + jr + jr_local)] =
1104                                                sum;
1105                                        }
1106                                    }
1107                                }
1108                            }
1109                        }
1110                    }
1111                }
1112            }
1113            // Record profiler event if active (C-AVX512-PROF-001)
1114            if let (Some(prof), Some(s)) = (profiler.as_mut(), start) {
1115                prof.record_avx512_blis(m, n, k, s.elapsed());
1116            }
1117
1118            Ok(())
1119        })
1120    })
1121}
1122
1123/// AVX-512 broadcast-B BLIS GEMM — MR=64, NR=6 (faer-style).
1124///
1125/// Key difference from broadcast-A path:
1126/// - A is loaded as zmm vectors (64 elements = 4 zmm per K step)
1127/// - B is broadcast as scalars (6 per K step)
1128/// - 24 FMA accumulators = 75% register utilization (matching faer)
1129/// - NR=6 → B panel is tiny (6×KC×4 bytes) → KC can stay large (256+)
1130///
1131/// This avoids the KC-halving problem that killed the 8×48 attempt.
1132#[cfg(target_arch = "x86_64")]
1133#[target_feature(enable = "avx512f", enable = "fma")]
1134unsafe fn gemm_blis_avx512_bcast_b(
1135    m: usize,
1136    n: usize,
1137    k: usize,
1138    a: &[f32],
1139    b: &[f32],
1140    c: &mut [f32],
1141) -> Result<(), TruenoError> {
1142    let blk = super::cache_topology::blocking_64x6_bcast_b();
1143    let mr = blk.mr; // 64
1144    let nr = blk.nr; // 6
1145    let mc = blk.mc.min(m);
1146    let nc = blk.nc.min(n);
1147    let kc = blk.kc;
1148
1149    // Allocate packing buffers
1150    let a_panels = (mc + mr - 1) / mr;
1151    let needed_a = a_panels * mr * kc;
1152    let b_panels = (nc + nr - 1) / nr;
1153    let needed_b = b_panels * nr * kc;
1154
1155    let mut packed_a = vec![0.0f32; needed_a];
1156    let mut packed_b = vec![0.0f32; needed_b];
1157
1158    for jc in (0..n).step_by(nc) {
1159        let nc_block = nc.min(n - jc);
1160
1161        for pc in (0..k).step_by(kc) {
1162            let kc_block = kc.min(k - pc);
1163
1164            // Pack B with NR=6
1165            pack_b_block_generic(b, n, pc, jc, kc_block, nc_block, nr, &mut packed_b);
1166
1167            for ic in (0..m).step_by(mc) {
1168                let mc_block = mc.min(m - ic);
1169
1170                // Pack A with MR=64 (generic column-major packing)
1171                pack_a_block_generic(a, k, ic, pc, mc_block, kc_block, mr, &mut packed_a);
1172
1173                let panels_m = (mc_block + mr - 1) / mr;
1174                let panels_n = (nc_block + nr - 1) / nr;
1175
1176                for ir_panel in 0..panels_m {
1177                    let ir = ir_panel * mr;
1178                    let mr_block = mr.min(mc_block - ir);
1179
1180                    for jr_panel in 0..panels_n {
1181                        let jr = jr_panel * nr;
1182                        let nr_block = nr.min(nc_block - jr);
1183
1184                        let a_panel = &packed_a[ir_panel * mr * kc_block..];
1185                        let b_panel = &packed_b[jr_panel * nr * kc_block..];
1186
1187                        if mr_block == 64 && nr_block == 6 {
1188                            // Full 64×6 broadcast-B tile
1189                            unsafe {
1190                                super::microkernels::codegen::microkernel_64x6_avx512_bcast_b(
1191                                    kc_block,
1192                                    a_panel.as_ptr(),
1193                                    b_panel.as_ptr(),
1194                                    c.as_mut_ptr().add((ic + ir) * n + (jc + jr)),
1195                                    n,
1196                                );
1197                            }
1198                        } else {
1199                            // Scalar fallback for edge tiles
1200                            for ir_local in 0..mr_block {
1201                                for jr_local in 0..nr_block {
1202                                    let mut sum = 0.0f32;
1203                                    for p in 0..kc_block {
1204                                        sum +=
1205                                            a_panel[p * mr + ir_local] * b_panel[p * nr + jr_local];
1206                                    }
1207                                    c[(ic + ir + ir_local) * n + (jc + jr + jr_local)] += sum;
1208                                }
1209                            }
1210                        }
1211                    }
1212                }
1213            }
1214        }
1215    }
1216
1217    Ok(())
1218}
1219
1220/// Generic A-packing: column-major panels with arbitrary MR.
1221/// Packs A[row_start..row_start+rows][col_start..col_start+cols] into
1222/// panels of MR×cols (column-major within each panel).
1223fn pack_a_block_generic(
1224    a: &[f32],
1225    lda: usize,
1226    row_start: usize,
1227    col_start: usize,
1228    rows: usize,
1229    cols: usize,
1230    mr: usize,
1231    packed: &mut [f32],
1232) {
1233    let panels = (rows + mr - 1) / mr;
1234    let mut pack_idx = 0;
1235    for panel in 0..panels {
1236        let ir = panel * mr;
1237        let mr_actual = mr.min(rows - ir);
1238        for col in 0..cols {
1239            for row in 0..mr {
1240                if row < mr_actual {
1241                    packed[pack_idx] = a[(row_start + ir + row) * lda + col_start + col];
1242                } else {
1243                    packed[pack_idx] = 0.0;
1244                }
1245                pack_idx += 1;
1246            }
1247        }
1248    }
1249}
1250
1251/// Safe public wrapper for broadcast-B GEMM (experimental).
1252/// Uses MR=64, NR=6 codegen microkernel with large KC.
1253#[cfg(target_arch = "x86_64")]
1254pub fn gemm_blis_broadcast_b(
1255    m: usize,
1256    n: usize,
1257    k: usize,
1258    a: &[f32],
1259    b: &[f32],
1260    c: &mut [f32],
1261) -> Result<(), TruenoError> {
1262    if a.len() != m * k || b.len() != k * n || c.len() != m * n {
1263        return Err(TruenoError::InvalidInput("Dimension mismatch".to_string()));
1264    }
1265    if std::arch::is_x86_feature_detected!("avx512f") {
1266        // SAFETY: AVX-512 detected, dimensions validated
1267        unsafe { gemm_blis_avx512_bcast_b(m, n, k, a, b, c) }
1268    } else {
1269        gemm_blis(m, n, k, a, b, c, None)
1270    }
1271}
1272
1273/// AVX-512 8×16 microkernel for row-major C (stride = n).
1274/// Processes 8 rows × 16 columns using zmm registers for C rows.
1275/// Each C row is 16 f32 = 1 zmm register. 8 rows = 8 zmm accumulators.
1276/// A: broadcast scalar to zmm. B: load 16 f32 (1 zmm) per K step.
1277/// 8 FMA ops per K step, each processing 16 elements = 2× throughput vs AVX2.
1278#[cfg(target_arch = "x86_64")]
1279#[target_feature(enable = "avx512f", enable = "fma")]
1280pub(super) unsafe fn avx512_microkernel_8x16_rowmajor(
1281    k: usize,
1282    a: *const f32, // MR=8 packed column-major
1283    b: *const f32, // NR=16 packed row-major
1284    c: *mut f32,
1285    ldc: usize, // row stride = n for row-major
1286) {
1287    use std::arch::x86_64::*;
1288
1289    // Load 8 C rows (each row = 16 f32 = 1 zmm)
1290    let mut c0 = _mm512_loadu_ps(c);
1291    let mut c1 = _mm512_loadu_ps(c.add(ldc));
1292    let mut c2 = _mm512_loadu_ps(c.add(2 * ldc));
1293    let mut c3 = _mm512_loadu_ps(c.add(3 * ldc));
1294    let mut c4 = _mm512_loadu_ps(c.add(4 * ldc));
1295    let mut c5 = _mm512_loadu_ps(c.add(5 * ldc));
1296    let mut c6 = _mm512_loadu_ps(c.add(6 * ldc));
1297    let mut c7 = _mm512_loadu_ps(c.add(7 * ldc));
1298
1299    // Main loop: for each K, load B[16] into zmm, broadcast A[i] to zmm, FMA.
1300    // NOTE: Manual 4-way K-unrolling was tested (2026-04-05) but REGRESSED
1301    // from 567→400 GFLOPS at 12T. The compiler (LLVM) already unrolls this
1302    // loop optimally. Manual unrolling causes register spills from 4× live
1303    // B vectors + address calculations exceeding the register budget.
1304    for p in 0..k {
1305        let b_row = _mm512_loadu_ps(b.add(p * 16));
1306        let ap = a.add(p * 8); // MR=8
1307
1308        c0 = _mm512_fmadd_ps(_mm512_set1_ps(*ap), b_row, c0);
1309        c1 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(1)), b_row, c1);
1310        c2 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(2)), b_row, c2);
1311        c3 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(3)), b_row, c3);
1312        c4 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(4)), b_row, c4);
1313        c5 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(5)), b_row, c5);
1314        c6 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(6)), b_row, c6);
1315        c7 = _mm512_fmadd_ps(_mm512_set1_ps(*ap.add(7)), b_row, c7);
1316    }
1317
1318    // Store 8 C rows
1319    _mm512_storeu_ps(c, c0);
1320    _mm512_storeu_ps(c.add(ldc), c1);
1321    _mm512_storeu_ps(c.add(2 * ldc), c2);
1322    _mm512_storeu_ps(c.add(3 * ldc), c3);
1323    _mm512_storeu_ps(c.add(4 * ldc), c4);
1324    _mm512_storeu_ps(c.add(5 * ldc), c5);
1325    _mm512_storeu_ps(c.add(6 * ldc), c6);
1326    _mm512_storeu_ps(c.add(7 * ldc), c7);
1327}
1328
1329/// AVX-512 8×32 microkernel for row-major C (stride = n).
1330/// Phase 4 (Appendix D): doubles NR from 16→32 to use 16 zmm accumulators.
1331/// Each C row spans 2 zmm (32 f32). 8 rows = 16 zmm accumulators.
1332/// B: 2 zmm loads per K step (32 columns). A: 8 scalar broadcasts.
1333/// FMAs per K step: 16 (2× the 8×16 kernel).
1334#[cfg(target_arch = "x86_64")]
1335#[target_feature(enable = "avx512f", enable = "fma")]
1336pub(super) unsafe fn avx512_microkernel_8x32_rowmajor(
1337    k: usize,
1338    a: *const f32, // MR=8 packed column-major
1339    b: *const f32, // NR=32 packed row-major
1340    c: *mut f32,
1341    ldc: usize, // row stride = n for row-major
1342) {
1343    use std::arch::x86_64::*;
1344
1345    // Load 8 C rows × 2 zmm halves = 16 accumulators
1346    let mut c0l = _mm512_loadu_ps(c);
1347    let mut c0h = _mm512_loadu_ps(c.add(16));
1348    let mut c1l = _mm512_loadu_ps(c.add(ldc));
1349    let mut c1h = _mm512_loadu_ps(c.add(ldc + 16));
1350    let mut c2l = _mm512_loadu_ps(c.add(2 * ldc));
1351    let mut c2h = _mm512_loadu_ps(c.add(2 * ldc + 16));
1352    let mut c3l = _mm512_loadu_ps(c.add(3 * ldc));
1353    let mut c3h = _mm512_loadu_ps(c.add(3 * ldc + 16));
1354    let mut c4l = _mm512_loadu_ps(c.add(4 * ldc));
1355    let mut c4h = _mm512_loadu_ps(c.add(4 * ldc + 16));
1356    let mut c5l = _mm512_loadu_ps(c.add(5 * ldc));
1357    let mut c5h = _mm512_loadu_ps(c.add(5 * ldc + 16));
1358    let mut c6l = _mm512_loadu_ps(c.add(6 * ldc));
1359    let mut c6h = _mm512_loadu_ps(c.add(6 * ldc + 16));
1360    let mut c7l = _mm512_loadu_ps(c.add(7 * ldc));
1361    let mut c7h = _mm512_loadu_ps(c.add(7 * ldc + 16));
1362
1363    // NOTE: Manual 2-way K-unrolling was tested (2026-04-05) but regressed
1364    // from 15.62ms→15.9ms at 1024 and 34.3→34.9µs at 128. The 8×32 kernel
1365    // uses 16 zmm accumulators + 2 B loads = 18 zmm live, leaving only 14
1366    // for unrolled state. LLVM's autounroll is better at managing this pressure.
1367    // Also tested MC=192 (from 96): regressed at 128-256 due to increased A-packing.
1368    for p in 0..k {
1369        let bl = _mm512_loadu_ps(b.add(p * 32));
1370        let bh = _mm512_loadu_ps(b.add(p * 32 + 16));
1371        let ap = a.add(p * 8);
1372
1373        let a0 = _mm512_set1_ps(*ap);
1374        c0l = _mm512_fmadd_ps(a0, bl, c0l);
1375        c0h = _mm512_fmadd_ps(a0, bh, c0h);
1376        let a1 = _mm512_set1_ps(*ap.add(1));
1377        c1l = _mm512_fmadd_ps(a1, bl, c1l);
1378        c1h = _mm512_fmadd_ps(a1, bh, c1h);
1379        let a2 = _mm512_set1_ps(*ap.add(2));
1380        c2l = _mm512_fmadd_ps(a2, bl, c2l);
1381        c2h = _mm512_fmadd_ps(a2, bh, c2h);
1382        let a3 = _mm512_set1_ps(*ap.add(3));
1383        c3l = _mm512_fmadd_ps(a3, bl, c3l);
1384        c3h = _mm512_fmadd_ps(a3, bh, c3h);
1385        let a4 = _mm512_set1_ps(*ap.add(4));
1386        c4l = _mm512_fmadd_ps(a4, bl, c4l);
1387        c4h = _mm512_fmadd_ps(a4, bh, c4h);
1388        let a5 = _mm512_set1_ps(*ap.add(5));
1389        c5l = _mm512_fmadd_ps(a5, bl, c5l);
1390        c5h = _mm512_fmadd_ps(a5, bh, c5h);
1391        let a6 = _mm512_set1_ps(*ap.add(6));
1392        c6l = _mm512_fmadd_ps(a6, bl, c6l);
1393        c6h = _mm512_fmadd_ps(a6, bh, c6h);
1394        let a7 = _mm512_set1_ps(*ap.add(7));
1395        c7l = _mm512_fmadd_ps(a7, bl, c7l);
1396        c7h = _mm512_fmadd_ps(a7, bh, c7h);
1397    }
1398
1399    // Store 8 C rows × 2 zmm
1400    _mm512_storeu_ps(c, c0l);
1401    _mm512_storeu_ps(c.add(16), c0h);
1402    _mm512_storeu_ps(c.add(ldc), c1l);
1403    _mm512_storeu_ps(c.add(ldc + 16), c1h);
1404    _mm512_storeu_ps(c.add(2 * ldc), c2l);
1405    _mm512_storeu_ps(c.add(2 * ldc + 16), c2h);
1406    _mm512_storeu_ps(c.add(3 * ldc), c3l);
1407    _mm512_storeu_ps(c.add(3 * ldc + 16), c3h);
1408    _mm512_storeu_ps(c.add(4 * ldc), c4l);
1409    _mm512_storeu_ps(c.add(4 * ldc + 16), c4h);
1410    _mm512_storeu_ps(c.add(5 * ldc), c5l);
1411    _mm512_storeu_ps(c.add(5 * ldc + 16), c5h);
1412    _mm512_storeu_ps(c.add(6 * ldc), c6l);
1413    _mm512_storeu_ps(c.add(6 * ldc + 16), c6h);
1414    _mm512_storeu_ps(c.add(7 * ldc), c7l);
1415    _mm512_storeu_ps(c.add(7 * ldc + 16), c7h);
1416}
1417
1418/// Pack B block with NR=16 row-major panels for AVX-512.
1419/// Each panel is KC × 16, stored as kc_block × nr contiguous.
1420pub(super) fn pack_b_block_nr16(
1421    b: &[f32],
1422    ldb: usize,
1423    pc: usize,
1424    jc: usize,
1425    kc: usize,
1426    nc: usize,
1427    packed: &mut [f32],
1428) {
1429    let nr = 16;
1430    let panels = (nc + nr - 1) / nr;
1431    for panel in 0..panels {
1432        let j_start = panel * nr;
1433        let nr_local = nr.min(nc - j_start);
1434        for p in 0..kc {
1435            for j in 0..nr_local {
1436                packed[panel * nr * kc + p * nr + j] = b[(pc + p) * ldb + (jc + j_start + j)];
1437            }
1438            // Zero-pad if nr_local < 16
1439            for j in nr_local..nr {
1440                packed[panel * nr * kc + p * nr + j] = 0.0;
1441            }
1442        }
1443    }
1444}
1445
1446/// Pack B block with generic NR for AVX-512 (NR=32 for 8×32 microkernel).
1447/// For full NR=32 panels with contiguous source, uses AVX-512 vectorized copy
1448/// (2 zmm loads + 2 zmm stores per K step vs 32 scalar copies).
1449pub(super) fn pack_b_block_generic(
1450    b: &[f32],
1451    ldb: usize,
1452    pc: usize,
1453    jc: usize,
1454    kc: usize,
1455    nc: usize,
1456    nr: usize,
1457    packed: &mut [f32],
1458) {
1459    #[cfg(target_arch = "x86_64")]
1460    if nr == 32 && std::arch::is_x86_feature_detected!("avx512f") {
1461        // SAFETY: AVX-512 detected, nr=32 = 2 zmm
1462        unsafe {
1463            pack_b_block_nr32_avx512(b, ldb, pc, jc, kc, nc, packed);
1464        }
1465        return;
1466    }
1467
1468    let panels = (nc + nr - 1) / nr;
1469    for panel in 0..panels {
1470        let j_start = panel * nr;
1471        let nr_local = nr.min(nc - j_start);
1472        for p in 0..kc {
1473            let dst_base = panel * nr * kc + p * nr;
1474            for j in 0..nr_local {
1475                packed[dst_base + j] = b[(pc + p) * ldb + (jc + j_start + j)];
1476            }
1477            for j in nr_local..nr {
1478                packed[dst_base + j] = 0.0;
1479            }
1480        }
1481    }
1482}
1483
1484/// AVX-512 optimized B packing for NR=32 (2 zmm per K step).
1485/// Full panels: 2× _mm512_loadu_ps + 2× _mm512_storeu_ps.
1486/// Edge panels: scalar fallback for partial rows.
1487#[cfg(target_arch = "x86_64")]
1488#[target_feature(enable = "avx512f")]
1489unsafe fn pack_b_block_nr32_avx512(
1490    b: &[f32],
1491    ldb: usize,
1492    pc: usize,
1493    jc: usize,
1494    kc: usize,
1495    nc: usize,
1496    packed: &mut [f32],
1497) {
1498    use std::arch::x86_64::*;
1499
1500    let nr = 32;
1501    let panels = (nc + nr - 1) / nr;
1502    for panel in 0..panels {
1503        let j_start = panel * nr;
1504        let nr_local = nr.min(nc - j_start);
1505
1506        if nr_local == 32 {
1507            // Full panel: SIMD copy — 2 zmm per K step, 2-way K-unrolled.
1508            // CGP-DBUF: unrolling amortizes loop overhead (~1 cycle/iter saved).
1509            let panel_base = panel * nr * kc;
1510            let b_col = jc + j_start;
1511            let kc2 = kc / 2 * 2;
1512            let mut p = 0;
1513            while p < kc2 {
1514                let src0 = b.as_ptr().add((pc + p) * ldb + b_col);
1515                let src1 = b.as_ptr().add((pc + p + 1) * ldb + b_col);
1516                let dst0 = packed.as_mut_ptr().add(panel_base + p * nr);
1517                let dst1 = packed.as_mut_ptr().add(panel_base + (p + 1) * nr);
1518                let v0a = _mm512_loadu_ps(src0);
1519                let v0b = _mm512_loadu_ps(src0.add(16));
1520                let v1a = _mm512_loadu_ps(src1);
1521                let v1b = _mm512_loadu_ps(src1.add(16));
1522                _mm512_storeu_ps(dst0, v0a);
1523                _mm512_storeu_ps(dst0.add(16), v0b);
1524                _mm512_storeu_ps(dst1, v1a);
1525                _mm512_storeu_ps(dst1.add(16), v1b);
1526                p += 2;
1527            }
1528            // Remainder K
1529            while p < kc {
1530                let src = b.as_ptr().add((pc + p) * ldb + b_col);
1531                let dst = packed.as_mut_ptr().add(panel_base + p * nr);
1532                let v0 = _mm512_loadu_ps(src);
1533                let v1 = _mm512_loadu_ps(src.add(16));
1534                _mm512_storeu_ps(dst, v0);
1535                _mm512_storeu_ps(dst.add(16), v1);
1536                p += 1;
1537            }
1538        } else {
1539            // Edge panel: scalar with zero-padding
1540            for p in 0..kc {
1541                let dst_base = panel * nr * kc + p * nr;
1542                for j in 0..nr_local {
1543                    packed[dst_base + j] = b[(pc + p) * ldb + (jc + j_start + j)];
1544                }
1545                for j in nr_local..nr {
1546                    packed[dst_base + j] = 0.0;
1547                }
1548            }
1549        }
1550    }
1551}
1552
1553/// BLIS 5-loop GEMM with NR=8 and row-major C SIMD load/store (AVX2).
1554///
1555/// Key optimization vs standard BLIS: C rows are loaded/stored with
1556/// `_mm256_loadu_ps`/`_mm256_storeu_ps` (NR=8 = 1 ymm per row).
1557/// This replaces 96 scalar C ops per tile with 16 SIMD ops.
1558/// Matches matrixmultiply's approach (MR=8, NR=8, contiguous C rows).
1559///
1560/// Uses existing `pack_a_block` (MR=8) and `pack_b_block_512` (NR=8).
1561/// Inner loop: broadcast packed A elements, load packed B row, FMA.
1562#[cfg(target_arch = "x86_64")]
1563#[target_feature(enable = "avx2", enable = "fma")]
1564unsafe fn gemm_blis_nr8_rowmajor_c(
1565    m: usize,
1566    n: usize,
1567    k: usize,
1568    a: &[f32],
1569    b: &[f32],
1570    c: &mut [f32],
1571) -> Result<(), TruenoError> {
1572    use std::arch::x86_64::*;
1573
1574    // Cache blocking: match matrixmultiply's parameters.
1575    // MC=64: A~ = MC×KC×4 = 64KB, fits in L2 (1MB/core on Zen 4).
1576    // KC=256: B panel per tile = KC×8×4 = 8KB, fits in L1 (32KB).
1577    // NC=1024: B~ = KC×NC×4 = 1MB, fits in L3 (~5MB/CCX on Zen 4).
1578    let mc = 64_usize.min(m);
1579    let nc = 1024_usize.min(n);
1580    let kc_param = KC;
1581    let nr = 8_usize; // NR=8 for ymm-width C rows
1582    let mr = MR; // MR=8
1583
1584    TL_PACKED_A.with(|tl_a| {
1585        TL_PACKED_B.with(|tl_b| {
1586            let mut packed_a = tl_a.borrow_mut();
1587            let mut packed_b = tl_b.borrow_mut();
1588
1589            let needed_a = packed_a_size(mc, kc_param);
1590            let needed_b = packed_b_size_512(kc_param, nc); // NR=8 packing
1591            if packed_a.len() < needed_a {
1592                packed_a.resize(needed_a, 0.0);
1593            }
1594            if packed_b.len() < needed_b {
1595                packed_b.resize(needed_b, 0.0);
1596            }
1597
1598            // BLIS 5-loop with NR=8 packing and row-major C
1599            for jc in (0..n).step_by(nc) {
1600                let nc_block = nc.min(n - jc);
1601
1602                for pc in (0..k).step_by(kc_param) {
1603                    let kc_block = kc_param.min(k - pc);
1604
1605                    // Pack B with NR=8
1606                    pack_b_block_512(b, n, pc, jc, kc_block, nc_block, &mut packed_b);
1607
1608                    for ic in (0..m).step_by(mc) {
1609                        let mc_block = mc.min(m - ic);
1610
1611                        // Pack A with MR=8
1612                        pack_a_block(a, k, ic, pc, mc_block, kc_block, &mut packed_a);
1613
1614                        // Microkernel loop: 8×8 tiles with row-major C
1615                        let panels_m = (mc_block + mr - 1) / mr;
1616                        let panels_n = (nc_block + nr - 1) / nr;
1617
1618                        for ir_panel in 0..panels_m {
1619                            let ir = ir_panel * mr;
1620                            let mr_block = mr.min(mc_block - ir);
1621
1622                            for jr_panel in 0..panels_n {
1623                                let jr = jr_panel * nr;
1624                                let nr_block = nr.min(nc_block - jr);
1625
1626                                let a_panel = &packed_a[ir_panel * mr * kc_block..];
1627                                let b_panel = &packed_b[jr_panel * nr * kc_block..];
1628
1629                                if mr_block == 8 && nr_block == 8 {
1630                                    // Full 8×8 tile: SIMD C load/store + FMA
1631                                    unsafe {
1632                                        let c_base = c.as_mut_ptr().add((ic + ir) * n + (jc + jr));
1633
1634                                        // Load 8 C rows (each 8 f32 = 1 ymm)
1635                                        let mut c0 = _mm256_loadu_ps(c_base);
1636                                        let mut c1 = _mm256_loadu_ps(c_base.add(n));
1637                                        let mut c2 = _mm256_loadu_ps(c_base.add(2 * n));
1638                                        let mut c3 = _mm256_loadu_ps(c_base.add(3 * n));
1639                                        let mut c4 = _mm256_loadu_ps(c_base.add(4 * n));
1640                                        let mut c5 = _mm256_loadu_ps(c_base.add(5 * n));
1641                                        let mut c6 = _mm256_loadu_ps(c_base.add(6 * n));
1642                                        let mut c7 = _mm256_loadu_ps(c_base.add(7 * n));
1643
1644                                        let ap = a_panel.as_ptr();
1645                                        let bp = b_panel.as_ptr();
1646
1647                                        // 4-way K-unrolled inner loop
1648                                        let k4 = kc_block / 4;
1649                                        let k_rem = kc_block % 4;
1650
1651                                        for p4 in 0..k4 {
1652                                            let p = p4 * 4;
1653
1654                                            let b_row = _mm256_loadu_ps(bp.add(p * 8));
1655                                            c0 = _mm256_fmadd_ps(
1656                                                _mm256_broadcast_ss(&*ap.add(p * 8)),
1657                                                b_row,
1658                                                c0,
1659                                            );
1660                                            c1 = _mm256_fmadd_ps(
1661                                                _mm256_broadcast_ss(&*ap.add(p * 8 + 1)),
1662                                                b_row,
1663                                                c1,
1664                                            );
1665                                            c2 = _mm256_fmadd_ps(
1666                                                _mm256_broadcast_ss(&*ap.add(p * 8 + 2)),
1667                                                b_row,
1668                                                c2,
1669                                            );
1670                                            c3 = _mm256_fmadd_ps(
1671                                                _mm256_broadcast_ss(&*ap.add(p * 8 + 3)),
1672                                                b_row,
1673                                                c3,
1674                                            );
1675                                            c4 = _mm256_fmadd_ps(
1676                                                _mm256_broadcast_ss(&*ap.add(p * 8 + 4)),
1677                                                b_row,
1678                                                c4,
1679                                            );
1680                                            c5 = _mm256_fmadd_ps(
1681                                                _mm256_broadcast_ss(&*ap.add(p * 8 + 5)),
1682                                                b_row,
1683                                                c5,
1684                                            );
1685                                            c6 = _mm256_fmadd_ps(
1686                                                _mm256_broadcast_ss(&*ap.add(p * 8 + 6)),
1687                                                b_row,
1688                                                c6,
1689                                            );
1690                                            c7 = _mm256_fmadd_ps(
1691                                                _mm256_broadcast_ss(&*ap.add(p * 8 + 7)),
1692                                                b_row,
1693                                                c7,
1694                                            );
1695
1696                                            let b_row = _mm256_loadu_ps(bp.add((p + 1) * 8));
1697                                            c0 = _mm256_fmadd_ps(
1698                                                _mm256_broadcast_ss(&*ap.add((p + 1) * 8)),
1699                                                b_row,
1700                                                c0,
1701                                            );
1702                                            c1 = _mm256_fmadd_ps(
1703                                                _mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 1)),
1704                                                b_row,
1705                                                c1,
1706                                            );
1707                                            c2 = _mm256_fmadd_ps(
1708                                                _mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 2)),
1709                                                b_row,
1710                                                c2,
1711                                            );
1712                                            c3 = _mm256_fmadd_ps(
1713                                                _mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 3)),
1714                                                b_row,
1715                                                c3,
1716                                            );
1717                                            c4 = _mm256_fmadd_ps(
1718                                                _mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 4)),
1719                                                b_row,
1720                                                c4,
1721                                            );
1722                                            c5 = _mm256_fmadd_ps(
1723                                                _mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 5)),
1724                                                b_row,
1725                                                c5,
1726                                            );
1727                                            c6 = _mm256_fmadd_ps(
1728                                                _mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 6)),
1729                                                b_row,
1730                                                c6,
1731                                            );
1732                                            c7 = _mm256_fmadd_ps(
1733                                                _mm256_broadcast_ss(&*ap.add((p + 1) * 8 + 7)),
1734                                                b_row,
1735                                                c7,
1736                                            );
1737
1738                                            let b_row = _mm256_loadu_ps(bp.add((p + 2) * 8));
1739                                            c0 = _mm256_fmadd_ps(
1740                                                _mm256_broadcast_ss(&*ap.add((p + 2) * 8)),
1741                                                b_row,
1742                                                c0,
1743                                            );
1744                                            c1 = _mm256_fmadd_ps(
1745                                                _mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 1)),
1746                                                b_row,
1747                                                c1,
1748                                            );
1749                                            c2 = _mm256_fmadd_ps(
1750                                                _mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 2)),
1751                                                b_row,
1752                                                c2,
1753                                            );
1754                                            c3 = _mm256_fmadd_ps(
1755                                                _mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 3)),
1756                                                b_row,
1757                                                c3,
1758                                            );
1759                                            c4 = _mm256_fmadd_ps(
1760                                                _mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 4)),
1761                                                b_row,
1762                                                c4,
1763                                            );
1764                                            c5 = _mm256_fmadd_ps(
1765                                                _mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 5)),
1766                                                b_row,
1767                                                c5,
1768                                            );
1769                                            c6 = _mm256_fmadd_ps(
1770                                                _mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 6)),
1771                                                b_row,
1772                                                c6,
1773                                            );
1774                                            c7 = _mm256_fmadd_ps(
1775                                                _mm256_broadcast_ss(&*ap.add((p + 2) * 8 + 7)),
1776                                                b_row,
1777                                                c7,
1778                                            );
1779
1780                                            let b_row = _mm256_loadu_ps(bp.add((p + 3) * 8));
1781                                            c0 = _mm256_fmadd_ps(
1782                                                _mm256_broadcast_ss(&*ap.add((p + 3) * 8)),
1783                                                b_row,
1784                                                c0,
1785                                            );
1786                                            c1 = _mm256_fmadd_ps(
1787                                                _mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 1)),
1788                                                b_row,
1789                                                c1,
1790                                            );
1791                                            c2 = _mm256_fmadd_ps(
1792                                                _mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 2)),
1793                                                b_row,
1794                                                c2,
1795                                            );
1796                                            c3 = _mm256_fmadd_ps(
1797                                                _mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 3)),
1798                                                b_row,
1799                                                c3,
1800                                            );
1801                                            c4 = _mm256_fmadd_ps(
1802                                                _mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 4)),
1803                                                b_row,
1804                                                c4,
1805                                            );
1806                                            c5 = _mm256_fmadd_ps(
1807                                                _mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 5)),
1808                                                b_row,
1809                                                c5,
1810                                            );
1811                                            c6 = _mm256_fmadd_ps(
1812                                                _mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 6)),
1813                                                b_row,
1814                                                c6,
1815                                            );
1816                                            c7 = _mm256_fmadd_ps(
1817                                                _mm256_broadcast_ss(&*ap.add((p + 3) * 8 + 7)),
1818                                                b_row,
1819                                                c7,
1820                                            );
1821                                        }
1822
1823                                        let base_rem = k4 * 4;
1824                                        for rp in 0..k_rem {
1825                                            let pp = base_rem + rp;
1826                                            let b_row = _mm256_loadu_ps(bp.add(pp * 8));
1827                                            c0 = _mm256_fmadd_ps(
1828                                                _mm256_broadcast_ss(&*ap.add(pp * 8)),
1829                                                b_row,
1830                                                c0,
1831                                            );
1832                                            c1 = _mm256_fmadd_ps(
1833                                                _mm256_broadcast_ss(&*ap.add(pp * 8 + 1)),
1834                                                b_row,
1835                                                c1,
1836                                            );
1837                                            c2 = _mm256_fmadd_ps(
1838                                                _mm256_broadcast_ss(&*ap.add(pp * 8 + 2)),
1839                                                b_row,
1840                                                c2,
1841                                            );
1842                                            c3 = _mm256_fmadd_ps(
1843                                                _mm256_broadcast_ss(&*ap.add(pp * 8 + 3)),
1844                                                b_row,
1845                                                c3,
1846                                            );
1847                                            c4 = _mm256_fmadd_ps(
1848                                                _mm256_broadcast_ss(&*ap.add(pp * 8 + 4)),
1849                                                b_row,
1850                                                c4,
1851                                            );
1852                                            c5 = _mm256_fmadd_ps(
1853                                                _mm256_broadcast_ss(&*ap.add(pp * 8 + 5)),
1854                                                b_row,
1855                                                c5,
1856                                            );
1857                                            c6 = _mm256_fmadd_ps(
1858                                                _mm256_broadcast_ss(&*ap.add(pp * 8 + 6)),
1859                                                b_row,
1860                                                c6,
1861                                            );
1862                                            c7 = _mm256_fmadd_ps(
1863                                                _mm256_broadcast_ss(&*ap.add(pp * 8 + 7)),
1864                                                b_row,
1865                                                c7,
1866                                            );
1867                                        }
1868
1869                                        // Store 8 C rows (SIMD)
1870                                        _mm256_storeu_ps(c_base, c0);
1871                                        _mm256_storeu_ps(c_base.add(n), c1);
1872                                        _mm256_storeu_ps(c_base.add(2 * n), c2);
1873                                        _mm256_storeu_ps(c_base.add(3 * n), c3);
1874                                        _mm256_storeu_ps(c_base.add(4 * n), c4);
1875                                        _mm256_storeu_ps(c_base.add(5 * n), c5);
1876                                        _mm256_storeu_ps(c_base.add(6 * n), c6);
1877                                        _mm256_storeu_ps(c_base.add(7 * n), c7);
1878                                    }
1879                                } else {
1880                                    // Edge tile: scalar fallback
1881                                    for p in 0..kc_block {
1882                                        for jj in 0..nr_block {
1883                                            let b_val = b_panel[p * nr + jj];
1884                                            for ii in 0..mr_block {
1885                                                c[(ic + ir + ii) * n + (jc + jr + jj)] +=
1886                                                    a_panel[p * mr + ii] * b_val;
1887                                            }
1888                                        }
1889                                    }
1890                                }
1891                            }
1892                        }
1893                    }
1894                }
1895            }
1896        });
1897    });
1898
1899    Ok(())
1900}
1901
1902/// AVX-512 BLIS 5-loop GEMM with packed 16×8 microkernel.
1903///
1904/// Uses MR_512=16, NR_512=8 packing for 2× compute density over AVX2 8×6.
1905/// C tiles loaded/stored with AVX-512 `_mm512_loadu_ps` (16 f32 per load).
1906/// Packing converts strided A/B into contiguous micro-panel layout for
1907/// sequential access in the microkernel.
1908#[cfg(target_arch = "x86_64")]
1909#[allow(dead_code)] // Retained for AVX-512-only systems; superseded by nr8_rowmajor_c on AVX2
1910fn gemm_blis_avx512_packed(
1911    m: usize,
1912    n: usize,
1913    k: usize,
1914    a: &[f32],
1915    b: &[f32],
1916    c: &mut [f32],
1917) -> Result<(), TruenoError> {
1918    let mc = MC_512.min(m);
1919    let nc = NC_512.min(n);
1920    let kc = KC_512.min(k);
1921
1922    let needed_a = packed_a_size_512(mc, kc);
1923    let needed_b = packed_b_size_512(kc, nc);
1924    let needed_c = MR_512 * NR_512;
1925
1926    TL_PACKED_A.with(|tl_a| {
1927        TL_PACKED_B.with(|tl_b| {
1928            TL_C_MICRO.with(|tl_c| {
1929                let mut packed_a = tl_a.borrow_mut();
1930                let mut packed_b = tl_b.borrow_mut();
1931                let mut c_micro = tl_c.borrow_mut();
1932
1933                if packed_a.len() < needed_a {
1934                    packed_a.resize(needed_a, 0.0);
1935                }
1936                if packed_b.len() < needed_b {
1937                    packed_b.resize(needed_b, 0.0);
1938                }
1939                if c_micro.len() < needed_c {
1940                    c_micro.resize(needed_c, 0.0);
1941                }
1942
1943                for jc in (0..n).step_by(NC_512) {
1944                    let nc_block = NC_512.min(n - jc);
1945
1946                    for pc in (0..k).step_by(KC_512) {
1947                        let kc_block = KC_512.min(k - pc);
1948
1949                        pack_b_block_512(b, n, pc, jc, kc_block, nc_block, &mut packed_b);
1950
1951                        for ic in (0..m).step_by(MC_512) {
1952                            let mc_block = MC_512.min(m - ic);
1953
1954                            pack_a_block_512(a, k, ic, pc, mc_block, kc_block, &mut packed_a);
1955
1956                            // AVX-512 macroblock: 16×8 tiles
1957                            for ir in (0..mc_block).step_by(MR_512) {
1958                                let mr_block = MR_512.min(mc_block - ir);
1959                                for jr in (0..nc_block).step_by(NR_512) {
1960                                    let nr_block = NR_512.min(nc_block - jr);
1961
1962                                    let a_panel = &packed_a[(ir / MR_512) * MR_512 * kc_block..];
1963                                    let b_panel = &packed_b[(jr / NR_512) * NR_512 * kc_block..];
1964
1965                                    // Load C tile (column-major for microkernel)
1966                                    for jj in 0..nr_block {
1967                                        for ii in 0..mr_block {
1968                                            c_micro[jj * MR_512 + ii] =
1969                                                c[(ic + ir + ii) * n + (jc + jr + jj)];
1970                                        }
1971                                        for ii in mr_block..MR_512 {
1972                                            c_micro[jj * MR_512 + ii] = 0.0;
1973                                        }
1974                                    }
1975                                    for jj in nr_block..NR_512 {
1976                                        for ii in 0..MR_512 {
1977                                            c_micro[jj * MR_512 + ii] = 0.0;
1978                                        }
1979                                    }
1980
1981                                    // Full tile → AVX-512, edge → scalar
1982                                    if mr_block == MR_512 && nr_block == NR_512 {
1983                                        // SAFETY: AVX-512 verified by is_x86_feature_detected
1984                                        // in caller. Packed layout matches microkernel.
1985                                        unsafe {
1986                                            microkernel_16x8_avx512(
1987                                                kc_block,
1988                                                a_panel.as_ptr(),
1989                                                b_panel.as_ptr(),
1990                                                c_micro.as_mut_ptr(),
1991                                                MR_512,
1992                                            );
1993                                        }
1994                                    } else {
1995                                        // Edge tiles: scalar
1996                                        for p in 0..kc_block {
1997                                            for jj in 0..NR_512 {
1998                                                let b_val = b_panel[p * NR_512 + jj];
1999                                                for ii in 0..MR_512 {
2000                                                    c_micro[jj * MR_512 + ii] +=
2001                                                        a_panel[p * MR_512 + ii] * b_val;
2002                                                }
2003                                            }
2004                                        }
2005                                    }
2006
2007                                    // Store C tile
2008                                    for jj in 0..nr_block {
2009                                        for ii in 0..mr_block {
2010                                            c[(ic + ir + ii) * n + (jc + jr + jj)] =
2011                                                c_micro[jj * MR_512 + ii];
2012                                        }
2013                                    }
2014                                }
2015                            }
2016                        }
2017                    }
2018                }
2019            });
2020        });
2021    });
2022
2023    Ok(())
2024}
2025
2026/// BLIS-style blocked GEMM with pre-packed B matrix.
2027///
2028/// Identical to [`gemm_blis`] but skips B packing entirely, reading packed
2029/// tiles from `prepacked_b` instead. This eliminates redundant B packing
2030/// when the same weight matrix is reused across calls (e.g., in parallel GEMM
2031/// where each thread would otherwise pack B independently).
2032///
2033/// # WAPR-KAIZEN Cycle 12
2034///
2035/// For encoder FFN: 16 threads × 2 GEMMs × 4 layers = 128 B packings eliminated.
2036pub fn gemm_blis_with_prepacked_b(
2037    m: usize,
2038    n: usize,
2039    k: usize,
2040    a: &[f32],
2041    prepacked_b: &PrepackedB,
2042    c: &mut [f32],
2043    mut profiler: Option<&mut BlisProfiler>,
2044) -> Result<(), TruenoError> {
2045    if a.len() != m * k {
2046        return Err(TruenoError::InvalidInput(format!(
2047            "A size mismatch: expected {}, got {}",
2048            m * k,
2049            a.len()
2050        )));
2051    }
2052    if c.len() != m * n {
2053        return Err(TruenoError::InvalidInput(format!(
2054            "C size mismatch: expected {}, got {}",
2055            m * n,
2056            c.len()
2057        )));
2058    }
2059    if prepacked_b.k != k || prepacked_b.n != n {
2060        return Err(TruenoError::InvalidInput(format!(
2061            "PrepackedB dimension mismatch: expected ({}, {}), got ({}, {})",
2062            k, n, prepacked_b.k, prepacked_b.n
2063        )));
2064    }
2065
2066    if m == 0 || n == 0 || k == 0 {
2067        return Ok(());
2068    }
2069
2070    let track_time = profiler.is_some();
2071    let start = if track_time { Some(Instant::now()) } else { None };
2072
2073    let mc = MC.min(m);
2074    let kc = KC.min(k);
2075
2076    let needed_a = packed_a_size(mc, kc);
2077    let needed_c = MR * NR;
2078
2079    // Only need A and C micro buffers — B is already packed
2080    TL_PACKED_A.with(|tl_a| {
2081        TL_C_MICRO.with(|tl_c| {
2082            let mut packed_a = tl_a.borrow_mut();
2083            let mut c_micro = tl_c.borrow_mut();
2084
2085            if packed_a.len() < needed_a {
2086                packed_a.resize(needed_a, 0.0);
2087            } else {
2088                packed_a[..needed_a].fill(0.0);
2089            }
2090            if c_micro.len() < needed_c {
2091                c_micro.resize(needed_c, 0.0);
2092            } else {
2093                c_micro[..needed_c].fill(0.0);
2094            }
2095
2096            for (jc_idx, jc) in (0..n).step_by(NC).enumerate() {
2097                let nc_block = NC.min(n - jc);
2098
2099                for (pc_idx, pc) in (0..k).step_by(KC).enumerate() {
2100                    let kc_block = KC.min(k - pc);
2101
2102                    // Use pre-packed B tile instead of runtime packing
2103                    let packed_b_tile = prepacked_b.tile(jc_idx, pc_idx);
2104
2105                    for ic in (0..m).step_by(MC) {
2106                        let mc_block = MC.min(m - ic);
2107
2108                        let pack_start = if track_time { Some(Instant::now()) } else { None };
2109                        pack_a_block(a, k, ic, pc, mc_block, kc_block, &mut packed_a);
2110                        record_prof(&mut profiler, BlisProfileLevel::Pack, pack_start, 0);
2111
2112                        compute_macroblock(
2113                            c,
2114                            &packed_a,
2115                            packed_b_tile,
2116                            &mut c_micro,
2117                            ic,
2118                            jc,
2119                            mc_block,
2120                            nc_block,
2121                            kc_block,
2122                            n,
2123                            &mut profiler,
2124                        );
2125                    }
2126                }
2127            }
2128
2129            if let (Some(prof), Some(s)) = (profiler, start) {
2130                prof.record(
2131                    BlisProfileLevel::Macro,
2132                    s.elapsed().as_nanos() as u64,
2133                    (2 * m * n * k) as u64,
2134                );
2135            }
2136        });
2137    });
2138
2139    Ok(())
2140}