Skip to main content

trueno/blis/
parallel.rs

1//! Parallel GEMM with Heijunka (load-leveling) scheduling.
2//!
3//! Uses Rayon for parallel execution when the `parallel` feature is enabled,
4//! with balanced M-dimension partitioning via [`HeijunkaScheduler`].
5
6use crate::error::TruenoError;
7
8use super::compute::{gemm_blis, gemm_blis_with_prepacked_b};
9use super::prepacked::PrepackedB;
10#[cfg(feature = "parallel")]
11use super::{MC, MR};
12
13/// Heijunka (load-leveling) scheduler for parallel GEMM
14#[derive(Debug, Clone)]
15pub struct HeijunkaScheduler {
16    /// Number of threads
17    pub num_threads: usize,
18    /// Target load variance threshold
19    pub variance_threshold: f32,
20}
21
22impl Default for HeijunkaScheduler {
23    fn default() -> Self {
24        #[cfg(feature = "parallel")]
25        let threads = rayon::current_num_threads();
26        #[cfg(not(feature = "parallel"))]
27        let threads = 1;
28
29        Self {
30            num_threads: threads,
31            variance_threshold: 0.05, // 5% variance target
32        }
33    }
34}
35
36impl HeijunkaScheduler {
37    /// Partition M dimension into balanced chunks
38    pub fn partition_m(&self, m: usize, mc: usize) -> Vec<std::ops::Range<usize>> {
39        let num_blocks = (m + mc - 1) / mc;
40        let blocks_per_thread = num_blocks / self.num_threads;
41        let remainder = num_blocks % self.num_threads;
42
43        let mut partitions = Vec::with_capacity(self.num_threads);
44        let mut start_block = 0;
45
46        for t in 0..self.num_threads {
47            let extra = if t < remainder { 1 } else { 0 };
48            let thread_blocks = blocks_per_thread + extra;
49
50            let start_row = start_block * mc;
51            let end_row = ((start_block + thread_blocks) * mc).min(m);
52
53            if start_row < end_row {
54                partitions.push(start_row..end_row);
55            }
56
57            start_block += thread_blocks;
58        }
59
60        partitions
61    }
62}
63
64/// Parallel BLIS GEMM using Rayon
65#[cfg(feature = "parallel")]
66pub fn gemm_blis_parallel(
67    m: usize,
68    n: usize,
69    k: usize,
70    a: &[f32],
71    b: &[f32],
72    c: &mut [f32],
73) -> Result<(), TruenoError> {
74    use rayon::prelude::*;
75    contract_pre_amdahl_speedup!();
76
77    // Dimension validation
78    if a.len() != m * k || b.len() != k * n || c.len() != m * n {
79        return Err(TruenoError::InvalidInput("Dimension mismatch".to_string()));
80    }
81
82    // Single-threaded threshold: 8M FLOPs ≈ 200³.
83    // Rayon dispatch costs ~3µs. For GEMM ≤128 (~4M FLOP, ~35µs compute),
84    // rayon overhead dominates. GEMM 256+ (33M FLOP, ~300µs) benefits.
85    let flops = m * n * k;
86    if flops < 8_000_000 {
87        return gemm_blis(m, n, k, a, b, c, None);
88    }
89
90    // Scale thread count to problem size and cache topology.
91    // cgp profile scaling measurements (2026-04-05, Threadripper 7960X 24C/48T):
92    //
93    //   256x256: 1T=27.8, 2T=34.5 (peak), 4T=35.2 → cap at 2
94    //   512x512: 1T=82.6, 4T=176 (peak), 8T=158 → cap at 4
95    //   1024x1024: 1T=106, 8T=489 (peak), 12T=417, 16T=450, 24T=426 → cap at 8
96    //
97    // Root cause for small-problem regression: L3 contention and thread spawn
98    // overhead (~40µs per thread::scope) dominate when compute < 1ms.
99    // Root cause for 1024 12T regression: cross-CCD L3 thrashing. 8T fits
100    // in a single CCD (12 cores, 32MB L3). 12+ threads span both CCDs.
101    let phys_cores = num_cpus::get_physical();
102    let max_threads = if flops < 64_000_000 {
103        // 256³ and below: barely benefits from parallelism
104        2.min(phys_cores)
105    } else if flops < 512_000_000 {
106        // 512³ range: 4T is peak, >4 regresses due to L3 contention
107        4.min(phys_cores)
108    } else if flops < 4_000_000_000 {
109        // 1024³ range (~2B FLOPs): 8T is empirical peak (626 GFLOPS).
110        // 12T regresses to 559 GFLOPS due to cross-CCD L3 thrashing — each thread
111        // independently packs B, and 12 copies × ~1MB packed_b exceeds one CCD's
112        // 32MB L3 share. Capping at 8 keeps all threads on one CCD.
113        // Measured 2026-04-05 on Threadripper 7960X (2 CCDs × 12 cores).
114        8.min(phys_cores)
115    } else {
116        // Very large (>4B FLOPs): use phys_cores/2 (one thread per CCD core).
117        // Beyond phys_cores/2, SMT contention regresses AVX-512 throughput.
118        (phys_cores / 2).max(8).min(phys_cores)
119    };
120
121    let mut scheduler = HeijunkaScheduler::default();
122    scheduler.num_threads = scheduler.num_threads.min(max_threads);
123    let ps = if m <= MC { MR.max(m / scheduler.num_threads) } else { MC };
124    let partitions = scheduler.partition_m(m, ps);
125
126    // NEGATIVE RESULT (2026-04-06): shared-B per (jc,pc) block REGRESSED 597→318 GFLOPS.
127    // Root cause: Rayon barrier after each K-tile pack forces thread synchronization.
128    // With K=4 tiles for 1024×1024, threads stall 4× per GEMM waiting for B pack.
129    // Per-thread independent packing (below) avoids synchronization entirely.
130    // The 8× redundant B packing (~8MB) fits in L3 (64MB) and eliminates barriers.
131    // Future fix: producer-consumer B packing (one thread packs while others compute).
132    let c_ptr = c.as_mut_ptr() as usize;
133
134    partitions.into_par_iter().for_each(|m_range| {
135        let m_local = m_range.len();
136        let m_start = m_range.start;
137
138        let a_local = &a[m_start * k..(m_start + m_local) * k];
139
140        // SAFETY: Each thread accesses a disjoint row range of C.
141        let c_local = unsafe {
142            let ptr = c_ptr as *mut f32;
143            std::slice::from_raw_parts_mut(ptr.add(m_start * n), m_local * n)
144        };
145
146        let _ = gemm_blis(m_local, n, k, a_local, b, c_local, None);
147    });
148
149    Ok(())
150}
151
152/// Parallel GEMM with shared packed-B: pack B once per (jc,pc) block,
153/// distribute M-slices across threads. Each thread only packs its own A.
154/// This eliminates O(threads) redundant B packings.
155///
156/// BLIS loop structure:
157///   for jc (N tiles):      ← sequential
158///     for pc (K tiles):    ← sequential, pack B ONCE
159///       for ic (M tiles):  ← PARALLEL across threads
160///         pack A_local
161///         microkernel(packed_a, shared_packed_b, c_local)
162#[cfg(feature = "parallel")]
163pub fn gemm_blis_parallel_shared_b(
164    m: usize,
165    n: usize,
166    k: usize,
167    a: &[f32],
168    b: &[f32],
169    c: &mut [f32],
170) -> Result<(), TruenoError> {
171    use rayon::prelude::*;
172
173    if a.len() != m * k || b.len() != k * n || c.len() != m * n {
174        return Err(TruenoError::InvalidInput("Dimension mismatch".to_string()));
175    }
176
177    // For small problems, use single-thread path
178    let flops = m * n * k;
179    if flops < 8_000_000 {
180        return gemm_blis(m, n, k, a, b, c, None);
181    }
182
183    // Require AVX-512 for the 8×32 microkernel
184    #[cfg(target_arch = "x86_64")]
185    if !std::arch::is_x86_feature_detected!("avx512f") {
186        return gemm_blis(m, n, k, a, b, c, None);
187    }
188
189    let phys_cores = num_cpus::get_physical();
190    let max_threads = if flops < 64_000_000 {
191        2.min(phys_cores)
192    } else if flops < 512_000_000 {
193        4.min(phys_cores)
194    } else if flops < 4_000_000_000 {
195        // Shared-B means less L3 pressure per thread, so we can potentially
196        // use more threads than the per-thread-B path. Try phys_cores/2.
197        (phys_cores / 2).max(8).min(phys_cores)
198    } else {
199        (phys_cores / 2).max(8).min(phys_cores)
200    };
201
202    let blk = super::cache_topology::blocking_8x32();
203    let mr = blk.mr; // 8
204    let nr = blk.nr; // 32
205    let mc = blk.mc.min(m);
206    let nc = blk.nc.min(n);
207    let kc = blk.kc;
208
209    // Shared packed B: one allocation for the largest B panel
210    let b_panels = (nc + nr - 1) / nr;
211    let packed_b_size = b_panels * nr * kc;
212    let mut packed_b = vec![0.0f32; packed_b_size];
213
214    let c_ptr = c.as_mut_ptr() as usize;
215    let num_threads = max_threads.min(rayon::current_num_threads());
216
217    for jc in (0..n).step_by(nc) {
218        let nc_block = nc.min(n - jc);
219
220        for pc in (0..k).step_by(kc) {
221            let kc_block = kc.min(k - pc);
222
223            // Pack B ONCE (sequential) — shared by all threads
224            super::compute::pack_b_block_generic(
225                b,
226                n,
227                pc,
228                jc,
229                kc_block,
230                nc_block,
231                nr,
232                &mut packed_b,
233            );
234            let shared_b: &[f32] = &packed_b;
235
236            // Parallel ic loop: each thread gets a slice of M
237            let m_per_thread = ((m + num_threads - 1) / num_threads + mr - 1) / mr * mr;
238
239            (0..num_threads).into_par_iter().for_each(|tid| {
240                let ic_start = tid * m_per_thread;
241                if ic_start >= m {
242                    return;
243                }
244                let ic_end = (ic_start + m_per_thread).min(m);
245
246                // Thread-local packed A — reuse across (jc, pc) iterations
247                // via thread_local! to avoid heap allocation per iteration.
248                thread_local! {
249                    static TL_A: std::cell::RefCell<Vec<f32>> =
250                        const { std::cell::RefCell::new(Vec::new()) };
251                }
252                TL_A.with(|tl| {
253                    let a_panels = (m_per_thread + mr - 1) / mr;
254                    let needed = a_panels * mr * kc_block;
255                    let mut packed_a = tl.borrow_mut();
256                    if packed_a.len() < needed {
257                        packed_a.resize(needed, 0.0);
258                    }
259
260                    let panels_n = (nc_block + nr - 1) / nr;
261
262                    for ic in (ic_start..ic_end).step_by(mc) {
263                        let mc_block = mc.min(ic_end - ic);
264
265                        super::packing::pack_a_block(
266                            a,
267                            k,
268                            ic,
269                            pc,
270                            mc_block,
271                            kc_block,
272                            &mut packed_a,
273                        );
274
275                        let panels_m = (mc_block + mr - 1) / mr;
276
277                        for ir_panel in 0..panels_m {
278                            let ir = ir_panel * mr;
279                            let mr_block = mr.min(mc_block - ir);
280
281                            for jr_panel in 0..panels_n {
282                                let jr = jr_panel * nr;
283                                let nr_block = nr.min(nc_block - jr);
284
285                                let a_panel = &packed_a[ir_panel * mr * kc_block..];
286                                let b_panel = &shared_b[jr_panel * nr * kc_block..];
287
288                                if mr_block == 8 && nr_block == 32 {
289                                    #[cfg(target_arch = "x86_64")]
290                                    unsafe {
291                                        super::compute::avx512_microkernel_8x32_rowmajor(
292                                            kc_block,
293                                            a_panel.as_ptr(),
294                                            b_panel.as_ptr(),
295                                            (c_ptr as *mut f32).add((ic + ir) * n + (jc + jr)),
296                                            n,
297                                        );
298                                    }
299                                } else {
300                                    // Scalar fallback for edge tiles
301                                    for ir_local in 0..mr_block {
302                                        for jr_local in 0..nr_block {
303                                            let mut sum = 0.0f32;
304                                            for p in 0..kc_block {
305                                                sum += a_panel[p * mr + ir_local]
306                                                    * b_panel[p * nr + jr_local];
307                                            }
308                                            unsafe {
309                                                let c = c_ptr as *mut f32;
310                                                *c.add(
311                                                    (ic + ir + ir_local) * n + (jc + jr + jr_local),
312                                                ) += sum;
313                                            }
314                                        }
315                                    }
316                                }
317                            }
318                        }
319                    }
320                }); // TL_A.with
321            });
322        }
323    }
324
325    Ok(())
326}
327
328/// Non-parallel fallback
329#[cfg(not(feature = "parallel"))]
330pub fn gemm_blis_parallel(
331    m: usize,
332    n: usize,
333    k: usize,
334    a: &[f32],
335    b: &[f32],
336    c: &mut [f32],
337) -> Result<(), TruenoError> {
338    gemm_blis(m, n, k, a, b, c, None)
339}
340
341/// Parallel BLIS GEMM with pre-packed B matrix.
342///
343/// Key optimization: the pre-packed B is shared immutably across all threads.
344/// Each thread only packs A (which differs per M partition). This eliminates
345/// N_threads × redundant B packings per GEMM call.
346///
347/// # WAPR-KAIZEN Cycle 12
348///
349/// For 16-thread encoder FFN: eliminates 15 redundant B packings per GEMM call
350/// (128 total across 2 GEMMs × 4 layers).
351#[cfg(feature = "parallel")]
352pub fn gemm_blis_parallel_with_prepacked_b(
353    m: usize,
354    n: usize,
355    k: usize,
356    a: &[f32],
357    prepacked_b: &PrepackedB,
358    c: &mut [f32],
359) -> Result<(), TruenoError> {
360    use rayon::prelude::*;
361
362    if a.len() != m * k || c.len() != m * n {
363        return Err(TruenoError::InvalidInput("Dimension mismatch".to_string()));
364    }
365    if prepacked_b.k != k || prepacked_b.n != n {
366        return Err(TruenoError::InvalidInput(format!(
367            "PrepackedB dimension mismatch: expected ({}, {}), got ({}, {})",
368            k, n, prepacked_b.k, prepacked_b.n
369        )));
370    }
371
372    // Small matrices: single-threaded
373    if m * n * k < 1_000_000 {
374        return gemm_blis_with_prepacked_b(m, n, k, a, prepacked_b, c, None);
375    }
376
377    let scheduler = HeijunkaScheduler::default();
378    let partitions = scheduler.partition_m(m, MC);
379
380    let c_ptr = c.as_mut_ptr() as usize;
381
382    // Key: prepacked_b is shared (immutable &) across all threads — zero redundant packing
383    partitions.into_par_iter().for_each(|m_range| {
384        let m_local = m_range.len();
385        let m_start = m_range.start;
386
387        let a_local = &a[m_start * k..(m_start + m_local) * k];
388
389        // SAFETY: Each thread accesses a disjoint row range of C.
390        // Partitions are non-overlapping by construction in HeijunkaScheduler::partition_m.
391        let c_local = unsafe {
392            let ptr = c_ptr as *mut f32;
393            std::slice::from_raw_parts_mut(ptr.add(m_start * n), m_local * n)
394        };
395
396        let _ = gemm_blis_with_prepacked_b(m_local, n, k, a_local, prepacked_b, c_local, None);
397    });
398
399    Ok(())
400}
401
402/// Non-parallel fallback for pre-packed B
403#[cfg(not(feature = "parallel"))]
404pub fn gemm_blis_parallel_with_prepacked_b(
405    m: usize,
406    n: usize,
407    k: usize,
408    a: &[f32],
409    prepacked_b: &PrepackedB,
410    c: &mut [f32],
411) -> Result<(), TruenoError> {
412    gemm_blis_with_prepacked_b(m, n, k, a, prepacked_b, c, None)
413}