use oxibonsai_core::BlockQ8K;
use crate::error::{KernelError, KernelResult};
pub fn gemv_q8k(
blocks: &[BlockQ8K],
input: &[f32],
output: &mut [f32],
n_rows: usize,
in_features: usize,
) -> KernelResult<()> {
const QK_K: usize = 256;
if in_features == 0 || in_features % QK_K != 0 {
return Err(KernelError::NotBlockAligned {
count: in_features,
block_size: QK_K,
});
}
if input.len() < in_features {
return Err(KernelError::DimensionMismatch {
expected: in_features,
got: input.len(),
});
}
if output.len() < n_rows {
return Err(KernelError::BufferTooSmall {
needed: n_rows,
available: output.len(),
});
}
let blocks_per_row = in_features / QK_K;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::DimensionMismatch {
expected: expected_blocks,
got: blocks.len(),
});
}
let mut row_buf = vec![0.0f32; in_features];
for row in 0..n_rows {
let row_blocks = &blocks[row * blocks_per_row..(row + 1) * blocks_per_row];
BlockQ8K::dequant(row_blocks, &mut row_buf).map_err(KernelError::Core)?;
let acc: f32 = row_buf.iter().zip(input.iter()).map(|(w, x)| w * x).sum();
output[row] = acc;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use oxibonsai_core::BlockQ8K;
fn make_q8k_block(value: f32) -> BlockQ8K {
let input = vec![value; 256];
let blocks = BlockQ8K::quantize(&input).expect("quantize ok");
blocks[0]
}
#[test]
fn gemv_q8k_single_row_uniform() {
let block = make_q8k_block(1.0);
let input = vec![1.0f32; 256];
let mut output = vec![0.0f32; 1];
gemv_q8k(&[block], &input, &mut output, 1, 256).expect("gemv ok");
assert!(
(output[0] - 256.0).abs() < 5.0,
"expected ~256.0, got {}",
output[0]
);
}
#[test]
fn gemv_q8k_two_rows() {
let block_pos = make_q8k_block(0.5);
let block_neg = make_q8k_block(-0.5);
let input = vec![1.0f32; 256];
let mut output = vec![0.0f32; 2];
gemv_q8k(&[block_pos, block_neg], &input, &mut output, 2, 256).expect("gemv ok");
assert!(
(output[0] - 128.0).abs() < 5.0,
"row 0: expected ~128, got {}",
output[0]
);
assert!(
(output[1] + 128.0).abs() < 5.0,
"row 1: expected ~-128, got {}",
output[1]
);
}
#[test]
fn gemv_q8k_not_block_aligned_errors() {
let block = make_q8k_block(1.0);
let input = vec![1.0f32; 100];
let mut output = vec![0.0f32; 1];
assert!(
gemv_q8k(&[block], &input, &mut output, 1, 100).is_err(),
"should error when in_features not multiple of 256"
);
}
#[test]
fn gemv_q8k_wrong_block_count_errors() {
let block = make_q8k_block(1.0);
let input = vec![1.0f32; 256];
let mut output = vec![0.0f32; 2];
assert!(
gemv_q8k(&[block], &input, &mut output, 2, 256).is_err(),
"should error on block count mismatch"
);
}
#[test]
fn gemv_q8k_output_too_small_errors() {
let block = make_q8k_block(1.0);
let input = vec![1.0f32; 256];
let mut output = vec![0.0f32; 0];
assert!(
gemv_q8k(&[block], &input, &mut output, 1, 256).is_err(),
"should error when output buffer is too small"
);
}
}