use crate::error::AnamnesisError;
pub fn assert_bnb4_decode_encode_round_trip<D, E>(
codebook: &[f32; 16],
scales: &[f32],
block_size: usize,
decode: D,
encode: E,
) -> crate::Result<()>
where
D: Fn(&[u8], &[u8], &[u8], usize, usize) -> crate::Result<Vec<u8>>,
E: Fn(&[u8], &[u8], &[u8], usize, usize) -> crate::Result<Vec<u8>>,
{
if scales.is_empty() {
return Err(AnamnesisError::Parse {
reason: "round-trip harness needs at least one scale".into(),
});
}
if block_size != 32 {
return Err(AnamnesisError::Parse {
reason: format!("round-trip harness requires block_size == 32 (got {block_size})"),
});
}
let mut synthetic = [0u8; 16];
for (i, slot) in synthetic.iter_mut().enumerate() {
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
let nibble = i as u8;
*slot = (nibble << 4) | nibble;
}
let total_elements: usize = 32;
let mut codebook_bytes = [0u8; 64];
for (i, &entry) in codebook.iter().enumerate() {
let le = entry.to_le_bytes();
#[allow(clippy::indexing_slicing)]
codebook_bytes[i * 4..i * 4 + 4].copy_from_slice(&le);
}
for (scale_idx, &scale) in scales.iter().enumerate() {
let absmax_bytes = scale.to_le_bytes();
let bf16 = decode(
&synthetic,
&absmax_bytes,
&codebook_bytes,
total_elements,
block_size,
)?;
let recovered = encode(
&bf16,
&absmax_bytes,
&codebook_bytes,
total_elements,
block_size,
)?;
assert_eq!(
recovered.len(),
synthetic.len(),
"scale[{scale_idx}] = {scale}: round-trip byte count mismatch \
(expected {}, got {})",
synthetic.len(),
recovered.len(),
);
for (byte_idx, (&orig, &back)) in synthetic.iter().zip(recovered.iter()).enumerate() {
assert_eq!(
back, orig,
"scale[{scale_idx}] = {scale}, byte {byte_idx}: \
round-trip mismatch (expected 0x{orig:02X}, got 0x{back:02X})",
);
}
}
Ok(())
}
pub fn assert_bnb_int8_decode_encode_round_trip<D, E>(decode: D, encode: E) -> crate::Result<()>
where
D: Fn(&[u8], &[u8], usize, usize) -> crate::Result<Vec<u8>>,
E: Fn(&[u8], &[u8], usize, usize) -> crate::Result<Vec<u8>>,
{
let mut synthetic = [0u8; 256];
for (i, slot) in synthetic.iter_mut().enumerate() {
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
let byte = i as u8;
*slot = byte;
}
let scb = 127.0_f32.to_le_bytes();
let bf16 = decode(&synthetic, &scb, 1, 256)?;
let recovered = encode(&bf16, &scb, 1, 256)?;
assert_eq!(
recovered.len(),
synthetic.len(),
"INT8 round-trip byte count mismatch (expected {}, got {})",
synthetic.len(),
recovered.len(),
);
for (byte_idx, (&orig, &back)) in synthetic.iter().zip(recovered.iter()).enumerate() {
assert_eq!(
back, orig,
"INT8 byte {byte_idx}: round-trip mismatch \
(expected 0x{orig:02X}, got 0x{back:02X})",
);
}
Ok(())
}
#[cfg(test)]
#[allow(
clippy::panic,
clippy::unwrap_used,
clippy::indexing_slicing,
clippy::as_conversions,
clippy::cast_possible_truncation,
clippy::cast_precision_loss,
clippy::float_cmp,
// `toy_decode` / `toy_encode` must keep the Result<Vec<u8>> return
// type to satisfy the harness closure signature; they never error.
clippy::unnecessary_wraps
)]
mod tests {
use super::*;
fn toy_decode(
weight: &[u8],
absmax: &[u8],
codebook: &[u8],
total_elements: usize,
block_size: usize,
) -> crate::Result<Vec<u8>> {
assert_eq!(weight.len() * 2, total_elements);
let mut cb = [0.0f32; 16];
for (i, slot) in cb.iter_mut().enumerate() {
let arr: [u8; 4] = codebook[i * 4..i * 4 + 4].try_into().unwrap();
*slot = f32::from_le_bytes(arr);
}
let mut out = vec![0u8; total_elements * 2];
let num_blocks = total_elements / block_size;
for block_idx in 0..num_blocks {
let arr: [u8; 4] = absmax[block_idx * 4..block_idx * 4 + 4].try_into().unwrap();
let scale = f32::from_le_bytes(arr);
let w_start = block_idx * (block_size / 2);
let w_end = w_start + block_size / 2;
for (offset, &byte) in weight[w_start..w_end].iter().enumerate() {
let low = (byte & 0x0F) as usize;
let high = (byte >> 4) as usize;
let val_low = cb[low] * scale;
let val_high = cb[high] * scale;
let bf16_low = (val_low.to_bits() >> 16) as u16;
let bf16_high = (val_high.to_bits() >> 16) as u16;
let o = (block_idx * block_size + offset * 2) * 2;
out[o..o + 2].copy_from_slice(&bf16_low.to_le_bytes());
out[o + 2..o + 4].copy_from_slice(&bf16_high.to_le_bytes());
}
}
Ok(out)
}
fn toy_encode(
bf16: &[u8],
absmax: &[u8],
codebook: &[u8],
total_elements: usize,
block_size: usize,
) -> crate::Result<Vec<u8>> {
assert_eq!(bf16.len(), total_elements * 2);
let mut cb = [0.0f32; 16];
for (i, slot) in cb.iter_mut().enumerate() {
let arr: [u8; 4] = codebook[i * 4..i * 4 + 4].try_into().unwrap();
*slot = f32::from_le_bytes(arr);
}
let mut out = vec![0u8; total_elements / 2];
let num_blocks = total_elements / block_size;
for block_idx in 0..num_blocks {
let arr: [u8; 4] = absmax[block_idx * 4..block_idx * 4 + 4].try_into().unwrap();
let scale = f32::from_le_bytes(arr);
for pair_idx in 0..block_size / 2 {
let e0 = block_idx * block_size + pair_idx * 2;
let arr0: [u8; 2] = bf16[e0 * 2..e0 * 2 + 2].try_into().unwrap();
let arr1: [u8; 2] = bf16[(e0 + 1) * 2..(e0 + 1) * 2 + 2].try_into().unwrap();
let f0 = f32::from_bits(u32::from(u16::from_le_bytes(arr0)) << 16) / scale;
let f1 = f32::from_bits(u32::from(u16::from_le_bytes(arr1)) << 16) / scale;
let n0 = nearest_idx(f0, &cb);
let n1 = nearest_idx(f1, &cb);
let o = block_idx * (block_size / 2) + pair_idx;
out[o] = (n1 << 4) | n0;
}
}
Ok(out)
}
fn nearest_idx(v: f32, cb: &[f32; 16]) -> u8 {
let val_bits = v.to_bits();
let mut best = 0u8;
let mut best_d = f32::INFINITY;
let mut best_exact = false;
for (i, &c) in cb.iter().enumerate() {
let exact = c.to_bits() == val_bits;
let d = (v - c).abs();
let take = if exact && !best_exact {
true
} else if !exact && best_exact {
false
} else {
d < best_d
};
if take {
best = i as u8;
best_d = d;
best_exact = exact;
}
}
best
}
#[test]
fn harness_round_trips_linear_codebook() {
let mut cb = [0.0f32; 16];
for (i, slot) in cb.iter_mut().enumerate() {
*slot = (i as f32 - 7.5) * 0.1;
}
let scales = [1.0, 2.0, 0.5, 0.123];
assert_bnb4_decode_encode_round_trip(&cb, &scales, 32, toy_decode, toy_encode).unwrap();
}
#[test]
fn harness_round_trips_signed_zero_codebook() {
let mut cb = [0.0f32; 16];
cb[0] = 0.0;
cb[8] = -0.0;
for (i, slot) in cb.iter_mut().enumerate() {
if i != 0 && i != 8 {
*slot = (i as f32 - 7.5) * 0.1;
}
}
let scales = [1.0, 2.5];
assert_bnb4_decode_encode_round_trip(&cb, &scales, 32, toy_decode, toy_encode).unwrap();
}
#[test]
fn harness_rejects_bad_block_size() {
let cb = [0.0f32; 16];
let scales = [1.0];
assert!(
assert_bnb4_decode_encode_round_trip(&cb, &scales, 16, toy_decode, toy_encode).is_err()
);
}
#[test]
fn harness_rejects_empty_scales() {
let cb = [0.0f32; 16];
let scales: [f32; 0] = [];
assert!(
assert_bnb4_decode_encode_round_trip(&cb, &scales, 32, toy_decode, toy_encode).is_err()
);
}
}