oxify_connect_vision/
quantization.rs

1//! Model quantization support for performance optimization.
2//!
3//! This module provides support for quantized ONNX models:
4//! - INT8 quantization for CPU inference
5//! - FP16 quantization for GPU inference
6//! - Dynamic quantization
7//! - Quantization configuration and validation
8//!
9//! Quantized models can significantly reduce:
10//! - Model size (2-4x smaller)
11//! - Memory usage (2-4x less)
12//! - Inference latency (1.5-3x faster)
13//!
14//! With minimal accuracy loss (typically <1%).
15
16use serde::{Deserialize, Serialize};
17use std::path::{Path, PathBuf};
18use tracing::{debug, info, warn};
19
20/// Quantization precision level
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
22pub enum QuantizationPrecision {
23    /// Full precision (FP32) - no quantization
24    #[default]
25    FP32,
26    /// Half precision (FP16) - good for GPU
27    FP16,
28    /// 8-bit integer (INT8) - good for CPU
29    INT8,
30    /// Mixed precision (automatic selection)
31    Mixed,
32}
33
34impl QuantizationPrecision {
35    /// Get the size reduction factor compared to FP32
36    pub fn size_reduction_factor(&self) -> f32 {
37        match self {
38            Self::FP32 => 1.0,
39            Self::FP16 => 2.0,
40            Self::INT8 => 4.0,
41            Self::Mixed => 2.5, // Approximate
42        }
43    }
44
45    /// Get the expected speedup factor compared to FP32
46    pub fn speedup_factor(&self) -> f32 {
47        match self {
48            Self::FP32 => 1.0,
49            Self::FP16 => 1.5,
50            Self::INT8 => 2.5,
51            Self::Mixed => 2.0, // Approximate
52        }
53    }
54
55    /// Get the typical accuracy loss percentage
56    pub fn accuracy_loss(&self) -> f32 {
57        match self {
58            Self::FP32 => 0.0,
59            Self::FP16 => 0.1,  // ~0.1% loss
60            Self::INT8 => 0.5,  // ~0.5% loss
61            Self::Mixed => 0.3, // ~0.3% loss
62        }
63    }
64
65    /// Check if this precision is suitable for GPU
66    pub fn is_gpu_suitable(&self) -> bool {
67        matches!(self, Self::FP16 | Self::Mixed)
68    }
69
70    /// Check if this precision is suitable for CPU
71    pub fn is_cpu_suitable(&self) -> bool {
72        matches!(self, Self::INT8 | Self::Mixed | Self::FP32)
73    }
74}
75
76impl std::fmt::Display for QuantizationPrecision {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        match self {
79            Self::FP32 => write!(f, "FP32"),
80            Self::FP16 => write!(f, "FP16"),
81            Self::INT8 => write!(f, "INT8"),
82            Self::Mixed => write!(f, "Mixed"),
83        }
84    }
85}
86
87/// Quantization method
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
89pub enum QuantizationMethod {
90    /// Static quantization (requires calibration data)
91    Static,
92    /// Dynamic quantization (no calibration needed)
93    #[default]
94    Dynamic,
95    /// Quantization-aware training (QAT)
96    QAT,
97}
98
99/// Configuration for model quantization
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct QuantizationConfig {
102    /// Quantization precision
103    pub precision: QuantizationPrecision,
104
105    /// Quantization method
106    pub method: QuantizationMethod,
107
108    /// Whether to use symmetric quantization
109    pub symmetric: bool,
110
111    /// Whether to quantize per-channel
112    pub per_channel: bool,
113
114    /// Layers to exclude from quantization (e.g., first/last layers)
115    pub exclude_layers: Vec<String>,
116
117    /// Calibration data size (for static quantization)
118    pub calibration_size: Option<usize>,
119
120    /// Target accuracy threshold (fail if accuracy drops more than this)
121    pub min_accuracy: Option<f32>,
122}
123
124impl QuantizationConfig {
125    /// Create a new quantization config with defaults
126    pub fn new(precision: QuantizationPrecision) -> Self {
127        Self {
128            precision,
129            method: QuantizationMethod::Dynamic,
130            symmetric: true,
131            per_channel: true,
132            exclude_layers: vec![],
133            calibration_size: None,
134            min_accuracy: None,
135        }
136    }
137
138    /// Create a config for INT8 CPU inference
139    pub fn int8_cpu() -> Self {
140        Self {
141            precision: QuantizationPrecision::INT8,
142            method: QuantizationMethod::Dynamic,
143            symmetric: true,
144            per_channel: true,
145            exclude_layers: vec![],
146            calibration_size: None,
147            min_accuracy: Some(0.99), // Allow 1% accuracy loss
148        }
149    }
150
151    /// Create a config for FP16 GPU inference
152    pub fn fp16_gpu() -> Self {
153        Self {
154            precision: QuantizationPrecision::FP16,
155            method: QuantizationMethod::Dynamic,
156            symmetric: false,
157            per_channel: false,
158            exclude_layers: vec![],
159            calibration_size: None,
160            min_accuracy: Some(0.999), // Allow 0.1% accuracy loss
161        }
162    }
163
164    /// Create a config for static quantization
165    pub fn static_quantization(precision: QuantizationPrecision, calibration_size: usize) -> Self {
166        Self {
167            precision,
168            method: QuantizationMethod::Static,
169            symmetric: true,
170            per_channel: true,
171            exclude_layers: vec![],
172            calibration_size: Some(calibration_size),
173            min_accuracy: Some(0.98),
174        }
175    }
176
177    /// Exclude specific layers from quantization
178    pub fn exclude_layer(mut self, layer_name: impl Into<String>) -> Self {
179        self.exclude_layers.push(layer_name.into());
180        self
181    }
182
183    /// Set minimum accuracy threshold
184    pub fn with_min_accuracy(mut self, accuracy: f32) -> Self {
185        self.min_accuracy = Some(accuracy);
186        self
187    }
188
189    /// Validate the configuration
190    pub fn validate(&self) -> Result<(), String> {
191        // Check if calibration is required but not provided
192        if self.method == QuantizationMethod::Static && self.calibration_size.is_none() {
193            return Err("Static quantization requires calibration_size".to_string());
194        }
195
196        // Check minimum accuracy
197        if let Some(min_acc) = self.min_accuracy {
198            if !(0.0..=1.0).contains(&min_acc) {
199                return Err("min_accuracy must be between 0.0 and 1.0".to_string());
200            }
201        }
202
203        Ok(())
204    }
205}
206
207impl Default for QuantizationConfig {
208    fn default() -> Self {
209        Self::new(QuantizationPrecision::FP32)
210    }
211}
212
213/// Information about a quantized model
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct QuantizedModelInfo {
216    /// Original model path
217    pub original_path: PathBuf,
218
219    /// Quantized model path
220    pub quantized_path: PathBuf,
221
222    /// Quantization configuration used
223    pub config: QuantizationConfig,
224
225    /// Original model size in bytes
226    pub original_size: u64,
227
228    /// Quantized model size in bytes
229    pub quantized_size: u64,
230
231    /// Size reduction ratio
232    pub size_reduction: f32,
233
234    /// Measured speedup (if available)
235    pub speedup: Option<f32>,
236
237    /// Accuracy on validation set (if available)
238    pub accuracy: Option<f32>,
239}
240
241impl QuantizedModelInfo {
242    /// Create a new quantized model info
243    pub fn new(
244        original_path: PathBuf,
245        quantized_path: PathBuf,
246        config: QuantizationConfig,
247    ) -> Self {
248        Self {
249            original_path,
250            quantized_path,
251            config,
252            original_size: 0,
253            quantized_size: 0,
254            size_reduction: 0.0,
255            speedup: None,
256            accuracy: None,
257        }
258    }
259
260    /// Set the model sizes
261    pub fn with_sizes(mut self, original: u64, quantized: u64) -> Self {
262        self.original_size = original;
263        self.quantized_size = quantized;
264        self.size_reduction = original as f32 / quantized as f32;
265        self
266    }
267
268    /// Set the speedup factor
269    pub fn with_speedup(mut self, speedup: f32) -> Self {
270        self.speedup = Some(speedup);
271        self
272    }
273
274    /// Set the accuracy
275    pub fn with_accuracy(mut self, accuracy: f32) -> Self {
276        self.accuracy = Some(accuracy);
277        self
278    }
279
280    /// Get a summary string
281    pub fn summary(&self) -> String {
282        format!(
283            "Quantization: {} -> {:.2}x smaller",
284            self.config.precision, self.size_reduction
285        )
286    }
287}
288
289/// Model quantizer (placeholder for actual quantization logic)
290pub struct ModelQuantizer {
291    config: QuantizationConfig,
292}
293
294impl ModelQuantizer {
295    /// Create a new quantizer with the given configuration
296    pub fn new(config: QuantizationConfig) -> Result<Self, String> {
297        config.validate()?;
298        Ok(Self { config })
299    }
300
301    /// Create a quantizer for INT8 CPU inference
302    pub fn int8_cpu() -> Result<Self, String> {
303        Self::new(QuantizationConfig::int8_cpu())
304    }
305
306    /// Create a quantizer for FP16 GPU inference
307    pub fn fp16_gpu() -> Result<Self, String> {
308        Self::new(QuantizationConfig::fp16_gpu())
309    }
310
311    /// Quantize a model
312    ///
313    /// Note: This is a placeholder. Actual quantization would require:
314    /// 1. Loading the ONNX model
315    /// 2. Converting weights to lower precision
316    /// 3. Optionally calibrating with sample data
317    /// 4. Saving the quantized model
318    pub fn quantize<P: AsRef<Path>>(
319        &self,
320        model_path: P,
321        output_path: P,
322    ) -> Result<QuantizedModelInfo, String> {
323        let model_path = model_path.as_ref();
324        let output_path = output_path.as_ref();
325
326        info!(
327            "Quantizing model {} to {} precision",
328            model_path.display(),
329            self.config.precision
330        );
331
332        // Validate input
333        if !model_path.exists() {
334            return Err(format!("Model not found: {}", model_path.display()));
335        }
336
337        // Get original model size
338        let original_size = std::fs::metadata(model_path)
339            .map_err(|e| format!("Failed to get model size: {}", e))?
340            .len();
341
342        debug!("Original model size: {} bytes", original_size);
343
344        // In a real implementation, this would:
345        // 1. Load the ONNX model with ort
346        // 2. Apply quantization transformations
347        // 3. Save the quantized model
348        //
349        // For now, we just log a warning
350        warn!(
351            "Model quantization is a placeholder. \
352             Actual quantization requires ONNX Runtime quantization tools."
353        );
354
355        // Estimate quantized size based on precision
356        let estimated_size =
357            (original_size as f32 / self.config.precision.size_reduction_factor()) as u64;
358
359        let info = QuantizedModelInfo::new(
360            model_path.to_path_buf(),
361            output_path.to_path_buf(),
362            self.config.clone(),
363        )
364        .with_sizes(original_size, estimated_size);
365
366        Ok(info)
367    }
368
369    /// Get the configuration
370    pub fn config(&self) -> &QuantizationConfig {
371        &self.config
372    }
373
374    /// Check if a model is already quantized
375    pub fn is_quantized<P: AsRef<Path>>(model_path: P) -> bool {
376        let path = model_path.as_ref();
377        let file_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
378
379        // Check for common quantization indicators in filename
380        file_name.contains("int8")
381            || file_name.contains("fp16")
382            || file_name.contains("quantized")
383            || file_name.contains("quant")
384    }
385
386    /// Estimate the benefits of quantization
387    pub fn estimate_benefits(&self, model_size_bytes: u64) -> QuantizationBenefits {
388        let size_reduction = self.config.precision.size_reduction_factor();
389        let speedup = self.config.precision.speedup_factor();
390        let accuracy_loss = self.config.precision.accuracy_loss();
391
392        QuantizationBenefits {
393            original_size_mb: model_size_bytes as f32 / (1024.0 * 1024.0),
394            quantized_size_mb: model_size_bytes as f32 / (1024.0 * 1024.0) / size_reduction,
395            size_reduction_factor: size_reduction,
396            expected_speedup: speedup,
397            expected_accuracy_loss: accuracy_loss,
398        }
399    }
400}
401
402/// Estimated benefits of quantization
403#[derive(Debug, Clone)]
404pub struct QuantizationBenefits {
405    pub original_size_mb: f32,
406    pub quantized_size_mb: f32,
407    pub size_reduction_factor: f32,
408    pub expected_speedup: f32,
409    pub expected_accuracy_loss: f32,
410}
411
412impl std::fmt::Display for QuantizationBenefits {
413    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
414        writeln!(f, "Quantization Benefits Estimate:")?;
415        writeln!(f, "  Original Size:     {:.2} MB", self.original_size_mb)?;
416        writeln!(f, "  Quantized Size:    {:.2} MB", self.quantized_size_mb)?;
417        writeln!(f, "  Size Reduction:    {:.2}x", self.size_reduction_factor)?;
418        writeln!(f, "  Expected Speedup:  {:.2}x", self.expected_speedup)?;
419        writeln!(
420            f,
421            "  Accuracy Loss:     ~{:.1}%",
422            self.expected_accuracy_loss
423        )?;
424        Ok(())
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431    use std::io::Write;
432    use tempfile::NamedTempFile;
433
434    #[test]
435    fn test_quantization_precision() {
436        assert_eq!(QuantizationPrecision::FP32.size_reduction_factor(), 1.0);
437        assert_eq!(QuantizationPrecision::FP16.size_reduction_factor(), 2.0);
438        assert_eq!(QuantizationPrecision::INT8.size_reduction_factor(), 4.0);
439    }
440
441    #[test]
442    fn test_quantization_precision_gpu_suitable() {
443        assert!(QuantizationPrecision::FP16.is_gpu_suitable());
444        assert!(!QuantizationPrecision::INT8.is_gpu_suitable());
445    }
446
447    #[test]
448    fn test_quantization_precision_cpu_suitable() {
449        assert!(QuantizationPrecision::INT8.is_cpu_suitable());
450        assert!(QuantizationPrecision::FP32.is_cpu_suitable());
451    }
452
453    #[test]
454    fn test_quantization_config_int8() {
455        let config = QuantizationConfig::int8_cpu();
456        assert_eq!(config.precision, QuantizationPrecision::INT8);
457        assert_eq!(config.method, QuantizationMethod::Dynamic);
458        assert!(config.symmetric);
459    }
460
461    #[test]
462    fn test_quantization_config_fp16() {
463        let config = QuantizationConfig::fp16_gpu();
464        assert_eq!(config.precision, QuantizationPrecision::FP16);
465        assert_eq!(config.method, QuantizationMethod::Dynamic);
466    }
467
468    #[test]
469    fn test_quantization_config_exclude_layer() {
470        let config = QuantizationConfig::int8_cpu().exclude_layer("input");
471        assert_eq!(config.exclude_layers, vec!["input"]);
472    }
473
474    #[test]
475    fn test_quantization_config_validation() {
476        let config = QuantizationConfig::int8_cpu();
477        assert!(config.validate().is_ok());
478
479        let invalid = QuantizationConfig {
480            method: QuantizationMethod::Static,
481            calibration_size: None,
482            ..QuantizationConfig::int8_cpu()
483        };
484        assert!(invalid.validate().is_err());
485    }
486
487    #[test]
488    fn test_model_quantizer_creation() {
489        let config = QuantizationConfig::int8_cpu();
490        let quantizer = ModelQuantizer::new(config);
491        assert!(quantizer.is_ok());
492    }
493
494    #[test]
495    fn test_model_quantizer_int8() {
496        let quantizer = ModelQuantizer::int8_cpu();
497        assert!(quantizer.is_ok());
498        assert_eq!(
499            quantizer.unwrap().config().precision,
500            QuantizationPrecision::INT8
501        );
502    }
503
504    #[test]
505    fn test_model_quantizer_fp16() {
506        let quantizer = ModelQuantizer::fp16_gpu();
507        assert!(quantizer.is_ok());
508        assert_eq!(
509            quantizer.unwrap().config().precision,
510            QuantizationPrecision::FP16
511        );
512    }
513
514    #[test]
515    fn test_is_quantized() {
516        assert!(ModelQuantizer::is_quantized("model_int8.onnx"));
517        assert!(ModelQuantizer::is_quantized("model_fp16.onnx"));
518        assert!(ModelQuantizer::is_quantized("model_quantized.onnx"));
519        assert!(!ModelQuantizer::is_quantized("model.onnx"));
520    }
521
522    #[test]
523    fn test_quantize_model_not_found() {
524        let quantizer = ModelQuantizer::int8_cpu().unwrap();
525        let result = quantizer.quantize("nonexistent.onnx", "output.onnx");
526        assert!(result.is_err());
527    }
528
529    #[test]
530    fn test_quantize_model() {
531        let mut temp_file = NamedTempFile::new().unwrap();
532        temp_file.write_all(b"fake model data").unwrap();
533
534        let quantizer = ModelQuantizer::int8_cpu().unwrap();
535        let output_path = PathBuf::from("output.onnx");
536        let result = quantizer.quantize(temp_file.path(), &output_path);
537
538        assert!(result.is_ok());
539        let info = result.unwrap();
540        assert!(info.size_reduction > 0.0);
541    }
542
543    #[test]
544    fn test_quantized_model_info() {
545        let info = QuantizedModelInfo::new(
546            PathBuf::from("input.onnx"),
547            PathBuf::from("output.onnx"),
548            QuantizationConfig::int8_cpu(),
549        )
550        .with_sizes(1000, 250);
551
552        assert_eq!(info.size_reduction, 4.0);
553        assert!(info.summary().contains("INT8"));
554    }
555
556    #[test]
557    fn test_estimate_benefits() {
558        let quantizer = ModelQuantizer::int8_cpu().unwrap();
559        let benefits = quantizer.estimate_benefits(100 * 1024 * 1024); // 100 MB
560
561        assert!(benefits.original_size_mb > 99.0);
562        assert!(benefits.quantized_size_mb < benefits.original_size_mb);
563        assert!(benefits.size_reduction_factor > 1.0);
564    }
565
566    #[test]
567    fn test_quantization_precision_display() {
568        assert_eq!(format!("{}", QuantizationPrecision::FP32), "FP32");
569        assert_eq!(format!("{}", QuantizationPrecision::FP16), "FP16");
570        assert_eq!(format!("{}", QuantizationPrecision::INT8), "INT8");
571    }
572
573    #[test]
574    fn test_quantization_method_default() {
575        let method = QuantizationMethod::default();
576        assert_eq!(method, QuantizationMethod::Dynamic);
577    }
578
579    #[test]
580    fn test_benefits_display() {
581        let benefits = QuantizationBenefits {
582            original_size_mb: 100.0,
583            quantized_size_mb: 25.0,
584            size_reduction_factor: 4.0,
585            expected_speedup: 2.5,
586            expected_accuracy_loss: 0.5,
587        };
588
589        let display = format!("{}", benefits);
590        assert!(display.contains("100.00 MB"));
591        assert!(display.contains("25.00 MB"));
592    }
593}