use crate::error::TruenoError;
use super::compute::{gemm_blis, gemm_blis_with_prepacked_b};
use super::prepacked::PrepackedB;
#[cfg(feature = "parallel")]
use super::{MC, MR};
#[derive(Debug, Clone)]
pub struct HeijunkaScheduler {
pub num_threads: usize,
pub variance_threshold: f32,
}
impl Default for HeijunkaScheduler {
fn default() -> Self {
#[cfg(feature = "parallel")]
let threads = rayon::current_num_threads();
#[cfg(not(feature = "parallel"))]
let threads = 1;
Self {
num_threads: threads,
variance_threshold: 0.05, }
}
}
impl HeijunkaScheduler {
pub fn partition_m(&self, m: usize, mc: usize) -> Vec<std::ops::Range<usize>> {
let num_blocks = (m + mc - 1) / mc;
let blocks_per_thread = num_blocks / self.num_threads;
let remainder = num_blocks % self.num_threads;
let mut partitions = Vec::with_capacity(self.num_threads);
let mut start_block = 0;
for t in 0..self.num_threads {
let extra = if t < remainder { 1 } else { 0 };
let thread_blocks = blocks_per_thread + extra;
let start_row = start_block * mc;
let end_row = ((start_block + thread_blocks) * mc).min(m);
if start_row < end_row {
partitions.push(start_row..end_row);
}
start_block += thread_blocks;
}
partitions
}
}
#[cfg(feature = "parallel")]
pub fn gemm_blis_parallel(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
) -> Result<(), TruenoError> {
use rayon::prelude::*;
contract_pre_amdahl_speedup!();
if a.len() != m * k || b.len() != k * n || c.len() != m * n {
return Err(TruenoError::InvalidInput("Dimension mismatch".to_string()));
}
let flops = m * n * k;
if flops < 8_000_000 {
return gemm_blis(m, n, k, a, b, c, None);
}
let phys_cores = num_cpus::get_physical();
let max_threads = if flops < 64_000_000 {
2.min(phys_cores)
} else if flops < 512_000_000 {
4.min(phys_cores)
} else if flops < 4_000_000_000 {
8.min(phys_cores)
} else {
(phys_cores / 2).max(8).min(phys_cores)
};
let mut scheduler = HeijunkaScheduler::default();
scheduler.num_threads = scheduler.num_threads.min(max_threads);
let ps = if m <= MC { MR.max(m / scheduler.num_threads) } else { MC };
let partitions = scheduler.partition_m(m, ps);
let c_ptr = c.as_mut_ptr() as usize;
partitions.into_par_iter().for_each(|m_range| {
let m_local = m_range.len();
let m_start = m_range.start;
let a_local = &a[m_start * k..(m_start + m_local) * k];
let c_local = unsafe {
let ptr = c_ptr as *mut f32;
std::slice::from_raw_parts_mut(ptr.add(m_start * n), m_local * n)
};
let _ = gemm_blis(m_local, n, k, a_local, b, c_local, None);
});
contract_post_amdahl_speedup!(c);
Ok(())
}
#[cfg(feature = "parallel")]
pub fn gemm_blis_parallel_shared_b(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
) -> Result<(), TruenoError> {
use rayon::prelude::*;
if a.len() != m * k || b.len() != k * n || c.len() != m * n {
return Err(TruenoError::InvalidInput("Dimension mismatch".to_string()));
}
let flops = m * n * k;
if flops < 8_000_000 {
return gemm_blis(m, n, k, a, b, c, None);
}
#[cfg(target_arch = "x86_64")]
if !std::arch::is_x86_feature_detected!("avx512f") {
return gemm_blis(m, n, k, a, b, c, None);
}
let phys_cores = num_cpus::get_physical();
let max_threads = if flops < 64_000_000 {
2.min(phys_cores)
} else if flops < 512_000_000 {
4.min(phys_cores)
} else if flops < 4_000_000_000 {
(phys_cores / 2).max(8).min(phys_cores)
} else {
(phys_cores / 2).max(8).min(phys_cores)
};
let blk = super::cache_topology::blocking_8x32();
let mr = blk.mr; let nr = blk.nr; let mc = blk.mc.min(m);
let nc = blk.nc.min(n);
let kc = blk.kc;
let b_panels = (nc + nr - 1) / nr;
let packed_b_size = b_panels * nr * kc;
let mut packed_b = vec![0.0f32; packed_b_size];
let c_ptr = c.as_mut_ptr() as usize;
let num_threads = max_threads.min(rayon::current_num_threads());
for jc in (0..n).step_by(nc) {
let nc_block = nc.min(n - jc);
for pc in (0..k).step_by(kc) {
let kc_block = kc.min(k - pc);
super::compute::pack_b_block_generic(
b,
n,
pc,
jc,
kc_block,
nc_block,
nr,
&mut packed_b,
);
let shared_b: &[f32] = &packed_b;
let m_per_thread = ((m + num_threads - 1) / num_threads + mr - 1) / mr * mr;
(0..num_threads).into_par_iter().for_each(|tid| {
let ic_start = tid * m_per_thread;
if ic_start >= m {
return;
}
let ic_end = (ic_start + m_per_thread).min(m);
thread_local! {
static TL_A: std::cell::RefCell<Vec<f32>> =
const { std::cell::RefCell::new(Vec::new()) };
}
TL_A.with(|tl| {
let a_panels = (m_per_thread + mr - 1) / mr;
let needed = a_panels * mr * kc_block;
let mut packed_a = tl.borrow_mut();
if packed_a.len() < needed {
packed_a.resize(needed, 0.0);
}
let panels_n = (nc_block + nr - 1) / nr;
for ic in (ic_start..ic_end).step_by(mc) {
let mc_block = mc.min(ic_end - ic);
super::packing::pack_a_block(
a,
k,
ic,
pc,
mc_block,
kc_block,
&mut packed_a,
);
let panels_m = (mc_block + mr - 1) / mr;
for ir_panel in 0..panels_m {
let ir = ir_panel * mr;
let mr_block = mr.min(mc_block - ir);
for jr_panel in 0..panels_n {
let jr = jr_panel * nr;
let nr_block = nr.min(nc_block - jr);
let a_panel = &packed_a[ir_panel * mr * kc_block..];
let b_panel = &shared_b[jr_panel * nr * kc_block..];
if mr_block == 8 && nr_block == 32 {
#[cfg(target_arch = "x86_64")]
unsafe {
super::compute::avx512_microkernel_8x32_rowmajor(
kc_block,
a_panel.as_ptr(),
b_panel.as_ptr(),
(c_ptr as *mut f32).add((ic + ir) * n + (jc + jr)),
n,
);
}
} else {
for ir_local in 0..mr_block {
for jr_local in 0..nr_block {
let mut sum = 0.0f32;
for p in 0..kc_block {
sum += a_panel[p * mr + ir_local]
* b_panel[p * nr + jr_local];
}
unsafe {
let c = c_ptr as *mut f32;
*c.add(
(ic + ir + ir_local) * n + (jc + jr + jr_local),
) += sum;
}
}
}
}
}
}
}
}); });
}
}
Ok(())
}
#[cfg(not(feature = "parallel"))]
pub fn gemm_blis_parallel(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
) -> Result<(), TruenoError> {
gemm_blis(m, n, k, a, b, c, None)
}
#[cfg(feature = "parallel")]
pub fn gemm_blis_parallel_with_prepacked_b(
m: usize,
n: usize,
k: usize,
a: &[f32],
prepacked_b: &PrepackedB,
c: &mut [f32],
) -> Result<(), TruenoError> {
use rayon::prelude::*;
if a.len() != m * k || c.len() != m * n {
return Err(TruenoError::InvalidInput("Dimension mismatch".to_string()));
}
if prepacked_b.k != k || prepacked_b.n != n {
return Err(TruenoError::InvalidInput(format!(
"PrepackedB dimension mismatch: expected ({}, {}), got ({}, {})",
k, n, prepacked_b.k, prepacked_b.n
)));
}
if m * n * k < 1_000_000 {
return gemm_blis_with_prepacked_b(m, n, k, a, prepacked_b, c, None);
}
let scheduler = HeijunkaScheduler::default();
let partitions = scheduler.partition_m(m, MC);
let c_ptr = c.as_mut_ptr() as usize;
partitions.into_par_iter().for_each(|m_range| {
let m_local = m_range.len();
let m_start = m_range.start;
let a_local = &a[m_start * k..(m_start + m_local) * k];
let c_local = unsafe {
let ptr = c_ptr as *mut f32;
std::slice::from_raw_parts_mut(ptr.add(m_start * n), m_local * n)
};
let _ = gemm_blis_with_prepacked_b(m_local, n, k, a_local, prepacked_b, c_local, None);
});
Ok(())
}
#[cfg(not(feature = "parallel"))]
pub fn gemm_blis_parallel_with_prepacked_b(
m: usize,
n: usize,
k: usize,
a: &[f32],
prepacked_b: &PrepackedB,
c: &mut [f32],
) -> Result<(), TruenoError> {
gemm_blis_with_prepacked_b(m, n, k, a, prepacked_b, c, None)
}