oxicuda_quant/scheme/
nf4.rs1use crate::error::{QuantError, QuantResult};
25
26pub const NF4_LUT: [f32; 16] = [
34 -1.0,
35 -0.696_192_86,
36 -0.525_073_05,
37 -0.394_917_5,
38 -0.284_441_38,
39 -0.184_773_43,
40 -0.091_050_03,
41 0.0,
42 0.079_580_3,
43 0.160_930_2,
44 0.246_112_3,
45 0.337_915_24,
46 0.440_709_83,
47 0.562_617,
48 0.722_956_84,
49 1.0,
50];
51
52#[derive(Debug, Clone)]
59pub struct Nf4Quantizer {
60 pub block_size: usize,
62}
63
64impl Default for Nf4Quantizer {
65 fn default() -> Self {
66 Self { block_size: 64 }
67 }
68}
69
70impl Nf4Quantizer {
71 #[must_use]
77 pub fn new(block_size: usize) -> Self {
78 assert!(block_size > 0, "block_size must be > 0");
79 Self { block_size }
80 }
81
82 pub fn encode(&self, tensor: &[f32]) -> QuantResult<(Vec<u8>, Vec<f32>)> {
94 if tensor.is_empty() {
95 return Err(QuantError::EmptyInput("Nf4Quantizer::encode"));
96 }
97 if tensor.len() % self.block_size != 0 {
98 return Err(QuantError::GroupSizeMismatch {
99 len: tensor.len(),
100 group: self.block_size,
101 });
102 }
103 let n_blocks = tensor.len() / self.block_size;
104 let n_bytes = tensor.len() / 2; let mut packed = vec![0u8; n_bytes];
106 let mut absmaxs = Vec::with_capacity(n_blocks);
107
108 for (blk_idx, block) in tensor.chunks_exact(self.block_size).enumerate() {
109 let absmax = block.iter().map(|&v| v.abs()).fold(0.0_f32, f32::max);
111 let absmax = if absmax < 1e-8 { 1e-8 } else { absmax };
112 absmaxs.push(absmax);
113
114 let base_byte = blk_idx * self.block_size / 2;
116 for (i, &v) in block.iter().enumerate() {
117 let normed = (v / absmax).clamp(-1.0, 1.0);
118 let code = nearest_nf4(normed) as u8;
119 let byte_idx = base_byte + i / 2;
120 if i % 2 == 0 {
121 packed[byte_idx] = code; } else {
123 packed[byte_idx] |= code << 4; }
125 }
126 }
127 Ok((packed, absmaxs))
128 }
129
130 pub fn decode(&self, packed: &[u8], absmaxs: &[f32]) -> QuantResult<Vec<f32>> {
136 let n_floats = packed.len() * 2;
137 let n_blocks_expected = n_floats / self.block_size;
138 if absmaxs.len() != n_blocks_expected {
139 return Err(QuantError::DimensionMismatch {
140 expected: n_blocks_expected,
141 got: absmaxs.len(),
142 });
143 }
144 let mut out = Vec::with_capacity(n_floats);
145 for (blk_idx, block_bytes) in packed.chunks_exact(self.block_size / 2).enumerate() {
146 let absmax = absmaxs[blk_idx];
147 for &byte in block_bytes {
148 let lo = (byte & 0x0F) as usize;
149 let hi = (byte >> 4) as usize;
150 out.push(NF4_LUT[lo] * absmax);
151 out.push(NF4_LUT[hi] * absmax);
152 }
153 }
154 Ok(out)
155 }
156
157 pub fn quantization_mse(&self, tensor: &[f32]) -> QuantResult<f32> {
163 let (packed, absmaxs) = self.encode(tensor)?;
164 let decoded = self.decode(&packed, &absmaxs)?;
165 let mse = tensor
166 .iter()
167 .zip(decoded.iter())
168 .map(|(&a, &b)| (a - b).powi(2))
169 .sum::<f32>()
170 / tensor.len() as f32;
171 Ok(mse)
172 }
173}
174
175fn nearest_nf4(v: f32) -> usize {
182 let mut lo = 0_usize;
184 let mut hi = NF4_LUT.len();
185 while lo < hi {
186 let mid = lo + (hi - lo) / 2;
187 if NF4_LUT[mid] < v {
188 lo = mid + 1;
189 } else {
190 hi = mid;
191 }
192 }
193 if lo == 0 {
195 return 0;
196 }
197 if lo == NF4_LUT.len() {
198 return NF4_LUT.len() - 1;
199 }
200 let d_lo = (v - NF4_LUT[lo - 1]).abs();
202 let d_hi = (NF4_LUT[lo] - v).abs();
203 if d_lo <= d_hi { lo - 1 } else { lo }
204}
205
206#[cfg(test)]
209mod tests {
210 use super::*;
211 use approx::assert_abs_diff_eq;
212
213 #[test]
214 fn lut_is_sorted_ascending() {
215 for w in NF4_LUT.windows(2) {
216 assert!(w[0] < w[1], "LUT must be sorted: {} >= {}", w[0], w[1]);
217 }
218 }
219
220 #[test]
221 fn lut_endpoints() {
222 assert_abs_diff_eq!(NF4_LUT[0], -1.0, epsilon = 1e-9);
223 assert_abs_diff_eq!(NF4_LUT[15], 1.0, epsilon = 1e-9);
224 assert_abs_diff_eq!(NF4_LUT[7], 0.0, epsilon = 1e-9);
225 }
226
227 #[test]
228 fn nearest_nf4_endpoints() {
229 assert_eq!(nearest_nf4(-1.0), 0, "exactly -1 → index 0");
230 assert_eq!(nearest_nf4(1.0), 15, "exactly 1 → index 15");
231 assert_eq!(nearest_nf4(0.0), 7, "exactly 0 → index 7");
232 }
233
234 #[test]
235 fn nearest_nf4_midpoint() {
236 let mid = (NF4_LUT[7] + NF4_LUT[8]) / 2.0;
238 let idx = nearest_nf4(mid);
239 assert!(idx == 7 || idx == 8, "midpoint should map to 7 or 8");
240 }
241
242 #[test]
243 fn encode_decode_round_trip() {
244 let q = Nf4Quantizer::new(64);
245 let t: Vec<f32> = (0..128).map(|i| (i as f32 / 64.0) - 1.0).collect();
246 let (packed, absmaxs) = q.encode(&t).unwrap();
247 assert_eq!(packed.len(), 64);
248 assert_eq!(absmaxs.len(), 2);
249 let decoded = q.decode(&packed, &absmaxs).unwrap();
250 let mse = t
252 .iter()
253 .zip(decoded.iter())
254 .map(|(a, b)| (a - b).powi(2))
255 .sum::<f32>()
256 / 128.0;
257 assert!(mse < 0.01, "MSE too large: {mse}");
258 }
259
260 #[test]
261 fn all_zeros_encodes_cleanly() {
262 let q = Nf4Quantizer::default();
263 let t = vec![0.0_f32; 64];
264 let (packed, absmaxs) = q.encode(&t).unwrap();
265 assert_eq!(absmaxs.len(), 1);
267 let decoded = q.decode(&packed, &absmaxs).unwrap();
268 for v in decoded {
269 assert!(v.abs() < 1e-5, "decoded zero should be near zero");
270 }
271 }
272
273 #[test]
274 fn mse_within_nf4_theory() {
275 let q = Nf4Quantizer::new(64);
278 let t: Vec<f32> = (0..1024)
280 .map(|i| {
281 let u = (i % 64) as f32 / 64.0;
282 2.0 * u - 1.0
283 })
284 .collect();
285 let mse = q.quantization_mse(&t).unwrap();
286 assert!(mse < 0.05, "NF4 MSE unexpectedly large: {mse}");
287 }
288
289 #[test]
290 fn group_size_mismatch_error() {
291 let q = Nf4Quantizer::new(64);
292 let t = vec![0.5_f32; 65]; assert!(matches!(
294 q.encode(&t),
295 Err(QuantError::GroupSizeMismatch { .. })
296 ));
297 }
298
299 #[test]
300 fn decode_length_mismatch_error() {
301 let q = Nf4Quantizer::new(64);
302 let packed = vec![0u8; 32];
303 let absmaxs = vec![1.0_f32; 5]; assert!(matches!(
305 q.decode(&packed, &absmaxs),
306 Err(QuantError::DimensionMismatch { .. })
307 ));
308 }
309}