use half::f16;
use crate::error::{BonsaiError, BonsaiResult};
pub const QK_K: usize = 256;
pub const BLOCK_Q2_K_BYTES: usize = 84;
pub const BLOCK_Q3K_BYTES: usize = 110;
pub const BLOCK_Q4_K_BYTES: usize = 144;
pub const BLOCK_Q8K_BYTES: usize = 292;
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(C)]
pub struct BlockQ2K {
pub scales: [u8; 16],
pub qs: [u8; 64],
pub d: f16,
pub dmin: f16,
}
const _: () = assert!(std::mem::size_of::<BlockQ2K>() == BLOCK_Q2_K_BYTES);
impl BlockQ2K {
pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
let expected_len = blocks.len() * QK_K;
if output.len() < expected_len {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q2_K dequant: output len {} < expected {}",
output.len(),
expected_len
),
});
}
for (block_idx, block) in blocks.iter().enumerate() {
let d = block.d.to_f32();
let dmin = block.dmin.to_f32();
let base = block_idx * QK_K;
for sub in 0..16 {
let scale_byte = block.scales[sub];
let sc = (scale_byte & 0x0F) as f32; let mn = ((scale_byte >> 4) & 0x0F) as f32;
let sub_offset = sub * 16;
for j in 0..16 {
let global_idx = sub_offset + j;
let byte_idx = global_idx / 4;
let shift = (global_idx % 4) * 2;
let q = ((block.qs[byte_idx] >> shift) & 0x03) as f32;
output[base + global_idx] = d * sc * q - dmin * mn;
}
}
}
Ok(())
}
pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
if input.len() % QK_K != 0 {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q2_K quantize: input len {} not a multiple of {}",
input.len(),
QK_K
),
});
}
let num_blocks = input.len() / QK_K;
let mut blocks = Vec::with_capacity(num_blocks);
for block_idx in 0..num_blocks {
let base = block_idx * QK_K;
let chunk = &input[base..base + QK_K];
let mut sub_scales = [0.0f32; 16];
let mut sub_mins = [0.0f32; 16];
for sub in 0..16 {
let sub_offset = sub * 16;
let sub_chunk = &chunk[sub_offset..sub_offset + 16];
let mut smin = f32::MAX;
let mut smax = f32::MIN;
for &v in sub_chunk {
if v < smin {
smin = v;
}
if v > smax {
smax = v;
}
}
sub_mins[sub] = if smin < 0.0 { -smin } else { 0.0 };
let range = smax + sub_mins[sub];
sub_scales[sub] = if range > 0.0 { range / 3.0 } else { 0.0 };
}
let max_scale = sub_scales.iter().copied().fold(0.0f32, f32::max);
let max_min = sub_mins.iter().copied().fold(0.0f32, f32::max);
let d = if max_scale > 0.0 {
max_scale / 15.0
} else {
0.0
};
let dmin = if max_min > 0.0 { max_min / 15.0 } else { 0.0 };
let inv_d = if d > 0.0 { 1.0 / d } else { 0.0 };
let inv_dmin = if dmin > 0.0 { 1.0 / dmin } else { 0.0 };
let mut scales = [0u8; 16];
let mut quant_sc = [0u8; 16];
let mut quant_mn = [0u8; 16];
for sub in 0..16 {
let sc = (sub_scales[sub] * inv_d + 0.5).min(15.0) as u8;
let mn = (sub_mins[sub] * inv_dmin + 0.5).min(15.0) as u8;
quant_sc[sub] = sc;
quant_mn[sub] = mn;
scales[sub] = sc | (mn << 4);
}
let mut qs = [0u8; 64];
for sub in 0..16 {
let sub_offset = sub * 16;
let sc_f = d * (quant_sc[sub] as f32);
let mn_f = dmin * (quant_mn[sub] as f32);
let inv_sc = if sc_f > 0.0 { 1.0 / sc_f } else { 0.0 };
for j in 0..16 {
let global_idx = sub_offset + j;
let val = chunk[global_idx] + mn_f;
let q = (val * inv_sc + 0.5).clamp(0.0, 3.0) as u8;
let byte_idx = global_idx / 4;
let shift = (global_idx % 4) * 2;
qs[byte_idx] |= q << shift;
}
}
blocks.push(BlockQ2K {
scales,
qs,
d: f16::from_f32(d),
dmin: f16::from_f32(dmin),
});
}
Ok(blocks)
}
pub fn dequant_row_to_buf(blocks_for_row: &[Self], buf: &mut Vec<f32>) {
let start = buf.len();
let n = blocks_for_row.len() * QK_K;
buf.resize(start + n, 0.0f32);
let _ = Self::dequant(blocks_for_row, &mut buf[start..]);
}
pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
if data.len() % BLOCK_Q2_K_BYTES != 0 {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q2_K slice_from_bytes: byte len {} not a multiple of {}",
data.len(),
BLOCK_Q2_K_BYTES
),
});
}
if data.is_empty() {
return Ok(&[]);
}
let align = std::mem::align_of::<Self>();
if data.as_ptr().align_offset(align) != 0 {
return Err(BonsaiError::KQuantError {
reason: format!("Q2_K slice_from_bytes: pointer not {}-byte aligned", align),
});
}
let count = data.len() / BLOCK_Q2_K_BYTES;
let ptr = data.as_ptr() as *const Self;
Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(C)]
pub struct BlockQ3K {
pub hmask: [u8; 32],
pub qs: [u8; 64],
pub scales: [u8; 12],
pub d: f16,
}
const _: () = assert!(std::mem::size_of::<BlockQ3K>() == BLOCK_Q3K_BYTES);
impl BlockQ3K {
pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
let expected_len = blocks.len() * QK_K;
if output.len() < expected_len {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q3_K dequant: output len {} < expected {}",
output.len(),
expected_len
),
});
}
for (block_idx, block) in blocks.iter().enumerate() {
let d = block.d.to_f32();
let base = block_idx * QK_K;
for i in 0..QK_K {
let byte_idx = i / 4;
let bit_shift = (i % 4) * 2;
let lo2 = (block.qs[byte_idx] >> bit_shift) & 0x03;
let hi1 = (block.hmask[i / 8] >> (i % 8)) & 0x01;
let q3 = lo2 | (hi1 << 2);
let q3_signed = (q3 as i32) - 4;
let sub = i / 16;
let scale_nibble = (block.scales[sub / 2] >> (4 * (sub % 2))) & 0x0F;
let scale_signed = (scale_nibble as i8) as i32 - 8;
output[base + i] = d * (scale_signed as f32) * (q3_signed as f32);
}
}
Ok(())
}
pub fn dequant_row_to_buf(blocks_for_row: &[Self], buf: &mut Vec<f32>) {
let start = buf.len();
let n = blocks_for_row.len() * QK_K;
buf.resize(start + n, 0.0f32);
let _ = Self::dequant(blocks_for_row, &mut buf[start..]);
}
pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
if input.len() % QK_K != 0 {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q3_K quantize: input len {} not a multiple of {}",
input.len(),
QK_K
),
});
}
let num_blocks = input.len() / QK_K;
let mut blocks = Vec::with_capacity(num_blocks);
for block_idx in 0..num_blocks {
let chunk = &input[block_idx * QK_K..block_idx * QK_K + QK_K];
let mut sub_max_abs = [0.0f32; 16];
for (sub, slot) in sub_max_abs.iter_mut().enumerate() {
let sub_chunk = &chunk[sub * 16..(sub + 1) * 16];
*slot = sub_chunk.iter().map(|&v| v.abs()).fold(0.0f32, f32::max);
}
let overall_max = sub_max_abs.iter().copied().fold(0.0f32, f32::max);
let d = if overall_max > 0.0 {
overall_max / 21.0
} else {
0.0
};
let inv_d = if d > 0.0 { 1.0 / d } else { 0.0 };
let mut scale_nibbles = [0u8; 16];
for (sub, &max_abs) in sub_max_abs.iter().enumerate() {
let sc_f = if d > 0.0 { max_abs * inv_d / 3.0 } else { 0.0 };
let sc_signed = sc_f.round().clamp(-8.0, 7.0) as i32;
scale_nibbles[sub] = (sc_signed + 8).clamp(0, 15) as u8;
}
let mut scales = [0u8; 12];
for (sub, &nibble_val) in scale_nibbles.iter().enumerate() {
let byte_idx = sub / 2;
let nibble = nibble_val & 0x0F;
if sub % 2 == 0 {
scales[byte_idx] |= nibble;
} else {
scales[byte_idx] |= nibble << 4;
}
}
let mut hmask = [0u8; 32];
let mut qs = [0u8; 64];
for i in 0..QK_K {
let sub = i / 16;
let sc_signed = (scale_nibbles[sub] as i32) - 8;
let eff_scale = d * (sc_signed as f32);
let inv_eff = if eff_scale.abs() > 1e-9 {
1.0 / eff_scale
} else {
0.0
};
let q3_signed = (chunk[i] * inv_eff).round() as i32;
let q3 = (q3_signed + 4).clamp(0, 7) as u8;
let lo2 = q3 & 0x03;
let byte_idx = i / 4;
let bit_shift = (i % 4) * 2;
qs[byte_idx] |= lo2 << bit_shift;
let hi1 = (q3 >> 2) & 0x01;
hmask[i / 8] |= hi1 << (i % 8);
}
blocks.push(BlockQ3K {
hmask,
qs,
scales,
d: f16::from_f32(d),
});
}
Ok(blocks)
}
pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
if data.len() % BLOCK_Q3K_BYTES != 0 {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q3_K slice_from_bytes: byte len {} not a multiple of {}",
data.len(),
BLOCK_Q3K_BYTES
),
});
}
if data.is_empty() {
return Ok(&[]);
}
let align = std::mem::align_of::<Self>();
if data.as_ptr().align_offset(align) != 0 {
return Err(BonsaiError::KQuantError {
reason: format!("Q3_K slice_from_bytes: pointer not {}-byte aligned", align),
});
}
let count = data.len() / BLOCK_Q3K_BYTES;
let ptr = data.as_ptr() as *const Self;
Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(C)]
pub struct BlockQ4K {
pub d: f16,
pub dmin: f16,
pub scales: [u8; 12],
pub qs: [u8; 128],
}
const _: () = assert!(std::mem::size_of::<BlockQ4K>() == BLOCK_Q4_K_BYTES);
fn decode_q4k_scales(scales_raw: &[u8; 12]) -> ([u8; 8], [u8; 8]) {
let mut sc = [0u8; 8];
let mut mn = [0u8; 8];
for i in 0..4 {
sc[2 * i] = scales_raw[i] & 0x0F;
sc[2 * i + 1] = (scales_raw[i] >> 4) & 0x0F;
}
for i in 0..4 {
mn[2 * i] = scales_raw[4 + i] & 0x0F;
mn[2 * i + 1] = (scales_raw[4 + i] >> 4) & 0x0F;
}
for i in 0..4 {
sc[i] |= ((scales_raw[8] >> (2 * i)) & 0x03) << 4;
sc[4 + i] |= ((scales_raw[9] >> (2 * i)) & 0x03) << 4;
}
for i in 0..4 {
mn[i] |= ((scales_raw[10] >> (2 * i)) & 0x03) << 4;
mn[4 + i] |= ((scales_raw[11] >> (2 * i)) & 0x03) << 4;
}
(sc, mn)
}
fn encode_q4k_scales(sc: &[u8; 8], mn: &[u8; 8]) -> [u8; 12] {
let mut out = [0u8; 12];
for i in 0..4 {
out[i] = (sc[2 * i] & 0x0F) | ((sc[2 * i + 1] & 0x0F) << 4);
}
for i in 0..4 {
out[4 + i] = (mn[2 * i] & 0x0F) | ((mn[2 * i + 1] & 0x0F) << 4);
}
for i in 0..4 {
out[8] |= ((sc[i] >> 4) & 0x03) << (2 * i);
out[9] |= ((sc[4 + i] >> 4) & 0x03) << (2 * i);
}
for i in 0..4 {
out[10] |= ((mn[i] >> 4) & 0x03) << (2 * i);
out[11] |= ((mn[4 + i] >> 4) & 0x03) << (2 * i);
}
out
}
impl BlockQ4K {
pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
let expected_len = blocks.len() * QK_K;
if output.len() < expected_len {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q4_K dequant: output len {} < expected {}",
output.len(),
expected_len
),
});
}
for (block_idx, block) in blocks.iter().enumerate() {
let d = block.d.to_f32();
let dmin_val = block.dmin.to_f32();
let base = block_idx * QK_K;
let (sc, mn) = decode_q4k_scales(&block.scales);
for sub in 0..8 {
let sub_scale = d * (sc[sub] as f32);
let sub_min = dmin_val * (mn[sub] as f32);
let sub_offset = sub * 32;
for j in 0..32 {
let global_idx = sub_offset + j;
let byte_idx = global_idx / 2;
let q = if global_idx % 2 == 0 {
(block.qs[byte_idx] & 0x0F) as f32
} else {
((block.qs[byte_idx] >> 4) & 0x0F) as f32
};
output[base + global_idx] = sub_scale * q - sub_min;
}
}
}
Ok(())
}
pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
if input.len() % QK_K != 0 {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q4_K quantize: input len {} not a multiple of {}",
input.len(),
QK_K
),
});
}
let num_blocks = input.len() / QK_K;
let mut blocks = Vec::with_capacity(num_blocks);
for block_idx in 0..num_blocks {
let base = block_idx * QK_K;
let chunk = &input[base..base + QK_K];
let mut sub_scales = [0.0f32; 8];
let mut sub_mins = [0.0f32; 8];
for sub in 0..8 {
let sub_offset = sub * 32;
let sub_chunk = &chunk[sub_offset..sub_offset + 32];
let mut smin = f32::MAX;
let mut smax = f32::MIN;
for &v in sub_chunk {
if v < smin {
smin = v;
}
if v > smax {
smax = v;
}
}
sub_mins[sub] = if smin < 0.0 { -smin } else { 0.0 };
let range = smax + sub_mins[sub];
sub_scales[sub] = if range > 0.0 { range / 15.0 } else { 0.0 };
}
let max_scale = sub_scales.iter().copied().fold(0.0f32, f32::max);
let max_min = sub_mins.iter().copied().fold(0.0f32, f32::max);
let d = if max_scale > 0.0 {
max_scale / 63.0
} else {
0.0
};
let dmin = if max_min > 0.0 { max_min / 63.0 } else { 0.0 };
let inv_d = if d > 0.0 { 1.0 / d } else { 0.0 };
let inv_dmin = if dmin > 0.0 { 1.0 / dmin } else { 0.0 };
let mut sc = [0u8; 8];
let mut mn = [0u8; 8];
for sub in 0..8 {
sc[sub] = (sub_scales[sub] * inv_d + 0.5).min(63.0) as u8;
mn[sub] = (sub_mins[sub] * inv_dmin + 0.5).min(63.0) as u8;
}
let scales = encode_q4k_scales(&sc, &mn);
let mut qs = [0u8; 128];
for sub in 0..8 {
let sub_offset = sub * 32;
let sc_f = d * (sc[sub] as f32);
let mn_f = dmin * (mn[sub] as f32);
let inv_sc = if sc_f > 0.0 { 1.0 / sc_f } else { 0.0 };
for j in 0..32 {
let global_idx = sub_offset + j;
let val = chunk[global_idx] + mn_f;
let q = (val * inv_sc + 0.5).clamp(0.0, 15.0) as u8;
let byte_idx = global_idx / 2;
if global_idx % 2 == 0 {
qs[byte_idx] |= q & 0x0F;
} else {
qs[byte_idx] |= (q & 0x0F) << 4;
}
}
}
blocks.push(BlockQ4K {
d: f16::from_f32(d),
dmin: f16::from_f32(dmin),
scales,
qs,
});
}
Ok(blocks)
}
pub fn dequant_row_to_buf(blocks_for_row: &[Self], buf: &mut Vec<f32>) {
let start = buf.len();
let n = blocks_for_row.len() * QK_K;
buf.resize(start + n, 0.0f32);
let _ = Self::dequant(blocks_for_row, &mut buf[start..]);
}
pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
if data.len() % BLOCK_Q4_K_BYTES != 0 {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q4_K slice_from_bytes: byte len {} not a multiple of {}",
data.len(),
BLOCK_Q4_K_BYTES
),
});
}
if data.is_empty() {
return Ok(&[]);
}
let align = std::mem::align_of::<Self>();
if data.as_ptr().align_offset(align) != 0 {
return Err(BonsaiError::KQuantError {
reason: format!("Q4_K slice_from_bytes: pointer not {}-byte aligned", align),
});
}
let count = data.len() / BLOCK_Q4_K_BYTES;
let ptr = data.as_ptr() as *const Self;
Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(C)]
pub struct BlockQ8K {
pub d: f32,
pub qs: [i8; 256],
pub bsums: [i16; 16],
}
const _: () = assert!(std::mem::size_of::<BlockQ8K>() == BLOCK_Q8K_BYTES);
impl BlockQ8K {
pub fn dequant(blocks: &[Self], output: &mut [f32]) -> BonsaiResult<()> {
let expected_len = blocks.len() * QK_K;
if output.len() < expected_len {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q8_K dequant: output len {} < expected {}",
output.len(),
expected_len
),
});
}
for (block_idx, block) in blocks.iter().enumerate() {
let d = block.d;
let base = block_idx * QK_K;
for i in 0..QK_K {
output[base + i] = d * (block.qs[i] as f32);
}
}
Ok(())
}
pub fn dequant_row_to_buf(blocks_for_row: &[Self], buf: &mut Vec<f32>) {
let start = buf.len();
let n = blocks_for_row.len() * QK_K;
buf.resize(start + n, 0.0f32);
let _ = Self::dequant(blocks_for_row, &mut buf[start..]);
}
pub fn quantize(input: &[f32]) -> BonsaiResult<Vec<Self>> {
if input.len() % QK_K != 0 {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q8_K quantize: input len {} not a multiple of {}",
input.len(),
QK_K
),
});
}
let num_blocks = input.len() / QK_K;
let mut blocks = Vec::with_capacity(num_blocks);
for block_idx in 0..num_blocks {
let chunk = &input[block_idx * QK_K..block_idx * QK_K + QK_K];
let max_abs = chunk.iter().map(|&v| v.abs()).fold(0.0f32, f32::max);
let d = if max_abs > 0.0 { max_abs / 127.0 } else { 0.0 };
let inv_d = if d > 0.0 { 1.0 / d } else { 0.0 };
let mut qs = [0i8; 256];
for (i, &w) in chunk.iter().enumerate() {
qs[i] = (w * inv_d).round().clamp(-127.0, 127.0) as i8;
}
let mut bsums = [0i16; 16];
for (group, slot) in bsums.iter_mut().enumerate() {
let group_start = group * 16;
let sum: i32 = qs[group_start..group_start + 16]
.iter()
.map(|&q| q as i32)
.sum();
*slot = sum.clamp(i16::MIN as i32, i16::MAX as i32) as i16;
}
blocks.push(BlockQ8K { d, qs, bsums });
}
Ok(blocks)
}
pub fn slice_from_bytes(data: &[u8]) -> BonsaiResult<&[Self]> {
if data.len() % BLOCK_Q8K_BYTES != 0 {
return Err(BonsaiError::KQuantError {
reason: format!(
"Q8_K slice_from_bytes: byte len {} not a multiple of {}",
data.len(),
BLOCK_Q8K_BYTES
),
});
}
if data.is_empty() {
return Ok(&[]);
}
let align = std::mem::align_of::<Self>();
if data.as_ptr().align_offset(align) != 0 {
return Err(BonsaiError::KQuantError {
reason: format!("Q8_K slice_from_bytes: pointer not {}-byte aligned", align),
});
}
let count = data.len() / BLOCK_Q8K_BYTES;
let ptr = data.as_ptr() as *const Self;
Ok(unsafe { std::slice::from_raw_parts(ptr, count) })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn q2k_block_size_correct() {
assert_eq!(std::mem::size_of::<BlockQ2K>(), BLOCK_Q2_K_BYTES);
assert_eq!(BLOCK_Q2_K_BYTES, 84);
}
#[test]
fn q2k_roundtrip_zero_weights() {
let blocks = BlockQ2K::quantize(&vec![0.0f32; 256]).expect("quantize ok");
let mut out = vec![0.0f32; 256];
BlockQ2K::dequant(&blocks, &mut out).expect("dequant ok");
for &v in &out {
assert!(
v.abs() < 1e-4,
"all-zero input should dequant to near-zero, got {v}"
);
}
}
#[test]
fn q2k_roundtrip_uniform() {
let input = vec![1.0f32; 256];
let blocks = BlockQ2K::quantize(&input).expect("quantize ok");
let mut out = vec![0.0f32; 256];
BlockQ2K::dequant(&blocks, &mut out).expect("dequant ok");
for &v in &out {
let err = (v - 1.0).abs();
assert!(err < 0.2, "uniform round-trip error {err} too high");
}
}
#[test]
fn q2k_quantize_output_length() {
let input = vec![0.5f32; 256];
let blocks = BlockQ2K::quantize(&input).expect("quantize ok");
assert_eq!(blocks.len(), 1);
}
#[test]
fn q2k_slice_from_bytes_empty() {
let data: Vec<u8> = vec![];
let result = BlockQ2K::slice_from_bytes(&data).expect("empty slice ok");
assert_eq!(result.len(), 0);
}
#[test]
fn q2k_slice_from_bytes_bad_length() {
let data = vec![0u8; 83]; assert!(BlockQ2K::slice_from_bytes(&data).is_err());
}
#[test]
fn q4k_block_size_correct() {
assert_eq!(std::mem::size_of::<BlockQ4K>(), BLOCK_Q4_K_BYTES);
assert_eq!(BLOCK_Q4_K_BYTES, 144);
}
#[test]
fn q4k_scale_encode_decode_roundtrip() {
let sc = [1, 2, 3, 4, 5, 63, 32, 0];
let mn = [10, 20, 30, 40, 50, 60, 15, 7];
let encoded = encode_q4k_scales(&sc, &mn);
let (sc2, mn2) = decode_q4k_scales(&encoded);
assert_eq!(sc, sc2);
assert_eq!(mn, mn2);
}
#[test]
fn q4k_scale_encode_decode_all_zeros() {
let sc = [0u8; 8];
let mn = [0u8; 8];
let encoded = encode_q4k_scales(&sc, &mn);
let (sc2, mn2) = decode_q4k_scales(&encoded);
assert_eq!(sc, sc2);
assert_eq!(mn, mn2);
}
#[test]
fn q4k_scale_encode_decode_max_values() {
let sc = [63u8; 8];
let mn = [63u8; 8];
let encoded = encode_q4k_scales(&sc, &mn);
let (sc2, mn2) = decode_q4k_scales(&encoded);
assert_eq!(sc, sc2);
assert_eq!(mn, mn2);
}
#[test]
fn q4k_slice_from_bytes_empty() {
let data: Vec<u8> = vec![];
let result = BlockQ4K::slice_from_bytes(&data).expect("empty slice ok");
assert_eq!(result.len(), 0);
}
#[test]
fn q4k_slice_from_bytes_bad_length() {
let data = vec![0u8; 100]; assert!(BlockQ4K::slice_from_bytes(&data).is_err());
}
#[test]
fn q3k_block_size_assertion() {
assert_eq!(std::mem::size_of::<BlockQ3K>(), BLOCK_Q3K_BYTES);
assert_eq!(BLOCK_Q3K_BYTES, 110);
}
#[test]
fn q3k_roundtrip_zero_weights() {
let blocks = BlockQ3K::quantize(&vec![0.0f32; 256]).expect("quantize ok");
let mut out = vec![0.0f32; 256];
BlockQ3K::dequant(&blocks, &mut out).expect("dequant ok");
for &v in &out {
assert!(
v.abs() < 1e-4,
"all-zero input should dequant to near-zero, got {v}"
);
}
}
#[test]
fn q3k_roundtrip_uniform() {
let input = vec![1.0f32; 256];
let blocks = BlockQ3K::quantize(&input).expect("quantize ok");
let mut out = vec![0.0f32; 256];
BlockQ3K::dequant(&blocks, &mut out).expect("dequant ok");
for &v in &out {
let err = (v - 1.0).abs() / 1.0;
assert!(
err < 0.5,
"uniform round-trip rel error {err} too high, got {v}"
);
}
}
#[test]
fn q3k_slice_from_bytes() {
let data = vec![0u8; BLOCK_Q3K_BYTES];
let result = BlockQ3K::slice_from_bytes(&data).expect("single block should parse");
assert_eq!(result.len(), 1);
}
#[test]
fn q3k_slice_from_bytes_empty() {
let data: Vec<u8> = vec![];
let result = BlockQ3K::slice_from_bytes(&data).expect("empty slice ok");
assert_eq!(result.len(), 0);
}
#[test]
fn q3k_slice_from_bytes_bad_length() {
let data = vec![0u8; 100]; assert!(BlockQ3K::slice_from_bytes(&data).is_err());
}
#[test]
fn q3k_quantize_output_length() {
let input = vec![0.5f32; 256];
let blocks = BlockQ3K::quantize(&input).expect("quantize ok");
assert_eq!(blocks.len(), 1, "256 weights → 1 block");
}
#[test]
fn q3k_quantize_non_multiple_errors() {
assert!(BlockQ3K::quantize(&vec![1.0f32; 100]).is_err());
}
#[test]
fn q3k_dequant_output_too_small_errors() {
let blocks = BlockQ3K::quantize(&vec![1.0f32; 256]).expect("quantize ok");
let mut out = vec![0.0f32; 100];
assert!(BlockQ3K::dequant(&blocks, &mut out).is_err());
}
#[test]
fn q3k_dequant_row_to_buf_works() {
let input = vec![0.5f32; 256];
let blocks = BlockQ3K::quantize(&input).expect("quantize ok");
let mut buf = Vec::new();
BlockQ3K::dequant_row_to_buf(&blocks, &mut buf);
assert_eq!(buf.len(), 256);
}
#[test]
fn q8k_block_size_assertion() {
assert_eq!(std::mem::size_of::<BlockQ8K>(), BLOCK_Q8K_BYTES);
assert_eq!(BLOCK_Q8K_BYTES, 292);
}
#[test]
fn q8k_roundtrip_zero_weights() {
let blocks = BlockQ8K::quantize(&vec![0.0f32; 256]).expect("quantize ok");
let mut out = vec![0.0f32; 256];
BlockQ8K::dequant(&blocks, &mut out).expect("dequant ok");
for &v in &out {
assert!(
v.abs() < 1e-6,
"all-zero input should dequant to exactly zero, got {v}"
);
}
}
#[test]
fn q8k_roundtrip_uniform() {
let input = vec![1.0f32; 256];
let blocks = BlockQ8K::quantize(&input).expect("quantize ok");
let mut out = vec![0.0f32; 256];
BlockQ8K::dequant(&blocks, &mut out).expect("dequant ok");
for &v in &out {
let err = (v - 1.0).abs();
assert!(err < 0.02, "Q8_K uniform round-trip error {err} too high");
}
}
#[test]
fn q8k_slice_from_bytes() {
let data = vec![0u8; BLOCK_Q8K_BYTES];
let result = BlockQ8K::slice_from_bytes(&data).expect("single block should parse");
assert_eq!(result.len(), 1);
}
#[test]
fn q8k_slice_from_bytes_empty() {
let data: Vec<u8> = vec![];
let result = BlockQ8K::slice_from_bytes(&data).expect("empty slice ok");
assert_eq!(result.len(), 0);
}
#[test]
fn q8k_slice_from_bytes_bad_length() {
let data = vec![0u8; 100]; assert!(BlockQ8K::slice_from_bytes(&data).is_err());
}
#[test]
fn q8k_quantize_output_length() {
let input = vec![0.5f32; 256];
let blocks = BlockQ8K::quantize(&input).expect("quantize ok");
assert_eq!(blocks.len(), 1, "256 weights → 1 block");
}
#[test]
fn q8k_quantize_non_multiple_errors() {
assert!(BlockQ8K::quantize(&vec![1.0f32; 100]).is_err());
}
#[test]
fn q8k_dequant_output_too_small_errors() {
let blocks = BlockQ8K::quantize(&vec![1.0f32; 256]).expect("quantize ok");
let mut out = vec![0.0f32; 100];
assert!(BlockQ8K::dequant(&blocks, &mut out).is_err());
}
#[test]
fn q8k_dequant_row_to_buf_works() {
let input = vec![0.5f32; 256];
let blocks = BlockQ8K::quantize(&input).expect("quantize ok");
let mut buf = Vec::new();
BlockQ8K::dequant_row_to_buf(&blocks, &mut buf);
assert_eq!(buf.len(), 256);
for &v in &buf {
assert!((v - 0.5).abs() < 0.01, "expected ~0.5, got {v}");
}
}
#[test]
fn q8k_bsums_roundtrip_sign() {
let input_pos = vec![0.5f32; 256];
let blocks_pos = BlockQ8K::quantize(&input_pos).expect("quantize ok");
for &bs in &blocks_pos[0].bsums {
assert!(
bs > 0,
"positive input should yield positive bsums, got {bs}"
);
}
let input_neg = vec![-0.5f32; 256];
let blocks_neg = BlockQ8K::quantize(&input_neg).expect("quantize ok");
for &bs in &blocks_neg[0].bsums {
assert!(
bs < 0,
"negative input should yield negative bsums, got {bs}"
);
}
}
}