use std::fmt::Write as _;
use std::path::Path;
use crate::error::AnamnesisError;
use crate::lethe::bnb::{encode_bnb4_compute_absmax, NF4_CODEBOOK};
pub const NF4_BLOCK_SIZE: usize = 64;
#[derive(Debug, Clone)]
pub struct BnbWriteInput<'a> {
pub name: &'a str,
pub shape: &'a [usize],
pub bf16_data: &'a [u8],
}
#[must_use]
pub fn is_eligible_for_nf4(shape: &[usize]) -> bool {
if shape.len() != 2 {
return false;
}
let total: usize = shape.iter().product();
total >= NF4_BLOCK_SIZE && total.is_multiple_of(NF4_BLOCK_SIZE)
}
fn codebook_bytes() -> Vec<u8> {
NF4_CODEBOOK.iter().flat_map(|v| v.to_le_bytes()).collect()
}
fn quant_state_json_bytes(shape: &[usize]) -> Vec<u8> {
let mut s = String::new();
let _ = write!(
&mut s,
r#"{{"quant_type":"nf4","blocksize":{NF4_BLOCK_SIZE},"shape":["#
);
for (i, dim) in shape.iter().enumerate() {
if i > 0 {
s.push(',');
}
let _ = write!(&mut s, "{dim}");
}
s.push_str(r#"],"nested":false}"#);
s.into_bytes()
}
pub fn write_bnb_nf4_safetensors(
inputs: &[BnbWriteInput<'_>],
output: impl AsRef<Path>,
) -> crate::Result<()> {
let bytes = write_bnb_nf4_safetensors_bytes(inputs)?;
std::fs::write(output.as_ref(), &bytes).map_err(AnamnesisError::Io)
}
pub fn write_bnb_nf4_safetensors_bytes(inputs: &[BnbWriteInput<'_>]) -> crate::Result<Vec<u8>> {
let mut owned_storage: Vec<(String, safetensors::Dtype, Vec<usize>, Vec<u8>)> = Vec::new();
let codebook = codebook_bytes();
let mut sorted_inputs: Vec<&BnbWriteInput<'_>> = inputs.iter().collect();
sorted_inputs.sort_by_key(|t| t.name);
for input in sorted_inputs {
let total_elements: usize = input
.shape
.iter()
.copied()
.try_fold(1usize, |acc, d| {
if d == 0 {
return None;
}
acc.checked_mul(d)
})
.ok_or_else(|| AnamnesisError::Parse {
reason: format!(
"BnB-NF4 write `{}`: shape {:?} element-count overflow or zero dimension",
input.name, input.shape
),
})?;
let expected_bytes =
total_elements
.checked_mul(2)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("BnB-NF4 write `{}`: BF16 byte count overflow", input.name),
})?;
if input.bf16_data.len() != expected_bytes {
return Err(AnamnesisError::Parse {
reason: format!(
"BnB-NF4 write `{}`: bf16_data length {} != expected {expected_bytes} \
bytes (shape {:?})",
input.name,
input.bf16_data.len(),
input.shape
),
});
}
if is_eligible_for_nf4(input.shape) {
let (weight, absmax) = encode_bnb4_compute_absmax(
input.bf16_data,
&codebook,
total_elements,
NF4_BLOCK_SIZE,
)?;
owned_storage.push((
format!("{}.weight", input.name),
safetensors::Dtype::U8,
vec![total_elements / 2, 1],
weight,
));
let num_blocks = total_elements / NF4_BLOCK_SIZE;
owned_storage.push((
format!("{}.weight.absmax", input.name),
safetensors::Dtype::F32,
vec![num_blocks],
absmax,
));
owned_storage.push((
format!("{}.weight.quant_map", input.name),
safetensors::Dtype::F32,
vec![16],
codebook.clone(),
));
let qs = quant_state_json_bytes(input.shape);
let qs_len = qs.len();
owned_storage.push((
format!("{}.weight.quant_state.bitsandbytes__nf4", input.name),
safetensors::Dtype::U8,
vec![qs_len],
qs,
));
} else {
owned_storage.push((
input.name.to_owned(),
safetensors::Dtype::BF16,
input.shape.to_vec(),
input.bf16_data.to_vec(),
));
}
}
owned_storage.sort_by(|a, b| a.0.cmp(&b.0));
let mut views: Vec<(String, safetensors::tensor::TensorView<'_>)> =
Vec::with_capacity(owned_storage.len());
for (name, dtype, shape, data) in &owned_storage {
let view =
safetensors::tensor::TensorView::new(*dtype, shape.clone(), data).map_err(|e| {
AnamnesisError::Parse {
reason: format!("failed to create TensorView for `{name}`: {e}"),
}
})?;
views.push((name.clone(), view));
}
#[allow(clippy::wildcard_enum_match_arm)]
safetensors::tensor::serialize(views, &None).map_err(|e| AnamnesisError::Parse {
reason: format!("failed to serialize BnB-NF4 safetensors: {e}"),
})
}
#[derive(Debug, Default, Clone, Copy)]
#[must_use]
pub struct BnbNf4WriteStats {
pub quantized: usize,
pub passthrough: usize,
}
pub fn classify_inputs(inputs: &[BnbWriteInput<'_>]) -> BnbNf4WriteStats {
let mut stats = BnbNf4WriteStats::default();
for input in inputs {
if is_eligible_for_nf4(input.shape) {
stats.quantized += 1;
} else {
stats.passthrough += 1;
}
}
stats
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::indexing_slicing,
clippy::expect_used,
clippy::panic,
clippy::float_cmp,
clippy::as_conversions,
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::wildcard_enum_match_arm
)]
mod tests {
use super::*;
use crate::parse::safetensors::{parse_safetensors_header, QuantScheme};
use crate::remember::bnb::dequantize_bnb4_to_bf16;
fn synth_bf16(rows: usize, cols: usize) -> Vec<u8> {
let n = rows * cols;
let mut out = Vec::with_capacity(n * 2);
for i in 0..n {
let v = (i as f32) / (n as f32) * 2.0 - 1.0;
let bits = (v.to_bits() >> 16) as u16;
out.extend_from_slice(&bits.to_le_bytes());
}
out
}
#[test]
fn eligibility_only_2d_multiples_of_64() {
assert!(is_eligible_for_nf4(&[64, 1]));
assert!(is_eligible_for_nf4(&[8, 8]));
assert!(is_eligible_for_nf4(&[128, 256]));
assert!(!is_eligible_for_nf4(&[63, 1])); assert!(!is_eligible_for_nf4(&[64])); assert!(!is_eligible_for_nf4(&[4, 4, 4])); assert!(!is_eligible_for_nf4(&[])); }
#[test]
fn quant_state_json_shape_recovery() {
let blob = quant_state_json_bytes(&[256, 64]);
let s = std::str::from_utf8(&blob).unwrap();
let v: serde_json::Value = serde_json::from_str(s).unwrap();
let arr = v["shape"].as_array().unwrap();
assert_eq!(arr[0].as_u64(), Some(256));
assert_eq!(arr[1].as_u64(), Some(64));
assert_eq!(v["quant_type"].as_str(), Some("nf4"));
assert_eq!(v["blocksize"].as_u64(), Some(64));
assert_eq!(v["nested"].as_bool(), Some(false));
}
#[test]
fn write_then_parse_detects_bnb_nf4_scheme() {
let bf16 = synth_bf16(64, 1);
let inputs = vec![BnbWriteInput {
name: "linear",
shape: &[64, 1],
bf16_data: &bf16,
}];
let bytes = write_bnb_nf4_safetensors_bytes(&inputs).unwrap();
let header = parse_safetensors_header(&bytes).unwrap();
assert_eq!(
header.scheme,
QuantScheme::Bnb4,
"scheme should be detected as Bnb4"
);
let names: Vec<&str> = header.tensors.iter().map(|t| t.name.as_str()).collect();
assert!(names.contains(&"linear.weight"));
assert!(names.contains(&"linear.weight.absmax"));
assert!(names.contains(&"linear.weight.quant_map"));
assert!(names.contains(&"linear.weight.quant_state.bitsandbytes__nf4"));
}
#[test]
fn passthrough_for_ineligible_shapes() {
let bf16_1d = synth_bf16(1, 8); let inputs = vec![BnbWriteInput {
name: "norm",
shape: &[8],
bf16_data: &bf16_1d,
}];
let bytes = write_bnb_nf4_safetensors_bytes(&inputs).unwrap();
let header = parse_safetensors_header(&bytes).unwrap();
let names: Vec<&str> = header.tensors.iter().map(|t| t.name.as_str()).collect();
assert_eq!(names, vec!["norm"]);
}
#[test]
fn round_trip_decode_recovers_within_codebook_error() {
let bf16 = synth_bf16(64, 1);
let inputs = vec![BnbWriteInput {
name: "linear",
shape: &[64, 1],
bf16_data: &bf16,
}];
let bytes = write_bnb_nf4_safetensors_bytes(&inputs).unwrap();
let parsed = safetensors::SafeTensors::deserialize(&bytes).unwrap();
let weight = parsed.tensor("linear.weight").unwrap();
let absmax = parsed.tensor("linear.weight.absmax").unwrap();
let qmap = parsed.tensor("linear.weight.quant_map").unwrap();
let total_elements = 64;
let decoded = dequantize_bnb4_to_bf16(
weight.data(),
absmax.data(),
qmap.data(),
total_elements,
NF4_BLOCK_SIZE,
)
.unwrap();
let re_inputs = vec![BnbWriteInput {
name: "linear",
shape: &[64, 1],
bf16_data: &decoded,
}];
let re_bytes = write_bnb_nf4_safetensors_bytes(&re_inputs).unwrap();
let re_parsed = safetensors::SafeTensors::deserialize(&re_bytes).unwrap();
let re_weight = re_parsed.tensor("linear.weight").unwrap();
assert_eq!(
weight.data(),
re_weight.data(),
"BnB-NF4 encode is not idempotent on already-quantized BF16"
);
}
#[test]
fn rejects_bf16_length_mismatch() {
let bf16 = vec![0u8; 16];
let inputs = vec![BnbWriteInput {
name: "w",
shape: &[64, 1],
bf16_data: &bf16,
}];
let err = write_bnb_nf4_safetensors_bytes(&inputs).expect_err("should reject");
match err {
AnamnesisError::Parse { reason } => {
assert!(
reason.contains("bf16_data length"),
"unexpected reason: {reason}"
);
}
other => panic!("expected Parse, got {other:?}"),
}
}
#[test]
fn classify_counts() {
let bf16_a = synth_bf16(64, 1);
let bf16_b = synth_bf16(1, 8);
let inputs = vec![
BnbWriteInput {
name: "w",
shape: &[64, 1],
bf16_data: &bf16_a,
},
BnbWriteInput {
name: "b",
shape: &[8],
bf16_data: &bf16_b,
},
];
let stats = classify_inputs(&inputs);
assert_eq!(stats.quantized, 1);
assert_eq!(stats.passthrough, 1);
}
}