pub const NF4_LUT: [f32; 16] = [
-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0, ];
#[inline]
pub fn dequantize_nf4(val: u8) -> f32 {
NF4_LUT[(val & 0x0F) as usize]
}
#[inline]
pub fn quantize_nf4(x: f32) -> u8 {
if x > 0.039_790_15 {
if x > 0.389_312_54 {
if x > 0.642_786_9 {
if x > 0.861_478_4 {
0b1111
} else {
0b1110
}
} else if x > 0.501_663_4 {
0b1101
} else {
0b1100
}
} else if x > 0.203_521_25 {
if x > 0.292_013_77 {
0b1011
} else {
0b1010
}
} else if x > 0.120_255_25 {
0b1001
} else {
0b1000
}
} else if x > -0.339_679_43 {
if x > -0.137_911_73 {
if x > -0.045_525_018 {
0b0111
} else {
0b0110
}
} else if x > -0.234_607_41 {
0b0101
} else {
0b0100
}
} else if x > -0.610_632_93 {
if x > -0.459_995_27 {
0b0011
} else {
0b0010
}
} else if x > -0.848_096_4 {
0b0001
} else {
0b0000
}
}
pub fn dequantize_blockwise(packed: &[u8], absmax: &[f32], blocksize: usize, output: &mut [f32]) {
assert_eq!(output.len(), packed.len() * 2, "output must be 2× packed length");
let half_block = blocksize / 2;
for (byte_idx, &byte) in packed.iter().enumerate() {
let elem_idx = byte_idx * 2;
let block_idx = elem_idx / blocksize;
let scale = absmax[block_idx];
let high = (byte >> 4) & 0x0F;
let low = byte & 0x0F;
output[elem_idx] = NF4_LUT[high as usize] * scale;
if elem_idx + 1 < output.len() {
output[elem_idx + 1] = NF4_LUT[low as usize] * scale;
}
}
}
pub fn quantize_blockwise(input: &[f32], blocksize: usize, packed: &mut [u8], absmax: &mut [f32]) {
assert_eq!(packed.len(), (input.len() + 1) / 2, "packed must be ceil(input/2)");
let num_blocks = (input.len() + blocksize - 1) / blocksize;
assert!(absmax.len() >= num_blocks, "absmax too small");
for block in 0..num_blocks {
let start = block * blocksize;
let end = (start + blocksize).min(input.len());
let mut max_val: f32 = 0.0;
for &v in &input[start..end] {
let abs = v.abs();
if abs > max_val {
max_val = abs;
}
}
absmax[block] = max_val;
}
for byte_idx in 0..packed.len() {
let elem_idx = byte_idx * 2;
let block_idx = elem_idx / blocksize;
let scale = absmax[block_idx];
let inv_scale = if scale > 0.0 { 1.0 / scale } else { 0.0 };
let high = quantize_nf4(input[elem_idx] * inv_scale);
let low = if elem_idx + 1 < input.len() {
let block_idx_low = (elem_idx + 1) / blocksize;
let inv_scale_low =
if absmax[block_idx_low] > 0.0 { 1.0 / absmax[block_idx_low] } else { 0.0 };
quantize_nf4(input[elem_idx + 1] * inv_scale_low)
} else {
0
};
packed[byte_idx] = (high << 4) | low;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nf4_lut_monotonic() {
for i in 0..15 {
assert!(
NF4_LUT[i] < NF4_LUT[i + 1],
"NF4_LUT[{}]={} >= NF4_LUT[{}]={}",
i,
NF4_LUT[i],
i + 1,
NF4_LUT[i + 1]
);
}
}
#[test]
fn test_nf4_lut_boundaries() {
assert_eq!(NF4_LUT[0], -1.0);
assert_eq!(NF4_LUT[7], 0.0);
assert_eq!(NF4_LUT[15], 1.0);
}
#[test]
fn test_nf4_roundtrip_exhaustive() {
for code in 0u8..16 {
let dequantized = dequantize_nf4(code);
let requantized = quantize_nf4(dequantized);
assert_eq!(
code, requantized,
"Roundtrip failed: code={} → dequant={} → requant={}",
code, dequantized, requantized
);
}
}
#[test]
fn test_nf4_nibble_order() {
let byte: u8 = 0xAB;
let high = (byte >> 4) & 0x0F;
let low = byte & 0x0F;
assert_eq!(high, 10);
assert_eq!(low, 11);
let dh = dequantize_nf4(high);
let dl = dequantize_nf4(low);
assert_ne!(dh, dl);
assert_eq!(dh, NF4_LUT[10]);
assert_eq!(dl, NF4_LUT[11]);
}
#[test]
fn test_nf4_blockwise_roundtrip() {
let blocksize = 64;
let n = 256;
let input: Vec<f32> = (0..n).map(|i| (i as f32 / n as f32) * 2.0 - 1.0).collect();
let mut packed = vec![0u8; n / 2];
let mut absmax = vec![0.0f32; n / blocksize];
quantize_blockwise(&input, blocksize, &mut packed, &mut absmax);
let mut output = vec![0.0f32; n];
dequantize_blockwise(&packed, &absmax, blocksize, &mut output);
for (i, (&orig, &deq)) in input.iter().zip(output.iter()).enumerate() {
let block = i / blocksize;
let max_error = 0.28 * absmax[block]; let error = (orig - deq).abs();
assert!(
error <= max_error + 1e-6,
"Block {} elem {}: error {} > max_error {}",
block,
i,
error,
max_error
);
}
}
#[test]
fn test_quantize_nf4_range() {
for i in 0..1000 {
let x = (i as f32 / 500.0) - 1.0; let code = quantize_nf4(x);
assert!(code < 16, "quantize_nf4({}) = {} >= 16", x, code);
}
}
}