use crate::error::TruenoError;
use super::compute::{gemm_blis, gemm_blis_with_prepacked_b};
use super::prepacked::PrepackedB;
#[cfg(feature = "parallel")]
use super::MC;
#[derive(Debug, Clone)]
pub struct HeijunkaScheduler {
pub num_threads: usize,
pub variance_threshold: f32,
}
impl Default for HeijunkaScheduler {
fn default() -> Self {
#[cfg(feature = "parallel")]
let threads = rayon::current_num_threads();
#[cfg(not(feature = "parallel"))]
let threads = 1;
Self {
num_threads: threads,
variance_threshold: 0.05, }
}
}
impl HeijunkaScheduler {
pub fn partition_m(&self, m: usize, mc: usize) -> Vec<std::ops::Range<usize>> {
let num_blocks = (m + mc - 1) / mc;
let blocks_per_thread = num_blocks / self.num_threads;
let remainder = num_blocks % self.num_threads;
let mut partitions = Vec::with_capacity(self.num_threads);
let mut start_block = 0;
for t in 0..self.num_threads {
let extra = if t < remainder { 1 } else { 0 };
let thread_blocks = blocks_per_thread + extra;
let start_row = start_block * mc;
let end_row = ((start_block + thread_blocks) * mc).min(m);
if start_row < end_row {
partitions.push(start_row..end_row);
}
start_block += thread_blocks;
}
partitions
}
}
#[cfg(feature = "parallel")]
pub fn gemm_blis_parallel(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
) -> Result<(), TruenoError> {
use rayon::prelude::*;
if a.len() != m * k || b.len() != k * n || c.len() != m * n {
return Err(TruenoError::InvalidInput("Dimension mismatch".to_string()));
}
if m * n * k < 1_000_000 {
return gemm_blis(m, n, k, a, b, c, None);
}
let scheduler = HeijunkaScheduler::default();
let partitions = scheduler.partition_m(m, MC);
let c_ptr = c.as_mut_ptr() as usize;
partitions.into_par_iter().for_each(|m_range| {
let m_local = m_range.len();
let m_start = m_range.start;
let a_local = &a[m_start * k..(m_start + m_local) * k];
let c_local = unsafe {
let ptr = c_ptr as *mut f32;
std::slice::from_raw_parts_mut(ptr.add(m_start * n), m_local * n)
};
let _ = gemm_blis(m_local, n, k, a_local, b, c_local, None);
});
Ok(())
}
#[cfg(not(feature = "parallel"))]
pub fn gemm_blis_parallel(
m: usize,
n: usize,
k: usize,
a: &[f32],
b: &[f32],
c: &mut [f32],
) -> Result<(), TruenoError> {
gemm_blis(m, n, k, a, b, c, None)
}
#[cfg(feature = "parallel")]
pub fn gemm_blis_parallel_with_prepacked_b(
m: usize,
n: usize,
k: usize,
a: &[f32],
prepacked_b: &PrepackedB,
c: &mut [f32],
) -> Result<(), TruenoError> {
use rayon::prelude::*;
if a.len() != m * k || c.len() != m * n {
return Err(TruenoError::InvalidInput("Dimension mismatch".to_string()));
}
if prepacked_b.k != k || prepacked_b.n != n {
return Err(TruenoError::InvalidInput(format!(
"PrepackedB dimension mismatch: expected ({}, {}), got ({}, {})",
k, n, prepacked_b.k, prepacked_b.n
)));
}
if m * n * k < 1_000_000 {
return gemm_blis_with_prepacked_b(m, n, k, a, prepacked_b, c, None);
}
let scheduler = HeijunkaScheduler::default();
let partitions = scheduler.partition_m(m, MC);
let c_ptr = c.as_mut_ptr() as usize;
partitions.into_par_iter().for_each(|m_range| {
let m_local = m_range.len();
let m_start = m_range.start;
let a_local = &a[m_start * k..(m_start + m_local) * k];
let c_local = unsafe {
let ptr = c_ptr as *mut f32;
std::slice::from_raw_parts_mut(ptr.add(m_start * n), m_local * n)
};
let _ = gemm_blis_with_prepacked_b(m_local, n, k, a_local, prepacked_b, c_local, None);
});
Ok(())
}
#[cfg(not(feature = "parallel"))]
pub fn gemm_blis_parallel_with_prepacked_b(
m: usize,
n: usize,
k: usize,
a: &[f32],
prepacked_b: &PrepackedB,
c: &mut [f32],
) -> Result<(), TruenoError> {
gemm_blis_with_prepacked_b(m, n, k, a, prepacked_b, c, None)
}