1use crate::{CoreError, CoreResult};
20use scirs2_core::ndarray::{Array1, Array2, Axis};
21use serde::{Deserialize, Serialize};
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
25pub enum QuantizationType {
26 INT8,
28 INT4,
30 FP16,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum QuantizationScheme {
37 PerTensor,
39 PerChannel,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct QuantizationParams {
46 pub qtype: QuantizationType,
48 pub scheme: QuantizationScheme,
50 pub scales: Vec<f32>,
52 pub zero_points: Vec<i32>,
54 pub shape: Vec<usize>,
56}
57
58impl QuantizationParams {
59 pub fn new(
61 qtype: QuantizationType,
62 scheme: QuantizationScheme,
63 scales: Vec<f32>,
64 zero_points: Vec<i32>,
65 shape: Vec<usize>,
66 ) -> Self {
67 Self {
68 qtype,
69 scheme,
70 scales,
71 zero_points,
72 shape,
73 }
74 }
75
76 pub fn qrange(&self) -> (i32, i32) {
78 match self.qtype {
79 QuantizationType::INT8 => (-128, 127),
80 QuantizationType::INT4 => (-8, 7),
81 QuantizationType::FP16 => (0, 0), }
83 }
84
85 pub fn validate(&self) -> CoreResult<()> {
87 match self.scheme {
88 QuantizationScheme::PerTensor => {
89 if self.scales.len() != 1 || self.zero_points.len() != 1 {
90 return Err(CoreError::InvalidConfig(
91 "PerTensor scheme requires exactly 1 scale and zero-point".into(),
92 ));
93 }
94 }
95 QuantizationScheme::PerChannel => {
96 if self.shape.is_empty() {
97 return Err(CoreError::InvalidConfig(
98 "PerChannel scheme requires shape information".into(),
99 ));
100 }
101 let num_channels = self.shape[0];
102 if self.scales.len() != num_channels || self.zero_points.len() != num_channels {
103 return Err(CoreError::InvalidConfig(format!(
104 "PerChannel scheme requires {} scales and zero-points, got {} and {}",
105 num_channels,
106 self.scales.len(),
107 self.zero_points.len()
108 )));
109 }
110 }
111 }
112 Ok(())
113 }
114}
115
116#[derive(Debug, Clone)]
118pub struct QuantizedTensor {
119 pub data: Vec<i8>,
121 pub params: QuantizationParams,
123}
124
125impl QuantizedTensor {
126 pub fn new(data: Vec<i8>, params: QuantizationParams) -> CoreResult<Self> {
128 params.validate()?;
129 Ok(Self { data, params })
130 }
131
132 pub fn dequantize_1d(&self) -> CoreResult<Array1<f32>> {
134 if self.params.shape.len() != 1 {
135 return Err(CoreError::InvalidConfig(
136 "Expected 1D tensor for dequantize_1d".into(),
137 ));
138 }
139
140 let size = self.params.shape[0];
141 let mut result = Array1::zeros(size);
142
143 match self.params.scheme {
144 QuantizationScheme::PerTensor => {
145 let scale = self.params.scales[0];
146 let zero_point = self.params.zero_points[0];
147
148 for (i, &q_val) in self.data.iter().enumerate() {
149 result[i] = (q_val as i32 - zero_point) as f32 * scale;
150 }
151 }
152 QuantizationScheme::PerChannel => {
153 let scale = self.params.scales[0];
155 let zero_point = self.params.zero_points[0];
156
157 for (i, &q_val) in self.data.iter().enumerate() {
158 result[i] = (q_val as i32 - zero_point) as f32 * scale;
159 }
160 }
161 }
162
163 Ok(result)
164 }
165
166 pub fn dequantize_2d(&self) -> CoreResult<Array2<f32>> {
168 if self.params.shape.len() != 2 {
169 return Err(CoreError::InvalidConfig(
170 "Expected 2D tensor for dequantize_2d".into(),
171 ));
172 }
173
174 let rows = self.params.shape[0];
175 let cols = self.params.shape[1];
176 let mut result = Array2::zeros((rows, cols));
177
178 match self.params.scheme {
179 QuantizationScheme::PerTensor => {
180 let scale = self.params.scales[0];
181 let zero_point = self.params.zero_points[0];
182
183 for i in 0..rows {
184 for j in 0..cols {
185 let idx = i * cols + j;
186 let q_val = self.data[idx];
187 result[[i, j]] = (q_val as i32 - zero_point) as f32 * scale;
188 }
189 }
190 }
191 QuantizationScheme::PerChannel => {
192 for i in 0..rows {
194 let scale = self.params.scales[i];
195 let zero_point = self.params.zero_points[i];
196
197 for j in 0..cols {
198 let idx = i * cols + j;
199 let q_val = self.data[idx];
200 result[[i, j]] = (q_val as i32 - zero_point) as f32 * scale;
201 }
202 }
203 }
204 }
205
206 Ok(result)
207 }
208
209 pub fn compression_ratio(&self) -> f32 {
211 let original_size = self.data.len() * std::mem::size_of::<f32>();
212 let quantized_size = self.data.len() * std::mem::size_of::<i8>()
213 + self.params.scales.len() * std::mem::size_of::<f32>()
214 + self.params.zero_points.len() * std::mem::size_of::<i32>();
215 original_size as f32 / quantized_size as f32
216 }
217}
218
219pub struct DynamicQuantizer {
221 qtype: QuantizationType,
223 scheme: QuantizationScheme,
225}
226
227impl DynamicQuantizer {
228 pub fn new(qtype: QuantizationType, scheme: QuantizationScheme) -> Self {
230 Self { qtype, scheme }
231 }
232
233 pub fn int8_per_tensor() -> Self {
235 Self::new(QuantizationType::INT8, QuantizationScheme::PerTensor)
236 }
237
238 pub fn int8_per_channel() -> Self {
240 Self::new(QuantizationType::INT8, QuantizationScheme::PerChannel)
241 }
242
243 pub fn int4_per_channel() -> Self {
245 Self::new(QuantizationType::INT4, QuantizationScheme::PerChannel)
246 }
247
248 pub fn quantize_1d(&self, data: &Array1<f32>) -> CoreResult<QuantizedTensor> {
250 let min_val = data.iter().cloned().fold(f32::INFINITY, f32::min);
251 let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
252
253 let (qmin, qmax) = self.get_qrange();
254
255 let scale = if (max_val - min_val).abs() < 1e-8 {
257 1.0
258 } else {
259 (max_val - min_val) / (qmax - qmin) as f32
260 };
261
262 let zero_point = if (max_val - min_val).abs() < 1e-8 {
263 0
264 } else {
265 qmin - (min_val / scale).round() as i32
266 };
267
268 let mut quantized = Vec::with_capacity(data.len());
269 for &val in data.iter() {
270 let q_val = (val / scale).round() as i32 + zero_point;
271 let q_val_clamped = q_val.clamp(qmin, qmax);
272 quantized.push(q_val_clamped as i8);
273 }
274
275 let params = QuantizationParams::new(
276 self.qtype,
277 self.scheme,
278 vec![scale],
279 vec![zero_point],
280 vec![data.len()],
281 );
282
283 QuantizedTensor::new(quantized, params)
284 }
285
286 pub fn quantize_2d(&self, data: &Array2<f32>) -> CoreResult<QuantizedTensor> {
288 let (rows, cols) = data.dim();
289 let (qmin, qmax) = self.get_qrange();
290
291 match self.scheme {
292 QuantizationScheme::PerTensor => {
293 let min_val = data.iter().cloned().fold(f32::INFINITY, f32::min);
295 let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
296
297 let scale = (max_val - min_val) / (qmax - qmin) as f32;
298 let zero_point = qmin - (min_val / scale).round() as i32;
299
300 let mut quantized = Vec::with_capacity(rows * cols);
301 for &val in data.iter() {
302 let q_val = (val / scale).round() as i32 + zero_point;
303 let q_val_clamped = q_val.clamp(qmin, qmax);
304 quantized.push(q_val_clamped as i8);
305 }
306
307 let params = QuantizationParams::new(
308 self.qtype,
309 self.scheme,
310 vec![scale],
311 vec![zero_point],
312 vec![rows, cols],
313 );
314
315 QuantizedTensor::new(quantized, params)
316 }
317 QuantizationScheme::PerChannel => {
318 let mut scales = Vec::with_capacity(rows);
320 let mut zero_points = Vec::with_capacity(rows);
321 let mut quantized = Vec::with_capacity(rows * cols);
322
323 for row in data.axis_iter(Axis(0)) {
324 let min_val = row.iter().cloned().fold(f32::INFINITY, f32::min);
325 let max_val = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
326
327 let scale = (max_val - min_val) / (qmax - qmin) as f32;
328 let zero_point = qmin - (min_val / scale).round() as i32;
329
330 scales.push(scale);
331 zero_points.push(zero_point);
332
333 for &val in row.iter() {
334 let q_val = (val / scale).round() as i32 + zero_point;
335 let q_val_clamped = q_val.clamp(qmin, qmax);
336 quantized.push(q_val_clamped as i8);
337 }
338 }
339
340 let params = QuantizationParams::new(
341 self.qtype,
342 self.scheme,
343 scales,
344 zero_points,
345 vec![rows, cols],
346 );
347
348 QuantizedTensor::new(quantized, params)
349 }
350 }
351 }
352
353 fn get_qrange(&self) -> (i32, i32) {
355 match self.qtype {
356 QuantizationType::INT8 => (-128, 127),
357 QuantizationType::INT4 => (-8, 7),
358 QuantizationType::FP16 => (0, 0),
359 }
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn test_quantization_types() {
369 let qt = QuantizationType::INT8;
370 assert_eq!(qt, QuantizationType::INT8);
371
372 let qs = QuantizationScheme::PerTensor;
373 assert_eq!(qs, QuantizationScheme::PerTensor);
374 }
375
376 #[test]
377 fn test_quantization_params() {
378 let params = QuantizationParams::new(
379 QuantizationType::INT8,
380 QuantizationScheme::PerTensor,
381 vec![0.1],
382 vec![0],
383 vec![100],
384 );
385
386 assert_eq!(params.qtype, QuantizationType::INT8);
387 assert_eq!(params.qrange(), (-128, 127));
388 assert!(params.validate().is_ok());
389 }
390
391 #[test]
392 fn test_params_validation() {
393 let mut params = QuantizationParams::new(
395 QuantizationType::INT8,
396 QuantizationScheme::PerTensor,
397 vec![0.1, 0.2],
398 vec![0],
399 vec![100],
400 );
401 assert!(params.validate().is_err());
402
403 params = QuantizationParams::new(
405 QuantizationType::INT8,
406 QuantizationScheme::PerChannel,
407 vec![0.1],
408 vec![0, 1],
409 vec![2, 100],
410 );
411 assert!(params.validate().is_err());
412
413 params = QuantizationParams::new(
415 QuantizationType::INT8,
416 QuantizationScheme::PerChannel,
417 vec![0.1, 0.2],
418 vec![0, 1],
419 vec![2, 100],
420 );
421 assert!(params.validate().is_ok());
422 }
423
424 #[test]
425 fn test_dynamic_quantizer_creation() {
426 let quantizer = DynamicQuantizer::int8_per_tensor();
427 assert_eq!(quantizer.qtype, QuantizationType::INT8);
428 assert_eq!(quantizer.scheme, QuantizationScheme::PerTensor);
429
430 let quantizer = DynamicQuantizer::int4_per_channel();
431 assert_eq!(quantizer.qtype, QuantizationType::INT4);
432 assert_eq!(quantizer.scheme, QuantizationScheme::PerChannel);
433 }
434
435 #[test]
436 fn test_quantize_dequantize_1d() {
437 let quantizer = DynamicQuantizer::int8_per_tensor();
438 let data = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
439
440 let quantized = quantizer.quantize_1d(&data).unwrap();
441 assert_eq!(quantized.data.len(), 5);
442
443 let dequantized = quantized.dequantize_1d().unwrap();
444 assert_eq!(dequantized.len(), 5);
445
446 for i in 0..5 {
448 let error = (dequantized[i] - data[i]).abs();
449 assert!(error < 0.1, "Reconstruction error too large: {}", error);
450 }
451 }
452
453 #[test]
454 fn test_quantize_dequantize_2d() {
455 let quantizer = DynamicQuantizer::int8_per_tensor();
456 let data = Array2::from_shape_fn((4, 4), |(i, j)| (i * 4 + j) as f32);
457
458 let quantized = quantizer.quantize_2d(&data).unwrap();
459 assert_eq!(quantized.data.len(), 16);
460
461 let dequantized = quantized.dequantize_2d().unwrap();
462 assert_eq!(dequantized.shape(), &[4, 4]);
463
464 for i in 0..4 {
466 for j in 0..4 {
467 let error = (dequantized[[i, j]] - data[[i, j]]).abs();
468 assert!(error < 0.5, "Reconstruction error too large: {}", error);
469 }
470 }
471 }
472
473 #[test]
474 fn test_per_channel_quantization() {
475 let quantizer = DynamicQuantizer::int8_per_channel();
476 let data = Array2::from_shape_fn((3, 4), |(i, j)| (i * 10 + j) as f32);
477
478 let quantized = quantizer.quantize_2d(&data).unwrap();
479 assert_eq!(quantized.params.scales.len(), 3); assert_eq!(quantized.params.zero_points.len(), 3);
481
482 let dequantized = quantized.dequantize_2d().unwrap();
483 assert_eq!(dequantized.shape(), &[3, 4]);
484
485 for i in 0..3 {
487 for j in 0..4 {
488 let error = (dequantized[[i, j]] - data[[i, j]]).abs();
489 assert!(error < 1.0, "Error at [{}, {}]: {}", i, j, error);
490 }
491 }
492 }
493
494 #[test]
495 fn test_compression_ratio() {
496 let quantizer = DynamicQuantizer::int8_per_tensor();
497 let data = Array2::from_shape_fn((100, 100), |(i, j)| (i + j) as f32);
498
499 let quantized = quantizer.quantize_2d(&data).unwrap();
500 let ratio = quantized.compression_ratio();
501
502 assert!(
505 ratio > 3.5 && ratio < 4.1,
506 "Unexpected compression ratio: {}",
507 ratio
508 );
509 }
510
511 #[test]
512 fn test_qrange() {
513 let quantizer = DynamicQuantizer::int8_per_tensor();
514 assert_eq!(quantizer.get_qrange(), (-128, 127));
515
516 let quantizer = DynamicQuantizer::int4_per_channel();
517 assert_eq!(quantizer.get_qrange(), (-8, 7));
518 }
519
520 #[test]
521 fn test_extreme_values() {
522 let quantizer = DynamicQuantizer::int8_per_tensor();
523 let data = Array1::from_vec(vec![-100.0, -50.0, 0.0, 50.0, 100.0]);
524
525 let quantized = quantizer.quantize_1d(&data).unwrap();
526 let dequantized = quantized.dequantize_1d().unwrap();
527
528 for i in 0..5 {
530 let error_pct = ((dequantized[i] - data[i]) / data[i].abs().max(1.0)).abs();
531 assert!(
532 error_pct < 0.05,
533 "Large error at index {}: {}%",
534 i,
535 error_pct * 100.0
536 );
537 }
538 }
539}