1use serde::{Deserialize, Serialize};
17use std::path::{Path, PathBuf};
18use tracing::{debug, info, warn};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
22pub enum QuantizationPrecision {
23 #[default]
25 FP32,
26 FP16,
28 INT8,
30 Mixed,
32}
33
34impl QuantizationPrecision {
35 pub fn size_reduction_factor(&self) -> f32 {
37 match self {
38 Self::FP32 => 1.0,
39 Self::FP16 => 2.0,
40 Self::INT8 => 4.0,
41 Self::Mixed => 2.5, }
43 }
44
45 pub fn speedup_factor(&self) -> f32 {
47 match self {
48 Self::FP32 => 1.0,
49 Self::FP16 => 1.5,
50 Self::INT8 => 2.5,
51 Self::Mixed => 2.0, }
53 }
54
55 pub fn accuracy_loss(&self) -> f32 {
57 match self {
58 Self::FP32 => 0.0,
59 Self::FP16 => 0.1, Self::INT8 => 0.5, Self::Mixed => 0.3, }
63 }
64
65 pub fn is_gpu_suitable(&self) -> bool {
67 matches!(self, Self::FP16 | Self::Mixed)
68 }
69
70 pub fn is_cpu_suitable(&self) -> bool {
72 matches!(self, Self::INT8 | Self::Mixed | Self::FP32)
73 }
74}
75
76impl std::fmt::Display for QuantizationPrecision {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 match self {
79 Self::FP32 => write!(f, "FP32"),
80 Self::FP16 => write!(f, "FP16"),
81 Self::INT8 => write!(f, "INT8"),
82 Self::Mixed => write!(f, "Mixed"),
83 }
84 }
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
89pub enum QuantizationMethod {
90 Static,
92 #[default]
94 Dynamic,
95 QAT,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct QuantizationConfig {
102 pub precision: QuantizationPrecision,
104
105 pub method: QuantizationMethod,
107
108 pub symmetric: bool,
110
111 pub per_channel: bool,
113
114 pub exclude_layers: Vec<String>,
116
117 pub calibration_size: Option<usize>,
119
120 pub min_accuracy: Option<f32>,
122}
123
124impl QuantizationConfig {
125 pub fn new(precision: QuantizationPrecision) -> Self {
127 Self {
128 precision,
129 method: QuantizationMethod::Dynamic,
130 symmetric: true,
131 per_channel: true,
132 exclude_layers: vec![],
133 calibration_size: None,
134 min_accuracy: None,
135 }
136 }
137
138 pub fn int8_cpu() -> Self {
140 Self {
141 precision: QuantizationPrecision::INT8,
142 method: QuantizationMethod::Dynamic,
143 symmetric: true,
144 per_channel: true,
145 exclude_layers: vec![],
146 calibration_size: None,
147 min_accuracy: Some(0.99), }
149 }
150
151 pub fn fp16_gpu() -> Self {
153 Self {
154 precision: QuantizationPrecision::FP16,
155 method: QuantizationMethod::Dynamic,
156 symmetric: false,
157 per_channel: false,
158 exclude_layers: vec![],
159 calibration_size: None,
160 min_accuracy: Some(0.999), }
162 }
163
164 pub fn static_quantization(precision: QuantizationPrecision, calibration_size: usize) -> Self {
166 Self {
167 precision,
168 method: QuantizationMethod::Static,
169 symmetric: true,
170 per_channel: true,
171 exclude_layers: vec![],
172 calibration_size: Some(calibration_size),
173 min_accuracy: Some(0.98),
174 }
175 }
176
177 pub fn exclude_layer(mut self, layer_name: impl Into<String>) -> Self {
179 self.exclude_layers.push(layer_name.into());
180 self
181 }
182
183 pub fn with_min_accuracy(mut self, accuracy: f32) -> Self {
185 self.min_accuracy = Some(accuracy);
186 self
187 }
188
189 pub fn validate(&self) -> Result<(), String> {
191 if self.method == QuantizationMethod::Static && self.calibration_size.is_none() {
193 return Err("Static quantization requires calibration_size".to_string());
194 }
195
196 if let Some(min_acc) = self.min_accuracy {
198 if !(0.0..=1.0).contains(&min_acc) {
199 return Err("min_accuracy must be between 0.0 and 1.0".to_string());
200 }
201 }
202
203 Ok(())
204 }
205}
206
207impl Default for QuantizationConfig {
208 fn default() -> Self {
209 Self::new(QuantizationPrecision::FP32)
210 }
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct QuantizedModelInfo {
216 pub original_path: PathBuf,
218
219 pub quantized_path: PathBuf,
221
222 pub config: QuantizationConfig,
224
225 pub original_size: u64,
227
228 pub quantized_size: u64,
230
231 pub size_reduction: f32,
233
234 pub speedup: Option<f32>,
236
237 pub accuracy: Option<f32>,
239}
240
241impl QuantizedModelInfo {
242 pub fn new(
244 original_path: PathBuf,
245 quantized_path: PathBuf,
246 config: QuantizationConfig,
247 ) -> Self {
248 Self {
249 original_path,
250 quantized_path,
251 config,
252 original_size: 0,
253 quantized_size: 0,
254 size_reduction: 0.0,
255 speedup: None,
256 accuracy: None,
257 }
258 }
259
260 pub fn with_sizes(mut self, original: u64, quantized: u64) -> Self {
262 self.original_size = original;
263 self.quantized_size = quantized;
264 self.size_reduction = original as f32 / quantized as f32;
265 self
266 }
267
268 pub fn with_speedup(mut self, speedup: f32) -> Self {
270 self.speedup = Some(speedup);
271 self
272 }
273
274 pub fn with_accuracy(mut self, accuracy: f32) -> Self {
276 self.accuracy = Some(accuracy);
277 self
278 }
279
280 pub fn summary(&self) -> String {
282 format!(
283 "Quantization: {} -> {:.2}x smaller",
284 self.config.precision, self.size_reduction
285 )
286 }
287}
288
289pub struct ModelQuantizer {
291 config: QuantizationConfig,
292}
293
294impl ModelQuantizer {
295 pub fn new(config: QuantizationConfig) -> Result<Self, String> {
297 config.validate()?;
298 Ok(Self { config })
299 }
300
301 pub fn int8_cpu() -> Result<Self, String> {
303 Self::new(QuantizationConfig::int8_cpu())
304 }
305
306 pub fn fp16_gpu() -> Result<Self, String> {
308 Self::new(QuantizationConfig::fp16_gpu())
309 }
310
311 pub fn quantize<P: AsRef<Path>>(
319 &self,
320 model_path: P,
321 output_path: P,
322 ) -> Result<QuantizedModelInfo, String> {
323 let model_path = model_path.as_ref();
324 let output_path = output_path.as_ref();
325
326 info!(
327 "Quantizing model {} to {} precision",
328 model_path.display(),
329 self.config.precision
330 );
331
332 if !model_path.exists() {
334 return Err(format!("Model not found: {}", model_path.display()));
335 }
336
337 let original_size = std::fs::metadata(model_path)
339 .map_err(|e| format!("Failed to get model size: {}", e))?
340 .len();
341
342 debug!("Original model size: {} bytes", original_size);
343
344 warn!(
351 "Model quantization is a placeholder. \
352 Actual quantization requires ONNX Runtime quantization tools."
353 );
354
355 let estimated_size =
357 (original_size as f32 / self.config.precision.size_reduction_factor()) as u64;
358
359 let info = QuantizedModelInfo::new(
360 model_path.to_path_buf(),
361 output_path.to_path_buf(),
362 self.config.clone(),
363 )
364 .with_sizes(original_size, estimated_size);
365
366 Ok(info)
367 }
368
369 pub fn config(&self) -> &QuantizationConfig {
371 &self.config
372 }
373
374 pub fn is_quantized<P: AsRef<Path>>(model_path: P) -> bool {
376 let path = model_path.as_ref();
377 let file_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
378
379 file_name.contains("int8")
381 || file_name.contains("fp16")
382 || file_name.contains("quantized")
383 || file_name.contains("quant")
384 }
385
386 pub fn estimate_benefits(&self, model_size_bytes: u64) -> QuantizationBenefits {
388 let size_reduction = self.config.precision.size_reduction_factor();
389 let speedup = self.config.precision.speedup_factor();
390 let accuracy_loss = self.config.precision.accuracy_loss();
391
392 QuantizationBenefits {
393 original_size_mb: model_size_bytes as f32 / (1024.0 * 1024.0),
394 quantized_size_mb: model_size_bytes as f32 / (1024.0 * 1024.0) / size_reduction,
395 size_reduction_factor: size_reduction,
396 expected_speedup: speedup,
397 expected_accuracy_loss: accuracy_loss,
398 }
399 }
400}
401
402#[derive(Debug, Clone)]
404pub struct QuantizationBenefits {
405 pub original_size_mb: f32,
406 pub quantized_size_mb: f32,
407 pub size_reduction_factor: f32,
408 pub expected_speedup: f32,
409 pub expected_accuracy_loss: f32,
410}
411
412impl std::fmt::Display for QuantizationBenefits {
413 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
414 writeln!(f, "Quantization Benefits Estimate:")?;
415 writeln!(f, " Original Size: {:.2} MB", self.original_size_mb)?;
416 writeln!(f, " Quantized Size: {:.2} MB", self.quantized_size_mb)?;
417 writeln!(f, " Size Reduction: {:.2}x", self.size_reduction_factor)?;
418 writeln!(f, " Expected Speedup: {:.2}x", self.expected_speedup)?;
419 writeln!(
420 f,
421 " Accuracy Loss: ~{:.1}%",
422 self.expected_accuracy_loss
423 )?;
424 Ok(())
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use std::io::Write;
432 use tempfile::NamedTempFile;
433
434 #[test]
435 fn test_quantization_precision() {
436 assert_eq!(QuantizationPrecision::FP32.size_reduction_factor(), 1.0);
437 assert_eq!(QuantizationPrecision::FP16.size_reduction_factor(), 2.0);
438 assert_eq!(QuantizationPrecision::INT8.size_reduction_factor(), 4.0);
439 }
440
441 #[test]
442 fn test_quantization_precision_gpu_suitable() {
443 assert!(QuantizationPrecision::FP16.is_gpu_suitable());
444 assert!(!QuantizationPrecision::INT8.is_gpu_suitable());
445 }
446
447 #[test]
448 fn test_quantization_precision_cpu_suitable() {
449 assert!(QuantizationPrecision::INT8.is_cpu_suitable());
450 assert!(QuantizationPrecision::FP32.is_cpu_suitable());
451 }
452
453 #[test]
454 fn test_quantization_config_int8() {
455 let config = QuantizationConfig::int8_cpu();
456 assert_eq!(config.precision, QuantizationPrecision::INT8);
457 assert_eq!(config.method, QuantizationMethod::Dynamic);
458 assert!(config.symmetric);
459 }
460
461 #[test]
462 fn test_quantization_config_fp16() {
463 let config = QuantizationConfig::fp16_gpu();
464 assert_eq!(config.precision, QuantizationPrecision::FP16);
465 assert_eq!(config.method, QuantizationMethod::Dynamic);
466 }
467
468 #[test]
469 fn test_quantization_config_exclude_layer() {
470 let config = QuantizationConfig::int8_cpu().exclude_layer("input");
471 assert_eq!(config.exclude_layers, vec!["input"]);
472 }
473
474 #[test]
475 fn test_quantization_config_validation() {
476 let config = QuantizationConfig::int8_cpu();
477 assert!(config.validate().is_ok());
478
479 let invalid = QuantizationConfig {
480 method: QuantizationMethod::Static,
481 calibration_size: None,
482 ..QuantizationConfig::int8_cpu()
483 };
484 assert!(invalid.validate().is_err());
485 }
486
487 #[test]
488 fn test_model_quantizer_creation() {
489 let config = QuantizationConfig::int8_cpu();
490 let quantizer = ModelQuantizer::new(config);
491 assert!(quantizer.is_ok());
492 }
493
494 #[test]
495 fn test_model_quantizer_int8() {
496 let quantizer = ModelQuantizer::int8_cpu();
497 assert!(quantizer.is_ok());
498 assert_eq!(
499 quantizer.unwrap().config().precision,
500 QuantizationPrecision::INT8
501 );
502 }
503
504 #[test]
505 fn test_model_quantizer_fp16() {
506 let quantizer = ModelQuantizer::fp16_gpu();
507 assert!(quantizer.is_ok());
508 assert_eq!(
509 quantizer.unwrap().config().precision,
510 QuantizationPrecision::FP16
511 );
512 }
513
514 #[test]
515 fn test_is_quantized() {
516 assert!(ModelQuantizer::is_quantized("model_int8.onnx"));
517 assert!(ModelQuantizer::is_quantized("model_fp16.onnx"));
518 assert!(ModelQuantizer::is_quantized("model_quantized.onnx"));
519 assert!(!ModelQuantizer::is_quantized("model.onnx"));
520 }
521
522 #[test]
523 fn test_quantize_model_not_found() {
524 let quantizer = ModelQuantizer::int8_cpu().unwrap();
525 let result = quantizer.quantize("nonexistent.onnx", "output.onnx");
526 assert!(result.is_err());
527 }
528
529 #[test]
530 fn test_quantize_model() {
531 let mut temp_file = NamedTempFile::new().unwrap();
532 temp_file.write_all(b"fake model data").unwrap();
533
534 let quantizer = ModelQuantizer::int8_cpu().unwrap();
535 let output_path = PathBuf::from("output.onnx");
536 let result = quantizer.quantize(temp_file.path(), &output_path);
537
538 assert!(result.is_ok());
539 let info = result.unwrap();
540 assert!(info.size_reduction > 0.0);
541 }
542
543 #[test]
544 fn test_quantized_model_info() {
545 let info = QuantizedModelInfo::new(
546 PathBuf::from("input.onnx"),
547 PathBuf::from("output.onnx"),
548 QuantizationConfig::int8_cpu(),
549 )
550 .with_sizes(1000, 250);
551
552 assert_eq!(info.size_reduction, 4.0);
553 assert!(info.summary().contains("INT8"));
554 }
555
556 #[test]
557 fn test_estimate_benefits() {
558 let quantizer = ModelQuantizer::int8_cpu().unwrap();
559 let benefits = quantizer.estimate_benefits(100 * 1024 * 1024); assert!(benefits.original_size_mb > 99.0);
562 assert!(benefits.quantized_size_mb < benefits.original_size_mb);
563 assert!(benefits.size_reduction_factor > 1.0);
564 }
565
566 #[test]
567 fn test_quantization_precision_display() {
568 assert_eq!(format!("{}", QuantizationPrecision::FP32), "FP32");
569 assert_eq!(format!("{}", QuantizationPrecision::FP16), "FP16");
570 assert_eq!(format!("{}", QuantizationPrecision::INT8), "INT8");
571 }
572
573 #[test]
574 fn test_quantization_method_default() {
575 let method = QuantizationMethod::default();
576 assert_eq!(method, QuantizationMethod::Dynamic);
577 }
578
579 #[test]
580 fn test_benefits_display() {
581 let benefits = QuantizationBenefits {
582 original_size_mb: 100.0,
583 quantized_size_mb: 25.0,
584 size_reduction_factor: 4.0,
585 expected_speedup: 2.5,
586 expected_accuracy_loss: 0.5,
587 };
588
589 let display = format!("{}", benefits);
590 assert!(display.contains("100.00 MB"));
591 assert!(display.contains("25.00 MB"));
592 }
593}