use oxibonsai_core::{BlockQ8_0, QK_Q8_0};
use crate::error::{KernelError, KernelResult};
pub fn gemv_q8_0(
blocks: &[BlockQ8_0],
input: &[f32],
output: &mut [f32],
n_rows: usize,
in_features: usize,
) -> KernelResult<()> {
if in_features % QK_Q8_0 != 0 {
return Err(KernelError::NotBlockAligned {
count: in_features,
block_size: QK_Q8_0,
});
}
let blocks_per_row = in_features / QK_Q8_0;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::DimensionMismatch {
expected: expected_blocks,
got: blocks.len(),
});
}
if input.len() < in_features {
return Err(KernelError::DimensionMismatch {
expected: in_features,
got: input.len(),
});
}
if output.len() < n_rows {
return Err(KernelError::BufferTooSmall {
needed: n_rows,
available: output.len(),
});
}
let mut tmp = [0.0f32; 32];
for row in 0..n_rows {
let mut acc = 0.0f32;
for b in 0..blocks_per_row {
let block = &blocks[row * blocks_per_row + b];
block.dequant_to_buf(&mut tmp);
let base = b * QK_Q8_0;
for j in 0..QK_Q8_0 {
acc += tmp[j] * input[base + j];
}
}
output[row] = acc;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use half::f16;
fn make_q8_block(scale: f32, qs: [i8; 32]) -> BlockQ8_0 {
BlockQ8_0 {
d: f16::from_f32(scale),
qs,
}
}
#[test]
fn q8_0_all_zero_weights() {
let block = make_q8_block(1.0, [0i8; 32]);
let blocks = vec![block];
let input = vec![1.0f32; 32];
let mut output = vec![99.0f32; 1];
gemv_q8_0(&blocks, &input, &mut output, 1, 32).unwrap();
assert!(
output[0].abs() < 1e-5,
"all-zero weights → output 0, got {}",
output[0]
);
}
#[test]
fn q8_0_unit_weights_single_block() {
let block = make_q8_block(1.0, [1i8; 32]);
let blocks = vec![block];
let input = vec![1.0f32; 32];
let mut output = vec![0.0f32; 1];
gemv_q8_0(&blocks, &input, &mut output, 1, 32).unwrap();
assert!(
(output[0] - 32.0).abs() < 1e-4,
"unit weights: expected 32.0, got {}",
output[0]
);
}
#[test]
fn q8_0_positive_negative_mix() {
let mut qs = [0i8; 32];
for (j, q) in qs.iter_mut().enumerate() {
*q = if j % 2 == 0 { 1 } else { -1 };
}
let block = make_q8_block(1.0, qs);
let blocks = vec![block];
let input = vec![1.0f32; 32];
let mut output = vec![99.0f32; 1];
gemv_q8_0(&blocks, &input, &mut output, 1, 32).unwrap();
assert!(
output[0].abs() < 1e-4,
"alternating sign: expected 0, got {}",
output[0]
);
}
#[test]
fn q8_0_max_scale() {
let block = make_q8_block(1.0, [127i8; 32]);
let blocks = vec![block];
let input = vec![1.0f32; 32];
let mut output = vec![0.0f32; 1];
gemv_q8_0(&blocks, &input, &mut output, 1, 32).unwrap();
assert!(
(output[0] - 32.0 * 127.0).abs() < 1.0,
"max scale: expected {}, got {}",
32.0 * 127.0,
output[0]
);
}
#[test]
fn q8_0_multiple_rows() {
let zero_block = make_q8_block(1.0, [0i8; 32]);
let unit_block = make_q8_block(1.0, [1i8; 32]);
let blocks = vec![
unit_block, zero_block, zero_block, unit_block, unit_block, unit_block,
];
let input = vec![1.0f32; 64];
let mut output = vec![0.0f32; 3];
gemv_q8_0(&blocks, &input, &mut output, 3, 64).unwrap();
assert!((output[0] - 32.0).abs() < 1e-4, "row0: {}", output[0]);
assert!((output[1] - 32.0).abs() < 1e-4, "row1: {}", output[1]);
assert!((output[2] - 64.0).abs() < 1e-4, "row2: {}", output[2]);
}
#[test]
fn q8_0_not_block_aligned() {
let block = make_q8_block(1.0, [0i8; 32]);
let blocks = vec![block];
let input = vec![1.0f32; 31];
let mut output = vec![0.0f32; 1];
let result = gemv_q8_0(&blocks, &input, &mut output, 1, 31);
assert!(
matches!(result, Err(KernelError::NotBlockAligned { .. })),
"expected NotBlockAligned, got {result:?}"
);
}
#[test]
fn q8_0_wrong_block_count() {
let block = make_q8_block(1.0, [0i8; 32]);
let blocks = vec![block];
let input = vec![1.0f32; 32];
let mut output = vec![0.0f32; 2];
let result = gemv_q8_0(&blocks, &input, &mut output, 2, 32);
assert!(
matches!(result, Err(KernelError::DimensionMismatch { .. })),
"expected DimensionMismatch, got {result:?}"
);
}
#[test]
fn q8_0_output_too_small() {
let block = make_q8_block(1.0, [0i8; 32]);
let blocks = vec![block];
let input = vec![1.0f32; 32];
let mut output = vec![];
let result = gemv_q8_0(&blocks, &input, &mut output, 1, 32);
assert!(
matches!(result, Err(KernelError::BufferTooSmall { .. })),
"expected BufferTooSmall, got {result:?}"
);
}
#[test]
fn q8_0_gemv_matches_quantized_dequant() {
use oxibonsai_core::BlockQ8_0 as BQ;
let raw: Vec<f32> = (0..64).map(|i| (i as f32) * 0.5 - 16.0).collect();
let blocks = BQ::quantize(&raw).unwrap();
let mut deq = vec![0.0f32; 64];
BQ::dequant(&blocks, &mut deq).unwrap();
let input = vec![1.0f32; 64];
let reference: f32 = deq.iter().sum();
let mut output = vec![0.0f32; 1];
gemv_q8_0(&blocks, &input, &mut output, 1, 64).unwrap();
assert!(
(output[0] - reference).abs() < 1e-3,
"GEMV must match dequant+dot: expected {reference}, got {}",
output[0]
);
}
}