ghostflow_nn/
quantization.rs

1//! Model Quantization
2//!
3//! Techniques for reducing model size and improving inference speed
4//! through quantization to lower precision formats.
5
6use crate::module::Module;
7use ghostflow_core::tensor::Tensor;
8use std::collections::HashMap;
9
10/// Quantization scheme
11#[derive(Clone, Copy, Debug)]
12pub enum QuantizationScheme {
13    /// 8-bit integer quantization
14    INT8,
15    /// 16-bit floating point
16    FP16,
17    /// Dynamic quantization (quantize at runtime)
18    Dynamic,
19}
20
21/// Quantization configuration
22#[derive(Clone, Debug)]
23pub struct QuantizationConfig {
24    pub scheme: QuantizationScheme,
25    pub per_channel: bool,  // Per-channel vs per-tensor quantization
26    pub symmetric: bool,    // Symmetric vs asymmetric quantization
27}
28
29impl Default for QuantizationConfig {
30    fn default() -> Self {
31        Self {
32            scheme: QuantizationScheme::INT8,
33            per_channel: true,
34            symmetric: true,
35        }
36    }
37}
38
39/// Quantized tensor representation
40#[derive(Clone, Debug)]
41pub struct QuantizedTensor {
42    /// Quantized values (INT8 or FP16)
43    pub data: Vec<i8>,
44    /// Scale factors for dequantization
45    pub scales: Vec<f32>,
46    /// Zero points for asymmetric quantization
47    pub zero_points: Vec<i8>,
48    /// Original shape
49    pub shape: Vec<usize>,
50    /// Quantization scheme used
51    pub scheme: QuantizationScheme,
52}
53
54impl QuantizedTensor {
55    /// Create a new quantized tensor from float tensor
56    pub fn from_tensor(tensor: &Tensor, config: &QuantizationConfig) -> Self {
57        match config.scheme {
58            QuantizationScheme::INT8 => Self::quantize_int8(tensor, config),
59            QuantizationScheme::FP16 => Self::quantize_fp16(tensor, config),
60            QuantizationScheme::Dynamic => Self::quantize_int8(tensor, config),
61        }
62    }
63
64    fn quantize_int8(tensor: &Tensor, config: &QuantizationConfig) -> Self {
65        let data_guard = tensor.storage().as_slice::<f32>();
66        let data_slice = &*data_guard;
67        let shape = tensor.shape().dims().to_vec();
68
69        if config.per_channel {
70            // Per-channel quantization (typically for weights)
71            Self::quantize_per_channel_int8(data_slice, &shape, config.symmetric)
72        } else {
73            // Per-tensor quantization
74            Self::quantize_per_tensor_int8(data_slice, &shape, config.symmetric)
75        }
76    }
77
78    fn quantize_per_tensor_int8(data: &[f32], shape: &[usize], symmetric: bool) -> Self {
79        let min_val = data.iter().cloned().fold(f32::INFINITY, f32::min);
80        let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
81
82        let (scale, zero_point) = if symmetric {
83            // Symmetric quantization: [-127, 127]
84            let abs_max = min_val.abs().max(max_val.abs());
85            let scale = abs_max / 127.0;
86            (scale, 0i8)
87        } else {
88            // Asymmetric quantization: [-128, 127]
89            let scale = (max_val - min_val) / 255.0;
90            let zero_point = (-min_val / scale - 128.0).round() as i8;
91            (scale, zero_point)
92        };
93
94        let quantized_data: Vec<i8> = data
95            .iter()
96            .map(|&x| {
97                let q = (x / scale).round() as i32 + zero_point as i32;
98                q.clamp(-128, 127) as i8
99            })
100            .collect();
101
102        Self {
103            data: quantized_data,
104            scales: vec![scale],
105            zero_points: vec![zero_point],
106            shape: shape.to_vec(),
107            scheme: QuantizationScheme::INT8,
108        }
109    }
110
111    fn quantize_per_channel_int8(data: &[f32], shape: &[usize], symmetric: bool) -> Self {
112        // Assume first dimension is the channel dimension
113        let num_channels = shape[0];
114        let channel_size = data.len() / num_channels;
115
116        let mut scales = Vec::with_capacity(num_channels);
117        let mut zero_points = Vec::with_capacity(num_channels);
118        let mut quantized_data = Vec::with_capacity(data.len());
119
120        for ch in 0..num_channels {
121            let start = ch * channel_size;
122            let end = start + channel_size;
123            let channel_data = &data[start..end];
124
125            let min_val = channel_data.iter().cloned().fold(f32::INFINITY, f32::min);
126            let max_val = channel_data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
127
128            let (scale, zero_point) = if symmetric {
129                let abs_max = min_val.abs().max(max_val.abs());
130                let scale = abs_max / 127.0;
131                (scale, 0i8)
132            } else {
133                let scale = (max_val - min_val) / 255.0;
134                let zero_point = (-min_val / scale - 128.0).round() as i8;
135                (scale, zero_point)
136            };
137
138            scales.push(scale);
139            zero_points.push(zero_point);
140
141            for &x in channel_data {
142                let q = (x / scale).round() as i32 + zero_point as i32;
143                quantized_data.push(q.clamp(-128, 127) as i8);
144            }
145        }
146
147        Self {
148            data: quantized_data,
149            scales,
150            zero_points,
151            shape: shape.to_vec(),
152            scheme: QuantizationScheme::INT8,
153        }
154    }
155
156    fn quantize_fp16(_tensor: &Tensor, _config: &QuantizationConfig) -> Self {
157        // FP16 quantization would require half-precision support
158        // For now, use INT8 as fallback
159        unimplemented!("FP16 quantization requires half-precision support")
160    }
161
162    /// Dequantize back to float tensor
163    pub fn dequantize(&self) -> Tensor {
164        match self.scheme {
165            QuantizationScheme::INT8 | QuantizationScheme::Dynamic => {
166                self.dequantize_int8()
167            }
168            QuantizationScheme::FP16 => {
169                unimplemented!("FP16 dequantization not yet implemented")
170            }
171        }
172    }
173
174    fn dequantize_int8(&self) -> Tensor {
175        if self.scales.len() == 1 {
176            // Per-tensor dequantization
177            let scale = self.scales[0];
178            let zero_point = self.zero_points[0];
179
180            let dequantized: Vec<f32> = self.data
181                .iter()
182                .map(|&q| (q as f32 - zero_point as f32) * scale)
183                .collect();
184
185            Tensor::from_slice::<f32>(&dequantized, &self.shape).unwrap()
186        } else {
187            // Per-channel dequantization
188            let num_channels = self.shape[0];
189            let channel_size = self.data.len() / num_channels;
190            let mut dequantized = Vec::with_capacity(self.data.len());
191
192            for ch in 0..num_channels {
193                let scale = self.scales[ch];
194                let zero_point = self.zero_points[ch];
195                let start = ch * channel_size;
196                let end = start + channel_size;
197
198                for &q in &self.data[start..end] {
199                    dequantized.push((q as f32 - zero_point as f32) * scale);
200                }
201            }
202
203            Tensor::from_slice::<f32>(&dequantized, &self.shape).unwrap()
204        }
205    }
206
207    /// Get compression ratio
208    pub fn compression_ratio(&self) -> f32 {
209        let original_size = self.data.len() * std::mem::size_of::<f32>();
210        let quantized_size = self.data.len() * std::mem::size_of::<i8>()
211            + self.scales.len() * std::mem::size_of::<f32>()
212            + self.zero_points.len() * std::mem::size_of::<i8>();
213        original_size as f32 / quantized_size as f32
214    }
215}
216
217/// Quantization-aware training (QAT)
218/// 
219/// Simulates quantization during training to make the model robust to quantization errors.
220pub struct QuantizationAwareTraining {
221    config: QuantizationConfig,
222    fake_quantize: bool,
223}
224
225impl QuantizationAwareTraining {
226    pub fn new(config: QuantizationConfig) -> Self {
227        Self {
228            config,
229            fake_quantize: true,
230        }
231    }
232
233    /// Apply fake quantization (quantize then dequantize)
234    pub fn fake_quantize(&self, tensor: &Tensor) -> Tensor {
235        if !self.fake_quantize {
236            return tensor.clone();
237        }
238
239        let quantized = QuantizedTensor::from_tensor(tensor, &self.config);
240        quantized.dequantize()
241    }
242
243    /// Enable/disable fake quantization
244    pub fn set_fake_quantize(&mut self, enabled: bool) {
245        self.fake_quantize = enabled;
246    }
247}
248
249/// Dynamic quantization
250/// 
251/// Quantizes activations dynamically at runtime while keeping weights quantized.
252pub struct DynamicQuantization {
253    config: QuantizationConfig,
254    weight_quantized: HashMap<String, QuantizedTensor>,
255}
256
257impl DynamicQuantization {
258    pub fn new() -> Self {
259        Self {
260            config: QuantizationConfig {
261                scheme: QuantizationScheme::Dynamic,
262                per_channel: true,
263                symmetric: true,
264            },
265            weight_quantized: HashMap::new(),
266        }
267    }
268
269    /// Quantize model weights
270    pub fn quantize_weights(&mut self, name: &str, weights: &Tensor) {
271        let quantized = QuantizedTensor::from_tensor(weights, &self.config);
272        self.weight_quantized.insert(name.to_string(), quantized);
273    }
274
275    /// Get quantized weights
276    pub fn get_weights(&self, name: &str) -> Option<Tensor> {
277        self.weight_quantized.get(name).map(|q| q.dequantize())
278    }
279
280    /// Quantize activations dynamically
281    pub fn quantize_activation(&self, activation: &Tensor) -> QuantizedTensor {
282        let config = QuantizationConfig {
283            scheme: QuantizationScheme::INT8,
284            per_channel: false,  // Per-tensor for activations
285            symmetric: false,    // Asymmetric for activations
286        };
287        QuantizedTensor::from_tensor(activation, &config)
288    }
289}
290
291impl Default for DynamicQuantization {
292    fn default() -> Self {
293        Self::new()
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    #[test]
302    fn test_per_tensor_quantization() {
303        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
304        let tensor = Tensor::from_slice(&data, &[2, 3]).unwrap();
305
306        let config = QuantizationConfig {
307            scheme: QuantizationScheme::INT8,
308            per_channel: false,
309            symmetric: true,
310        };
311
312        let quantized = QuantizedTensor::from_tensor(&tensor, &config);
313        let dequantized = quantized.dequantize();
314
315        // Check shape preserved
316        assert_eq!(dequantized.shape().dims(), tensor.shape().dims());
317
318        // Check values are close (within quantization error)
319        let original = tensor.storage().as_slice::<f32>();
320        let recovered = dequantized.storage().as_slice::<f32>();
321        for (o, r) in original.iter().zip(recovered.iter()) {
322            assert!((o - r).abs() < 0.1, "Original: {}, Recovered: {}", o, r);
323        }
324    }
325
326    #[test]
327    fn test_per_channel_quantization() {
328        let data = vec![1.0f32, 2.0, 3.0, 10.0, 20.0, 30.0];
329        let tensor = Tensor::from_slice(&data, &[2, 3]).unwrap();
330
331        let config = QuantizationConfig {
332            scheme: QuantizationScheme::INT8,
333            per_channel: true,
334            symmetric: true,
335        };
336
337        let quantized = QuantizedTensor::from_tensor(&tensor, &config);
338        
339        // Should have 2 scales (one per channel)
340        assert_eq!(quantized.scales.len(), 2);
341        
342        let dequantized = quantized.dequantize();
343        
344        // Check values are close
345        let original = tensor.storage().as_slice::<f32>();
346        let recovered = dequantized.storage().as_slice::<f32>();
347        for (o, r) in original.iter().zip(recovered.iter()) {
348            assert!((o - r).abs() < 0.5, "Original: {}, Recovered: {}", o, r);
349        }
350    }
351
352    #[test]
353    fn test_asymmetric_quantization() {
354        let data = vec![-5.0f32, -3.0, -1.0, 1.0, 3.0, 5.0];
355        let tensor = Tensor::from_slice(&data, &[6]).unwrap();
356
357        let config = QuantizationConfig {
358            scheme: QuantizationScheme::INT8,
359            per_channel: false,
360            symmetric: false,
361        };
362
363        let quantized = QuantizedTensor::from_tensor(&tensor, &config);
364        let dequantized = quantized.dequantize();
365
366        let original = tensor.storage().as_slice::<f32>();
367        let recovered = dequantized.storage().as_slice::<f32>();
368        for (o, r) in original.iter().zip(recovered.iter()) {
369            assert!((o - r).abs() < 0.1, "Original: {}, Recovered: {}", o, r);
370        }
371    }
372
373    #[test]
374    fn test_compression_ratio() {
375        let data: Vec<f32> = (0..1000).map(|x| x as f32).collect();
376        let tensor = Tensor::from_slice(&data, &[1000]).unwrap();
377
378        let config = QuantizationConfig {
379            scheme: QuantizationScheme::INT8,
380            per_channel: false,  // Use per-tensor for better compression on 1D
381            symmetric: true,
382        };
383        let quantized = QuantizedTensor::from_tensor(&tensor, &config);
384
385        let ratio = quantized.compression_ratio();
386        // INT8 should give ~4x compression (32-bit to 8-bit)
387        assert!(ratio > 3.5 && ratio < 4.5, "Compression ratio: {}", ratio);
388    }
389
390    #[test]
391    fn test_quantization_aware_training() {
392        let data = vec![1.0f32, 2.0, 3.0, 4.0];
393        let tensor = Tensor::from_slice(&data, &[4]).unwrap();
394
395        let config = QuantizationConfig::default();
396        let qat = QuantizationAwareTraining::new(config);
397
398        let fake_quantized = qat.fake_quantize(&tensor);
399
400        // Should have same shape
401        assert_eq!(fake_quantized.shape().dims(), tensor.shape().dims());
402
403        // Values should be close but not exact (due to quantization)
404        let original = tensor.storage().as_slice::<f32>();
405        let quantized = fake_quantized.storage().as_slice::<f32>();
406        for (o, q) in original.iter().zip(quantized.iter()) {
407            assert!((o - q).abs() < 0.1);
408        }
409    }
410
411    #[test]
412    fn test_dynamic_quantization() {
413        let mut dq = DynamicQuantization::new();
414
415        // Quantize weights
416        let weights = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
417        dq.quantize_weights("layer1", &weights);
418
419        // Retrieve weights
420        let retrieved = dq.get_weights("layer1").unwrap();
421        assert_eq!(retrieved.shape().dims(), weights.shape().dims());
422
423        // Quantize activation
424        let activation = Tensor::from_slice(&[0.5f32, 1.5, 2.5], &[3]).unwrap();
425        let q_activation = dq.quantize_activation(&activation);
426        
427        assert_eq!(q_activation.shape, vec![3]);
428        assert_eq!(q_activation.scales.len(), 1);  // Per-tensor for activations
429    }
430}