Skip to main content

entrenar/quant/
double_quant.rs

1//! Double quantization for QLoRA (ENT-LoRA-008)
2//!
3//! Quantizes the FP32 absmax scale factors from 4-bit quantization to 8-bit,
4//! saving ~0.37 bits/param (~0.5 GB for a 7B model).
5//!
6//! Two-level quantization:
7//! - Level 1: Values → 4-bit with FP32 scales (standard, 64-element blocks)
8//! - Level 2: FP32 scales → 8-bit unsigned with FP32 super-scales (256-scale blocks)
9//!
10//! Memory per 64 values:
11//! - Without double quant: 32 bytes (data) + 4 bytes (scale) = 36 bytes = 4.50 bits/param
12//! - With double quant: 32 bytes (data) + 1 byte (scale) + 0.016 bytes (super) ≈ 33 bytes = 4.13 bits/param
13
14use serde::{Deserialize, Serialize};
15
16use super::quant4bit::{quantize_4bit, BLOCK_SIZE};
17
18/// Block size for second-level scale quantization (256 scales per super-block)
19pub const DOUBLE_QUANT_BLOCK_SIZE: usize = 256;
20
21/// Double-quantized 4-bit representation
22///
23/// Same 4-bit packed data as `Quantized4Bit`, but scale factors are stored as
24/// 8-bit unsigned values with a second-level FP32 super-scale per group of 256.
25#[derive(Clone, Debug, Serialize, Deserialize)]
26pub struct DoubleQuantized4Bit {
27    /// 8-bit quantized scale factors (one per first-level block)
28    pub quantized_scales: Vec<u8>,
29    /// Super-scale factors (one FP32 per DOUBLE_QUANT_BLOCK_SIZE first-level blocks)
30    pub super_scales: Vec<f32>,
31    /// Quantized data: 2 values per byte (4 bits each) — identical to Quantized4Bit
32    pub data: Vec<u8>,
33    /// Original number of elements
34    pub len: usize,
35    /// Number of first-level blocks
36    pub num_blocks: usize,
37}
38
39impl DoubleQuantized4Bit {
40    /// Memory usage in bytes
41    pub fn memory_bytes(&self) -> usize {
42        self.quantized_scales.len() // 1 byte per quantized scale
43            + self.super_scales.len() * 4 // 4 bytes per f32 super-scale
44            + self.data.len() // packed 4-bit data
45    }
46
47    /// Compression ratio vs f32
48    pub fn compression_ratio(&self) -> f32 {
49        let original_bytes = self.len * 4;
50        let compressed_bytes = self.memory_bytes();
51        if compressed_bytes == 0 {
52            return 1.0;
53        }
54        original_bytes as f32 / compressed_bytes as f32
55    }
56
57    /// Memory saved compared to single quantization (bytes)
58    pub fn double_quant_savings(&self) -> usize {
59        // Single quant: num_blocks * 4 bytes per f32 scale
60        // Double quant: num_blocks * 1 byte + super_scales * 4 bytes
61        let single_scale_bytes = self.num_blocks * 4;
62        let double_scale_bytes = self.quantized_scales.len() + self.super_scales.len() * 4;
63        single_scale_bytes.saturating_sub(double_scale_bytes)
64    }
65}
66
67/// Quantize values to 4-bit with double quantization of scale factors
68///
69/// First applies standard 4-bit quantization, then quantizes the resulting
70/// FP32 scale factors to 8-bit with a second-level block size of 256.
71pub fn quantize_4bit_double(values: &[f32]) -> DoubleQuantized4Bit {
72    // Step 1: Standard 4-bit quantization
73    let single = quantize_4bit(values);
74    let num_blocks = single.scales.len();
75
76    // Step 2: Double-quantize the scale factors
77    let num_super_blocks = num_blocks.div_ceil(DOUBLE_QUANT_BLOCK_SIZE);
78    let mut quantized_scales = Vec::with_capacity(num_blocks);
79    let mut super_scales = Vec::with_capacity(num_super_blocks);
80
81    for sb in 0..num_super_blocks {
82        let start = sb * DOUBLE_QUANT_BLOCK_SIZE;
83        let end = (start + DOUBLE_QUANT_BLOCK_SIZE).min(num_blocks);
84        let scale_block = &single.scales[start..end];
85
86        // Super-scale = max of this block of scales
87        // Scales are always non-negative (absmax / 7)
88        let max_scale =
89            scale_block.iter().copied().max_by(f32::total_cmp).unwrap_or(1e-8).max(1e-8); // avoid division by zero
90        super_scales.push(max_scale);
91
92        // Quantize each scale to u8: q = round(scale / max_scale * 255)
93        for &scale in scale_block {
94            let normalized = scale / max_scale;
95            let q = (normalized * 255.0).round().clamp(0.0, 255.0) as u8;
96            quantized_scales.push(q);
97        }
98    }
99
100    DoubleQuantized4Bit {
101        quantized_scales,
102        super_scales,
103        data: single.data,
104        len: single.len,
105        num_blocks,
106    }
107}
108
109/// Dequantize double-quantized 4-bit values back to f32
110pub fn dequantize_4bit_double(dq: &DoubleQuantized4Bit) -> Vec<f32> {
111    // Step 1: Reconstruct FP32 scales from double-quantized representation
112    let scales = reconstruct_scales(dq);
113
114    // Step 2: Standard 4-bit dequantization using reconstructed scales
115    let mut result = Vec::with_capacity(dq.len);
116
117    for block_idx in 0..dq.num_blocks {
118        let scale = scales[block_idx];
119        let start = block_idx * BLOCK_SIZE;
120        let end = (start + BLOCK_SIZE).min(dq.len);
121        let block_len = end - start;
122
123        for i in 0..block_len {
124            let byte_idx = usize::midpoint(start, i);
125            let byte = dq.data[byte_idx];
126
127            let q_val = if (start + i).is_multiple_of(2) {
128                let nibble = (byte >> 4) & 0x0F;
129                if nibble & 0x08 != 0 {
130                    (nibble | 0xF0) as i8
131                } else {
132                    nibble as i8
133                }
134            } else {
135                let nibble = byte & 0x0F;
136                if nibble & 0x08 != 0 {
137                    (nibble | 0xF0) as i8
138                } else {
139                    nibble as i8
140                }
141            };
142
143            result.push(f32::from(q_val) * scale);
144        }
145    }
146
147    result
148}
149
150/// Reconstruct FP32 scales from double-quantized representation
151fn reconstruct_scales(dq: &DoubleQuantized4Bit) -> Vec<f32> {
152    let mut scales = Vec::with_capacity(dq.num_blocks);
153
154    for (i, &q_scale) in dq.quantized_scales.iter().enumerate() {
155        let super_idx = i / DOUBLE_QUANT_BLOCK_SIZE;
156        let super_scale = dq.super_scales[super_idx];
157        // Dequantize: scale = q / 255 * super_scale
158        let scale = f32::from(q_scale) / 255.0 * super_scale;
159        scales.push(scale);
160    }
161
162    scales
163}
164
165#[cfg(test)]
166#[allow(clippy::unwrap_used)]
167mod tests {
168    use super::*;
169    use crate::quant::{dequantize_4bit, quantize_4bit};
170    use proptest::prelude::*;
171
172    // ========================================================================
173    // SPEC REQUIREMENT: dequantize(double_quant(x)) within 1% of dequantize(single_quant(x))
174    // ========================================================================
175
176    #[test]
177    fn test_ent_lora_008_double_quant_within_1pct_of_single() {
178        // Realistic transformer weight distribution
179        let values: Vec<f32> = (0..4096).map(|i| (i as f32 * 0.1).sin() * 2.0).collect();
180
181        let single = quantize_4bit(&values);
182        let single_deq = dequantize_4bit(&single);
183
184        let double = quantize_4bit_double(&values);
185        let double_deq = dequantize_4bit_double(&double);
186
187        assert_eq!(single_deq.len(), double_deq.len());
188
189        for (i, (s, d)) in single_deq.iter().zip(double_deq.iter()).enumerate() {
190            let diff = (s - d).abs();
191            let tolerance = s.abs() * 0.01 + 1e-6; // 1% relative + small absolute
192            assert!(
193                diff <= tolerance,
194                "Double quant diverged at [{i}]: single={s}, double={d}, diff={diff}, tol={tolerance}"
195            );
196        }
197    }
198
199    #[test]
200    fn test_ent_lora_008_memory_savings() {
201        // 7B model params: ~7 billion. Test with representative block count.
202        // 65536 values = 1024 blocks of 64
203        let values: Vec<f32> = (0..65536).map(|i| (i as f32 * 0.01).sin()).collect();
204
205        let single = quantize_4bit(&values);
206        let double = quantize_4bit_double(&values);
207
208        let single_bytes = single.memory_bytes();
209        let double_bytes = double.memory_bytes();
210
211        // Double quant should use fewer bytes
212        assert!(
213            double_bytes < single_bytes,
214            "Double quant ({double_bytes}B) should be smaller than single ({single_bytes}B)"
215        );
216
217        // Savings should be ~3 bytes per block (4 bytes f32 → 1 byte u8)
218        let savings = double.double_quant_savings();
219        assert!(savings > 0, "Should have positive savings, got {savings}");
220
221        // Savings per param: (savings * 8 bits) / num_values
222        let savings_bits_per_param = (savings as f64 * 8.0) / values.len() as f64;
223        assert!(
224            savings_bits_per_param > 0.3,
225            "Expected ~0.37 bits/param savings, got {savings_bits_per_param:.3}"
226        );
227    }
228
229    #[test]
230    fn test_ent_lora_008_round_trip_preserves_length() {
231        let values: Vec<f32> = (0..200).map(|i| i as f32 * 0.5).collect();
232        let dq = quantize_4bit_double(&values);
233        let result = dequantize_4bit_double(&dq);
234        assert_eq!(result.len(), values.len());
235    }
236
237    #[test]
238    fn test_ent_lora_008_zeros() {
239        let values = vec![0.0; 128];
240        let dq = quantize_4bit_double(&values);
241        let result = dequantize_4bit_double(&dq);
242
243        for val in result {
244            assert!(val.abs() < 1e-6, "Zero input should dequantize to ~0, got {val}");
245        }
246    }
247
248    #[test]
249    fn test_ent_lora_008_compression_ratio_better_than_single() {
250        let values: Vec<f32> = (0..65536).map(|i| (i as f32 * 0.01).cos()).collect();
251
252        let single = quantize_4bit(&values);
253        let double = quantize_4bit_double(&values);
254
255        assert!(
256            double.compression_ratio() > single.compression_ratio(),
257            "Double quant ratio ({:.2}) should exceed single ({:.2})",
258            double.compression_ratio(),
259            single.compression_ratio()
260        );
261    }
262
263    #[test]
264    fn test_ent_lora_008_small_input() {
265        // Fewer values than one scale block
266        let values = vec![1.0, -2.0, 3.0, -4.0];
267        let dq = quantize_4bit_double(&values);
268        let result = dequantize_4bit_double(&dq);
269        assert_eq!(result.len(), 4);
270        assert_eq!(dq.num_blocks, 1);
271        assert_eq!(dq.super_scales.len(), 1);
272    }
273
274    #[test]
275    fn test_ent_lora_008_scale_reconstruction_accuracy() {
276        // Verify scale factors survive double quantization well
277        let values: Vec<f32> = (0..4096).map(|i| (i as f32 * 0.05).sin() * 5.0).collect();
278
279        let single = quantize_4bit(&values);
280        let double = quantize_4bit_double(&values);
281        let reconstructed = reconstruct_scales(&double);
282
283        assert_eq!(reconstructed.len(), single.scales.len());
284
285        for (i, (orig, recon)) in single.scales.iter().zip(reconstructed.iter()).enumerate() {
286            let diff = (orig - recon).abs();
287            let tolerance = orig.abs() * 0.01 + 1e-8; // 1% relative
288            assert!(
289                diff <= tolerance,
290                "Scale [{i}] diverged: orig={orig}, recon={recon}, diff={diff}"
291            );
292        }
293    }
294
295    // ========================================================================
296    // PROPERTY TESTS
297    // ========================================================================
298
299    proptest! {
300        #![proptest_config(proptest::test_runner::Config::with_cases(100))]
301
302        #[test]
303        fn prop_double_quant_within_1pct(
304            n in (64usize..1024).prop_map(|n| n - (n % 64)), // multiple of 64
305            magnitude in 0.1f32..10.0,
306        ) {
307            let values: Vec<f32> = (0..n)
308                .map(|i| (i as f32 * 0.1).sin() * magnitude)
309                .collect();
310
311            let single_deq = dequantize_4bit(&quantize_4bit(&values));
312            let double_deq = dequantize_4bit_double(&quantize_4bit_double(&values));
313
314            prop_assert_eq!(single_deq.len(), double_deq.len());
315
316            for (s, d) in single_deq.iter().zip(double_deq.iter()) {
317                let diff = (s - d).abs();
318                let tolerance = s.abs() * 0.01 + 1e-5;
319                prop_assert!(
320                    diff <= tolerance,
321                    "single={s}, double={d}, diff={diff}, tol={tolerance}"
322                );
323            }
324        }
325
326        #[test]
327        fn prop_double_quant_preserves_length(n in 1usize..2048) {
328            let values: Vec<f32> = (0..n).map(|i| i as f32 * 0.1).collect();
329            let dq = quantize_4bit_double(&values);
330            let result = dequantize_4bit_double(&dq);
331            prop_assert_eq!(result.len(), n);
332        }
333
334        #[test]
335        fn prop_double_quant_uses_less_memory(
336            n in (256usize..8192).prop_map(|n| n - (n % 64)),
337        ) {
338            let values: Vec<f32> = (0..n).map(|i| (i as f32 * 0.01).sin()).collect();
339            let single = quantize_4bit(&values);
340            let double = quantize_4bit_double(&values);
341
342            // Double quant saves 3 bytes per scale (f32→u8) minus super-scale overhead
343            // For large enough inputs this should always be positive
344            if single.scales.len() > DOUBLE_QUANT_BLOCK_SIZE {
345                prop_assert!(double.memory_bytes() < single.memory_bytes());
346            }
347        }
348    }
349}