use crate::error::AnamnesisError;
pub const NF4_CODEBOOK: [f32; 16] = [
-1.0,
-0.696_192_8,
-0.525_073_05,
-0.394_917_5,
-0.284_441_38,
-0.184_773_43,
-0.091_050_036,
0.0,
0.079_580_3,
0.160_930_2,
0.246_112_3,
0.337_915_24,
0.440_709_83,
0.562_617,
0.722_956_84,
1.0,
];
pub const FP4_CODEBOOK: [f32; 16] = [
0.0,
0.005_208_333_5,
0.666_666_7,
1.0,
0.333_333_34,
0.5,
0.166_666_67,
0.25,
-0.0,
-0.005_208_333_5,
-0.666_666_7,
-1.0,
-0.333_333_34,
-0.5,
-0.166_666_67,
-0.25,
];
fn read_f32_le(data: &[u8], offset: usize) -> Option<f32> {
let bytes: &[u8] = data.get(offset..offset + 4)?;
let arr: [u8; 4] = bytes.try_into().ok()?;
Some(f32::from_le_bytes(arr))
}
#[inline]
#[must_use]
fn bf16_bits_to_f32(bits: u16) -> f32 {
f32::from_bits(u32::from(bits) << 16)
}
#[inline]
#[must_use]
fn apply_sign_magnitude_encode_correction(value: f32, nibble: u8, codebook: &[f32; 16]) -> u8 {
if value.is_sign_negative() && nibble < 8 {
let upper = nibble + 8;
#[allow(clippy::indexing_slicing, clippy::as_conversions)]
let upper_bits = codebook[upper as usize].to_bits();
#[allow(clippy::indexing_slicing, clippy::as_conversions)]
let chosen_bits = codebook[nibble as usize].to_bits();
if upper_bits == chosen_bits {
return upper;
}
}
nibble
}
#[inline]
#[must_use]
fn nearest_codebook_index(value: f32, codebook: &[f32; 16]) -> u8 {
let val_bits = value.to_bits();
let mut best_idx: u8 = 0;
let mut best_dist = f32::INFINITY;
let mut best_exact = false;
for (i, &entry) in codebook.iter().enumerate() {
let exact = entry.to_bits() == val_bits;
let dist = (value - entry).abs();
let take = if exact && !best_exact {
true
} else if !exact && best_exact {
false
} else {
dist < best_dist
};
if take {
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
let i_u8 = i as u8;
best_idx = i_u8;
best_dist = dist;
best_exact = exact;
}
}
best_idx
}
fn parse_codebook(quant_map_data: &[u8]) -> crate::Result<[f32; 16]> {
if quant_map_data.len() != 64 {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 quant_map must be 64 bytes (16xF32), got {}",
quant_map_data.len()
),
});
}
let mut codebook = [0.0f32; 16];
for (i, slot) in codebook.iter_mut().enumerate() {
*slot = read_f32_le(quant_map_data, i * 4).ok_or_else(|| AnamnesisError::Parse {
reason: "BnB4 quant_map read out of bounds".into(),
})?;
}
Ok(codebook)
}
fn parse_absmax(absmax_data: &[u8], num_blocks: usize) -> crate::Result<Vec<f32>> {
let expected_bytes = num_blocks
.checked_mul(4)
.ok_or_else(|| AnamnesisError::Parse {
reason: "absmax byte count overflow".into(),
})?;
if absmax_data.len() != expected_bytes {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 absmax byte count mismatch: expected {expected_bytes}, got {}",
absmax_data.len()
),
});
}
let mut absmax = vec![0.0f32; num_blocks];
for (i, slot) in absmax.iter_mut().enumerate() {
*slot = read_f32_le(absmax_data, i * 4).ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB4 absmax read out of bounds at block {i}"),
})?;
}
Ok(absmax)
}
fn encode_bnb4_core(
bf16_data: &[u8],
absmax: &[f32],
codebook: &[f32; 16],
total_elements: usize,
block_size: usize,
) -> crate::Result<Vec<u8>> {
let out_byte_len = total_elements
.checked_div(2)
.ok_or_else(|| AnamnesisError::Parse {
reason: "BnB4 encode total_elements must be even".into(),
})?;
let mut output = vec![0u8; out_byte_len];
let bytes_per_block = block_size / 2;
let mut scratch = vec![0.0f32; block_size];
for (block_idx, &block_absmax) in absmax.iter().enumerate() {
let bf16_byte_start = block_idx
.checked_mul(block_size)
.and_then(|n| n.checked_mul(2))
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB4 encode bf16 start overflow at block {block_idx}"),
})?;
let bf16_byte_end =
bf16_byte_start
.checked_add(block_size * 2)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB4 encode bf16 end overflow at block {block_idx}"),
})?;
let bf16_block = bf16_data
.get(bf16_byte_start..bf16_byte_end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB4 encode bf16 block {block_idx} out of bounds"),
})?;
let o_start = block_idx * bytes_per_block;
let o_end = o_start + bytes_per_block;
let out_block = output
.get_mut(o_start..o_end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB4 encode output block {block_idx} out of bounds"),
})?;
#[allow(clippy::indexing_slicing)]
let scratch_block = &mut scratch[..block_size];
for (bf16_pair, slot) in bf16_block.chunks_exact(2).zip(scratch_block.iter_mut()) {
#[allow(clippy::indexing_slicing)]
let bits = u16::from_le_bytes([bf16_pair[0], bf16_pair[1]]);
*slot = bf16_bits_to_f32(bits);
}
#[allow(clippy::indexing_slicing)]
let scratch_view = &scratch[..block_size];
for (pair, out_byte) in scratch_view.chunks_exact(2).zip(out_block.iter_mut()) {
#[allow(clippy::indexing_slicing)]
let (val_low, val_high) = (pair[0], pair[1]);
let (norm_low, norm_high) = if block_absmax == 0.0 {
(0.0_f32, 0.0_f32)
} else {
(val_low / block_absmax, val_high / block_absmax)
};
let low_raw = nearest_codebook_index(norm_low, codebook);
let high_raw = nearest_codebook_index(norm_high, codebook);
let low_nibble = apply_sign_magnitude_encode_correction(norm_low, low_raw, codebook);
let high_nibble = apply_sign_magnitude_encode_correction(norm_high, high_raw, codebook);
*out_byte = (high_nibble << 4) | (low_nibble & 0x0F);
}
}
Ok(output)
}
pub fn encode_bnb4(
bf16_data: &[u8],
absmax_data: &[u8],
quant_map_data: &[u8],
total_elements: usize,
block_size: usize,
) -> crate::Result<Vec<u8>> {
if block_size == 0 {
return Err(AnamnesisError::Parse {
reason: "BnB encode block_size must be > 0".into(),
});
}
if !total_elements.is_multiple_of(2) {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 encode total_elements ({total_elements}) must be even \
(two nibbles per byte)"
),
});
}
let expected_bf16_bytes =
total_elements
.checked_mul(2)
.ok_or_else(|| AnamnesisError::Parse {
reason: "BnB4 encode bf16 byte count overflow".into(),
})?;
if bf16_data.len() != expected_bf16_bytes {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 encode bf16 byte count mismatch: expected {expected_bf16_bytes} for \
{total_elements} elements, got {}",
bf16_data.len()
),
});
}
if !total_elements.is_multiple_of(block_size) {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 encode total_elements ({total_elements}) not divisible by \
block_size ({block_size})"
),
});
}
let num_blocks = total_elements / block_size;
let codebook = parse_codebook(quant_map_data)?;
let absmax = parse_absmax(absmax_data, num_blocks)?;
encode_bnb4_core(bf16_data, &absmax, &codebook, total_elements, block_size)
}
pub fn encode_bnb4_compute_absmax(
bf16_data: &[u8],
quant_map_data: &[u8],
total_elements: usize,
block_size: usize,
) -> crate::Result<(Vec<u8>, Vec<u8>)> {
if block_size == 0 {
return Err(AnamnesisError::Parse {
reason: "BnB encode block_size must be > 0".into(),
});
}
if !total_elements.is_multiple_of(block_size) {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 encode total_elements ({total_elements}) not divisible by \
block_size ({block_size})"
),
});
}
let expected_bf16_bytes =
total_elements
.checked_mul(2)
.ok_or_else(|| AnamnesisError::Parse {
reason: "BnB4 encode bf16 byte count overflow".into(),
})?;
if bf16_data.len() != expected_bf16_bytes {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 encode bf16 byte count mismatch: expected {expected_bf16_bytes} for \
{total_elements} elements, got {}",
bf16_data.len()
),
});
}
let num_blocks = total_elements / block_size;
let mut absmax = vec![0.0f32; num_blocks];
for (block_idx, slot) in absmax.iter_mut().enumerate() {
let bf16_byte_start = block_idx * block_size * 2;
let bf16_byte_end = bf16_byte_start + block_size * 2;
let bf16_block = bf16_data
.get(bf16_byte_start..bf16_byte_end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB4 encode bf16 block {block_idx} out of bounds"),
})?;
let mut max_abs = 0.0_f32;
for pair in bf16_block.chunks_exact(2) {
#[allow(clippy::indexing_slicing)]
let bits = u16::from_le_bytes([pair[0], pair[1]]);
let v = bf16_bits_to_f32(bits).abs();
if v > max_abs {
max_abs = v;
}
}
*slot = max_abs;
}
let absmax_bytes: Vec<u8> = absmax.iter().flat_map(|v| v.to_le_bytes()).collect();
let codebook = parse_codebook(quant_map_data)?;
let weight_bytes = encode_bnb4_core(bf16_data, &absmax, &codebook, total_elements, block_size)?;
Ok((weight_bytes, absmax_bytes))
}
fn recover_double_quant_absmax(
absmax_data: &[u8],
nested_absmax_data: &[u8],
nested_quant_map_data: &[u8],
nested_block_size: usize,
) -> crate::Result<Vec<f32>> {
if nested_block_size == 0 {
return Err(AnamnesisError::Parse {
reason: "BnB nested_block_size must be > 0".into(),
});
}
if nested_quant_map_data.len() != 1024 {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 nested_quant_map must be 1024 bytes (256xF32), got {}",
nested_quant_map_data.len()
),
});
}
let num_blocks = absmax_data.len();
let num_nested_blocks = if num_blocks.is_multiple_of(nested_block_size) {
num_blocks / nested_block_size
} else {
num_blocks / nested_block_size + 1
};
let expected_nested_absmax_bytes =
num_nested_blocks
.checked_mul(4)
.ok_or_else(|| AnamnesisError::Parse {
reason: "BnB4 encode nested absmax byte count overflow".into(),
})?;
if nested_absmax_data.len() != expected_nested_absmax_bytes {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 encode nested_absmax byte count mismatch: expected \
{expected_nested_absmax_bytes}, got {}",
nested_absmax_data.len()
),
});
}
let mut nested_codebook = [0.0_f32; 256];
for (i, slot) in nested_codebook.iter_mut().enumerate() {
*slot = read_f32_le(nested_quant_map_data, i * 4).ok_or_else(|| AnamnesisError::Parse {
reason: "BnB4 encode nested_quant_map read out of bounds".into(),
})?;
}
let mut recovered = vec![0.0_f32; num_blocks];
for (i, (&absmax_byte, slot)) in absmax_data.iter().zip(recovered.iter_mut()).enumerate() {
let nested_block_idx = i / nested_block_size;
let nested_absmax_val =
read_f32_le(nested_absmax_data, nested_block_idx * 4).ok_or_else(|| {
AnamnesisError::Parse {
reason: format!(
"BnB4 encode nested_absmax read out of bounds at block {nested_block_idx}"
),
}
})?;
#[allow(clippy::as_conversions)]
let idx = absmax_byte as usize;
#[allow(clippy::indexing_slicing)]
let entry = nested_codebook[idx];
*slot = entry * nested_absmax_val;
}
Ok(recovered)
}
#[allow(clippy::too_many_arguments)]
pub fn encode_bnb4_double_quant(
bf16_data: &[u8],
absmax_data: &[u8],
quant_map_data: &[u8],
nested_absmax_data: &[u8],
nested_quant_map_data: &[u8],
total_elements: usize,
block_size: usize,
nested_block_size: usize,
) -> crate::Result<Vec<u8>> {
if block_size == 0 {
return Err(AnamnesisError::Parse {
reason: "BnB encode block_size must be > 0".into(),
});
}
if !total_elements.is_multiple_of(2) {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 encode total_elements ({total_elements}) must be even \
(two nibbles per byte)"
),
});
}
let expected_bf16_bytes =
total_elements
.checked_mul(2)
.ok_or_else(|| AnamnesisError::Parse {
reason: "BnB4 encode bf16 byte count overflow".into(),
})?;
if bf16_data.len() != expected_bf16_bytes {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 encode bf16 byte count mismatch: expected {expected_bf16_bytes} for \
{total_elements} elements, got {}",
bf16_data.len()
),
});
}
if !total_elements.is_multiple_of(block_size) {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 encode total_elements ({total_elements}) not divisible by \
block_size ({block_size})"
),
});
}
let num_blocks = total_elements / block_size;
if absmax_data.len() != num_blocks {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 double-quant encode: absmax byte count mismatch: expected \
{num_blocks} (one byte per block), got {}",
absmax_data.len()
),
});
}
let recovered_absmax = recover_double_quant_absmax(
absmax_data,
nested_absmax_data,
nested_quant_map_data,
nested_block_size,
)?;
let codebook = parse_codebook(quant_map_data)?;
encode_bnb4_core(
bf16_data,
&recovered_absmax,
&codebook,
total_elements,
block_size,
)
}
pub fn encode_bnb_int8(
bf16_data: &[u8],
scb_data: &[u8],
out_features: usize,
in_features: usize,
) -> crate::Result<Vec<u8>> {
let total_elements =
out_features
.checked_mul(in_features)
.ok_or_else(|| AnamnesisError::Parse {
reason: "BnB INT8 encode element count overflow".into(),
})?;
let expected_bf16_bytes =
total_elements
.checked_mul(2)
.ok_or_else(|| AnamnesisError::Parse {
reason: "BnB INT8 encode bf16 byte count overflow".into(),
})?;
if bf16_data.len() != expected_bf16_bytes {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB INT8 encode bf16 byte count mismatch: expected {expected_bf16_bytes}, got {}",
bf16_data.len()
),
});
}
let expected_scb_bytes = out_features
.checked_mul(4)
.ok_or_else(|| AnamnesisError::Parse {
reason: "BnB INT8 encode SCB byte count overflow".into(),
})?;
if scb_data.len() != expected_scb_bytes {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB INT8 encode SCB byte count mismatch: expected {expected_scb_bytes}, got {}",
scb_data.len()
),
});
}
let mut output = vec![0u8; total_elements];
for row in 0..out_features {
let scb_val = read_f32_le(scb_data, row * 4).ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB INT8 encode SCB read out of bounds at row {row}"),
})?;
let scale = scb_val / 127.0;
let bf16_byte_start =
row.checked_mul(in_features * 2)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB INT8 encode bf16 row start overflow at row {row}"),
})?;
let bf16_byte_end = bf16_byte_start + in_features * 2;
let bf16_row = bf16_data
.get(bf16_byte_start..bf16_byte_end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB INT8 encode bf16 row {row} out of bounds"),
})?;
let o_start = row * in_features;
let o_end = o_start + in_features;
let out_row = output
.get_mut(o_start..o_end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB INT8 encode output row {row} out of bounds"),
})?;
for (bf16_pair, out_byte) in bf16_row.chunks_exact(2).zip(out_row.iter_mut()) {
#[allow(clippy::indexing_slicing)]
let bits = u16::from_le_bytes([bf16_pair[0], bf16_pair[1]]);
let v = bf16_bits_to_f32(bits);
let scaled = if scale == 0.0 {
0.0_f32
} else {
v / scale
};
let rounded = scaled.round();
let clamped = rounded.clamp(-128.0, 127.0);
#[allow(
clippy::as_conversions,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap
)]
let signed = clamped as i32 as i8;
*out_byte = signed.cast_unsigned();
}
}
Ok(output)
}
pub fn encode_bnb_int8_compute_scb(
bf16_data: &[u8],
out_features: usize,
in_features: usize,
) -> crate::Result<(Vec<u8>, Vec<u8>)> {
let total_elements =
out_features
.checked_mul(in_features)
.ok_or_else(|| AnamnesisError::Parse {
reason: "BnB INT8 encode element count overflow".into(),
})?;
let expected_bf16_bytes =
total_elements
.checked_mul(2)
.ok_or_else(|| AnamnesisError::Parse {
reason: "BnB INT8 encode bf16 byte count overflow".into(),
})?;
if bf16_data.len() != expected_bf16_bytes {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB INT8 encode bf16 byte count mismatch: expected {expected_bf16_bytes}, got {}",
bf16_data.len()
),
});
}
let mut scb = vec![0.0f32; out_features];
for (row, slot) in scb.iter_mut().enumerate() {
let bf16_byte_start = row * in_features * 2;
let bf16_byte_end = bf16_byte_start + in_features * 2;
let bf16_row = bf16_data
.get(bf16_byte_start..bf16_byte_end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB INT8 encode bf16 row {row} out of bounds"),
})?;
let mut max_abs = 0.0_f32;
for pair in bf16_row.chunks_exact(2) {
#[allow(clippy::indexing_slicing)]
let bits = u16::from_le_bytes([pair[0], pair[1]]);
let v = bf16_bits_to_f32(bits).abs();
if v > max_abs {
max_abs = v;
}
}
*slot = max_abs;
}
let scb_bytes: Vec<u8> = scb.iter().flat_map(|v| v.to_le_bytes()).collect();
let weight_bytes = encode_bnb_int8(bf16_data, &scb_bytes, out_features, in_features)?;
Ok((weight_bytes, scb_bytes))
}
#[cfg(test)]
#[allow(
clippy::panic,
clippy::indexing_slicing,
clippy::unwrap_used,
clippy::float_cmp,
clippy::as_conversions,
clippy::cast_possible_truncation,
clippy::cast_precision_loss,
clippy::cast_possible_wrap
)]
mod tests {
use super::*;
use crate::lethe::round_trip::{
assert_bnb4_decode_encode_round_trip, assert_bnb_int8_decode_encode_round_trip,
};
use crate::remember::bnb::{dequantize_bnb4_to_bf16, dequantize_bnb_int8_to_bf16};
fn f32_to_bytes(values: &[f32]) -> Vec<u8> {
values.iter().flat_map(|v| v.to_le_bytes()).collect()
}
fn bf16_bytes_from_f32(values: &[f32]) -> Vec<u8> {
values
.iter()
.flat_map(|v| {
let bits = v.to_bits();
let lsb = (bits >> 16) & 1;
let rounding_bias = 0x7FFF_u32 + lsb;
let bf16 = (bits.wrapping_add(rounding_bias) >> 16) as u16;
bf16.to_le_bytes()
})
.collect()
}
#[test]
fn nearest_picks_closest_entry() {
let cb = [
-0.5, -0.375, -0.25, -0.125, 0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0,
1.125, 1.25, 1.375,
];
assert_eq!(nearest_codebook_index(0.0, &cb), 4);
assert_eq!(nearest_codebook_index(0.0625, &cb), 4);
assert_eq!(nearest_codebook_index(0.07, &cb), 5);
assert_eq!(nearest_codebook_index(1.3125, &cb), 14);
assert_eq!(nearest_codebook_index(1.32, &cb), 15);
assert_eq!(nearest_codebook_index(-10.0, &cb), 0);
assert_eq!(nearest_codebook_index(10.0, &cb), 15);
}
#[test]
fn nearest_preserves_signed_zero() {
let mut cb = [0.0f32; 16];
cb[0] = 0.0;
cb[8] = -0.0;
for (i, slot) in cb.iter_mut().enumerate().take(8).skip(1) {
*slot = i as f32 * 0.1;
}
for (i, slot) in cb.iter_mut().enumerate().take(16).skip(9) {
*slot = -(i as f32 - 8.0) * 0.1;
}
assert_eq!(nearest_codebook_index(0.0, &cb), 0);
assert_eq!(nearest_codebook_index(-0.0, &cb), 8);
}
#[test]
fn encode_bnb4_round_trips_linear_codebook() {
let mut cb = [0.0f32; 16];
for (i, slot) in cb.iter_mut().enumerate() {
*slot = (i as f32 - 7.5) * 0.1;
}
assert_bnb4_decode_encode_round_trip(
&cb,
&[1.0, 2.0, 0.5, 8.0],
32,
dequantize_bnb4_to_bf16,
encode_bnb4,
)
.unwrap();
}
#[test]
fn encode_bnb4_round_trips_nf4_codebook() {
assert_bnb4_decode_encode_round_trip(
&NF4_CODEBOOK,
&[1.0, 0.5, 2.0, 0.0123],
32,
dequantize_bnb4_to_bf16,
encode_bnb4,
)
.unwrap();
}
#[test]
fn encode_bnb4_round_trips_fp4_codebook() {
assert_bnb4_decode_encode_round_trip(
&FP4_CODEBOOK,
&[1.0, 0.5, 2.0, 0.0123],
32,
dequantize_bnb4_to_bf16,
encode_bnb4,
)
.unwrap();
}
#[test]
fn encode_bnb4_round_trips_collapsed_fp4_codebook() {
let mut collapsed = FP4_CODEBOOK;
collapsed[8] = 0.0; assert_eq!(
collapsed[0].to_bits(),
collapsed[8].to_bits(),
"test pre-condition: indices 0 and 8 must share bits",
);
assert_bnb4_decode_encode_round_trip(
&collapsed,
&[1.0, 0.5, 2.0, 0.0123],
32,
dequantize_bnb4_to_bf16,
encode_bnb4,
)
.unwrap();
}
#[test]
fn apply_sign_magnitude_encode_correction_lifts_to_upper_when_duplicated() {
let mut cb = FP4_CODEBOOK;
cb[8] = 0.0;
assert_eq!(apply_sign_magnitude_encode_correction(-0.0, 0, &cb), 8);
assert_eq!(apply_sign_magnitude_encode_correction(0.0, 0, &cb), 0);
assert_eq!(apply_sign_magnitude_encode_correction(-1e-30, 0, &cb), 8);
assert_eq!(apply_sign_magnitude_encode_correction(-1.0, 11, &cb), 11);
}
#[test]
fn apply_sign_magnitude_encode_correction_noop_when_bits_differ() {
assert_eq!(
apply_sign_magnitude_encode_correction(-0.0, 7, &NF4_CODEBOOK),
7,
"NF4 codebook has codebook[15]=1.0 ≠ codebook[7]=0.0; \
correction must not fire",
);
}
#[test]
fn encode_bnb4_uniform_codebook_zero_byte() {
let cb: Vec<f32> = (0..16).map(|i| i as f32 * 0.1).collect();
let cb_bytes = f32_to_bytes(&cb);
let absmax_bytes = f32_to_bytes(&[1.0]);
let bf16 = bf16_bytes_from_f32(&[0.0; 4]);
let out = encode_bnb4(&bf16, &absmax_bytes, &cb_bytes, 4, 4).unwrap();
assert_eq!(out, vec![0x00, 0x00]);
}
#[test]
fn encode_bnb4_nibble_extraction_inverse() {
let cb: Vec<f32> = (0..16).map(|i| i as f32).collect();
let cb_bytes = f32_to_bytes(&cb);
let absmax_bytes = f32_to_bytes(&[1.0]);
let bf16 = bf16_bytes_from_f32(&[1.0, 3.0, 2.0, 4.0]);
let out = encode_bnb4(&bf16, &absmax_bytes, &cb_bytes, 4, 4).unwrap();
assert_eq!(out, vec![0x31, 0x42]);
}
#[test]
fn encode_bnb4_validation_errors() {
let cb_bytes = f32_to_bytes(&[0.0; 16]);
let absmax_bytes = f32_to_bytes(&[1.0]);
assert!(encode_bnb4(&[0; 4], &absmax_bytes, &cb_bytes, 2, 0).is_err());
assert!(encode_bnb4(&[0; 2], &absmax_bytes, &cb_bytes, 4, 4).is_err());
assert!(encode_bnb4(&[0; 8], &absmax_bytes, &[0; 32], 4, 4).is_err());
assert!(encode_bnb4(&[0; 6], &absmax_bytes, &cb_bytes, 3, 3).is_err());
}
#[test]
fn encode_bnb4_double_quant_round_trips_synthetic() {
let mut nested_cb = [0.0_f32; 256];
nested_cb[64] = 0.5;
nested_cb[200] = 1.5;
let nested_cb_bytes: Vec<u8> = nested_cb.iter().flat_map(|v| v.to_le_bytes()).collect();
let nested_absmax_bytes = f32_to_bytes(&[2.0]);
let absmax_data = vec![64u8, 200u8];
let codebook_bytes = f32_to_bytes(&NF4_CODEBOOK);
let weight_data = vec![0x53u8, 0x9Cu8];
let bf16 = crate::remember::bnb::dequantize_bnb4_double_quant_to_bf16(
&weight_data,
&absmax_data,
&codebook_bytes,
&nested_absmax_bytes,
&nested_cb_bytes,
4,
2,
256,
)
.unwrap();
let re_encoded = encode_bnb4_double_quant(
&bf16,
&absmax_data,
&codebook_bytes,
&nested_absmax_bytes,
&nested_cb_bytes,
4,
2,
256,
)
.unwrap();
assert_eq!(re_encoded, weight_data);
}
#[test]
fn encode_bnb4_double_quant_validation_errors() {
let codebook_bytes = f32_to_bytes(&NF4_CODEBOOK);
let nested_cb_bytes = f32_to_bytes(&[0.0_f32; 256]);
let nested_absmax_bytes = f32_to_bytes(&[1.0]);
let absmax = vec![0u8];
assert!(encode_bnb4_double_quant(
&[0; 4],
&absmax,
&codebook_bytes,
&nested_absmax_bytes,
&nested_cb_bytes,
2,
0,
256,
)
.is_err());
assert!(encode_bnb4_double_quant(
&[0; 6],
&absmax,
&codebook_bytes,
&nested_absmax_bytes,
&nested_cb_bytes,
3,
3,
256,
)
.is_err());
assert!(encode_bnb4_double_quant(
&[0; 4],
&absmax,
&codebook_bytes,
&nested_absmax_bytes,
&[0; 512],
2,
2,
256,
)
.is_err());
assert!(encode_bnb4_double_quant(
&[0; 4],
&[0u8, 0u8],
&codebook_bytes,
&nested_absmax_bytes,
&nested_cb_bytes,
2,
2,
256,
)
.is_err());
assert!(encode_bnb4_double_quant(
&[0; 4],
&absmax,
&codebook_bytes,
&nested_absmax_bytes,
&nested_cb_bytes,
2,
2,
0,
)
.is_err());
}
#[test]
fn encode_bnb4_compute_absmax_round_trips_self() {
let cb_bytes = f32_to_bytes(&NF4_CODEBOOK);
let values: Vec<f32> = (0..64).map(|i| (i as f32 - 31.5) / 31.5).collect();
let bf16 = bf16_bytes_from_f32(&values);
let (weight, absmax) = encode_bnb4_compute_absmax(&bf16, &cb_bytes, 64, 32).unwrap();
let decoded = dequantize_bnb4_to_bf16(&weight, &absmax, &cb_bytes, 64, 32).unwrap();
let re_encoded = encode_bnb4(&decoded, &absmax, &cb_bytes, 64, 32).unwrap();
assert_eq!(weight, re_encoded);
}
#[test]
fn encode_bnb_int8_round_trips_every_i8() {
assert_bnb_int8_decode_encode_round_trip(dequantize_bnb_int8_to_bf16, encode_bnb_int8)
.unwrap();
}
#[test]
fn encode_bnb_int8_basic_inverse() {
let scb_bytes = f32_to_bytes(&[127.0, 254.0]);
let bf16 = bf16_bytes_from_f32(&[1.0, -1.0, 4.0, -4.0]);
let out = encode_bnb_int8(&bf16, &scb_bytes, 2, 2).unwrap();
assert_eq!(out, vec![1u8, 0xFF, 2u8, 0xFE]);
}
#[test]
fn encode_bnb_int8_clamps_positive_overflow() {
let scb_bytes = f32_to_bytes(&[1.0]);
let bf16 = bf16_bytes_from_f32(&[1.0, 0.5]);
let out = encode_bnb_int8(&bf16, &scb_bytes, 1, 2).unwrap();
assert_eq!(out[0], 0x7F);
assert_eq!(out[1], 0x40);
}
#[test]
fn encode_bnb_int8_clamps_negative_overflow() {
let scb_bytes = f32_to_bytes(&[1.0]);
let bf16 = bf16_bytes_from_f32(&[-2.0]);
let out = encode_bnb_int8(&bf16, &scb_bytes, 1, 1).unwrap();
assert_eq!(out[0], 0x80);
}
#[test]
fn encode_bnb_int8_zero_scb() {
let scb_bytes = f32_to_bytes(&[0.0]);
let bf16 = bf16_bytes_from_f32(&[1.0, -1.0, 100.0]);
let out = encode_bnb_int8(&bf16, &scb_bytes, 1, 3).unwrap();
assert_eq!(out, vec![0, 0, 0]);
}
#[test]
fn encode_bnb_int8_validation_errors() {
let scb_bytes = f32_to_bytes(&[1.0]);
assert!(encode_bnb_int8(&[0; 2], &scb_bytes, 1, 2).is_err());
assert!(encode_bnb_int8(&[0; 4], &scb_bytes, 2, 1).is_err());
}
#[test]
fn encode_bnb_int8_compute_scb_round_trips_self() {
let values: Vec<f32> = (-16..16).map(|i| i as f32 * 0.5).collect(); let bf16 = bf16_bytes_from_f32(&values);
let (weight, scb) = encode_bnb_int8_compute_scb(&bf16, 2, 16).unwrap();
let decoded = dequantize_bnb_int8_to_bf16(&weight, &scb, 2, 16).unwrap();
let re_encoded = encode_bnb_int8(&decoded, &scb, 2, 16).unwrap();
assert_eq!(weight, re_encoded);
}
}