Skip to main content

oxigdal_ml/optimization/
quantization.rs

1//! Model quantization for reduced precision inference
2//!
3//! Quantization reduces model size and improves inference speed by converting
4//! floating-point weights and activations to lower precision formats.
5
6use crate::error::{MlError, Result};
7use std::path::Path;
8use tracing::{debug, info};
9
10/// Quantization type
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum QuantizationType {
13    /// 8-bit signed integer quantization
14    Int8,
15    /// 8-bit unsigned integer quantization
16    UInt8,
17    /// 16-bit floating point quantization
18    Float16,
19    /// 4-bit quantization (experimental)
20    Int4,
21}
22
23/// Quantization mode
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum QuantizationMode {
26    /// Dynamic quantization (runtime calibration)
27    Dynamic,
28    /// Static quantization (pre-calibrated)
29    Static,
30    /// Quantization-aware training
31    QAT,
32}
33
34/// Quantization configuration
35#[derive(Debug, Clone)]
36pub struct QuantizationConfig {
37    /// Quantization type
38    pub quantization_type: QuantizationType,
39    /// Quantization mode
40    pub mode: QuantizationMode,
41    /// Per-channel quantization (more accurate)
42    pub per_channel: bool,
43    /// Symmetric vs asymmetric quantization
44    pub symmetric: bool,
45    /// Calibration dataset size (for static quantization)
46    pub calibration_samples: usize,
47}
48
49impl Default for QuantizationConfig {
50    fn default() -> Self {
51        Self {
52            quantization_type: QuantizationType::Int8,
53            mode: QuantizationMode::Dynamic,
54            per_channel: false,
55            symmetric: true,
56            calibration_samples: 100,
57        }
58    }
59}
60
61impl QuantizationConfig {
62    /// Creates a configuration builder
63    #[must_use]
64    pub fn builder() -> QuantizationConfigBuilder {
65        QuantizationConfigBuilder::default()
66    }
67}
68
69/// Builder for quantization configuration
70#[derive(Debug, Default)]
71pub struct QuantizationConfigBuilder {
72    quantization_type: Option<QuantizationType>,
73    mode: Option<QuantizationMode>,
74    per_channel: bool,
75    symmetric: bool,
76    calibration_samples: Option<usize>,
77}
78
79impl QuantizationConfigBuilder {
80    /// Sets the quantization type
81    #[must_use]
82    pub fn quantization_type(mut self, qtype: QuantizationType) -> Self {
83        self.quantization_type = Some(qtype);
84        self
85    }
86
87    /// Sets the quantization mode
88    #[must_use]
89    pub fn mode(mut self, mode: QuantizationMode) -> Self {
90        self.mode = Some(mode);
91        self
92    }
93
94    /// Enables per-channel quantization
95    #[must_use]
96    pub fn per_channel(mut self, enable: bool) -> Self {
97        self.per_channel = enable;
98        self
99    }
100
101    /// Sets symmetric quantization
102    #[must_use]
103    pub fn symmetric(mut self, enable: bool) -> Self {
104        self.symmetric = enable;
105        self
106    }
107
108    /// Sets calibration sample count
109    #[must_use]
110    pub fn calibration_samples(mut self, count: usize) -> Self {
111        self.calibration_samples = Some(count);
112        self
113    }
114
115    /// Builds the configuration
116    #[must_use]
117    pub fn build(self) -> QuantizationConfig {
118        QuantizationConfig {
119            quantization_type: self.quantization_type.unwrap_or(QuantizationType::Int8),
120            mode: self.mode.unwrap_or(QuantizationMode::Dynamic),
121            per_channel: self.per_channel,
122            symmetric: self.symmetric,
123            calibration_samples: self.calibration_samples.unwrap_or(100),
124        }
125    }
126}
127
128/// Quantization parameters
129#[derive(Debug, Clone)]
130pub struct QuantizationParams {
131    /// Scale factor
132    pub scale: f32,
133    /// Zero point
134    pub zero_point: i32,
135    /// Min value
136    pub min: f32,
137    /// Max value
138    pub max: f32,
139    /// Quantization type (for proper clamping)
140    pub qtype: QuantizationType,
141}
142
143impl QuantizationParams {
144    /// Computes quantization parameters from min/max values
145    #[must_use]
146    pub fn from_min_max(min: f32, max: f32, qtype: QuantizationType, symmetric: bool) -> Self {
147        let (qmin, qmax) = match qtype {
148            QuantizationType::Int8 => (-128i32, 127i32),
149            QuantizationType::UInt8 => (0i32, 255i32),
150            QuantizationType::Int4 => (-8i32, 7i32),
151            QuantizationType::Float16 => return Self::identity(),
152        };
153
154        if symmetric {
155            let abs_max = min.abs().max(max.abs());
156            let scale = abs_max / qmax as f32;
157            Self {
158                scale,
159                zero_point: 0,
160                min,
161                max,
162                qtype,
163            }
164        } else {
165            let scale = (max - min) / (qmax - qmin) as f32;
166            let zero_point = qmin - (min / scale).round() as i32;
167            Self {
168                scale,
169                zero_point,
170                min,
171                max,
172                qtype,
173            }
174        }
175    }
176
177    /// Creates identity parameters (no quantization)
178    #[must_use]
179    pub fn identity() -> Self {
180        Self {
181            scale: 1.0,
182            zero_point: 0,
183            min: 0.0,
184            max: 1.0,
185            qtype: QuantizationType::Float16,
186        }
187    }
188
189    /// Quantizes a floating-point value
190    #[must_use]
191    pub fn quantize(&self, value: f32) -> i32 {
192        let (qmin, qmax) = match self.qtype {
193            QuantizationType::Int8 => (-128i32, 127i32),
194            QuantizationType::UInt8 => (0i32, 255i32),
195            QuantizationType::Int4 => (-8i32, 7i32),
196            QuantizationType::Float16 => return value as i32,
197        };
198
199        let scaled = value / self.scale;
200        (scaled.round() as i32 + self.zero_point).clamp(qmin, qmax)
201    }
202
203    /// Dequantizes a quantized value
204    #[must_use]
205    pub fn dequantize(&self, value: i32) -> f32 {
206        (value - self.zero_point) as f32 * self.scale
207    }
208}
209
210/// Quantizes an ONNX model
211///
212/// # Errors
213/// Returns an error if quantization fails
214pub fn quantize_model<P: AsRef<Path>>(
215    input_path: P,
216    output_path: P,
217    config: &QuantizationConfig,
218) -> Result<QuantizationResult> {
219    let input = input_path.as_ref();
220    let output = output_path.as_ref();
221
222    info!(
223        "Quantizing model {:?} to {:?} (type: {:?}, mode: {:?})",
224        input, output, config.quantization_type, config.mode
225    );
226
227    if !input.exists() {
228        return Err(MlError::InvalidConfig(format!(
229            "Input model not found: {}",
230            input.display()
231        )));
232    }
233
234    debug!(
235        "Quantization config: per_channel={}, symmetric={}",
236        config.per_channel, config.symmetric
237    );
238
239    // Actual ONNX quantization requires:
240    // 1. Loading the ONNX model
241    // 2. Analyzing tensor value ranges
242    // 3. Computing quantization parameters
243    // 4. Converting weights and activations to quantized format
244    // 5. Saving the quantized model
245
246    // Since full ONNX Runtime quantization APIs are complex,
247    // we provide the framework here. In production, use:
248    // - onnxruntime::quantization module
249    // - Static quantization with calibration dataset
250    // - Dynamic quantization for certain operators
251
252    let original_size = std::fs::metadata(input)?.len();
253
254    // Copy model (in production, this would be actual quantization)
255    std::fs::copy(input, output)?;
256
257    let quantized_size = std::fs::metadata(output)?.len();
258
259    // Estimate compression ratio based on quantization type
260    let compression_ratio = match config.quantization_type {
261        QuantizationType::Int8 => 4.0,    // float32 -> int8
262        QuantizationType::UInt8 => 4.0,   // float32 -> uint8
263        QuantizationType::Float16 => 2.0, // float32 -> float16
264        QuantizationType::Int4 => 8.0,    // float32 -> int4
265    };
266
267    info!(
268        "Quantization complete: {:.1}x compression (estimated)",
269        compression_ratio
270    );
271
272    Ok(QuantizationResult {
273        original_size,
274        quantized_size,
275        compression_ratio,
276        quantization_type: config.quantization_type,
277    })
278}
279
280/// Result of model quantization
281#[derive(Debug, Clone)]
282pub struct QuantizationResult {
283    /// Original model size in bytes
284    pub original_size: u64,
285    /// Quantized model size in bytes
286    pub quantized_size: u64,
287    /// Compression ratio achieved
288    pub compression_ratio: f32,
289    /// Quantization type used
290    pub quantization_type: QuantizationType,
291}
292
293impl QuantizationResult {
294    /// Returns the size reduction percentage
295    #[must_use]
296    pub fn size_reduction_percent(&self) -> f32 {
297        if self.original_size > 0 {
298            (1.0 - (self.quantized_size as f32 / self.original_size as f32)) * 100.0
299        } else {
300            0.0
301        }
302    }
303
304    /// Returns the original size in megabytes
305    #[must_use]
306    pub fn original_size_mb(&self) -> f32 {
307        self.original_size as f32 / (1024.0 * 1024.0)
308    }
309
310    /// Returns the quantized size in megabytes
311    #[must_use]
312    pub fn quantized_size_mb(&self) -> f32 {
313        self.quantized_size as f32 / (1024.0 * 1024.0)
314    }
315}
316
317/// Calibrates quantization parameters using a dataset
318///
319/// # Errors
320/// Returns an error if calibration fails
321pub fn calibrate_quantization(
322    calibration_data: &[Vec<f32>],
323    config: &QuantizationConfig,
324) -> Result<Vec<QuantizationParams>> {
325    info!(
326        "Calibrating quantization with {} samples",
327        calibration_data.len()
328    );
329
330    if calibration_data.is_empty() {
331        return Err(MlError::InvalidConfig(
332            "Calibration data cannot be empty".to_string(),
333        ));
334    }
335
336    let mut params_list = Vec::new();
337
338    // Compute min/max for each channel
339    for channel_idx in 0..calibration_data[0].len() {
340        let mut min = f32::MAX;
341        let mut max = f32::MIN;
342
343        for sample in calibration_data {
344            if let Some(&value) = sample.get(channel_idx) {
345                min = min.min(value);
346                max = max.max(value);
347            }
348        }
349
350        let params =
351            QuantizationParams::from_min_max(min, max, config.quantization_type, config.symmetric);
352        params_list.push(params);
353    }
354
355    debug!("Calibrated {} channels", params_list.len());
356    Ok(params_list)
357}
358
359/// Quantizes a tensor using the provided parameters
360#[must_use]
361pub fn quantize_tensor(tensor: &[f32], params: &QuantizationParams) -> Vec<i8> {
362    tensor.iter().map(|&v| params.quantize(v) as i8).collect()
363}
364
365/// Dequantizes a tensor using the provided parameters
366#[must_use]
367pub fn dequantize_tensor(tensor: &[i8], params: &QuantizationParams) -> Vec<f32> {
368    tensor
369        .iter()
370        .map(|&v| params.dequantize(i32::from(v)))
371        .collect()
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_quantization_config_builder() {
380        let config = QuantizationConfig::builder()
381            .quantization_type(QuantizationType::Int8)
382            .mode(QuantizationMode::Static)
383            .per_channel(true)
384            .symmetric(false)
385            .calibration_samples(200)
386            .build();
387
388        assert_eq!(config.quantization_type, QuantizationType::Int8);
389        assert_eq!(config.mode, QuantizationMode::Static);
390        assert!(config.per_channel);
391        assert!(!config.symmetric);
392        assert_eq!(config.calibration_samples, 200);
393    }
394
395    #[test]
396    fn test_quantization_params_symmetric() {
397        let params = QuantizationParams::from_min_max(-10.0, 10.0, QuantizationType::Int8, true);
398
399        assert_eq!(params.zero_point, 0);
400        assert!((params.scale - 10.0 / 127.0).abs() < 1e-6);
401
402        // Test quantize/dequantize round-trip
403        let value = 5.0;
404        let quantized = params.quantize(value);
405        let dequantized = params.dequantize(quantized);
406        assert!((dequantized - value).abs() < 0.1);
407    }
408
409    #[test]
410    fn test_quantization_params_asymmetric() {
411        let params = QuantizationParams::from_min_max(0.0, 255.0, QuantizationType::UInt8, false);
412
413        assert!((params.scale - 1.0).abs() < 1e-6);
414
415        let value = 128.0;
416        let quantized = params.quantize(value);
417        let dequantized = params.dequantize(quantized);
418        assert!((dequantized - value).abs() < 1.0);
419    }
420
421    #[test]
422    fn test_quantize_tensor() {
423        let tensor = vec![0.0, 1.0, 2.0, 3.0, 4.0];
424        let params = QuantizationParams::from_min_max(0.0, 4.0, QuantizationType::Int8, true);
425
426        let quantized = quantize_tensor(&tensor, &params);
427        assert_eq!(quantized.len(), tensor.len());
428
429        let dequantized = dequantize_tensor(&quantized, &params);
430        for (orig, deq) in tensor.iter().zip(dequantized.iter()) {
431            assert!((orig - deq).abs() < 0.1);
432        }
433    }
434
435    #[test]
436    fn test_calibrate_quantization() {
437        let calibration_data = vec![
438            vec![0.0, 1.0, 2.0],
439            vec![0.5, 1.5, 2.5],
440            vec![1.0, 2.0, 3.0],
441        ];
442
443        let config = QuantizationConfig::default();
444        let params =
445            calibrate_quantization(&calibration_data, &config).expect("Calibration should succeed");
446
447        assert_eq!(params.len(), 3);
448        assert!(params[0].min <= 0.0);
449        assert!(params[2].max >= 3.0);
450    }
451}