1use serde::{Deserialize, Serialize};
37use std::fmt;
38use thiserror::Error;
39
40#[derive(Debug, Error)]
42pub enum QuantizationError {
43 #[error("Invalid quantization bit width: {0}")]
44 InvalidBitWidth(u8),
45
46 #[error("Invalid shape: {0}")]
47 InvalidShape(String),
48
49 #[error("Invalid number of channels: expected {expected}, got {got}")]
50 InvalidChannelCount { expected: usize, got: usize },
51
52 #[error("Empty tensor cannot be quantized")]
53 EmptyTensor,
54
55 #[error("Calibration data required for dynamic quantization")]
56 CalibrationRequired,
57
58 #[error("Unsupported quantization scheme: {0}")]
59 UnsupportedScheme(String),
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
64pub enum QuantizationScheme {
65 Int8,
67 Int4,
69 Int16,
71}
72
73impl QuantizationScheme {
74 pub fn bit_width(&self) -> u8 {
76 match self {
77 QuantizationScheme::Int4 => 4,
78 QuantizationScheme::Int8 => 8,
79 QuantizationScheme::Int16 => 16,
80 }
81 }
82
83 pub fn range(&self, symmetric: bool) -> (i32, i32) {
85 match (self, symmetric) {
86 (QuantizationScheme::Int4, true) => (-8, 7),
87 (QuantizationScheme::Int4, false) => (0, 15),
88 (QuantizationScheme::Int8, true) => (-128, 127),
89 (QuantizationScheme::Int8, false) => (0, 255),
90 (QuantizationScheme::Int16, true) => (-32768, 32767),
91 (QuantizationScheme::Int16, false) => (0, 65535),
92 }
93 }
94
95 pub fn compression_ratio(&self) -> f32 {
97 32.0 / self.bit_width() as f32
98 }
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
103pub enum QuantizationGranularity {
104 PerTensor,
106 PerChannel { num_channels: usize },
108 PerGroup { group_size: usize },
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct QuantizationConfig {
115 pub scheme: QuantizationScheme,
117 pub granularity: QuantizationGranularity,
119 pub symmetric: bool,
121 pub calibration: CalibrationMethod,
123}
124
125impl QuantizationConfig {
126 pub fn int8_symmetric() -> Self {
128 Self {
129 scheme: QuantizationScheme::Int8,
130 granularity: QuantizationGranularity::PerTensor,
131 symmetric: true,
132 calibration: CalibrationMethod::MinMax,
133 }
134 }
135
136 pub fn int8_asymmetric() -> Self {
138 Self {
139 scheme: QuantizationScheme::Int8,
140 granularity: QuantizationGranularity::PerTensor,
141 symmetric: false,
142 calibration: CalibrationMethod::MinMax,
143 }
144 }
145
146 pub fn int8_per_channel(num_channels: usize) -> Self {
148 Self {
149 scheme: QuantizationScheme::Int8,
150 granularity: QuantizationGranularity::PerChannel { num_channels },
151 symmetric: true,
152 calibration: CalibrationMethod::MinMax,
153 }
154 }
155
156 pub fn int4_symmetric() -> Self {
158 Self {
159 scheme: QuantizationScheme::Int4,
160 granularity: QuantizationGranularity::PerTensor,
161 symmetric: true,
162 calibration: CalibrationMethod::MinMax,
163 }
164 }
165
166 pub fn int4_per_channel(num_channels: usize) -> Self {
168 Self {
169 scheme: QuantizationScheme::Int4,
170 granularity: QuantizationGranularity::PerChannel { num_channels },
171 symmetric: true,
172 calibration: CalibrationMethod::MinMax,
173 }
174 }
175}
176
177#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
179pub enum CalibrationMethod {
180 MinMax,
182 Percentile { lower: u8, upper: u8 },
184 Entropy,
186 Mse,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct QuantizationParams {
193 pub scale: f32,
195 pub zero_point: i32,
197 pub qmin: i32,
199 pub qmax: i32,
201}
202
203impl QuantizationParams {
204 pub fn from_min_max(
206 min_val: f32,
207 max_val: f32,
208 scheme: QuantizationScheme,
209 symmetric: bool,
210 ) -> Self {
211 let (qmin, qmax) = scheme.range(symmetric);
212
213 let (scale, zero_point) = if symmetric {
214 let abs_max = min_val.abs().max(max_val.abs());
216 let scale = if abs_max > 0.0 {
217 abs_max / (qmax as f32)
218 } else {
219 1.0
220 };
221 (scale, 0)
222 } else {
223 let scale = if (max_val - min_val).abs() > 0.0 {
225 (max_val - min_val) / ((qmax - qmin) as f32)
226 } else {
227 1.0
228 };
229 let zero_point = qmin - (min_val / scale).round() as i32;
230 let zero_point = zero_point.clamp(qmin, qmax);
231 (scale, zero_point)
232 };
233
234 Self {
235 scale,
236 zero_point,
237 qmin,
238 qmax,
239 }
240 }
241
242 #[inline]
244 pub fn quantize(&self, value: f32) -> i32 {
245 let quantized = (value / self.scale).round() as i32 + self.zero_point;
246 quantized.clamp(self.qmin, self.qmax)
247 }
248
249 #[inline]
251 pub fn dequantize(&self, quantized: i32) -> f32 {
252 (quantized - self.zero_point) as f32 * self.scale
253 }
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct QuantizedTensor {
259 pub data: Vec<i32>,
261 pub shape: Vec<usize>,
263 pub params: Vec<QuantizationParams>,
265 pub config: QuantizationConfig,
267}
268
269impl QuantizedTensor {
270 pub fn quantize_per_tensor(
272 data: &[f32],
273 shape: Vec<usize>,
274 config: QuantizationConfig,
275 ) -> Result<Self, QuantizationError> {
276 if data.is_empty() {
277 return Err(QuantizationError::EmptyTensor);
278 }
279
280 if !matches!(config.granularity, QuantizationGranularity::PerTensor) {
282 return Err(QuantizationError::UnsupportedScheme(
283 "Expected per-tensor granularity".to_string(),
284 ));
285 }
286
287 let (min_val, max_val) = Self::calculate_min_max(data, &config.calibration)?;
289
290 let params =
292 QuantizationParams::from_min_max(min_val, max_val, config.scheme, config.symmetric);
293
294 let quantized_data: Vec<i32> = data.iter().map(|&v| params.quantize(v)).collect();
296
297 Ok(Self {
298 data: quantized_data,
299 shape,
300 params: vec![params],
301 config,
302 })
303 }
304
305 pub fn quantize_per_channel(
307 data: &[f32],
308 shape: Vec<usize>,
309 config: QuantizationConfig,
310 ) -> Result<Self, QuantizationError> {
311 if data.is_empty() {
312 return Err(QuantizationError::EmptyTensor);
313 }
314
315 let num_channels = match config.granularity {
316 QuantizationGranularity::PerChannel { num_channels } => num_channels,
317 _ => {
318 return Err(QuantizationError::UnsupportedScheme(
319 "Expected per-channel granularity".to_string(),
320 ))
321 }
322 };
323
324 if shape.is_empty() {
325 return Err(QuantizationError::InvalidShape("Empty shape".to_string()));
326 }
327
328 if shape[0] != num_channels {
330 return Err(QuantizationError::InvalidChannelCount {
331 expected: num_channels,
332 got: shape[0],
333 });
334 }
335
336 let channel_size = data.len() / num_channels;
337
338 let mut params = Vec::with_capacity(num_channels);
340 for i in 0..num_channels {
341 let start = i * channel_size;
342 let end = start + channel_size;
343 let channel_data = &data[start..end];
344
345 let (min_val, max_val) = Self::calculate_min_max(channel_data, &config.calibration)?;
346 let channel_params =
347 QuantizationParams::from_min_max(min_val, max_val, config.scheme, config.symmetric);
348 params.push(channel_params);
349 }
350
351 let mut quantized_data = Vec::with_capacity(data.len());
353 for (i, chunk) in data.chunks(channel_size).enumerate() {
354 for &value in chunk {
355 quantized_data.push(params[i].quantize(value));
356 }
357 }
358
359 Ok(Self {
360 data: quantized_data,
361 shape,
362 params,
363 config,
364 })
365 }
366
367 fn calculate_min_max(
369 data: &[f32],
370 calibration: &CalibrationMethod,
371 ) -> Result<(f32, f32), QuantizationError> {
372 match calibration {
373 CalibrationMethod::MinMax => {
374 let min_val = data.iter().copied().fold(f32::INFINITY, f32::min);
375 let max_val = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
376 Ok((min_val, max_val))
377 }
378 CalibrationMethod::Percentile { lower, upper } => {
379 let mut sorted = data.to_vec();
380 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
381
382 let lower_idx = (sorted.len() as f32 * (*lower as f32 / 100.0)) as usize;
383 let upper_idx = (sorted.len() as f32 * (*upper as f32 / 100.0)) as usize;
384
385 let min_val = sorted[lower_idx.min(sorted.len() - 1)];
386 let max_val = sorted[upper_idx.min(sorted.len() - 1)];
387 Ok((min_val, max_val))
388 }
389 _ => {
390 let min_val = data.iter().copied().fold(f32::INFINITY, f32::min);
392 let max_val = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
393 Ok((min_val, max_val))
394 }
395 }
396 }
397
398 pub fn dequantize(&self) -> Vec<f32> {
400 match self.config.granularity {
401 QuantizationGranularity::PerTensor => {
402 let params = &self.params[0];
403 self.data.iter().map(|&q| params.dequantize(q)).collect()
404 }
405 QuantizationGranularity::PerChannel { num_channels } => {
406 let channel_size = self.data.len() / num_channels;
407 let mut result = Vec::with_capacity(self.data.len());
408
409 for (i, chunk) in self.data.chunks(channel_size).enumerate() {
410 for &q in chunk {
411 result.push(self.params[i].dequantize(q));
412 }
413 }
414 result
415 }
416 QuantizationGranularity::PerGroup { .. } => {
417 let params = &self.params[0];
419 self.data.iter().map(|&q| params.dequantize(q)).collect()
420 }
421 }
422 }
423
424 pub fn compression_ratio(&self) -> f32 {
426 let original_bytes = self.data.len() * 4; let quantized_bytes = self.size_bytes();
428 original_bytes as f32 / quantized_bytes as f32
429 }
430
431 pub fn size_bytes(&self) -> usize {
433 match self.config.scheme {
434 QuantizationScheme::Int4 => {
435 self.data.len().div_ceil(2)
437 + self.params.len() * std::mem::size_of::<QuantizationParams>()
438 }
439 QuantizationScheme::Int8 => {
440 self.data.len() + self.params.len() * std::mem::size_of::<QuantizationParams>()
441 }
442 QuantizationScheme::Int16 => {
443 self.data.len() * 2 + self.params.len() * std::mem::size_of::<QuantizationParams>()
444 }
445 }
446 }
447
448 pub fn pack_int4(&self) -> Result<Vec<u8>, QuantizationError> {
450 if self.config.scheme != QuantizationScheme::Int4 {
451 return Err(QuantizationError::InvalidBitWidth(
452 self.config.scheme.bit_width(),
453 ));
454 }
455
456 let mut packed = Vec::with_capacity(self.data.len().div_ceil(2));
457 for chunk in self.data.chunks(2) {
458 let high = (chunk[0] & 0xF) as u8;
459 let low = if chunk.len() > 1 {
460 (chunk[1] & 0xF) as u8
461 } else {
462 0
463 };
464 packed.push((high << 4) | low);
465 }
466
467 Ok(packed)
468 }
469
470 pub fn unpack_int4(packed: &[u8], length: usize) -> Vec<i32> {
472 let mut unpacked = Vec::with_capacity(length);
473 for &byte in packed {
474 let high = ((byte >> 4) & 0xF) as i32;
475 let low = (byte & 0xF) as i32;
476 unpacked.push(high);
477 if unpacked.len() < length {
478 unpacked.push(low);
479 }
480 }
481 unpacked.truncate(length);
482 unpacked
483 }
484
485 pub fn quantization_error(&self, original: &[f32]) -> f32 {
487 let dequantized = self.dequantize();
488 let mse: f32 = original
489 .iter()
490 .zip(dequantized.iter())
491 .map(|(o, d)| {
492 let diff = o - d;
493 diff * diff
494 })
495 .sum::<f32>()
496 / original.len() as f32;
497 mse
498 }
499}
500
501impl fmt::Display for QuantizedTensor {
502 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
503 write!(
504 f,
505 "QuantizedTensor({:?}, shape={:?}, scheme={:?}, params={})",
506 self.config.granularity,
507 self.shape,
508 self.config.scheme,
509 self.params.len()
510 )
511 }
512}
513
514#[derive(Debug, Clone)]
516pub struct DynamicQuantizer {
517 scheme: QuantizationScheme,
519 symmetric: bool,
521 calibration: CalibrationMethod,
523}
524
525impl DynamicQuantizer {
526 pub fn new(scheme: QuantizationScheme, symmetric: bool) -> Self {
528 Self {
529 scheme,
530 symmetric,
531 calibration: CalibrationMethod::MinMax,
532 }
533 }
534
535 pub fn quantize_activation(
537 &self,
538 data: &[f32],
539 shape: Vec<usize>,
540 ) -> Result<QuantizedTensor, QuantizationError> {
541 let config = QuantizationConfig {
542 scheme: self.scheme,
543 granularity: QuantizationGranularity::PerTensor,
544 symmetric: self.symmetric,
545 calibration: self.calibration,
546 };
547
548 QuantizedTensor::quantize_per_tensor(data, shape, config)
549 }
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555
556 #[test]
557 fn test_quantization_scheme_ranges() {
558 assert_eq!(QuantizationScheme::Int8.range(true), (-128, 127));
559 assert_eq!(QuantizationScheme::Int8.range(false), (0, 255));
560 assert_eq!(QuantizationScheme::Int4.range(true), (-8, 7));
561 assert_eq!(QuantizationScheme::Int4.range(false), (0, 15));
562 }
563
564 #[test]
565 fn test_quantization_params_symmetric() {
566 let params = QuantizationParams::from_min_max(-1.0, 1.0, QuantizationScheme::Int8, true);
567 assert_eq!(params.zero_point, 0);
568 assert!(params.scale > 0.0);
569
570 assert_eq!(params.quantize(0.0), 0);
572 assert!(params.quantize(1.0) > 0);
573 assert!(params.quantize(-1.0) < 0);
574 }
575
576 #[test]
577 fn test_quantization_params_asymmetric() {
578 let params = QuantizationParams::from_min_max(0.5, 1.5, QuantizationScheme::Int8, false);
580 assert!(params.scale > 0.0);
582
583 let params2 = QuantizationParams::from_min_max(-1.0, 0.5, QuantizationScheme::Int8, false);
585 assert!(params2.scale > 0.0);
586 assert!(params2.zero_point >= params2.qmin && params2.zero_point <= params2.qmax);
587 }
588
589 #[test]
590 fn test_per_tensor_quantization() {
591 let data = vec![0.5, -0.3, 0.8, -0.1];
592 let config = QuantizationConfig::int8_symmetric();
593 let quantized = QuantizedTensor::quantize_per_tensor(&data, vec![4], config).unwrap();
594
595 assert_eq!(quantized.data.len(), 4);
596 assert_eq!(quantized.params.len(), 1);
597
598 let dequantized = quantized.dequantize();
600 assert_eq!(dequantized.len(), 4);
601
602 for (orig, deq) in data.iter().zip(dequantized.iter()) {
604 assert!((orig - deq).abs() < 0.01);
605 }
606 }
607
608 #[test]
609 fn test_per_channel_quantization() {
610 let data = vec![0.5, 0.3, -0.2, -0.4, 0.1, 0.6, -0.5, 0.2];
612 let config = QuantizationConfig::int8_per_channel(2);
613 let quantized = QuantizedTensor::quantize_per_channel(&data, vec![2, 4], config).unwrap();
614
615 assert_eq!(quantized.data.len(), 8);
616 assert_eq!(quantized.params.len(), 2);
617
618 assert_ne!(quantized.params[0].scale, quantized.params[1].scale);
620 }
621
622 #[test]
623 fn test_int4_quantization() {
624 let data = vec![0.1, 0.2, 0.3, 0.4];
625 let config = QuantizationConfig::int4_symmetric();
626 let quantized = QuantizedTensor::quantize_per_tensor(&data, vec![4], config).unwrap();
627
628 for &q in &quantized.data {
630 assert!(q >= -8 && q <= 7);
631 }
632
633 let packed = quantized.pack_int4().unwrap();
635 assert_eq!(packed.len(), 2); let unpacked = QuantizedTensor::unpack_int4(&packed, 4);
639 assert_eq!(unpacked, quantized.data);
640 }
641
642 #[test]
643 fn test_compression_ratio() {
644 let data = vec![1.0; 100];
645 let config = QuantizationConfig::int8_symmetric();
646 let quantized = QuantizedTensor::quantize_per_tensor(&data, vec![100], config).unwrap();
647
648 let ratio = quantized.compression_ratio();
649 assert!(ratio > 1.0); }
651
652 #[test]
653 fn test_quantization_error() {
654 let data = vec![0.1, 0.5, 0.9, 0.3];
655 let config = QuantizationConfig::int8_symmetric();
656 let quantized = QuantizedTensor::quantize_per_tensor(&data, vec![4], config).unwrap();
657
658 let error = quantized.quantization_error(&data);
659 assert!(error < 0.001); }
661
662 #[test]
663 fn test_dynamic_quantizer() {
664 let quantizer = DynamicQuantizer::new(QuantizationScheme::Int8, true);
665 let data = vec![1.0, 2.0, 3.0, 4.0];
666
667 let quantized = quantizer.quantize_activation(&data, vec![4]).unwrap();
668 assert_eq!(quantized.data.len(), 4);
669
670 let dequantized = quantized.dequantize();
671 for (orig, deq) in data.iter().zip(dequantized.iter()) {
672 assert!((orig - deq).abs() < 0.1);
673 }
674 }
675
676 #[test]
677 fn test_percentile_calibration() {
678 let mut data = vec![0.0; 100];
679 data[0] = -100.0;
681 data[99] = 100.0;
682 for i in 1..99 {
684 data[i] = (i as f32 - 50.0) / 50.0; }
686
687 let config = QuantizationConfig {
688 scheme: QuantizationScheme::Int8,
689 granularity: QuantizationGranularity::PerTensor,
690 symmetric: true,
691 calibration: CalibrationMethod::Percentile {
692 lower: 1,
693 upper: 99,
694 },
695 };
696
697 let quantized = QuantizedTensor::quantize_per_tensor(&data, vec![100], config).unwrap();
698
699 let params = &quantized.params[0];
701 assert!(params.scale < 1.0); }
703
704 #[test]
705 fn test_empty_tensor_error() {
706 let data: Vec<f32> = vec![];
707 let config = QuantizationConfig::int8_symmetric();
708 let result = QuantizedTensor::quantize_per_tensor(&data, vec![0], config);
709 assert!(result.is_err());
710 }
711
712 #[test]
713 fn test_invalid_channel_count() {
714 let data = vec![1.0; 8];
715 let config = QuantizationConfig::int8_per_channel(3); let result = QuantizedTensor::quantize_per_channel(&data, vec![2, 4], config);
717 assert!(result.is_err());
718 }
719}