Skip to main content

oxicuda_quant/scheme/
nf4.rs

1//! # NF4 — NormalFloat4 Quantization
2//!
3//! NF4 (Dettmers et al., 2023 — "QLoRA: Efficient Finetuning of Quantized LLMs")
4//! is a data type that is **information-theoretically optimal** for normally
5//! distributed weights.  It stores 4-bit indices into a 16-entry lookup table
6//! whose values are the quantiles of N(0, 1) at equal probability mass points.
7//!
8//! ## Encoding
9//!
10//! ```text
11//! absmax = max(|W|)
12//! W_norm = W / absmax            ∈ [-1, 1]
13//! code   = argmin_{v ∈ LUT} |W_norm - v|   (nearest-neighbour in LUT)
14//! ```
15//!
16//! Two codes are packed per byte (lo nibble = first element).
17//!
18//! ## Decoding
19//!
20//! ```text
21//! W_approx = LUT[code] * absmax
22//! ```
23
24use crate::error::{QuantError, QuantResult};
25
26// ─── NF4 Lookup Table ────────────────────────────────────────────────────────
27
28/// The 16 NF4 quantization levels (sorted ascending).
29///
30/// These are the quantiles of the standard normal distribution at probability
31/// mass points `{0.5/16, 1.5/16, ..., 15.5/16}`, scaled so that the extreme
32/// values are exactly ±1.
33pub const NF4_LUT: [f32; 16] = [
34    -1.0,
35    -0.696_192_86,
36    -0.525_073_05,
37    -0.394_917_5,
38    -0.284_441_38,
39    -0.184_773_43,
40    -0.091_050_03,
41    0.0,
42    0.079_580_3,
43    0.160_930_2,
44    0.246_112_3,
45    0.337_915_24,
46    0.440_709_83,
47    0.562_617,
48    0.722_956_84,
49    1.0,
50];
51
52// ─── Nf4Quantizer ─────────────────────────────────────────────────────────────
53
54/// NF4 quantizer — encodes tensors to packed 4-bit NF4 codes.
55///
56/// Blocks of `block_size` elements share an `absmax` scaling factor.
57/// The default `block_size` of 64 matches the QLoRA paper.
58#[derive(Debug, Clone)]
59pub struct Nf4Quantizer {
60    /// Number of elements per absmax scaling block.
61    pub block_size: usize,
62}
63
64impl Default for Nf4Quantizer {
65    fn default() -> Self {
66        Self { block_size: 64 }
67    }
68}
69
70impl Nf4Quantizer {
71    /// Create an NF4 quantizer with the given block size.
72    ///
73    /// # Panics
74    ///
75    /// Panics if `block_size == 0`.
76    #[must_use]
77    pub fn new(block_size: usize) -> Self {
78        assert!(block_size > 0, "block_size must be > 0");
79        Self { block_size }
80    }
81
82    /// Encode a flat tensor to packed NF4 bytes and per-block absmax values.
83    ///
84    /// The number of elements must be a multiple of `block_size`, and
85    /// `block_size` must be even (pairs of nibbles pack into bytes).
86    ///
87    /// Returns `(packed_bytes, absmax_per_block)`.
88    ///
89    /// # Errors
90    ///
91    /// * [`QuantError::GroupSizeMismatch`] — if `len` is not divisible by `block_size`.
92    /// * [`QuantError::EmptyInput`] — if `tensor` is empty.
93    pub fn encode(&self, tensor: &[f32]) -> QuantResult<(Vec<u8>, Vec<f32>)> {
94        if tensor.is_empty() {
95            return Err(QuantError::EmptyInput("Nf4Quantizer::encode"));
96        }
97        if tensor.len() % self.block_size != 0 {
98            return Err(QuantError::GroupSizeMismatch {
99                len: tensor.len(),
100                group: self.block_size,
101            });
102        }
103        let n_blocks = tensor.len() / self.block_size;
104        let n_bytes = tensor.len() / 2; // 2 codes per byte
105        let mut packed = vec![0u8; n_bytes];
106        let mut absmaxs = Vec::with_capacity(n_blocks);
107
108        for (blk_idx, block) in tensor.chunks_exact(self.block_size).enumerate() {
109            // Compute absmax for this block.
110            let absmax = block.iter().map(|&v| v.abs()).fold(0.0_f32, f32::max);
111            let absmax = if absmax < 1e-8 { 1e-8 } else { absmax };
112            absmaxs.push(absmax);
113
114            // Encode each element.
115            let base_byte = blk_idx * self.block_size / 2;
116            for (i, &v) in block.iter().enumerate() {
117                let normed = (v / absmax).clamp(-1.0, 1.0);
118                let code = nearest_nf4(normed) as u8;
119                let byte_idx = base_byte + i / 2;
120                if i % 2 == 0 {
121                    packed[byte_idx] = code; // lo nibble
122                } else {
123                    packed[byte_idx] |= code << 4; // hi nibble
124                }
125            }
126        }
127        Ok((packed, absmaxs))
128    }
129
130    /// Decode packed NF4 bytes back to f32 using the stored absmax values.
131    ///
132    /// # Errors
133    ///
134    /// * [`QuantError::DimensionMismatch`] — if packed/absmax lengths are inconsistent.
135    pub fn decode(&self, packed: &[u8], absmaxs: &[f32]) -> QuantResult<Vec<f32>> {
136        let n_floats = packed.len() * 2;
137        let n_blocks_expected = n_floats / self.block_size;
138        if absmaxs.len() != n_blocks_expected {
139            return Err(QuantError::DimensionMismatch {
140                expected: n_blocks_expected,
141                got: absmaxs.len(),
142            });
143        }
144        let mut out = Vec::with_capacity(n_floats);
145        for (blk_idx, block_bytes) in packed.chunks_exact(self.block_size / 2).enumerate() {
146            let absmax = absmaxs[blk_idx];
147            for &byte in block_bytes {
148                let lo = (byte & 0x0F) as usize;
149                let hi = (byte >> 4) as usize;
150                out.push(NF4_LUT[lo] * absmax);
151                out.push(NF4_LUT[hi] * absmax);
152            }
153        }
154        Ok(out)
155    }
156
157    /// Estimate the quantization error (mean squared error) for a tensor.
158    ///
159    /// # Errors
160    ///
161    /// Propagates errors from [`encode`](Self::encode) / [`decode`](Self::decode).
162    pub fn quantization_mse(&self, tensor: &[f32]) -> QuantResult<f32> {
163        let (packed, absmaxs) = self.encode(tensor)?;
164        let decoded = self.decode(&packed, &absmaxs)?;
165        let mse = tensor
166            .iter()
167            .zip(decoded.iter())
168            .map(|(&a, &b)| (a - b).powi(2))
169            .sum::<f32>()
170            / tensor.len() as f32;
171        Ok(mse)
172    }
173}
174
175// ─── Helpers ─────────────────────────────────────────────────────────────────
176
177/// Find the index of the nearest NF4 level using binary search.
178///
179/// Since `NF4_LUT` is sorted, we use the standard partition-point approach
180/// and compare both neighbours.
181fn nearest_nf4(v: f32) -> usize {
182    // Binary search for the insertion point.
183    let mut lo = 0_usize;
184    let mut hi = NF4_LUT.len();
185    while lo < hi {
186        let mid = lo + (hi - lo) / 2;
187        if NF4_LUT[mid] < v {
188            lo = mid + 1;
189        } else {
190            hi = mid;
191        }
192    }
193    // `lo` is the first index where NF4_LUT[lo] >= v.
194    if lo == 0 {
195        return 0;
196    }
197    if lo == NF4_LUT.len() {
198        return NF4_LUT.len() - 1;
199    }
200    // Compare the two neighbours.
201    let d_lo = (v - NF4_LUT[lo - 1]).abs();
202    let d_hi = (NF4_LUT[lo] - v).abs();
203    if d_lo <= d_hi { lo - 1 } else { lo }
204}
205
206// ─── Tests ───────────────────────────────────────────────────────────────────
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use approx::assert_abs_diff_eq;
212
213    #[test]
214    fn lut_is_sorted_ascending() {
215        for w in NF4_LUT.windows(2) {
216            assert!(w[0] < w[1], "LUT must be sorted: {} >= {}", w[0], w[1]);
217        }
218    }
219
220    #[test]
221    fn lut_endpoints() {
222        assert_abs_diff_eq!(NF4_LUT[0], -1.0, epsilon = 1e-9);
223        assert_abs_diff_eq!(NF4_LUT[15], 1.0, epsilon = 1e-9);
224        assert_abs_diff_eq!(NF4_LUT[7], 0.0, epsilon = 1e-9);
225    }
226
227    #[test]
228    fn nearest_nf4_endpoints() {
229        assert_eq!(nearest_nf4(-1.0), 0, "exactly -1 → index 0");
230        assert_eq!(nearest_nf4(1.0), 15, "exactly 1 → index 15");
231        assert_eq!(nearest_nf4(0.0), 7, "exactly 0 → index 7");
232    }
233
234    #[test]
235    fn nearest_nf4_midpoint() {
236        // Between LUT[7]=0 and LUT[8]=0.0796: midpoint ≈ 0.0398
237        let mid = (NF4_LUT[7] + NF4_LUT[8]) / 2.0;
238        let idx = nearest_nf4(mid);
239        assert!(idx == 7 || idx == 8, "midpoint should map to 7 or 8");
240    }
241
242    #[test]
243    fn encode_decode_round_trip() {
244        let q = Nf4Quantizer::new(64);
245        let t: Vec<f32> = (0..128).map(|i| (i as f32 / 64.0) - 1.0).collect();
246        let (packed, absmaxs) = q.encode(&t).unwrap();
247        assert_eq!(packed.len(), 64);
248        assert_eq!(absmaxs.len(), 2);
249        let decoded = q.decode(&packed, &absmaxs).unwrap();
250        // NF4 is lossy; error should be small but non-zero.
251        let mse = t
252            .iter()
253            .zip(decoded.iter())
254            .map(|(a, b)| (a - b).powi(2))
255            .sum::<f32>()
256            / 128.0;
257        assert!(mse < 0.01, "MSE too large: {mse}");
258    }
259
260    #[test]
261    fn all_zeros_encodes_cleanly() {
262        let q = Nf4Quantizer::default();
263        let t = vec![0.0_f32; 64];
264        let (packed, absmaxs) = q.encode(&t).unwrap();
265        // With all-zero input, absmax = 1e-8, codes all map to index 7 (value 0).
266        assert_eq!(absmaxs.len(), 1);
267        let decoded = q.decode(&packed, &absmaxs).unwrap();
268        for v in decoded {
269            assert!(v.abs() < 1e-5, "decoded zero should be near zero");
270        }
271    }
272
273    #[test]
274    fn mse_within_nf4_theory() {
275        // Theory: NF4 should give ~0.3% relative MSE for normal data.
276        // Use a larger random-ish sample.
277        let q = Nf4Quantizer::new(64);
278        // Approximate N(0,1) using sum of uniforms (CLT)
279        let t: Vec<f32> = (0..1024)
280            .map(|i| {
281                let u = (i % 64) as f32 / 64.0;
282                2.0 * u - 1.0
283            })
284            .collect();
285        let mse = q.quantization_mse(&t).unwrap();
286        assert!(mse < 0.05, "NF4 MSE unexpectedly large: {mse}");
287    }
288
289    #[test]
290    fn group_size_mismatch_error() {
291        let q = Nf4Quantizer::new(64);
292        let t = vec![0.5_f32; 65]; // 65 % 64 != 0
293        assert!(matches!(
294            q.encode(&t),
295            Err(QuantError::GroupSizeMismatch { .. })
296        ));
297    }
298
299    #[test]
300    fn decode_length_mismatch_error() {
301        let q = Nf4Quantizer::new(64);
302        let packed = vec![0u8; 32];
303        let absmaxs = vec![1.0_f32; 5]; // wrong: expected 1
304        assert!(matches!(
305            q.decode(&packed, &absmaxs),
306            Err(QuantError::DimensionMismatch { .. })
307        ));
308    }
309}