1use crate::error::{ModelError, ModelResult};
29use scirs2_core::ndarray::{Array1, Array2};
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum QuantizationMethod {
34 Symmetric,
36 Asymmetric,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum QuantizationGranularity {
43 PerTensor,
45 PerChannel,
47}
48
49#[derive(Debug, Clone)]
51pub struct QuantizationParams {
52 pub scale: Vec<f32>,
54 pub zero_point: Vec<i8>,
56 pub method: QuantizationMethod,
58 pub granularity: QuantizationGranularity,
60}
61
62impl QuantizationParams {
63 pub fn symmetric_per_tensor(scale: f32) -> Self {
65 Self {
66 scale: vec![scale],
67 zero_point: vec![0],
68 method: QuantizationMethod::Symmetric,
69 granularity: QuantizationGranularity::PerTensor,
70 }
71 }
72
73 pub fn asymmetric_per_tensor(scale: f32, zero_point: i8) -> Self {
75 Self {
76 scale: vec![scale],
77 zero_point: vec![zero_point],
78 method: QuantizationMethod::Asymmetric,
79 granularity: QuantizationGranularity::PerTensor,
80 }
81 }
82
83 pub fn symmetric_per_channel(scales: Vec<f32>) -> Self {
85 let n = scales.len();
86 Self {
87 scale: scales,
88 zero_point: vec![0; n],
89 method: QuantizationMethod::Symmetric,
90 granularity: QuantizationGranularity::PerChannel,
91 }
92 }
93
94 pub fn validate(&self) -> ModelResult<()> {
96 if self.scale.is_empty() {
97 return Err(ModelError::invalid_config("scale cannot be empty"));
98 }
99 if self.scale.len() != self.zero_point.len() {
100 return Err(ModelError::invalid_config(
101 "scale and zero_point must have same length",
102 ));
103 }
104 for &s in &self.scale {
105 if s <= 0.0 || !s.is_finite() {
106 return Err(ModelError::invalid_config(format!("invalid scale: {}", s)));
107 }
108 }
109 Ok(())
110 }
111}
112
113#[derive(Debug, Clone)]
115pub struct QuantizedWeight {
116 pub data: Vec<i8>,
118 pub shape: Vec<usize>,
120 pub params: QuantizationParams,
122}
123
124impl QuantizedWeight {
125 pub fn new(data: Vec<i8>, shape: Vec<usize>, params: QuantizationParams) -> ModelResult<Self> {
127 params.validate()?;
128
129 let total_size: usize = shape.iter().product();
130 if data.len() != total_size {
131 return Err(ModelError::invalid_config(format!(
132 "data length {} does not match shape {:?}",
133 data.len(),
134 shape
135 )));
136 }
137
138 Ok(Self {
139 data,
140 shape,
141 params,
142 })
143 }
144
145 pub fn dequantize_1d(&self) -> ModelResult<Array1<f32>> {
147 if self.shape.len() != 1 {
148 return Err(ModelError::invalid_config(format!(
149 "expected 1D shape, got {:?}",
150 self.shape
151 )));
152 }
153
154 let n = self.shape[0];
155 let mut result = Array1::zeros(n);
156
157 match self.params.granularity {
158 QuantizationGranularity::PerTensor => {
159 let scale = self.params.scale[0];
160 let zero_point = self.params.zero_point[0];
161
162 for i in 0..n {
163 result[i] = (self.data[i] as i32 - zero_point as i32) as f32 * scale;
164 }
165 }
166 QuantizationGranularity::PerChannel => {
167 return Err(ModelError::invalid_config(
168 "per-channel quantization not supported for 1D tensors",
169 ));
170 }
171 }
172
173 Ok(result)
174 }
175
176 pub fn dequantize_2d(&self) -> ModelResult<Array2<f32>> {
178 if self.shape.len() != 2 {
179 return Err(ModelError::invalid_config(format!(
180 "expected 2D shape, got {:?}",
181 self.shape
182 )));
183 }
184
185 let (rows, cols) = (self.shape[0], self.shape[1]);
186 let mut result = Array2::zeros((rows, cols));
187
188 match self.params.granularity {
189 QuantizationGranularity::PerTensor => {
190 let scale = self.params.scale[0];
191 let zero_point = self.params.zero_point[0];
192
193 for i in 0..rows {
194 for j in 0..cols {
195 let idx = i * cols + j;
196 result[[i, j]] = (self.data[idx] as i32 - zero_point as i32) as f32 * scale;
197 }
198 }
199 }
200 QuantizationGranularity::PerChannel => {
201 if self.params.scale.len() != rows {
203 return Err(ModelError::invalid_config(format!(
204 "expected {} scales for per-channel, got {}",
205 rows,
206 self.params.scale.len()
207 )));
208 }
209
210 for i in 0..rows {
211 let scale = self.params.scale[i];
212 let zero_point = self.params.zero_point[i];
213
214 for j in 0..cols {
215 let idx = i * cols + j;
216 result[[i, j]] = (self.data[idx] as i32 - zero_point as i32) as f32 * scale;
217 }
218 }
219 }
220 }
221
222 Ok(result)
223 }
224
225 pub fn memory_size(&self) -> usize {
227 self.data.len() }
229}
230
231pub fn quantize_symmetric_1d(array: &Array1<f32>) -> ModelResult<QuantizedWeight> {
241 let max_val = array.iter().map(|&x| x.abs()).fold(0.0f32, f32::max);
242
243 if max_val == 0.0 {
244 let data = vec![0i8; array.len()];
246 let params = QuantizationParams::symmetric_per_tensor(1.0);
247 return QuantizedWeight::new(data, vec![array.len()], params);
248 }
249
250 let scale = max_val / 127.0;
251 let mut data = Vec::with_capacity(array.len());
252
253 for &x in array.iter() {
254 let q = (x / scale).round() as i32;
255 let q_clamped = q.clamp(-128, 127) as i8;
256 data.push(q_clamped);
257 }
258
259 let params = QuantizationParams::symmetric_per_tensor(scale);
260 QuantizedWeight::new(data, vec![array.len()], params)
261}
262
263pub fn quantize_symmetric_2d(array: &Array2<f32>) -> ModelResult<QuantizedWeight> {
265 let max_val = array.iter().map(|&x| x.abs()).fold(0.0f32, f32::max);
266
267 if max_val == 0.0 {
268 let (rows, cols) = array.dim();
269 let data = vec![0i8; rows * cols];
270 let params = QuantizationParams::symmetric_per_tensor(1.0);
271 return QuantizedWeight::new(data, vec![rows, cols], params);
272 }
273
274 let scale = max_val / 127.0;
275 let (rows, cols) = array.dim();
276 let mut data = Vec::with_capacity(rows * cols);
277
278 for i in 0..rows {
279 for j in 0..cols {
280 let x = array[[i, j]];
281 let q = (x / scale).round() as i32;
282 let q_clamped = q.clamp(-128, 127) as i8;
283 data.push(q_clamped);
284 }
285 }
286
287 let params = QuantizationParams::symmetric_per_tensor(scale);
288 QuantizedWeight::new(data, vec![rows, cols], params)
289}
290
291pub fn quantize_symmetric_per_channel(array: &Array2<f32>) -> ModelResult<QuantizedWeight> {
295 let (rows, cols) = array.dim();
296 let mut scales = Vec::with_capacity(rows);
297 let mut data = Vec::with_capacity(rows * cols);
298
299 for i in 0..rows {
301 let row = array.row(i);
302 let max_val = row.iter().map(|&x| x.abs()).fold(0.0f32, f32::max);
303
304 let scale = if max_val == 0.0 {
305 1.0 } else {
307 max_val / 127.0
308 };
309
310 scales.push(scale);
311 }
312
313 for i in 0..rows {
315 let scale = scales[i];
316 for j in 0..cols {
317 let x = array[[i, j]];
318 let q = (x / scale).round() as i32;
319 let q_clamped = q.clamp(-128, 127) as i8;
320 data.push(q_clamped);
321 }
322 }
323
324 let params = QuantizationParams::symmetric_per_channel(scales);
325 QuantizedWeight::new(data, vec![rows, cols], params)
326}
327
328pub fn quantize_asymmetric_1d(array: &Array1<f32>) -> ModelResult<QuantizedWeight> {
332 let min_val = array.iter().copied().fold(f32::INFINITY, f32::min);
333 let max_val = array.iter().copied().fold(f32::NEG_INFINITY, f32::max);
334
335 if (max_val - min_val).abs() < 1e-8 {
336 let data = vec![0i8; array.len()];
338 let params = QuantizationParams::asymmetric_per_tensor(1.0, 0);
339 return QuantizedWeight::new(data, vec![array.len()], params);
340 }
341
342 let scale = (max_val - min_val) / 255.0;
343 let zero_point_f = -128.0 - min_val / scale;
344 let zero_point = zero_point_f.round().clamp(-128.0, 127.0) as i8;
345
346 let mut data = Vec::with_capacity(array.len());
347
348 for &x in array.iter() {
349 let q_f = x / scale + zero_point as f32;
350 let q = q_f.round().clamp(-128.0, 127.0) as i8;
351 data.push(q);
352 }
353
354 let params = QuantizationParams::asymmetric_per_tensor(scale, zero_point);
355 QuantizedWeight::new(data, vec![array.len()], params)
356}
357
358#[derive(Debug, Clone)]
360pub struct CalibrationStats {
361 pub min: f32,
363 pub max: f32,
365 pub count: usize,
367}
368
369impl CalibrationStats {
370 pub fn new() -> Self {
372 Self {
373 min: f32::INFINITY,
374 max: f32::NEG_INFINITY,
375 count: 0,
376 }
377 }
378
379 pub fn update_1d(&mut self, data: &Array1<f32>) {
381 for &x in data.iter() {
382 self.min = self.min.min(x);
383 self.max = self.max.max(x);
384 }
385 self.count += data.len();
386 }
387
388 pub fn update_2d(&mut self, data: &Array2<f32>) {
390 for &x in data.iter() {
391 self.min = self.min.min(x);
392 self.max = self.max.max(x);
393 }
394 self.count += data.len();
395 }
396
397 pub fn to_symmetric_params(&self) -> ModelResult<QuantizationParams> {
399 let max_abs = self.max.abs().max(self.min.abs());
400 if max_abs == 0.0 {
401 Ok(QuantizationParams::symmetric_per_tensor(1.0))
402 } else {
403 Ok(QuantizationParams::symmetric_per_tensor(max_abs / 127.0))
404 }
405 }
406
407 pub fn to_asymmetric_params(&self) -> ModelResult<QuantizationParams> {
409 if (self.max - self.min).abs() < 1e-8 {
410 Ok(QuantizationParams::asymmetric_per_tensor(1.0, 0))
411 } else {
412 let scale = (self.max - self.min) / 255.0;
413 let zero_point_f = -128.0 - self.min / scale;
414 let zero_point = zero_point_f.round().clamp(-128.0, 127.0) as i8;
415 Ok(QuantizationParams::asymmetric_per_tensor(scale, zero_point))
416 }
417 }
418}
419
420impl Default for CalibrationStats {
421 fn default() -> Self {
422 Self::new()
423 }
424}
425
426#[derive(Debug, Clone)]
430pub struct ActivationQuantizer {
431 method: QuantizationMethod,
433 #[allow(dead_code)]
435 granularity: QuantizationGranularity,
436 calibration: Option<QuantizationParams>,
438}
439
440impl ActivationQuantizer {
441 pub fn new_symmetric() -> Self {
443 Self {
444 method: QuantizationMethod::Symmetric,
445 granularity: QuantizationGranularity::PerTensor,
446 calibration: None,
447 }
448 }
449
450 pub fn new_asymmetric() -> Self {
452 Self {
453 method: QuantizationMethod::Asymmetric,
454 granularity: QuantizationGranularity::PerTensor,
455 calibration: None,
456 }
457 }
458
459 pub fn calibrate(&mut self, stats: &CalibrationStats) -> ModelResult<()> {
461 self.calibration = Some(match self.method {
462 QuantizationMethod::Symmetric => stats.to_symmetric_params()?,
463 QuantizationMethod::Asymmetric => stats.to_asymmetric_params()?,
464 });
465 Ok(())
466 }
467
468 pub fn quantize_activation_1d(&self, activation: &Array1<f32>) -> ModelResult<Vec<i8>> {
470 let params = if let Some(ref cal) = self.calibration {
471 cal.clone()
473 } else {
474 let min_val = activation.iter().copied().fold(f32::INFINITY, f32::min);
476 let max_val = activation.iter().copied().fold(f32::NEG_INFINITY, f32::max);
477
478 match self.method {
479 QuantizationMethod::Symmetric => {
480 let max_abs = max_val.abs().max(min_val.abs());
481 QuantizationParams::symmetric_per_tensor(max_abs / 127.0)
482 }
483 QuantizationMethod::Asymmetric => {
484 let scale = (max_val - min_val) / 255.0;
485 let zero_point = (-128.0 - min_val / scale).round().clamp(-128.0, 127.0) as i8;
486 QuantizationParams::asymmetric_per_tensor(scale, zero_point)
487 }
488 }
489 };
490
491 let scale = params.scale[0];
492 let zero_point = params.zero_point[0];
493
494 let mut quantized = Vec::with_capacity(activation.len());
495 for &x in activation.iter() {
496 let q = match self.method {
497 QuantizationMethod::Symmetric => (x / scale).round().clamp(-128.0, 127.0) as i8,
498 QuantizationMethod::Asymmetric => {
499 let q_f = x / scale + zero_point as f32;
500 q_f.round().clamp(-128.0, 127.0) as i8
501 }
502 };
503 quantized.push(q);
504 }
505
506 Ok(quantized)
507 }
508
509 pub fn dequantize_activation_1d(
511 &self,
512 quantized: &[i8],
513 original_len: usize,
514 ) -> ModelResult<Array1<f32>> {
515 if quantized.len() != original_len {
516 return Err(ModelError::invalid_config(format!(
517 "quantized length {} doesn't match expected {}",
518 quantized.len(),
519 original_len
520 )));
521 }
522
523 let params = self
524 .calibration
525 .as_ref()
526 .ok_or_else(|| ModelError::invalid_config("calibration required for dequantization"))?;
527
528 let scale = params.scale[0];
529 let zero_point = params.zero_point[0];
530
531 let mut result = Array1::zeros(original_len);
532 for (i, &q) in quantized.iter().enumerate() {
533 result[i] = (q as i32 - zero_point as i32) as f32 * scale;
534 }
535
536 Ok(result)
537 }
538
539 pub fn simulate_quantization(&self, activation: &Array1<f32>) -> ModelResult<Array1<f32>> {
541 let min_val = activation.iter().copied().fold(f32::INFINITY, f32::min);
543 let max_val = activation.iter().copied().fold(f32::NEG_INFINITY, f32::max);
544
545 let (scale, zero_point) = match self.method {
546 QuantizationMethod::Symmetric => {
547 let max_abs = max_val.abs().max(min_val.abs());
548 (max_abs / 127.0, 0)
549 }
550 QuantizationMethod::Asymmetric => {
551 let scale = (max_val - min_val) / 255.0;
552 let zp = (-128.0 - min_val / scale).round().clamp(-128.0, 127.0) as i8;
553 (scale, zp)
554 }
555 };
556
557 let mut result = Array1::zeros(activation.len());
559 for (i, &x) in activation.iter().enumerate() {
560 let q = match self.method {
561 QuantizationMethod::Symmetric => (x / scale).round().clamp(-128.0, 127.0) as i8,
562 QuantizationMethod::Asymmetric => {
563 let q_f = x / scale + zero_point as f32;
564 q_f.round().clamp(-128.0, 127.0) as i8
565 }
566 };
567 result[i] = (q as i32 - zero_point as i32) as f32 * scale;
568 }
569
570 Ok(result)
571 }
572
573 pub fn memory_savings(&self) -> f32 {
575 75.0 }
578}
579
580impl Default for ActivationQuantizer {
581 fn default() -> Self {
582 Self::new_symmetric()
583 }
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589
590 fn approx_eq(a: f32, b: f32, epsilon: f32) -> bool {
591 (a - b).abs() < epsilon
592 }
593
594 #[test]
595 fn test_symmetric_quantization_1d() {
596 let array = Array1::from_vec(vec![-10.0, -5.0, 0.0, 5.0, 10.0]);
597 let quantized = quantize_symmetric_1d(&array).expect("Failed to quantize 1d array");
598
599 assert_eq!(quantized.shape, vec![5]);
600 assert_eq!(quantized.params.method, QuantizationMethod::Symmetric);
601
602 let dequantized = quantized
604 .dequantize_1d()
605 .expect("Failed to dequantize 1d array");
606 for i in 0..5 {
607 assert!(approx_eq(array[i], dequantized[i], 0.1));
608 }
609 }
610
611 #[test]
612 fn test_symmetric_quantization_2d() {
613 let array = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, -1.0, -2.0, -3.0])
614 .expect("Failed to create test array");
615 let quantized = quantize_symmetric_2d(&array).expect("Failed to quantize 2d array");
616
617 assert_eq!(quantized.shape, vec![2, 3]);
618
619 let dequantized = quantized
620 .dequantize_2d()
621 .expect("Failed to dequantize 2d array");
622 for i in 0..2 {
623 for j in 0..3 {
624 assert!(approx_eq(array[[i, j]], dequantized[[i, j]], 0.05));
625 }
626 }
627 }
628
629 #[test]
630 fn test_per_channel_quantization() {
631 let array = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 10.0, 20.0, 30.0])
633 .expect("Failed to create test array");
634 let quantized =
635 quantize_symmetric_per_channel(&array).expect("Failed to quantize per channel");
636
637 assert_eq!(
638 quantized.params.granularity,
639 QuantizationGranularity::PerChannel
640 );
641 assert_eq!(quantized.params.scale.len(), 2);
642
643 let dequantized = quantized
644 .dequantize_2d()
645 .expect("Failed to dequantize 2d array");
646
647 for i in 0..2 {
649 for j in 0..3 {
650 assert!(approx_eq(array[[i, j]], dequantized[[i, j]], 0.3));
651 }
652 }
653 }
654
655 #[test]
656 fn test_asymmetric_quantization() {
657 let array = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
658 let quantized = quantize_asymmetric_1d(&array).expect("Failed to quantize asymmetric");
659
660 assert_eq!(quantized.params.method, QuantizationMethod::Asymmetric);
661
662 let dequantized = quantized.dequantize_1d().expect("Failed to dequantize");
663 for i in 0..5 {
664 assert!(approx_eq(array[i], dequantized[i], 0.05));
665 }
666 }
667
668 #[test]
669 fn test_calibration_stats() {
670 let mut stats = CalibrationStats::new();
671
672 let data1 = Array1::from_vec(vec![-5.0, 0.0, 5.0]);
673 let data2 = Array1::from_vec(vec![-10.0, -2.0, 8.0]);
674
675 stats.update_1d(&data1);
676 stats.update_1d(&data2);
677
678 assert_eq!(stats.min, -10.0);
679 assert_eq!(stats.max, 8.0);
680 assert_eq!(stats.count, 6);
681
682 let params = stats.to_symmetric_params().expect("Failed to get params");
683 assert!(approx_eq(params.scale[0], 10.0 / 127.0, 1e-6));
684 }
685
686 #[test]
687 fn test_memory_savings() {
688 let array = Array2::from_shape_vec((100, 100), vec![1.0; 10000])
689 .expect("Failed to create test array");
690 let quantized = quantize_symmetric_2d(&array).expect("Failed to quantize");
691
692 let original_size = 10000 * 4;
694 let quantized_size = quantized.memory_size();
695
696 assert_eq!(quantized_size, 10000); assert!(quantized_size < original_size / 3); }
699
700 #[test]
701 fn test_activation_quantizer_symmetric() {
702 let quantizer = ActivationQuantizer::new_symmetric();
703 let activation = Array1::from_vec(vec![-10.0, -5.0, 0.0, 5.0, 10.0]);
704
705 let quantized = quantizer
707 .quantize_activation_1d(&activation)
708 .expect("Failed to quantize activation");
709 assert_eq!(quantized.len(), activation.len());
710
711 assert_eq!(quantizer.memory_savings(), 75.0);
713 }
714
715 #[test]
716 fn test_activation_quantizer_asymmetric() {
717 let quantizer = ActivationQuantizer::new_asymmetric();
718 let activation = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
719
720 let quantized = quantizer
721 .quantize_activation_1d(&activation)
722 .expect("Failed to quantize activation");
723 assert_eq!(quantized.len(), activation.len());
724 }
725
726 #[test]
727 fn test_activation_quantizer_with_calibration() {
728 let mut quantizer = ActivationQuantizer::new_symmetric();
729
730 let mut stats = CalibrationStats::new();
732 stats.update_1d(&Array1::from_vec(vec![-10.0, 0.0, 10.0]));
733 stats.update_1d(&Array1::from_vec(vec![-5.0, 0.0, 5.0]));
734
735 quantizer.calibrate(&stats).expect("Failed to calibrate");
737
738 let activation = Array1::from_vec(vec![-8.0, 0.0, 8.0]);
740 let quantized = quantizer
741 .quantize_activation_1d(&activation)
742 .expect("Failed to quantize activation");
743
744 let dequantized = quantizer
746 .dequantize_activation_1d(&quantized, activation.len())
747 .expect("Failed to dequantize activation");
748
749 for i in 0..activation.len() {
751 assert!((activation[i] - dequantized[i]).abs() < 1.0);
752 }
753 }
754
755 #[test]
756 fn test_simulate_quantization() {
757 let quantizer = ActivationQuantizer::new_symmetric();
758 let activation = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
759
760 let simulated = quantizer
761 .simulate_quantization(&activation)
762 .expect("Failed to simulate quantization");
763
764 for i in 0..activation.len() {
766 assert!((activation[i] - simulated[i]).abs() < 0.1);
767 }
768 }
769}