1use crate::error::{MlError, Result};
7use std::path::Path;
8use tracing::{debug, info};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum QuantizationType {
13 Int8,
15 UInt8,
17 Float16,
19 Int4,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum QuantizationMode {
26 Dynamic,
28 Static,
30 QAT,
32}
33
34#[derive(Debug, Clone)]
36pub struct QuantizationConfig {
37 pub quantization_type: QuantizationType,
39 pub mode: QuantizationMode,
41 pub per_channel: bool,
43 pub symmetric: bool,
45 pub calibration_samples: usize,
47}
48
49impl Default for QuantizationConfig {
50 fn default() -> Self {
51 Self {
52 quantization_type: QuantizationType::Int8,
53 mode: QuantizationMode::Dynamic,
54 per_channel: false,
55 symmetric: true,
56 calibration_samples: 100,
57 }
58 }
59}
60
61impl QuantizationConfig {
62 #[must_use]
64 pub fn builder() -> QuantizationConfigBuilder {
65 QuantizationConfigBuilder::default()
66 }
67}
68
69#[derive(Debug, Default)]
71pub struct QuantizationConfigBuilder {
72 quantization_type: Option<QuantizationType>,
73 mode: Option<QuantizationMode>,
74 per_channel: bool,
75 symmetric: bool,
76 calibration_samples: Option<usize>,
77}
78
79impl QuantizationConfigBuilder {
80 #[must_use]
82 pub fn quantization_type(mut self, qtype: QuantizationType) -> Self {
83 self.quantization_type = Some(qtype);
84 self
85 }
86
87 #[must_use]
89 pub fn mode(mut self, mode: QuantizationMode) -> Self {
90 self.mode = Some(mode);
91 self
92 }
93
94 #[must_use]
96 pub fn per_channel(mut self, enable: bool) -> Self {
97 self.per_channel = enable;
98 self
99 }
100
101 #[must_use]
103 pub fn symmetric(mut self, enable: bool) -> Self {
104 self.symmetric = enable;
105 self
106 }
107
108 #[must_use]
110 pub fn calibration_samples(mut self, count: usize) -> Self {
111 self.calibration_samples = Some(count);
112 self
113 }
114
115 #[must_use]
117 pub fn build(self) -> QuantizationConfig {
118 QuantizationConfig {
119 quantization_type: self.quantization_type.unwrap_or(QuantizationType::Int8),
120 mode: self.mode.unwrap_or(QuantizationMode::Dynamic),
121 per_channel: self.per_channel,
122 symmetric: self.symmetric,
123 calibration_samples: self.calibration_samples.unwrap_or(100),
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
130pub struct QuantizationParams {
131 pub scale: f32,
133 pub zero_point: i32,
135 pub min: f32,
137 pub max: f32,
139 pub qtype: QuantizationType,
141}
142
143impl QuantizationParams {
144 #[must_use]
146 pub fn from_min_max(min: f32, max: f32, qtype: QuantizationType, symmetric: bool) -> Self {
147 let (qmin, qmax) = match qtype {
148 QuantizationType::Int8 => (-128i32, 127i32),
149 QuantizationType::UInt8 => (0i32, 255i32),
150 QuantizationType::Int4 => (-8i32, 7i32),
151 QuantizationType::Float16 => return Self::identity(),
152 };
153
154 if symmetric {
155 let abs_max = min.abs().max(max.abs());
156 let scale = abs_max / qmax as f32;
157 Self {
158 scale,
159 zero_point: 0,
160 min,
161 max,
162 qtype,
163 }
164 } else {
165 let scale = (max - min) / (qmax - qmin) as f32;
166 let zero_point = qmin - (min / scale).round() as i32;
167 Self {
168 scale,
169 zero_point,
170 min,
171 max,
172 qtype,
173 }
174 }
175 }
176
177 #[must_use]
179 pub fn identity() -> Self {
180 Self {
181 scale: 1.0,
182 zero_point: 0,
183 min: 0.0,
184 max: 1.0,
185 qtype: QuantizationType::Float16,
186 }
187 }
188
189 #[must_use]
191 pub fn quantize(&self, value: f32) -> i32 {
192 let (qmin, qmax) = match self.qtype {
193 QuantizationType::Int8 => (-128i32, 127i32),
194 QuantizationType::UInt8 => (0i32, 255i32),
195 QuantizationType::Int4 => (-8i32, 7i32),
196 QuantizationType::Float16 => return value as i32,
197 };
198
199 let scaled = value / self.scale;
200 (scaled.round() as i32 + self.zero_point).clamp(qmin, qmax)
201 }
202
203 #[must_use]
205 pub fn dequantize(&self, value: i32) -> f32 {
206 (value - self.zero_point) as f32 * self.scale
207 }
208}
209
210pub fn quantize_model<P: AsRef<Path>>(
215 input_path: P,
216 output_path: P,
217 config: &QuantizationConfig,
218) -> Result<QuantizationResult> {
219 let input = input_path.as_ref();
220 let output = output_path.as_ref();
221
222 info!(
223 "Quantizing model {:?} to {:?} (type: {:?}, mode: {:?})",
224 input, output, config.quantization_type, config.mode
225 );
226
227 if !input.exists() {
228 return Err(MlError::InvalidConfig(format!(
229 "Input model not found: {}",
230 input.display()
231 )));
232 }
233
234 debug!(
235 "Quantization config: per_channel={}, symmetric={}",
236 config.per_channel, config.symmetric
237 );
238
239 let original_size = std::fs::metadata(input)?.len();
253
254 std::fs::copy(input, output)?;
256
257 let quantized_size = std::fs::metadata(output)?.len();
258
259 let compression_ratio = match config.quantization_type {
261 QuantizationType::Int8 => 4.0, QuantizationType::UInt8 => 4.0, QuantizationType::Float16 => 2.0, QuantizationType::Int4 => 8.0, };
266
267 info!(
268 "Quantization complete: {:.1}x compression (estimated)",
269 compression_ratio
270 );
271
272 Ok(QuantizationResult {
273 original_size,
274 quantized_size,
275 compression_ratio,
276 quantization_type: config.quantization_type,
277 })
278}
279
280#[derive(Debug, Clone)]
282pub struct QuantizationResult {
283 pub original_size: u64,
285 pub quantized_size: u64,
287 pub compression_ratio: f32,
289 pub quantization_type: QuantizationType,
291}
292
293impl QuantizationResult {
294 #[must_use]
296 pub fn size_reduction_percent(&self) -> f32 {
297 if self.original_size > 0 {
298 (1.0 - (self.quantized_size as f32 / self.original_size as f32)) * 100.0
299 } else {
300 0.0
301 }
302 }
303
304 #[must_use]
306 pub fn original_size_mb(&self) -> f32 {
307 self.original_size as f32 / (1024.0 * 1024.0)
308 }
309
310 #[must_use]
312 pub fn quantized_size_mb(&self) -> f32 {
313 self.quantized_size as f32 / (1024.0 * 1024.0)
314 }
315}
316
317pub fn calibrate_quantization(
322 calibration_data: &[Vec<f32>],
323 config: &QuantizationConfig,
324) -> Result<Vec<QuantizationParams>> {
325 info!(
326 "Calibrating quantization with {} samples",
327 calibration_data.len()
328 );
329
330 if calibration_data.is_empty() {
331 return Err(MlError::InvalidConfig(
332 "Calibration data cannot be empty".to_string(),
333 ));
334 }
335
336 let mut params_list = Vec::new();
337
338 for channel_idx in 0..calibration_data[0].len() {
340 let mut min = f32::MAX;
341 let mut max = f32::MIN;
342
343 for sample in calibration_data {
344 if let Some(&value) = sample.get(channel_idx) {
345 min = min.min(value);
346 max = max.max(value);
347 }
348 }
349
350 let params =
351 QuantizationParams::from_min_max(min, max, config.quantization_type, config.symmetric);
352 params_list.push(params);
353 }
354
355 debug!("Calibrated {} channels", params_list.len());
356 Ok(params_list)
357}
358
359#[must_use]
361pub fn quantize_tensor(tensor: &[f32], params: &QuantizationParams) -> Vec<i8> {
362 tensor.iter().map(|&v| params.quantize(v) as i8).collect()
363}
364
365#[must_use]
367pub fn dequantize_tensor(tensor: &[i8], params: &QuantizationParams) -> Vec<f32> {
368 tensor
369 .iter()
370 .map(|&v| params.dequantize(i32::from(v)))
371 .collect()
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_quantization_config_builder() {
380 let config = QuantizationConfig::builder()
381 .quantization_type(QuantizationType::Int8)
382 .mode(QuantizationMode::Static)
383 .per_channel(true)
384 .symmetric(false)
385 .calibration_samples(200)
386 .build();
387
388 assert_eq!(config.quantization_type, QuantizationType::Int8);
389 assert_eq!(config.mode, QuantizationMode::Static);
390 assert!(config.per_channel);
391 assert!(!config.symmetric);
392 assert_eq!(config.calibration_samples, 200);
393 }
394
395 #[test]
396 fn test_quantization_params_symmetric() {
397 let params = QuantizationParams::from_min_max(-10.0, 10.0, QuantizationType::Int8, true);
398
399 assert_eq!(params.zero_point, 0);
400 assert!((params.scale - 10.0 / 127.0).abs() < 1e-6);
401
402 let value = 5.0;
404 let quantized = params.quantize(value);
405 let dequantized = params.dequantize(quantized);
406 assert!((dequantized - value).abs() < 0.1);
407 }
408
409 #[test]
410 fn test_quantization_params_asymmetric() {
411 let params = QuantizationParams::from_min_max(0.0, 255.0, QuantizationType::UInt8, false);
412
413 assert!((params.scale - 1.0).abs() < 1e-6);
414
415 let value = 128.0;
416 let quantized = params.quantize(value);
417 let dequantized = params.dequantize(quantized);
418 assert!((dequantized - value).abs() < 1.0);
419 }
420
421 #[test]
422 fn test_quantize_tensor() {
423 let tensor = vec![0.0, 1.0, 2.0, 3.0, 4.0];
424 let params = QuantizationParams::from_min_max(0.0, 4.0, QuantizationType::Int8, true);
425
426 let quantized = quantize_tensor(&tensor, ¶ms);
427 assert_eq!(quantized.len(), tensor.len());
428
429 let dequantized = dequantize_tensor(&quantized, ¶ms);
430 for (orig, deq) in tensor.iter().zip(dequantized.iter()) {
431 assert!((orig - deq).abs() < 0.1);
432 }
433 }
434
435 #[test]
436 fn test_calibrate_quantization() {
437 let calibration_data = vec![
438 vec![0.0, 1.0, 2.0],
439 vec![0.5, 1.5, 2.5],
440 vec![1.0, 2.0, 3.0],
441 ];
442
443 let config = QuantizationConfig::default();
444 let params =
445 calibrate_quantization(&calibration_data, &config).expect("Calibration should succeed");
446
447 assert_eq!(params.len(), 3);
448 assert!(params[0].min <= 0.0);
449 assert!(params[2].max >= 3.0);
450 }
451}