1use crate::errors::{QuantizeError, Result};
8
9#[derive(Debug, Clone)]
11pub struct QuantConfig {
12 pub bits: u8,
14 pub per_channel: bool,
16 pub symmetric: bool,
20 pub calibration_method: Option<crate::calibration::methods::CalibrationMethod>,
22 pub excluded_layers: Vec<String>,
24 pub layer_bits: std::collections::HashMap<String, u8>,
26 pub min_elements: usize,
29}
30
31impl Default for QuantConfig {
32 fn default() -> Self {
33 Self {
34 bits: 8,
35 per_channel: false,
36 symmetric: false,
37 calibration_method: None,
38 excluded_layers: Vec::new(),
39 layer_bits: std::collections::HashMap::new(),
40 min_elements: 0,
41 }
42 }
43}
44
45impl QuantConfig {
46 pub fn int8() -> Self {
48 Self::default()
49 }
50
51 pub fn with_per_channel(mut self, enabled: bool) -> Self {
53 self.per_channel = enabled;
54 self
55 }
56
57 pub fn with_symmetric(mut self, enabled: bool) -> Self {
59 self.symmetric = enabled;
60 self
61 }
62
63 pub fn with_calibration(
65 mut self,
66 method: crate::calibration::methods::CalibrationMethod,
67 ) -> Self {
68 self.calibration_method = Some(method);
69 self
70 }
71
72 pub fn should_quantize(&self, name: &str, num_elements: usize) -> bool {
78 if self.excluded_layers.iter().any(|e| e == name) {
79 return false;
80 }
81 if self.min_elements > 0 && num_elements < self.min_elements {
82 return false;
83 }
84 true
85 }
86
87 pub fn bits_for_layer(&self, name: &str) -> u8 {
92 self.layer_bits.get(name).copied().unwrap_or(self.bits)
93 }
94}
95
96pub trait QuantRange: Clone + std::fmt::Debug + Send + Sync + 'static {
102 const QMIN: f32;
104 const QMAX: f32;
106 const BITS: u8;
108}
109
110#[derive(Debug, Clone)]
112pub struct Int8Range;
113impl QuantRange for Int8Range {
114 const QMIN: f32 = -128.0;
115 const QMAX: f32 = 127.0;
116 const BITS: u8 = 8;
117}
118
119#[derive(Debug, Clone)]
121pub struct Int4Range;
122impl QuantRange for Int4Range {
123 const QMIN: f32 = -8.0;
124 const QMAX: f32 = 7.0;
125 const BITS: u8 = 4;
126}
127
128#[derive(Debug, Clone)]
138pub struct QuantParamsGeneric<R: QuantRange> {
139 scale: f32,
140 zero_point: i8,
141 _marker: std::marker::PhantomData<R>,
142}
143
144pub type QuantParams = QuantParamsGeneric<Int8Range>;
146pub type QuantParamsInt4 = QuantParamsGeneric<Int4Range>;
148
149impl<R: QuantRange> QuantParamsGeneric<R> {
150 pub fn scale(&self) -> f32 {
152 self.scale
153 }
154 pub fn zero_point(&self) -> i8 {
156 self.zero_point
157 }
158
159 pub fn from_range(min: f32, max: f32) -> Self {
167 let min = min.min(0.0);
168 let max = max.max(0.0);
169
170 let (min, max) = if (max - min).abs() < 1e-8 {
173 let abs = min.abs().max(max.abs()).max(1e-8);
174 (-abs, abs)
175 } else {
176 (min, max)
177 };
178
179 let scale = (max - min) / (R::QMAX - R::QMIN);
180 let scale = scale.max(1e-8);
181
182 let initial_zero_point = R::QMIN - min / scale;
183 let zero_point = if initial_zero_point.is_finite() {
186 initial_zero_point.round().clamp(R::QMIN, R::QMAX) as i8
187 } else {
188 0i8
189 };
190
191 QuantParamsGeneric {
192 scale,
193 zero_point,
194 _marker: std::marker::PhantomData,
195 }
196 }
197
198 pub fn from_range_symmetric(min: f32, max: f32) -> Self {
208 let abs_max = min.abs().max(max.abs()).max(1e-8);
209 let scale = (abs_max / R::QMAX).max(1e-8);
213 QuantParamsGeneric {
214 scale,
215 zero_point: 0,
216 _marker: std::marker::PhantomData,
217 }
218 }
219
220 pub fn quantize(&self, value: f32) -> i8 {
222 if !value.is_finite() {
223 return self.zero_point;
224 }
225 let quantized = (value / self.scale).round() + (self.zero_point as f32);
226 quantized.clamp(R::QMIN, R::QMAX) as i8
227 }
228
229 pub fn dequantize(&self, value: i8) -> f32 {
231 ((value as i32) - (self.zero_point as i32)) as f32 * self.scale
232 }
233}
234
235#[derive(Debug, Clone)]
244pub struct QuantizedTensorGeneric<R: QuantRange> {
245 pub(crate) data: Vec<i8>,
246 pub(crate) packed_data: Option<Vec<u8>>,
248 pub(crate) shape: Vec<usize>,
249 pub(crate) params: QuantParamsGeneric<R>,
250 pub(crate) per_channel: bool,
251 pub(crate) channel_params: Option<Vec<QuantParamsGeneric<R>>>,
252}
253
254pub type QuantizedTensor = QuantizedTensorGeneric<Int8Range>;
256
257pub type QuantizedTensorInt4 = QuantizedTensorGeneric<Int4Range>;
262
263impl<R: QuantRange> QuantizedTensorGeneric<R> {
268 pub fn shape(&self) -> &[usize] {
270 &self.shape
271 }
272 pub fn params(&self) -> &QuantParamsGeneric<R> {
274 &self.params
275 }
276 pub fn is_per_channel(&self) -> bool {
278 self.per_channel
279 }
280
281 pub fn from_f32(data: &[f32], shape: Vec<usize>) -> Result<Self> {
287 Self::from_f32_with_mode(data, shape, false)
288 }
289
290 pub fn from_f32_symmetric(data: &[f32], shape: Vec<usize>) -> Result<Self> {
295 Self::from_f32_with_mode(data, shape, true)
296 }
297
298 fn from_f32_with_mode(data: &[f32], shape: Vec<usize>, symmetric: bool) -> Result<Self> {
299 if data.is_empty() {
300 return Err(QuantizeError::InvalidTensor {
301 reason: "Cannot quantize empty tensor".into(),
302 });
303 }
304
305 let expected_len: usize = shape.iter().product();
306 if expected_len != data.len() {
307 return Err(QuantizeError::InvalidTensor {
308 reason: format!(
309 "Shape {:?} expects {} elements but got {}",
310 shape,
311 expected_len,
312 data.len()
313 ),
314 });
315 }
316
317 let min = data
318 .iter()
319 .copied()
320 .filter(|v| v.is_finite())
321 .fold(f32::INFINITY, f32::min);
322 let max = data
323 .iter()
324 .copied()
325 .filter(|v| v.is_finite())
326 .fold(f32::NEG_INFINITY, f32::max);
327
328 if !min.is_finite() || !max.is_finite() {
329 return Err(QuantizeError::InvalidTensor {
330 reason: "Tensor contains only non-finite values (NaN/Inf)".into(),
331 });
332 }
333
334 let params = if symmetric {
335 QuantParamsGeneric::<R>::from_range_symmetric(min, max)
336 } else {
337 QuantParamsGeneric::<R>::from_range(min, max)
338 };
339
340 let quantized_data: Vec<i8> = data.iter().map(|&v| params.quantize(v)).collect();
341
342 Ok(QuantizedTensorGeneric {
343 data: quantized_data,
344 packed_data: None,
345 shape,
346 params,
347 per_channel: false,
348 channel_params: None,
349 })
350 }
351
352 pub fn from_f32_with_range(
358 data: &[f32],
359 shape: Vec<usize>,
360 min: f32,
361 max: f32,
362 ) -> Result<Self> {
363 Self::from_f32_with_range_and_mode(data, shape, min, max, false)
364 }
365
366 pub fn from_f32_with_range_symmetric(
369 data: &[f32],
370 shape: Vec<usize>,
371 min: f32,
372 max: f32,
373 ) -> Result<Self> {
374 Self::from_f32_with_range_and_mode(data, shape, min, max, true)
375 }
376
377 fn from_f32_with_range_and_mode(
378 data: &[f32],
379 shape: Vec<usize>,
380 min: f32,
381 max: f32,
382 symmetric: bool,
383 ) -> Result<Self> {
384 if data.is_empty() {
385 return Err(QuantizeError::InvalidTensor {
386 reason: "Cannot quantize empty tensor".into(),
387 });
388 }
389
390 let expected_len: usize = shape.iter().product();
391 if expected_len != data.len() {
392 return Err(QuantizeError::InvalidTensor {
393 reason: format!(
394 "Shape {:?} expects {} elements but got {}",
395 shape,
396 expected_len,
397 data.len()
398 ),
399 });
400 }
401
402 let params = if symmetric {
403 QuantParamsGeneric::<R>::from_range_symmetric(min, max)
404 } else {
405 QuantParamsGeneric::<R>::from_range(min, max)
406 };
407
408 let quantized_data: Vec<i8> = data.iter().map(|&v| params.quantize(v)).collect();
409
410 Ok(QuantizedTensorGeneric {
411 data: quantized_data,
412 packed_data: None,
413 shape,
414 params,
415 per_channel: false,
416 channel_params: None,
417 })
418 }
419
420 pub fn from_f32_per_channel(data: &[f32], shape: Vec<usize>) -> Result<Self> {
427 Self::from_f32_per_channel_with_mode(data, shape, false)
428 }
429
430 pub fn from_f32_per_channel_symmetric(data: &[f32], shape: Vec<usize>) -> Result<Self> {
434 Self::from_f32_per_channel_with_mode(data, shape, true)
435 }
436
437 fn from_f32_per_channel_with_mode(
438 data: &[f32],
439 shape: Vec<usize>,
440 symmetric: bool,
441 ) -> Result<Self> {
442 if data.is_empty() {
443 return Err(QuantizeError::InvalidTensor {
444 reason: "Cannot quantize empty tensor".into(),
445 });
446 }
447
448 if shape.is_empty() {
449 return Err(QuantizeError::InvalidTensor {
450 reason: "Cannot do per-channel quantization on scalar".into(),
451 });
452 }
453
454 let expected_len: usize = shape.iter().product();
455 if expected_len != data.len() {
456 return Err(QuantizeError::InvalidTensor {
457 reason: format!(
458 "Shape {:?} expects {} elements but got {}",
459 shape,
460 expected_len,
461 data.len()
462 ),
463 });
464 }
465
466 let num_channels = shape[0];
467 if num_channels == 0 {
468 return Err(QuantizeError::InvalidTensor {
469 reason: "Number of channels is 0".into(),
470 });
471 }
472 if !data.len().is_multiple_of(num_channels) {
473 return Err(QuantizeError::InvalidTensor {
474 reason: format!(
475 "Data length {} not evenly divisible by {} channels",
476 data.len(),
477 num_channels
478 ),
479 });
480 }
481 let elements_per_channel = data.len() / num_channels;
482
483 let mut channel_params = Vec::with_capacity(num_channels);
484 let mut quantized_data = Vec::with_capacity(data.len());
485
486 for (channel_idx, channel_slice) in data.chunks_exact(elements_per_channel).enumerate() {
490 let mut min = f32::INFINITY;
491 let mut max = f32::NEG_INFINITY;
492 for &v in channel_slice {
493 if v.is_finite() {
494 if v < min {
495 min = v;
496 }
497 if v > max {
498 max = v;
499 }
500 }
501 }
502
503 if !min.is_finite() || !max.is_finite() {
504 return Err(QuantizeError::InvalidTensor {
505 reason: format!(
506 "Channel {} contains only non-finite values (NaN/Inf)",
507 channel_idx
508 ),
509 });
510 }
511
512 let params = if symmetric {
513 QuantParamsGeneric::<R>::from_range_symmetric(min, max)
514 } else {
515 QuantParamsGeneric::<R>::from_range(min, max)
516 };
517
518 quantized_data.extend(channel_slice.iter().map(|&v| params.quantize(v)));
519 channel_params.push(params);
520 }
521
522 let params = channel_params[0].clone();
524
525 Ok(QuantizedTensorGeneric {
526 data: quantized_data,
527 packed_data: None,
528 shape,
529 params,
530 per_channel: true,
531 channel_params: Some(channel_params),
532 })
533 }
534
535 pub fn to_f32(&self) -> Vec<f32> {
537 let data_owned;
539 let data: &[i8] = if let Some(ref packed) = self.packed_data {
540 data_owned = unpack_int4(packed, self.data.len());
541 &data_owned
542 } else {
543 &self.data
544 };
545
546 if self.per_channel {
547 if let Some(ref channel_params) = self.channel_params {
548 if channel_params.is_empty() {
549 return data.iter().map(|&v| self.params.dequantize(v)).collect();
550 }
551 let elements_per_channel = data.len() / channel_params.len();
555 let mut out = Vec::with_capacity(data.len());
556 if elements_per_channel == 0 {
557 return data.iter().map(|&v| self.params.dequantize(v)).collect();
560 }
561 for (chunk, params) in data.chunks(elements_per_channel).zip(channel_params.iter())
562 {
563 out.extend(chunk.iter().map(|&v| params.dequantize(v)));
564 }
565 out
566 } else {
567 data.iter().map(|&v| self.params.dequantize(v)).collect()
568 }
569 } else {
570 data.iter().map(|&v| self.params.dequantize(v)).collect()
571 }
572 }
573
574 pub fn size_bytes(&self) -> usize {
576 if let Some(ref packed) = self.packed_data {
577 packed.len()
578 } else {
579 self.data.len() * std::mem::size_of::<i8>()
580 }
581 }
582
583 pub fn quantization_error(&self, original: &[f32]) -> f32 {
585 if original.is_empty() {
586 return 0.0;
587 }
588
589 let dequantized = self.to_f32();
590
591 let sum: f32 = original
592 .iter()
593 .zip(dequantized.iter())
594 .map(|(a, b)| (a - b).powi(2))
595 .sum();
596
597 sum / original.len() as f32
598 }
599}
600
601impl QuantizedTensorGeneric<Int4Range> {
606 pub fn pack(&mut self) {
608 self.packed_data = Some(pack_int4(&self.data));
609 }
610
611 pub fn ensure_unpacked(&self) -> Vec<i8> {
613 if let Some(ref packed) = self.packed_data {
614 unpack_int4(packed, self.data.len())
615 } else {
616 self.data.clone()
617 }
618 }
619
620 pub fn is_packed(&self) -> bool {
622 self.packed_data.is_some()
623 }
624
625 pub fn packed_size_bytes(&self) -> usize {
627 if let Some(ref packed) = self.packed_data {
628 packed.len()
629 } else {
630 self.data.len().div_ceil(2)
631 }
632 }
633
634 pub fn unpacked_size_bytes(&self) -> usize {
636 self.data.len() * std::mem::size_of::<i8>()
637 }
638}
639
640fn pack_int4_pair(val1: i8, val2: i8) -> u8 {
645 debug_assert!((-8..=7).contains(&val1), "val1 out of INT4 range: {}", val1);
646 debug_assert!((-8..=7).contains(&val2), "val2 out of INT4 range: {}", val2);
647
648 let nibble1 = (val1 & 0x0F) as u8;
650 let nibble2 = (val2 & 0x0F) as u8;
651
652 (nibble1 << 4) | nibble2
654}
655
656fn unpack_int4_pair(byte: u8) -> (i8, i8) {
657 let nibble1 = (byte >> 4) & 0x0F;
658 let nibble2 = byte & 0x0F;
659
660 let val1 = if nibble1 >= 8 {
662 (nibble1 as i8) | !0x0F
663 } else {
664 nibble1 as i8
665 };
666
667 let val2 = if nibble2 >= 8 {
668 (nibble2 as i8) | !0x0F
669 } else {
670 nibble2 as i8
671 };
672
673 (val1, val2)
674}
675
676pub fn pack_int4(values: &[i8]) -> Vec<u8> {
678 let mut packed = Vec::with_capacity(values.len().div_ceil(2));
679
680 for chunk in values.chunks(2) {
681 let val1 = chunk[0];
682 let val2 = if chunk.len() > 1 { chunk[1] } else { 0 };
683
684 packed.push(pack_int4_pair(val1, val2));
685 }
686
687 packed
688}
689
690pub fn unpack_int4(packed: &[u8], num_values: usize) -> Vec<i8> {
692 let mut values = Vec::with_capacity(num_values);
693
694 for &byte in packed {
695 let (val1, val2) = unpack_int4_pair(byte);
696 values.push(val1);
697 if values.len() < num_values {
698 values.push(val2);
699 }
700 }
701
702 values.truncate(num_values);
704 values
705}
706
707#[derive(Debug, Clone)]
713pub enum QuantizedTensorType {
714 Int8(QuantizedTensor),
715 Int4(QuantizedTensorInt4),
716}
717
718impl QuantizedTensorType {
719 pub fn to_f32(&self) -> Vec<f32> {
721 match self {
722 QuantizedTensorType::Int8(t) => t.to_f32(),
723 QuantizedTensorType::Int4(t) => t.to_f32(),
724 }
725 }
726
727 pub fn size_bytes(&self) -> usize {
729 match self {
730 QuantizedTensorType::Int8(t) => t.size_bytes(),
731 QuantizedTensorType::Int4(t) => t.size_bytes(),
732 }
733 }
734
735 #[must_use]
736 pub fn quantization_error(&self, original: &[f32]) -> f32 {
737 match self {
738 QuantizedTensorType::Int8(t) => t.quantization_error(original),
739 QuantizedTensorType::Int4(t) => t.quantization_error(original),
740 }
741 }
742
743 #[must_use]
744 pub fn data(&self) -> Vec<i8> {
745 match self {
746 QuantizedTensorType::Int8(t) => t.data.clone(),
747 QuantizedTensorType::Int4(t) => t.ensure_unpacked(),
748 }
749 }
750
751 pub fn get_scale_zero_point(&self) -> (f32, i8) {
753 match self {
754 QuantizedTensorType::Int8(t) => (t.params.scale, t.params.zero_point),
755 QuantizedTensorType::Int4(t) => (t.params.scale, t.params.zero_point),
756 }
757 }
758
759 pub fn get_all_scales_zero_points(&self) -> (Vec<f32>, Vec<i8>) {
764 match self {
765 QuantizedTensorType::Int8(t) => {
766 if let Some(ref cp) = t.channel_params {
767 (
768 cp.iter().map(|p| p.scale).collect(),
769 cp.iter().map(|p| p.zero_point).collect(),
770 )
771 } else {
772 (vec![t.params.scale], vec![t.params.zero_point])
773 }
774 }
775 QuantizedTensorType::Int4(t) => {
776 if let Some(ref cp) = t.channel_params {
777 (
778 cp.iter().map(|p| p.scale).collect(),
779 cp.iter().map(|p| p.zero_point).collect(),
780 )
781 } else {
782 (vec![t.params.scale], vec![t.params.zero_point])
783 }
784 }
785 }
786 }
787
788 pub fn is_per_channel(&self) -> bool {
790 match self {
791 QuantizedTensorType::Int8(t) => t.per_channel,
792 QuantizedTensorType::Int4(t) => t.per_channel,
793 }
794 }
795
796 #[must_use]
797 pub fn bits(&self) -> u8 {
798 match self {
799 QuantizedTensorType::Int8(_) => 8,
800 QuantizedTensorType::Int4(_) => 4,
801 }
802 }
803
804 pub fn is_int8(&self) -> bool {
806 matches!(self, QuantizedTensorType::Int8(_))
807 }
808
809 pub fn is_int4(&self) -> bool {
811 matches!(self, QuantizedTensorType::Int4(_))
812 }
813
814 pub fn data_ref(&self) -> Option<&[i8]> {
818 match self {
819 QuantizedTensorType::Int8(t) => Some(&t.data),
820 QuantizedTensorType::Int4(t) => {
821 if t.packed_data.is_some() {
822 None } else {
824 Some(&t.data)
825 }
826 }
827 }
828 }
829}
830
831pub struct Quantizer {
837 config: QuantConfig,
838 calibration_stats:
839 Option<std::collections::HashMap<String, crate::calibration::stats::ActivationStats>>,
840}
841
842impl std::fmt::Debug for Quantizer {
843 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
844 let stats_count = self.calibration_stats.as_ref().map(|m| m.len());
845 f.debug_struct("Quantizer")
846 .field("config", &self.config)
847 .field("calibration_stats_count", &stats_count)
848 .finish()
849 }
850}
851
852impl Quantizer {
853 pub fn new(config: QuantConfig) -> Self {
855 Self {
856 config,
857 calibration_stats: None,
858 }
859 }
860
861 pub fn with_calibration(
863 config: QuantConfig,
864 stats: std::collections::HashMap<String, crate::calibration::stats::ActivationStats>,
865 ) -> Self {
866 Self {
867 config,
868 calibration_stats: Some(stats),
869 }
870 }
871
872 pub fn quantize_tensor_with_name(
874 &self,
875 name: &str,
876 data: &[f32],
877 shape: Vec<usize>,
878 ) -> Result<QuantizedTensorType> {
879 let (min, max) = if let Some(ref stats_map) = self.calibration_stats {
880 if let Some(stats) = stats_map.get(name) {
881 if let Some(method) = self.config.calibration_method {
882 use crate::calibration::stats::calculate_optimal_range_from_stats;
885 calculate_optimal_range_from_stats(stats, method)
886 } else {
887 (stats.min(), stats.max())
888 }
889 } else {
890 finite_min_max(data, name)?
891 }
892 } else {
893 finite_min_max(data, name)?
894 };
895
896 self.quantize_with_range(data, shape, min, max)
897 }
898
899 pub fn quantize_tensor(&self, data: &[f32], shape: Vec<usize>) -> Result<QuantizedTensorType> {
905 self.build_tensor_with_optional_range(data, shape, None)
906 }
907
908 fn quantize_with_range(
915 &self,
916 data: &[f32],
917 shape: Vec<usize>,
918 min: f32,
919 max: f32,
920 ) -> Result<QuantizedTensorType> {
921 self.build_tensor_with_optional_range(data, shape, Some((min, max)))
922 }
923
924 fn build_tensor_with_optional_range(
926 &self,
927 data: &[f32],
928 shape: Vec<usize>,
929 range: Option<(f32, f32)>,
930 ) -> Result<QuantizedTensorType> {
931 let pc = self.config.per_channel && shape.len() >= 2;
932 let sym = self.config.symmetric;
933 match self.config.bits {
934 8 => {
935 let t = match (pc, range, sym) {
936 (true, _, true) => {
937 QuantizedTensor::from_f32_per_channel_symmetric(data, shape)?
938 }
939 (true, _, false) => QuantizedTensor::from_f32_per_channel(data, shape)?,
940 (false, Some((min, max)), true) => {
941 QuantizedTensor::from_f32_with_range_symmetric(data, shape, min, max)?
942 }
943 (false, Some((min, max)), false) => {
944 QuantizedTensor::from_f32_with_range(data, shape, min, max)?
945 }
946 (false, None, true) => QuantizedTensor::from_f32_symmetric(data, shape)?,
947 (false, None, false) => QuantizedTensor::from_f32(data, shape)?,
948 };
949 Ok(QuantizedTensorType::Int8(t))
950 }
951 4 => {
952 let mut t = match (pc, range, sym) {
953 (true, _, true) => {
954 QuantizedTensorInt4::from_f32_per_channel_symmetric(data, shape)?
955 }
956 (true, _, false) => QuantizedTensorInt4::from_f32_per_channel(data, shape)?,
957 (false, Some((min, max)), true) => {
958 QuantizedTensorInt4::from_f32_with_range_symmetric(data, shape, min, max)?
959 }
960 (false, Some((min, max)), false) => {
961 QuantizedTensorInt4::from_f32_with_range(data, shape, min, max)?
962 }
963 (false, None, true) => QuantizedTensorInt4::from_f32_symmetric(data, shape)?,
964 (false, None, false) => QuantizedTensorInt4::from_f32(data, shape)?,
965 };
966 t.pack();
967 Ok(QuantizedTensorType::Int4(t))
968 }
969 b => Err(QuantizeError::UnsupportedConfig {
970 reason: format!("bits must be 4 or 8, got {b}"),
971 }),
972 }
973 }
974
975 pub fn quantize_model(
985 &self,
986 model: &crate::onnx_utils::OnnxModel,
987 ) -> Result<Vec<QuantizedWeightOutput>> {
988 use rayon::prelude::*;
989
990 let weights = model.extract_weights();
991 let to_quantize: Vec<_> = weights
992 .iter()
993 .filter(|w| self.config.should_quantize(&w.name, w.num_elements()))
994 .collect();
995
996 to_quantize
997 .par_iter()
998 .map(|w| self.quantize_weight_to_output(w))
999 .collect()
1000 }
1001
1002 fn quantize_weight_to_output(
1003 &self,
1004 weight: &crate::onnx_utils::WeightTensor,
1005 ) -> Result<QuantizedWeightOutput> {
1006 let layer_bits = self.config.bits_for_layer(&weight.name);
1007
1008 let quantized = if layer_bits == self.config.bits {
1013 self.quantize_tensor_with_name(&weight.name, &weight.data, weight.shape.clone())?
1014 } else {
1015 let layer_config = QuantConfig {
1016 bits: layer_bits,
1017 per_channel: self.config.per_channel,
1018 symmetric: self.config.symmetric,
1019 ..Default::default()
1020 };
1021 Quantizer::new(layer_config).quantize_tensor(&weight.data, weight.shape.clone())?
1022 };
1023
1024 let mse = quantized.quantization_error(&weight.data);
1025 let (scales, zero_points) = quantized.get_all_scales_zero_points();
1026 let is_per_channel = quantized.is_per_channel();
1027 let bits_used = quantized.bits();
1028 let quantized_size_bytes = quantized.size_bytes();
1029
1030 Ok(QuantizedWeightOutput {
1031 qdq: crate::onnx_utils::graph_builder::QdqWeightInput {
1032 original_name: weight.name.clone(),
1033 quantized_values: quantized.data(),
1034 scales,
1035 zero_points,
1036 bits: bits_used,
1037 axis: if is_per_channel { Some(0) } else { None },
1038 },
1039 quantized_size_bytes,
1040 mse,
1041 })
1042 }
1043}
1044
1045#[derive(Debug, Clone)]
1052pub struct QuantizedWeightOutput {
1053 pub qdq: crate::onnx_utils::graph_builder::QdqWeightInput,
1055 pub quantized_size_bytes: usize,
1058 pub mse: f32,
1060}
1061
1062fn finite_min_max(data: &[f32], name: &str) -> Result<(f32, f32)> {
1068 let min = data
1069 .iter()
1070 .copied()
1071 .filter(|v| v.is_finite())
1072 .fold(f32::INFINITY, f32::min);
1073 let max = data
1074 .iter()
1075 .copied()
1076 .filter(|v| v.is_finite())
1077 .fold(f32::NEG_INFINITY, f32::max);
1078 if !min.is_finite() || !max.is_finite() {
1079 return Err(QuantizeError::InvalidTensor {
1080 reason: format!(
1081 "Tensor '{}' contains only non-finite values (NaN/Inf)",
1082 name
1083 ),
1084 });
1085 }
1086 Ok((min, max))
1087}
1088
1089#[cfg(test)]
1090mod tests {
1091 use super::*;
1092
1093 #[test]
1098 fn test_should_quantize_no_restrictions() {
1099 let config = QuantConfig::default();
1100 assert!(config.should_quantize("any.layer", 1));
1101 assert!(config.should_quantize("any.layer", 1_000_000));
1102 }
1103
1104 #[test]
1105 fn test_should_quantize_excluded_layer() {
1106 let config = QuantConfig {
1107 excluded_layers: vec!["head.weight".to_string()],
1108 ..Default::default()
1109 };
1110 assert!(!config.should_quantize("head.weight", 1024));
1111 assert!(config.should_quantize("body.weight", 1024));
1112 }
1113
1114 #[test]
1115 fn test_should_quantize_min_elements() {
1116 let config = QuantConfig {
1117 min_elements: 512,
1118 ..Default::default()
1119 };
1120 assert!(!config.should_quantize("small.bias", 4));
1121 assert!(!config.should_quantize("small.bias", 511));
1122 assert!(config.should_quantize("large.weight", 512));
1123 assert!(config.should_quantize("large.weight", 1024));
1124 }
1125
1126 #[test]
1127 fn test_should_quantize_excluded_takes_priority_over_min_elements() {
1128 let config = QuantConfig {
1129 excluded_layers: vec!["head.weight".to_string()],
1130 min_elements: 1,
1131 ..Default::default()
1132 };
1133 assert!(!config.should_quantize("head.weight", 1_000_000));
1135 }
1136
1137 #[test]
1138 fn test_bits_for_layer_default() {
1139 let config = QuantConfig {
1140 bits: 8,
1141 ..Default::default()
1142 };
1143 assert_eq!(config.bits_for_layer("any.weight"), 8);
1144 }
1145
1146 #[test]
1147 fn test_bits_for_layer_override() {
1148 let mut layer_bits = std::collections::HashMap::new();
1149 layer_bits.insert("head.weight".to_string(), 4u8);
1150 let config = QuantConfig {
1151 bits: 8,
1152 layer_bits,
1153 ..Default::default()
1154 };
1155 assert_eq!(config.bits_for_layer("head.weight"), 4);
1156 assert_eq!(config.bits_for_layer("body.weight"), 8);
1157 }
1158
1159 #[test]
1164 fn test_quant_params() {
1165 let params = QuantParams::from_range(-1.0, 1.0);
1166
1167 assert_eq!(params.quantize(0.0), params.zero_point);
1168
1169 let original = 0.5;
1170 let quantized = params.quantize(original);
1171 let dequantized = params.dequantize(quantized);
1172
1173 assert!((original - dequantized).abs() < 0.01);
1174 }
1175
1176 #[test]
1177 fn test_quantize_tensor() {
1178 let data = vec![0.0, 0.5, 1.0, -0.5, -1.0];
1179 let shape = vec![5];
1180
1181 let quantized = QuantizedTensor::from_f32(&data, shape).unwrap();
1182
1183 assert_eq!(quantized.data.len(), 5);
1184 assert_eq!(quantized.size_bytes(), 5);
1185 }
1186
1187 #[test]
1188 fn test_per_channel_quantization() {
1189 let mut data = vec![];
1190 for _ in 0..100 {
1191 data.push(0.5); }
1193 for _ in 0..100 {
1194 data.push(5.0); }
1196
1197 let shape = vec![2, 100];
1198
1199 let quantized = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
1200
1201 assert!(quantized.per_channel);
1202 assert!(quantized.channel_params.is_some());
1203 assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 2);
1204
1205 let dequantized = quantized.to_f32();
1206 let error: f32 = data
1207 .iter()
1208 .zip(dequantized.iter())
1209 .map(|(a, b)| (a - b).powi(2))
1210 .sum::<f32>()
1211 / data.len() as f32;
1212
1213 println!("Per-channel MSE: {}", error);
1214 assert!(error < 0.1);
1215 }
1216
1217 #[test]
1218 fn test_per_channel_vs_per_tensor() {
1219 let mut data = vec![];
1220
1221 for _ in 0..1000 {
1222 data.push(0.01);
1223 }
1224
1225 for _ in 0..1000 {
1226 data.push(10.0);
1227 }
1228
1229 let shape = vec![2, 1000];
1230
1231 let per_tensor = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1233 let per_tensor_error = per_tensor.quantization_error(&data);
1234
1235 let per_channel = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
1237 let per_channel_error = per_channel.quantization_error(&data);
1238
1239 println!("Per-tensor error: {:.8}", per_tensor_error);
1240 println!("Per-channel error: {:.8}", per_channel_error);
1241
1242 assert!(per_channel_error < per_tensor_error);
1244 assert!(per_channel_error < per_tensor_error * 0.5);
1245 }
1246
1247 #[test]
1248 fn test_per_channel_benefit() {
1249 let mut data = vec![];
1250
1251 for i in 0..1000 {
1252 data.push(-0.1 + (i as f32 / 1000.0) * 0.2);
1253 }
1254
1255 for i in 0..1000 {
1256 data.push(-10.0 + (i as f32 / 1000.0) * 20.0);
1257 }
1258
1259 let shape = vec![2, 1000];
1260
1261 let per_tensor = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1262 let per_tensor_error = per_tensor.quantization_error(&data);
1263
1264 let per_channel = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
1265 let per_channel_error = per_channel.quantization_error(&data);
1266
1267 println!("Per-tensor MSE: {:.8}", per_tensor_error);
1268 println!("Per-channel MSE: {:.8}", per_channel_error);
1269
1270 assert!(
1271 per_channel_error < per_tensor_error,
1272 "Per-channel ({:.8}) should be better than per-tensor ({:.8})",
1273 per_channel_error,
1274 per_tensor_error
1275 );
1276 }
1277
1278 #[test]
1279 fn test_int4_quant_params() {
1280 let params = QuantParamsInt4::from_range(-1.0, 1.0);
1281
1282 assert!(params.quantize(-10.0) >= -8);
1283 assert!(params.quantize(-10.0) <= 7);
1284 assert!(params.quantize(10.0) >= -8);
1285 assert!(params.quantize(10.0) <= 7);
1286
1287 let zero_quant = params.quantize(0.0);
1288 assert!(zero_quant >= -8 && zero_quant <= 7);
1289
1290 for &original in &[-1.0, -0.5, 0.0, 0.5, 1.0] {
1291 let quantized = params.quantize(original);
1292 let dequantized = params.dequantize(quantized);
1293
1294 println!(
1295 "Original: {:.2}, Quantized: {}, Dequantized: {:.2}, Error: {:.4}",
1296 original,
1297 quantized,
1298 dequantized,
1299 (original - dequantized).abs()
1300 );
1301
1302 assert!((original - dequantized).abs() < params.scale * 2.0);
1303 }
1304 }
1305
1306 #[test]
1307 fn test_int4_extreme_values() {
1308 let params = QuantParamsInt4::from_range(-100.0, 100.0);
1310
1311 let q_neg = params.quantize(-100.0);
1312 let q_pos = params.quantize(100.0);
1313
1314 assert_eq!(q_neg, -8);
1315 assert_eq!(q_pos, 7);
1316 }
1317
1318 #[test]
1319 fn test_int4_vs_int8_error() {
1320 let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
1321
1322 let params_int8 = QuantParams::from_range(-1.0, 1.0);
1323 let error_int8: f32 = data
1324 .iter()
1325 .map(|&v| {
1326 let q = params_int8.quantize(v);
1327 let dq = params_int8.dequantize(q);
1328 (v - dq).powi(2)
1329 })
1330 .sum::<f32>()
1331 / data.len() as f32;
1332
1333 let params_int4 = QuantParamsInt4::from_range(-1.0, 1.0);
1334 let error_int4: f32 = data
1335 .iter()
1336 .map(|&v| {
1337 let q = params_int4.quantize(v);
1338 let dq = params_int4.dequantize(q);
1339 (v - dq).powi(2)
1340 })
1341 .sum::<f32>()
1342 / data.len() as f32;
1343
1344 println!("INT8 MSE: {:.8}", error_int8);
1345 println!("INT4 MSE: {:.8}", error_int4);
1346
1347 assert!(error_int4 > error_int8);
1348
1349 assert!(
1350 error_int4 < error_int8 * 500.0,
1351 "INT4 error ({:.8}) is too high compared to INT8 ({:.8})",
1352 error_int4,
1353 error_int8
1354 );
1355
1356 assert!(error_int4.is_finite());
1357 assert!(error_int4 < 0.01);
1358 }
1359
1360 #[test]
1361 fn test_int4_range() {
1362 let params = QuantParamsInt4::from_range(-1.0, 1.0);
1363
1364 assert!(params.quantize(-10.0) == -8);
1365 assert!(params.quantize(10.0) == 7);
1366
1367 for i in -8..=7 {
1369 let value = i as f32 * params.scale;
1370 let quantized = params.quantize(value);
1371 assert!(quantized >= -8 && quantized <= 7);
1372 }
1373 }
1374
1375 #[test]
1376 fn test_int4_optimal_precision() {
1377 let params = QuantParamsInt4::from_range(-1.0, 1.0);
1378
1379 let mut unique_values = std::collections::HashSet::new();
1380
1381 for i in 0..1000 {
1383 let value = -1.0 + (i as f32 / 1000.0) * 2.0;
1384 unique_values.insert(params.quantize(value));
1385 }
1386
1387 println!("Unique quantized values: {}", unique_values.len());
1388 assert!(unique_values.len() >= 14);
1389 }
1390
1391 #[test]
1392 fn test_int4_tensor_quantization() {
1393 let data = vec![0.0, 0.5, 1.0, -0.5, -1.0];
1394 let shape = vec![5];
1395
1396 let quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1397
1398 assert_eq!(quantized.data.len(), 5);
1399 assert_eq!(quantized.size_bytes(), 5);
1400 assert_eq!(quantized.packed_size_bytes(), 3);
1401
1402 for &val in &quantized.data {
1403 assert!(val >= -8 && val <= 7, "Value {} out of INT4 range", val);
1404 }
1405 }
1406
1407 #[test]
1408 fn test_int4_round_trip() {
1409 let original = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
1410 let shape = vec![5];
1411
1412 let quantized = QuantizedTensorInt4::from_f32(&original, shape).unwrap();
1413 let dequantized = quantized.to_f32();
1414
1415 println!("Original: {:?}", original);
1416 println!("Quantized: {:?}", quantized.data);
1417 println!("Dequantized: {:?}", dequantized);
1418
1419 for (orig, deq) in original.iter().zip(dequantized.iter()) {
1420 let error = (orig - deq).abs();
1421 println!(" {:.2} -> {:.2}, error: {:.4}", orig, deq, error);
1422 assert!(error < 0.15, "Error too large: {}", error);
1423 }
1424 }
1425
1426 #[test]
1427 fn test_int4_per_channel() {
1428 let mut data = vec![];
1429
1430 for i in 0..100 {
1432 data.push(-0.1 + (i as f32 / 100.0) * 0.2);
1433 }
1434
1435 for i in 0..100 {
1437 data.push(-10.0 + (i as f32 / 100.0) * 20.0);
1438 }
1439
1440 let shape = vec![2, 100];
1441
1442 let quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1443
1444 assert!(quantized.per_channel);
1445 assert!(quantized.channel_params.is_some());
1446 assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 2);
1447
1448 let error = quantized.quantization_error(&data);
1449 println!("INT4 per-channel MSE: {:.8}", error);
1450
1451 assert!(error < 1.0, "Error too high: {}", error);
1452 }
1453
1454 #[test]
1455 fn test_int4_vs_int8_compression() {
1456 let data: Vec<f32> = (0..1000).map(|i| (i as f32 / 1000.0) * 2.0 - 1.0).collect();
1457 let shape = vec![1000];
1458
1459 let int8_quantized = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1460 let int8_size = int8_quantized.size_bytes();
1461 let int8_error = int8_quantized.quantization_error(&data);
1462
1463 let int4_quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1464 let int4_size = int4_quantized.size_bytes();
1465 let int4_packed_size = int4_quantized.packed_size_bytes();
1466 let int4_error = int4_quantized.quantization_error(&data);
1467
1468 println!("INT8: {} bytes, MSE: {:.8}", int8_size, int8_error);
1469 println!(
1470 "INT4 (unpacked): {} bytes, MSE: {:.8}",
1471 int4_size, int4_error
1472 );
1473 println!(
1474 "INT4 (packed): {} bytes, MSE: {:.8}",
1475 int4_packed_size, int4_error
1476 );
1477
1478 assert_eq!(int4_size, int8_size);
1479
1480 assert!(int4_packed_size <= int8_size / 2 + 1);
1481
1482 assert!(int4_error > int8_error);
1483
1484 assert!(int4_error < 0.01, "INT4 error too high: {}", int4_error);
1485 }
1486
1487 #[test]
1488 fn test_int4_large_tensor() {
1489 let size = 64 * 3 * 3 * 3; let data: Vec<f32> = (0..size)
1491 .map(|i| ((i as f32 / size as f32) * 2.0 - 1.0) * 0.5)
1492 .collect();
1493
1494 let shape = vec![64, 3, 3, 3];
1495
1496 let quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1497
1498 assert_eq!(quantized.data.len(), size);
1499 assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 64);
1500
1501 let error = quantized.quantization_error(&data);
1502 println!("Large tensor INT4 error: {:.8}", error);
1503
1504 assert!(error < 0.01, "Error too high for large tensor: {}", error);
1505 }
1506
1507 #[test]
1508 fn test_int4_extreme_ranges() {
1509 let test_cases = vec![
1510 (vec![-0.001, 0.0, 0.001], "tiny range"),
1511 (vec![-100.0, 0.0, 100.0], "large range"),
1512 (vec![0.0, 0.0, 0.0], "all zeros"),
1513 (vec![1.0, 1.0, 1.0], "all same"),
1514 ];
1515
1516 for (data, desc) in test_cases {
1517 println!("\nTesting: {}", desc);
1518 let shape = vec![data.len()];
1519
1520 let result = QuantizedTensorInt4::from_f32(&data, shape);
1521 assert!(result.is_ok(), "Failed on {}", desc);
1522
1523 let quantized = result.unwrap();
1524 let dequantized = quantized.to_f32();
1525
1526 println!(" Original: {:?}", data);
1527 println!(" Dequantized: {:?}", dequantized);
1528
1529 for &val in &quantized.data {
1530 assert!(
1531 val >= -8 && val <= 7,
1532 "Value {} out of range for {}",
1533 val,
1534 desc
1535 );
1536 }
1537 }
1538 }
1539
1540 #[test]
1541 fn test_int4_pack_unpack_pair() {
1542 let test_cases = vec![
1543 (-8, 7),
1544 (-8, -8),
1545 (7, 7),
1546 (0, 0),
1547 (-1, 0),
1548 (0, -1),
1549 (-5, 3),
1550 (6, -4),
1551 ];
1552
1553 for (val1, val2) in test_cases {
1554 println!("\nTesting: ({}, {})", val1, val2);
1555
1556 let packed = pack_int4_pair(val1, val2);
1557 let (unpacked1, unpacked2) = unpack_int4_pair(packed);
1558
1559 println!(" Packed: 0x{:02X} (binary: {:08b})", packed, packed);
1560 println!(" Unpacked: ({}, {})", unpacked1, unpacked2);
1561
1562 assert_eq!(val1, unpacked1, "First value mismatch");
1563 assert_eq!(val2, unpacked2, "Second value mismatch");
1564 }
1565 }
1566
1567 #[test]
1568 fn test_int4_pack_unpack_vector() {
1569 let values = vec![-8, -7, -1, 0, 1, 7];
1570 let packed = pack_int4(&values);
1571 let unpacked = unpack_int4(&packed, values.len());
1572
1573 println!("\nEven length:");
1574 println!(" Original: {:?}", values);
1575 println!(" Packed: {:?} ({} bytes)", packed, packed.len());
1576 println!(" Unpacked: {:?}", unpacked);
1577
1578 assert_eq!(values, unpacked);
1579 assert_eq!(packed.len(), (values.len() + 1) / 2);
1580 }
1581
1582 #[test]
1583 fn test_int4_pack_unpack_odd_length() {
1584 let values = vec![-8, -5, 0, 5, 7];
1585 let packed = pack_int4(&values);
1586 let unpacked = unpack_int4(&packed, values.len());
1587
1588 println!("\nOdd length:");
1589 println!(" Original: {:?}", values);
1590 println!(" Packed: {:?} ({} bytes)", packed, packed.len());
1591 println!(" Unpacked: {:?}", unpacked);
1592
1593 assert_eq!(values, unpacked);
1594 assert_eq!(packed.len(), (values.len() + 1) / 2);
1595 }
1596
1597 #[test]
1598 fn test_int4_pack_all_values() {
1599 let values: Vec<i8> = (-8..=7).collect();
1600 let packed = pack_int4(&values);
1601 let unpacked = unpack_int4(&packed, values.len());
1602
1603 println!("\nAll INT4 values:");
1604 println!(" Original: {:?}", values);
1605 println!(" Packed: {} bytes", packed.len());
1606 println!(" Unpacked: {:?}", unpacked);
1607
1608 assert_eq!(values, unpacked);
1609 assert_eq!(packed.len(), 8);
1610 }
1611
1612 #[test]
1613 fn test_int4_pack_large_vector() {
1614 let values: Vec<i8> = (0..1000).map(|i| ((i % 16) - 8) as i8).collect();
1615 let packed = pack_int4(&values);
1616 let unpacked = unpack_int4(&packed, values.len());
1617
1618 assert_eq!(values, unpacked);
1619 assert_eq!(packed.len(), 500);
1620
1621 println!("\nLarge vector:");
1622 println!(" Original: {} values", values.len());
1623 println!(
1624 " Packed: {} bytes ({}x compression)",
1625 packed.len(),
1626 values.len() / packed.len()
1627 );
1628 println!(" Unpacked: {} values", unpacked.len());
1629 }
1630
1631 #[test]
1632 fn test_int4_compression_ratio() {
1633 let size = 10000;
1634 let values: Vec<i8> = (0..size).map(|i| ((i % 16) - 8) as i8).collect();
1635
1636 let unpacked_size = values.len() * std::mem::size_of::<i8>();
1637
1638 let packed = pack_int4(&values);
1639 let packed_size = packed.len();
1640
1641 let compression_ratio = unpacked_size as f32 / packed_size as f32;
1642
1643 println!("\nCompression test:");
1644 println!(" Values: {}", size);
1645 println!(" Unpacked: {} bytes", unpacked_size);
1646 println!(" Packed: {} bytes", packed_size);
1647 println!(" Compression: {:.2}x", compression_ratio);
1648
1649 assert!(
1650 (compression_ratio - 2.0).abs() < 0.01,
1651 "Expected ~2x compression, got {:.2}x",
1652 compression_ratio
1653 );
1654 }
1655
1656 #[test]
1657 fn test_int4_tensor_packing() {
1658 let data: Vec<f32> = (0..1000).map(|i| (i as f32 / 1000.0) * 2.0 - 1.0).collect();
1659 let shape = vec![1000];
1660
1661 let mut quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1662
1663 println!("Before packing:");
1664 println!(" Unpacked size: {} bytes", quantized.unpacked_size_bytes());
1665 println!(" Is packed: {}", quantized.is_packed());
1666
1667 assert!(!quantized.is_packed());
1668 assert_eq!(quantized.size_bytes(), 1000);
1669
1670 quantized.pack();
1671
1672 println!("\nAfter packing:");
1673 println!(" Packed size: {} bytes", quantized.size_bytes());
1674 println!(" Is packed: {}", quantized.is_packed());
1675 println!(
1676 " Compression: {}x",
1677 quantized.unpacked_size_bytes() / quantized.size_bytes()
1678 );
1679
1680 assert!(quantized.is_packed());
1681 assert_eq!(quantized.size_bytes(), 500);
1682
1683 let dequantized = quantized.to_f32();
1684 assert_eq!(dequantized.len(), 1000);
1685
1686 let error = quantized.quantization_error(&data);
1687 println!(" MSE after packing: {:.8}", error);
1688 assert!(error < 0.01);
1689 }
1690
1691 #[test]
1692 fn test_int4_packed_vs_unpacked_error() {
1693 let data: Vec<f32> = (0..100).map(|i| (i as f32 / 100.0) * 2.0 - 1.0).collect();
1694 let shape = vec![100];
1695
1696 let unpacked = QuantizedTensorInt4::from_f32(&data, shape.clone()).unwrap();
1697 let error_unpacked = unpacked.quantization_error(&data);
1698
1699 let mut packed = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1700 packed.pack();
1701 let error_packed = packed.quantization_error(&data);
1702
1703 println!("Unpacked error: {:.8}", error_unpacked);
1704 println!("Packed error: {:.8}", error_packed);
1705
1706 assert!((error_unpacked - error_packed).abs() < 1e-6);
1707 }
1708
1709 #[test]
1710 fn test_int4_per_channel_packing() {
1711 let mut data = vec![];
1712 for i in 0..500 {
1713 data.push((i as f32 / 500.0) * 0.2 - 0.1); }
1715 for i in 0..500 {
1716 data.push((i as f32 / 500.0) * 20.0 - 10.0); }
1718
1719 let shape = vec![2, 500];
1720
1721 let mut quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1722
1723 let error_before = quantized.quantization_error(&data);
1724 println!("Error before packing: {:.8}", error_before);
1725
1726 quantized.pack();
1727
1728 let error_after = quantized.quantization_error(&data);
1729 println!("Error after packing: {:.8}", error_after);
1730 println!(
1731 "Size: {} bytes (packed from {} bytes)",
1732 quantized.size_bytes(),
1733 quantized.unpacked_size_bytes()
1734 );
1735
1736 assert!((error_before - error_after).abs() < 1e-6);
1737
1738 assert_eq!(quantized.size_bytes(), 500);
1739 }
1740
1741 #[test]
1742 fn test_int4_compression_comparison() {
1743 let size = 10000;
1744 let data: Vec<f32> = (0..size)
1745 .map(|i| ((i as f32 / size as f32) * 2.0 - 1.0) * 0.5)
1746 .collect();
1747 let shape = vec![size];
1748
1749 let fp32_size = size * std::mem::size_of::<f32>();
1750
1751 let int8 = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1752 let int8_size = int8.size_bytes();
1753
1754 let int4_unpacked = QuantizedTensorInt4::from_f32(&data, shape.clone()).unwrap();
1755 let int4_unpacked_size = int4_unpacked.size_bytes();
1756
1757 let mut int4_packed = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1758 int4_packed.pack();
1759 let int4_packed_size = int4_packed.size_bytes();
1760
1761 println!("\nCompression Comparison:");
1762 println!(" FP32: {} bytes", fp32_size);
1763 println!(
1764 " INT8: {} bytes ({:.1}x)",
1765 int8_size,
1766 fp32_size as f32 / int8_size as f32
1767 );
1768 println!(
1769 " INT4 unpacked: {} bytes ({:.1}x)",
1770 int4_unpacked_size,
1771 fp32_size as f32 / int4_unpacked_size as f32
1772 );
1773 println!(
1774 " INT4 packed: {} bytes ({:.1}x)",
1775 int4_packed_size,
1776 fp32_size as f32 / int4_packed_size as f32
1777 );
1778
1779 assert_eq!(fp32_size / int8_size, 4); assert_eq!(fp32_size / int4_packed_size, 8); }
1782
1783 #[test]
1784 #[ignore] fn test_int4_real_model() {
1786 use crate::onnx_utils::OnnxModel;
1787
1788 println!("\n{}", "=".repeat(60));
1789 println!("INT4 Real Model Test");
1790 println!("\n{}", "=".repeat(60));
1791
1792 let model_paths = vec![
1793 "test_models/mnist.onnx",
1794 "mnist.onnx",
1795 "test_models/resnet18-v1-7.onnx",
1796 "resnet18-v1-7.onnx",
1797 ];
1798
1799 let mut model = None;
1800 for path in &model_paths {
1801 if std::path::Path::new(path).exists() {
1802 println!("Loading model: {}", path);
1803 match OnnxModel::load(path) {
1804 Ok(m) => {
1805 model = Some(m);
1806 break;
1807 }
1808 Err(e) => println!(" Failed: {}", e),
1809 }
1810 }
1811 }
1812
1813 let model = match model {
1814 Some(m) => m,
1815 None => {
1816 println!("No test models found. Skipping test.");
1817 println!("Place mnist.onnx or resnet18-v1-7.onnx in current directory.");
1818 return;
1819 }
1820 };
1821
1822 let info = model.info();
1823 println!("✓ Model loaded: {}", info.name);
1824 println!(" Nodes: {}", info.num_nodes);
1825 println!();
1826
1827 println!("Extracting weights...");
1828 let weights = model.extract_weights();
1829 println!("✓ Found {} weight tensors", weights.len());
1830
1831 if weights.is_empty() {
1832 println!("No weights to quantize!");
1833 return;
1834 }
1835
1836 println!();
1837 println!("\n{}", "=".repeat(60));
1838 println!("Testing Per-Tensor Quantization");
1839 println!("\n{}", "=".repeat(60));
1840
1841 let test_weights: Vec<_> = weights
1842 .iter()
1843 .filter(|w| w.data.len() > 1000)
1844 .take(5)
1845 .collect();
1846
1847 println!("Testing {} large layers:\n", test_weights.len());
1848
1849 for (idx, weight) in test_weights.iter().enumerate() {
1850 let name = if weight.name.len() > 40 {
1851 format!("{}...", &weight.name[..37])
1852 } else {
1853 weight.name.clone()
1854 };
1855
1856 println!("[{}] {}", idx + 1, name);
1857 println!(
1858 " Shape: {:?}, Elements: {}",
1859 weight.shape,
1860 weight.data.len()
1861 );
1862
1863 let fp32_size = weight.data.len() * 4;
1864
1865 let int8_result = QuantizedTensor::from_f32(&weight.data, weight.shape.clone());
1866 let (int8_size, int8_error) = if let Ok(q) = int8_result {
1867 (q.size_bytes(), q.quantization_error(&weight.data))
1868 } else {
1869 println!(" INT8 failed!");
1870 continue;
1871 };
1872
1873 let int4_result = QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone());
1874 let (int4_unpacked_size, int4_error) = if let Ok(q) = int4_result {
1875 (q.size_bytes(), q.quantization_error(&weight.data))
1876 } else {
1877 println!(" INT4 failed!");
1878 continue;
1879 };
1880
1881 let mut int4_packed =
1882 QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone()).unwrap();
1883 int4_packed.pack();
1884 let int4_packed_size = int4_packed.size_bytes();
1885 let int4_packed_error = int4_packed.quantization_error(&weight.data);
1886
1887 println!(" FP32: {:7} bytes", fp32_size);
1888 println!(
1889 " INT8: {:7} bytes ({:.1}x) MSE: {:.8}",
1890 int8_size,
1891 fp32_size as f32 / int8_size as f32,
1892 int8_error
1893 );
1894 println!(
1895 " INT4 unpacked: {:7} bytes ({:.1}x) MSE: {:.8}",
1896 int4_unpacked_size,
1897 fp32_size as f32 / int4_unpacked_size as f32,
1898 int4_error
1899 );
1900 println!(
1901 " INT4 packed: {:7} bytes ({:.1}x) MSE: {:.8}",
1902 int4_packed_size,
1903 fp32_size as f32 / int4_packed_size as f32,
1904 int4_packed_error
1905 );
1906
1907 assert_eq!(int4_error, int4_packed_error, "Packing changed error!");
1908
1909 let int8_ratio = fp32_size as f32 / int8_size as f32;
1910 let int4_ratio = fp32_size as f32 / int4_packed_size as f32;
1911
1912 assert!(
1913 (int8_ratio - 4.0).abs() < 0.1,
1914 "INT8 compression should be ~4x"
1915 );
1916 assert!(
1917 (int4_ratio - 8.0).abs() < 0.1,
1918 "INT4 compression should be ~8x"
1919 );
1920
1921 println!();
1922 }
1923
1924 println!("\n{}", "=".repeat(60));
1925 println!("Testing Per-Channel Quantization");
1926 println!("\n{}", "=".repeat(60));
1927
1928 let conv_weights: Vec<_> = weights
1930 .iter()
1931 .filter(|w| w.shape.len() >= 2 && w.shape[0] > 1)
1932 .take(3)
1933 .collect();
1934
1935 if conv_weights.is_empty() {
1936 println!("No multi-channel layers found for per-channel test.");
1937 } else {
1938 println!("Testing {} conv layers:\n", conv_weights.len());
1939
1940 for (idx, weight) in conv_weights.iter().enumerate() {
1941 let name = if weight.name.len() > 40 {
1942 format!("{}...", &weight.name[..37])
1943 } else {
1944 weight.name.clone()
1945 };
1946
1947 println!("[{}] {}", idx + 1, name);
1948 println!(
1949 " Shape: {:?}, Channels: {}",
1950 weight.shape, weight.shape[0]
1951 );
1952
1953 let per_tensor =
1954 QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone()).unwrap();
1955 let per_tensor_error = per_tensor.quantization_error(&weight.data);
1956
1957 let per_channel_result =
1958 QuantizedTensorInt4::from_f32_per_channel(&weight.data, weight.shape.clone());
1959
1960 if let Ok(per_channel) = per_channel_result {
1961 let per_channel_error = per_channel.quantization_error(&weight.data);
1962
1963 let improvement =
1964 ((per_tensor_error - per_channel_error) / per_tensor_error) * 100.0;
1965
1966 println!(" Per-tensor: MSE: {:.8}", per_tensor_error);
1967 println!(
1968 " Per-channel: MSE: {:.8} ({:.1}% better)",
1969 per_channel_error, improvement
1970 );
1971
1972 assert!(
1973 per_channel_error <= per_tensor_error * 1.1,
1974 "Per-channel should not be significantly worse"
1975 );
1976 } else {
1977 println!(" Per-channel failed!");
1978 }
1979
1980 println!();
1981 }
1982 }
1983
1984 println!("\n{}", "=".repeat(60));
1985 println!("Summary");
1986 println!("\n{}", "=".repeat(60));
1987
1988 println!("✓ INT4 quantization works on real model weights");
1989 println!("✓ Compression ratios correct (4x INT8, 8x INT4)");
1990 println!("✓ Bit packing is lossless");
1991 println!("✓ Per-channel quantization works");
1992 println!("\nINT4 implementation is ready for CLI integration!");
1993 }
1994
1995 #[test]
2000 fn test_all_nan_returns_error() {
2001 let data = vec![f32::NAN, f32::NAN, f32::NAN];
2002 let result = QuantizedTensor::from_f32(&data, vec![3]);
2003 assert!(result.is_err());
2004 let err = result.unwrap_err().to_string();
2005 assert!(
2006 err.contains("non-finite"),
2007 "error should mention non-finite: {}",
2008 err
2009 );
2010 }
2011
2012 #[test]
2013 fn test_all_inf_returns_error() {
2014 let data = vec![f32::INFINITY, f32::NEG_INFINITY];
2015 let result = QuantizedTensor::from_f32(&data, vec![2]);
2016 assert!(result.is_err());
2017 }
2018
2019 #[test]
2020 fn test_all_nan_int4_returns_error() {
2021 let data = vec![f32::NAN; 4];
2022 let result = QuantizedTensorInt4::from_f32(&data, vec![4]);
2023 assert!(result.is_err());
2024 }
2025
2026 #[test]
2027 fn test_all_nan_per_channel_returns_error() {
2028 let data = vec![f32::NAN; 6];
2029 let result = QuantizedTensor::from_f32_per_channel(&data, vec![2, 3]);
2030 assert!(result.is_err());
2031 let err = result.unwrap_err().to_string();
2032 assert!(
2033 err.contains("Channel 0"),
2034 "error should mention channel: {}",
2035 err
2036 );
2037 }
2038
2039 #[test]
2040 fn test_mixed_nan_finite_succeeds() {
2041 let data = vec![f32::NAN, 1.0, -1.0, f32::NAN];
2043 let result = QuantizedTensor::from_f32(&data, vec![4]);
2044 assert!(result.is_ok());
2045 }
2046
2047 #[test]
2052 fn test_int8_symmetric_params_zero_point_is_zero() {
2053 let params = QuantParams::from_range_symmetric(-0.5, 2.0);
2054 assert_eq!(params.zero_point(), 0, "symmetric must have zp=0");
2055 let expected_scale = 2.0_f32 / 127.0;
2057 assert!(
2058 (params.scale() - expected_scale).abs() < 1e-6,
2059 "scale {} vs expected {}",
2060 params.scale(),
2061 expected_scale
2062 );
2063 }
2064
2065 #[test]
2066 fn test_int4_symmetric_params_zero_point_is_zero() {
2067 let params = QuantParamsInt4::from_range_symmetric(-3.0, 1.0);
2068 assert_eq!(params.zero_point(), 0);
2069 let expected_scale = 3.0_f32 / 7.0;
2071 assert!((params.scale() - expected_scale).abs() < 1e-6);
2072 }
2073
2074 #[test]
2075 fn test_symmetric_zero_dequantizes_to_zero() {
2076 let params = QuantParams::from_range_symmetric(-10.0, 10.0);
2078 let q = params.quantize(0.0);
2079 assert_eq!(q, 0);
2080 let dq = params.dequantize(q);
2081 assert_eq!(dq, 0.0);
2082 }
2083
2084 #[test]
2085 fn test_symmetric_asymmetric_produce_different_scales() {
2086 let asym = QuantParams::from_range(0.0, 10.0);
2088 let sym = QuantParams::from_range_symmetric(0.0, 10.0);
2089 assert_ne!(asym.zero_point(), sym.zero_point());
2090 assert!(
2093 sym.scale() > asym.scale(),
2094 "symmetric scale {} should exceed asymmetric {}",
2095 sym.scale(),
2096 asym.scale()
2097 );
2098 }
2099
2100 #[test]
2101 fn test_symmetric_constant_tensor_handled() {
2102 let params = QuantParams::from_range_symmetric(0.0, 0.0);
2104 assert!(params.scale() > 0.0);
2105 assert_eq!(params.zero_point(), 0);
2106 }
2107
2108 #[test]
2109 fn test_from_f32_symmetric_tensor_has_zero_zp() {
2110 let data: Vec<f32> = (0..100).map(|i| (i as f32 - 50.0) * 0.1).collect();
2111 let tensor = QuantizedTensor::from_f32_symmetric(&data, vec![100]).unwrap();
2112 assert_eq!(tensor.params().zero_point(), 0);
2113 }
2114
2115 #[test]
2116 fn test_from_f32_per_channel_symmetric_every_channel_zp_zero() {
2117 let mut data = Vec::new();
2119 for ch in 0..4 {
2120 let scale = (ch + 1) as f32;
2121 for i in 0..16 {
2122 data.push((i as f32 - 8.0) * 0.1 * scale);
2123 }
2124 }
2125 let tensor = QuantizedTensor::from_f32_per_channel_symmetric(&data, vec![4, 16]).unwrap();
2126
2127 let channel_params = tensor
2128 .channel_params
2129 .as_ref()
2130 .expect("per-channel expected");
2131 assert_eq!(channel_params.len(), 4);
2132 for (i, p) in channel_params.iter().enumerate() {
2133 assert_eq!(p.zero_point(), 0, "channel {} zp should be 0", i);
2134 assert!(p.scale() > 0.0, "channel {} scale must be positive", i);
2135 }
2136 }
2137
2138 #[test]
2139 fn test_symmetric_round_trip_error_bounded() {
2140 let data: Vec<f32> = (0..500).map(|i| (i as f32 - 250.0) / 250.0).collect();
2141 let tensor = QuantizedTensor::from_f32_symmetric(&data, vec![500]).unwrap();
2142 let mse = tensor.quantization_error(&data);
2143 assert!(mse < 1e-3, "symmetric MSE unexpectedly high: {}", mse);
2145 }
2146
2147 #[test]
2148 fn test_int4_symmetric_round_trip_error_bounded() {
2149 let data: Vec<f32> = (0..500).map(|i| (i as f32 - 250.0) / 250.0).collect();
2150 let tensor = QuantizedTensorInt4::from_f32_symmetric(&data, vec![500]).unwrap();
2151 let mse = tensor.quantization_error(&data);
2152 assert!(mse < 0.01, "INT4 symmetric MSE too high: {}", mse);
2154 }
2155
2156 #[test]
2157 fn test_quantizer_symmetric_config_routes_correctly() {
2158 let data: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
2159 let config = QuantConfig {
2160 bits: 8,
2161 per_channel: true,
2162 symmetric: true,
2163 ..Default::default()
2164 };
2165 let q = Quantizer::new(config)
2166 .quantize_tensor(&data, vec![4, 16])
2167 .unwrap();
2168 let (_, zero_points) = q.get_all_scales_zero_points();
2169 assert!(
2170 zero_points.iter().all(|&z| z == 0),
2171 "all zero_points must be 0 under symmetric config, got {:?}",
2172 zero_points
2173 );
2174 }
2175}