use oxibonsai_core::{fp8_e4m3_decode, fp8_e5m2_decode, BlockFP8E4M3, BlockFP8E5M2, QK_FP8};
use crate::error::{KernelError, KernelResult};
pub fn dequant_fp8_e4m3(blocks: &[BlockFP8E4M3], output: &mut [f32]) -> KernelResult<()> {
let needed = blocks.len() * QK_FP8;
if output.len() < needed {
return Err(KernelError::BufferTooSmall {
needed,
available: output.len(),
});
}
for (bi, block) in blocks.iter().enumerate() {
let d = block.d.to_f32();
let base = bi * QK_FP8;
for i in 0..QK_FP8 {
output[base + i] = d * fp8_e4m3_decode(block.qs[i]);
}
}
Ok(())
}
pub fn dequant_fp8_e5m2(blocks: &[BlockFP8E5M2], output: &mut [f32]) -> KernelResult<()> {
let needed = blocks.len() * QK_FP8;
if output.len() < needed {
return Err(KernelError::BufferTooSmall {
needed,
available: output.len(),
});
}
for (bi, block) in blocks.iter().enumerate() {
let d = block.d.to_f32();
let base = bi * QK_FP8;
for i in 0..QK_FP8 {
output[base + i] = d * fp8_e5m2_decode(block.qs[i]);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use half::f16;
fn make_e4m3_block(scale: f32, qs: [u8; 32]) -> BlockFP8E4M3 {
BlockFP8E4M3 {
qs,
d: f16::from_f32(scale),
}
}
fn make_e5m2_block(scale: f32, qs: [u8; 32]) -> BlockFP8E5M2 {
BlockFP8E5M2 {
qs,
d: f16::from_f32(scale),
}
}
#[test]
fn e4m3_dequant_all_zeros() {
let block = make_e4m3_block(2.0, [0x00u8; 32]);
let mut output = vec![0.0f32; QK_FP8];
dequant_fp8_e4m3(&[block], &mut output).expect("dequant should succeed");
for (i, &v) in output.iter().enumerate() {
assert!(v.abs() < 1e-6, "index {i}: expected 0.0, got {v}");
}
}
#[test]
fn e4m3_dequant_ones_with_scale() {
let block = make_e4m3_block(3.0, [0x38u8; 32]);
let mut output = vec![0.0f32; QK_FP8];
dequant_fp8_e4m3(&[block], &mut output).expect("dequant should succeed");
for (i, &v) in output.iter().enumerate() {
assert!((v - 3.0).abs() < 0.02, "index {i}: expected ~3.0, got {v}");
}
}
#[test]
fn e4m3_dequant_negative_scale() {
let block = make_e4m3_block(-1.0, [0x38u8; 32]);
let mut output = vec![0.0f32; QK_FP8];
dequant_fp8_e4m3(&[block], &mut output).expect("dequant should succeed");
for (i, &v) in output.iter().enumerate() {
assert!((v + 1.0).abs() < 0.02, "index {i}: expected ~-1.0, got {v}");
}
}
#[test]
fn e4m3_dequant_two_blocks_independent() {
let blocks = vec![
make_e4m3_block(1.0, [0x38u8; 32]),
make_e4m3_block(2.0, [0x38u8; 32]),
];
let mut output = vec![0.0f32; QK_FP8 * 2];
dequant_fp8_e4m3(&blocks, &mut output).expect("dequant should succeed");
for (i, &v) in output.iter().enumerate().take(QK_FP8) {
assert!(
(v - 1.0).abs() < 0.02,
"block0[{i}]: expected ~1.0, got {v}"
);
}
for (i, &v) in output.iter().enumerate().skip(QK_FP8).take(QK_FP8) {
assert!(
(v - 2.0).abs() < 0.02,
"block1[{i}]: expected ~2.0, got {v}"
);
}
}
#[test]
fn e4m3_dequant_empty_blocks() {
let mut output = vec![];
dequant_fp8_e4m3(&[], &mut output).expect("empty dequant should succeed");
}
#[test]
fn e4m3_dequant_buffer_too_small() {
let block = make_e4m3_block(1.0, [0x38u8; 32]);
let mut output = vec![0.0f32; QK_FP8 - 1];
let result = dequant_fp8_e4m3(&[block], &mut output);
assert!(
matches!(result, Err(KernelError::BufferTooSmall { .. })),
"expected BufferTooSmall, got {result:?}"
);
}
#[test]
fn e4m3_dequant_exact_buffer_size() {
let block = make_e4m3_block(1.0, [0x38u8; 32]);
let mut output = vec![0.0f32; QK_FP8];
dequant_fp8_e4m3(&[block], &mut output).expect("exact-size buffer should succeed");
}
#[test]
fn e4m3_dequant_oversized_buffer() {
let block = make_e4m3_block(1.0, [0x38u8; 32]);
let mut output = vec![99.0f32; QK_FP8 + 10];
dequant_fp8_e4m3(&[block], &mut output).expect("oversized buffer should succeed");
for (i, &v) in output.iter().enumerate().skip(QK_FP8) {
assert_eq!(v, 99.0, "trailing element {i} was modified");
}
}
#[test]
fn e4m3_dequant_negative_weights() {
let block = make_e4m3_block(1.0, [0xB8u8; 32]);
let mut output = vec![0.0f32; QK_FP8];
dequant_fp8_e4m3(&[block], &mut output).expect("dequant should succeed");
for (i, &v) in output.iter().enumerate() {
assert!((v + 1.0).abs() < 0.02, "index {i}: expected ~-1.0, got {v}");
}
}
#[test]
fn e5m2_dequant_all_zeros() {
let block = make_e5m2_block(5.0, [0x00u8; 32]);
let mut output = vec![0.0f32; QK_FP8];
dequant_fp8_e5m2(&[block], &mut output).expect("dequant should succeed");
for (i, &v) in output.iter().enumerate() {
assert!(v.abs() < 1e-6, "index {i}: expected 0.0, got {v}");
}
}
#[test]
fn e5m2_dequant_ones_with_scale() {
let block = make_e5m2_block(2.0, [0x3Cu8; 32]);
let mut output = vec![0.0f32; QK_FP8];
dequant_fp8_e5m2(&[block], &mut output).expect("dequant should succeed");
for (i, &v) in output.iter().enumerate() {
assert!((v - 2.0).abs() < 0.02, "index {i}: expected ~2.0, got {v}");
}
}
#[test]
fn e5m2_dequant_buffer_too_small() {
let block = make_e5m2_block(1.0, [0x00u8; 32]);
let mut output = vec![0.0f32; 0];
let result = dequant_fp8_e5m2(&[block], &mut output);
assert!(
matches!(result, Err(KernelError::BufferTooSmall { .. })),
"expected BufferTooSmall, got {result:?}"
);
}
#[test]
fn e5m2_dequant_two_blocks_independent() {
let blocks = vec![
make_e5m2_block(1.0, [0x3Cu8; 32]), make_e5m2_block(4.0, [0x3Cu8; 32]), ];
let mut output = vec![0.0f32; QK_FP8 * 2];
dequant_fp8_e5m2(&blocks, &mut output).expect("dequant should succeed");
for (i, &v) in output.iter().enumerate().take(QK_FP8) {
assert!(
(v - 1.0).abs() < 0.02,
"block0[{i}]: expected ~1.0, got {v}"
);
}
for (i, &v) in output.iter().enumerate().skip(QK_FP8).take(QK_FP8) {
assert!(
(v - 4.0).abs() < 0.02,
"block1[{i}]: expected ~4.0, got {v}"
);
}
}
}