1use std::collections::HashMap;
9
10use crate::dtype::Float;
11use crate::error::{FerrotorchError, FerrotorchResult};
12use crate::storage::TensorStorage;
13use crate::tensor::Tensor;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum QuantScheme {
22 PerTensor,
24 PerChannel(usize),
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum QuantDtype {
31 Int8,
33 Int4,
35 Uint8,
37}
38
39impl QuantDtype {
40 #[inline]
42 fn qmin(self) -> i32 {
43 match self {
44 QuantDtype::Int8 => -128,
45 QuantDtype::Int4 => -8,
46 QuantDtype::Uint8 => 0,
47 }
48 }
49
50 #[inline]
52 fn qmax(self) -> i32 {
53 match self {
54 QuantDtype::Int8 => 127,
55 QuantDtype::Int4 => 7,
56 QuantDtype::Uint8 => 255,
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
73pub struct QuantizedTensor {
74 data: Vec<i8>,
78 scale: Vec<f32>,
80 zero_point: Vec<i32>,
82 shape: Vec<usize>,
84 scheme: QuantScheme,
86 dtype: QuantDtype,
88}
89
90impl QuantizedTensor {
91 #[inline]
93 pub fn numel(&self) -> usize {
94 self.shape.iter().product()
95 }
96
97 #[inline]
99 pub fn shape(&self) -> &[usize] {
100 &self.shape
101 }
102
103 #[inline]
105 pub fn data(&self) -> &[i8] {
106 &self.data
107 }
108
109 #[inline]
111 pub fn scale(&self) -> &[f32] {
112 &self.scale
113 }
114
115 #[inline]
117 pub fn zero_point(&self) -> &[i32] {
118 &self.zero_point
119 }
120
121 #[inline]
123 pub fn scheme(&self) -> QuantScheme {
124 self.scheme
125 }
126
127 #[inline]
129 pub fn qdtype(&self) -> QuantDtype {
130 self.dtype
131 }
132}
133
134fn compute_scale_zp(min_val: f32, max_val: f32, dtype: QuantDtype) -> (f32, i32) {
149 let qmin = dtype.qmin();
150 let qmax = dtype.qmax();
151
152 let min_val = min_val.min(0.0);
154 let max_val = max_val.max(0.0);
155
156 let range = (max_val - min_val).max(f32::EPSILON);
159 let scale = range / (qmax - qmin) as f32;
160
161 let zp = (qmin as f32 - min_val / scale).round() as i32;
166
167 (scale, zp)
168}
169
170#[inline]
176fn quantize_val(x: f32, scale: f32, zp: i32, qmin: i32, qmax: i32, is_unsigned: bool) -> i8 {
177 let q = (x / scale + zp as f32).round() as i32;
178 let clamped = q.clamp(qmin, qmax);
179 if is_unsigned {
180 (clamped as u8) as i8
181 } else {
182 clamped as i8
183 }
184}
185
186#[inline]
189fn stored_to_i32(val: i8, is_unsigned: bool) -> i32 {
190 if is_unsigned {
191 (val as u8) as i32
192 } else {
193 val as i32
194 }
195}
196
197#[inline]
202fn channel_index(flat_index: usize, shape: &[usize], axis: usize) -> usize {
203 let stride: usize = shape[axis + 1..].iter().product();
205 (flat_index / stride) % shape[axis]
206}
207
208pub fn quantize<T: Float>(
223 tensor: &Tensor<T>,
224 scheme: QuantScheme,
225 dtype: QuantDtype,
226) -> FerrotorchResult<QuantizedTensor> {
227 let data = tensor.data()?;
228 let shape = tensor.shape().to_vec();
229 let numel = tensor.numel();
230 let qmin = dtype.qmin();
231 let qmax = dtype.qmax();
232
233 let is_unsigned = dtype == QuantDtype::Uint8;
234
235 match scheme {
236 QuantScheme::PerTensor => {
237 let mut min_val = f32::INFINITY;
239 let mut max_val = f32::NEG_INFINITY;
240 for &v in data {
241 let f = v.to_f32().unwrap();
242 if f < min_val {
243 min_val = f;
244 }
245 if f > max_val {
246 max_val = f;
247 }
248 }
249
250 let (scale, zp) = compute_scale_zp(min_val, max_val, dtype);
251
252 let qdata: Vec<i8> = data
253 .iter()
254 .map(|&v| quantize_val(v.to_f32().unwrap(), scale, zp, qmin, qmax, is_unsigned))
255 .collect();
256
257 Ok(QuantizedTensor {
258 data: qdata,
259 scale: vec![scale],
260 zero_point: vec![zp],
261 shape,
262 scheme,
263 dtype,
264 })
265 }
266
267 QuantScheme::PerChannel(axis) => {
268 if axis >= shape.len() {
269 return Err(FerrotorchError::InvalidArgument {
270 message: format!(
271 "PerChannel axis {axis} out of range for {}-d tensor",
272 shape.len()
273 ),
274 });
275 }
276
277 let num_channels = shape[axis];
278 let mut mins = vec![f32::INFINITY; num_channels];
279 let mut maxs = vec![f32::NEG_INFINITY; num_channels];
280
281 for (i, &v) in data.iter().enumerate() {
282 let ch = channel_index(i, &shape, axis);
283 let f = v.to_f32().unwrap();
284 if f < mins[ch] {
285 mins[ch] = f;
286 }
287 if f > maxs[ch] {
288 maxs[ch] = f;
289 }
290 }
291
292 let params: Vec<(f32, i32)> = mins
293 .iter()
294 .zip(maxs.iter())
295 .map(|(&mn, &mx)| compute_scale_zp(mn, mx, dtype))
296 .collect();
297
298 let scales: Vec<f32> = params.iter().map(|&(s, _)| s).collect();
299 let zps: Vec<i32> = params.iter().map(|&(_, z)| z).collect();
300
301 let mut qdata = Vec::with_capacity(numel);
302 for (i, &v) in data.iter().enumerate() {
303 let ch = channel_index(i, &shape, axis);
304 qdata.push(quantize_val(
305 v.to_f32().unwrap(),
306 scales[ch],
307 zps[ch],
308 qmin,
309 qmax,
310 is_unsigned,
311 ));
312 }
313
314 Ok(QuantizedTensor {
315 data: qdata,
316 scale: scales,
317 zero_point: zps,
318 shape,
319 scheme,
320 dtype,
321 })
322 }
323 }
324}
325
326pub fn dequantize<T: Float>(qtensor: &QuantizedTensor) -> FerrotorchResult<Tensor<T>> {
334 let numel = qtensor.numel();
335 let mut result = Vec::with_capacity(numel);
336 let is_unsigned = qtensor.dtype == QuantDtype::Uint8;
337
338 match qtensor.scheme {
339 QuantScheme::PerTensor => {
340 let scale = qtensor.scale[0];
341 let zp = qtensor.zero_point[0];
342 for &q in &qtensor.data {
343 let val = (stored_to_i32(q, is_unsigned) - zp) as f32 * scale;
344 result.push(T::from(val).unwrap());
345 }
346 }
347 QuantScheme::PerChannel(axis) => {
348 for (i, &q) in qtensor.data.iter().enumerate() {
349 let ch = channel_index(i, &qtensor.shape, axis);
350 let val = (stored_to_i32(q, is_unsigned) - qtensor.zero_point[ch]) as f32
351 * qtensor.scale[ch];
352 result.push(T::from(val).unwrap());
353 }
354 }
355 }
356
357 Tensor::from_storage(TensorStorage::cpu(result), qtensor.shape.clone(), false)
358}
359
360pub fn quantized_matmul(
373 a: &QuantizedTensor,
374 b: &QuantizedTensor,
375) -> FerrotorchResult<QuantizedTensor> {
376 if a.shape.len() != 2 || b.shape.len() != 2 {
378 return Err(FerrotorchError::InvalidArgument {
379 message: format!(
380 "quantized_matmul requires 2-D tensors, got shapes {:?} and {:?}",
381 a.shape, b.shape
382 ),
383 });
384 }
385
386 let m = a.shape[0];
387 let k = a.shape[1];
388 let k2 = b.shape[0];
389 let n = b.shape[1];
390
391 if k != k2 {
392 return Err(FerrotorchError::ShapeMismatch {
393 message: format!(
394 "quantized_matmul inner dimensions mismatch: [{m}, {k}] x [{k2}, {n}]"
395 ),
396 });
397 }
398
399 if a.scale.len() != 1 || b.scale.len() != 1 {
401 return Err(FerrotorchError::InvalidArgument {
402 message: "quantized_matmul currently requires PerTensor-quantized inputs".into(),
403 });
404 }
405
406 let a_scale = a.scale[0];
407 let a_zp = a.zero_point[0];
408 let b_scale = b.scale[0];
409 let b_zp = b.zero_point[0];
410
411 let a_unsigned = a.dtype == QuantDtype::Uint8;
412 let b_unsigned = b.dtype == QuantDtype::Uint8;
413
414 let mut acc = vec![0i32; m * n];
416 for i in 0..m {
417 for j in 0..n {
418 let mut sum = 0i32;
419 for p in 0..k {
420 let qa = stored_to_i32(a.data[i * k + p], a_unsigned) - a_zp;
421 let qb = stored_to_i32(b.data[p * n + j], b_unsigned) - b_zp;
422 sum += qa * qb;
423 }
424 acc[i * n + j] = sum;
425 }
426 }
427
428 let combined_scale = a_scale * b_scale;
431
432 let mut out_min = f32::INFINITY;
434 let mut out_max = f32::NEG_INFINITY;
435 for &a_val in &acc {
436 let real = a_val as f32 * combined_scale;
437 if real < out_min {
438 out_min = real;
439 }
440 if real > out_max {
441 out_max = real;
442 }
443 }
444
445 let out_dtype = QuantDtype::Int8;
446 let (out_scale, out_zp) = compute_scale_zp(out_min, out_max, out_dtype);
447 let qmin = out_dtype.qmin();
448 let qmax = out_dtype.qmax();
449
450 let qdata: Vec<i8> = acc
451 .iter()
452 .map(|&a_val| {
453 let real = a_val as f32 * combined_scale;
454 quantize_val(real, out_scale, out_zp, qmin, qmax, false)
455 })
456 .collect();
457
458 Ok(QuantizedTensor {
459 data: qdata,
460 scale: vec![out_scale],
461 zero_point: vec![out_zp],
462 shape: vec![m, n],
463 scheme: QuantScheme::PerTensor,
464 dtype: out_dtype,
465 })
466}
467
468pub fn quantize_named_tensors<T: Float>(
479 named_tensors: impl IntoIterator<Item = (String, Tensor<T>)>,
480 scheme: QuantScheme,
481 dtype: QuantDtype,
482) -> FerrotorchResult<HashMap<String, QuantizedTensor>> {
483 let mut result = HashMap::new();
484 for (name, tensor) in named_tensors {
485 let qtensor = quantize(&tensor, scheme, dtype)?;
486 result.insert(name, qtensor);
487 }
488 Ok(result)
489}
490
491#[derive(Debug, Clone)]
497pub struct QParams {
498 pub scale: Vec<f32>,
500 pub zero_point: Vec<i32>,
502}
503
504impl QParams {
505 pub fn symmetric(max_abs: f32, dtype: QuantDtype) -> Self {
512 let max_abs = max_abs.max(f32::EPSILON);
513 match dtype {
514 QuantDtype::Int8 => QParams {
515 scale: vec![max_abs / 127.0],
516 zero_point: vec![0],
517 },
518 QuantDtype::Int4 => QParams {
519 scale: vec![max_abs / 7.0],
520 zero_point: vec![0],
521 },
522 QuantDtype::Uint8 => QParams {
523 scale: vec![max_abs / 128.0],
524 zero_point: vec![128],
525 },
526 }
527 }
528
529 pub fn asymmetric(min_val: f32, max_val: f32, dtype: QuantDtype) -> Self {
531 let (scale, zp) = compute_scale_zp(min_val, max_val, dtype);
532 QParams {
533 scale: vec![scale],
534 zero_point: vec![zp],
535 }
536 }
537}
538
539pub trait Observer {
545 fn observe(&mut self, data: &[f32]);
547 fn calculate_qparams(&self, dtype: QuantDtype) -> QParams;
549 fn reset(&mut self);
551}
552
553#[derive(Debug, Clone)]
561pub struct MinMaxObserver {
562 min_val: f32,
563 max_val: f32,
564}
565
566impl MinMaxObserver {
567 pub fn new() -> Self {
568 Self {
569 min_val: f32::INFINITY,
570 max_val: f32::NEG_INFINITY,
571 }
572 }
573}
574
575impl Default for MinMaxObserver {
576 fn default() -> Self {
577 Self::new()
578 }
579}
580
581impl Observer for MinMaxObserver {
582 fn observe(&mut self, data: &[f32]) {
583 for &x in data {
584 if !x.is_finite() {
585 continue;
586 }
587 if x < self.min_val {
588 self.min_val = x;
589 }
590 if x > self.max_val {
591 self.max_val = x;
592 }
593 }
594 }
595
596 fn calculate_qparams(&self, dtype: QuantDtype) -> QParams {
597 QParams::asymmetric(self.min_val, self.max_val, dtype)
598 }
599
600 fn reset(&mut self) {
601 self.min_val = f32::INFINITY;
602 self.max_val = f32::NEG_INFINITY;
603 }
604}
605
606#[derive(Debug, Clone)]
616pub struct PerChannelMinMaxObserver {
617 num_channels: usize,
618 axis: usize,
619 min_vals: Vec<f32>,
620 max_vals: Vec<f32>,
621}
622
623impl PerChannelMinMaxObserver {
624 pub fn new(num_channels: usize, axis: usize) -> Self {
629 Self {
630 num_channels,
631 axis,
632 min_vals: vec![f32::INFINITY; num_channels],
633 max_vals: vec![f32::NEG_INFINITY; num_channels],
634 }
635 }
636
637 pub fn observe_with_shape(&mut self, data: &[f32], shape: &[usize]) -> FerrotorchResult<()> {
641 if self.axis >= shape.len() {
642 return Err(FerrotorchError::InvalidArgument {
643 message: format!(
644 "PerChannelMinMaxObserver axis {} out of range for {}-d tensor",
645 self.axis,
646 shape.len()
647 ),
648 });
649 }
650 let actual_channels = shape[self.axis];
651 if actual_channels != self.num_channels {
652 eprintln!(
653 "WARNING: PerChannelMinMaxObserver expected {} channels on axis {}, got {}",
654 self.num_channels, self.axis, actual_channels
655 );
656 return Err(FerrotorchError::InvalidArgument {
657 message: format!(
658 "PerChannelMinMaxObserver expected {} channels on axis {}, got {}",
659 self.num_channels, self.axis, actual_channels
660 ),
661 });
662 }
663
664 for (i, &x) in data.iter().enumerate() {
665 if !x.is_finite() {
666 continue;
667 }
668 let ch = channel_index(i, shape, self.axis);
669 if x < self.min_vals[ch] {
670 self.min_vals[ch] = x;
671 }
672 if x > self.max_vals[ch] {
673 self.max_vals[ch] = x;
674 }
675 }
676 Ok(())
677 }
678}
679
680impl Observer for PerChannelMinMaxObserver {
681 fn observe(&mut self, data: &[f32]) {
682 if data.len() % self.num_channels != 0 {
684 eprintln!(
685 "WARNING: PerChannelMinMaxObserver data length {} not divisible by {} channels",
686 data.len(),
687 self.num_channels
688 );
689 return;
690 }
691 let per_channel = data.len() / self.num_channels;
692 for (i, &x) in data.iter().enumerate() {
693 if !x.is_finite() {
694 continue;
695 }
696 let ch = i / per_channel;
697 if ch >= self.num_channels {
698 continue;
699 }
700 if x < self.min_vals[ch] {
701 self.min_vals[ch] = x;
702 }
703 if x > self.max_vals[ch] {
704 self.max_vals[ch] = x;
705 }
706 }
707 }
708
709 fn calculate_qparams(&self, dtype: QuantDtype) -> QParams {
710 let params: Vec<(f32, i32)> = self
711 .min_vals
712 .iter()
713 .zip(self.max_vals.iter())
714 .map(|(&mn, &mx)| compute_scale_zp(mn, mx, dtype))
715 .collect();
716 QParams {
717 scale: params.iter().map(|&(s, _)| s).collect(),
718 zero_point: params.iter().map(|&(_, z)| z).collect(),
719 }
720 }
721
722 fn reset(&mut self) {
723 self.min_vals.fill(f32::INFINITY);
724 self.max_vals.fill(f32::NEG_INFINITY);
725 }
726}
727
728#[derive(Debug, Clone)]
737pub struct HistogramObserver {
738 num_bins: usize,
739 bins: Vec<u64>,
740 min_val: f32,
741 max_val: f32,
742 initialized: bool,
744}
745
746impl HistogramObserver {
747 pub fn new(num_bins: usize) -> Self {
748 Self {
749 num_bins,
750 bins: vec![0u64; num_bins],
751 min_val: f32::INFINITY,
752 max_val: f32::NEG_INFINITY,
753 initialized: false,
754 }
755 }
756
757 fn redistribute(&mut self, new_min: f32, new_max: f32) {
759 if !self.initialized || self.bins.iter().all(|&c| c == 0) {
760 self.min_val = new_min;
761 self.max_val = new_max;
762 return;
763 }
764
765 let old_min = self.min_val;
766 let old_max = self.max_val;
767 let old_range = old_max - old_min;
768 let new_range = new_max - new_min;
769
770 if old_range <= 0.0 || new_range <= 0.0 {
771 self.min_val = new_min;
772 self.max_val = new_max;
773 return;
774 }
775
776 let n = self.num_bins;
777 let old_bins = self.bins.clone();
778 self.bins.fill(0);
779
780 let old_bin_width = old_range / n as f32;
781 let new_bin_width = new_range / n as f32;
782
783 for (old_idx, &old_count) in old_bins.iter().enumerate().take(n) {
784 if old_count == 0 {
785 continue;
786 }
787 let old_center = old_min + (old_idx as f32 + 0.5) * old_bin_width;
789 let new_frac = (old_center - new_min) / new_bin_width;
791 let new_idx = (new_frac as usize).min(n - 1);
792 self.bins[new_idx] += old_count;
793 }
794
795 self.min_val = new_min;
796 self.max_val = new_max;
797 }
798}
799
800impl Observer for HistogramObserver {
801 fn observe(&mut self, data: &[f32]) {
802 let mut batch_min = f32::INFINITY;
804 let mut batch_max = f32::NEG_INFINITY;
805 for &x in data {
806 if !x.is_finite() {
807 continue;
808 }
809 if x < batch_min {
810 batch_min = x;
811 }
812 if x > batch_max {
813 batch_max = x;
814 }
815 }
816
817 if batch_min > batch_max {
818 return;
820 }
821
822 let new_min = if self.initialized {
824 self.min_val.min(batch_min)
825 } else {
826 batch_min
827 };
828 let new_max = if self.initialized {
829 self.max_val.max(batch_max)
830 } else {
831 batch_max
832 };
833
834 if self.initialized && (new_min < self.min_val || new_max > self.max_val) {
835 self.redistribute(new_min, new_max);
837 } else if !self.initialized {
838 self.min_val = new_min;
839 self.max_val = new_max;
840 self.initialized = true;
841 }
842
843 let range = (self.max_val - self.min_val).max(f32::EPSILON);
845 let n = self.num_bins;
846 for &x in data {
847 if !x.is_finite() {
848 continue;
849 }
850 let frac = (x - self.min_val) / range;
851 let idx = ((frac * n as f32) as usize).min(n - 1);
852 self.bins[idx] += 1;
853 }
854 }
855
856 fn calculate_qparams(&self, dtype: QuantDtype) -> QParams {
857 QParams::asymmetric(self.min_val, self.max_val, dtype)
858 }
859
860 fn reset(&mut self) {
861 self.bins.fill(0);
862 self.min_val = f32::INFINITY;
863 self.max_val = f32::NEG_INFINITY;
864 self.initialized = false;
865 }
866}
867
868#[derive(Debug, Clone)]
880pub struct FakeQuantize {
881 pub dtype: QuantDtype,
883 pub qparams: Option<QParams>,
885 pub observer_enabled: bool,
887 pub fake_quant_enabled: bool,
889 observer: MinMaxObserver,
891}
892
893impl FakeQuantize {
894 pub fn new(dtype: QuantDtype) -> Self {
896 Self {
897 dtype,
898 qparams: None,
899 observer_enabled: true,
900 fake_quant_enabled: true,
901 observer: MinMaxObserver::new(),
902 }
903 }
904
905 pub fn forward(&mut self, data: &[f32]) -> (Vec<f32>, Vec<f32>) {
910 if !self.fake_quant_enabled {
911 let ones = vec![1.0f32; data.len()];
912 return (data.to_vec(), ones);
913 }
914
915 if self.observer_enabled {
917 self.observer.observe(data);
918 }
919
920 let qparams = if let Some(cached) = self.qparams.as_ref().filter(|_| !self.observer_enabled)
923 {
924 cached.clone()
925 } else {
926 let qp = self.observer.calculate_qparams(self.dtype);
927 self.qparams = Some(qp.clone());
928 qp
929 };
930
931 let scale = qparams.scale[0];
932 let zp = qparams.zero_point[0];
933 let qmin = self.dtype.qmin();
934 let qmax = self.dtype.qmax();
935
936 let range_min = (qmin as f32 - zp as f32) * scale;
938 let range_max = (qmax as f32 - zp as f32) * scale;
939
940 let mut output = Vec::with_capacity(data.len());
941 let mut grad_mask = Vec::with_capacity(data.len());
942
943 for &x in data {
944 let q = (x / scale + zp as f32)
946 .round()
947 .clamp(qmin as f32, qmax as f32);
948 let dq = (q - zp as f32) * scale;
949 output.push(dq);
950
951 if x >= range_min && x <= range_max {
953 grad_mask.push(1.0);
954 } else {
955 grad_mask.push(0.0);
956 }
957 }
958
959 (output, grad_mask)
960 }
961}
962
963#[derive(Debug, Clone)]
969pub struct QatLayer {
970 pub weight_fq: FakeQuantize,
972 pub activation_fq: FakeQuantize,
974}
975
976#[derive(Debug)]
983pub struct QatModel {
984 pub layers: HashMap<String, QatLayer>,
986 pub dtype: QuantDtype,
988}
989
990impl QatModel {
991 pub fn new(dtype: QuantDtype) -> Self {
993 Self {
994 layers: HashMap::new(),
995 dtype,
996 }
997 }
998
999 pub fn register_layer(&mut self, name: &str) {
1001 self.layers.insert(
1002 name.to_string(),
1003 QatLayer {
1004 weight_fq: FakeQuantize::new(self.dtype),
1005 activation_fq: FakeQuantize::new(self.dtype),
1006 },
1007 );
1008 }
1009
1010 pub fn fake_quantize_weights(
1015 &mut self,
1016 layer_name: &str,
1017 weights: &[f32],
1018 ) -> FerrotorchResult<(Vec<f32>, Vec<f32>)> {
1019 let layer =
1020 self.layers
1021 .get_mut(layer_name)
1022 .ok_or_else(|| FerrotorchError::InvalidArgument {
1023 message: format!("layer '{layer_name}' not registered for QAT"),
1024 })?;
1025
1026 let originals = weights.to_vec();
1028
1029 let (fq_weights, _mask) = layer.weight_fq.forward(weights);
1031
1032 Ok((fq_weights, originals))
1033 }
1034
1035 pub fn fake_quantize_activations(
1039 &mut self,
1040 layer_name: &str,
1041 activations: &[f32],
1042 ) -> FerrotorchResult<(Vec<f32>, Vec<f32>)> {
1043 let layer =
1044 self.layers
1045 .get_mut(layer_name)
1046 .ok_or_else(|| FerrotorchError::InvalidArgument {
1047 message: format!("layer '{layer_name}' not registered for QAT"),
1048 })?;
1049
1050 let (fq_activations, grad_mask) = layer.activation_fq.forward(activations);
1051 Ok((fq_activations, grad_mask))
1052 }
1053}
1054
1055pub fn prepare_qat(param_names: &[&str], dtype: QuantDtype) -> QatModel {
1060 let mut model = QatModel::new(dtype);
1061
1062 for &name in param_names {
1063 let layer_name = if let Some(prefix) = name.strip_suffix(".weight") {
1065 prefix
1066 } else if let Some(prefix) = name.strip_suffix(".bias") {
1067 if !model.layers.contains_key(prefix) {
1070 model.register_layer(prefix);
1071 }
1072 continue;
1073 } else {
1074 name
1075 };
1076
1077 model.register_layer(layer_name);
1078 }
1079
1080 model
1081}
1082
1083pub mod cuda_rng {
1092 use std::sync::Mutex;
1093
1094 static RNG_STATE: Mutex<u64> = Mutex::new(0xdeadbeef_cafebabe);
1096
1097 static RNG_STACK: Mutex<Vec<u64>> = Mutex::new(Vec::new());
1099
1100 pub fn get_state() -> u64 {
1102 let guard = RNG_STATE.lock().unwrap_or_else(|e| e.into_inner());
1103 *guard
1104 }
1105
1106 pub fn set_state(state: u64) {
1108 let mut guard = RNG_STATE.lock().unwrap_or_else(|e| e.into_inner());
1109 *guard = state;
1110 }
1111
1112 pub fn fork_rng(new_seed: u64) {
1117 let current = {
1118 let guard = RNG_STATE.lock().unwrap_or_else(|e| e.into_inner());
1119 *guard
1120 };
1121
1122 {
1123 let mut stack = RNG_STACK.lock().unwrap_or_else(|e| e.into_inner());
1124 stack.push(current);
1125 }
1126
1127 set_state(new_seed);
1128 }
1129
1130 pub fn join_rng() {
1135 let saved = {
1136 let mut stack = RNG_STACK.lock().unwrap_or_else(|e| e.into_inner());
1137 stack.pop()
1138 };
1139
1140 if let Some(state) = saved {
1141 set_state(state);
1142 }
1143 }
1144
1145 pub fn next_seed() -> u64 {
1147 let mut guard = RNG_STATE.lock().unwrap_or_else(|e| e.into_inner());
1148 *guard = guard.wrapping_add(0x9e3779b97f4a7c15);
1150 let mut z = *guard;
1151 z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
1152 z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
1153 z ^ (z >> 31)
1154 }
1155}
1156
1157#[cfg(test)]
1162mod tests {
1163 use super::*;
1164
1165 fn make_tensor(data: &[f32], shape: &[usize]) -> Tensor<f32> {
1167 crate::from_slice(data, shape).unwrap()
1168 }
1169
1170 #[test]
1173 fn test_per_tensor_int8_roundtrip() {
1174 let data: Vec<f32> = (-10..=10).map(|x| x as f32 * 0.5).collect();
1175 let t = make_tensor(&data, &[data.len()]);
1176 let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1177 let rt: Tensor<f32> = dequantize(&qt).unwrap();
1178
1179 assert_eq!(rt.shape(), t.shape());
1180 let orig = t.data().unwrap();
1181 let recovered = rt.data().unwrap();
1182 for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
1183 let err = (o - r).abs();
1184 assert!(
1186 err < 0.05,
1187 "element {i}: original={o}, recovered={r}, error={err}"
1188 );
1189 }
1190 }
1191
1192 #[test]
1193 fn test_per_tensor_uint8_roundtrip() {
1194 let data: Vec<f32> = (0..=20).map(|x| x as f32 * 0.1).collect();
1195 let t = make_tensor(&data, &[data.len()]);
1196 let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Uint8).unwrap();
1197 let rt: Tensor<f32> = dequantize(&qt).unwrap();
1198
1199 let orig = t.data().unwrap();
1200 let recovered = rt.data().unwrap();
1201 for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
1202 let err = (o - r).abs();
1203 assert!(
1205 err < 0.02,
1206 "element {i}: original={o}, recovered={r}, error={err}"
1207 );
1208 }
1209 }
1210
1211 #[test]
1212 fn test_per_tensor_int4_roundtrip() {
1213 let data: Vec<f32> = (-8..=7).map(|x| x as f32).collect();
1215 let t = make_tensor(&data, &[data.len()]);
1216 let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int4).unwrap();
1217 let rt: Tensor<f32> = dequantize(&qt).unwrap();
1218
1219 let orig = t.data().unwrap();
1220 let recovered = rt.data().unwrap();
1221 for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
1222 let err = (o - r).abs();
1223 assert!(
1225 err < 1.01,
1226 "element {i}: original={o}, recovered={r}, error={err}"
1227 );
1228 }
1229 }
1230
1231 #[test]
1234 fn test_per_channel_int8_roundtrip() {
1235 #[rustfmt::skip]
1237 let data: Vec<f32> = vec![
1238 0.0, 1.0, 2.0, 3.0,
1240 -10.0, -5.0, 5.0, 10.0,
1242 100.0, 130.0, 170.0, 200.0,
1244 ];
1245 let t = make_tensor(&data, &[3, 4]);
1246 let qt = quantize(&t, QuantScheme::PerChannel(0), QuantDtype::Int8).unwrap();
1247 let rt: Tensor<f32> = dequantize(&qt).unwrap();
1248
1249 assert_eq!(qt.scale.len(), 3);
1250 assert_eq!(qt.zero_point.len(), 3);
1251
1252 let orig = t.data().unwrap();
1253 let recovered = rt.data().unwrap();
1254 for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
1255 let err = (o - r).abs();
1256 assert!(
1259 err < 0.5,
1260 "element {i}: original={o}, recovered={r}, error={err}"
1261 );
1262 }
1263 }
1264
1265 #[test]
1266 fn test_per_channel_axis_out_of_bounds() {
1267 let t = make_tensor(&[1.0, 2.0, 3.0], &[3]);
1268 let result = quantize(&t, QuantScheme::PerChannel(5), QuantDtype::Int8);
1269 assert!(result.is_err());
1270 }
1271
1272 #[test]
1275 fn test_quantized_matmul_identity() {
1276 let a_data = vec![1.0f32, 2.0, 3.0, 4.0];
1278 let a = make_tensor(&a_data, &[2, 2]);
1279 let eye = make_tensor(&[1.0, 0.0, 0.0, 1.0], &[2, 2]);
1280
1281 let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1282 let qi = quantize(&eye, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1283 let qc = quantized_matmul(&qa, &qi).unwrap();
1284 let c: Tensor<f32> = dequantize(&qc).unwrap();
1285
1286 assert_eq!(c.shape(), &[2, 2]);
1287 let c_data = c.data().unwrap();
1288 for (i, (&expected, &got)) in a_data.iter().zip(c_data.iter()).enumerate() {
1289 let err = (expected - got).abs();
1290 assert!(
1291 err < 0.5,
1292 "element {i}: expected={expected}, got={got}, error={err}"
1293 );
1294 }
1295 }
1296
1297 #[test]
1298 fn test_quantized_matmul_correctness() {
1299 let a = make_tensor(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
1308 let b = make_tensor(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[3, 2]);
1309
1310 let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1311 let qb = quantize(&b, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1312 let qc = quantized_matmul(&qa, &qb).unwrap();
1313 let c: Tensor<f32> = dequantize(&qc).unwrap();
1314
1315 let expected = [58.0f32, 64.0, 139.0, 154.0];
1316 let c_data = c.data().unwrap();
1317 assert_eq!(c.shape(), &[2, 2]);
1318 for (i, (&e, &g)) in expected.iter().zip(c_data.iter()).enumerate() {
1319 let err = (e - g).abs();
1320 assert!(err < 3.0, "element {i}: expected={e}, got={g}, error={err}");
1323 }
1324 }
1325
1326 #[test]
1327 fn test_quantized_matmul_shape_mismatch() {
1328 let a = make_tensor(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
1329 let b = make_tensor(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1330
1331 let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1332 let qb = quantize(&b, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1333 let result = quantized_matmul(&qa, &qb);
1334 assert!(result.is_err());
1335 }
1336
1337 #[test]
1338 fn test_quantized_matmul_non_2d() {
1339 let a = make_tensor(&[1.0, 2.0, 3.0], &[3]);
1340 let b = make_tensor(&[4.0, 5.0, 6.0], &[3]);
1341
1342 let qa = quantize(&a, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1343 let qb = quantize(&b, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1344 let result = quantized_matmul(&qa, &qb);
1345 assert!(result.is_err());
1346 }
1347
1348 #[test]
1351 fn test_quantize_named_tensors() {
1352 let w1 = make_tensor(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
1353 let w2 = make_tensor(&[-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], &[3, 2]);
1354
1355 let named = vec![
1356 ("layer.weight".to_string(), w1),
1357 ("layer2.weight".to_string(), w2),
1358 ];
1359
1360 let qmap = quantize_named_tensors(named, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1361
1362 assert_eq!(qmap.len(), 2);
1363 assert!(qmap.contains_key("layer.weight"));
1364 assert!(qmap.contains_key("layer2.weight"));
1365 assert_eq!(qmap["layer.weight"].shape(), &[2, 2]);
1366 assert_eq!(qmap["layer2.weight"].shape(), &[3, 2]);
1367 }
1368
1369 #[test]
1372 fn test_quantize_constant_tensor() {
1373 let t = make_tensor(&[5.0, 5.0, 5.0, 5.0], &[4]);
1375 let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1376 let rt: Tensor<f32> = dequantize(&qt).unwrap();
1377
1378 let recovered = rt.data().unwrap();
1379 for &r in recovered {
1380 assert!(
1381 (r - 5.0).abs() < 0.1,
1382 "constant tensor dequantized to {r}, expected 5.0"
1383 );
1384 }
1385 }
1386
1387 #[test]
1388 fn test_quantize_single_element() {
1389 let t = make_tensor(&[42.0], &[1]);
1390 let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1391 let rt: Tensor<f32> = dequantize(&qt).unwrap();
1392 assert!((rt.data().unwrap()[0] - 42.0).abs() < 0.5);
1393 }
1394
1395 #[test]
1396 fn test_per_channel_int4() {
1397 let data = vec![0.0, 1.0, 2.0, -4.0, 0.0, 4.0];
1399 let t = make_tensor(&data, &[2, 3]);
1400 let qt = quantize(&t, QuantScheme::PerChannel(0), QuantDtype::Int4).unwrap();
1401
1402 assert_eq!(qt.scale.len(), 2);
1403 assert_eq!(qt.zero_point.len(), 2);
1404
1405 let rt: Tensor<f32> = dequantize(&qt).unwrap();
1406 let orig = t.data().unwrap();
1407 let recovered = rt.data().unwrap();
1408 for (i, (&o, &r)) in orig.iter().zip(recovered.iter()).enumerate() {
1409 let err = (o - r).abs();
1410 assert!(
1412 err < 1.0,
1413 "element {i}: original={o}, recovered={r}, error={err}"
1414 );
1415 }
1416 }
1417
1418 #[test]
1419 fn test_dequantize_f64() {
1420 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1421 let t = crate::from_slice(&data, &[4]).unwrap();
1422 let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1423 let rt: Tensor<f64> = dequantize(&qt).unwrap();
1424
1425 assert_eq!(rt.shape(), &[4]);
1426 let recovered = rt.data().unwrap();
1427 for (i, &r) in recovered.iter().enumerate() {
1428 let expected = data[i] as f64;
1429 let err = (expected - r).abs();
1430 assert!(
1431 err < 0.05,
1432 "element {i}: expected={expected}, recovered={r}, error={err}"
1433 );
1434 }
1435 }
1436
1437 #[test]
1438 fn test_quantized_tensor_accessors() {
1439 let t = make_tensor(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
1440 let qt = quantize(&t, QuantScheme::PerTensor, QuantDtype::Int8).unwrap();
1441
1442 assert_eq!(qt.numel(), 6);
1443 assert_eq!(qt.shape(), &[2, 3]);
1444 assert_eq!(qt.data().len(), 6);
1445 assert_eq!(qt.scale().len(), 1);
1446 assert_eq!(qt.zero_point().len(), 1);
1447 assert_eq!(qt.scheme(), QuantScheme::PerTensor);
1448 assert_eq!(qt.qdtype(), QuantDtype::Int8);
1449 }
1450
1451 #[test]
1454 fn test_qparams_symmetric_int8() {
1455 let qp = QParams::symmetric(5.0, QuantDtype::Int8);
1456 assert_eq!(qp.zero_point, vec![0]);
1457 assert!((qp.scale[0] - 5.0 / 127.0).abs() < 1e-7);
1458 }
1459
1460 #[test]
1461 fn test_qparams_symmetric_uint8() {
1462 let qp = QParams::symmetric(5.0, QuantDtype::Uint8);
1463 assert_eq!(qp.zero_point, vec![128]);
1464 assert!((qp.scale[0] - 5.0 / 128.0).abs() < 1e-7);
1465 }
1466
1467 #[test]
1468 fn test_qparams_symmetric_int4() {
1469 let qp = QParams::symmetric(7.0, QuantDtype::Int4);
1470 assert_eq!(qp.zero_point, vec![0]);
1471 assert!((qp.scale[0] - 1.0).abs() < 1e-7);
1472 }
1473
1474 #[test]
1477 fn test_minmax_observer() {
1478 let mut obs = MinMaxObserver::new();
1479 obs.observe(&[1.0, 2.0, 3.0]);
1480 obs.observe(&[-1.0, 5.0]);
1481 let qp = obs.calculate_qparams(QuantDtype::Int8);
1482 assert_eq!(qp.scale.len(), 1);
1484 assert_eq!(qp.zero_point.len(), 1);
1485 }
1486
1487 #[test]
1488 fn test_minmax_observer_filters_nan_inf() {
1489 let mut obs = MinMaxObserver::new();
1490 obs.observe(&[1.0, f32::NAN, 2.0, f32::INFINITY, -1.0, f32::NEG_INFINITY]);
1491 let qp = obs.calculate_qparams(QuantDtype::Int8);
1492 let expected_range = 2.0 - (-1.0); let expected_scale = expected_range / 255.0;
1495 assert!((qp.scale[0] - expected_scale).abs() < 1e-5);
1496 }
1497
1498 #[test]
1501 fn test_per_channel_observer_with_shape() {
1502 let mut obs = PerChannelMinMaxObserver::new(2, 0);
1503 obs.observe_with_shape(&[0.0, 1.0, 2.0, 10.0, 20.0, 30.0], &[2, 3])
1505 .unwrap();
1506 let qp = obs.calculate_qparams(QuantDtype::Int8);
1507 assert_eq!(qp.scale.len(), 2);
1508 assert_eq!(qp.zero_point.len(), 2);
1509 }
1510
1511 #[test]
1512 fn test_per_channel_observer_shape_mismatch() {
1513 let mut obs = PerChannelMinMaxObserver::new(3, 0);
1514 let result = obs.observe_with_shape(&[1.0; 6], &[2, 3]);
1516 assert!(result.is_err());
1517 }
1518
1519 #[test]
1520 fn test_per_channel_observer_axis() {
1521 let mut obs = PerChannelMinMaxObserver::new(3, 1);
1522 obs.observe_with_shape(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
1524 .unwrap();
1525 let qp = obs.calculate_qparams(QuantDtype::Int8);
1526 assert_eq!(qp.scale.len(), 3);
1527 }
1528
1529 #[test]
1530 fn test_per_channel_observer_filters_nan_inf() {
1531 let mut obs = PerChannelMinMaxObserver::new(2, 0);
1532 obs.observe_with_shape(&[f32::NAN, 1.0, 2.0, 10.0, f32::INFINITY, 30.0], &[2, 3])
1533 .unwrap();
1534 let qp = obs.calculate_qparams(QuantDtype::Int8);
1536 assert_eq!(qp.scale.len(), 2);
1537 }
1538
1539 #[test]
1542 fn test_histogram_observer_basic() {
1543 let mut obs = HistogramObserver::new(100);
1544 obs.observe(&[0.0, 0.5, 1.0]);
1545 let qp = obs.calculate_qparams(QuantDtype::Int8);
1546 assert_eq!(qp.scale.len(), 1);
1547 }
1548
1549 #[test]
1550 fn test_histogram_observer_range_expansion() {
1551 let mut obs = HistogramObserver::new(100);
1552 obs.observe(&[0.0, 1.0]);
1553 let bins_after_first = obs.bins.clone();
1555 let total_first: u64 = bins_after_first.iter().sum();
1556 assert_eq!(total_first, 2);
1557
1558 obs.observe(&[-1.0, 2.0]);
1559 let total_second: u64 = obs.bins.iter().sum();
1561 assert_eq!(total_second, 4);
1563 }
1564
1565 #[test]
1566 fn test_histogram_observer_filters_nan_inf() {
1567 let mut obs = HistogramObserver::new(50);
1568 obs.observe(&[f32::NAN, 1.0, f32::INFINITY, 2.0]);
1569 let total: u64 = obs.bins.iter().sum();
1570 assert_eq!(total, 2);
1572 }
1573
1574 #[test]
1577 fn test_fake_quantize_roundtrip() {
1578 let mut fq = FakeQuantize::new(QuantDtype::Int8);
1579 let data = vec![0.0, 0.5, 1.0, 1.5, 2.0];
1580 let (output, mask) = fq.forward(&data);
1581 assert_eq!(output.len(), 5);
1582 assert_eq!(mask.len(), 5);
1583
1584 for (i, (&o, &d)) in output.iter().zip(data.iter()).enumerate() {
1586 assert!((o - d).abs() < 0.1, "element {i}: output={o}, data={d}");
1587 }
1588 }
1589
1590 #[test]
1591 fn test_fake_quantize_ste_clipping() {
1592 let mut fq = FakeQuantize::new(QuantDtype::Int8);
1593 let (_, _) = fq.forward(&[0.0, 1.0, 2.0]);
1595
1596 fq.observer_enabled = false;
1598
1599 let (_, mask) = fq.forward(&[0.5, 1.0, 100.0, -100.0]);
1601 assert_eq!(mask[0], 1.0);
1603 assert_eq!(mask[1], 1.0);
1604 assert_eq!(mask[2], 0.0);
1606 assert_eq!(mask[3], 0.0);
1607 }
1608
1609 #[test]
1610 fn test_fake_quantize_observer_disabled_uses_cached() {
1611 let mut fq = FakeQuantize::new(QuantDtype::Int8);
1612 let (_, _) = fq.forward(&[0.0, 10.0]);
1614 let cached_scale = fq.qparams.as_ref().unwrap().scale[0];
1615
1616 fq.observer_enabled = false;
1618
1619 let (_, _) = fq.forward(&[0.0, 1000.0]);
1621 let scale_after = fq.qparams.as_ref().unwrap().scale[0];
1622 assert!(
1623 (scale_after - cached_scale).abs() < 1e-10,
1624 "scale should not change when observer is disabled"
1625 );
1626 }
1627
1628 #[test]
1629 fn test_fake_quantize_disabled_is_identity() {
1630 let mut fq = FakeQuantize::new(QuantDtype::Int8);
1631 fq.fake_quant_enabled = false;
1632 let data = vec![1.234, 5.678, -9.012];
1633 let (output, mask) = fq.forward(&data);
1634 assert_eq!(output, data);
1635 assert!(mask.iter().all(|&m| m == 1.0));
1636 }
1637
1638 #[test]
1641 fn test_qat_model_register_and_fq_weights() {
1642 let mut model = QatModel::new(QuantDtype::Int8);
1643 model.register_layer("fc1");
1644
1645 let weights = vec![0.1, 0.2, 0.3, 0.4];
1646 let (fq_weights, originals) = model.fake_quantize_weights("fc1", &weights).unwrap();
1647
1648 assert_eq!(originals, weights);
1650 for (i, (&fq, &orig)) in fq_weights.iter().zip(weights.iter()).enumerate() {
1652 assert!((fq - orig).abs() < 0.1, "weight {i}: fq={fq}, orig={orig}");
1653 }
1654 }
1655
1656 #[test]
1657 fn test_qat_model_activation_fq_per_layer() {
1658 let mut model = QatModel::new(QuantDtype::Int8);
1659 model.register_layer("layer1");
1660 model.register_layer("layer2");
1661
1662 let (act1, _) = model
1664 .fake_quantize_activations("layer1", &[1.0, 2.0])
1665 .unwrap();
1666 let (act2, _) = model
1667 .fake_quantize_activations("layer2", &[10.0, 20.0])
1668 .unwrap();
1669 assert_eq!(act1.len(), 2);
1670 assert_eq!(act2.len(), 2);
1671 }
1672
1673 #[test]
1674 fn test_qat_model_unregistered_layer_errors() {
1675 let mut model = QatModel::new(QuantDtype::Int8);
1676 let result = model.fake_quantize_weights("nonexistent", &[1.0]);
1677 assert!(result.is_err());
1678 }
1679
1680 #[test]
1683 fn test_prepare_qat_skips_bias() {
1684 let names = &["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"];
1685 let model = prepare_qat(names, QuantDtype::Int8);
1686
1687 assert!(model.layers.contains_key("fc1"));
1688 assert!(model.layers.contains_key("fc2"));
1689 assert_eq!(model.layers.len(), 2);
1690 }
1691
1692 #[test]
1693 fn test_prepare_qat_bias_only_still_registers() {
1694 let names = &["fc1.bias"];
1695 let model = prepare_qat(names, QuantDtype::Int8);
1696 assert!(model.layers.contains_key("fc1"));
1698 }
1699
1700 #[test]
1703 fn test_cuda_rng_fork_join() {
1704 let initial = cuda_rng::get_state();
1705 cuda_rng::fork_rng(0x12345678);
1706 assert_eq!(cuda_rng::get_state(), 0x12345678);
1707 cuda_rng::join_rng();
1708 assert_eq!(cuda_rng::get_state(), initial);
1709 }
1710
1711 #[test]
1712 fn test_cuda_rng_next_seed() {
1713 let s1 = cuda_rng::next_seed();
1714 let s2 = cuda_rng::next_seed();
1715 assert_ne!(s1, s2, "consecutive seeds should differ");
1716 }
1717}