use oxibonsai_core::{BlockQ4_0, QK_Q4_0};
use crate::error::{KernelError, KernelResult};
pub fn gemv_q4_0(
blocks: &[BlockQ4_0],
input: &[f32],
output: &mut [f32],
n_rows: usize,
in_features: usize,
) -> KernelResult<()> {
if in_features % QK_Q4_0 != 0 {
return Err(KernelError::NotBlockAligned {
count: in_features,
block_size: QK_Q4_0,
});
}
let blocks_per_row = in_features / QK_Q4_0;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::DimensionMismatch {
expected: expected_blocks,
got: blocks.len(),
});
}
if input.len() < in_features {
return Err(KernelError::DimensionMismatch {
expected: in_features,
got: input.len(),
});
}
if output.len() < n_rows {
return Err(KernelError::BufferTooSmall {
needed: n_rows,
available: output.len(),
});
}
let mut tmp = [0.0f32; 32];
for row in 0..n_rows {
let mut acc = 0.0f32;
for b in 0..blocks_per_row {
let block = &blocks[row * blocks_per_row + b];
block.dequant_to_buf(&mut tmp);
let base = b * QK_Q4_0;
for j in 0..QK_Q4_0 {
acc += tmp[j] * input[base + j];
}
}
output[row] = acc;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use half::f16;
fn make_q4_block(scale: f32, nibbles: [u8; 32]) -> BlockQ4_0 {
let mut qs = [0u8; 16];
for j in 0..32 {
let n = nibbles[j] & 0x0F;
if j % 2 == 0 {
qs[j / 2] = n;
} else {
qs[j / 2] |= n << 4;
}
}
BlockQ4_0 {
d: f16::from_f32(scale),
qs,
}
}
#[test]
fn q4_0_all_zero_weights() {
let block = make_q4_block(1.0, [8u8; 32]);
let blocks = vec![block];
let input = vec![1.0f32; 32];
let mut output = vec![99.0f32; 1];
gemv_q4_0(&blocks, &input, &mut output, 1, 32).unwrap();
assert!(
output[0].abs() < 1e-5,
"all-zero weights → output 0, got {}",
output[0]
);
}
#[test]
fn q4_0_max_nibbles_single_block() {
let scale = 1.0f32;
let block = make_q4_block(scale, [15u8; 32]);
let blocks = vec![block];
let input = vec![1.0f32; 32];
let mut output = vec![0.0f32; 1];
gemv_q4_0(&blocks, &input, &mut output, 1, 32).unwrap();
let expected = 32.0 * 7.0 * scale;
assert!(
(output[0] - expected).abs() < 1.0,
"max nibbles: expected {expected}, got {}",
output[0]
);
}
#[test]
fn q4_0_multiple_rows() {
let zero_block = make_q4_block(1.0, [8u8; 32]);
let unit_block = make_q4_block(1.0, [9u8; 32]);
let blocks = vec![
unit_block, zero_block, zero_block, unit_block, unit_block, unit_block, ];
let input = vec![1.0f32; 64];
let mut output = vec![0.0f32; 3];
gemv_q4_0(&blocks, &input, &mut output, 3, 64).unwrap();
assert!((output[0] - 32.0).abs() < 1.0, "row0: {}", output[0]);
assert!((output[1] - 32.0).abs() < 1.0, "row1: {}", output[1]);
assert!((output[2] - 64.0).abs() < 1.0, "row2: {}", output[2]);
}
#[test]
fn q4_0_not_block_aligned() {
let block = make_q4_block(1.0, [8u8; 32]);
let blocks = vec![block];
let input = vec![1.0f32; 31];
let mut output = vec![0.0f32; 1];
let result = gemv_q4_0(&blocks, &input, &mut output, 1, 31);
assert!(
matches!(result, Err(KernelError::NotBlockAligned { .. })),
"expected NotBlockAligned, got {result:?}"
);
}
#[test]
fn q4_0_wrong_block_count() {
let block = make_q4_block(1.0, [8u8; 32]);
let blocks = vec![block];
let input = vec![1.0f32; 32];
let mut output = vec![0.0f32; 2];
let result = gemv_q4_0(&blocks, &input, &mut output, 2, 32);
assert!(
matches!(result, Err(KernelError::DimensionMismatch { .. })),
"expected DimensionMismatch, got {result:?}"
);
}
#[test]
fn q4_0_output_too_small() {
let block = make_q4_block(1.0, [8u8; 32]);
let blocks = vec![block];
let input = vec![1.0f32; 32];
let mut output = vec![];
let result = gemv_q4_0(&blocks, &input, &mut output, 1, 32);
assert!(
matches!(result, Err(KernelError::BufferTooSmall { .. })),
"expected BufferTooSmall, got {result:?}"
);
}
#[test]
fn q4_0_gemv_matches_quantized_dequant() {
use oxibonsai_core::BlockQ4_0 as BQ;
let raw: Vec<f32> = (0..64).map(|i| (i as f32) * 0.25 - 8.0).collect();
let blocks = BQ::quantize(&raw).unwrap();
let mut deq = vec![0.0f32; 64];
BQ::dequant(&blocks, &mut deq).unwrap();
let input = vec![1.0f32; 64];
let reference: f32 = deq.iter().sum();
let mut output = vec![0.0f32; 1];
gemv_q4_0(&blocks, &input, &mut output, 1, 64).unwrap();
assert!(
(output[0] - reference).abs() < 1e-3,
"GEMV must match dequant+dot: expected {reference}, got {}",
output[0]
);
}
}