#[cfg(feature = "parallel")]
use rayon::prelude::*;
use crate::error::{QuantError, QuantResult};
use crate::types::QuantTensor;
pub fn parallel_gemv<F>(
quant_matrix: &QuantTensor,
input: &[f32],
output: &mut [f32],
block_size: usize,
block_bytes: usize,
row_dot: F,
) -> QuantResult<()>
where
F: Fn(&[u8], &[f32], usize) -> f32 + Send + Sync,
{
let n_rows = quant_matrix.shape[0];
let n_cols = if quant_matrix.shape.len() > 1 {
quant_matrix.shape[1]
} else {
quant_matrix.n_elements() / n_rows
};
if input.len() < n_cols {
return Err(QuantError::DimensionMismatch {
expected: n_cols,
got: input.len(),
});
}
if output.len() < n_rows {
return Err(QuantError::DimensionMismatch {
expected: n_rows,
got: output.len(),
});
}
let blocks_per_row = n_cols.div_ceil(block_size);
let row_bytes = blocks_per_row * block_bytes;
let data = &quant_matrix.data;
#[cfg(feature = "parallel")]
{
output[..n_rows]
.par_iter_mut()
.enumerate()
.for_each(|(row, out)| {
let row_start = row * row_bytes;
let row_data = &data[row_start..row_start + row_bytes];
*out = row_dot(row_data, input, n_cols);
});
}
#[cfg(not(feature = "parallel"))]
{
output[..n_rows]
.iter_mut()
.enumerate()
.for_each(|(row, out)| {
let row_start = row * row_bytes;
let row_data = &data[row_start..row_start + row_bytes];
*out = row_dot(row_data, input, n_cols);
});
}
Ok(())
}
pub struct GemmDims {
pub m: usize,
pub n: usize,
pub k: usize,
pub block_size: usize,
pub block_bytes: usize,
}
pub fn parallel_gemm<F>(
quant_matrix: &QuantTensor,
input: &[f32],
output: &mut [f32],
dims: &GemmDims,
row_dot: F,
) -> QuantResult<()>
where
F: Fn(&[u8], &[f32], usize) -> f32 + Send + Sync,
{
let blocks_per_row = dims.k.div_ceil(dims.block_size);
let weight_row_bytes = blocks_per_row * dims.block_bytes;
let data = &quant_matrix.data;
#[cfg(feature = "parallel")]
{
output
.par_chunks_mut(dims.n)
.enumerate()
.take(dims.m)
.for_each(|(batch_row, out_row)| {
let inp_row = &input[batch_row * dims.k..(batch_row + 1) * dims.k];
for (weight_row, out) in out_row.iter_mut().enumerate().take(dims.n) {
let row_start = weight_row * weight_row_bytes;
let row_data = &data[row_start..row_start + weight_row_bytes];
*out = row_dot(row_data, inp_row, dims.k);
}
});
}
#[cfg(not(feature = "parallel"))]
{
output
.chunks_mut(dims.n)
.enumerate()
.take(dims.m)
.for_each(|(batch_row, out_row)| {
let inp_row = &input[batch_row * dims.k..(batch_row + 1) * dims.k];
for (weight_row, out) in out_row.iter_mut().enumerate().take(dims.n) {
let row_start = weight_row * weight_row_bytes;
let row_data = &data[row_start..row_start + weight_row_bytes];
*out = row_dot(row_data, inp_row, dims.k);
}
});
}
Ok(())
}
pub const PARALLEL_ROW_THRESHOLD: usize = 64;
pub fn should_parallelize(n_rows: usize, n_cols: usize) -> bool {
n_rows >= PARALLEL_ROW_THRESHOLD && n_cols >= 256
}
#[cfg(test)]
mod tests {
use super::*;
use crate::reference::Q8_0Ref;
use crate::traits::QuantKernel;
fn make_q8_0_block(d: f32, qs: &[i8; 32]) -> Vec<u8> {
let mut block = Vec::with_capacity(34);
let d_bits = half::f16::from_f32(d).to_bits();
block.extend_from_slice(&d_bits.to_le_bytes());
for &q in qs {
block.push(q as u8);
}
block
}
#[test]
fn test_parallel_gemv_matches_sequential() {
let n_rows = 4;
let n_cols = 32;
let mut data = Vec::new();
for row in 0..n_rows {
let mut qs = [0i8; 32];
for (i, q) in qs.iter_mut().enumerate() {
*q = ((row as i16 * 7 + i as i16 * 3 - 48).clamp(-128, 127)) as i8;
}
data.extend_from_slice(&make_q8_0_block(0.5, &qs));
}
let tensor = QuantTensor::new(
data,
vec![n_rows, n_cols],
oxillama_gguf::GgufTensorType::Q8_0,
);
let input: Vec<f32> = (0..n_cols).map(|i| (i as f32 * 0.1) - 1.6).collect();
let kernel = Q8_0Ref;
let mut seq_output = vec![0.0f32; n_rows];
kernel.gemv(&tensor, &input, &mut seq_output).unwrap();
let mut par_output = vec![0.0f32; n_rows];
parallel_gemv(
&tensor,
&input,
&mut par_output,
32,
34,
|row_data, inp, _n_cols| {
let d =
half::f16::from_bits(u16::from_le_bytes([row_data[0], row_data[1]])).to_f32();
let qs = &row_data[2..34];
let mut sum = 0.0f32;
for (i, &q) in qs.iter().enumerate() {
sum += (q as i8) as f32 * inp[i];
}
d * sum
},
)
.unwrap();
for (i, (&s, &p)) in seq_output.iter().zip(par_output.iter()).enumerate() {
assert!(
(s - p).abs() < 1e-4,
"row {i}: sequential={s}, parallel={p}"
);
}
}
#[test]
fn test_should_parallelize() {
assert!(!should_parallelize(1, 256));
assert!(!should_parallelize(32, 256));
assert!(should_parallelize(64, 256));
assert!(should_parallelize(4096, 4096));
assert!(!should_parallelize(128, 32));
}
#[test]
fn test_parallel_gemv_input_too_small_errors() {
let tensor = QuantTensor::new(
make_q8_0_block(1.0, &[0i8; 32]),
vec![1, 32],
oxillama_gguf::GgufTensorType::Q8_0,
);
let input = vec![0.0f32; 4]; let mut output = vec![0.0f32; 1];
let result = parallel_gemv(&tensor, &input, &mut output, 32, 34, |_, _, _| 0.0);
assert!(result.is_err(), "too-small input should error");
}
#[test]
fn test_parallel_gemv_output_too_small_errors() {
let tensor = QuantTensor::new(
make_q8_0_block(1.0, &[0i8; 32]),
vec![2, 32],
oxillama_gguf::GgufTensorType::Q8_0,
);
let input = vec![0.0f32; 32];
let mut output = vec![0.0f32; 1]; let result = parallel_gemv(&tensor, &input, &mut output, 32, 34, |_, _, _| 0.0);
assert!(result.is_err(), "too-small output should error");
}
#[test]
fn test_parallel_gemm_basic() {
let n_rows = 2usize;
let n_cols = 32usize;
let mut data = Vec::new();
for row in 0..n_rows {
let mut qs = [0i8; 32];
for (i, q) in qs.iter_mut().enumerate() {
*q = ((row as i16 + i as i16) % 10) as i8;
}
data.extend_from_slice(&make_q8_0_block(0.25, &qs));
}
let tensor = QuantTensor::new(
data,
vec![n_rows, n_cols],
oxillama_gguf::GgufTensorType::Q8_0,
);
let m = 1usize;
let k = n_cols;
let input = vec![1.0f32; k]; let mut output = vec![0.0f32; m * n_rows];
let dims = GemmDims {
m,
n: n_rows,
k,
block_size: 32,
block_bytes: 34,
};
let result = parallel_gemm(&tensor, &input, &mut output, &dims, |row_data, inp, _nc| {
let d = half::f16::from_bits(u16::from_le_bytes([row_data[0], row_data[1]])).to_f32();
let qs = &row_data[2..34];
let mut sum = 0.0f32;
for (i, &q) in qs.iter().enumerate() {
if i < inp.len() {
sum += (q as i8) as f32 * inp[i];
}
}
d * sum
});
assert!(result.is_ok(), "parallel_gemm should succeed: {result:?}");
}
}