kizzasi_core/
quantization.rs

1//! # Dynamic Quantization
2//!
3//! Post-training quantization for model compression and faster inference.
4//!
5//! ## Features
6//!
7//! - **INT8 Quantization**: 8-bit integer quantization with calibration
8//! - **INT4 Quantization**: 4-bit quantization for extreme compression
9//! - **Per-Tensor Quantization**: Single scale/zero-point per tensor
10//! - **Per-Channel Quantization**: Separate scale/zero-point per channel
11//! - **Dynamic Range**: Automatic dynamic range calculation
12//! - **Calibration**: Statistics collection for better quantization
13//!
14//! ## References
15//!
16//! - "Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference"
17//! - "ZeroQuant: Efficient and Affordable Post-Training Quantization for Large-Scale Transformers"
18
19use crate::{CoreError, CoreResult};
20use scirs2_core::ndarray::{Array1, Array2, Axis};
21use serde::{Deserialize, Serialize};
22
23/// Quantization data type
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
25pub enum QuantizationType {
26    /// 8-bit signed integer
27    INT8,
28    /// 4-bit signed integer
29    INT4,
30    /// 16-bit floating point
31    FP16,
32}
33
34/// Quantization scheme
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum QuantizationScheme {
37    /// Per-tensor quantization (single scale and zero-point)
38    PerTensor,
39    /// Per-channel quantization (separate scale and zero-point per output channel)
40    PerChannel,
41}
42
43/// Quantization parameters
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct QuantizationParams {
46    /// Quantization type
47    pub qtype: QuantizationType,
48    /// Quantization scheme
49    pub scheme: QuantizationScheme,
50    /// Scale factors (per-tensor or per-channel)
51    pub scales: Vec<f32>,
52    /// Zero points (per-tensor or per-channel)
53    pub zero_points: Vec<i32>,
54    /// Original shape
55    pub shape: Vec<usize>,
56}
57
58impl QuantizationParams {
59    /// Create new quantization parameters
60    pub fn new(
61        qtype: QuantizationType,
62        scheme: QuantizationScheme,
63        scales: Vec<f32>,
64        zero_points: Vec<i32>,
65        shape: Vec<usize>,
66    ) -> Self {
67        Self {
68            qtype,
69            scheme,
70            scales,
71            zero_points,
72            shape,
73        }
74    }
75
76    /// Get quantization range
77    pub fn qrange(&self) -> (i32, i32) {
78        match self.qtype {
79            QuantizationType::INT8 => (-128, 127),
80            QuantizationType::INT4 => (-8, 7),
81            QuantizationType::FP16 => (0, 0), // Not applicable for FP16
82        }
83    }
84
85    /// Validate parameters
86    pub fn validate(&self) -> CoreResult<()> {
87        match self.scheme {
88            QuantizationScheme::PerTensor => {
89                if self.scales.len() != 1 || self.zero_points.len() != 1 {
90                    return Err(CoreError::InvalidConfig(
91                        "PerTensor scheme requires exactly 1 scale and zero-point".into(),
92                    ));
93                }
94            }
95            QuantizationScheme::PerChannel => {
96                if self.shape.is_empty() {
97                    return Err(CoreError::InvalidConfig(
98                        "PerChannel scheme requires shape information".into(),
99                    ));
100                }
101                let num_channels = self.shape[0];
102                if self.scales.len() != num_channels || self.zero_points.len() != num_channels {
103                    return Err(CoreError::InvalidConfig(format!(
104                        "PerChannel scheme requires {} scales and zero-points, got {} and {}",
105                        num_channels,
106                        self.scales.len(),
107                        self.zero_points.len()
108                    )));
109                }
110            }
111        }
112        Ok(())
113    }
114}
115
116/// Quantized tensor representation
117#[derive(Debug, Clone)]
118pub struct QuantizedTensor {
119    /// Quantized data (stored as i8 or i16)
120    pub data: Vec<i8>,
121    /// Quantization parameters
122    pub params: QuantizationParams,
123}
124
125impl QuantizedTensor {
126    /// Create a new quantized tensor
127    pub fn new(data: Vec<i8>, params: QuantizationParams) -> CoreResult<Self> {
128        params.validate()?;
129        Ok(Self { data, params })
130    }
131
132    /// Dequantize to f32 array
133    pub fn dequantize_1d(&self) -> CoreResult<Array1<f32>> {
134        if self.params.shape.len() != 1 {
135            return Err(CoreError::InvalidConfig(
136                "Expected 1D tensor for dequantize_1d".into(),
137            ));
138        }
139
140        let size = self.params.shape[0];
141        let mut result = Array1::zeros(size);
142
143        match self.params.scheme {
144            QuantizationScheme::PerTensor => {
145                let scale = self.params.scales[0];
146                let zero_point = self.params.zero_points[0];
147
148                for (i, &q_val) in self.data.iter().enumerate() {
149                    result[i] = (q_val as i32 - zero_point) as f32 * scale;
150                }
151            }
152            QuantizationScheme::PerChannel => {
153                // For 1D, per-channel doesn't make sense, treat as per-tensor
154                let scale = self.params.scales[0];
155                let zero_point = self.params.zero_points[0];
156
157                for (i, &q_val) in self.data.iter().enumerate() {
158                    result[i] = (q_val as i32 - zero_point) as f32 * scale;
159                }
160            }
161        }
162
163        Ok(result)
164    }
165
166    /// Dequantize to f32 2D array
167    pub fn dequantize_2d(&self) -> CoreResult<Array2<f32>> {
168        if self.params.shape.len() != 2 {
169            return Err(CoreError::InvalidConfig(
170                "Expected 2D tensor for dequantize_2d".into(),
171            ));
172        }
173
174        let rows = self.params.shape[0];
175        let cols = self.params.shape[1];
176        let mut result = Array2::zeros((rows, cols));
177
178        match self.params.scheme {
179            QuantizationScheme::PerTensor => {
180                let scale = self.params.scales[0];
181                let zero_point = self.params.zero_points[0];
182
183                for i in 0..rows {
184                    for j in 0..cols {
185                        let idx = i * cols + j;
186                        let q_val = self.data[idx];
187                        result[[i, j]] = (q_val as i32 - zero_point) as f32 * scale;
188                    }
189                }
190            }
191            QuantizationScheme::PerChannel => {
192                // Per-channel: one scale/zero-point per output channel (row)
193                for i in 0..rows {
194                    let scale = self.params.scales[i];
195                    let zero_point = self.params.zero_points[i];
196
197                    for j in 0..cols {
198                        let idx = i * cols + j;
199                        let q_val = self.data[idx];
200                        result[[i, j]] = (q_val as i32 - zero_point) as f32 * scale;
201                    }
202                }
203            }
204        }
205
206        Ok(result)
207    }
208
209    /// Get compression ratio
210    pub fn compression_ratio(&self) -> f32 {
211        let original_size = self.data.len() * std::mem::size_of::<f32>();
212        let quantized_size = self.data.len() * std::mem::size_of::<i8>()
213            + self.params.scales.len() * std::mem::size_of::<f32>()
214            + self.params.zero_points.len() * std::mem::size_of::<i32>();
215        original_size as f32 / quantized_size as f32
216    }
217}
218
219/// Dynamic quantizer
220pub struct DynamicQuantizer {
221    /// Quantization type
222    qtype: QuantizationType,
223    /// Quantization scheme
224    scheme: QuantizationScheme,
225}
226
227impl DynamicQuantizer {
228    /// Create a new dynamic quantizer
229    pub fn new(qtype: QuantizationType, scheme: QuantizationScheme) -> Self {
230        Self { qtype, scheme }
231    }
232
233    /// Create INT8 per-tensor quantizer
234    pub fn int8_per_tensor() -> Self {
235        Self::new(QuantizationType::INT8, QuantizationScheme::PerTensor)
236    }
237
238    /// Create INT8 per-channel quantizer
239    pub fn int8_per_channel() -> Self {
240        Self::new(QuantizationType::INT8, QuantizationScheme::PerChannel)
241    }
242
243    /// Create INT4 per-channel quantizer
244    pub fn int4_per_channel() -> Self {
245        Self::new(QuantizationType::INT4, QuantizationScheme::PerChannel)
246    }
247
248    /// Quantize a 1D array
249    pub fn quantize_1d(&self, data: &Array1<f32>) -> CoreResult<QuantizedTensor> {
250        let min_val = data.iter().cloned().fold(f32::INFINITY, f32::min);
251        let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
252
253        let (qmin, qmax) = self.get_qrange();
254
255        // Handle case where all values are the same (avoid division by zero)
256        let scale = if (max_val - min_val).abs() < 1e-8 {
257            1.0
258        } else {
259            (max_val - min_val) / (qmax - qmin) as f32
260        };
261
262        let zero_point = if (max_val - min_val).abs() < 1e-8 {
263            0
264        } else {
265            qmin - (min_val / scale).round() as i32
266        };
267
268        let mut quantized = Vec::with_capacity(data.len());
269        for &val in data.iter() {
270            let q_val = (val / scale).round() as i32 + zero_point;
271            let q_val_clamped = q_val.clamp(qmin, qmax);
272            quantized.push(q_val_clamped as i8);
273        }
274
275        let params = QuantizationParams::new(
276            self.qtype,
277            self.scheme,
278            vec![scale],
279            vec![zero_point],
280            vec![data.len()],
281        );
282
283        QuantizedTensor::new(quantized, params)
284    }
285
286    /// Quantize a 2D array
287    pub fn quantize_2d(&self, data: &Array2<f32>) -> CoreResult<QuantizedTensor> {
288        let (rows, cols) = data.dim();
289        let (qmin, qmax) = self.get_qrange();
290
291        match self.scheme {
292            QuantizationScheme::PerTensor => {
293                // Find global min/max
294                let min_val = data.iter().cloned().fold(f32::INFINITY, f32::min);
295                let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
296
297                let scale = (max_val - min_val) / (qmax - qmin) as f32;
298                let zero_point = qmin - (min_val / scale).round() as i32;
299
300                let mut quantized = Vec::with_capacity(rows * cols);
301                for &val in data.iter() {
302                    let q_val = (val / scale).round() as i32 + zero_point;
303                    let q_val_clamped = q_val.clamp(qmin, qmax);
304                    quantized.push(q_val_clamped as i8);
305                }
306
307                let params = QuantizationParams::new(
308                    self.qtype,
309                    self.scheme,
310                    vec![scale],
311                    vec![zero_point],
312                    vec![rows, cols],
313                );
314
315                QuantizedTensor::new(quantized, params)
316            }
317            QuantizationScheme::PerChannel => {
318                // Per-channel: compute scale/zero-point for each row
319                let mut scales = Vec::with_capacity(rows);
320                let mut zero_points = Vec::with_capacity(rows);
321                let mut quantized = Vec::with_capacity(rows * cols);
322
323                for row in data.axis_iter(Axis(0)) {
324                    let min_val = row.iter().cloned().fold(f32::INFINITY, f32::min);
325                    let max_val = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
326
327                    let scale = (max_val - min_val) / (qmax - qmin) as f32;
328                    let zero_point = qmin - (min_val / scale).round() as i32;
329
330                    scales.push(scale);
331                    zero_points.push(zero_point);
332
333                    for &val in row.iter() {
334                        let q_val = (val / scale).round() as i32 + zero_point;
335                        let q_val_clamped = q_val.clamp(qmin, qmax);
336                        quantized.push(q_val_clamped as i8);
337                    }
338                }
339
340                let params = QuantizationParams::new(
341                    self.qtype,
342                    self.scheme,
343                    scales,
344                    zero_points,
345                    vec![rows, cols],
346                );
347
348                QuantizedTensor::new(quantized, params)
349            }
350        }
351    }
352
353    /// Get quantization range
354    fn get_qrange(&self) -> (i32, i32) {
355        match self.qtype {
356            QuantizationType::INT8 => (-128, 127),
357            QuantizationType::INT4 => (-8, 7),
358            QuantizationType::FP16 => (0, 0),
359        }
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn test_quantization_types() {
369        let qt = QuantizationType::INT8;
370        assert_eq!(qt, QuantizationType::INT8);
371
372        let qs = QuantizationScheme::PerTensor;
373        assert_eq!(qs, QuantizationScheme::PerTensor);
374    }
375
376    #[test]
377    fn test_quantization_params() {
378        let params = QuantizationParams::new(
379            QuantizationType::INT8,
380            QuantizationScheme::PerTensor,
381            vec![0.1],
382            vec![0],
383            vec![100],
384        );
385
386        assert_eq!(params.qtype, QuantizationType::INT8);
387        assert_eq!(params.qrange(), (-128, 127));
388        assert!(params.validate().is_ok());
389    }
390
391    #[test]
392    fn test_params_validation() {
393        // PerTensor with wrong number of scales
394        let mut params = QuantizationParams::new(
395            QuantizationType::INT8,
396            QuantizationScheme::PerTensor,
397            vec![0.1, 0.2],
398            vec![0],
399            vec![100],
400        );
401        assert!(params.validate().is_err());
402
403        // PerChannel with wrong number of scales
404        params = QuantizationParams::new(
405            QuantizationType::INT8,
406            QuantizationScheme::PerChannel,
407            vec![0.1],
408            vec![0, 1],
409            vec![2, 100],
410        );
411        assert!(params.validate().is_err());
412
413        // Correct PerChannel
414        params = QuantizationParams::new(
415            QuantizationType::INT8,
416            QuantizationScheme::PerChannel,
417            vec![0.1, 0.2],
418            vec![0, 1],
419            vec![2, 100],
420        );
421        assert!(params.validate().is_ok());
422    }
423
424    #[test]
425    fn test_dynamic_quantizer_creation() {
426        let quantizer = DynamicQuantizer::int8_per_tensor();
427        assert_eq!(quantizer.qtype, QuantizationType::INT8);
428        assert_eq!(quantizer.scheme, QuantizationScheme::PerTensor);
429
430        let quantizer = DynamicQuantizer::int4_per_channel();
431        assert_eq!(quantizer.qtype, QuantizationType::INT4);
432        assert_eq!(quantizer.scheme, QuantizationScheme::PerChannel);
433    }
434
435    #[test]
436    fn test_quantize_dequantize_1d() {
437        let quantizer = DynamicQuantizer::int8_per_tensor();
438        let data = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
439
440        let quantized = quantizer.quantize_1d(&data).unwrap();
441        assert_eq!(quantized.data.len(), 5);
442
443        let dequantized = quantized.dequantize_1d().unwrap();
444        assert_eq!(dequantized.len(), 5);
445
446        // Check approximate reconstruction
447        for i in 0..5 {
448            let error = (dequantized[i] - data[i]).abs();
449            assert!(error < 0.1, "Reconstruction error too large: {}", error);
450        }
451    }
452
453    #[test]
454    fn test_quantize_dequantize_2d() {
455        let quantizer = DynamicQuantizer::int8_per_tensor();
456        let data = Array2::from_shape_fn((4, 4), |(i, j)| (i * 4 + j) as f32);
457
458        let quantized = quantizer.quantize_2d(&data).unwrap();
459        assert_eq!(quantized.data.len(), 16);
460
461        let dequantized = quantized.dequantize_2d().unwrap();
462        assert_eq!(dequantized.shape(), &[4, 4]);
463
464        // Check approximate reconstruction
465        for i in 0..4 {
466            for j in 0..4 {
467                let error = (dequantized[[i, j]] - data[[i, j]]).abs();
468                assert!(error < 0.5, "Reconstruction error too large: {}", error);
469            }
470        }
471    }
472
473    #[test]
474    fn test_per_channel_quantization() {
475        let quantizer = DynamicQuantizer::int8_per_channel();
476        let data = Array2::from_shape_fn((3, 4), |(i, j)| (i * 10 + j) as f32);
477
478        let quantized = quantizer.quantize_2d(&data).unwrap();
479        assert_eq!(quantized.params.scales.len(), 3); // One per channel (row)
480        assert_eq!(quantized.params.zero_points.len(), 3);
481
482        let dequantized = quantized.dequantize_2d().unwrap();
483        assert_eq!(dequantized.shape(), &[3, 4]);
484
485        // Check reconstruction
486        for i in 0..3 {
487            for j in 0..4 {
488                let error = (dequantized[[i, j]] - data[[i, j]]).abs();
489                assert!(error < 1.0, "Error at [{}, {}]: {}", i, j, error);
490            }
491        }
492    }
493
494    #[test]
495    fn test_compression_ratio() {
496        let quantizer = DynamicQuantizer::int8_per_tensor();
497        let data = Array2::from_shape_fn((100, 100), |(i, j)| (i + j) as f32);
498
499        let quantized = quantizer.quantize_2d(&data).unwrap();
500        let ratio = quantized.compression_ratio();
501
502        // INT8 should give ~4x compression (32-bit float -> 8-bit int)
503        // With overhead for scale/zero-point, expect ~3.9x
504        assert!(
505            ratio > 3.5 && ratio < 4.1,
506            "Unexpected compression ratio: {}",
507            ratio
508        );
509    }
510
511    #[test]
512    fn test_qrange() {
513        let quantizer = DynamicQuantizer::int8_per_tensor();
514        assert_eq!(quantizer.get_qrange(), (-128, 127));
515
516        let quantizer = DynamicQuantizer::int4_per_channel();
517        assert_eq!(quantizer.get_qrange(), (-8, 7));
518    }
519
520    #[test]
521    fn test_extreme_values() {
522        let quantizer = DynamicQuantizer::int8_per_tensor();
523        let data = Array1::from_vec(vec![-100.0, -50.0, 0.0, 50.0, 100.0]);
524
525        let quantized = quantizer.quantize_1d(&data).unwrap();
526        let dequantized = quantized.dequantize_1d().unwrap();
527
528        // Extreme values should be preserved reasonably well
529        for i in 0..5 {
530            let error_pct = ((dequantized[i] - data[i]) / data[i].abs().max(1.0)).abs();
531            assert!(
532                error_pct < 0.05,
533                "Large error at index {}: {}%",
534                i,
535                error_pct * 100.0
536            );
537        }
538    }
539}