use half::f16;
use crate::error::{BonsaiError, BonsaiResult};
pub const QK_Q4_0: usize = 32;
pub const BLOCK_Q4_0_BYTES: usize = 18;
pub const QK_Q8_0: usize = 32;
pub const BLOCK_Q8_0_BYTES: usize = 34;
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(C)]
pub struct BlockQ4_0 {
pub d: f16,
pub qs: [u8; 16],
}
const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == BLOCK_Q4_0_BYTES);
impl BlockQ4_0 {
pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
let expected_len = blocks.len() * QK_Q4_0;
if output.len() < expected_len {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q4_0 dequant: output len {} < expected {}",
output.len(),
expected_len
),
});
}
for (block_idx, block) in blocks.iter().enumerate() {
let d = block.d.to_f32();
let base = block_idx * QK_Q4_0;
for j in 0..QK_Q4_0 {
let nibble = if j % 2 == 0 {
(block.qs[j / 2] & 0x0F) as f32
} else {
((block.qs[j / 2] >> 4) & 0x0F) as f32
};
output[base + j] = d * (nibble - 8.0);
}
}
Ok(())
}
pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
if input.len() % QK_Q4_0 != 0 {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q4_0 quantize: input len {} not a multiple of {}",
input.len(),
QK_Q4_0
),
});
}
let num_blocks = input.len() / QK_Q4_0;
let mut blocks = Vec::with_capacity(num_blocks);
for block_idx in 0..num_blocks {
let base = block_idx * QK_Q4_0;
let chunk = &input[base..base + QK_Q4_0];
let max_abs = chunk
.iter()
.filter(|v| !v.is_nan())
.map(|v| v.abs())
.fold(0.0f32, f32::max);
if max_abs == 0.0 {
blocks.push(BlockQ4_0 {
d: f16::ZERO,
qs: [0x88u8; 16], });
continue;
}
let scale = max_abs / 7.0;
let d = f16::from_f32(scale);
let scale_actual = d.to_f32();
let inv_scale = if scale_actual == 0.0 {
0.0
} else {
1.0 / scale_actual
};
let mut qs = [0u8; 16];
for j in 0..QK_Q4_0 {
let v = chunk[j];
let nibble = (v * inv_scale + 8.5).clamp(0.0, 15.0) as u8;
if j % 2 == 0 {
qs[j / 2] = nibble & 0x0F;
} else {
qs[j / 2] |= (nibble & 0x0F) << 4;
}
}
blocks.push(BlockQ4_0 { d, qs });
}
Ok(blocks)
}
pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
if data.len() % BLOCK_Q4_0_BYTES != 0 {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q4_0 slice_from_bytes: byte len {} not a multiple of {}",
data.len(),
BLOCK_Q4_0_BYTES
),
});
}
let align = std::mem::align_of::<Self>();
if data.as_ptr().align_offset(align) != 0 {
return Err(BonsaiError::KQuantError {
reason: format!("Q4_0 slice_from_bytes: pointer not {}-byte aligned", align),
});
}
let count = data.len() / BLOCK_Q4_0_BYTES;
let ptr = data.as_ptr() as *const Self;
Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
}
#[inline]
pub fn dequant_to_buf(&self, buf: &mut [f32; 32]) {
let d = self.d.to_f32();
for (j, out) in buf.iter_mut().enumerate() {
let nibble = if j % 2 == 0 {
(self.qs[j / 2] & 0x0F) as f32
} else {
((self.qs[j / 2] >> 4) & 0x0F) as f32
};
*out = d * (nibble - 8.0);
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(C)]
pub struct BlockQ8_0 {
pub d: f16,
pub qs: [i8; 32],
}
const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == BLOCK_Q8_0_BYTES);
impl BlockQ8_0 {
pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
let expected_len = blocks.len() * QK_Q8_0;
if output.len() < expected_len {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q8_0 dequant: output len {} < expected {}",
output.len(),
expected_len
),
});
}
for (block_idx, block) in blocks.iter().enumerate() {
let d = block.d.to_f32();
let base = block_idx * QK_Q8_0;
for (j, &q) in block.qs.iter().enumerate() {
output[base + j] = d * (q as f32);
}
}
Ok(())
}
pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
if input.len() % QK_Q8_0 != 0 {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q8_0 quantize: input len {} not a multiple of {}",
input.len(),
QK_Q8_0
),
});
}
let num_blocks = input.len() / QK_Q8_0;
let mut blocks = Vec::with_capacity(num_blocks);
for block_idx in 0..num_blocks {
let base = block_idx * QK_Q8_0;
let chunk = &input[base..base + QK_Q8_0];
let max_abs = chunk
.iter()
.filter(|v| !v.is_nan())
.map(|v| v.abs())
.fold(0.0f32, f32::max);
if max_abs == 0.0 {
blocks.push(BlockQ8_0 {
d: f16::ZERO,
qs: [0i8; 32],
});
continue;
}
let scale = max_abs / 127.0;
let d = f16::from_f32(scale);
let scale_actual = d.to_f32();
let inv_scale = if scale_actual == 0.0 {
0.0
} else {
1.0 / scale_actual
};
let mut qs = [0i8; 32];
for (j, &v) in chunk.iter().enumerate() {
let q = (v * inv_scale).round().clamp(-127.0, 127.0) as i8;
qs[j] = q;
}
blocks.push(BlockQ8_0 { d, qs });
}
Ok(blocks)
}
pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
if data.len() % BLOCK_Q8_0_BYTES != 0 {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q8_0 slice_from_bytes: byte len {} not a multiple of {}",
data.len(),
BLOCK_Q8_0_BYTES
),
});
}
let align = std::mem::align_of::<Self>();
if data.as_ptr().align_offset(align) != 0 {
return Err(BonsaiError::KQuantError {
reason: format!("Q8_0 slice_from_bytes: pointer not {}-byte aligned", align),
});
}
let count = data.len() / BLOCK_Q8_0_BYTES;
let ptr = data.as_ptr() as *const Self;
Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
}
#[inline]
pub fn dequant_to_buf(&self, buf: &mut [f32; 32]) {
let d = self.d.to_f32();
for (j, &q) in self.qs.iter().enumerate() {
buf[j] = d * (q as f32);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn q4_0_block_size_correct() {
assert_eq!(std::mem::size_of::<BlockQ4_0>(), BLOCK_Q4_0_BYTES);
assert_eq!(BLOCK_Q4_0_BYTES, 18);
}
#[test]
fn q8_0_block_size_correct() {
assert_eq!(std::mem::size_of::<BlockQ8_0>(), BLOCK_Q8_0_BYTES);
assert_eq!(BLOCK_Q8_0_BYTES, 34);
}
#[test]
fn qk_constants_correct() {
assert_eq!(QK_Q4_0, 32);
assert_eq!(QK_Q8_0, 32);
}
#[test]
fn q4_0_dequant_roundtrip() {
let values: Vec<f32> = (0..32).map(|i| (i as f32) * 0.5 - 7.5).collect();
let blocks = BlockQ4_0::quantize(&values).unwrap();
assert_eq!(blocks.len(), 1);
let mut output = vec![0.0f32; 32];
BlockQ4_0::dequant(&blocks, &mut output).unwrap();
let max_err: f32 = values
.iter()
.zip(output.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
max_err < 1.5,
"Q4_0 round-trip max error: {max_err} (values range ±7.5)"
);
}
#[test]
fn q4_0_all_zeros() {
let values = vec![0.0f32; 32];
let blocks = BlockQ4_0::quantize(&values).unwrap();
let mut output = vec![0.0f32; 32];
BlockQ4_0::dequant(&blocks, &mut output).unwrap();
assert!(
output.iter().all(|&x| x == 0.0),
"all-zero input should give all-zero output"
);
}
#[test]
fn q4_0_nibble_extremes() {
let mut values = vec![0.0f32; 32];
values[0] = 7.0;
values[1] = -7.0;
let blocks = BlockQ4_0::quantize(&values).unwrap();
let mut output = vec![0.0f32; 32];
BlockQ4_0::dequant(&blocks, &mut output).unwrap();
assert!(
(output[0] - 7.0).abs() < 1.1,
"max weight round-trip: got {}",
output[0]
);
assert!(
(output[1] + 7.0).abs() < 1.1,
"min weight round-trip: got {}",
output[1]
);
}
#[test]
fn q4_0_slice_from_bytes_valid() {
let block = BlockQ4_0 {
d: f16::from_f32(1.0),
qs: [0x88u8; 16],
};
let bytes: &[u8] = unsafe {
std::slice::from_raw_parts((&block as *const BlockQ4_0).cast::<u8>(), BLOCK_Q4_0_BYTES)
};
let result = BlockQ4_0::slice_from_bytes(bytes).expect("aligned slice should succeed");
assert_eq!(result.len(), 1);
assert_eq!(result[0].d, f16::from_f32(1.0));
}
#[test]
fn q4_0_slice_from_bytes_bad_len() {
let data = vec![0u8; 17]; assert!(
BlockQ4_0::slice_from_bytes(&data).is_err(),
"bad length should be rejected"
);
}
#[test]
fn q4_0_block_count_validation() {
let values = vec![1.0f32; 96]; let blocks = BlockQ4_0::quantize(&values).unwrap();
assert_eq!(blocks.len(), 3);
}
#[test]
fn q4_0_quantize_wrong_len() {
assert!(
BlockQ4_0::quantize(&[1.0f32; 15]).is_err(),
"non-multiple of 32 should be rejected"
);
}
#[test]
fn q4_0_dequant_too_small_buffer() {
let blocks = BlockQ4_0::quantize(&[1.0f32; 32]).unwrap();
let mut out = vec![0.0f32; 10];
assert!(
BlockQ4_0::dequant(&blocks, &mut out).is_err(),
"output too small should be rejected"
);
}
#[test]
fn q4_0_dequant_to_buf_matches_dequant() {
let values: Vec<f32> = (0..32).map(|i| (i as f32) - 16.0).collect();
let blocks = BlockQ4_0::quantize(&values).unwrap();
let mut full_out = vec![0.0f32; 32];
BlockQ4_0::dequant(&blocks, &mut full_out).unwrap();
let mut buf = [0.0f32; 32];
blocks[0].dequant_to_buf(&mut buf);
for (a, b) in full_out.iter().zip(buf.iter()) {
assert!((a - b).abs() < 1e-6, "dequant_to_buf must match dequant");
}
}
#[test]
fn q4_0_multi_block_no_nan() {
let values: Vec<f32> = (0..64).map(|i| (i as f32) * 0.25 - 8.0).collect();
let blocks = BlockQ4_0::quantize(&values).unwrap();
assert_eq!(blocks.len(), 2);
let mut out = vec![0.0f32; 64];
BlockQ4_0::dequant(&blocks, &mut out).unwrap();
assert!(out.iter().all(|x| !x.is_nan()), "no NaN in output");
}
#[test]
fn q4_0_scale_nonzero_for_nonzero_input() {
let values = vec![1.0f32; 32];
let blocks = BlockQ4_0::quantize(&values).unwrap();
assert_ne!(blocks[0].d, f16::ZERO, "scale must be non-zero");
}
#[test]
fn q8_0_dequant_roundtrip() {
let values: Vec<f32> = (0..32).map(|i| (i as f32) * 0.1 - 1.6).collect();
let blocks = BlockQ8_0::quantize(&values).unwrap();
let mut output = vec![0.0f32; 32];
BlockQ8_0::dequant(&blocks, &mut output).unwrap();
let max_err: f32 = values
.iter()
.zip(output.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
max_err < 0.05,
"Q8_0 round-trip max error: {max_err} (8-bit should be very accurate)"
);
}
#[test]
fn q8_0_all_zeros() {
let values = vec![0.0f32; 32];
let blocks = BlockQ8_0::quantize(&values).unwrap();
let mut output = vec![0.0f32; 32];
BlockQ8_0::dequant(&blocks, &mut output).unwrap();
assert!(output.iter().all(|&x| x == 0.0));
}
#[test]
fn q8_0_int8_extremes() {
let mut values = vec![0.0f32; 32];
values[0] = 127.0;
values[1] = -127.0;
let blocks = BlockQ8_0::quantize(&values).unwrap();
let scale = blocks[0].d.to_f32();
assert!((scale - 1.0).abs() < 0.01, "scale should be ~1.0: {scale}");
assert_eq!(blocks[0].qs[0], 127, "max quantized to 127");
assert_eq!(blocks[0].qs[1], -127, "min quantized to -127");
}
#[test]
fn q8_0_slice_alignment() {
let block = BlockQ8_0 {
d: f16::from_f32(2.0),
qs: [0i8; 32],
};
let bytes: &[u8] = unsafe {
std::slice::from_raw_parts((&block as *const BlockQ8_0).cast::<u8>(), BLOCK_Q8_0_BYTES)
};
let result = BlockQ8_0::slice_from_bytes(bytes).expect("aligned slice should succeed");
assert_eq!(result.len(), 1);
assert_eq!(result[0].d, f16::from_f32(2.0));
}
#[test]
fn q8_0_quantize_scale() {
let mut values = vec![0.0f32; 32];
values[5] = 63.5; let blocks = BlockQ8_0::quantize(&values).unwrap();
let scale = blocks[0].d.to_f32();
assert!(
(scale - 0.5).abs() < 0.02,
"scale should be ~0.5 for max=63.5, got {scale}"
);
}
#[test]
fn q8_0_slice_bad_len() {
let data = vec![0u8; 35]; assert!(BlockQ8_0::slice_from_bytes(&data).is_err());
}
#[test]
fn q8_0_quantize_wrong_len() {
assert!(BlockQ8_0::quantize(&[1.0f32; 17]).is_err());
}
#[test]
fn q8_0_dequant_too_small_buffer() {
let blocks = BlockQ8_0::quantize(&[0.0f32; 32]).unwrap();
let mut out = vec![0.0f32; 5];
assert!(BlockQ8_0::dequant(&blocks, &mut out).is_err());
}
#[test]
fn q8_0_dequant_to_buf_matches_dequant() {
let values: Vec<f32> = (0..32).map(|i| (i as f32) * 3.0 - 48.0).collect();
let blocks = BlockQ8_0::quantize(&values).unwrap();
let mut full_out = vec![0.0f32; 32];
BlockQ8_0::dequant(&blocks, &mut full_out).unwrap();
let mut buf = [0.0f32; 32];
blocks[0].dequant_to_buf(&mut buf);
for (a, b) in full_out.iter().zip(buf.iter()) {
assert!((a - b).abs() < 1e-6, "dequant_to_buf must match dequant");
}
}
#[test]
fn q8_0_positive_negative_mix() {
let values: Vec<f32> = (0..32)
.map(|i| if i % 2 == 0 { i as f32 } else { -(i as f32) })
.collect();
let blocks = BlockQ8_0::quantize(&values).unwrap();
let mut out = vec![0.0f32; 32];
BlockQ8_0::dequant(&blocks, &mut out).unwrap();
for i in (2..32).step_by(2) {
assert!(
out[i] >= 0.0,
"even index should be non-negative: {}",
out[i]
);
}
for i in (1..32).step_by(2) {
assert!(
out[i] <= 0.0,
"odd index should be non-positive: {}",
out[i]
);
}
}
#[test]
fn q8_0_block_count_correct() {
let values = vec![1.0f32; 96]; let blocks = BlockQ8_0::quantize(&values).unwrap();
assert_eq!(blocks.len(), 3);
}
}