1use crate::errors::{QuantizeError, Result};
8
9#[derive(Debug, Clone)]
11pub struct QuantConfig {
12 pub bits: u8,
14 pub per_channel: bool,
16 pub calibration_method: Option<crate::calibration::methods::CalibrationMethod>,
18 pub excluded_layers: Vec<String>,
20 pub layer_bits: std::collections::HashMap<String, u8>,
22 pub min_elements: usize,
25}
26
27impl Default for QuantConfig {
28 fn default() -> Self {
29 Self {
30 bits: 8,
31 per_channel: false,
32 calibration_method: None,
33 excluded_layers: Vec::new(),
34 layer_bits: std::collections::HashMap::new(),
35 min_elements: 0,
36 }
37 }
38}
39
40impl QuantConfig {
41 pub fn int8() -> Self {
43 Self::default()
44 }
45
46 pub fn with_per_channel(mut self, enabled: bool) -> Self {
48 self.per_channel = enabled;
49 self
50 }
51
52 pub fn with_calibration(mut self, method: crate::calibration::methods::CalibrationMethod) -> Self {
54 self.calibration_method = Some(method);
55 self
56 }
57
58 pub fn should_quantize(&self, name: &str, num_elements: usize) -> bool {
64 if self.excluded_layers.iter().any(|e| e == name) {
65 return false;
66 }
67 if self.min_elements > 0 && num_elements < self.min_elements {
68 return false;
69 }
70 true
71 }
72
73 pub fn bits_for_layer(&self, name: &str) -> u8 {
78 self.layer_bits.get(name).copied().unwrap_or(self.bits)
79 }
80}
81
82pub trait QuantRange: Clone + std::fmt::Debug + Send + Sync + 'static {
88 const QMIN: f32;
90 const QMAX: f32;
92 const BITS: u8;
94}
95
96#[derive(Debug, Clone)]
98pub struct Int8Range;
99impl QuantRange for Int8Range {
100 const QMIN: f32 = -128.0;
101 const QMAX: f32 = 127.0;
102 const BITS: u8 = 8;
103}
104
105#[derive(Debug, Clone)]
107pub struct Int4Range;
108impl QuantRange for Int4Range {
109 const QMIN: f32 = -8.0;
110 const QMAX: f32 = 7.0;
111 const BITS: u8 = 4;
112}
113
114#[derive(Debug, Clone)]
124pub struct QuantParamsGeneric<R: QuantRange> {
125 scale: f32,
126 zero_point: i8,
127 _marker: std::marker::PhantomData<R>,
128}
129
130pub type QuantParams = QuantParamsGeneric<Int8Range>;
132pub type QuantParamsInt4 = QuantParamsGeneric<Int4Range>;
134
135impl<R: QuantRange> QuantParamsGeneric<R> {
136 pub fn scale(&self) -> f32 { self.scale }
138 pub fn zero_point(&self) -> i8 { self.zero_point }
140
141 pub fn from_range(min: f32, max: f32) -> Self {
143 let min = min.min(0.0);
144 let max = max.max(0.0);
145
146 let (min, max) = if (max - min).abs() < 1e-8 {
148 (min - 0.01, max + 0.01)
149 } else {
150 (min, max)
151 };
152
153 let scale = (max - min) / (R::QMAX - R::QMIN);
154 let scale = scale.max(1e-8);
155
156 let initial_zero_point = R::QMIN - min / scale;
157 let zero_point = initial_zero_point.round().clamp(R::QMIN, R::QMAX) as i8;
158
159 QuantParamsGeneric {
160 scale,
161 zero_point,
162 _marker: std::marker::PhantomData,
163 }
164 }
165
166 pub fn quantize(&self, value: f32) -> i8 {
168 if !value.is_finite() {
169 return self.zero_point;
170 }
171 let quantized = (value / self.scale).round() + (self.zero_point as f32);
172 quantized.clamp(R::QMIN, R::QMAX) as i8
173 }
174
175 pub fn dequantize(&self, value: i8) -> f32 {
177 ((value as i32) - (self.zero_point as i32)) as f32 * self.scale
178 }
179}
180
181#[derive(Debug, Clone)]
190pub struct QuantizedTensorGeneric<R: QuantRange> {
191 pub(crate) data: Vec<i8>,
192 pub(crate) packed_data: Option<Vec<u8>>,
194 pub(crate) shape: Vec<usize>,
195 pub(crate) params: QuantParamsGeneric<R>,
196 pub(crate) per_channel: bool,
197 pub(crate) channel_params: Option<Vec<QuantParamsGeneric<R>>>,
198}
199
200pub type QuantizedTensor = QuantizedTensorGeneric<Int8Range>;
202
203pub type QuantizedTensorInt4 = QuantizedTensorGeneric<Int4Range>;
208
209impl<R: QuantRange> QuantizedTensorGeneric<R> {
214 pub fn shape(&self) -> &[usize] { &self.shape }
216 pub fn params(&self) -> &QuantParamsGeneric<R> { &self.params }
218 pub fn is_per_channel(&self) -> bool { self.per_channel }
220
221 pub fn from_f32(data: &[f32], shape: Vec<usize>) -> Result<Self> {
227 if data.is_empty() {
228 return Err(QuantizeError::InvalidTensor { reason: "Cannot quantize empty tensor".into() });
229 }
230
231 let expected_len: usize = shape.iter().product();
232 if expected_len != data.len() {
233 return Err(QuantizeError::InvalidTensor { reason: format!("Shape {:?} expects {} elements but got {}", shape, expected_len, data.len()) });
234 }
235
236 let min = data.iter().copied().filter(|v| v.is_finite()).fold(f32::INFINITY, f32::min);
237 let max = data.iter().copied().filter(|v| v.is_finite()).fold(f32::NEG_INFINITY, f32::max);
238
239 if !min.is_finite() || !max.is_finite() {
240 return Err(QuantizeError::InvalidTensor { reason: "Tensor contains only non-finite values (NaN/Inf)".into() });
241 }
242
243 let params = QuantParamsGeneric::<R>::from_range(min, max);
244
245 let quantized_data: Vec<i8> = data.iter()
246 .map(|&v| params.quantize(v))
247 .collect();
248
249 Ok(QuantizedTensorGeneric {
250 data: quantized_data,
251 packed_data: None,
252 shape,
253 params,
254 per_channel: false,
255 channel_params: None,
256 })
257 }
258
259 pub fn from_f32_with_range(data: &[f32], shape: Vec<usize>, min: f32, max: f32) -> Result<Self> {
265 if data.is_empty() {
266 return Err(QuantizeError::InvalidTensor { reason: "Cannot quantize empty tensor".into() });
267 }
268
269 let expected_len: usize = shape.iter().product();
270 if expected_len != data.len() {
271 return Err(QuantizeError::InvalidTensor { reason: format!("Shape {:?} expects {} elements but got {}", shape, expected_len, data.len()) });
272 }
273
274 let params = QuantParamsGeneric::<R>::from_range(min, max);
275
276 let quantized_data: Vec<i8> = data.iter()
277 .map(|&v| params.quantize(v))
278 .collect();
279
280 Ok(QuantizedTensorGeneric {
281 data: quantized_data,
282 packed_data: None,
283 shape,
284 params,
285 per_channel: false,
286 channel_params: None,
287 })
288 }
289
290 pub fn from_f32_per_channel(
297 data: &[f32],
298 shape: Vec<usize>,
299 ) -> Result<Self> {
300 if data.is_empty() {
301 return Err(QuantizeError::InvalidTensor { reason: "Cannot quantize empty tensor".into() });
302 }
303
304 if shape.is_empty() {
305 return Err(QuantizeError::InvalidTensor { reason: "Cannot do per-channel quantization on scalar".into() });
306 }
307
308 let expected_len: usize = shape.iter().product();
309 if expected_len != data.len() {
310 return Err(QuantizeError::InvalidTensor { reason: format!("Shape {:?} expects {} elements but got {}", shape, expected_len, data.len()) });
311 }
312
313 let num_channels = shape[0];
314
315 let mut channel_params = Vec::new();
316 let mut quantized_data = Vec::with_capacity(data.len());
317
318 for channel_idx in 0..num_channels {
319 let channel_data = extract_channel(data, &shape, channel_idx)?;
320
321 let min = channel_data.iter().copied().filter(|v| v.is_finite()).fold(f32::INFINITY, f32::min);
322 let max = channel_data.iter().copied().filter(|v| v.is_finite()).fold(f32::NEG_INFINITY, f32::max);
323
324 if !min.is_finite() || !max.is_finite() {
325 return Err(QuantizeError::InvalidTensor {
326 reason: format!("Channel {} contains only non-finite values (NaN/Inf)", channel_idx),
327 });
328 }
329
330 let params = QuantParamsGeneric::<R>::from_range(min, max);
331 channel_params.push(params.clone());
332
333 for &value in &channel_data {
334 quantized_data.push(params.quantize(value));
335 }
336 }
337
338 let params = channel_params[0].clone();
340
341 Ok(QuantizedTensorGeneric {
342 data: quantized_data,
343 packed_data: None,
344 shape,
345 params,
346 per_channel: true,
347 channel_params: Some(channel_params),
348 })
349 }
350
351 pub fn to_f32(&self) -> Vec<f32> {
353 let data_owned;
355 let data: &[i8] = if let Some(ref packed) = self.packed_data {
356 data_owned = unpack_int4(packed, self.data.len());
357 &data_owned
358 } else {
359 &self.data
360 };
361
362 if self.per_channel {
363 if let Some(ref channel_params) = self.channel_params {
364 if channel_params.is_empty() {
365 return data.iter().map(|&v| self.params.dequantize(v)).collect();
366 }
367 let elements_per_channel = data.len() / channel_params.len();
368 data.iter()
369 .enumerate()
370 .map(|(i, &v)| {
371 let channel_idx = (i / elements_per_channel).min(channel_params.len() - 1);
372 channel_params[channel_idx].dequantize(v)
373 })
374 .collect()
375 } else {
376 data.iter().map(|&v| self.params.dequantize(v)).collect()
377 }
378 } else {
379 data.iter().map(|&v| self.params.dequantize(v)).collect()
380 }
381 }
382
383 pub fn size_bytes(&self) -> usize {
385 if let Some(ref packed) = self.packed_data {
386 packed.len()
387 } else {
388 self.data.len() * std::mem::size_of::<i8>()
389 }
390 }
391
392 pub fn quantization_error(&self, original: &[f32]) -> f32 {
394 if original.is_empty() {
395 return 0.0;
396 }
397
398 let dequantized = self.to_f32();
399
400 let sum: f32 = original.iter()
401 .zip(dequantized.iter())
402 .map(|(a, b)| (a - b).powi(2))
403 .sum();
404
405 sum / original.len() as f32
406 }
407}
408
409impl QuantizedTensorGeneric<Int4Range> {
414 pub fn pack(&mut self) {
416 self.packed_data = Some(pack_int4(&self.data));
417 }
418
419 pub fn ensure_unpacked(&self) -> Vec<i8> {
421 if let Some(ref packed) = self.packed_data {
422 unpack_int4(packed, self.data.len())
423 } else {
424 self.data.clone()
425 }
426 }
427
428 pub fn is_packed(&self) -> bool {
430 self.packed_data.is_some()
431 }
432
433 pub fn packed_size_bytes(&self) -> usize {
435 if let Some(ref packed) = self.packed_data {
436 packed.len()
437 } else {
438 self.data.len().div_ceil(2)
439 }
440 }
441
442 pub fn unpacked_size_bytes(&self) -> usize {
444 self.data.len() * std::mem::size_of::<i8>()
445 }
446}
447
448fn pack_int4_pair(val1: i8, val2: i8) -> u8 {
453 debug_assert!((-8..=7).contains(&val1), "val1 out of INT4 range: {}", val1);
454 debug_assert!((-8..=7).contains(&val2), "val2 out of INT4 range: {}", val2);
455
456 let nibble1 = (val1 & 0x0F) as u8;
458 let nibble2 = (val2 & 0x0F) as u8;
459
460 (nibble1 << 4) | nibble2
462}
463
464fn unpack_int4_pair(byte: u8) -> (i8, i8) {
465 let nibble1 = (byte >> 4) & 0x0F;
466 let nibble2 = byte & 0x0F;
467
468 let val1 = if nibble1 >= 8 {
470 (nibble1 as i8) | !0x0F
471 } else {
472 nibble1 as i8
473 };
474
475 let val2 = if nibble2 >= 8 {
476 (nibble2 as i8) | !0x0F
477 } else {
478 nibble2 as i8
479 };
480
481 (val1, val2)
482}
483
484pub fn pack_int4(values: &[i8]) -> Vec<u8> {
486 let mut packed = Vec::with_capacity(values.len().div_ceil(2));
487
488 for chunk in values.chunks(2) {
489 let val1 = chunk[0];
490 let val2 = if chunk.len() > 1 { chunk[1] } else { 0 };
491
492 packed.push(pack_int4_pair(val1, val2));
493 }
494
495 packed
496}
497
498pub fn unpack_int4(packed: &[u8], num_values: usize) -> Vec<i8> {
500 let mut values = Vec::with_capacity(num_values);
501
502 for &byte in packed {
503 let (val1, val2) = unpack_int4_pair(byte);
504 values.push(val1);
505 if values.len() < num_values {
506 values.push(val2);
507 }
508 }
509
510 values.truncate(num_values);
512 values
513}
514
515fn extract_channel(data: &[f32], shape: &[usize], channel_idx: usize) -> Result<Vec<f32>> {
520 if shape.is_empty() {
521 return Err(QuantizeError::InvalidTensor { reason: "Cannot extract channel from empty shape".into() });
522 }
523 let num_channels = shape[0];
524 if num_channels == 0 {
525 return Err(QuantizeError::InvalidTensor { reason: "Number of channels is 0".into() });
526 }
527 if channel_idx >= num_channels {
528 return Err(QuantizeError::InvalidTensor { reason: format!("Channel index {} out of bounds for {} channels", channel_idx, num_channels) });
529 }
530 if data.len() % num_channels != 0 {
531 return Err(QuantizeError::InvalidTensor { reason: format!("Data length {} not evenly divisible by {} channels", data.len(), num_channels) });
532 }
533 let elements_per_channel = data.len() / num_channels;
534 let start = channel_idx * elements_per_channel;
535 let end = start + elements_per_channel;
536 Ok(data[start..end].to_vec())
537}
538
539#[derive(Debug, Clone)]
545pub enum QuantizedTensorType {
546 Int8(QuantizedTensor),
547 Int4(QuantizedTensorInt4),
548}
549
550impl QuantizedTensorType {
551 pub fn to_f32(&self) -> Vec<f32> {
553 match self {
554 QuantizedTensorType::Int8(t) => t.to_f32(),
555 QuantizedTensorType::Int4(t) => t.to_f32(),
556 }
557 }
558
559 pub fn size_bytes(&self) -> usize {
561 match self {
562 QuantizedTensorType::Int8(t) => t.size_bytes(),
563 QuantizedTensorType::Int4(t) => t.size_bytes(),
564 }
565 }
566
567 #[must_use]
568 pub fn quantization_error(&self, original: &[f32]) -> f32 {
569 match self {
570 QuantizedTensorType::Int8(t) => t.quantization_error(original),
571 QuantizedTensorType::Int4(t) => t.quantization_error(original),
572 }
573 }
574
575 #[must_use]
576 pub fn data(&self) -> Vec<i8> {
577 match self {
578 QuantizedTensorType::Int8(t) => t.data.clone(),
579 QuantizedTensorType::Int4(t) => t.ensure_unpacked(),
580 }
581 }
582
583 pub fn get_scale_zero_point(&self) -> (f32, i8) {
585 match self {
586 QuantizedTensorType::Int8(t) => (t.params.scale, t.params.zero_point),
587 QuantizedTensorType::Int4(t) => (t.params.scale, t.params.zero_point),
588 }
589 }
590
591 pub fn get_all_scales_zero_points(&self) -> (Vec<f32>, Vec<i8>) {
596 match self {
597 QuantizedTensorType::Int8(t) => {
598 if let Some(ref cp) = t.channel_params {
599 (
600 cp.iter().map(|p| p.scale).collect(),
601 cp.iter().map(|p| p.zero_point).collect(),
602 )
603 } else {
604 (vec![t.params.scale], vec![t.params.zero_point])
605 }
606 }
607 QuantizedTensorType::Int4(t) => {
608 if let Some(ref cp) = t.channel_params {
609 (
610 cp.iter().map(|p| p.scale).collect(),
611 cp.iter().map(|p| p.zero_point).collect(),
612 )
613 } else {
614 (vec![t.params.scale], vec![t.params.zero_point])
615 }
616 }
617 }
618 }
619
620 pub fn is_per_channel(&self) -> bool {
622 match self {
623 QuantizedTensorType::Int8(t) => t.per_channel,
624 QuantizedTensorType::Int4(t) => t.per_channel,
625 }
626 }
627
628 #[must_use]
629 pub fn bits(&self) -> u8 {
630 match self {
631 QuantizedTensorType::Int8(_) => 8,
632 QuantizedTensorType::Int4(_) => 4,
633 }
634 }
635
636 pub fn is_int8(&self) -> bool {
638 matches!(self, QuantizedTensorType::Int8(_))
639 }
640
641 pub fn is_int4(&self) -> bool {
643 matches!(self, QuantizedTensorType::Int4(_))
644 }
645
646 pub fn data_ref(&self) -> Option<&[i8]> {
650 match self {
651 QuantizedTensorType::Int8(t) => Some(&t.data),
652 QuantizedTensorType::Int4(t) => {
653 if t.packed_data.is_some() {
654 None } else {
656 Some(&t.data)
657 }
658 }
659 }
660 }
661}
662
663pub struct Quantizer {
669 config: QuantConfig,
670 calibration_stats: Option<std::collections::HashMap<String, crate::calibration::stats::ActivationStats>>,
671}
672
673impl std::fmt::Debug for Quantizer {
674 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
675 let stats_count = self.calibration_stats.as_ref().map(|m| m.len());
676 f.debug_struct("Quantizer")
677 .field("config", &self.config)
678 .field("calibration_stats_count", &stats_count)
679 .finish()
680 }
681}
682
683impl Quantizer {
684 pub fn new(config: QuantConfig) -> Self {
686 Self {
687 config,
688 calibration_stats: None,
689 }
690 }
691
692 pub fn with_calibration(
694 config: QuantConfig,
695 stats: std::collections::HashMap<String, crate::calibration::stats::ActivationStats>,
696 ) -> Self {
697 Self {
698 config,
699 calibration_stats: Some(stats),
700 }
701 }
702
703 pub fn quantize_tensor_with_name(
705 &self,
706 name: &str,
707 data: &[f32],
708 shape: Vec<usize>,
709 ) -> Result<QuantizedTensorType> {
710 let (min, max) = if let Some(ref stats_map) = self.calibration_stats {
711 if let Some(stats) = stats_map.get(name) {
712 if let Some(method) = self.config.calibration_method {
713 use crate::calibration::stats::calculate_optimal_range;
714
715 let sample_data = sample_from_activation_stats(stats, 1000);
716 calculate_optimal_range(&sample_data, method)
717 } else {
718 (stats.min(), stats.max())
719 }
720 } else {
721 let min = data.iter().copied().filter(|v| v.is_finite()).fold(f32::INFINITY, f32::min);
723 let max = data.iter().copied().filter(|v| v.is_finite()).fold(f32::NEG_INFINITY, f32::max);
724 if !min.is_finite() || !max.is_finite() {
725 return Err(QuantizeError::InvalidTensor {
726 reason: format!("Tensor '{}' contains only non-finite values (NaN/Inf)", name),
727 });
728 }
729 (min, max)
730 }
731 } else {
732 let min = data.iter().copied().filter(|v| v.is_finite()).fold(f32::INFINITY, f32::min);
734 let max = data.iter().copied().filter(|v| v.is_finite()).fold(f32::NEG_INFINITY, f32::max);
735 if !min.is_finite() || !max.is_finite() {
736 return Err(QuantizeError::InvalidTensor {
737 reason: format!("Tensor '{}' contains only non-finite values (NaN/Inf)", name),
738 });
739 }
740 (min, max)
741 };
742
743 self.quantize_with_range(data, shape, min, max)
744 }
745
746 pub fn quantize_tensor(&self, data: &[f32], shape: Vec<usize>) -> Result<QuantizedTensorType> {
752 self.build_tensor_with_optional_range(data, shape, None)
753 }
754
755 fn quantize_with_range(
762 &self,
763 data: &[f32],
764 shape: Vec<usize>,
765 min: f32,
766 max: f32,
767 ) -> Result<QuantizedTensorType> {
768 self.build_tensor_with_optional_range(data, shape, Some((min, max)))
769 }
770
771 fn build_tensor_with_optional_range(
773 &self,
774 data: &[f32],
775 shape: Vec<usize>,
776 range: Option<(f32, f32)>,
777 ) -> Result<QuantizedTensorType> {
778 let pc = self.config.per_channel && shape.len() >= 2;
779 match self.config.bits {
780 8 => {
781 let t = match (pc, range) {
782 (true, _) => QuantizedTensor::from_f32_per_channel(data, shape)?,
783 (false, Some((min, max))) => QuantizedTensor::from_f32_with_range(data, shape, min, max)?,
784 (false, None) => QuantizedTensor::from_f32(data, shape)?,
785 };
786 Ok(QuantizedTensorType::Int8(t))
787 }
788 4 => {
789 let mut t = match (pc, range) {
790 (true, _) => QuantizedTensorInt4::from_f32_per_channel(data, shape)?,
791 (false, Some((min, max))) => QuantizedTensorInt4::from_f32_with_range(data, shape, min, max)?,
792 (false, None) => QuantizedTensorInt4::from_f32(data, shape)?,
793 };
794 t.pack();
795 Ok(QuantizedTensorType::Int4(t))
796 }
797 b => Err(QuantizeError::UnsupportedConfig {
798 reason: format!("bits must be 4 or 8, got {b}"),
799 }),
800 }
801 }
802}
803
804fn sample_from_activation_stats(stats: &crate::calibration::stats::ActivationStats, n: usize) -> Vec<f32> {
810 use rand::Rng;
811
812 let histogram = stats.histogram_data();
813 if histogram.is_empty() {
814 let mut rng = rand::thread_rng();
816 let range = stats.max() - stats.min();
817 if !range.is_finite() || range.abs() < 1e-8 {
818 return vec![stats.mean(); n];
819 }
820 return (0..n).map(|_| rng.gen::<f32>() * range + stats.min()).collect();
821 }
822
823 let total_count: usize = histogram.iter().map(|&(_, c)| c).sum();
824 if total_count == 0 {
825 let mut rng = rand::thread_rng();
826 let range = stats.max() - stats.min();
827 if !range.is_finite() || range.abs() < 1e-8 {
828 return vec![stats.mean(); n];
829 }
830 return (0..n).map(|_| rng.gen::<f32>() * range + stats.min()).collect();
831 }
832
833 let mut samples = Vec::with_capacity(n);
834 for &(value, count) in &histogram {
835 let num_samples = ((count as f64 / total_count as f64) * n as f64).round() as usize;
836 for _ in 0..num_samples {
837 samples.push(value);
838 }
839 }
840
841 samples.truncate(n);
843 while samples.len() < n {
844 samples.push(stats.mean());
845 }
846
847 samples
848}
849
850#[cfg(test)]
851mod tests {
852 use super::*;
853
854 #[test]
859 fn test_should_quantize_no_restrictions() {
860 let config = QuantConfig::default();
861 assert!(config.should_quantize("any.layer", 1));
862 assert!(config.should_quantize("any.layer", 1_000_000));
863 }
864
865 #[test]
866 fn test_should_quantize_excluded_layer() {
867 let config = QuantConfig {
868 excluded_layers: vec!["head.weight".to_string()],
869 ..Default::default()
870 };
871 assert!(!config.should_quantize("head.weight", 1024));
872 assert!(config.should_quantize("body.weight", 1024));
873 }
874
875 #[test]
876 fn test_should_quantize_min_elements() {
877 let config = QuantConfig {
878 min_elements: 512,
879 ..Default::default()
880 };
881 assert!(!config.should_quantize("small.bias", 4));
882 assert!(!config.should_quantize("small.bias", 511));
883 assert!(config.should_quantize("large.weight", 512));
884 assert!(config.should_quantize("large.weight", 1024));
885 }
886
887 #[test]
888 fn test_should_quantize_excluded_takes_priority_over_min_elements() {
889 let config = QuantConfig {
890 excluded_layers: vec!["head.weight".to_string()],
891 min_elements: 1,
892 ..Default::default()
893 };
894 assert!(!config.should_quantize("head.weight", 1_000_000));
896 }
897
898 #[test]
899 fn test_bits_for_layer_default() {
900 let config = QuantConfig { bits: 8, ..Default::default() };
901 assert_eq!(config.bits_for_layer("any.weight"), 8);
902 }
903
904 #[test]
905 fn test_bits_for_layer_override() {
906 let mut layer_bits = std::collections::HashMap::new();
907 layer_bits.insert("head.weight".to_string(), 4u8);
908 let config = QuantConfig {
909 bits: 8,
910 layer_bits,
911 ..Default::default()
912 };
913 assert_eq!(config.bits_for_layer("head.weight"), 4);
914 assert_eq!(config.bits_for_layer("body.weight"), 8);
915 }
916
917 #[test]
922 fn test_quant_params() {
923 let params = QuantParams::from_range(-1.0, 1.0);
924
925 assert_eq!(params.quantize(0.0), params.zero_point);
926
927 let original = 0.5;
928 let quantized = params.quantize(original);
929 let dequantized = params.dequantize(quantized);
930
931 assert!((original - dequantized).abs() < 0.01);
932 }
933
934 #[test]
935 fn test_quantize_tensor() {
936 let data = vec![0.0, 0.5, 1.0, -0.5, -1.0];
937 let shape = vec![5];
938
939 let quantized = QuantizedTensor::from_f32(&data, shape).unwrap();
940
941 assert_eq!(quantized.data.len(), 5);
942 assert_eq!(quantized.size_bytes(), 5);
943 }
944
945 #[test]
946 fn test_per_channel_quantization() {
947 let mut data = vec![];
948 for _ in 0..100 {
949 data.push(0.5); }
951 for _ in 0..100 {
952 data.push(5.0); }
954
955 let shape = vec![2, 100];
956
957 let quantized = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
958
959 assert!(quantized.per_channel);
960 assert!(quantized.channel_params.is_some());
961 assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 2);
962
963 let dequantized = quantized.to_f32();
964 let error: f32 = data.iter()
965 .zip(dequantized.iter())
966 .map(|(a, b)| (a - b).powi(2))
967 .sum::<f32>() / data.len() as f32;
968
969 println!("Per-channel MSE: {}", error);
970 assert!(error < 0.1);
971 }
972
973 #[test]
974 fn test_per_channel_vs_per_tensor() {
975 let mut data = vec![];
976
977 for _ in 0..1000 {
978 data.push(0.01);
979 }
980
981 for _ in 0..1000 {
982 data.push(10.0);
983 }
984
985 let shape = vec![2, 1000];
986
987 let per_tensor = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
989 let per_tensor_error = per_tensor.quantization_error(&data);
990
991 let per_channel = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
993 let per_channel_error = per_channel.quantization_error(&data);
994
995 println!("Per-tensor error: {:.8}", per_tensor_error);
996 println!("Per-channel error: {:.8}", per_channel_error);
997
998 assert!(per_channel_error < per_tensor_error);
1000 assert!(per_channel_error < per_tensor_error * 0.5);
1001 }
1002
1003 #[test]
1004 fn test_per_channel_benefit() {
1005 let mut data = vec![];
1006
1007 for i in 0..1000 {
1008 data.push(-0.1 + (i as f32 / 1000.0) * 0.2);
1009 }
1010
1011 for i in 0..1000 {
1012 data.push(-10.0 + (i as f32 / 1000.0) * 20.0);
1013 }
1014
1015 let shape = vec![2, 1000];
1016
1017 let per_tensor = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1018 let per_tensor_error = per_tensor.quantization_error(&data);
1019
1020 let per_channel = QuantizedTensor::from_f32_per_channel(&data, shape).unwrap();
1021 let per_channel_error = per_channel.quantization_error(&data);
1022
1023 println!("Per-tensor MSE: {:.8}", per_tensor_error);
1024 println!("Per-channel MSE: {:.8}", per_channel_error);
1025
1026 assert!(per_channel_error < per_tensor_error,
1027 "Per-channel ({:.8}) should be better than per-tensor ({:.8})",
1028 per_channel_error, per_tensor_error);
1029 }
1030
1031 #[test]
1032 fn test_int4_quant_params() {
1033 let params = QuantParamsInt4::from_range(-1.0, 1.0);
1034
1035 assert!(params.quantize(-10.0) >= -8);
1036 assert!(params.quantize(-10.0) <= 7);
1037 assert!(params.quantize(10.0) >= -8);
1038 assert!(params.quantize(10.0) <= 7);
1039
1040 let zero_quant = params.quantize(0.0);
1041 assert!(zero_quant >= -8 && zero_quant <= 7);
1042
1043 for &original in &[-1.0, -0.5, 0.0, 0.5, 1.0] {
1044 let quantized = params.quantize(original);
1045 let dequantized = params.dequantize(quantized);
1046
1047 println!("Original: {:.2}, Quantized: {}, Dequantized: {:.2}, Error: {:.4}",
1048 original, quantized, dequantized, (original - dequantized).abs());
1049
1050 assert!((original - dequantized).abs() < params.scale * 2.0);
1051 }
1052 }
1053
1054 #[test]
1055 fn test_int4_extreme_values() {
1056 let params = QuantParamsInt4::from_range(-100.0, 100.0);
1058
1059 let q_neg = params.quantize(-100.0);
1060 let q_pos = params.quantize(100.0);
1061
1062 assert_eq!(q_neg, -8);
1063 assert_eq!(q_pos, 7);
1064 }
1065
1066 #[test]
1067 fn test_int4_vs_int8_error() {
1068 let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
1069
1070 let params_int8 = QuantParams::from_range(-1.0, 1.0);
1071 let error_int8: f32 = data.iter()
1072 .map(|&v| {
1073 let q = params_int8.quantize(v);
1074 let dq = params_int8.dequantize(q);
1075 (v - dq).powi(2)
1076 })
1077 .sum::<f32>() / data.len() as f32;
1078
1079 let params_int4 = QuantParamsInt4::from_range(-1.0, 1.0);
1080 let error_int4: f32 = data.iter()
1081 .map(|&v| {
1082 let q = params_int4.quantize(v);
1083 let dq = params_int4.dequantize(q);
1084 (v - dq).powi(2)
1085 })
1086 .sum::<f32>() / data.len() as f32;
1087
1088 println!("INT8 MSE: {:.8}", error_int8);
1089 println!("INT4 MSE: {:.8}", error_int4);
1090
1091 assert!(error_int4 > error_int8);
1092
1093 assert!(error_int4 < error_int8 * 500.0,
1094 "INT4 error ({:.8}) is too high compared to INT8 ({:.8})",
1095 error_int4, error_int8);
1096
1097 assert!(error_int4.is_finite());
1098 assert!(error_int4 < 0.01);
1099
1100 }
1101
1102 #[test]
1103 fn test_int4_range() {
1104 let params = QuantParamsInt4::from_range(-1.0, 1.0);
1105
1106 assert!(params.quantize(-10.0) == -8);
1107 assert!(params.quantize(10.0) == 7);
1108
1109 for i in -8..=7 {
1111 let value = i as f32 * params.scale;
1112 let quantized = params.quantize(value);
1113 assert!(quantized >= -8 && quantized <= 7);
1114 }
1115 }
1116
1117 #[test]
1118 fn test_int4_optimal_precision() {
1119 let params = QuantParamsInt4::from_range(-1.0, 1.0);
1120
1121 let mut unique_values = std::collections::HashSet::new();
1122
1123 for i in 0..1000 {
1125 let value = -1.0 + (i as f32 / 1000.0) * 2.0;
1126 unique_values.insert(params.quantize(value));
1127 }
1128
1129 println!("Unique quantized values: {}", unique_values.len());
1130 assert!(unique_values.len() >= 14);
1131 }
1132
1133 #[test]
1134 fn test_int4_tensor_quantization() {
1135 let data = vec![0.0, 0.5, 1.0, -0.5, -1.0];
1136 let shape = vec![5];
1137
1138 let quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1139
1140 assert_eq!(quantized.data.len(), 5);
1141 assert_eq!(quantized.size_bytes(), 5);
1142 assert_eq!(quantized.packed_size_bytes(), 3);
1143
1144 for &val in &quantized.data {
1145 assert!(val >= -8 && val <= 7, "Value {} out of INT4 range", val);
1146 }
1147 }
1148
1149 #[test]
1150 fn test_int4_round_trip() {
1151 let original = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
1152 let shape = vec![5];
1153
1154 let quantized = QuantizedTensorInt4::from_f32(&original, shape).unwrap();
1155 let dequantized = quantized.to_f32();
1156
1157 println!("Original: {:?}", original);
1158 println!("Quantized: {:?}", quantized.data);
1159 println!("Dequantized: {:?}", dequantized);
1160
1161 for (orig, deq) in original.iter().zip(dequantized.iter()) {
1162 let error = (orig - deq).abs();
1163 println!(" {:.2} -> {:.2}, error: {:.4}", orig, deq, error);
1164 assert!(error < 0.15, "Error too large: {}", error);
1165 }
1166 }
1167
1168 #[test]
1169 fn test_int4_per_channel() {
1170 let mut data = vec![];
1171
1172 for i in 0..100 {
1174 data.push(-0.1 + (i as f32 / 100.0) * 0.2);
1175 }
1176
1177 for i in 0..100 {
1179 data.push(-10.0 + (i as f32 / 100.0) * 20.0);
1180 }
1181
1182 let shape = vec![2, 100];
1183
1184 let quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1185
1186 assert!(quantized.per_channel);
1187 assert!(quantized.channel_params.is_some());
1188 assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 2);
1189
1190 let error = quantized.quantization_error(&data);
1191 println!("INT4 per-channel MSE: {:.8}", error);
1192
1193 assert!(error < 1.0, "Error too high: {}", error);
1194 }
1195
1196 #[test]
1197 fn test_int4_vs_int8_compression() {
1198 let data: Vec<f32> = (0..1000).map(|i| (i as f32 / 1000.0) * 2.0 - 1.0).collect();
1199 let shape = vec![1000];
1200
1201 let int8_quantized = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1202 let int8_size = int8_quantized.size_bytes();
1203 let int8_error = int8_quantized.quantization_error(&data);
1204
1205 let int4_quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1206 let int4_size = int4_quantized.size_bytes();
1207 let int4_packed_size = int4_quantized.packed_size_bytes();
1208 let int4_error = int4_quantized.quantization_error(&data);
1209
1210 println!("INT8: {} bytes, MSE: {:.8}", int8_size, int8_error);
1211 println!("INT4 (unpacked): {} bytes, MSE: {:.8}", int4_size, int4_error);
1212 println!("INT4 (packed): {} bytes, MSE: {:.8}", int4_packed_size, int4_error);
1213
1214 assert_eq!(int4_size, int8_size);
1215
1216 assert!(int4_packed_size <= int8_size / 2 + 1);
1217
1218 assert!(int4_error > int8_error);
1219
1220 assert!(int4_error < 0.01, "INT4 error too high: {}", int4_error);
1221 }
1222
1223 #[test]
1224 fn test_int4_large_tensor() {
1225 let size = 64 * 3 * 3 * 3; let data: Vec<f32> = (0..size).map(|i| {
1227 ((i as f32 / size as f32) * 2.0 - 1.0) * 0.5
1228 }).collect();
1229
1230 let shape = vec![64, 3, 3, 3];
1231
1232 let quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1233
1234 assert_eq!(quantized.data.len(), size);
1235 assert_eq!(quantized.channel_params.as_ref().unwrap().len(), 64);
1236
1237 let error = quantized.quantization_error(&data);
1238 println!("Large tensor INT4 error: {:.8}", error);
1239
1240 assert!(error < 0.01, "Error too high for large tensor: {}", error);
1241 }
1242
1243 #[test]
1244 fn test_int4_extreme_ranges() {
1245 let test_cases = vec![
1246 (vec![-0.001, 0.0, 0.001], "tiny range"),
1247 (vec![-100.0, 0.0, 100.0], "large range"),
1248 (vec![0.0, 0.0, 0.0], "all zeros"),
1249 (vec![1.0, 1.0, 1.0], "all same"),
1250 ];
1251
1252 for (data, desc) in test_cases {
1253 println!("\nTesting: {}", desc);
1254 let shape = vec![data.len()];
1255
1256 let result = QuantizedTensorInt4::from_f32(&data, shape);
1257 assert!(result.is_ok(), "Failed on {}", desc);
1258
1259 let quantized = result.unwrap();
1260 let dequantized = quantized.to_f32();
1261
1262 println!(" Original: {:?}", data);
1263 println!(" Dequantized: {:?}", dequantized);
1264
1265 for &val in &quantized.data {
1266 assert!(val >= -8 && val <= 7, "Value {} out of range for {}", val, desc);
1267 }
1268 }
1269 }
1270
1271 #[test]
1272 fn test_int4_pack_unpack_pair() {
1273 let test_cases = vec![
1274 (-8, 7),
1275 (-8, -8),
1276 (7, 7),
1277 (0, 0),
1278 (-1, 0),
1279 (0, -1),
1280 (-5, 3),
1281 (6, -4),
1282 ];
1283
1284 for (val1, val2) in test_cases {
1285 println!("\nTesting: ({}, {})", val1, val2);
1286
1287 let packed = pack_int4_pair(val1, val2);
1288 let (unpacked1, unpacked2) = unpack_int4_pair(packed);
1289
1290 println!(" Packed: 0x{:02X} (binary: {:08b})", packed, packed);
1291 println!(" Unpacked: ({}, {})", unpacked1, unpacked2);
1292
1293 assert_eq!(val1, unpacked1, "First value mismatch");
1294 assert_eq!(val2, unpacked2, "Second value mismatch");
1295 }
1296 }
1297
1298 #[test]
1299 fn test_int4_pack_unpack_vector() {
1300 let values = vec![-8, -7, -1, 0, 1, 7];
1301 let packed = pack_int4(&values);
1302 let unpacked = unpack_int4(&packed, values.len());
1303
1304 println!("\nEven length:");
1305 println!(" Original: {:?}", values);
1306 println!(" Packed: {:?} ({} bytes)", packed, packed.len());
1307 println!(" Unpacked: {:?}", unpacked);
1308
1309 assert_eq!(values, unpacked);
1310 assert_eq!(packed.len(), (values.len() + 1) / 2);
1311 }
1312
1313 #[test]
1314 fn test_int4_pack_unpack_odd_length() {
1315 let values = vec![-8, -5, 0, 5, 7];
1316 let packed = pack_int4(&values);
1317 let unpacked = unpack_int4(&packed, values.len());
1318
1319 println!("\nOdd length:");
1320 println!(" Original: {:?}", values);
1321 println!(" Packed: {:?} ({} bytes)", packed, packed.len());
1322 println!(" Unpacked: {:?}", unpacked);
1323
1324 assert_eq!(values, unpacked);
1325 assert_eq!(packed.len(), (values.len() + 1) / 2);
1326 }
1327
1328 #[test]
1329 fn test_int4_pack_all_values() {
1330 let values: Vec<i8> = (-8..=7).collect();
1331 let packed = pack_int4(&values);
1332 let unpacked = unpack_int4(&packed, values.len());
1333
1334 println!("\nAll INT4 values:");
1335 println!(" Original: {:?}", values);
1336 println!(" Packed: {} bytes", packed.len());
1337 println!(" Unpacked: {:?}", unpacked);
1338
1339 assert_eq!(values, unpacked);
1340 assert_eq!(packed.len(), 8);
1341 }
1342
1343 #[test]
1344 fn test_int4_pack_large_vector() {
1345 let values: Vec<i8> = (0..1000).map(|i| ((i % 16) - 8) as i8).collect();
1346 let packed = pack_int4(&values);
1347 let unpacked = unpack_int4(&packed, values.len());
1348
1349 assert_eq!(values, unpacked);
1350 assert_eq!(packed.len(), 500);
1351
1352 println!("\nLarge vector:");
1353 println!(" Original: {} values", values.len());
1354 println!(" Packed: {} bytes ({}x compression)", packed.len(),
1355 values.len() / packed.len());
1356 println!(" Unpacked: {} values", unpacked.len());
1357 }
1358
1359 #[test]
1360 fn test_int4_compression_ratio() {
1361 let size = 10000;
1362 let values: Vec<i8> = (0..size).map(|i| ((i % 16) - 8) as i8).collect();
1363
1364 let unpacked_size = values.len() * std::mem::size_of::<i8>();
1365
1366 let packed = pack_int4(&values);
1367 let packed_size = packed.len();
1368
1369 let compression_ratio = unpacked_size as f32 / packed_size as f32;
1370
1371 println!("\nCompression test:");
1372 println!(" Values: {}", size);
1373 println!(" Unpacked: {} bytes", unpacked_size);
1374 println!(" Packed: {} bytes", packed_size);
1375 println!(" Compression: {:.2}x", compression_ratio);
1376
1377 assert!((compression_ratio - 2.0).abs() < 0.01,
1378 "Expected ~2x compression, got {:.2}x", compression_ratio);
1379 }
1380
1381 #[test]
1382 fn test_int4_tensor_packing() {
1383 let data: Vec<f32> = (0..1000).map(|i| (i as f32 / 1000.0) * 2.0 - 1.0).collect();
1384 let shape = vec![1000];
1385
1386 let mut quantized = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1387
1388 println!("Before packing:");
1389 println!(" Unpacked size: {} bytes", quantized.unpacked_size_bytes());
1390 println!(" Is packed: {}", quantized.is_packed());
1391
1392 assert!(!quantized.is_packed());
1393 assert_eq!(quantized.size_bytes(), 1000);
1394
1395 quantized.pack();
1396
1397 println!("\nAfter packing:");
1398 println!(" Packed size: {} bytes", quantized.size_bytes());
1399 println!(" Is packed: {}", quantized.is_packed());
1400 println!(" Compression: {}x", quantized.unpacked_size_bytes() / quantized.size_bytes());
1401
1402 assert!(quantized.is_packed());
1403 assert_eq!(quantized.size_bytes(), 500);
1404
1405 let dequantized = quantized.to_f32();
1406 assert_eq!(dequantized.len(), 1000);
1407
1408 let error = quantized.quantization_error(&data);
1409 println!(" MSE after packing: {:.8}", error);
1410 assert!(error < 0.01);
1411 }
1412
1413 #[test]
1414 fn test_int4_packed_vs_unpacked_error() {
1415 let data: Vec<f32> = (0..100).map(|i| (i as f32 / 100.0) * 2.0 - 1.0).collect();
1416 let shape = vec![100];
1417
1418 let unpacked = QuantizedTensorInt4::from_f32(&data, shape.clone()).unwrap();
1419 let error_unpacked = unpacked.quantization_error(&data);
1420
1421 let mut packed = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1422 packed.pack();
1423 let error_packed = packed.quantization_error(&data);
1424
1425 println!("Unpacked error: {:.8}", error_unpacked);
1426 println!("Packed error: {:.8}", error_packed);
1427
1428 assert!((error_unpacked - error_packed).abs() < 1e-6);
1429 }
1430
1431 #[test]
1432 fn test_int4_per_channel_packing() {
1433 let mut data = vec![];
1434 for i in 0..500 {
1435 data.push((i as f32 / 500.0) * 0.2 - 0.1); }
1437 for i in 0..500 {
1438 data.push((i as f32 / 500.0) * 20.0 - 10.0); }
1440
1441 let shape = vec![2, 500];
1442
1443 let mut quantized = QuantizedTensorInt4::from_f32_per_channel(&data, shape).unwrap();
1444
1445 let error_before = quantized.quantization_error(&data);
1446 println!("Error before packing: {:.8}", error_before);
1447
1448 quantized.pack();
1449
1450 let error_after = quantized.quantization_error(&data);
1451 println!("Error after packing: {:.8}", error_after);
1452 println!("Size: {} bytes (packed from {} bytes)",
1453 quantized.size_bytes(),
1454 quantized.unpacked_size_bytes());
1455
1456 assert!((error_before - error_after).abs() < 1e-6);
1457
1458 assert_eq!(quantized.size_bytes(), 500);
1459 }
1460
1461 #[test]
1462 fn test_int4_compression_comparison() {
1463 let size = 10000;
1464 let data: Vec<f32> = (0..size).map(|i| {
1465 ((i as f32 / size as f32) * 2.0 - 1.0) * 0.5
1466 }).collect();
1467 let shape = vec![size];
1468
1469 let fp32_size = size * std::mem::size_of::<f32>();
1470
1471 let int8 = QuantizedTensor::from_f32(&data, shape.clone()).unwrap();
1472 let int8_size = int8.size_bytes();
1473
1474 let int4_unpacked = QuantizedTensorInt4::from_f32(&data, shape.clone()).unwrap();
1475 let int4_unpacked_size = int4_unpacked.size_bytes();
1476
1477 let mut int4_packed = QuantizedTensorInt4::from_f32(&data, shape).unwrap();
1478 int4_packed.pack();
1479 let int4_packed_size = int4_packed.size_bytes();
1480
1481 println!("\nCompression Comparison:");
1482 println!(" FP32: {} bytes", fp32_size);
1483 println!(" INT8: {} bytes ({:.1}x)", int8_size, fp32_size as f32 / int8_size as f32);
1484 println!(" INT4 unpacked: {} bytes ({:.1}x)", int4_unpacked_size, fp32_size as f32 / int4_unpacked_size as f32);
1485 println!(" INT4 packed: {} bytes ({:.1}x)", int4_packed_size, fp32_size as f32 / int4_packed_size as f32);
1486
1487 assert_eq!(fp32_size / int8_size, 4); assert_eq!(fp32_size / int4_packed_size, 8); }
1490
1491 #[test]
1492 #[ignore] fn test_int4_real_model() {
1494 use crate::onnx_utils::OnnxModel;
1495
1496 println!("\n{}", "=".repeat(60));
1497 println!("INT4 Real Model Test");
1498 println!("\n{}", "=".repeat(60));
1499
1500 let model_paths = vec![
1501 "test_models/mnist.onnx",
1502 "mnist.onnx",
1503 "test_models/resnet18-v1-7.onnx",
1504 "resnet18-v1-7.onnx",
1505 ];
1506
1507 let mut model = None;
1508 for path in &model_paths {
1509 if std::path::Path::new(path).exists() {
1510 println!("Loading model: {}", path);
1511 match OnnxModel::load(path) {
1512 Ok(m) => {
1513 model = Some(m);
1514 break;
1515 }
1516 Err(e) => println!(" Failed: {}", e),
1517 }
1518 }
1519 }
1520
1521 let model = match model {
1522 Some(m) => m,
1523 None => {
1524 println!("No test models found. Skipping test.");
1525 println!("Place mnist.onnx or resnet18-v1-7.onnx in current directory.");
1526 return;
1527 }
1528 };
1529
1530 let info = model.info();
1531 println!("✓ Model loaded: {}", info.name);
1532 println!(" Nodes: {}", info.num_nodes);
1533 println!();
1534
1535 println!("Extracting weights...");
1536 let weights = model.extract_weights();
1537 println!("✓ Found {} weight tensors", weights.len());
1538
1539 if weights.is_empty() {
1540 println!("No weights to quantize!");
1541 return;
1542 }
1543
1544 println!();
1545 println!("\n{}", "=".repeat(60));
1546 println!("Testing Per-Tensor Quantization");
1547 println!("\n{}", "=".repeat(60));
1548
1549 let test_weights: Vec<_> = weights.iter()
1550 .filter(|w| w.data.len() > 1000)
1551 .take(5)
1552 .collect();
1553
1554 println!("Testing {} large layers:\n", test_weights.len());
1555
1556 for (idx, weight) in test_weights.iter().enumerate() {
1557 let name = if weight.name.len() > 40 {
1558 format!("{}...", &weight.name[..37])
1559 } else {
1560 weight.name.clone()
1561 };
1562
1563 println!("[{}] {}", idx + 1, name);
1564 println!(" Shape: {:?}, Elements: {}", weight.shape, weight.data.len());
1565
1566 let fp32_size = weight.data.len() * 4;
1567
1568 let int8_result = QuantizedTensor::from_f32(&weight.data, weight.shape.clone());
1569 let (int8_size, int8_error) = if let Ok(q) = int8_result {
1570 (q.size_bytes(), q.quantization_error(&weight.data))
1571 } else {
1572 println!(" INT8 failed!");
1573 continue;
1574 };
1575
1576 let int4_result = QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone());
1577 let (int4_unpacked_size, int4_error) = if let Ok(q) = int4_result {
1578 (q.size_bytes(), q.quantization_error(&weight.data))
1579 } else {
1580 println!(" INT4 failed!");
1581 continue;
1582 };
1583
1584 let mut int4_packed = QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone()).unwrap();
1585 int4_packed.pack();
1586 let int4_packed_size = int4_packed.size_bytes();
1587 let int4_packed_error = int4_packed.quantization_error(&weight.data);
1588
1589 println!(" FP32: {:7} bytes", fp32_size);
1590 println!(" INT8: {:7} bytes ({:.1}x) MSE: {:.8}",
1591 int8_size, fp32_size as f32 / int8_size as f32, int8_error);
1592 println!(" INT4 unpacked: {:7} bytes ({:.1}x) MSE: {:.8}",
1593 int4_unpacked_size, fp32_size as f32 / int4_unpacked_size as f32, int4_error);
1594 println!(" INT4 packed: {:7} bytes ({:.1}x) MSE: {:.8}",
1595 int4_packed_size, fp32_size as f32 / int4_packed_size as f32, int4_packed_error);
1596
1597 assert_eq!(int4_error, int4_packed_error, "Packing changed error!");
1598
1599 let int8_ratio = fp32_size as f32 / int8_size as f32;
1600 let int4_ratio = fp32_size as f32 / int4_packed_size as f32;
1601
1602 assert!((int8_ratio - 4.0).abs() < 0.1, "INT8 compression should be ~4x");
1603 assert!((int4_ratio - 8.0).abs() < 0.1, "INT4 compression should be ~8x");
1604
1605 println!();
1606 }
1607
1608 println!("\n{}", "=".repeat(60));
1609 println!("Testing Per-Channel Quantization");
1610 println!("\n{}", "=".repeat(60));
1611
1612 let conv_weights: Vec<_> = weights.iter()
1614 .filter(|w| w.shape.len() >= 2 && w.shape[0] > 1)
1615 .take(3)
1616 .collect();
1617
1618 if conv_weights.is_empty() {
1619 println!("No multi-channel layers found for per-channel test.");
1620 } else {
1621 println!("Testing {} conv layers:\n", conv_weights.len());
1622
1623 for (idx, weight) in conv_weights.iter().enumerate() {
1624 let name = if weight.name.len() > 40 {
1625 format!("{}...", &weight.name[..37])
1626 } else {
1627 weight.name.clone()
1628 };
1629
1630 println!("[{}] {}", idx + 1, name);
1631 println!(" Shape: {:?}, Channels: {}", weight.shape, weight.shape[0]);
1632
1633 let per_tensor = QuantizedTensorInt4::from_f32(&weight.data, weight.shape.clone()).unwrap();
1634 let per_tensor_error = per_tensor.quantization_error(&weight.data);
1635
1636 let per_channel_result = QuantizedTensorInt4::from_f32_per_channel(
1637 &weight.data,
1638 weight.shape.clone(),
1639 );
1640
1641 if let Ok(per_channel) = per_channel_result {
1642 let per_channel_error = per_channel.quantization_error(&weight.data);
1643
1644 let improvement = ((per_tensor_error - per_channel_error) / per_tensor_error) * 100.0;
1645
1646 println!(" Per-tensor: MSE: {:.8}", per_tensor_error);
1647 println!(" Per-channel: MSE: {:.8} ({:.1}% better)",
1648 per_channel_error, improvement);
1649
1650 assert!(per_channel_error <= per_tensor_error * 1.1,
1651 "Per-channel should not be significantly worse");
1652 } else {
1653 println!(" Per-channel failed!");
1654 }
1655
1656 println!();
1657 }
1658 }
1659
1660 println!("\n{}", "=".repeat(60));
1661 println!("Summary");
1662 println!("\n{}", "=".repeat(60));
1663
1664 println!("✓ INT4 quantization works on real model weights");
1665 println!("✓ Compression ratios correct (4x INT8, 8x INT4)");
1666 println!("✓ Bit packing is lossless");
1667 println!("✓ Per-channel quantization works");
1668 println!("\nINT4 implementation is ready for CLI integration!");
1669 }
1670
1671 #[test]
1676 fn test_all_nan_returns_error() {
1677 let data = vec![f32::NAN, f32::NAN, f32::NAN];
1678 let result = QuantizedTensor::from_f32(&data, vec![3]);
1679 assert!(result.is_err());
1680 let err = result.unwrap_err().to_string();
1681 assert!(err.contains("non-finite"), "error should mention non-finite: {}", err);
1682 }
1683
1684 #[test]
1685 fn test_all_inf_returns_error() {
1686 let data = vec![f32::INFINITY, f32::NEG_INFINITY];
1687 let result = QuantizedTensor::from_f32(&data, vec![2]);
1688 assert!(result.is_err());
1689 }
1690
1691 #[test]
1692 fn test_all_nan_int4_returns_error() {
1693 let data = vec![f32::NAN; 4];
1694 let result = QuantizedTensorInt4::from_f32(&data, vec![4]);
1695 assert!(result.is_err());
1696 }
1697
1698 #[test]
1699 fn test_all_nan_per_channel_returns_error() {
1700 let data = vec![f32::NAN; 6];
1701 let result = QuantizedTensor::from_f32_per_channel(&data, vec![2, 3]);
1702 assert!(result.is_err());
1703 let err = result.unwrap_err().to_string();
1704 assert!(err.contains("Channel 0"), "error should mention channel: {}", err);
1705 }
1706
1707 #[test]
1708 fn test_mixed_nan_finite_succeeds() {
1709 let data = vec![f32::NAN, 1.0, -1.0, f32::NAN];
1711 let result = QuantizedTensor::from_f32(&data, vec![4]);
1712 assert!(result.is_ok());
1713 }
1714}