#[cfg(not(target_arch = "wasm32"))]
use rayon::prelude::*;
use crate::dispatch::KernelDispatcher;
use crate::error::{KernelError, KernelResult};
#[cfg(not(target_arch = "wasm32"))]
use crate::tiled::{optimal_tile_rows, L2_TILE_ROWS};
use crate::traits::OneBitKernel;
use crate::traits::TernaryKernel;
use oxibonsai_core::tensor::{BlockQ1_0G128, QK1_0_G128};
const PAR_TILED_MIN_ROWS: usize = 128;
const PAR_TILED_MIN_BATCH: usize = 4;
const DIRECT_DISPATCH_MAX_ROWS: usize = 32;
const MEDIUM_PARALLEL_MAX_ROWS: usize = 256;
fn validate_gemv(
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &[f32],
n_rows: usize,
k: usize,
) -> KernelResult<usize> {
if k % QK1_0_G128 != 0 {
return Err(KernelError::NotBlockAligned {
count: k,
block_size: QK1_0_G128,
});
}
if input.len() < k {
return Err(KernelError::DimensionMismatch {
expected: k,
got: input.len(),
});
}
if output.len() < n_rows {
return Err(KernelError::BufferTooSmall {
needed: n_rows,
available: output.len(),
});
}
let blocks_per_row = k / QK1_0_G128;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::BufferTooSmall {
needed: expected_blocks,
available: blocks.len(),
});
}
Ok(blocks_per_row)
}
fn validate_gemm(
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &[f32],
m: usize,
n_rows: usize,
k: usize,
) -> KernelResult<usize> {
if k % QK1_0_G128 != 0 {
return Err(KernelError::NotBlockAligned {
count: k,
block_size: QK1_0_G128,
});
}
if input.len() < m * k {
return Err(KernelError::DimensionMismatch {
expected: m * k,
got: input.len(),
});
}
if output.len() < m * n_rows {
return Err(KernelError::BufferTooSmall {
needed: m * n_rows,
available: output.len(),
});
}
let blocks_per_row = k / QK1_0_G128;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::BufferTooSmall {
needed: expected_blocks,
available: blocks.len(),
});
}
Ok(blocks_per_row)
}
pub fn gemv_parallel_tiled(
dispatcher: &KernelDispatcher,
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()> {
#[cfg(not(target_arch = "wasm32"))]
let blocks_per_row = validate_gemv(blocks, input, output, n_rows, k)?;
#[cfg(target_arch = "wasm32")]
let _blocks_per_row = validate_gemv(blocks, input, output, n_rows, k)?;
if n_rows < PAR_TILED_MIN_ROWS {
return crate::tiled::gemv_tiled(dispatcher, blocks, input, output, n_rows, k);
}
#[cfg(target_arch = "wasm32")]
{
return crate::tiled::gemv_tiled(dispatcher, blocks, input, output, n_rows, k);
}
#[cfg(not(target_arch = "wasm32"))]
let l1_tile = optimal_tile_rows(k).max(1);
#[cfg(not(target_arch = "wasm32"))]
{
output[..n_rows]
.par_chunks_mut(L2_TILE_ROWS)
.enumerate()
.try_for_each(|(tile_idx, out_chunk)| -> KernelResult<()> {
let tile_start = tile_idx * L2_TILE_ROWS;
let tile_rows = out_chunk.len();
let block_start = tile_start * blocks_per_row;
let block_end = (tile_start + tile_rows) * blocks_per_row;
let tile_blocks = &blocks[block_start..block_end];
let mut l1_start = 0;
while l1_start < tile_rows {
let l1_rows = (tile_rows - l1_start).min(l1_tile);
let l1_block_start = l1_start * blocks_per_row;
let l1_block_end = (l1_start + l1_rows) * blocks_per_row;
dispatcher.gemv(
&tile_blocks[l1_block_start..l1_block_end],
input,
&mut out_chunk[l1_start..l1_start + l1_rows],
l1_rows,
k,
)?;
l1_start += l1_rows;
}
Ok::<(), KernelError>(())
})?;
Ok(())
}
}
pub fn gemm_parallel_tiled(
dispatcher: &KernelDispatcher,
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &mut [f32],
m: usize,
n_rows: usize,
k: usize,
) -> KernelResult<()> {
#[cfg(not(target_arch = "wasm32"))]
let blocks_per_row = validate_gemm(blocks, input, output, m, n_rows, k)?;
#[cfg(target_arch = "wasm32")]
let _blocks_per_row = validate_gemm(blocks, input, output, m, n_rows, k)?;
if m < PAR_TILED_MIN_BATCH {
return crate::tiled::gemm_tiled(dispatcher, blocks, input, output, m, n_rows, k);
}
#[cfg(target_arch = "wasm32")]
{
return crate::tiled::gemm_tiled(dispatcher, blocks, input, output, m, n_rows, k);
}
#[cfg(not(target_arch = "wasm32"))]
let l1_tile = optimal_tile_rows(k).max(1);
#[cfg(not(target_arch = "wasm32"))]
{
output[..m * n_rows]
.par_chunks_mut(n_rows)
.enumerate()
.try_for_each(|(mi, out_row)| -> KernelResult<()> {
let input_offset = mi * k;
let mut row_start = 0;
while row_start < n_rows {
let tile_rows = (n_rows - row_start).min(l1_tile);
let block_start = row_start * blocks_per_row;
let block_end = (row_start + tile_rows) * blocks_per_row;
dispatcher.gemm(
&blocks[block_start..block_end],
&input[input_offset..input_offset + k],
&mut out_row[row_start..row_start + tile_rows],
1,
tile_rows,
k,
)?;
row_start += tile_rows;
}
Ok::<(), KernelError>(())
})?;
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AdaptiveStrategy {
Direct,
ParallelRow,
ParallelTiled,
}
pub fn select_gemv_strategy(n_rows: usize, _k: usize) -> AdaptiveStrategy {
if n_rows <= DIRECT_DISPATCH_MAX_ROWS {
AdaptiveStrategy::Direct
} else if n_rows <= MEDIUM_PARALLEL_MAX_ROWS {
AdaptiveStrategy::ParallelRow
} else {
AdaptiveStrategy::ParallelTiled
}
}
pub fn gemv_adaptive(
dispatcher: &KernelDispatcher,
blocks: &[BlockQ1_0G128],
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()> {
match select_gemv_strategy(n_rows, k) {
AdaptiveStrategy::Direct => dispatcher.gemv(blocks, input, output, n_rows, k),
AdaptiveStrategy::ParallelRow => {
crate::parallel::gemv_1bit_g128_par(dispatcher, blocks, input, output, n_rows, k)
}
AdaptiveStrategy::ParallelTiled => {
gemv_parallel_tiled(dispatcher, blocks, input, output, n_rows, k)
}
}
}
pub fn gemv_adaptive_ternary(
dispatcher: &KernelDispatcher,
blocks: &[oxibonsai_core::BlockTQ2_0_g128],
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()> {
match select_gemv_strategy(n_rows, k) {
AdaptiveStrategy::Direct => dispatcher.gemv_ternary_g128(blocks, input, output, n_rows, k),
AdaptiveStrategy::ParallelRow | AdaptiveStrategy::ParallelTiled => {
crate::parallel::gemv_ternary_g128_par(dispatcher, blocks, input, output, n_rows, k)
}
}
}
pub fn gemm_adaptive_ternary(
dispatcher: &KernelDispatcher,
blocks: &[oxibonsai_core::BlockTQ2_0_g128],
input: &[f32],
output: &mut [f32],
m: usize,
n_rows: usize,
k: usize,
) -> KernelResult<()> {
if m < PAR_TILED_MIN_BATCH {
dispatcher.gemm_ternary_g128(blocks, input, output, m, n_rows, k)
} else {
crate::parallel::gemm_ternary_g128_par(dispatcher, blocks, input, output, m, n_rows, k)
}
}
#[derive(Debug, Clone)]
pub struct ParallelConfig {
pub num_threads: usize,
pub gemv_threshold: usize,
pub gemm_threshold: usize,
pub use_tiling: bool,
}
impl Default for ParallelConfig {
fn default() -> Self {
#[cfg(not(target_arch = "wasm32"))]
let num_threads = rayon::current_num_threads();
#[cfg(target_arch = "wasm32")]
let num_threads = 1usize;
Self {
num_threads,
gemv_threshold: PAR_TILED_MIN_ROWS,
gemm_threshold: PAR_TILED_MIN_BATCH,
use_tiling: true,
}
}
}
impl ParallelConfig {
pub fn single_threaded() -> Self {
Self {
num_threads: 1,
gemv_threshold: usize::MAX,
gemm_threshold: usize::MAX,
use_tiling: false,
}
}
pub fn should_parallelize_gemv(&self, n_rows: usize) -> bool {
self.num_threads > 1 && n_rows >= self.gemv_threshold
}
pub fn should_parallelize_gemm(&self, m: usize) -> bool {
self.num_threads > 1 && m >= self.gemm_threshold
}
}
#[cfg(test)]
mod tests {
use super::*;
use half::f16;
fn make_block(scale: f32, bits: [u8; 16]) -> BlockQ1_0G128 {
BlockQ1_0G128 {
d: f16::from_f32(scale),
qs: bits,
}
}
fn make_test_data(n_rows: usize, k: usize) -> (Vec<BlockQ1_0G128>, Vec<f32>) {
let blocks_per_row = k / QK1_0_G128;
let mut blocks = Vec::with_capacity(n_rows * blocks_per_row);
for row in 0..n_rows {
for bi in 0..blocks_per_row {
let bits = [((row * 37 + bi * 13) & 0xFF) as u8; 16];
blocks.push(make_block(0.5 + (row as f32) * 0.01, bits));
}
}
let input: Vec<f32> = (0..k).map(|i| (i as f32 * 0.01) - 1.28).collect();
(blocks, input)
}
fn make_ternary_block(qs: [u8; 32]) -> oxibonsai_core::BlockTQ2_0_g128 {
oxibonsai_core::BlockTQ2_0_g128 { qs, d: f16::ONE }
}
#[test]
fn parallel_tiled_gemv_matches_sequential() {
let n_rows = 256;
let k = 256;
let (blocks, input) = make_test_data(n_rows, k);
let dispatcher = KernelDispatcher::auto_detect();
let mut out_seq = vec![0.0f32; n_rows];
let mut out_par = vec![0.0f32; n_rows];
dispatcher
.gemv(&blocks, &input, &mut out_seq, n_rows, k)
.expect("direct gemv should succeed");
gemv_parallel_tiled(&dispatcher, &blocks, &input, &mut out_par, n_rows, k)
.expect("parallel tiled gemv should succeed");
for i in 0..n_rows {
assert!(
(out_seq[i] - out_par[i]).abs() < 1e-4,
"row {i}: seq={}, par_tiled={}",
out_seq[i],
out_par[i]
);
}
}
#[test]
fn parallel_tiled_gemv_small_fallback() {
let n_rows = 16;
let k = 128;
let (blocks, input) = make_test_data(n_rows, k);
let dispatcher = KernelDispatcher::auto_detect();
let mut out_seq = vec![0.0f32; n_rows];
let mut out_par = vec![0.0f32; n_rows];
dispatcher
.gemv(&blocks, &input, &mut out_seq, n_rows, k)
.expect("direct gemv should succeed");
gemv_parallel_tiled(&dispatcher, &blocks, &input, &mut out_par, n_rows, k)
.expect("fallback tiled gemv should succeed");
for i in 0..n_rows {
assert!(
(out_seq[i] - out_par[i]).abs() < f32::EPSILON,
"row {i}: seq={}, par={}",
out_seq[i],
out_par[i]
);
}
}
#[test]
fn parallel_tiled_gemm_matches_sequential() {
let m = 8;
let n_rows = 32;
let k = 128;
let blocks_per_row = k / QK1_0_G128;
let mut blocks = Vec::new();
for ni in 0..n_rows {
for bi in 0..blocks_per_row {
let bits = [((ni * 17 + bi * 7) & 0xFF) as u8; 16];
blocks.push(make_block(1.0 + ni as f32 * 0.2, bits));
}
}
let input: Vec<f32> = (0..m * k).map(|i| (i as f32 * 0.005) - 0.32).collect();
let dispatcher = KernelDispatcher::auto_detect();
let mut out_seq = vec![0.0f32; m * n_rows];
let mut out_par = vec![0.0f32; m * n_rows];
dispatcher
.gemm(&blocks, &input, &mut out_seq, m, n_rows, k)
.expect("direct gemm should succeed");
gemm_parallel_tiled(&dispatcher, &blocks, &input, &mut out_par, m, n_rows, k)
.expect("parallel tiled gemm should succeed");
for i in 0..(m * n_rows) {
assert!(
(out_seq[i] - out_par[i]).abs() < 1e-3,
"idx {i}: seq={}, par_tiled={}",
out_seq[i],
out_par[i]
);
}
}
#[test]
fn adaptive_selects_direct_for_small() {
let strategy = select_gemv_strategy(16, 128);
assert_eq!(strategy, AdaptiveStrategy::Direct);
}
#[test]
fn adaptive_selects_parallel_row_for_medium() {
let strategy = select_gemv_strategy(128, 256);
assert_eq!(strategy, AdaptiveStrategy::ParallelRow);
}
#[test]
fn adaptive_selects_parallel_tiled_for_large() {
let strategy = select_gemv_strategy(512, 4096);
assert_eq!(strategy, AdaptiveStrategy::ParallelTiled);
}
#[test]
fn adaptive_gemv_matches_direct() {
let n_rows = 64;
let k = 256;
let (blocks, input) = make_test_data(n_rows, k);
let dispatcher = KernelDispatcher::auto_detect();
let mut out_direct = vec![0.0f32; n_rows];
let mut out_adaptive = vec![0.0f32; n_rows];
dispatcher
.gemv(&blocks, &input, &mut out_direct, n_rows, k)
.expect("direct gemv should succeed");
gemv_adaptive(&dispatcher, &blocks, &input, &mut out_adaptive, n_rows, k)
.expect("adaptive gemv should succeed");
for i in 0..n_rows {
assert!(
(out_direct[i] - out_adaptive[i]).abs() < 1e-4,
"row {i}: direct={}, adaptive={}",
out_direct[i],
out_adaptive[i]
);
}
}
#[test]
fn adaptive_ternary_gemv_small_is_direct() -> KernelResult<()> {
let n_rows = 16;
let k = 128;
let blocks_per_row = k / oxibonsai_core::QK_TQ2_0_G128;
let blocks = vec![make_ternary_block([0xAAu8; 32]); n_rows * blocks_per_row];
let input: Vec<f32> = (0..k).map(|i| (i as f32 * 0.01) - 1.28).collect();
let dispatcher = KernelDispatcher::auto_detect();
let mut output = vec![0.0f32; n_rows];
gemv_adaptive_ternary(&dispatcher, &blocks, &input, &mut output, n_rows, k)
}
#[test]
fn adaptive_ternary_gemv_large_is_parallel() -> KernelResult<()> {
let n_rows = 512;
let k = 128;
let blocks_per_row = k / oxibonsai_core::QK_TQ2_0_G128;
let blocks = vec![make_ternary_block([0xAAu8; 32]); n_rows * blocks_per_row];
let input: Vec<f32> = (0..k).map(|i| (i as f32 * 0.01) - 1.28).collect();
let dispatcher = KernelDispatcher::auto_detect();
let mut output = vec![0.0f32; n_rows];
gemv_adaptive_ternary(&dispatcher, &blocks, &input, &mut output, n_rows, k)
}
#[test]
fn parallel_config_default() {
let config = ParallelConfig::default();
assert!(config.num_threads >= 1);
assert_eq!(config.gemv_threshold, PAR_TILED_MIN_ROWS);
assert_eq!(config.gemm_threshold, PAR_TILED_MIN_BATCH);
assert!(config.use_tiling);
}
#[test]
fn parallel_config_single_threaded() {
let config = ParallelConfig::single_threaded();
assert_eq!(config.num_threads, 1);
assert!(!config.use_tiling);
assert!(!config.should_parallelize_gemv(1_000_000));
assert!(!config.should_parallelize_gemm(1_000_000));
}
#[test]
fn parallel_config_threshold_checks() {
let config = ParallelConfig::default();
if config.num_threads > 1 {
assert!(!config.should_parallelize_gemv(64));
assert!(config.should_parallelize_gemv(256));
assert!(!config.should_parallelize_gemm(2));
assert!(config.should_parallelize_gemm(8));
}
}
#[test]
fn validation_errors_propagate() {
let dispatcher = KernelDispatcher::auto_detect();
let blocks = vec![make_block(1.0, [0xFF; 16])];
let input = vec![1.0f32; 128];
let mut output = vec![0.0f32; 1];
let result = gemv_parallel_tiled(&dispatcher, &blocks, &input, &mut output, 1, 100);
assert!(result.is_err());
let result = gemm_parallel_tiled(&dispatcher, &blocks, &input, &mut output, 1, 1, 100);
assert!(result.is_err());
}
}