ghostflow_nn/
quantization.rs

1//! Model Quantization
2//!
3//! Techniques for reducing model size and improving inference speed
4//! through quantization to lower precision formats.
5
6use ghostflow_core::tensor::Tensor;
7use std::collections::HashMap;
8
9/// Quantization scheme
10#[derive(Clone, Copy, Debug)]
11pub enum QuantizationScheme {
12    /// 8-bit integer quantization
13    INT8,
14    /// 16-bit floating point
15    FP16,
16    /// Dynamic quantization (quantize at runtime)
17    Dynamic,
18}
19
20/// Quantization configuration
21#[derive(Clone, Debug)]
22pub struct QuantizationConfig {
23    pub scheme: QuantizationScheme,
24    pub per_channel: bool,  // Per-channel vs per-tensor quantization
25    pub symmetric: bool,    // Symmetric vs asymmetric quantization
26}
27
28impl Default for QuantizationConfig {
29    fn default() -> Self {
30        Self {
31            scheme: QuantizationScheme::INT8,
32            per_channel: true,
33            symmetric: true,
34        }
35    }
36}
37
38/// Quantized tensor representation
39#[derive(Clone, Debug)]
40pub struct QuantizedTensor {
41    /// Quantized values (INT8 or FP16)
42    pub data: Vec<i8>,
43    /// Scale factors for dequantization
44    pub scales: Vec<f32>,
45    /// Zero points for asymmetric quantization
46    pub zero_points: Vec<i8>,
47    /// Original shape
48    pub shape: Vec<usize>,
49    /// Quantization scheme used
50    pub scheme: QuantizationScheme,
51}
52
53impl QuantizedTensor {
54    /// Create a new quantized tensor from float tensor
55    pub fn from_tensor(tensor: &Tensor, config: &QuantizationConfig) -> Self {
56        match config.scheme {
57            QuantizationScheme::INT8 => Self::quantize_int8(tensor, config),
58            QuantizationScheme::FP16 => Self::quantize_fp16(tensor, config),
59            QuantizationScheme::Dynamic => Self::quantize_int8(tensor, config),
60        }
61    }
62
63    fn quantize_int8(tensor: &Tensor, config: &QuantizationConfig) -> Self {
64        let data_guard = tensor.storage().as_slice::<f32>();
65        let data_slice = &*data_guard;
66        let shape = tensor.shape().dims().to_vec();
67
68        if config.per_channel {
69            // Per-channel quantization (typically for weights)
70            Self::quantize_per_channel_int8(data_slice, &shape, config.symmetric)
71        } else {
72            // Per-tensor quantization
73            Self::quantize_per_tensor_int8(data_slice, &shape, config.symmetric)
74        }
75    }
76
77    fn quantize_per_tensor_int8(data: &[f32], shape: &[usize], symmetric: bool) -> Self {
78        let min_val = data.iter().cloned().fold(f32::INFINITY, f32::min);
79        let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
80
81        let (scale, zero_point) = if symmetric {
82            // Symmetric quantization: [-127, 127]
83            let abs_max = min_val.abs().max(max_val.abs());
84            let scale = abs_max / 127.0;
85            (scale, 0i8)
86        } else {
87            // Asymmetric quantization: [-128, 127]
88            let scale = (max_val - min_val) / 255.0;
89            let zero_point = (-min_val / scale - 128.0).round() as i8;
90            (scale, zero_point)
91        };
92
93        let quantized_data: Vec<i8> = data
94            .iter()
95            .map(|&x| {
96                let q = (x / scale).round() as i32 + zero_point as i32;
97                q.clamp(-128, 127) as i8
98            })
99            .collect();
100
101        Self {
102            data: quantized_data,
103            scales: vec![scale],
104            zero_points: vec![zero_point],
105            shape: shape.to_vec(),
106            scheme: QuantizationScheme::INT8,
107        }
108    }
109
110    fn quantize_per_channel_int8(data: &[f32], shape: &[usize], symmetric: bool) -> Self {
111        // Assume first dimension is the channel dimension
112        let num_channels = shape[0];
113        let channel_size = data.len() / num_channels;
114
115        let mut scales = Vec::with_capacity(num_channels);
116        let mut zero_points = Vec::with_capacity(num_channels);
117        let mut quantized_data = Vec::with_capacity(data.len());
118
119        for ch in 0..num_channels {
120            let start = ch * channel_size;
121            let end = start + channel_size;
122            let channel_data = &data[start..end];
123
124            let min_val = channel_data.iter().cloned().fold(f32::INFINITY, f32::min);
125            let max_val = channel_data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
126
127            let (scale, zero_point) = if symmetric {
128                let abs_max = min_val.abs().max(max_val.abs());
129                let scale = abs_max / 127.0;
130                (scale, 0i8)
131            } else {
132                let scale = (max_val - min_val) / 255.0;
133                let zero_point = (-min_val / scale - 128.0).round() as i8;
134                (scale, zero_point)
135            };
136
137            scales.push(scale);
138            zero_points.push(zero_point);
139
140            for &x in channel_data {
141                let q = (x / scale).round() as i32 + zero_point as i32;
142                quantized_data.push(q.clamp(-128, 127) as i8);
143            }
144        }
145
146        Self {
147            data: quantized_data,
148            scales,
149            zero_points,
150            shape: shape.to_vec(),
151            scheme: QuantizationScheme::INT8,
152        }
153    }
154
155    fn quantize_fp16(_tensor: &Tensor, _config: &QuantizationConfig) -> Self {
156        // FP16 quantization would require half-precision support
157        // For now, use INT8 as fallback
158        unimplemented!("FP16 quantization requires half-precision support")
159    }
160
161    /// Dequantize back to float tensor
162    pub fn dequantize(&self) -> Tensor {
163        match self.scheme {
164            QuantizationScheme::INT8 | QuantizationScheme::Dynamic => {
165                self.dequantize_int8()
166            }
167            QuantizationScheme::FP16 => {
168                unimplemented!("FP16 dequantization not yet implemented")
169            }
170        }
171    }
172
173    fn dequantize_int8(&self) -> Tensor {
174        if self.scales.len() == 1 {
175            // Per-tensor dequantization
176            let scale = self.scales[0];
177            let zero_point = self.zero_points[0];
178
179            let dequantized: Vec<f32> = self.data
180                .iter()
181                .map(|&q| (q as f32 - zero_point as f32) * scale)
182                .collect();
183
184            Tensor::from_slice::<f32>(&dequantized, &self.shape).unwrap()
185        } else {
186            // Per-channel dequantization
187            let num_channels = self.shape[0];
188            let channel_size = self.data.len() / num_channels;
189            let mut dequantized = Vec::with_capacity(self.data.len());
190
191            for ch in 0..num_channels {
192                let scale = self.scales[ch];
193                let zero_point = self.zero_points[ch];
194                let start = ch * channel_size;
195                let end = start + channel_size;
196
197                for &q in &self.data[start..end] {
198                    dequantized.push((q as f32 - zero_point as f32) * scale);
199                }
200            }
201
202            Tensor::from_slice::<f32>(&dequantized, &self.shape).unwrap()
203        }
204    }
205
206    /// Get compression ratio
207    pub fn compression_ratio(&self) -> f32 {
208        let original_size = self.data.len() * std::mem::size_of::<f32>();
209        let quantized_size = self.data.len() * std::mem::size_of::<i8>()
210            + self.scales.len() * std::mem::size_of::<f32>()
211            + self.zero_points.len() * std::mem::size_of::<i8>();
212        original_size as f32 / quantized_size as f32
213    }
214}
215
216/// Quantization-aware training (QAT)
217/// 
218/// Simulates quantization during training to make the model robust to quantization errors.
219pub struct QuantizationAwareTraining {
220    config: QuantizationConfig,
221    fake_quantize: bool,
222}
223
224impl QuantizationAwareTraining {
225    pub fn new(config: QuantizationConfig) -> Self {
226        Self {
227            config,
228            fake_quantize: true,
229        }
230    }
231
232    /// Apply fake quantization (quantize then dequantize)
233    pub fn fake_quantize(&self, tensor: &Tensor) -> Tensor {
234        if !self.fake_quantize {
235            return tensor.clone();
236        }
237
238        let quantized = QuantizedTensor::from_tensor(tensor, &self.config);
239        quantized.dequantize()
240    }
241
242    /// Enable/disable fake quantization
243    pub fn set_fake_quantize(&mut self, enabled: bool) {
244        self.fake_quantize = enabled;
245    }
246}
247
248/// Dynamic quantization
249/// 
250/// Quantizes activations dynamically at runtime while keeping weights quantized.
251pub struct DynamicQuantization {
252    config: QuantizationConfig,
253    weight_quantized: HashMap<String, QuantizedTensor>,
254}
255
256impl DynamicQuantization {
257    pub fn new() -> Self {
258        Self {
259            config: QuantizationConfig {
260                scheme: QuantizationScheme::Dynamic,
261                per_channel: true,
262                symmetric: true,
263            },
264            weight_quantized: HashMap::new(),
265        }
266    }
267
268    /// Quantize model weights
269    pub fn quantize_weights(&mut self, name: &str, weights: &Tensor) {
270        let quantized = QuantizedTensor::from_tensor(weights, &self.config);
271        self.weight_quantized.insert(name.to_string(), quantized);
272    }
273
274    /// Get quantized weights
275    pub fn get_weights(&self, name: &str) -> Option<Tensor> {
276        self.weight_quantized.get(name).map(|q| q.dequantize())
277    }
278
279    /// Quantize activations dynamically
280    pub fn quantize_activation(&self, activation: &Tensor) -> QuantizedTensor {
281        let config = QuantizationConfig {
282            scheme: QuantizationScheme::INT8,
283            per_channel: false,  // Per-tensor for activations
284            symmetric: false,    // Asymmetric for activations
285        };
286        QuantizedTensor::from_tensor(activation, &config)
287    }
288}
289
290impl Default for DynamicQuantization {
291    fn default() -> Self {
292        Self::new()
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn test_per_tensor_quantization() {
302        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
303        let tensor = Tensor::from_slice(&data, &[2, 3]).unwrap();
304
305        let config = QuantizationConfig {
306            scheme: QuantizationScheme::INT8,
307            per_channel: false,
308            symmetric: true,
309        };
310
311        let quantized = QuantizedTensor::from_tensor(&tensor, &config);
312        let dequantized = quantized.dequantize();
313
314        // Check shape preserved
315        assert_eq!(dequantized.shape().dims(), tensor.shape().dims());
316
317        // Check values are close (within quantization error)
318        let original = tensor.storage().as_slice::<f32>();
319        let recovered = dequantized.storage().as_slice::<f32>();
320        for (o, r) in original.iter().zip(recovered.iter()) {
321            assert!((o - r).abs() < 0.1, "Original: {}, Recovered: {}", o, r);
322        }
323    }
324
325    #[test]
326    fn test_per_channel_quantization() {
327        let data = vec![1.0f32, 2.0, 3.0, 10.0, 20.0, 30.0];
328        let tensor = Tensor::from_slice(&data, &[2, 3]).unwrap();
329
330        let config = QuantizationConfig {
331            scheme: QuantizationScheme::INT8,
332            per_channel: true,
333            symmetric: true,
334        };
335
336        let quantized = QuantizedTensor::from_tensor(&tensor, &config);
337        
338        // Should have 2 scales (one per channel)
339        assert_eq!(quantized.scales.len(), 2);
340        
341        let dequantized = quantized.dequantize();
342        
343        // Check values are close
344        let original = tensor.storage().as_slice::<f32>();
345        let recovered = dequantized.storage().as_slice::<f32>();
346        for (o, r) in original.iter().zip(recovered.iter()) {
347            assert!((o - r).abs() < 0.5, "Original: {}, Recovered: {}", o, r);
348        }
349    }
350
351    #[test]
352    fn test_asymmetric_quantization() {
353        let data = vec![-5.0f32, -3.0, -1.0, 1.0, 3.0, 5.0];
354        let tensor = Tensor::from_slice(&data, &[6]).unwrap();
355
356        let config = QuantizationConfig {
357            scheme: QuantizationScheme::INT8,
358            per_channel: false,
359            symmetric: false,
360        };
361
362        let quantized = QuantizedTensor::from_tensor(&tensor, &config);
363        let dequantized = quantized.dequantize();
364
365        let original = tensor.storage().as_slice::<f32>();
366        let recovered = dequantized.storage().as_slice::<f32>();
367        for (o, r) in original.iter().zip(recovered.iter()) {
368            assert!((o - r).abs() < 0.1, "Original: {}, Recovered: {}", o, r);
369        }
370    }
371
372    #[test]
373    fn test_compression_ratio() {
374        let data: Vec<f32> = (0..1000).map(|x| x as f32).collect();
375        let tensor = Tensor::from_slice(&data, &[1000]).unwrap();
376
377        let config = QuantizationConfig {
378            scheme: QuantizationScheme::INT8,
379            per_channel: false,  // Use per-tensor for better compression on 1D
380            symmetric: true,
381        };
382        let quantized = QuantizedTensor::from_tensor(&tensor, &config);
383
384        let ratio = quantized.compression_ratio();
385        // INT8 should give ~4x compression (32-bit to 8-bit)
386        assert!(ratio > 3.5 && ratio < 4.5, "Compression ratio: {}", ratio);
387    }
388
389    #[test]
390    fn test_quantization_aware_training() {
391        let data = vec![1.0f32, 2.0, 3.0, 4.0];
392        let tensor = Tensor::from_slice(&data, &[4]).unwrap();
393
394        let config = QuantizationConfig::default();
395        let qat = QuantizationAwareTraining::new(config);
396
397        let fake_quantized = qat.fake_quantize(&tensor);
398
399        // Should have same shape
400        assert_eq!(fake_quantized.shape().dims(), tensor.shape().dims());
401
402        // Values should be close but not exact (due to quantization)
403        let original = tensor.storage().as_slice::<f32>();
404        let quantized = fake_quantized.storage().as_slice::<f32>();
405        for (o, q) in original.iter().zip(quantized.iter()) {
406            assert!((o - q).abs() < 0.1);
407        }
408    }
409
410    #[test]
411    fn test_dynamic_quantization() {
412        let mut dq = DynamicQuantization::new();
413
414        // Quantize weights
415        let weights = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
416        dq.quantize_weights("layer1", &weights);
417
418        // Retrieve weights
419        let retrieved = dq.get_weights("layer1").unwrap();
420        assert_eq!(retrieved.shape().dims(), weights.shape().dims());
421
422        // Quantize activation
423        let activation = Tensor::from_slice(&[0.5f32, 1.5, 2.5], &[3]).unwrap();
424        let q_activation = dq.quantize_activation(&activation);
425        
426        assert_eq!(q_activation.shape, vec![3]);
427        assert_eq!(q_activation.scales.len(), 1);  // Per-tensor for activations
428    }
429}