entrenar/quant/
double_quant.rs1use serde::{Deserialize, Serialize};
15
16use super::quant4bit::{quantize_4bit, BLOCK_SIZE};
17
18pub const DOUBLE_QUANT_BLOCK_SIZE: usize = 256;
20
21#[derive(Clone, Debug, Serialize, Deserialize)]
26pub struct DoubleQuantized4Bit {
27 pub quantized_scales: Vec<u8>,
29 pub super_scales: Vec<f32>,
31 pub data: Vec<u8>,
33 pub len: usize,
35 pub num_blocks: usize,
37}
38
39impl DoubleQuantized4Bit {
40 pub fn memory_bytes(&self) -> usize {
42 self.quantized_scales.len() + self.super_scales.len() * 4 + self.data.len() }
46
47 pub fn compression_ratio(&self) -> f32 {
49 let original_bytes = self.len * 4;
50 let compressed_bytes = self.memory_bytes();
51 if compressed_bytes == 0 {
52 return 1.0;
53 }
54 original_bytes as f32 / compressed_bytes as f32
55 }
56
57 pub fn double_quant_savings(&self) -> usize {
59 let single_scale_bytes = self.num_blocks * 4;
62 let double_scale_bytes = self.quantized_scales.len() + self.super_scales.len() * 4;
63 single_scale_bytes.saturating_sub(double_scale_bytes)
64 }
65}
66
67pub fn quantize_4bit_double(values: &[f32]) -> DoubleQuantized4Bit {
72 let single = quantize_4bit(values);
74 let num_blocks = single.scales.len();
75
76 let num_super_blocks = num_blocks.div_ceil(DOUBLE_QUANT_BLOCK_SIZE);
78 let mut quantized_scales = Vec::with_capacity(num_blocks);
79 let mut super_scales = Vec::with_capacity(num_super_blocks);
80
81 for sb in 0..num_super_blocks {
82 let start = sb * DOUBLE_QUANT_BLOCK_SIZE;
83 let end = (start + DOUBLE_QUANT_BLOCK_SIZE).min(num_blocks);
84 let scale_block = &single.scales[start..end];
85
86 let max_scale =
89 scale_block.iter().copied().max_by(f32::total_cmp).unwrap_or(1e-8).max(1e-8); super_scales.push(max_scale);
91
92 for &scale in scale_block {
94 let normalized = scale / max_scale;
95 let q = (normalized * 255.0).round().clamp(0.0, 255.0) as u8;
96 quantized_scales.push(q);
97 }
98 }
99
100 DoubleQuantized4Bit {
101 quantized_scales,
102 super_scales,
103 data: single.data,
104 len: single.len,
105 num_blocks,
106 }
107}
108
109pub fn dequantize_4bit_double(dq: &DoubleQuantized4Bit) -> Vec<f32> {
111 let scales = reconstruct_scales(dq);
113
114 let mut result = Vec::with_capacity(dq.len);
116
117 for block_idx in 0..dq.num_blocks {
118 let scale = scales[block_idx];
119 let start = block_idx * BLOCK_SIZE;
120 let end = (start + BLOCK_SIZE).min(dq.len);
121 let block_len = end - start;
122
123 for i in 0..block_len {
124 let byte_idx = usize::midpoint(start, i);
125 let byte = dq.data[byte_idx];
126
127 let q_val = if (start + i).is_multiple_of(2) {
128 let nibble = (byte >> 4) & 0x0F;
129 if nibble & 0x08 != 0 {
130 (nibble | 0xF0) as i8
131 } else {
132 nibble as i8
133 }
134 } else {
135 let nibble = byte & 0x0F;
136 if nibble & 0x08 != 0 {
137 (nibble | 0xF0) as i8
138 } else {
139 nibble as i8
140 }
141 };
142
143 result.push(f32::from(q_val) * scale);
144 }
145 }
146
147 result
148}
149
150fn reconstruct_scales(dq: &DoubleQuantized4Bit) -> Vec<f32> {
152 let mut scales = Vec::with_capacity(dq.num_blocks);
153
154 for (i, &q_scale) in dq.quantized_scales.iter().enumerate() {
155 let super_idx = i / DOUBLE_QUANT_BLOCK_SIZE;
156 let super_scale = dq.super_scales[super_idx];
157 let scale = f32::from(q_scale) / 255.0 * super_scale;
159 scales.push(scale);
160 }
161
162 scales
163}
164
165#[cfg(test)]
166#[allow(clippy::unwrap_used)]
167mod tests {
168 use super::*;
169 use crate::quant::{dequantize_4bit, quantize_4bit};
170 use proptest::prelude::*;
171
172 #[test]
177 fn test_ent_lora_008_double_quant_within_1pct_of_single() {
178 let values: Vec<f32> = (0..4096).map(|i| (i as f32 * 0.1).sin() * 2.0).collect();
180
181 let single = quantize_4bit(&values);
182 let single_deq = dequantize_4bit(&single);
183
184 let double = quantize_4bit_double(&values);
185 let double_deq = dequantize_4bit_double(&double);
186
187 assert_eq!(single_deq.len(), double_deq.len());
188
189 for (i, (s, d)) in single_deq.iter().zip(double_deq.iter()).enumerate() {
190 let diff = (s - d).abs();
191 let tolerance = s.abs() * 0.01 + 1e-6; assert!(
193 diff <= tolerance,
194 "Double quant diverged at [{i}]: single={s}, double={d}, diff={diff}, tol={tolerance}"
195 );
196 }
197 }
198
199 #[test]
200 fn test_ent_lora_008_memory_savings() {
201 let values: Vec<f32> = (0..65536).map(|i| (i as f32 * 0.01).sin()).collect();
204
205 let single = quantize_4bit(&values);
206 let double = quantize_4bit_double(&values);
207
208 let single_bytes = single.memory_bytes();
209 let double_bytes = double.memory_bytes();
210
211 assert!(
213 double_bytes < single_bytes,
214 "Double quant ({double_bytes}B) should be smaller than single ({single_bytes}B)"
215 );
216
217 let savings = double.double_quant_savings();
219 assert!(savings > 0, "Should have positive savings, got {savings}");
220
221 let savings_bits_per_param = (savings as f64 * 8.0) / values.len() as f64;
223 assert!(
224 savings_bits_per_param > 0.3,
225 "Expected ~0.37 bits/param savings, got {savings_bits_per_param:.3}"
226 );
227 }
228
229 #[test]
230 fn test_ent_lora_008_round_trip_preserves_length() {
231 let values: Vec<f32> = (0..200).map(|i| i as f32 * 0.5).collect();
232 let dq = quantize_4bit_double(&values);
233 let result = dequantize_4bit_double(&dq);
234 assert_eq!(result.len(), values.len());
235 }
236
237 #[test]
238 fn test_ent_lora_008_zeros() {
239 let values = vec![0.0; 128];
240 let dq = quantize_4bit_double(&values);
241 let result = dequantize_4bit_double(&dq);
242
243 for val in result {
244 assert!(val.abs() < 1e-6, "Zero input should dequantize to ~0, got {val}");
245 }
246 }
247
248 #[test]
249 fn test_ent_lora_008_compression_ratio_better_than_single() {
250 let values: Vec<f32> = (0..65536).map(|i| (i as f32 * 0.01).cos()).collect();
251
252 let single = quantize_4bit(&values);
253 let double = quantize_4bit_double(&values);
254
255 assert!(
256 double.compression_ratio() > single.compression_ratio(),
257 "Double quant ratio ({:.2}) should exceed single ({:.2})",
258 double.compression_ratio(),
259 single.compression_ratio()
260 );
261 }
262
263 #[test]
264 fn test_ent_lora_008_small_input() {
265 let values = vec![1.0, -2.0, 3.0, -4.0];
267 let dq = quantize_4bit_double(&values);
268 let result = dequantize_4bit_double(&dq);
269 assert_eq!(result.len(), 4);
270 assert_eq!(dq.num_blocks, 1);
271 assert_eq!(dq.super_scales.len(), 1);
272 }
273
274 #[test]
275 fn test_ent_lora_008_scale_reconstruction_accuracy() {
276 let values: Vec<f32> = (0..4096).map(|i| (i as f32 * 0.05).sin() * 5.0).collect();
278
279 let single = quantize_4bit(&values);
280 let double = quantize_4bit_double(&values);
281 let reconstructed = reconstruct_scales(&double);
282
283 assert_eq!(reconstructed.len(), single.scales.len());
284
285 for (i, (orig, recon)) in single.scales.iter().zip(reconstructed.iter()).enumerate() {
286 let diff = (orig - recon).abs();
287 let tolerance = orig.abs() * 0.01 + 1e-8; assert!(
289 diff <= tolerance,
290 "Scale [{i}] diverged: orig={orig}, recon={recon}, diff={diff}"
291 );
292 }
293 }
294
295 proptest! {
300 #![proptest_config(proptest::test_runner::Config::with_cases(100))]
301
302 #[test]
303 fn prop_double_quant_within_1pct(
304 n in (64usize..1024).prop_map(|n| n - (n % 64)), magnitude in 0.1f32..10.0,
306 ) {
307 let values: Vec<f32> = (0..n)
308 .map(|i| (i as f32 * 0.1).sin() * magnitude)
309 .collect();
310
311 let single_deq = dequantize_4bit(&quantize_4bit(&values));
312 let double_deq = dequantize_4bit_double(&quantize_4bit_double(&values));
313
314 prop_assert_eq!(single_deq.len(), double_deq.len());
315
316 for (s, d) in single_deq.iter().zip(double_deq.iter()) {
317 let diff = (s - d).abs();
318 let tolerance = s.abs() * 0.01 + 1e-5;
319 prop_assert!(
320 diff <= tolerance,
321 "single={s}, double={d}, diff={diff}, tol={tolerance}"
322 );
323 }
324 }
325
326 #[test]
327 fn prop_double_quant_preserves_length(n in 1usize..2048) {
328 let values: Vec<f32> = (0..n).map(|i| i as f32 * 0.1).collect();
329 let dq = quantize_4bit_double(&values);
330 let result = dequantize_4bit_double(&dq);
331 prop_assert_eq!(result.len(), n);
332 }
333
334 #[test]
335 fn prop_double_quant_uses_less_memory(
336 n in (256usize..8192).prop_map(|n| n - (n % 64)),
337 ) {
338 let values: Vec<f32> = (0..n).map(|i| (i as f32 * 0.01).sin()).collect();
339 let single = quantize_4bit(&values);
340 let double = quantize_4bit_double(&values);
341
342 if single.scales.len() > DOUBLE_QUANT_BLOCK_SIZE {
345 prop_assert!(double.memory_bytes() < single.memory_bytes());
346 }
347 }
348 }
349}