1use crate::errors::{QuantizeError, Result};
8
9#[derive(Debug, Clone)]
11pub struct QuantConfig {
12 pub bits: u8,
14 pub per_channel: bool,
16 pub calibration_method: Option<crate::calibration::methods::CalibrationMethod>,
18 pub excluded_layers: Vec<String>,
20 pub layer_bits: std::collections::HashMap<String, u8>,
22 pub min_elements: usize,
25}
26
27impl Default for QuantConfig {
28 fn default() -> Self {
29 Self {
30 bits: 8,
31 per_channel: false,
32 calibration_method: None,
33 excluded_layers: Vec::new(),
34 layer_bits: std::collections::HashMap::new(),
35 min_elements: 0,
36 }
37 }
38}
39
40impl QuantConfig {
41 pub fn int8() -> Self {
43 Self::default()
44 }
45
46 pub fn with_per_channel(mut self, enabled: bool) -> Self {
48 self.per_channel = enabled;
49 self
50 }
51
52 pub fn with_calibration(
54 mut self,
55 method: crate::calibration::methods::CalibrationMethod,
56 ) -> Self {
57 self.calibration_method = Some(method);
58 self
59 }
60
61 pub fn should_quantize(&self, name: &str, num_elements: usize) -> bool {
67 if self.excluded_layers.iter().any(|e| e == name) {
68 return false;
69 }
70 if self.min_elements > 0 && num_elements < self.min_elements {
71 return false;
72 }
73 true
74 }
75
76 pub fn bits_for_layer(&self, name: &str) -> u8 {
81 self.layer_bits.get(name).copied().unwrap_or(self.bits)
82 }
83}
84
85pub trait QuantRange: Clone + std::fmt::Debug + Send + Sync + 'static {
91 const QMIN: f32;
93 const QMAX: f32;
95 const BITS: u8;
97}
98
99#[derive(Debug, Clone)]
101pub struct Int8Range;
102impl QuantRange for Int8Range {
103 const QMIN: f32 = -128.0;
104 const QMAX: f32 = 127.0;
105 const BITS: u8 = 8;
106}
107
108#[derive(Debug, Clone)]
110pub struct Int4Range;
111impl QuantRange for Int4Range {
112 const QMIN: f32 = -8.0;
113 const QMAX: f32 = 7.0;
114 const BITS: u8 = 4;
115}
116
117#[derive(Debug, Clone)]
127pub struct QuantParamsGeneric<R: QuantRange> {
128 scale: f32,
129 zero_point: i8,
130 _marker: std::marker::PhantomData<R>,
131}
132
133pub type QuantParams = QuantParamsGeneric<Int8Range>;
135pub type QuantParamsInt4 = QuantParamsGeneric<Int4Range>;
137
138impl<R: QuantRange> QuantParamsGeneric<R> {
139 pub fn scale(&self) -> f32 {
141 self.scale
142 }
143 pub fn zero_point(&self) -> i8 {
145 self.zero_point
146 }
147
148 pub fn from_range(min: f32, max: f32) -> Self {
150 let min = min.min(0.0);
151 let max = max.max(0.0);
152
153 let (min, max) = if (max - min).abs() < 1e-8 {
156 let abs = min.abs().max(max.abs()).max(1e-8);
157 (-abs, abs)
158 } else {
159 (min, max)
160 };
161
162 let scale = (max - min) / (R::QMAX - R::QMIN);
163 let scale = scale.max(1e-8);
164
165 let initial_zero_point = R::QMIN - min / scale;
166 let zero_point = if initial_zero_point.is_finite() {
169 initial_zero_point.round().clamp(R::QMIN, R::QMAX) as i8
170 } else {
171 0i8
172 };
173
174 QuantParamsGeneric {
175 scale,
176 zero_point,
177 _marker: std::marker::PhantomData,
178 }
179 }
180
181 pub fn quantize(&self, value: f32) -> i8 {
183 if !value.is_finite() {
184 return self.zero_point;
185 }
186 let quantized = (value / self.scale).round() + (self.zero_point as f32);
187 quantized.clamp(R::QMIN, R::QMAX) as i8
188 }
189
190 pub fn dequantize(&self, value: i8) -> f32 {
192 ((value as i32) - (self.zero_point as i32)) as f32 * self.scale
193 }
194}
195
196#[derive(Debug, Clone)]
205pub struct QuantizedTensorGeneric<R: QuantRange> {
206 pub(crate) data: Vec<i8>,
207 pub(crate) packed_data: Option<Vec<u8>>,
209 pub(crate) shape: Vec<usize>,
210 pub(crate) params: QuantParamsGeneric<R>,
211 pub(crate) per_channel: bool,
212 pub(crate) channel_params: Option<Vec<QuantParamsGeneric<R>>>,
213}
214
215pub type QuantizedTensor = QuantizedTensorGeneric<Int8Range>;
217
218pub type QuantizedTensorInt4 = QuantizedTensorGeneric<Int4Range>;
223
224impl<R: QuantRange> QuantizedTensorGeneric<R> {
229 pub fn shape(&self) -> &[usize] {
231 &self.shape
232 }
233 pub fn params(&self) -> &QuantParamsGeneric<R> {
235 &self.params
236 }
237 pub fn is_per_channel(&self) -> bool {
239 self.per_channel
240 }
241
242 pub fn from_f32(data: &[f32], shape: Vec<usize>) -> Result<Self> {
248 if data.is_empty() {
249 return Err(QuantizeError::InvalidTensor {
250 reason: "Cannot quantize empty tensor".into(),
251 });
252 }
253
254 let expected_len: usize = shape.iter().product();
255 if expected_len != data.len() {
256 return Err(QuantizeError::InvalidTensor {
257 reason: format!(
258 "Shape {:?} expects {} elements but got {}",
259 shape,
260 expected_len,
261 data.len()
262 ),
263 });
264 }
265
266 let min = data
267 .iter()
268 .copied()
269 .filter(|v| v.is_finite())
270 .fold(f32::INFINITY, f32::min);
271 let max = data
272 .iter()
273 .copied()
274 .filter(|v| v.is_finite())
275 .fold(f32::NEG_INFINITY, f32::max);
276
277 if !min.is_finite() || !max.is_finite() {
278 return Err(QuantizeError::InvalidTensor {
279 reason: "Tensor contains only non-finite values (NaN/Inf)".into(),
280 });
281 }
282
283 let params = QuantParamsGeneric::<R>::from_range(min, max);
284
285 let quantized_data: Vec<i8> = data.iter().map(|&v| params.quantize(v)).collect();
286
287 Ok(QuantizedTensorGeneric {
288 data: quantized_data,
289 packed_data: None,
290 shape,
291 params,
292 per_channel: false,
293 channel_params: None,
294 })
295 }
296
297 pub fn from_f32_with_range(
303 data: &[f32],
304 shape: Vec<usize>,
305 min: f32,
306 max: f32,
307 ) -> Result<Self> {
308 if data.is_empty() {
309 return Err(QuantizeError::InvalidTensor {
310 reason: "Cannot quantize empty tensor".into(),
311 });
312 }
313
314 let expected_len: usize = shape.iter().product();
315 if expected_len != data.len() {
316 return Err(QuantizeError::InvalidTensor {
317 reason: format!(
318 "Shape {:?} expects {} elements but got {}",
319 shape,
320 expected_len,
321 data.len()
322 ),
323 });
324 }
325
326 let params = QuantParamsGeneric::<R>::from_range(min, max);
327
328 let quantized_data: Vec<i8> = data.iter().map(|&v| params.quantize(v)).collect();
329
330 Ok(QuantizedTensorGeneric {
331 data: quantized_data,
332 packed_data: None,
333 shape,
334 params,
335 per_channel: false,
336 channel_params: None,
337 })
338 }
339
340 pub fn from_f32_per_channel(data: &[f32], shape: Vec<usize>) -> Result<Self> {
347 if data.is_empty() {
348 return Err(QuantizeError::InvalidTensor {
349 reason: "Cannot quantize empty tensor".into(),
350 });
351 }
352
353 if shape.is_empty() {
354 return Err(QuantizeError::InvalidTensor {
355 reason: "Cannot do per-channel quantization on scalar".into(),
356 });
357 }
358
359 let expected_len: usize = shape.iter().product();
360 if expected_len != data.len() {
361 return Err(QuantizeError::InvalidTensor {
362 reason: format!(
363 "Shape {:?} expects {} elements but got {}",
364 shape,
365 expected_len,
366 data.len()
367 ),
368 });
369 }
370
371 let num_channels = shape[0];
372
373 let mut channel_params = Vec::new();
374 let mut quantized_data = Vec::with_capacity(data.len());
375
376 for channel_idx in 0..num_channels {
377 let channel_data = extract_channel(data, &shape, channel_idx)?;
378
379 let min = channel_data
380 .iter()
381 .copied()
382 .filter(|v| v.is_finite())
383 .fold(f32::INFINITY, f32::min);
384 let max = channel_data
385 .iter()
386 .copied()
387 .filter(|v| v.is_finite())
388 .fold(f32::NEG_INFINITY, f32::max);
389
390 if !min.is_finite() || !max.is_finite() {
391 return Err(QuantizeError::InvalidTensor {
392 reason: format!(
393 "Channel {} contains only non-finite values (NaN/Inf)",
394 channel_idx
395 ),
396 });
397 }
398
399 let params = QuantParamsGeneric::<R>::from_range(min, max);
400 channel_params.push(params.clone());
401
402 for &value in &channel_data {
403 quantized_data.push(params.quantize(value));
404 }
405 }
406
407 let params = channel_params[0].clone();
409
410 Ok(QuantizedTensorGeneric {
411 data: quantized_data,
412 packed_data: None,
413 shape,
414 params,
415 per_channel: true,
416 channel_params: Some(channel_params),
417 })
418 }
419
420 pub fn to_f32(&self) -> Vec<f32> {
422 let data_owned;
424 let data: &[i8] = if let Some(ref packed) = self.packed_data {
425 data_owned = unpack_int4(packed, self.data.len());
426 &data_owned
427 } else {
428 &self.data
429 };
430
431 if self.per_channel {
432 if let Some(ref channel_params) = self.channel_params {
433 if channel_params.is_empty() {
434 return data.iter().map(|&v| self.params.dequantize(v)).collect();
435 }
436 let elements_per_channel = data.len() / channel_params.len();
437 data.iter()
438 .enumerate()
439 .map(|(i, &v)| {
440 let channel_idx = (i / elements_per_channel).min(channel_params.len() - 1);
441 channel_params[channel_idx].dequantize(v)
442 })
443 .collect()
444 } else {
445 data.iter().map(|&v| self.params.dequantize(v)).collect()
446 }
447 } else {
448 data.iter().map(|&v| self.params.dequantize(v)).collect()
449 }
450 }
451
452 pub fn size_bytes(&self) -> usize {
454 if let Some(ref packed) = self.packed_data {
455 packed.len()
456 } else {
457 self.data.len() * std::mem::size_of::<i8>()
458 }
459 }
460
461 pub fn quantization_error(&self, original: &[f32]) -> f32 {
463 if original.is_empty() {
464 return 0.0;
465 }
466
467 let dequantized = self.to_f32();
468
469 let sum: f32 = original
470 .iter()
471 .zip(dequantized.iter())
472 .map(|(a, b)| (a - b).powi(2))
473 .sum();
474
475 sum / original.len() as f32
476 }
477}
478
479impl QuantizedTensorGeneric<Int4Range> {
484 pub fn pack(&mut self) {
486 self.packed_data = Some(pack_int4(&self.data));
487 }
488
489 pub fn ensure_unpacked(&self) -> Vec<i8> {
491 if let Some(ref packed) = self.packed_data {
492 unpack_int4(packed, self.data.len())
493 } else {
494 self.data.clone()
495 }
496 }
497
498 pub fn is_packed(&self) -> bool {
500 self.packed_data.is_some()
501 }
502
503 pub fn packed_size_bytes(&self) -> usize {
505 if let Some(ref packed) = self.packed_data {
506 packed.len()
507 } else {
508 self.data.len().div_ceil(2)
509 }
510 }
511
512 pub fn unpacked_size_bytes(&self) -> usize {
514 self.data.len() * std::mem::size_of::<i8>()
515 }
516}
517
518fn pack_int4_pair(val1: i8, val2: i8) -> u8 {
523 debug_assert!((-8..=7).contains(&val1), "val1 out of INT4 range: {}", val1);
524 debug_assert!((-8..=7).contains(&val2), "val2 out of INT4 range: {}", val2);
525
526 let nibble1 = (val1 & 0x0F) as u8;
528 let nibble2 = (val2 & 0x0F) as u8;
529
530 (nibble1 << 4) | nibble2
532}
533
534fn unpack_int4_pair(byte: u8) -> (i8, i8) {
535 let nibble1 = (byte >> 4) & 0x0F;
536 let nibble2 = byte & 0x0F;
537
538 let val1 = if nibble1 >= 8 {
540 (nibble1 as i8) | !0x0F
541 } else {
542 nibble1 as i8
543 };
544
545 let val2 = if nibble2 >= 8 {
546 (nibble2 as i8) | !0x0F
547 } else {
548 nibble2 as i8
549 };
550
551 (val1, val2)
552}
553
554pub fn pack_int4(values: &[i8]) -> Vec<u8> {
556 let mut packed = Vec::with_capacity(values.len().div_ceil(2));
557
558 for chunk in values.chunks(2) {
559 let val1 = chunk[0];
560 let val2 = if chunk.len() > 1 { chunk[1] } else { 0 };
561
562 packed.push(pack_int4_pair(val1, val2));
563 }
564
565 packed
566}
567
568pub fn unpack_int4(packed: &[u8], num_values: usize) -> Vec<i8> {
570 let mut values = Vec::with_capacity(num_values);
571
572 for &byte in packed {
573 let (val1, val2) = unpack_int4_pair(byte);
574 values.push(val1);
575 if values.len() < num_values {
576 values.push(val2);
577 }
578 }
579
580 values.truncate(num_values);
582 values
583}
584
585fn extract_channel(data: &[f32], shape: &[usize], channel_idx: usize) -> Result<Vec<f32>> {
590 if shape.is_empty() {
591 return Err(QuantizeError::InvalidTensor {
592 reason: "Cannot extract channel from empty shape".into(),
593 });
594 }
595 let num_channels = shape[0];
596 if num_channels == 0 {
597 return Err(QuantizeError::InvalidTensor {
598 reason: "Number of channels is 0".into(),
599 });
600 }
601 if channel_idx >= num_channels {
602 return Err(QuantizeError::InvalidTensor {
603 reason: format!(
604 "Channel index {} out of bounds for {} channels",
605 channel_idx, num_channels
606 ),
607 });
608 }
609 if !data.len().is_multiple_of(num_channels) {
610 return Err(QuantizeError::InvalidTensor {
611 reason: format!(
612 "Data length {} not evenly divisible by {} channels",
613 data.len(),
614 num_channels
615 ),
616 });
617 }
618 let elements_per_channel = data.len() / num_channels;
619 let start = channel_idx * elements_per_channel;
620 let end = start + elements_per_channel;
621 Ok(data[start..end].to_vec())
622}
623
624#[derive(Debug, Clone)]
630pub enum QuantizedTensorType {
631 Int8(QuantizedTensor),
632 Int4(QuantizedTensorInt4),
633}
634
635impl QuantizedTensorType {
636 pub fn to_f32(&self) -> Vec<f32> {
638 match self {
639 QuantizedTensorType::Int8(t) => t.to_f32(),
640 QuantizedTensorType::Int4(t) => t.to_f32(),
641 }
642 }
643
644 pub fn size_bytes(&self) -> usize {
646 match self {
647 QuantizedTensorType::Int8(t) => t.size_bytes(),
648 QuantizedTensorType::Int4(t) => t.size_bytes(),
649 }
650 }
651
652 #[must_use]
653 pub fn quantization_error(&self, original: &[f32]) -> f32 {
654 match self {
655 QuantizedTensorType::Int8(t) => t.quantization_error(original),
656 QuantizedTensorType::Int4(t) => t.quantization_error(original),
657 }
658 }
659
660 #[must_use]
661 pub fn data(&self) -> Vec<i8> {
662 match self {
663 QuantizedTensorType::Int8(t) => t.data.clone(),
664 QuantizedTensorType::Int4(t) => t.ensure_unpacked(),
665 }
666 }
667
668 pub fn get_scale_zero_point(&self) -> (f32, i8) {
670 match self {
671 QuantizedTensorType::Int8(t) => (t.params.scale, t.params.zero_point),
672 QuantizedTensorType::Int4(t) => (t.params.scale, t.params.zero_point),
673 }
674 }
675
676 pub fn get_all_scales_zero_points(&self) -> (Vec<f32>, Vec<i8>) {
681 match self {
682 QuantizedTensorType::Int8(t) => {
683 if let Some(ref cp) = t.channel_params {
684 (
685 cp.iter().map(|p| p.scale).collect(),
686 cp.iter().map(|p| p.zero_point).collect(),
687 )
688 } else {
689 (vec![t.params.scale], vec![t.params.zero_point])
690 }
691 }
692 QuantizedTensorType::Int4(t) => {
693 if let Some(ref cp) = t.channel_params {
694 (
695 cp.iter().map(|p| p.scale).collect(),
696 cp.iter().map(|p| p.zero_point).collect(),
697 )
698 } else {
699 (vec![t.params.scale], vec![t.params.zero_point])
700 }
701 }
702 }
703 }
704
705 pub fn is_per_channel(&self) -> bool {
707 match self {
708 QuantizedTensorType::Int8(t) => t.per_channel,
709 QuantizedTensorType::Int4(t) => t.per_channel,
710 }
711 }
712
713 #[must_use]
714 pub fn bits(&self) -> u8 {
715 match self {
716 QuantizedTensorType::Int8(_) => 8,
717 QuantizedTensorType::Int4(_) => 4,
718 }
719 }
720
721 pub fn is_int8(&self) -> bool {
723 matches!(self, QuantizedTensorType::Int8(_))
724 }
725
726 pub fn is_int4(&self) -> bool {
728 matches!(self, QuantizedTensorType::Int4(_))
729 }
730
731 pub fn data_ref(&self) -> Option<&[i8]> {
735 match self {
736 QuantizedTensorType::Int8(t) => Some(&t.data),
737 QuantizedTensorType::Int4(t) => {
738 if t.packed_data.is_some() {
739 None } else {
741 Some(&t.data)
742 }
743 }
744 }
745 }
746}
747
748pub struct Quantizer {
754 config: QuantConfig,
755 calibration_stats:
756 Option<std::collections::HashMap<String, crate::calibration::stats::ActivationStats>>,
757}
758
759impl std::fmt::Debug for Quantizer {
760 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
761 let stats_count = self.calibration_stats.as_ref().map(|m| m.len());
762 f.debug_struct("Quantizer")
763 .field("config", &self.config)
764 .field("calibration_stats_count", &stats_count)
765 .finish()
766 }
767}
768
769impl Quantizer {
770 pub fn new(config: QuantConfig) -> Self {
772 Self {
773 config,
774 calibration_stats: None,
775 }
776 }
777
778 pub fn with_calibration(
780 config: QuantConfig,
781 stats: std::collections::HashMap<String, crate::calibration::stats::ActivationStats>,
782 ) -> Self {
783 Self {
784 config,
785 calibration_stats: Some(stats),
786 }
787 }
788
789 pub fn quantize_tensor_with_name(
791 &self,
792 name: &str,
793 data: &[f32],
794 shape: Vec<usize>,
795 ) -> Result<QuantizedTensorType> {
796 let (min, max) = if let Some(ref stats_map) = self.calibration_stats {
797 if let Some(stats) = stats_map.get(name) {
798 if let Some(method) = self.config.calibration_method {
799 use crate::calibration::stats::calculate_optimal_range;
800
801 let sample_data = sample_from_activation_stats(stats, 1000);
802 calculate_optimal_range(&sample_data, method)
803 } else {
804 (stats.min(), stats.max())
805 }
806 } else {
807 finite_min_max(data, name)?
808 }
809 } else {
810 finite_min_max(data, name)?
811 };
812
813 self.quantize_with_range(data, shape, min, max)
814 }
815
816 pub fn quantize_tensor(&self, data: &[f32], shape: Vec<usize>) -> Result<QuantizedTensorType> {
822 self.build_tensor_with_optional_range(data, shape, None)
823 }
824
825 fn quantize_with_range(
832 &self,
833 data: &[f32],
834 shape: Vec<usize>,
835 min: f32,
836 max: f32,
837 ) -> Result<QuantizedTensorType> {
838 self.build_tensor_with_optional_range(data, shape, Some((min, max)))
839 }
840
841 fn build_tensor_with_optional_range(
843 &self,
844 data: &[f32],
845 shape: Vec<usize>,
846 range: Option<(f32, f32)>,
847 ) -> Result<QuantizedTensorType> {
848 let pc = self.config.per_channel && shape.len() >= 2;
849 match self.config.bits {
850 8 => {
851 let t = match (pc, range) {
852 (true, _) => QuantizedTensor::from_f32_per_channel(data, shape)?,
853 (false, Some((min, max))) => {
854 QuantizedTensor::from_f32_with_range(data, shape, min, max)?
855 }
856 (false, None) => QuantizedTensor::from_f32(data, shape)?,
857 };
858 Ok(QuantizedTensorType::Int8(t))
859 }
860 4 => {
861 let mut t = match (pc, range) {
862 (true, _) => QuantizedTensorInt4::from_f32_per_channel(data, shape)?,
863 (false, Some((min, max))) => {
864 QuantizedTensorInt4::from_f32_with_range(data, shape, min, max)?
865 }
866 (false, None) => QuantizedTensorInt4::from_f32(data, shape)?,
867 };
868 t.pack();
869 Ok(QuantizedTensorType::Int4(t))
870 }
871 b => Err(QuantizeError::UnsupportedConfig {
872 reason: format!("bits must be 4 or 8, got {b}"),
873 }),
874 }
875 }
876}
877
878fn finite_min_max(data: &[f32], name: &str) -> Result<(f32, f32)> {
884 let min = data
885 .iter()
886 .copied()
887 .filter(|v| v.is_finite())
888 .fold(f32::INFINITY, f32::min);
889 let max = data
890 .iter()
891 .copied()
892 .filter(|v| v.is_finite())
893 .fold(f32::NEG_INFINITY, f32::max);
894 if !min.is_finite() || !max.is_finite() {
895 return Err(QuantizeError::InvalidTensor {
896 reason: format!(
897 "Tensor '{}' contains only non-finite values (NaN/Inf)",
898 name
899 ),
900 });
901 }
902 Ok((min, max))
903}
904
905fn sample_from_activation_stats(
907 stats: &crate::calibration::stats::ActivationStats,
908 n: usize,
909) -> Vec<f32> {
910 use rand::Rng;
911
912 let histogram = stats.histogram_data();
913 if histogram.is_empty() {
914 let mut rng = rand::thread_rng();
916 let range = stats.max() - stats.min();
917 if !range.is_finite() || range.abs() < 1e-8 {
918 return vec![stats.mean(); n];
919 }
920 return (0..n)
921 .map(|_| rng.gen::<f32>() * range + stats.min())
922 .collect();
923 }
924
925 let total_count: usize = histogram.iter().map(|&(_, c)| c).sum();
926 if total_count == 0 {
927 let mut rng = rand::thread_rng();
928 let range = stats.max() - stats.min();
929 if !range.is_finite() || range.abs() < 1e-8 {
930 return vec![stats.mean(); n];
931 }
932 return (0..n)
933 .map(|_| rng.gen::<f32>() * range + stats.min())
934 .collect();
935 }
936
937 let mut samples = Vec::with_capacity(n);
938 for &(value, count) in &histogram {
939 let num_samples = ((count as f64 / total_count as f64) * n as f64).round() as usize;
940 for _ in 0..num_samples {
941 samples.push(value);
942 }
943 }
944
945 samples.truncate(n);
947 while samples.len() < n {
948 samples.push(stats.mean());
949 }
950
951 samples
952}
953
954#[cfg(test)]
955mod tests {
956 use super::*;
957
958 #[test]
963 fn test_should_quantize_no_restrictions() {
964 let config = QuantConfig::default();
965 assert!(config.should_quantize("any.layer", 1));
966 assert!(config.should_quantize("any.layer", 1_000_000));
967 }
968
969 #[test]
970 fn test_should_quantize_excluded_layer() {
971 let config = QuantConfig {
972 excluded_layers: vec!["head.weight".to_string()],
973 ..Default::default()
974 };
975 assert!(!config.should_quantize("head.weight", 1024));
976 assert!(config.should_quantize("body.weight", 1024));
977 }
978
979 #[test]
980 fn test_should_quantize_min_elements() {
981 let config = QuantConfig {
982 min_elements: 512,
983 ..Default::default()
984 };
985 assert!(!config.should_quantize("small.bias", 4));
986 assert!(!config.should_quantize("small.bias", 511));
987 assert!(config.should_quantize("large.weight", 512));
988 assert!(config.should_quantize("large.weight", 1024));
989 }
990
991 #[test]
992 fn test_should_quantize_excluded_takes_priority_over_min_elements() {
993 let config = QuantConfig {
994 excluded_layers: vec!["head.weight".to_string()],
995 min_elements: 1,
996 ..Default::default()
997 };
998 assert!(!config.should_quantize("head.weight", 1_000_000));
1000 }
1001
1002 #[test]
1003 fn test_bits_for_layer_default() {
1004 let config = QuantConfig {
1005 bits: 8,
1006 ..Default::default()
1007 };
1008 assert_eq!(config.bits_for_layer("any.weight"), 8);
1009 }
1010
1011 #[test]
1012 fn test_bits_for_layer_override() {
1013 let mut layer_bits = std::collections::HashMap::new();
1014 layer_bits.insert("head.weight".to_string(), 4u8);
1015 let config = QuantConfig {
1016 bits: 8,
1017 layer_bits,
1018 ..Default::default()
1019 };
1020 assert_eq!(config.bits_for_layer("head.weight"), 4);
1021 assert_eq!(config.bits_for_layer("body.weight"), 8);
1022 }
1023
1024 #[test]
1029 fn test_quant_params() {
1030 let params = QuantParams::from_range(-1.0, 1.0);
1031
1032 assert_eq!(params.quantize(0.0), params.zero_point);
1033
1034 let original = 0.5;
1035 let quantized = params.quantize(original);
1036 let dequantized = params.dequantize(quantized);
1037
1038 assert!((original - dequantized).abs() < 0.01);
1039 }
1040
1041 #[test]
1042 fn test_quantize_tensor() {
1043 let data = vec![0.0, 0.5, 1.0, -0.5, -1.0];
1044 let shape = vec![5];
1045
1046 let quantized = QuantizedTensor::from_f32(&data, shape).unwrap();
1047
1048 assert_eq!(quantized.data.len(), 5);
1049 assert_eq!(quantized.size_bytes(), 5);
1050 }
1051
1052 #[test]
1053 fn test_per_channel_quantization() {
1054 let mut data = vec![];
1055 for _ in 0..100 {
1056 data.push(0.5); }
1058 for _ in 0..100 {
1059 data.push(5.0); }
1061
1062 let shape = vec![2, 100];
1063
1064 let quantized = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
1065
1066 assert!(quantized.per_channel);
1067 assert!(quantized.channel_params.is_some());
1068 assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 2);
1069
1070 let dequantized = quantized.to_f32();
1071 let error: f32 = data
1072 .iter()
1073 .zip(dequantized.iter())
1074 .map(|(a, b)| (a - b).powi(2))
1075 .sum::<f32>()
1076 / data.len() as f32;
1077
1078 println!("Per-channel MSE: {}", error);
1079 assert!(error < 0.1);
1080 }
1081
1082 #[test]
1083 fn test_per_channel_vs_per_tensor() {
1084 let mut data = vec![];
1085
1086 for _ in 0..1000 {
1087 data.push(0.01);
1088 }
1089
1090 for _ in 0..1000 {
1091 data.push(10.0);
1092 }
1093
1094 let shape = vec![2, 1000];
1095
1096 let per_tensor = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1098 let per_tensor_error = per_tensor.quantization_error(&data);
1099
1100 let per_channel = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
1102 let per_channel_error = per_channel.quantization_error(&data);
1103
1104 println!("Per-tensor error: {:.8}", per_tensor_error);
1105 println!("Per-channel error: {:.8}", per_channel_error);
1106
1107 assert!(per_channel_error < per_tensor_error);
1109 assert!(per_channel_error < per_tensor_error * 0.5);
1110 }
1111
1112 #[test]
1113 fn test_per_channel_benefit() {
1114 let mut data = vec![];
1115
1116 for i in 0..1000 {
1117 data.push(-0.1 + (i as f32 / 1000.0) * 0.2);
1118 }
1119
1120 for i in 0..1000 {
1121 data.push(-10.0 + (i as f32 / 1000.0) * 20.0);
1122 }
1123
1124 let shape = vec![2, 1000];
1125
1126 let per_tensor = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1127 let per_tensor_error = per_tensor.quantization_error(&data);
1128
1129 let per_channel = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
1130 let per_channel_error = per_channel.quantization_error(&data);
1131
1132 println!("Per-tensor MSE: {:.8}", per_tensor_error);
1133 println!("Per-channel MSE: {:.8}", per_channel_error);
1134
1135 assert!(
1136 per_channel_error < per_tensor_error,
1137 "Per-channel ({:.8}) should be better than per-tensor ({:.8})",
1138 per_channel_error,
1139 per_tensor_error
1140 );
1141 }
1142
1143 #[test]
1144 fn test_int4_quant_params() {
1145 let params = QuantParamsInt4::from_range(-1.0, 1.0);
1146
1147 assert!(params.quantize(-10.0) >= -8);
1148 assert!(params.quantize(-10.0) <= 7);
1149 assert!(params.quantize(10.0) >= -8);
1150 assert!(params.quantize(10.0) <= 7);
1151
1152 let zero_quant = params.quantize(0.0);
1153 assert!(zero_quant >= -8 && zero_quant <= 7);
1154
1155 for &original in &[-1.0, -0.5, 0.0, 0.5, 1.0] {
1156 let quantized = params.quantize(original);
1157 let dequantized = params.dequantize(quantized);
1158
1159 println!(
1160 "Original: {:.2}, Quantized: {}, Dequantized: {:.2}, Error: {:.4}",
1161 original,
1162 quantized,
1163 dequantized,
1164 (original - dequantized).abs()
1165 );
1166
1167 assert!((original - dequantized).abs() < params.scale * 2.0);
1168 }
1169 }
1170
1171 #[test]
1172 fn test_int4_extreme_values() {
1173 let params = QuantParamsInt4::from_range(-100.0, 100.0);
1175
1176 let q_neg = params.quantize(-100.0);
1177 let q_pos = params.quantize(100.0);
1178
1179 assert_eq!(q_neg, -8);
1180 assert_eq!(q_pos, 7);
1181 }
1182
1183 #[test]
1184 fn test_int4_vs_int8_error() {
1185 let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
1186
1187 let params_int8 = QuantParams::from_range(-1.0, 1.0);
1188 let error_int8: f32 = data
1189 .iter()
1190 .map(|&v| {
1191 let q = params_int8.quantize(v);
1192 let dq = params_int8.dequantize(q);
1193 (v - dq).powi(2)
1194 })
1195 .sum::<f32>()
1196 / data.len() as f32;
1197
1198 let params_int4 = QuantParamsInt4::from_range(-1.0, 1.0);
1199 let error_int4: f32 = data
1200 .iter()
1201 .map(|&v| {
1202 let q = params_int4.quantize(v);
1203 let dq = params_int4.dequantize(q);
1204 (v - dq).powi(2)
1205 })
1206 .sum::<f32>()
1207 / data.len() as f32;
1208
1209 println!("INT8 MSE: {:.8}", error_int8);
1210 println!("INT4 MSE: {:.8}", error_int4);
1211
1212 assert!(error_int4 > error_int8);
1213
1214 assert!(
1215 error_int4 < error_int8 * 500.0,
1216 "INT4 error ({:.8}) is too high compared to INT8 ({:.8})",
1217 error_int4,
1218 error_int8
1219 );
1220
1221 assert!(error_int4.is_finite());
1222 assert!(error_int4 < 0.01);
1223 }
1224
1225 #[test]
1226 fn test_int4_range() {
1227 let params = QuantParamsInt4::from_range(-1.0, 1.0);
1228
1229 assert!(params.quantize(-10.0) == -8);
1230 assert!(params.quantize(10.0) == 7);
1231
1232 for i in -8..=7 {
1234 let value = i as f32 * params.scale;
1235 let quantized = params.quantize(value);
1236 assert!(quantized >= -8 && quantized <= 7);
1237 }
1238 }
1239
1240 #[test]
1241 fn test_int4_optimal_precision() {
1242 let params = QuantParamsInt4::from_range(-1.0, 1.0);
1243
1244 let mut unique_values = std::collections::HashSet::new();
1245
1246 for i in 0..1000 {
1248 let value = -1.0 + (i as f32 / 1000.0) * 2.0;
1249 unique_values.insert(params.quantize(value));
1250 }
1251
1252 println!("Unique quantized values: {}", unique_values.len());
1253 assert!(unique_values.len() >= 14);
1254 }
1255
1256 #[test]
1257 fn test_int4_tensor_quantization() {
1258 let data = vec![0.0, 0.5, 1.0, -0.5, -1.0];
1259 let shape = vec![5];
1260
1261 let quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1262
1263 assert_eq!(quantized.data.len(), 5);
1264 assert_eq!(quantized.size_bytes(), 5);
1265 assert_eq!(quantized.packed_size_bytes(), 3);
1266
1267 for &val in &quantized.data {
1268 assert!(val >= -8 && val <= 7, "Value {} out of INT4 range", val);
1269 }
1270 }
1271
1272 #[test]
1273 fn test_int4_round_trip() {
1274 let original = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
1275 let shape = vec![5];
1276
1277 let quantized = QuantizedTensorInt4::from_f32(&original, shape).unwrap();
1278 let dequantized = quantized.to_f32();
1279
1280 println!("Original: {:?}", original);
1281 println!("Quantized: {:?}", quantized.data);
1282 println!("Dequantized: {:?}", dequantized);
1283
1284 for (orig, deq) in original.iter().zip(dequantized.iter()) {
1285 let error = (orig - deq).abs();
1286 println!(" {:.2} -> {:.2}, error: {:.4}", orig, deq, error);
1287 assert!(error < 0.15, "Error too large: {}", error);
1288 }
1289 }
1290
1291 #[test]
1292 fn test_int4_per_channel() {
1293 let mut data = vec![];
1294
1295 for i in 0..100 {
1297 data.push(-0.1 + (i as f32 / 100.0) * 0.2);
1298 }
1299
1300 for i in 0..100 {
1302 data.push(-10.0 + (i as f32 / 100.0) * 20.0);
1303 }
1304
1305 let shape = vec![2, 100];
1306
1307 let quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1308
1309 assert!(quantized.per_channel);
1310 assert!(quantized.channel_params.is_some());
1311 assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 2);
1312
1313 let error = quantized.quantization_error(&data);
1314 println!("INT4 per-channel MSE: {:.8}", error);
1315
1316 assert!(error < 1.0, "Error too high: {}", error);
1317 }
1318
1319 #[test]
1320 fn test_int4_vs_int8_compression() {
1321 let data: Vec<f32> = (0..1000).map(|i| (i as f32 / 1000.0) * 2.0 - 1.0).collect();
1322 let shape = vec![1000];
1323
1324 let int8_quantized = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1325 let int8_size = int8_quantized.size_bytes();
1326 let int8_error = int8_quantized.quantization_error(&data);
1327
1328 let int4_quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1329 let int4_size = int4_quantized.size_bytes();
1330 let int4_packed_size = int4_quantized.packed_size_bytes();
1331 let int4_error = int4_quantized.quantization_error(&data);
1332
1333 println!("INT8: {} bytes, MSE: {:.8}", int8_size, int8_error);
1334 println!(
1335 "INT4 (unpacked): {} bytes, MSE: {:.8}",
1336 int4_size, int4_error
1337 );
1338 println!(
1339 "INT4 (packed): {} bytes, MSE: {:.8}",
1340 int4_packed_size, int4_error
1341 );
1342
1343 assert_eq!(int4_size, int8_size);
1344
1345 assert!(int4_packed_size <= int8_size / 2 + 1);
1346
1347 assert!(int4_error > int8_error);
1348
1349 assert!(int4_error < 0.01, "INT4 error too high: {}", int4_error);
1350 }
1351
1352 #[test]
1353 fn test_int4_large_tensor() {
1354 let size = 64 * 3 * 3 * 3; let data: Vec<f32> = (0..size)
1356 .map(|i| ((i as f32 / size as f32) * 2.0 - 1.0) * 0.5)
1357 .collect();
1358
1359 let shape = vec![64, 3, 3, 3];
1360
1361 let quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1362
1363 assert_eq!(quantized.data.len(), size);
1364 assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 64);
1365
1366 let error = quantized.quantization_error(&data);
1367 println!("Large tensor INT4 error: {:.8}", error);
1368
1369 assert!(error < 0.01, "Error too high for large tensor: {}", error);
1370 }
1371
1372 #[test]
1373 fn test_int4_extreme_ranges() {
1374 let test_cases = vec![
1375 (vec![-0.001, 0.0, 0.001], "tiny range"),
1376 (vec![-100.0, 0.0, 100.0], "large range"),
1377 (vec![0.0, 0.0, 0.0], "all zeros"),
1378 (vec![1.0, 1.0, 1.0], "all same"),
1379 ];
1380
1381 for (data, desc) in test_cases {
1382 println!("\nTesting: {}", desc);
1383 let shape = vec![data.len()];
1384
1385 let result = QuantizedTensorInt4::from_f32(&data, shape);
1386 assert!(result.is_ok(), "Failed on {}", desc);
1387
1388 let quantized = result.unwrap();
1389 let dequantized = quantized.to_f32();
1390
1391 println!(" Original: {:?}", data);
1392 println!(" Dequantized: {:?}", dequantized);
1393
1394 for &val in &quantized.data {
1395 assert!(
1396 val >= -8 && val <= 7,
1397 "Value {} out of range for {}",
1398 val,
1399 desc
1400 );
1401 }
1402 }
1403 }
1404
1405 #[test]
1406 fn test_int4_pack_unpack_pair() {
1407 let test_cases = vec![
1408 (-8, 7),
1409 (-8, -8),
1410 (7, 7),
1411 (0, 0),
1412 (-1, 0),
1413 (0, -1),
1414 (-5, 3),
1415 (6, -4),
1416 ];
1417
1418 for (val1, val2) in test_cases {
1419 println!("\nTesting: ({}, {})", val1, val2);
1420
1421 let packed = pack_int4_pair(val1, val2);
1422 let (unpacked1, unpacked2) = unpack_int4_pair(packed);
1423
1424 println!(" Packed: 0x{:02X} (binary: {:08b})", packed, packed);
1425 println!(" Unpacked: ({}, {})", unpacked1, unpacked2);
1426
1427 assert_eq!(val1, unpacked1, "First value mismatch");
1428 assert_eq!(val2, unpacked2, "Second value mismatch");
1429 }
1430 }
1431
1432 #[test]
1433 fn test_int4_pack_unpack_vector() {
1434 let values = vec![-8, -7, -1, 0, 1, 7];
1435 let packed = pack_int4(&values);
1436 let unpacked = unpack_int4(&packed, values.len());
1437
1438 println!("\nEven length:");
1439 println!(" Original: {:?}", values);
1440 println!(" Packed: {:?} ({} bytes)", packed, packed.len());
1441 println!(" Unpacked: {:?}", unpacked);
1442
1443 assert_eq!(values, unpacked);
1444 assert_eq!(packed.len(), (values.len() + 1) / 2);
1445 }
1446
1447 #[test]
1448 fn test_int4_pack_unpack_odd_length() {
1449 let values = vec![-8, -5, 0, 5, 7];
1450 let packed = pack_int4(&values);
1451 let unpacked = unpack_int4(&packed, values.len());
1452
1453 println!("\nOdd length:");
1454 println!(" Original: {:?}", values);
1455 println!(" Packed: {:?} ({} bytes)", packed, packed.len());
1456 println!(" Unpacked: {:?}", unpacked);
1457
1458 assert_eq!(values, unpacked);
1459 assert_eq!(packed.len(), (values.len() + 1) / 2);
1460 }
1461
1462 #[test]
1463 fn test_int4_pack_all_values() {
1464 let values: Vec<i8> = (-8..=7).collect();
1465 let packed = pack_int4(&values);
1466 let unpacked = unpack_int4(&packed, values.len());
1467
1468 println!("\nAll INT4 values:");
1469 println!(" Original: {:?}", values);
1470 println!(" Packed: {} bytes", packed.len());
1471 println!(" Unpacked: {:?}", unpacked);
1472
1473 assert_eq!(values, unpacked);
1474 assert_eq!(packed.len(), 8);
1475 }
1476
1477 #[test]
1478 fn test_int4_pack_large_vector() {
1479 let values: Vec<i8> = (0..1000).map(|i| ((i % 16) - 8) as i8).collect();
1480 let packed = pack_int4(&values);
1481 let unpacked = unpack_int4(&packed, values.len());
1482
1483 assert_eq!(values, unpacked);
1484 assert_eq!(packed.len(), 500);
1485
1486 println!("\nLarge vector:");
1487 println!(" Original: {} values", values.len());
1488 println!(
1489 " Packed: {} bytes ({}x compression)",
1490 packed.len(),
1491 values.len() / packed.len()
1492 );
1493 println!(" Unpacked: {} values", unpacked.len());
1494 }
1495
1496 #[test]
1497 fn test_int4_compression_ratio() {
1498 let size = 10000;
1499 let values: Vec<i8> = (0..size).map(|i| ((i % 16) - 8) as i8).collect();
1500
1501 let unpacked_size = values.len() * std::mem::size_of::<i8>();
1502
1503 let packed = pack_int4(&values);
1504 let packed_size = packed.len();
1505
1506 let compression_ratio = unpacked_size as f32 / packed_size as f32;
1507
1508 println!("\nCompression test:");
1509 println!(" Values: {}", size);
1510 println!(" Unpacked: {} bytes", unpacked_size);
1511 println!(" Packed: {} bytes", packed_size);
1512 println!(" Compression: {:.2}x", compression_ratio);
1513
1514 assert!(
1515 (compression_ratio - 2.0).abs() < 0.01,
1516 "Expected ~2x compression, got {:.2}x",
1517 compression_ratio
1518 );
1519 }
1520
1521 #[test]
1522 fn test_int4_tensor_packing() {
1523 let data: Vec<f32> = (0..1000).map(|i| (i as f32 / 1000.0) * 2.0 - 1.0).collect();
1524 let shape = vec![1000];
1525
1526 let mut quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1527
1528 println!("Before packing:");
1529 println!(" Unpacked size: {} bytes", quantized.unpacked_size_bytes());
1530 println!(" Is packed: {}", quantized.is_packed());
1531
1532 assert!(!quantized.is_packed());
1533 assert_eq!(quantized.size_bytes(), 1000);
1534
1535 quantized.pack();
1536
1537 println!("\nAfter packing:");
1538 println!(" Packed size: {} bytes", quantized.size_bytes());
1539 println!(" Is packed: {}", quantized.is_packed());
1540 println!(
1541 " Compression: {}x",
1542 quantized.unpacked_size_bytes() / quantized.size_bytes()
1543 );
1544
1545 assert!(quantized.is_packed());
1546 assert_eq!(quantized.size_bytes(), 500);
1547
1548 let dequantized = quantized.to_f32();
1549 assert_eq!(dequantized.len(), 1000);
1550
1551 let error = quantized.quantization_error(&data);
1552 println!(" MSE after packing: {:.8}", error);
1553 assert!(error < 0.01);
1554 }
1555
1556 #[test]
1557 fn test_int4_packed_vs_unpacked_error() {
1558 let data: Vec<f32> = (0..100).map(|i| (i as f32 / 100.0) * 2.0 - 1.0).collect();
1559 let shape = vec![100];
1560
1561 let unpacked = QuantizedTensorInt4::from_f32(&data, shape.clone()).unwrap();
1562 let error_unpacked = unpacked.quantization_error(&data);
1563
1564 let mut packed = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1565 packed.pack();
1566 let error_packed = packed.quantization_error(&data);
1567
1568 println!("Unpacked error: {:.8}", error_unpacked);
1569 println!("Packed error: {:.8}", error_packed);
1570
1571 assert!((error_unpacked - error_packed).abs() < 1e-6);
1572 }
1573
1574 #[test]
1575 fn test_int4_per_channel_packing() {
1576 let mut data = vec![];
1577 for i in 0..500 {
1578 data.push((i as f32 / 500.0) * 0.2 - 0.1); }
1580 for i in 0..500 {
1581 data.push((i as f32 / 500.0) * 20.0 - 10.0); }
1583
1584 let shape = vec![2, 500];
1585
1586 let mut quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1587
1588 let error_before = quantized.quantization_error(&data);
1589 println!("Error before packing: {:.8}", error_before);
1590
1591 quantized.pack();
1592
1593 let error_after = quantized.quantization_error(&data);
1594 println!("Error after packing: {:.8}", error_after);
1595 println!(
1596 "Size: {} bytes (packed from {} bytes)",
1597 quantized.size_bytes(),
1598 quantized.unpacked_size_bytes()
1599 );
1600
1601 assert!((error_before - error_after).abs() < 1e-6);
1602
1603 assert_eq!(quantized.size_bytes(), 500);
1604 }
1605
1606 #[test]
1607 fn test_int4_compression_comparison() {
1608 let size = 10000;
1609 let data: Vec<f32> = (0..size)
1610 .map(|i| ((i as f32 / size as f32) * 2.0 - 1.0) * 0.5)
1611 .collect();
1612 let shape = vec![size];
1613
1614 let fp32_size = size * std::mem::size_of::<f32>();
1615
1616 let int8 = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1617 let int8_size = int8.size_bytes();
1618
1619 let int4_unpacked = QuantizedTensorInt4::from_f32(&data, shape.clone()).unwrap();
1620 let int4_unpacked_size = int4_unpacked.size_bytes();
1621
1622 let mut int4_packed = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1623 int4_packed.pack();
1624 let int4_packed_size = int4_packed.size_bytes();
1625
1626 println!("\nCompression Comparison:");
1627 println!(" FP32: {} bytes", fp32_size);
1628 println!(
1629 " INT8: {} bytes ({:.1}x)",
1630 int8_size,
1631 fp32_size as f32 / int8_size as f32
1632 );
1633 println!(
1634 " INT4 unpacked: {} bytes ({:.1}x)",
1635 int4_unpacked_size,
1636 fp32_size as f32 / int4_unpacked_size as f32
1637 );
1638 println!(
1639 " INT4 packed: {} bytes ({:.1}x)",
1640 int4_packed_size,
1641 fp32_size as f32 / int4_packed_size as f32
1642 );
1643
1644 assert_eq!(fp32_size / int8_size, 4); assert_eq!(fp32_size / int4_packed_size, 8); }
1647
1648 #[test]
1649 #[ignore] fn test_int4_real_model() {
1651 use crate::onnx_utils::OnnxModel;
1652
1653 println!("\n{}", "=".repeat(60));
1654 println!("INT4 Real Model Test");
1655 println!("\n{}", "=".repeat(60));
1656
1657 let model_paths = vec![
1658 "test_models/mnist.onnx",
1659 "mnist.onnx",
1660 "test_models/resnet18-v1-7.onnx",
1661 "resnet18-v1-7.onnx",
1662 ];
1663
1664 let mut model = None;
1665 for path in &model_paths {
1666 if std::path::Path::new(path).exists() {
1667 println!("Loading model: {}", path);
1668 match OnnxModel::load(path) {
1669 Ok(m) => {
1670 model = Some(m);
1671 break;
1672 }
1673 Err(e) => println!(" Failed: {}", e),
1674 }
1675 }
1676 }
1677
1678 let model = match model {
1679 Some(m) => m,
1680 None => {
1681 println!("No test models found. Skipping test.");
1682 println!("Place mnist.onnx or resnet18-v1-7.onnx in current directory.");
1683 return;
1684 }
1685 };
1686
1687 let info = model.info();
1688 println!("✓ Model loaded: {}", info.name);
1689 println!(" Nodes: {}", info.num_nodes);
1690 println!();
1691
1692 println!("Extracting weights...");
1693 let weights = model.extract_weights();
1694 println!("✓ Found {} weight tensors", weights.len());
1695
1696 if weights.is_empty() {
1697 println!("No weights to quantize!");
1698 return;
1699 }
1700
1701 println!();
1702 println!("\n{}", "=".repeat(60));
1703 println!("Testing Per-Tensor Quantization");
1704 println!("\n{}", "=".repeat(60));
1705
1706 let test_weights: Vec<_> = weights
1707 .iter()
1708 .filter(|w| w.data.len() > 1000)
1709 .take(5)
1710 .collect();
1711
1712 println!("Testing {} large layers:\n", test_weights.len());
1713
1714 for (idx, weight) in test_weights.iter().enumerate() {
1715 let name = if weight.name.len() > 40 {
1716 format!("{}...", &weight.name[..37])
1717 } else {
1718 weight.name.clone()
1719 };
1720
1721 println!("[{}] {}", idx + 1, name);
1722 println!(
1723 " Shape: {:?}, Elements: {}",
1724 weight.shape,
1725 weight.data.len()
1726 );
1727
1728 let fp32_size = weight.data.len() * 4;
1729
1730 let int8_result = QuantizedTensor::from_f32(&weight.data, weight.shape.clone());
1731 let (int8_size, int8_error) = if let Ok(q) = int8_result {
1732 (q.size_bytes(), q.quantization_error(&weight.data))
1733 } else {
1734 println!(" INT8 failed!");
1735 continue;
1736 };
1737
1738 let int4_result = QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone());
1739 let (int4_unpacked_size, int4_error) = if let Ok(q) = int4_result {
1740 (q.size_bytes(), q.quantization_error(&weight.data))
1741 } else {
1742 println!(" INT4 failed!");
1743 continue;
1744 };
1745
1746 let mut int4_packed =
1747 QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone()).unwrap();
1748 int4_packed.pack();
1749 let int4_packed_size = int4_packed.size_bytes();
1750 let int4_packed_error = int4_packed.quantization_error(&weight.data);
1751
1752 println!(" FP32: {:7} bytes", fp32_size);
1753 println!(
1754 " INT8: {:7} bytes ({:.1}x) MSE: {:.8}",
1755 int8_size,
1756 fp32_size as f32 / int8_size as f32,
1757 int8_error
1758 );
1759 println!(
1760 " INT4 unpacked: {:7} bytes ({:.1}x) MSE: {:.8}",
1761 int4_unpacked_size,
1762 fp32_size as f32 / int4_unpacked_size as f32,
1763 int4_error
1764 );
1765 println!(
1766 " INT4 packed: {:7} bytes ({:.1}x) MSE: {:.8}",
1767 int4_packed_size,
1768 fp32_size as f32 / int4_packed_size as f32,
1769 int4_packed_error
1770 );
1771
1772 assert_eq!(int4_error, int4_packed_error, "Packing changed error!");
1773
1774 let int8_ratio = fp32_size as f32 / int8_size as f32;
1775 let int4_ratio = fp32_size as f32 / int4_packed_size as f32;
1776
1777 assert!(
1778 (int8_ratio - 4.0).abs() < 0.1,
1779 "INT8 compression should be ~4x"
1780 );
1781 assert!(
1782 (int4_ratio - 8.0).abs() < 0.1,
1783 "INT4 compression should be ~8x"
1784 );
1785
1786 println!();
1787 }
1788
1789 println!("\n{}", "=".repeat(60));
1790 println!("Testing Per-Channel Quantization");
1791 println!("\n{}", "=".repeat(60));
1792
1793 let conv_weights: Vec<_> = weights
1795 .iter()
1796 .filter(|w| w.shape.len() >= 2 && w.shape[0] > 1)
1797 .take(3)
1798 .collect();
1799
1800 if conv_weights.is_empty() {
1801 println!("No multi-channel layers found for per-channel test.");
1802 } else {
1803 println!("Testing {} conv layers:\n", conv_weights.len());
1804
1805 for (idx, weight) in conv_weights.iter().enumerate() {
1806 let name = if weight.name.len() > 40 {
1807 format!("{}...", &weight.name[..37])
1808 } else {
1809 weight.name.clone()
1810 };
1811
1812 println!("[{}] {}", idx + 1, name);
1813 println!(
1814 " Shape: {:?}, Channels: {}",
1815 weight.shape, weight.shape[0]
1816 );
1817
1818 let per_tensor =
1819 QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone()).unwrap();
1820 let per_tensor_error = per_tensor.quantization_error(&weight.data);
1821
1822 let per_channel_result =
1823 QuantizedTensorInt4::from_f32_per_channel(&weight.data, weight.shape.clone());
1824
1825 if let Ok(per_channel) = per_channel_result {
1826 let per_channel_error = per_channel.quantization_error(&weight.data);
1827
1828 let improvement =
1829 ((per_tensor_error - per_channel_error) / per_tensor_error) * 100.0;
1830
1831 println!(" Per-tensor: MSE: {:.8}", per_tensor_error);
1832 println!(
1833 " Per-channel: MSE: {:.8} ({:.1}% better)",
1834 per_channel_error, improvement
1835 );
1836
1837 assert!(
1838 per_channel_error <= per_tensor_error * 1.1,
1839 "Per-channel should not be significantly worse"
1840 );
1841 } else {
1842 println!(" Per-channel failed!");
1843 }
1844
1845 println!();
1846 }
1847 }
1848
1849 println!("\n{}", "=".repeat(60));
1850 println!("Summary");
1851 println!("\n{}", "=".repeat(60));
1852
1853 println!("✓ INT4 quantization works on real model weights");
1854 println!("✓ Compression ratios correct (4x INT8, 8x INT4)");
1855 println!("✓ Bit packing is lossless");
1856 println!("✓ Per-channel quantization works");
1857 println!("\nINT4 implementation is ready for CLI integration!");
1858 }
1859
1860 #[test]
1865 fn test_all_nan_returns_error() {
1866 let data = vec![f32::NAN, f32::NAN, f32::NAN];
1867 let result = QuantizedTensor::from_f32(&data, vec![3]);
1868 assert!(result.is_err());
1869 let err = result.unwrap_err().to_string();
1870 assert!(
1871 err.contains("non-finite"),
1872 "error should mention non-finite: {}",
1873 err
1874 );
1875 }
1876
1877 #[test]
1878 fn test_all_inf_returns_error() {
1879 let data = vec![f32::INFINITY, f32::NEG_INFINITY];
1880 let result = QuantizedTensor::from_f32(&data, vec![2]);
1881 assert!(result.is_err());
1882 }
1883
1884 #[test]
1885 fn test_all_nan_int4_returns_error() {
1886 let data = vec![f32::NAN; 4];
1887 let result = QuantizedTensorInt4::from_f32(&data, vec![4]);
1888 assert!(result.is_err());
1889 }
1890
1891 #[test]
1892 fn test_all_nan_per_channel_returns_error() {
1893 let data = vec![f32::NAN; 6];
1894 let result = QuantizedTensor::from_f32_per_channel(&data, vec![2, 3]);
1895 assert!(result.is_err());
1896 let err = result.unwrap_err().to_string();
1897 assert!(
1898 err.contains("Channel 0"),
1899 "error should mention channel: {}",
1900 err
1901 );
1902 }
1903
1904 #[test]
1905 fn test_mixed_nan_finite_succeeds() {
1906 let data = vec![f32::NAN, 1.0, -1.0, f32::NAN];
1908 let result = QuantizedTensor::from_f32(&data, vec![4]);
1909 assert!(result.is_ok());
1910 }
1911}