1use crate::error::TruenoError;
7
8use super::compute::{gemm_blis, gemm_blis_with_prepacked_b};
9use super::prepacked::PrepackedB;
10#[cfg(feature = "parallel")]
11use super::{MC, MR};
12
13#[derive(Debug, Clone)]
15pub struct HeijunkaScheduler {
16 pub num_threads: usize,
18 pub variance_threshold: f32,
20}
21
22impl Default for HeijunkaScheduler {
23 fn default() -> Self {
24 #[cfg(feature = "parallel")]
25 let threads = rayon::current_num_threads();
26 #[cfg(not(feature = "parallel"))]
27 let threads = 1;
28
29 Self {
30 num_threads: threads,
31 variance_threshold: 0.05, }
33 }
34}
35
36impl HeijunkaScheduler {
37 pub fn partition_m(&self, m: usize, mc: usize) -> Vec<std::ops::Range<usize>> {
39 let num_blocks = (m + mc - 1) / mc;
40 let blocks_per_thread = num_blocks / self.num_threads;
41 let remainder = num_blocks % self.num_threads;
42
43 let mut partitions = Vec::with_capacity(self.num_threads);
44 let mut start_block = 0;
45
46 for t in 0..self.num_threads {
47 let extra = if t < remainder { 1 } else { 0 };
48 let thread_blocks = blocks_per_thread + extra;
49
50 let start_row = start_block * mc;
51 let end_row = ((start_block + thread_blocks) * mc).min(m);
52
53 if start_row < end_row {
54 partitions.push(start_row..end_row);
55 }
56
57 start_block += thread_blocks;
58 }
59
60 partitions
61 }
62}
63
64#[cfg(feature = "parallel")]
66pub fn gemm_blis_parallel(
67 m: usize,
68 n: usize,
69 k: usize,
70 a: &[f32],
71 b: &[f32],
72 c: &mut [f32],
73) -> Result<(), TruenoError> {
74 use rayon::prelude::*;
75 contract_pre_amdahl_speedup!();
76
77 if a.len() != m * k || b.len() != k * n || c.len() != m * n {
79 return Err(TruenoError::InvalidInput("Dimension mismatch".to_string()));
80 }
81
82 let flops = m * n * k;
86 if flops < 8_000_000 {
87 return gemm_blis(m, n, k, a, b, c, None);
88 }
89
90 let phys_cores = num_cpus::get_physical();
102 let max_threads = if flops < 64_000_000 {
103 2.min(phys_cores)
105 } else if flops < 512_000_000 {
106 4.min(phys_cores)
108 } else if flops < 4_000_000_000 {
109 8.min(phys_cores)
115 } else {
116 (phys_cores / 2).max(8).min(phys_cores)
119 };
120
121 let mut scheduler = HeijunkaScheduler::default();
122 scheduler.num_threads = scheduler.num_threads.min(max_threads);
123 let ps = if m <= MC { MR.max(m / scheduler.num_threads) } else { MC };
124 let partitions = scheduler.partition_m(m, ps);
125
126 let c_ptr = c.as_mut_ptr() as usize;
133
134 partitions.into_par_iter().for_each(|m_range| {
135 let m_local = m_range.len();
136 let m_start = m_range.start;
137
138 let a_local = &a[m_start * k..(m_start + m_local) * k];
139
140 let c_local = unsafe {
142 let ptr = c_ptr as *mut f32;
143 std::slice::from_raw_parts_mut(ptr.add(m_start * n), m_local * n)
144 };
145
146 let _ = gemm_blis(m_local, n, k, a_local, b, c_local, None);
147 });
148
149 Ok(())
150}
151
152#[cfg(feature = "parallel")]
163pub fn gemm_blis_parallel_shared_b(
164 m: usize,
165 n: usize,
166 k: usize,
167 a: &[f32],
168 b: &[f32],
169 c: &mut [f32],
170) -> Result<(), TruenoError> {
171 use rayon::prelude::*;
172
173 if a.len() != m * k || b.len() != k * n || c.len() != m * n {
174 return Err(TruenoError::InvalidInput("Dimension mismatch".to_string()));
175 }
176
177 let flops = m * n * k;
179 if flops < 8_000_000 {
180 return gemm_blis(m, n, k, a, b, c, None);
181 }
182
183 #[cfg(target_arch = "x86_64")]
185 if !std::arch::is_x86_feature_detected!("avx512f") {
186 return gemm_blis(m, n, k, a, b, c, None);
187 }
188
189 let phys_cores = num_cpus::get_physical();
190 let max_threads = if flops < 64_000_000 {
191 2.min(phys_cores)
192 } else if flops < 512_000_000 {
193 4.min(phys_cores)
194 } else if flops < 4_000_000_000 {
195 (phys_cores / 2).max(8).min(phys_cores)
198 } else {
199 (phys_cores / 2).max(8).min(phys_cores)
200 };
201
202 let blk = super::cache_topology::blocking_8x32();
203 let mr = blk.mr; let nr = blk.nr; let mc = blk.mc.min(m);
206 let nc = blk.nc.min(n);
207 let kc = blk.kc;
208
209 let b_panels = (nc + nr - 1) / nr;
211 let packed_b_size = b_panels * nr * kc;
212 let mut packed_b = vec![0.0f32; packed_b_size];
213
214 let c_ptr = c.as_mut_ptr() as usize;
215 let num_threads = max_threads.min(rayon::current_num_threads());
216
217 for jc in (0..n).step_by(nc) {
218 let nc_block = nc.min(n - jc);
219
220 for pc in (0..k).step_by(kc) {
221 let kc_block = kc.min(k - pc);
222
223 super::compute::pack_b_block_generic(
225 b,
226 n,
227 pc,
228 jc,
229 kc_block,
230 nc_block,
231 nr,
232 &mut packed_b,
233 );
234 let shared_b: &[f32] = &packed_b;
235
236 let m_per_thread = ((m + num_threads - 1) / num_threads + mr - 1) / mr * mr;
238
239 (0..num_threads).into_par_iter().for_each(|tid| {
240 let ic_start = tid * m_per_thread;
241 if ic_start >= m {
242 return;
243 }
244 let ic_end = (ic_start + m_per_thread).min(m);
245
246 thread_local! {
249 static TL_A: std::cell::RefCell<Vec<f32>> =
250 const { std::cell::RefCell::new(Vec::new()) };
251 }
252 TL_A.with(|tl| {
253 let a_panels = (m_per_thread + mr - 1) / mr;
254 let needed = a_panels * mr * kc_block;
255 let mut packed_a = tl.borrow_mut();
256 if packed_a.len() < needed {
257 packed_a.resize(needed, 0.0);
258 }
259
260 let panels_n = (nc_block + nr - 1) / nr;
261
262 for ic in (ic_start..ic_end).step_by(mc) {
263 let mc_block = mc.min(ic_end - ic);
264
265 super::packing::pack_a_block(
266 a,
267 k,
268 ic,
269 pc,
270 mc_block,
271 kc_block,
272 &mut packed_a,
273 );
274
275 let panels_m = (mc_block + mr - 1) / mr;
276
277 for ir_panel in 0..panels_m {
278 let ir = ir_panel * mr;
279 let mr_block = mr.min(mc_block - ir);
280
281 for jr_panel in 0..panels_n {
282 let jr = jr_panel * nr;
283 let nr_block = nr.min(nc_block - jr);
284
285 let a_panel = &packed_a[ir_panel * mr * kc_block..];
286 let b_panel = &shared_b[jr_panel * nr * kc_block..];
287
288 if mr_block == 8 && nr_block == 32 {
289 #[cfg(target_arch = "x86_64")]
290 unsafe {
291 super::compute::avx512_microkernel_8x32_rowmajor(
292 kc_block,
293 a_panel.as_ptr(),
294 b_panel.as_ptr(),
295 (c_ptr as *mut f32).add((ic + ir) * n + (jc + jr)),
296 n,
297 );
298 }
299 } else {
300 for ir_local in 0..mr_block {
302 for jr_local in 0..nr_block {
303 let mut sum = 0.0f32;
304 for p in 0..kc_block {
305 sum += a_panel[p * mr + ir_local]
306 * b_panel[p * nr + jr_local];
307 }
308 unsafe {
309 let c = c_ptr as *mut f32;
310 *c.add(
311 (ic + ir + ir_local) * n + (jc + jr + jr_local),
312 ) += sum;
313 }
314 }
315 }
316 }
317 }
318 }
319 }
320 }); });
322 }
323 }
324
325 Ok(())
326}
327
328#[cfg(not(feature = "parallel"))]
330pub fn gemm_blis_parallel(
331 m: usize,
332 n: usize,
333 k: usize,
334 a: &[f32],
335 b: &[f32],
336 c: &mut [f32],
337) -> Result<(), TruenoError> {
338 gemm_blis(m, n, k, a, b, c, None)
339}
340
341#[cfg(feature = "parallel")]
352pub fn gemm_blis_parallel_with_prepacked_b(
353 m: usize,
354 n: usize,
355 k: usize,
356 a: &[f32],
357 prepacked_b: &PrepackedB,
358 c: &mut [f32],
359) -> Result<(), TruenoError> {
360 use rayon::prelude::*;
361
362 if a.len() != m * k || c.len() != m * n {
363 return Err(TruenoError::InvalidInput("Dimension mismatch".to_string()));
364 }
365 if prepacked_b.k != k || prepacked_b.n != n {
366 return Err(TruenoError::InvalidInput(format!(
367 "PrepackedB dimension mismatch: expected ({}, {}), got ({}, {})",
368 k, n, prepacked_b.k, prepacked_b.n
369 )));
370 }
371
372 if m * n * k < 1_000_000 {
374 return gemm_blis_with_prepacked_b(m, n, k, a, prepacked_b, c, None);
375 }
376
377 let scheduler = HeijunkaScheduler::default();
378 let partitions = scheduler.partition_m(m, MC);
379
380 let c_ptr = c.as_mut_ptr() as usize;
381
382 partitions.into_par_iter().for_each(|m_range| {
384 let m_local = m_range.len();
385 let m_start = m_range.start;
386
387 let a_local = &a[m_start * k..(m_start + m_local) * k];
388
389 let c_local = unsafe {
392 let ptr = c_ptr as *mut f32;
393 std::slice::from_raw_parts_mut(ptr.add(m_start * n), m_local * n)
394 };
395
396 let _ = gemm_blis_with_prepacked_b(m_local, n, k, a_local, prepacked_b, c_local, None);
397 });
398
399 Ok(())
400}
401
402#[cfg(not(feature = "parallel"))]
404pub fn gemm_blis_parallel_with_prepacked_b(
405 m: usize,
406 n: usize,
407 k: usize,
408 a: &[f32],
409 prepacked_b: &PrepackedB,
410 c: &mut [f32],
411) -> Result<(), TruenoError> {
412 gemm_blis_with_prepacked_b(m, n, k, a, prepacked_b, c, None)
413}