1use std::collections::HashMap;
9
10use crate::dtype::Float;
11use crate::error::{FerrotorchError, FerrotorchResult};
12use crate::storage::TensorStorage;
13use crate::tensor::Tensor;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum QuantScheme {
22 PerTensor,
24 PerChannel(usize),
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum QuantDtype {
31 Int8,
33 Int4,
35 Uint8,
37}
38
39impl QuantDtype {
40 #[inline]
42 fn qmin(self) -> i32 {
43 match self {
44 QuantDtype::Int8 => -128,
45 QuantDtype::Int4 => -8,
46 QuantDtype::Uint8 => 0,
47 }
48 }
49
50 #[inline]
52 fn qmax(self) -> i32 {
53 match self {
54 QuantDtype::Int8 => 127,
55 QuantDtype::Int4 => 7,
56 QuantDtype::Uint8 => 255,
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
73pub struct QuantizedTensor {
74 data: Vec<i8>,
78 scale: Vec<f32>,
80 zero_point: Vec<i32>,
82 shape: Vec<usize>,
84 scheme: QuantScheme,
86 dtype: QuantDtype,
88}
89
90impl QuantizedTensor {
91 #[inline]
93 pub fn numel(&self) -> usize {
94 self.shape.iter().product()
95 }
96
97 #[inline]
99 pub fn shape(&self) -> &[usize] {
100 &self.shape
101 }
102
103 #[inline]
105 pub fn data(&self) -> &[i8] {
106 &self.data
107 }
108
109 #[inline]
111 pub fn scale(&self) -> &[f32] {
112 &self.scale
113 }
114
115 #[inline]
117 pub fn zero_point(&self) -> &[i32] {
118 &self.zero_point
119 }
120
121 #[inline]
123 pub fn scheme(&self) -> QuantScheme {
124 self.scheme
125 }
126
127 #[inline]
129 pub fn qdtype(&self) -> QuantDtype {
130 self.dtype
131 }
132}
133
134fn compute_scale_zp(min_val: f32, max_val: f32, dtype: QuantDtype) -> (f32, i32) {
149 let qmin = dtype.qmin();
150 let qmax = dtype.qmax();
151
152 let min_val = min_val.min(0.0);
154 let max_val = max_val.max(0.0);
155
156 let range = (max_val - min_val).max(f32::EPSILON);
159 let scale = range / (qmax - qmin) as f32;
160
161 let zp = (qmin as f32 - min_val / scale).round() as i32;
166
167 (scale, zp)
168}
169
170#[inline]
176fn quantize_val(x: f32, scale: f32, zp: i32, qmin: i32, qmax: i32, is_unsigned: bool) -> i8 {
177 let q = (x / scale + zp as f32).round() as i32;
178 let clamped = q.clamp(qmin, qmax);
179 if is_unsigned {
180 (clamped as u8) as i8
181 } else {
182 clamped as i8
183 }
184}
185
186#[inline]
189fn stored_to_i32(val: i8, is_unsigned: bool) -> i32 {
190 if is_unsigned {
191 (val as u8) as i32
192 } else {
193 val as i32
194 }
195}
196
197#[inline]
202fn channel_index(flat_index: usize, shape: &[usize], axis: usize) -> usize {
203 let stride: usize = shape[axis + 1..].iter().product();
205 (flat_index / stride) % shape[axis]
206}
207
208pub fn quantize<T: Float>(
223 tensor: &Tensor<T>,
224 scheme: QuantScheme,
225 dtype: QuantDtype,
226) -> FerrotorchResult<QuantizedTensor> {
227 let data = tensor.data()?;
228 let shape = tensor.shape().to_vec();
229 let numel = tensor.numel();
230 let qmin = dtype.qmin();
231 let qmax = dtype.qmax();
232
233 let is_unsigned = dtype == QuantDtype::Uint8;
234
235 match scheme {
236 QuantScheme::PerTensor => {
237 let mut min_val = f32::INFINITY;
239 let mut max_val = f32::NEG_INFINITY;
240 for &v in data {
241 let f = v.to_f32().unwrap();
242 if f < min_val {
243 min_val = f;
244 }
245 if f > max_val {
246 max_val = f;
247 }
248 }
249
250 let (scale, zp) = compute_scale_zp(min_val, max_val, dtype);
251
252 let qdata: Vec<i8> = data
253 .iter()
254 .map(|&v| {
255 quantize_val(v.to_f32().unwrap(), scale, zp, qmin, qmax, is_unsigned)
256 })
257 .collect();
258
259 Ok(QuantizedTensor {
260 data: qdata,
261 scale: vec![scale],
262 zero_point: vec![zp],
263 shape,
264 scheme,
265 dtype,
266 })
267 }
268
269 QuantScheme::PerChannel(axis) => {
270 if axis >= shape.len() {
271 return Err(FerrotorchError::InvalidArgument {
272 message: format!(
273 "PerChannel axis {axis} out of range for {}-d tensor",
274 shape.len()
275 ),
276 });
277 }
278
279 let num_channels = shape[axis];
280 let mut mins = vec![f32::INFINITY; num_channels];
281 let mut maxs = vec![f32::NEG_INFINITY; num_channels];
282
283 for (i, &v) in data.iter().enumerate() {
284 let ch = channel_index(i, &shape, axis);
285 let f = v.to_f32().unwrap();
286 if f < mins[ch] {
287 mins[ch] = f;
288 }
289 if f > maxs[ch] {
290 maxs[ch] = f;
291 }
292 }
293
294 let params: Vec<(f32, i32)> = mins
295 .iter()
296 .zip(maxs.iter())
297 .map(|(&mn, &mx)| compute_scale_zp(mn, mx, dtype))
298 .collect();
299
300 let scales: Vec<f32> = params.iter().map(|&(s, _)| s).collect();
301 let zps: Vec<i32> = params.iter().map(|&(_, z)| z).collect();
302
303 let mut qdata = Vec::with_capacity(numel);
304 for (i, &v) in data.iter().enumerate() {
305 let ch = channel_index(i, &shape, axis);
306 qdata.push(quantize_val(
307 v.to_f32().unwrap(),
308 scales[ch],
309 zps[ch],
310 qmin,
311 qmax,
312 is_unsigned,
313 ));
314 }
315
316 Ok(QuantizedTensor {
317 data: qdata,
318 scale: scales,
319 zero_point: zps,
320 shape,
321 scheme,
322 dtype,
323 })
324 }
325 }
326}
327
328pub fn dequantize<T: Float>(qtensor: &QuantizedTensor) -> FerrotorchResult<Tensor<T>> {
336 let numel = qtensor.numel();
337 let mut result = Vec::with_capacity(numel);
338 let is_unsigned = qtensor.dtype == QuantDtype::Uint8;
339
340 match qtensor.scheme {
341 QuantScheme::PerTensor => {
342 let scale = qtensor.scale[0];
343 let zp = qtensor.zero_point[0];
344 for &q in &qtensor.data {
345 let val = (stored_to_i32(q, is_unsigned) - zp) as f32 * scale;
346 result.push(T::from(val).unwrap());
347 }
348 }
349 QuantScheme::PerChannel(axis) => {
350 for (i, &q) in qtensor.data.iter().enumerate() {
351 let ch = channel_index(i, &qtensor.shape, axis);
352 let val = (stored_to_i32(q, is_unsigned) - qtensor.zero_point[ch]) as f32
353 * qtensor.scale[ch];
354 result.push(T::from(val).unwrap());
355 }
356 }
357 }
358
359 Tensor::from_storage(TensorStorage::cpu(result), qtensor.shape.clone(), false)
360}
361
362pub fn quantized_matmul(
375 a: &QuantizedTensor,
376 b: &QuantizedTensor,
377) -> FerrotorchResult<QuantizedTensor> {
378 if a.shape.len() != 2 || b.shape.len() != 2 {
380 return Err(FerrotorchError::InvalidArgument {
381 message: format!(
382 "quantized_matmul requires 2-D tensors, got shapes {:?} and {:?}",
383 a.shape, b.shape
384 ),
385 });
386 }
387
388 let m = a.shape[0];
389 let k = a.shape[1];
390 let k2 = b.shape[0];
391 let n = b.shape[1];
392
393 if k != k2 {
394 return Err(FerrotorchError::ShapeMismatch {
395 message: format!(
396 "quantized_matmul inner dimensions mismatch: [{m}, {k}] x [{k2}, {n}]"
397 ),
398 });
399 }
400
401 if a.scale.len() != 1 || b.scale.len() != 1 {
403 return Err(FerrotorchError::InvalidArgument {
404 message: "quantized_matmul currently requires PerTensor-quantized inputs".into(),
405 });
406 }
407
408 let a_scale = a.scale[0];
409 let a_zp = a.zero_point[0];
410 let b_scale = b.scale[0];
411 let b_zp = b.zero_point[0];
412
413 let a_unsigned = a.dtype == QuantDtype::Uint8;
414 let b_unsigned = b.dtype == QuantDtype::Uint8;
415
416 let mut acc = vec![0i32; m * n];
418 for i in 0..m {
419 for j in 0..n {
420 let mut sum = 0i32;
421 for p in 0..k {
422 let qa = stored_to_i32(a.data[i * k + p], a_unsigned) - a_zp;
423 let qb = stored_to_i32(b.data[p * n + j], b_unsigned) - b_zp;
424 sum += qa * qb;
425 }
426 acc[i * n + j] = sum;
427 }
428 }
429
430 let combined_scale = a_scale * b_scale;
433
434 let mut out_min = f32::INFINITY;
436 let mut out_max = f32::NEG_INFINITY;
437 for &a_val in &acc {
438 let real = a_val as f32 * combined_scale;
439 if real < out_min {
440 out_min = real;
441 }
442 if real > out_max {
443 out_max = real;
444 }
445 }
446
447 let out_dtype = QuantDtype::Int8;
448 let (out_scale, out_zp) = compute_scale_zp(out_min, out_max, out_dtype);
449 let qmin = out_dtype.qmin();
450 let qmax = out_dtype.qmax();
451
452 let qdata: Vec<i8> = acc
453 .iter()
454 .map(|&a_val| {
455 let real = a_val as f32 * combined_scale;
456 quantize_val(real, out_scale, out_zp, qmin, qmax, false)
457 })
458 .collect();
459
460 Ok(QuantizedTensor {
461 data: qdata,
462 scale: vec![out_scale],
463 zero_point: vec![out_zp],
464 shape: vec![m, n],
465 scheme: QuantScheme::PerTensor,
466 dtype: out_dtype,
467 })
468}
469
470pub fn quantize_named_tensors<T: Float>(
481 named_tensors: impl IntoIterator<Item = (String, Tensor<T>)>,
482 scheme: QuantScheme,
483 dtype: QuantDtype,
484) -> FerrotorchResult<HashMap<String, QuantizedTensor>> {
485 let mut result = HashMap::new();
486 for (name, tensor) in named_tensors {
487 let qtensor = quantize(&tensor, scheme, dtype)?;
488 result.insert(name, qtensor);
489 }
490 Ok(result)
491}
492
493#[cfg(test)]
498mod tests {
499 use super::*;
500
501 fn make_tensor(data: &[f32], shape: &[usize]) -> Tensor<f32> {
503 crate::from_slice(data, shape).unwrap()
504 }
505
506 #[test]
509 fn test_per_tensor_int8_roundtrip() {
510 let data: Vec<f32> = (-10..=10).map(|x| x as f32 * 0.5).collect();
511 let t = make_tensor(&data, &[data.len()]);
512 let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
513 let rt: Tensor<f32> = dequantize(&qt).unwrap();
514
515 assert_eq!(rt.shape(), t.shape());
516 let orig = t.data().unwrap();
517 let recovered = rt.data().unwrap();
518 for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
519 let err = (o - r).abs();
520 assert!(
522 err < 0.05,
523 "element {i}: original={o}, recovered={r}, error={err}"
524 );
525 }
526 }
527
528 #[test]
529 fn test_per_tensor_uint8_roundtrip() {
530 let data: Vec<f32> = (0..=20).map(|x| x as f32 * 0.1).collect();
531 let t = make_tensor(&data, &[data.len()]);
532 let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Uint8).unwrap();
533 let rt: Tensor<f32> = dequantize(&qt).unwrap();
534
535 let orig = t.data().unwrap();
536 let recovered = rt.data().unwrap();
537 for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
538 let err = (o - r).abs();
539 assert!(
541 err < 0.02,
542 "element {i}: original={o}, recovered={r}, error={err}"
543 );
544 }
545 }
546
547 #[test]
548 fn test_per_tensor_int4_roundtrip() {
549 let data: Vec<f32> = (-8..=7).map(|x| x as f32).collect();
551 let t = make_tensor(&data, &[data.len()]);
552 let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int4).unwrap();
553 let rt: Tensor<f32> = dequantize(&qt).unwrap();
554
555 let orig = t.data().unwrap();
556 let recovered = rt.data().unwrap();
557 for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
558 let err = (o - r).abs();
559 assert!(
561 err < 1.01,
562 "element {i}: original={o}, recovered={r}, error={err}"
563 );
564 }
565 }
566
567 #[test]
570 fn test_per_channel_int8_roundtrip() {
571 #[rustfmt::skip]
573 let data: Vec<f32> = vec![
574 0.0, 1.0, 2.0, 3.0,
576 -10.0, -5.0, 5.0, 10.0,
578 100.0, 130.0, 170.0, 200.0,
580 ];
581 let t = make_tensor(&data, &[3, 4]);
582 let qt = quantize(&t, QuantScheme::PerChannel(0), QuantDtype::Int8).unwrap();
583 let rt: Tensor<f32> = dequantize(&qt).unwrap();
584
585 assert_eq!(qt.scale.len(), 3);
586 assert_eq!(qt.zero_point.len(), 3);
587
588 let orig = t.data().unwrap();
589 let recovered = rt.data().unwrap();
590 for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
591 let err = (o - r).abs();
592 assert!(
595 err < 0.5,
596 "element {i}: original={o}, recovered={r}, error={err}"
597 );
598 }
599 }
600
601 #[test]
602 fn test_per_channel_axis_out_of_bounds() {
603 let t = make_tensor(&[1.0, 2.0, 3.0], &[3]);
604 let result = quantize(&t, QuantScheme::PerChannel(5), QuantDtype::Int8);
605 assert!(result.is_err());
606 }
607
608 #[test]
611 fn test_quantized_matmul_identity() {
612 let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
614 let a = make_tensor(&a_data, &[2, 2]);
615 let eye = make_tensor(&[1.0, 0.0, 0.0, 1.0], &[2, 2]);
616
617 let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
618 let qi = quantize(&eye, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
619 let qc = quantized_matmul(&qa, &qi).unwrap();
620 let c: Tensor<f32> = dequantize(&qc).unwrap();
621
622 assert_eq!(c.shape(), &[2, 2]);
623 let c_data = c.data().unwrap();
624 for (i, (&expected, &got)) in a_data.iter().zip(c_data.iter()).enumerate() {
625 let err = (expected - got).abs();
626 assert!(
627 err < 0.5,
628 "element {i}: expected={expected}, got={got}, error={err}"
629 );
630 }
631 }
632
633 #[test]
634 fn test_quantized_matmul_correctness() {
635 let a = make_tensor(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
644 let b = make_tensor(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2]);
645
646 let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
647 let qb = quantize(&b, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
648 let qc = quantized_matmul(&qa, &qb).unwrap();
649 let c: Tensor<f32> = dequantize(&qc).unwrap();
650
651 let expected = [58.0f32, 64.0, 139.0, 154.0];
652 let c_data = c.data().unwrap();
653 assert_eq!(c.shape(), &[2, 2]);
654 for (i, (&e, &g)) in expected.iter().zip(c_data.iter()).enumerate() {
655 let err = (e - g).abs();
656 assert!(
659 err < 3.0,
660 "element {i}: expected={e}, got={g}, error={err}"
661 );
662 }
663 }
664
665 #[test]
666 fn test_quantized_matmul_shape_mismatch() {
667 let a = make_tensor(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
668 let b = make_tensor(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
669
670 let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
671 let qb = quantize(&b, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
672 let result = quantized_matmul(&qa, &qb);
673 assert!(result.is_err());
674 }
675
676 #[test]
677 fn test_quantized_matmul_non_2d() {
678 let a = make_tensor(&[1.0, 2.0, 3.0], &[3]);
679 let b = make_tensor(&[4.0, 5.0, 6.0], &[3]);
680
681 let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
682 let qb = quantize(&b, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
683 let result = quantized_matmul(&qa, &qb);
684 assert!(result.is_err());
685 }
686
687 #[test]
690 fn test_quantize_named_tensors() {
691 let w1 = make_tensor(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
692 let w2 = make_tensor(&[-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], &[3, 2]);
693
694 let named = vec![
695 ("layer.weight".to_string(), w1),
696 ("layer2.weight".to_string(), w2),
697 ];
698
699 let qmap =
700 quantize_named_tensors(named, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
701
702 assert_eq!(qmap.len(), 2);
703 assert!(qmap.contains_key("layer.weight"));
704 assert!(qmap.contains_key("layer2.weight"));
705 assert_eq!(qmap["layer.weight"].shape(), &[2, 2]);
706 assert_eq!(qmap["layer2.weight"].shape(), &[3, 2]);
707 }
708
709 #[test]
712 fn test_quantize_constant_tensor() {
713 let t = make_tensor(&[5.0, 5.0, 5.0, 5.0], &[4]);
715 let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
716 let rt: Tensor<f32> = dequantize(&qt).unwrap();
717
718 let recovered = rt.data().unwrap();
719 for &r in recovered {
720 assert!(
721 (r - 5.0).abs() < 0.1,
722 "constant tensor dequantized to {r}, expected 5.0"
723 );
724 }
725 }
726
727 #[test]
728 fn test_quantize_single_element() {
729 let t = make_tensor(&[42.0], &[1]);
730 let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
731 let rt: Tensor<f32> = dequantize(&qt).unwrap();
732 assert!((rt.data().unwrap()[0] - 42.0).abs() < 0.5);
733 }
734
735 #[test]
736 fn test_per_channel_int4() {
737 let data = vec![0.0, 1.0, 2.0, -4.0, 0.0, 4.0];
739 let t = make_tensor(&data, &[2, 3]);
740 let qt = quantize(&t, QuantScheme::PerChannel(0), QuantDtype::Int4).unwrap();
741
742 assert_eq!(qt.scale.len(), 2);
743 assert_eq!(qt.zero_point.len(), 2);
744
745 let rt: Tensor<f32> = dequantize(&qt).unwrap();
746 let orig = t.data().unwrap();
747 let recovered = rt.data().unwrap();
748 for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
749 let err = (o - r).abs();
750 assert!(
752 err < 1.0,
753 "element {i}: original={o}, recovered={r}, error={err}"
754 );
755 }
756 }
757
758 #[test]
759 fn test_dequantize_f64() {
760 let data = vec![1.0f32, 2.0, 3.0, 4.0];
761 let t = crate::from_slice(&data, &[4]).unwrap();
762 let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
763 let rt: Tensor<f64> = dequantize(&qt).unwrap();
764
765 assert_eq!(rt.shape(), &[4]);
766 let recovered = rt.data().unwrap();
767 for (i, &r) in recovered.iter().enumerate() {
768 let expected = data[i] as f64;
769 let err = (expected - r).abs();
770 assert!(
771 err < 0.05,
772 "element {i}: expected={expected}, recovered={r}, error={err}"
773 );
774 }
775 }
776
777 #[test]
778 fn test_quantized_tensor_accessors() {
779 let t = make_tensor(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
780 let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
781
782 assert_eq!(qt.numel(), 6);
783 assert_eq!(qt.shape(), &[2, 3]);
784 assert_eq!(qt.data().len(), 6);
785 assert_eq!(qt.scale().len(), 1);
786 assert_eq!(qt.zero_point().len(), 1);
787 assert_eq!(qt.scheme(), QuantScheme::PerTensor);
788 assert_eq!(qt.qdtype(), QuantDtype::Int8);
789 }
790}