use oxibonsai_core::{BlockFP8E4M3, BlockFP8E5M2, QK_FP8};
use crate::error::{KernelError, KernelResult};
use crate::gemv_fp8::{gemv_fp8_e4m3, gemv_fp8_e5m2};
pub fn gemm_fp8_e4m3(
blocks: &[BlockFP8E4M3],
inputs: &[f32],
outputs: &mut [f32],
n_rows: usize,
k: usize,
batch: usize,
) -> KernelResult<()> {
if k % QK_FP8 != 0 {
return Err(KernelError::NotBlockAligned {
count: k,
block_size: QK_FP8,
});
}
if inputs.len() < batch * k {
return Err(KernelError::DimensionMismatch {
expected: batch * k,
got: inputs.len(),
});
}
if outputs.len() < batch * n_rows {
return Err(KernelError::BufferTooSmall {
needed: batch * n_rows,
available: outputs.len(),
});
}
let blocks_per_row = k / QK_FP8;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::DimensionMismatch {
expected: expected_blocks,
got: blocks.len(),
});
}
for b in 0..batch {
let input_row = &inputs[b * k..(b + 1) * k];
let output_row = &mut outputs[b * n_rows..(b + 1) * n_rows];
gemv_fp8_e4m3(blocks, input_row, output_row, n_rows, k)?;
}
Ok(())
}
pub fn gemm_fp8_e5m2(
blocks: &[BlockFP8E5M2],
inputs: &[f32],
outputs: &mut [f32],
n_rows: usize,
k: usize,
batch: usize,
) -> KernelResult<()> {
if k % QK_FP8 != 0 {
return Err(KernelError::NotBlockAligned {
count: k,
block_size: QK_FP8,
});
}
if inputs.len() < batch * k {
return Err(KernelError::DimensionMismatch {
expected: batch * k,
got: inputs.len(),
});
}
if outputs.len() < batch * n_rows {
return Err(KernelError::BufferTooSmall {
needed: batch * n_rows,
available: outputs.len(),
});
}
let blocks_per_row = k / QK_FP8;
let expected_blocks = n_rows * blocks_per_row;
if blocks.len() < expected_blocks {
return Err(KernelError::DimensionMismatch {
expected: expected_blocks,
got: blocks.len(),
});
}
for b in 0..batch {
let input_row = &inputs[b * k..(b + 1) * k];
let output_row = &mut outputs[b * n_rows..(b + 1) * n_rows];
gemv_fp8_e5m2(blocks, input_row, output_row, n_rows, k)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::gemv_fp8::gemv_fp8_e4m3;
use half::f16;
fn make_e4m3_block(scale: f32, qs: [u8; 32]) -> BlockFP8E4M3 {
BlockFP8E4M3 {
qs,
d: f16::from_f32(scale),
}
}
fn make_e5m2_block(scale: f32, qs: [u8; 32]) -> BlockFP8E5M2 {
BlockFP8E5M2 {
qs,
d: f16::from_f32(scale),
}
}
#[test]
fn gemm_e4m3_matches_gemv() {
let blocks = vec![
make_e4m3_block(1.0, [0x38u8; 32]), make_e4m3_block(2.0, [0x38u8; 32]), ];
let batch = 3;
let n_rows = 2;
let k = 32;
let mut inputs = vec![0.0_f32; batch * k];
for b in 0..batch {
for j in 0..k {
inputs[b * k + j] = (b + 1) as f32 * 0.5;
}
}
let mut gemm_out = vec![0.0_f32; batch * n_rows];
gemm_fp8_e4m3(&blocks, &inputs, &mut gemm_out, n_rows, k, batch)
.expect("gemm should succeed");
for b in 0..batch {
let input_row = &inputs[b * k..(b + 1) * k];
let mut gemv_out = vec![0.0_f32; n_rows];
gemv_fp8_e4m3(&blocks, input_row, &mut gemv_out, n_rows, k)
.expect("gemv should succeed");
for r in 0..n_rows {
let gm = gemm_out[b * n_rows + r];
let gv = gemv_out[r];
assert!(
(gm - gv).abs() < 1e-4,
"batch={b} row={r}: gemm={gm} vs gemv={gv}"
);
}
}
}
#[test]
fn gemm_e4m3_batch_one_equals_gemv() {
let blocks = vec![
make_e4m3_block(1.0, [0x38u8; 32]),
make_e4m3_block(1.5, [0x38u8; 32]),
make_e4m3_block(0.5, [0x38u8; 32]),
];
let n_rows = 3;
let k = 32;
let inputs: Vec<f32> = (0..k).map(|i| (i as f32) * 0.1).collect();
let mut gemm_out = vec![0.0_f32; n_rows];
let mut gemv_out = vec![0.0_f32; n_rows];
gemm_fp8_e4m3(&blocks, &inputs, &mut gemm_out, n_rows, k, 1)
.expect("gemm(batch=1) should succeed");
gemv_fp8_e4m3(&blocks, &inputs, &mut gemv_out, n_rows, k).expect("gemv should succeed");
for r in 0..n_rows {
assert!(
(gemm_out[r] - gemv_out[r]).abs() < 1e-4,
"row={r}: gemm={} vs gemv={}",
gemm_out[r],
gemv_out[r]
);
}
}
#[test]
fn gemm_e4m3_all_positive_weights() {
let n_rows = 4;
let k = 32;
let batch = 3;
let blocks = vec![make_e4m3_block(1.0, [0x38u8; 32]); n_rows];
let inputs = vec![1.0_f32; batch * k];
let mut outputs = vec![0.0_f32; batch * n_rows];
gemm_fp8_e4m3(&blocks, &inputs, &mut outputs, n_rows, k, batch)
.expect("gemm should succeed");
for b in 0..batch {
for r in 0..n_rows {
let v = outputs[b * n_rows + r];
assert!(
(v - 32.0).abs() < 1.0,
"batch={b} row={r}: expected ~32.0, got {v}"
);
}
}
}
#[test]
fn gemm_e4m3_not_block_aligned() {
let blocks = vec![make_e4m3_block(1.0, [0x38u8; 32])];
let inputs = vec![1.0_f32; 33];
let mut outputs = vec![0.0_f32; 1];
let result = gemm_fp8_e4m3(&blocks, &inputs, &mut outputs, 1, 33, 1);
assert!(
matches!(result, Err(KernelError::NotBlockAligned { .. })),
"expected NotBlockAligned, got {result:?}"
);
}
#[test]
fn gemm_e4m3_output_buffer_too_small() {
let blocks = vec![
make_e4m3_block(1.0, [0x38u8; 32]),
make_e4m3_block(1.0, [0x38u8; 32]),
];
let inputs = vec![1.0_f32; 64]; let mut outputs = vec![0.0_f32; 3]; let result = gemm_fp8_e4m3(&blocks, &inputs, &mut outputs, 2, 32, 2);
assert!(
matches!(result, Err(KernelError::BufferTooSmall { .. })),
"expected BufferTooSmall, got {result:?}"
);
}
#[test]
fn gemm_e4m3_dimension_mismatch_blocks() {
let blocks = vec![
make_e4m3_block(1.0, [0x38u8; 32]),
make_e4m3_block(1.0, [0x38u8; 32]),
];
let inputs = vec![1.0_f32; 32];
let mut outputs = vec![0.0_f32; 3];
let result = gemm_fp8_e4m3(&blocks, &inputs, &mut outputs, 3, 32, 1);
assert!(
matches!(result, Err(KernelError::DimensionMismatch { .. })),
"expected DimensionMismatch, got {result:?}"
);
}
#[test]
fn gemm_e4m3_two_blocks_per_row_batched() {
let n_rows = 2;
let k = 64;
let batch = 2;
let blocks = vec![
make_e4m3_block(1.0, [0x38u8; 32]), make_e4m3_block(1.0, [0x38u8; 32]), make_e4m3_block(1.0, [0x38u8; 32]), make_e4m3_block(1.0, [0x38u8; 32]), ];
let inputs = vec![1.0_f32; batch * k];
let mut outputs = vec![0.0_f32; batch * n_rows];
gemm_fp8_e4m3(&blocks, &inputs, &mut outputs, n_rows, k, batch)
.expect("gemm should succeed");
for b in 0..batch {
for r in 0..n_rows {
let v = outputs[b * n_rows + r];
assert!(
(v - 64.0).abs() < 1.5,
"batch={b} row={r}: expected ~64.0, got {v}"
);
}
}
}
#[test]
fn gemm_e5m2_basic() {
let blocks = vec![
make_e5m2_block(1.0, [0x3Cu8; 32]),
make_e5m2_block(1.0, [0x3Cu8; 32]),
];
let batch = 2;
let n_rows = 2;
let k = 32;
let inputs = vec![1.0_f32; batch * k];
let mut outputs = vec![0.0_f32; batch * n_rows];
gemm_fp8_e5m2(&blocks, &inputs, &mut outputs, n_rows, k, batch)
.expect("e5m2 gemm should succeed");
for b in 0..batch {
for r in 0..n_rows {
let v = outputs[b * n_rows + r];
assert!(
(v - 32.0).abs() < 1.0,
"batch={b} row={r}: expected ~32.0, got {v}"
);
}
}
}
#[test]
fn gemm_e5m2_not_block_aligned() {
let blocks = vec![make_e5m2_block(1.0, [0x3Cu8; 32])];
let inputs = vec![1.0_f32; 33];
let mut outputs = vec![0.0_f32; 1];
let result = gemm_fp8_e5m2(&blocks, &inputs, &mut outputs, 1, 33, 1);
assert!(
matches!(result, Err(KernelError::NotBlockAligned { .. })),
"expected NotBlockAligned, got {result:?}"
);
}
}