use oxibonsai_core::{fp8_e4m3_decode, fp8_e5m2_decode, BlockFP8E4M3, BlockFP8E5M2, QK_FP8};
use crate::error::{KernelError, KernelResult};
pub fn gemv_fp8_e4m3(
blocks: &[BlockFP8E4M3],
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()> {
if k % QK_FP8 != 0 {
return Err(KernelError::NotBlockAligned {
count: k,
block_size: QK_FP8,
});
}
if input.len() < k {
return Err(KernelError::DimensionMismatch {
expected: k,
got: input.len(),
});
}
if output.len() < n_rows {
return Err(KernelError::BufferTooSmall {
needed: n_rows,
available: output.len(),
});
}
let blocks_per_row = k / QK_FP8;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::DimensionMismatch {
expected: expected_blocks,
got: blocks.len(),
});
}
for row in 0..n_rows {
let mut acc = 0.0_f32;
for bi in 0..blocks_per_row {
let block = &blocks[row * blocks_per_row + bi];
let d = block.d.to_f32();
let input_base = bi * QK_FP8;
let mut block_dot = 0.0_f32;
for i in 0..QK_FP8 {
block_dot += fp8_e4m3_decode(block.qs[i]) * input[input_base + i];
}
acc += d * block_dot;
}
output[row] = acc;
}
Ok(())
}
pub fn gemv_fp8_e5m2(
blocks: &[BlockFP8E5M2],
input: &[f32],
output: &mut [f32],
n_rows: usize,
k: usize,
) -> KernelResult<()> {
if k % QK_FP8 != 0 {
return Err(KernelError::NotBlockAligned {
count: k,
block_size: QK_FP8,
});
}
if input.len() < k {
return Err(KernelError::DimensionMismatch {
expected: k,
got: input.len(),
});
}
if output.len() < n_rows {
return Err(KernelError::BufferTooSmall {
needed: n_rows,
available: output.len(),
});
}
let blocks_per_row = k / QK_FP8;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::DimensionMismatch {
expected: expected_blocks,
got: blocks.len(),
});
}
for row in 0..n_rows {
let mut acc = 0.0_f32;
for bi in 0..blocks_per_row {
let block = &blocks[row * blocks_per_row + bi];
let d = block.d.to_f32();
let input_base = bi * QK_FP8;
let mut block_dot = 0.0_f32;
for i in 0..QK_FP8 {
block_dot += fp8_e5m2_decode(block.qs[i]) * input[input_base + i];
}
acc += d * block_dot;
}
output[row] = acc;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use half::f16;
use oxibonsai_core::fp8_e4m3_encode;
fn make_e4m3_block(scale: f32, qs: [u8; 32]) -> BlockFP8E4M3 {
BlockFP8E4M3 {
qs,
d: f16::from_f32(scale),
}
}
fn make_e5m2_block(scale: f32, qs: [u8; 32]) -> BlockFP8E5M2 {
BlockFP8E5M2 {
qs,
d: f16::from_f32(scale),
}
}
#[test]
fn gemv_e4m3_identity_one_row() {
let blocks = vec![make_e4m3_block(1.0, [0x38u8; 32])];
let input = vec![1.0_f32; 32];
let mut output = vec![0.0_f32; 1];
gemv_fp8_e4m3(&blocks, &input, &mut output, 1, 32).expect("gemv should succeed");
assert!(
(output[0] - 32.0).abs() < 0.5,
"expected ~32.0, got {}",
output[0]
);
}
#[test]
fn gemv_e4m3_two_rows_different_scales() {
let blocks = vec![
make_e4m3_block(1.0, [0x38u8; 32]), make_e4m3_block(2.0, [0x38u8; 32]), ];
let input = vec![1.0_f32; 32];
let mut output = vec![0.0_f32; 2];
gemv_fp8_e4m3(&blocks, &input, &mut output, 2, 32).expect("gemv should succeed");
assert!(
(output[0] - 32.0).abs() < 0.5,
"row0: expected ~32.0, got {}",
output[0]
);
assert!(
(output[1] - 64.0).abs() < 1.0,
"row1: expected ~64.0, got {}",
output[1]
);
}
#[test]
fn gemv_e4m3_two_blocks_per_row() {
let blocks = vec![
make_e4m3_block(1.0, [0x38u8; 32]),
make_e4m3_block(1.0, [0x38u8; 32]),
];
let input = vec![1.0_f32; 64];
let mut output = vec![0.0_f32; 1];
gemv_fp8_e4m3(&blocks, &input, &mut output, 1, 64).expect("gemv should succeed");
assert!(
(output[0] - 64.0).abs() < 1.0,
"expected ~64.0, got {}",
output[0]
);
}
#[test]
fn gemv_e4m3_all_zeros_input() {
let blocks = vec![make_e4m3_block(1.0, [0x38u8; 32])];
let input = vec![0.0_f32; 32];
let mut output = vec![99.0_f32; 1];
gemv_fp8_e4m3(&blocks, &input, &mut output, 1, 32).expect("gemv should succeed");
assert!(
output[0].abs() < 1e-6,
"all-zero input → output should be 0.0, got {}",
output[0]
);
}
#[test]
fn gemv_e4m3_all_zeros_weights() {
let blocks = vec![make_e4m3_block(1.0, [0x00u8; 32])];
let input = vec![1.0_f32; 32];
let mut output = vec![99.0_f32; 1];
gemv_fp8_e4m3(&blocks, &input, &mut output, 1, 32).expect("gemv should succeed");
assert!(
output[0].abs() < 1e-6,
"all-zero weights → output should be 0.0, got {}",
output[0]
);
}
#[test]
fn gemv_e4m3_not_block_aligned() {
let blocks = vec![make_e4m3_block(1.0, [0x38u8; 32])];
let input = vec![1.0_f32; 31];
let mut output = vec![0.0_f32; 1];
let result = gemv_fp8_e4m3(&blocks, &input, &mut output, 1, 31);
assert!(
matches!(result, Err(KernelError::NotBlockAligned { .. })),
"expected NotBlockAligned, got {result:?}"
);
}
#[test]
fn gemv_e4m3_dimension_mismatch_blocks() {
let blocks = vec![make_e4m3_block(1.0, [0x38u8; 32])];
let input = vec![1.0_f32; 32];
let mut output = vec![0.0_f32; 2];
let result = gemv_fp8_e4m3(&blocks, &input, &mut output, 2, 32);
assert!(
matches!(result, Err(KernelError::DimensionMismatch { .. })),
"expected DimensionMismatch, got {result:?}"
);
}
#[test]
fn gemv_e4m3_output_buffer_too_small() {
let blocks = vec![make_e4m3_block(1.0, [0x38u8; 32])];
let input = vec![1.0_f32; 32];
let mut output = vec![];
let result = gemv_fp8_e4m3(&blocks, &input, &mut output, 1, 32);
assert!(
matches!(result, Err(KernelError::BufferTooSmall { .. })),
"expected BufferTooSmall, got {result:?}"
);
}
#[test]
fn gemv_e4m3_unit_input() {
let w = fp8_e4m3_encode(5.0);
let mut qs = [0x00u8; 32];
qs[0] = w;
let blocks = vec![make_e4m3_block(1.0, qs)];
let mut input = vec![0.0_f32; 32];
input[0] = 1.0;
let mut output = vec![0.0_f32; 1];
gemv_fp8_e4m3(&blocks, &input, &mut output, 1, 32).expect("gemv should succeed");
let expected = fp8_e4m3_decode(w); assert!(
(output[0] - expected).abs() < 1e-5,
"unit input: expected {expected}, got {}",
output[0]
);
}
#[test]
fn gemv_e5m2_identity_one_row() {
let blocks = vec![make_e5m2_block(1.0, [0x3Cu8; 32])];
let input = vec![1.0_f32; 32];
let mut output = vec![0.0_f32; 1];
gemv_fp8_e5m2(&blocks, &input, &mut output, 1, 32).expect("gemv should succeed");
assert!(
(output[0] - 32.0).abs() < 0.5,
"expected ~32.0, got {}",
output[0]
);
}
#[test]
fn gemv_e5m2_not_block_aligned() {
let blocks = vec![make_e5m2_block(1.0, [0x3Cu8; 32])];
let input = vec![1.0_f32; 33];
let mut output = vec![0.0_f32; 1];
let result = gemv_fp8_e5m2(&blocks, &input, &mut output, 1, 33);
assert!(
matches!(result, Err(KernelError::NotBlockAligned { .. })),
"expected NotBlockAligned, got {result:?}"
);
}
#[test]
fn gemv_e5m2_dimension_mismatch() {
let blocks = vec![
make_e5m2_block(1.0, [0x3Cu8; 32]),
make_e5m2_block(1.0, [0x3Cu8; 32]),
];
let input = vec![1.0_f32; 32];
let mut output = vec![0.0_f32; 3];
let result = gemv_fp8_e5m2(&blocks, &input, &mut output, 3, 32);
assert!(
matches!(result, Err(KernelError::DimensionMismatch { .. })),
"expected DimensionMismatch, got {result:?}"
);
}
}