Skip to main content

entrenar/quant/
quant4bit.rs

1//! 4-bit quantization for QLoRA
2//!
3//! Implements block-wise symmetric 4-bit quantization to reduce memory usage
4//! of frozen base weights by ~75% (4 bits vs 32 bits per value).
5//!
6//! Uses block-wise quantization with 64-element blocks, where each block has:
7//! - 1 scale factor (f32)
8//! - 64 quantized values (4 bits each = 32 bytes total)
9//!
10//! Quantization: q = round(clamp(x / scale, -7, 7))
11//! Dequantization: x ≈ q * scale
12
13use serde::{Deserialize, Serialize};
14
15/// Block size for quantization (64 elements per block)
16pub const BLOCK_SIZE: usize = 64;
17
18/// 4-bit quantized representation with block-wise scale factors
19///
20/// Memory layout:
21/// - scales: `Vec<f32>` with length = ceil(n / BLOCK_SIZE)
22/// - data: `Vec<u8>` where each byte stores 2 quantized values (4 bits each)
23#[derive(Clone, Debug, Serialize, Deserialize)]
24pub struct Quantized4Bit {
25    /// Scale factors (one per block)
26    pub scales: Vec<f32>,
27    /// Quantized data: 2 values per byte (4 bits each)
28    pub data: Vec<u8>,
29    /// Original number of elements
30    pub len: usize,
31}
32
33impl Quantized4Bit {
34    /// Get memory usage in bytes
35    pub fn memory_bytes(&self) -> usize {
36        self.scales.len() * 4 + self.data.len()
37    }
38
39    /// Get compression ratio vs f32
40    #[provable_contracts_macros::contract("quantization-v1", equation = "compression_ratio")]
41    pub fn compression_ratio(&self) -> f32 {
42        let original_bytes = self.len * 4; // f32
43        let compressed_bytes = self.memory_bytes();
44        original_bytes as f32 / compressed_bytes as f32
45    }
46}
47
48/// Quantize f32 values to 4-bit with block-wise scaling
49///
50/// # Arguments
51/// * `values` - Input f32 values
52///
53/// # Returns
54/// Quantized4Bit structure with scales and packed 4-bit data
55pub fn quantize_4bit(values: &[f32]) -> Quantized4Bit {
56    let len = values.len();
57    let num_blocks = len.div_ceil(BLOCK_SIZE);
58
59    let mut scales = Vec::with_capacity(num_blocks);
60    let mut data = Vec::with_capacity(len.div_ceil(2)); // 2 values per byte
61
62    for block_idx in 0..num_blocks {
63        let start = block_idx * BLOCK_SIZE;
64        let end = (start + BLOCK_SIZE).min(len);
65        let block = &values[start..end];
66
67        // Compute scale factor for this block: max absolute value / 7
68        // (7 is max representable in 4-bit signed: -7 to 7)
69        let max_abs = block.iter().map(|v| v.abs()).max_by(f32::total_cmp).unwrap_or(1e-8);
70
71        let scale = max_abs / 7.0;
72        scales.push(scale);
73
74        // Quantize each value in the block
75        for (i, &val) in block.iter().enumerate() {
76            let quantized = quantize_value(val, scale);
77
78            // Pack 2 values per byte (mask to 4 bits: 0x0F)
79            if i.is_multiple_of(2) {
80                // First value: store in upper 4 bits
81                data.push(((quantized as u8) & 0x0F) << 4);
82            } else {
83                // Second value: store in lower 4 bits
84                let last_idx = data.len() - 1;
85                data[last_idx] |= (quantized as u8) & 0x0F;
86            }
87        }
88
89        // If block has odd number of elements, the last byte is already pushed
90        // with upper 4 bits filled and lower 4 bits as 0
91    }
92
93    Quantized4Bit { scales, data, len }
94}
95
96/// Dequantize 4-bit values back to f32
97///
98/// # Arguments
99/// * `quantized` - Quantized4Bit data
100///
101/// # Returns
102/// Dequantized f32 values
103pub fn dequantize_4bit(quantized: &Quantized4Bit) -> Vec<f32> {
104    let mut result = Vec::with_capacity(quantized.len);
105
106    let num_blocks = quantized.scales.len();
107
108    for block_idx in 0..num_blocks {
109        let scale = quantized.scales[block_idx];
110        let start = block_idx * BLOCK_SIZE;
111        let end = (start + BLOCK_SIZE).min(quantized.len);
112        let block_len = end - start;
113
114        for i in 0..block_len {
115            let byte_idx = usize::midpoint(start, i);
116            let byte = quantized.data[byte_idx];
117
118            // Extract 4-bit value and sign-extend
119            let q_val = if (start + i).is_multiple_of(2) {
120                // Upper 4 bits - shift right to get value, then sign extend
121                let nibble = (byte >> 4) & 0x0F;
122                // Sign extend: if bit 3 is set, extend with 1s
123                if nibble & 0x08 != 0 {
124                    (nibble | 0xF0) as i8
125                } else {
126                    nibble as i8
127                }
128            } else {
129                // Lower 4 bits
130                let nibble = byte & 0x0F;
131                // Sign extend: if bit 3 is set, extend with 1s
132                if nibble & 0x08 != 0 {
133                    (nibble | 0xF0) as i8
134                } else {
135                    nibble as i8
136                }
137            };
138
139            // Dequantize
140            let deq_val = f32::from(q_val) * scale;
141            result.push(deq_val);
142        }
143    }
144
145    result
146}
147
148/// Quantize a single value to 4-bit
149///
150/// Maps f32 to integer in range [-7, 7]
151fn quantize_value(val: f32, scale: f32) -> i8 {
152    let normalized = val / scale;
153    let clamped = normalized.clamp(-7.0, 7.0);
154    clamped.round() as i8
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use approx::assert_abs_diff_eq;
161
162    #[test]
163    fn test_quantize_dequantize_round_trip() {
164        let values = vec![1.0, -2.0, 3.5, -4.2, 0.5, -0.8, 2.1, -1.5];
165        let quantized = quantize_4bit(&values);
166        let dequantized = dequantize_4bit(&quantized);
167
168        assert_eq!(dequantized.len(), values.len());
169
170        // Check approximate equality (quantization introduces some error)
171        // 4-bit quantization has limited precision (15 values from -7 to 7)
172        for (original, deq) in values.iter().zip(dequantized.iter()) {
173            let error = (original - deq).abs();
174            let relative_error = error / original.abs().max(1e-6);
175            assert!(
176                relative_error < 0.3,
177                "Relative error too large: {original} vs {deq} (error: {error}, rel_error: {relative_error})"
178            );
179        }
180    }
181
182    #[test]
183    fn test_quantize_zeros() {
184        let values = vec![0.0; 64];
185        let quantized = quantize_4bit(&values);
186        let dequantized = dequantize_4bit(&quantized);
187
188        for val in dequantized {
189            assert_abs_diff_eq!(val, 0.0, epsilon = 1e-6);
190        }
191    }
192
193    #[test]
194    fn test_quantize_uniform() {
195        let values = vec![1.0; 64];
196        let quantized = quantize_4bit(&values);
197        let dequantized = dequantize_4bit(&quantized);
198
199        for val in dequantized {
200            assert_abs_diff_eq!(val, 1.0, epsilon = 0.2);
201        }
202    }
203
204    #[test]
205    fn test_quantize_range() {
206        // Test values spanning the quantization range
207        let values: Vec<f32> = (-7..=7).map(|x| x as f32).collect();
208        let quantized = quantize_4bit(&values);
209        let dequantized = dequantize_4bit(&quantized);
210
211        for (original, deq) in values.iter().zip(dequantized.iter()) {
212            assert_abs_diff_eq!(original, deq, epsilon = 0.5);
213        }
214    }
215
216    #[test]
217    fn test_quantize_multiple_blocks() {
218        // Test with more than one block (>64 elements)
219        let values: Vec<f32> = (0..200).map(|i| (i as f32 * 0.1).sin()).collect();
220        let quantized = quantize_4bit(&values);
221        let dequantized = dequantize_4bit(&quantized);
222
223        assert_eq!(dequantized.len(), values.len());
224
225        // Verify multiple blocks were created
226        let expected_blocks = 200_usize.div_ceil(BLOCK_SIZE);
227        assert_eq!(quantized.scales.len(), expected_blocks);
228    }
229
230    #[test]
231    fn test_memory_savings() {
232        let values = vec![1.0; 1024];
233        let quantized = quantize_4bit(&values);
234
235        let original_bytes = values.len() * 4; // f32 = 4 bytes
236        let compressed_bytes = quantized.memory_bytes();
237
238        // Should achieve significant compression (close to 8x for large arrays)
239        let compression = original_bytes as f32 / compressed_bytes as f32;
240        assert!(compression > 6.0, "Compression ratio {compression} should be > 6.0");
241    }
242
243    #[test]
244    fn test_compression_ratio() {
245        let values = vec![1.5; 1024];
246        let quantized = quantize_4bit(&values);
247
248        let ratio = quantized.compression_ratio();
249        assert!(ratio > 6.0, "Compression ratio {ratio} should be > 6.0");
250    }
251
252    #[test]
253    fn test_quantize_small_values() {
254        let values = vec![0.001, 0.002, 0.003, 0.004];
255        let quantized = quantize_4bit(&values);
256        let dequantized = dequantize_4bit(&quantized);
257
258        // Small values should be preserved relatively well
259        for (original, deq) in values.iter().zip(dequantized.iter()) {
260            let error = (original - deq).abs();
261            assert!(error < 0.001, "Error {error} too large for small value");
262        }
263    }
264
265    #[test]
266    fn test_quantize_mixed_magnitudes() {
267        // Use values with similar magnitudes to avoid quantization precision issues
268        // When values span 1000x range in one block, small values may quantize to zero
269        let values = vec![10.0, 1.0, -5.0, 0.5, 7.5, -2.0];
270        let quantized = quantize_4bit(&values);
271        let dequantized = dequantize_4bit(&quantized);
272
273        assert_eq!(dequantized.len(), values.len());
274
275        // Each value should be within reasonable error
276        // 4-bit quantization has only 15 discrete values, so error can be substantial
277        // For values near zero, use absolute error; for larger values, use relative error
278        for (original, deq) in values.iter().zip(dequantized.iter()) {
279            let error = (original - deq).abs();
280            if original.abs() < 1.0 {
281                // Small values: use absolute error tolerance
282                assert!(
283                    error < 1.5,
284                    "Absolute error {error} too large for small value {original} vs {deq}"
285                );
286            } else {
287                // Larger values: use relative error
288                let relative_error = error / original.abs();
289                assert!(
290                    relative_error < 0.5,
291                    "Relative error {relative_error} too large for {original} vs {deq} (error: {error})"
292                );
293            }
294        }
295    }
296
297    #[test]
298    fn test_quantize_odd_length() {
299        // Test with odd number of elements (not divisible by 2 or BLOCK_SIZE)
300        let values: Vec<f32> = (0..77).map(|i| i as f32 * 0.5).collect();
301        let quantized = quantize_4bit(&values);
302        let dequantized = dequantize_4bit(&quantized);
303
304        assert_eq!(dequantized.len(), 77);
305    }
306
307    #[test]
308    fn test_quantized_data_size() {
309        let values = vec![1.0; 128];
310        let quantized = quantize_4bit(&values);
311
312        // 128 values = 64 bytes (2 values per byte)
313        assert_eq!(quantized.data.len(), 64);
314
315        // 128 values = 2 blocks of 64
316        assert_eq!(quantized.scales.len(), 2);
317    }
318}