use crate::error::AnamnesisError;
use crate::remember::fp8::f32_bits_to_bf16_bits;
fn read_f32_le(data: &[u8], offset: usize) -> Option<f32> {
let bytes: &[u8] = data.get(offset..offset + 4)?;
let arr: [u8; 4] = bytes.try_into().ok()?;
Some(f32::from_le_bytes(arr))
}
fn dequantize_bnb4_core(
weight_data: &[u8],
absmax: &[f32],
quant_map: &[f32; 16],
total_elements: usize,
block_size: usize,
) -> crate::Result<Vec<u8>> {
let out_byte_len = total_elements
.checked_mul(2)
.ok_or_else(|| AnamnesisError::Parse {
reason: "BnB4 output byte count overflow".into(),
})?;
let mut output = vec![0u8; out_byte_len];
let bytes_per_block = block_size / 2;
let mut scratch = vec![0.0f32; block_size];
for (block_idx, &block_absmax) in absmax.iter().enumerate() {
let w_start = block_idx * bytes_per_block;
let w_end = w_start + bytes_per_block;
let weight_block =
weight_data
.get(w_start..w_end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB4 weight block {block_idx} out of bounds"),
})?;
let o_start = block_idx * block_size * 2;
let o_end = o_start + block_size * 2;
let out_block = output
.get_mut(o_start..o_end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB4 output block {block_idx} out of bounds"),
})?;
#[allow(clippy::indexing_slicing)]
let scratch_block = &mut scratch[..block_size];
for (&byte, pair) in weight_block.iter().zip(scratch_block.chunks_exact_mut(2)) {
#[allow(clippy::as_conversions)]
let low = (byte & 0x0F) as usize;
#[allow(clippy::as_conversions)]
let high = (byte >> 4) as usize;
#[allow(clippy::indexing_slicing)]
{
pair[0] = quant_map[low];
pair[1] = quant_map[high];
}
}
#[allow(clippy::indexing_slicing)]
let scratch_view = &scratch[..block_size];
for (val, out_pair) in scratch_view.iter().zip(out_block.chunks_exact_mut(2)) {
let scaled = val * block_absmax;
let bf16 = f32_bits_to_bf16_bits(scaled.to_bits());
out_pair.copy_from_slice(&bf16.to_le_bytes());
}
}
Ok(output)
}
pub fn dequantize_bnb4_to_bf16(
weight_data: &[u8],
absmax_data: &[u8],
quant_map_data: &[u8],
total_elements: usize,
block_size: usize,
) -> crate::Result<Vec<u8>> {
if block_size == 0 {
return Err(AnamnesisError::Parse {
reason: "BnB block_size must be > 0".into(),
});
}
let expected_weight_bytes = if total_elements.is_multiple_of(2) {
Some(total_elements / 2)
} else {
None
};
if expected_weight_bytes != Some(weight_data.len()) {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 weight byte count mismatch: expected {} for {} elements, got {}",
expected_weight_bytes.unwrap_or(0),
total_elements,
weight_data.len()
),
});
}
if !total_elements.is_multiple_of(block_size) {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 total_elements ({total_elements}) not divisible by block_size ({block_size})"
),
});
}
let num_blocks = total_elements / block_size;
let expected_absmax_bytes = num_blocks
.checked_mul(4)
.ok_or_else(|| AnamnesisError::Parse {
reason: "absmax byte count overflow".into(),
})?;
if absmax_data.len() != expected_absmax_bytes {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 absmax byte count mismatch: expected {expected_absmax_bytes}, got {}",
absmax_data.len()
),
});
}
if quant_map_data.len() != 64 {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 quant_map must be 64 bytes (16×F32), got {}",
quant_map_data.len()
),
});
}
let mut quant_map = [0.0f32; 16];
for (i, val) in quant_map.iter_mut().enumerate() {
*val = read_f32_le(quant_map_data, i * 4).ok_or_else(|| AnamnesisError::Parse {
reason: "BnB4 quant_map read out of bounds".into(),
})?;
}
let mut absmax_f32 = vec![0.0f32; num_blocks];
for (i, val) in absmax_f32.iter_mut().enumerate() {
*val = read_f32_le(absmax_data, i * 4).ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB4 absmax read out of bounds at block {i}"),
})?;
}
dequantize_bnb4_core(
weight_data,
&absmax_f32,
&quant_map,
total_elements,
block_size,
)
}
#[allow(clippy::too_many_arguments)]
pub fn dequantize_bnb4_double_quant_to_bf16(
weight_data: &[u8],
absmax_data: &[u8],
quant_map_data: &[u8],
nested_absmax_data: &[u8],
nested_quant_map_data: &[u8],
total_elements: usize,
block_size: usize,
nested_block_size: usize,
) -> crate::Result<Vec<u8>> {
if block_size == 0 || nested_block_size == 0 {
return Err(AnamnesisError::Parse {
reason: "BnB block_size and nested_block_size must be > 0".into(),
});
}
if !total_elements.is_multiple_of(block_size) {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 total_elements ({total_elements}) not divisible by block_size ({block_size})"
),
});
}
let num_blocks = total_elements / block_size;
if absmax_data.len() != num_blocks {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 double-quant: absmax byte count mismatch: expected {num_blocks}, got {}",
absmax_data.len()
),
});
}
if nested_quant_map_data.len() != 1024 {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 nested_quant_map must be 1024 bytes (256×F32), got {}",
nested_quant_map_data.len()
),
});
}
let mut nested_quant_map = [0.0f32; 256];
for (i, val) in nested_quant_map.iter_mut().enumerate() {
*val = read_f32_le(nested_quant_map_data, i * 4).ok_or_else(|| AnamnesisError::Parse {
reason: "BnB4 nested_quant_map read out of bounds".into(),
})?;
}
let num_nested_blocks = if num_blocks.is_multiple_of(nested_block_size) {
num_blocks / nested_block_size
} else {
num_blocks / nested_block_size + 1
};
let expected_nested_absmax_bytes =
num_nested_blocks
.checked_mul(4)
.ok_or_else(|| AnamnesisError::Parse {
reason: "nested absmax byte count overflow".into(),
})?;
if nested_absmax_data.len() != expected_nested_absmax_bytes {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 nested_absmax byte count mismatch: expected {expected_nested_absmax_bytes}, got {}",
nested_absmax_data.len()
),
});
}
let mut dequantized_absmax = vec![0.0f32; num_blocks];
for (i, &absmax_byte) in absmax_data.iter().enumerate() {
let nested_block_idx = i / nested_block_size;
let nested_absmax_val =
read_f32_le(nested_absmax_data, nested_block_idx * 4).ok_or_else(|| {
AnamnesisError::Parse {
reason: format!(
"BnB4 nested_absmax read out of bounds at block {nested_block_idx}"
),
}
})?;
#[allow(clippy::as_conversions)]
let idx = absmax_byte as usize;
#[allow(clippy::indexing_slicing)]
{
dequantized_absmax[i] = nested_quant_map[idx] * nested_absmax_val;
}
}
if quant_map_data.len() != 64 {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB4 quant_map must be 64 bytes (16×F32), got {}",
quant_map_data.len()
),
});
}
let mut quant_map = [0.0f32; 16];
for (i, val) in quant_map.iter_mut().enumerate() {
*val = read_f32_le(quant_map_data, i * 4).ok_or_else(|| AnamnesisError::Parse {
reason: "BnB4 quant_map read out of bounds".into(),
})?;
}
dequantize_bnb4_core(
weight_data,
&dequantized_absmax,
&quant_map,
total_elements,
block_size,
)
}
pub fn dequantize_bnb_int8_to_bf16(
weight_data: &[u8],
scb_data: &[u8],
out_features: usize,
in_features: usize,
) -> crate::Result<Vec<u8>> {
let total_elements =
out_features
.checked_mul(in_features)
.ok_or_else(|| AnamnesisError::Parse {
reason: "BnB INT8 element count overflow".into(),
})?;
if weight_data.len() != total_elements {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB INT8 weight byte count mismatch: expected {total_elements}, got {}",
weight_data.len()
),
});
}
let expected_scb_bytes = out_features
.checked_mul(4)
.ok_or_else(|| AnamnesisError::Parse {
reason: "SCB byte count overflow".into(),
})?;
if scb_data.len() != expected_scb_bytes {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB INT8 SCB byte count mismatch: expected {expected_scb_bytes}, got {}",
scb_data.len()
),
});
}
let out_byte_len = total_elements
.checked_mul(2)
.ok_or_else(|| AnamnesisError::Parse {
reason: "BnB INT8 output byte count overflow".into(),
})?;
let mut output = vec![0u8; out_byte_len];
for row in 0..out_features {
let scb_val = read_f32_le(scb_data, row * 4).ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB INT8 SCB read out of bounds at row {row}"),
})?;
let scale = scb_val / 127.0;
let w_start = row * in_features;
let w_end = w_start + in_features;
let row_weights = weight_data
.get(w_start..w_end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB INT8 weight row {row} out of bounds"),
})?;
let o_start = row * in_features * 2;
let o_end = o_start + in_features * 2;
let out_row = output
.get_mut(o_start..o_end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB INT8 output row {row} out of bounds"),
})?;
for (&w_byte, out_pair) in row_weights.iter().zip(out_row.chunks_exact_mut(2)) {
#[allow(clippy::as_conversions, clippy::cast_possible_wrap)]
let w_i8 = w_byte as i8;
let w_f32 = f32::from(w_i8);
let val = w_f32 * scale;
let bf16 = f32_bits_to_bf16_bits(val.to_bits());
out_pair.copy_from_slice(&bf16.to_le_bytes());
}
}
Ok(output)
}
#[cfg(test)]
#[allow(
clippy::panic,
clippy::indexing_slicing,
clippy::unwrap_used,
clippy::float_cmp,
clippy::as_conversions,
clippy::cast_possible_truncation,
clippy::cast_precision_loss
)]
mod tests {
use super::*;
fn f32_to_bytes(values: &[f32]) -> Vec<u8> {
values.iter().flat_map(|v| v.to_le_bytes()).collect()
}
fn read_bf16(output: &[u8], idx: usize) -> f32 {
let offset = idx * 2;
let bits = u16::from_le_bytes([output[offset], output[offset + 1]]);
let f32_bits = u32::from(bits) << 16;
f32::from_bits(f32_bits)
}
#[test]
fn bnb4_uniform_lookup() {
let quant_map: Vec<f32> = (0..16).map(|i| i as f32 * 0.1).collect();
let quant_map_bytes = f32_to_bytes(&quant_map);
let block_size = 4;
let num_bytes = 2; let weight_data = vec![0x00u8; num_bytes];
let absmax_bytes = f32_to_bytes(&[2.0]);
let out =
dequantize_bnb4_to_bf16(&weight_data, &absmax_bytes, &quant_map_bytes, 4, block_size)
.unwrap();
for i in 0..4 {
assert_eq!(read_bf16(&out, i), 0.0, "element {i}");
}
}
#[test]
fn bnb4_nibble_extraction() {
let quant_map: Vec<f32> = (0..16).map(|i| i as f32).collect();
let quant_map_bytes = f32_to_bytes(&quant_map);
let weight_data = vec![0x31, 0x42];
let absmax_bytes = f32_to_bytes(&[1.0]);
let out =
dequantize_bnb4_to_bf16(&weight_data, &absmax_bytes, &quant_map_bytes, 4, 4).unwrap();
assert_eq!(read_bf16(&out, 0), 1.0);
assert_eq!(read_bf16(&out, 1), 3.0);
assert_eq!(read_bf16(&out, 2), 2.0);
assert_eq!(read_bf16(&out, 3), 4.0);
}
#[test]
fn bnb4_absmax_scaling() {
let mut quant_map = [0.0f32; 16];
quant_map[5] = 0.5;
let quant_map_bytes = f32_to_bytes(&quant_map);
let weight_data = vec![0x55]; let absmax_bytes = f32_to_bytes(&[4.0]);
let out =
dequantize_bnb4_to_bf16(&weight_data, &absmax_bytes, &quant_map_bytes, 2, 2).unwrap();
assert_eq!(read_bf16(&out, 0), 2.0); assert_eq!(read_bf16(&out, 1), 2.0);
}
#[test]
fn bnb4_multi_block() {
let quant_map: Vec<f32> = (0..16).map(|i| i as f32).collect();
let quant_map_bytes = f32_to_bytes(&quant_map);
let weight_data = vec![0x10, 0x10];
let absmax_bytes = f32_to_bytes(&[1.0, 3.0]);
let out = dequantize_bnb4_to_bf16(
&weight_data,
&absmax_bytes,
&quant_map_bytes,
4,
2, )
.unwrap();
assert_eq!(read_bf16(&out, 0), 0.0);
assert_eq!(read_bf16(&out, 1), 1.0);
assert_eq!(read_bf16(&out, 2), 0.0);
assert_eq!(read_bf16(&out, 3), 3.0);
}
#[test]
fn bnb4_validation_errors() {
let quant_map_bytes = f32_to_bytes(&[0.0; 16]);
let absmax_bytes = f32_to_bytes(&[1.0]);
assert!(dequantize_bnb4_to_bf16(&[0], &absmax_bytes, &quant_map_bytes, 2, 0).is_err());
assert!(dequantize_bnb4_to_bf16(&[0, 0], &absmax_bytes, &quant_map_bytes, 2, 2).is_err());
assert!(dequantize_bnb4_to_bf16(&[0], &absmax_bytes, &[0; 32], 2, 2).is_err());
}
#[test]
fn bnb4_double_quant_basic() {
let quant_map: Vec<f32> = (0..16).map(|i| i as f32).collect();
let quant_map_bytes = f32_to_bytes(&quant_map);
let mut nested_quant_map = [0.0f32; 256];
nested_quant_map[2] = 0.5; let nested_quant_map_bytes = f32_to_bytes(&nested_quant_map);
let nested_absmax_bytes = f32_to_bytes(&[4.0]);
let absmax_data = vec![2u8]; let weight_data = vec![0x10];
let out = dequantize_bnb4_double_quant_to_bf16(
&weight_data,
&absmax_data,
&quant_map_bytes,
&nested_absmax_bytes,
&nested_quant_map_bytes,
2, 2, 256, )
.unwrap();
assert_eq!(read_bf16(&out, 0), 0.0);
assert_eq!(read_bf16(&out, 1), 2.0);
}
#[test]
fn bnb_int8_basic() {
let weight_data: Vec<u8> = vec![
1u8, 0xFF, 2u8, 0xFE, ];
let scb_bytes = f32_to_bytes(&[127.0, 254.0]);
let out = dequantize_bnb_int8_to_bf16(&weight_data, &scb_bytes, 2, 2).unwrap();
assert_eq!(read_bf16(&out, 0), 1.0); assert_eq!(read_bf16(&out, 1), -1.0); assert_eq!(read_bf16(&out, 2), 4.0); assert_eq!(read_bf16(&out, 3), -4.0); }
#[test]
fn bnb_int8_zero_scale() {
let weight_data = vec![127u8, 1u8]; let scb_bytes = f32_to_bytes(&[0.0]);
let out = dequantize_bnb_int8_to_bf16(&weight_data, &scb_bytes, 1, 2).unwrap();
assert_eq!(read_bf16(&out, 0), 0.0);
assert_eq!(read_bf16(&out, 1), 0.0);
}
#[test]
fn bnb_int8_validation_errors() {
let scb_bytes = f32_to_bytes(&[1.0]);
assert!(dequantize_bnb_int8_to_bf16(&[0], &scb_bytes, 1, 2).is_err());
assert!(dequantize_bnb_int8_to_bf16(&[0; 4], &scb_bytes, 2, 2).is_err());
}
}