use crate::error::{QuantError, QuantResult};
pub const NF4_LUT: [f32; 16] = [
-1.0,
-0.696_192_86,
-0.525_073_05,
-0.394_917_5,
-0.284_441_38,
-0.184_773_43,
-0.091_050_03,
0.0,
0.079_580_3,
0.160_930_2,
0.246_112_3,
0.337_915_24,
0.440_709_83,
0.562_617,
0.722_956_84,
1.0,
];
#[derive(Debug, Clone)]
pub struct Nf4Quantizer {
pub block_size: usize,
}
impl Default for Nf4Quantizer {
fn default() -> Self {
Self { block_size: 64 }
}
}
impl Nf4Quantizer {
#[must_use]
pub fn new(block_size: usize) -> Self {
assert!(block_size > 0, "block_size must be > 0");
Self { block_size }
}
pub fn encode(&self, tensor: &[f32]) -> QuantResult<(Vec<u8>, Vec<f32>)> {
if tensor.is_empty() {
return Err(QuantError::EmptyInput("Nf4Quantizer::encode"));
}
if tensor.len() % self.block_size != 0 {
return Err(QuantError::GroupSizeMismatch {
len: tensor.len(),
group: self.block_size,
});
}
let n_blocks = tensor.len() / self.block_size;
let n_bytes = tensor.len() / 2; let mut packed = vec![0u8; n_bytes];
let mut absmaxs = Vec::with_capacity(n_blocks);
for (blk_idx, block) in tensor.chunks_exact(self.block_size).enumerate() {
let absmax = block.iter().map(|&v| v.abs()).fold(0.0_f32, f32::max);
let absmax = if absmax < 1e-8 { 1e-8 } else { absmax };
absmaxs.push(absmax);
let base_byte = blk_idx * self.block_size / 2;
for (i, &v) in block.iter().enumerate() {
let normed = (v / absmax).clamp(-1.0, 1.0);
let code = nearest_nf4(normed) as u8;
let byte_idx = base_byte + i / 2;
if i % 2 == 0 {
packed[byte_idx] = code; } else {
packed[byte_idx] |= code << 4; }
}
}
Ok((packed, absmaxs))
}
pub fn decode(&self, packed: &[u8], absmaxs: &[f32]) -> QuantResult<Vec<f32>> {
let n_floats = packed.len() * 2;
let n_blocks_expected = n_floats / self.block_size;
if absmaxs.len() != n_blocks_expected {
return Err(QuantError::DimensionMismatch {
expected: n_blocks_expected,
got: absmaxs.len(),
});
}
let mut out = Vec::with_capacity(n_floats);
for (blk_idx, block_bytes) in packed.chunks_exact(self.block_size / 2).enumerate() {
let absmax = absmaxs[blk_idx];
for &byte in block_bytes {
let lo = (byte & 0x0F) as usize;
let hi = (byte >> 4) as usize;
out.push(NF4_LUT[lo] * absmax);
out.push(NF4_LUT[hi] * absmax);
}
}
Ok(out)
}
pub fn quantization_mse(&self, tensor: &[f32]) -> QuantResult<f32> {
let (packed, absmaxs) = self.encode(tensor)?;
let decoded = self.decode(&packed, &absmaxs)?;
let mse = tensor
.iter()
.zip(decoded.iter())
.map(|(&a, &b)| (a - b).powi(2))
.sum::<f32>()
/ tensor.len() as f32;
Ok(mse)
}
}
fn nearest_nf4(v: f32) -> usize {
let mut lo = 0_usize;
let mut hi = NF4_LUT.len();
while lo < hi {
let mid = lo + (hi - lo) / 2;
if NF4_LUT[mid] < v {
lo = mid + 1;
} else {
hi = mid;
}
}
if lo == 0 {
return 0;
}
if lo == NF4_LUT.len() {
return NF4_LUT.len() - 1;
}
let d_lo = (v - NF4_LUT[lo - 1]).abs();
let d_hi = (NF4_LUT[lo] - v).abs();
if d_lo <= d_hi { lo - 1 } else { lo }
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn lut_is_sorted_ascending() {
for w in NF4_LUT.windows(2) {
assert!(w[0] < w[1], "LUT must be sorted: {} >= {}", w[0], w[1]);
}
}
#[test]
fn lut_endpoints() {
assert_abs_diff_eq!(NF4_LUT[0], -1.0, epsilon = 1e-9);
assert_abs_diff_eq!(NF4_LUT[15], 1.0, epsilon = 1e-9);
assert_abs_diff_eq!(NF4_LUT[7], 0.0, epsilon = 1e-9);
}
#[test]
fn nearest_nf4_endpoints() {
assert_eq!(nearest_nf4(-1.0), 0, "exactly -1 → index 0");
assert_eq!(nearest_nf4(1.0), 15, "exactly 1 → index 15");
assert_eq!(nearest_nf4(0.0), 7, "exactly 0 → index 7");
}
#[test]
fn nearest_nf4_midpoint() {
let mid = (NF4_LUT[7] + NF4_LUT[8]) / 2.0;
let idx = nearest_nf4(mid);
assert!(idx == 7 || idx == 8, "midpoint should map to 7 or 8");
}
#[test]
fn encode_decode_round_trip() {
let q = Nf4Quantizer::new(64);
let t: Vec<f32> = (0..128).map(|i| (i as f32 / 64.0) - 1.0).collect();
let (packed, absmaxs) = q.encode(&t).unwrap();
assert_eq!(packed.len(), 64);
assert_eq!(absmaxs.len(), 2);
let decoded = q.decode(&packed, &absmaxs).unwrap();
let mse = t
.iter()
.zip(decoded.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
/ 128.0;
assert!(mse < 0.01, "MSE too large: {mse}");
}
#[test]
fn all_zeros_encodes_cleanly() {
let q = Nf4Quantizer::default();
let t = vec![0.0_f32; 64];
let (packed, absmaxs) = q.encode(&t).unwrap();
assert_eq!(absmaxs.len(), 1);
let decoded = q.decode(&packed, &absmaxs).unwrap();
for v in decoded {
assert!(v.abs() < 1e-5, "decoded zero should be near zero");
}
}
#[test]
fn mse_within_nf4_theory() {
let q = Nf4Quantizer::new(64);
let t: Vec<f32> = (0..1024)
.map(|i| {
let u = (i % 64) as f32 / 64.0;
2.0 * u - 1.0
})
.collect();
let mse = q.quantization_mse(&t).unwrap();
assert!(mse < 0.05, "NF4 MSE unexpectedly large: {mse}");
}
#[test]
fn group_size_mismatch_error() {
let q = Nf4Quantizer::new(64);
let t = vec![0.5_f32; 65]; assert!(matches!(
q.encode(&t),
Err(QuantError::GroupSizeMismatch { .. })
));
}
#[test]
fn decode_length_mismatch_error() {
let q = Nf4Quantizer::new(64);
let packed = vec![0u8; 32];
let absmaxs = vec![1.0_f32; 5]; assert!(matches!(
q.decode(&packed, &absmaxs),
Err(QuantError::DimensionMismatch { .. })
));
}
}