ipfrs_tensorlogic/
quantization.rs

1//! Model Quantization Support
2//!
3//! This module provides comprehensive quantization support for ML models, enabling
4//! efficient deployment on edge devices and reducing model size and inference latency.
5//!
6//! ## Supported Quantization Schemes
7//!
8//! - **INT8 Quantization**: 8-bit integer quantization with configurable ranges
9//! - **INT4 Quantization**: 4-bit integer quantization for extreme compression
10//! - **Per-Tensor Quantization**: Single scale/zero-point for entire tensor
11//! - **Per-Channel Quantization**: Independent scale/zero-point per channel
12//! - **Symmetric Quantization**: Zero-point = 0 (centered around zero)
13//! - **Asymmetric Quantization**: Arbitrary zero-point for full range coverage
14//! - **Dynamic Quantization**: Runtime quantization of activations
15//!
16//! ## Examples
17//!
18//! ```
19//! use ipfrs_tensorlogic::{QuantizedTensor, QuantizationScheme, QuantizationConfig};
20//!
21//! // Per-tensor INT8 symmetric quantization
22//! let weights = vec![0.5, -0.3, 0.8, -0.1];
23//! let config = QuantizationConfig::int8_symmetric();
24//! let quantized = QuantizedTensor::quantize_per_tensor(&weights, vec![4], config).unwrap();
25//!
26//! // Dequantize back to f32
27//! let dequantized = quantized.dequantize();
28//! assert_eq!(dequantized.len(), 4);
29//!
30//! // Per-channel quantization for Conv2D weights
31//! let weights = vec![0.5, 0.3, -0.2, -0.4, 0.1, 0.6, -0.5, 0.2]; // 2 channels, 4 elements each
32//! let config = QuantizationConfig::int8_per_channel(2);
33//! let quantized = QuantizedTensor::quantize_per_channel(&weights, vec![2, 4], config).unwrap();
34//! ```
35
36use serde::{Deserialize, Serialize};
37use std::fmt;
38use thiserror::Error;
39
40/// Errors that can occur during quantization operations
41#[derive(Debug, Error)]
42pub enum QuantizationError {
43    #[error("Invalid quantization bit width: {0}")]
44    InvalidBitWidth(u8),
45
46    #[error("Invalid shape: {0}")]
47    InvalidShape(String),
48
49    #[error("Invalid number of channels: expected {expected}, got {got}")]
50    InvalidChannelCount { expected: usize, got: usize },
51
52    #[error("Empty tensor cannot be quantized")]
53    EmptyTensor,
54
55    #[error("Calibration data required for dynamic quantization")]
56    CalibrationRequired,
57
58    #[error("Unsupported quantization scheme: {0}")]
59    UnsupportedScheme(String),
60}
61
62/// Quantization scheme
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
64pub enum QuantizationScheme {
65    /// 8-bit integer quantization (INT8)
66    Int8,
67    /// 4-bit integer quantization (INT4)
68    Int4,
69    /// 16-bit integer quantization (INT16)
70    Int16,
71}
72
73impl QuantizationScheme {
74    /// Get the bit width for this scheme
75    pub fn bit_width(&self) -> u8 {
76        match self {
77            QuantizationScheme::Int4 => 4,
78            QuantizationScheme::Int8 => 8,
79            QuantizationScheme::Int16 => 16,
80        }
81    }
82
83    /// Get the quantization range (min, max) for this scheme
84    pub fn range(&self, symmetric: bool) -> (i32, i32) {
85        match (self, symmetric) {
86            (QuantizationScheme::Int4, true) => (-8, 7),
87            (QuantizationScheme::Int4, false) => (0, 15),
88            (QuantizationScheme::Int8, true) => (-128, 127),
89            (QuantizationScheme::Int8, false) => (0, 255),
90            (QuantizationScheme::Int16, true) => (-32768, 32767),
91            (QuantizationScheme::Int16, false) => (0, 65535),
92        }
93    }
94
95    /// Calculate compression ratio compared to f32
96    pub fn compression_ratio(&self) -> f32 {
97        32.0 / self.bit_width() as f32
98    }
99}
100
101/// Quantization granularity
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
103pub enum QuantizationGranularity {
104    /// Per-tensor quantization (single scale/zero-point)
105    PerTensor,
106    /// Per-channel quantization (scale/zero-point per output channel)
107    PerChannel { num_channels: usize },
108    /// Per-group quantization (scale/zero-point per group)
109    PerGroup { group_size: usize },
110}
111
112/// Quantization configuration
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct QuantizationConfig {
115    /// Quantization scheme (INT4, INT8, etc.)
116    pub scheme: QuantizationScheme,
117    /// Quantization granularity
118    pub granularity: QuantizationGranularity,
119    /// Use symmetric quantization (zero_point = 0)
120    pub symmetric: bool,
121    /// Calibration method for determining scale/zero-point
122    pub calibration: CalibrationMethod,
123}
124
125impl QuantizationConfig {
126    /// Create INT8 symmetric per-tensor quantization config
127    pub fn int8_symmetric() -> Self {
128        Self {
129            scheme: QuantizationScheme::Int8,
130            granularity: QuantizationGranularity::PerTensor,
131            symmetric: true,
132            calibration: CalibrationMethod::MinMax,
133        }
134    }
135
136    /// Create INT8 asymmetric per-tensor quantization config
137    pub fn int8_asymmetric() -> Self {
138        Self {
139            scheme: QuantizationScheme::Int8,
140            granularity: QuantizationGranularity::PerTensor,
141            symmetric: false,
142            calibration: CalibrationMethod::MinMax,
143        }
144    }
145
146    /// Create INT8 per-channel quantization config
147    pub fn int8_per_channel(num_channels: usize) -> Self {
148        Self {
149            scheme: QuantizationScheme::Int8,
150            granularity: QuantizationGranularity::PerChannel { num_channels },
151            symmetric: true,
152            calibration: CalibrationMethod::MinMax,
153        }
154    }
155
156    /// Create INT4 symmetric per-tensor quantization config
157    pub fn int4_symmetric() -> Self {
158        Self {
159            scheme: QuantizationScheme::Int4,
160            granularity: QuantizationGranularity::PerTensor,
161            symmetric: true,
162            calibration: CalibrationMethod::MinMax,
163        }
164    }
165
166    /// Create INT4 per-channel quantization config
167    pub fn int4_per_channel(num_channels: usize) -> Self {
168        Self {
169            scheme: QuantizationScheme::Int4,
170            granularity: QuantizationGranularity::PerChannel { num_channels },
171            symmetric: true,
172            calibration: CalibrationMethod::MinMax,
173        }
174    }
175}
176
177/// Calibration method for determining quantization parameters
178#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
179pub enum CalibrationMethod {
180    /// Min-max calibration (uses actual min/max of data)
181    MinMax,
182    /// Percentile-based calibration (clips outliers)
183    Percentile { lower: u8, upper: u8 },
184    /// Entropy-based calibration (minimizes KL divergence)
185    Entropy,
186    /// MSE-based calibration (minimizes mean squared error)
187    Mse,
188}
189
190/// Quantization parameters for a single channel/tensor
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct QuantizationParams {
193    /// Scale factor: real_value = (quantized_value - zero_point) * scale
194    pub scale: f32,
195    /// Zero point (quantized value corresponding to real 0.0)
196    pub zero_point: i32,
197    /// Min value in quantized range
198    pub qmin: i32,
199    /// Max value in quantized range
200    pub qmax: i32,
201}
202
203impl QuantizationParams {
204    /// Create quantization parameters from min/max values
205    pub fn from_min_max(
206        min_val: f32,
207        max_val: f32,
208        scheme: QuantizationScheme,
209        symmetric: bool,
210    ) -> Self {
211        let (qmin, qmax) = scheme.range(symmetric);
212
213        let (scale, zero_point) = if symmetric {
214            // Symmetric: zero_point = 0, scale based on max absolute value
215            let abs_max = min_val.abs().max(max_val.abs());
216            let scale = if abs_max > 0.0 {
217                abs_max / (qmax as f32)
218            } else {
219                1.0
220            };
221            (scale, 0)
222        } else {
223            // Asymmetric: use full range
224            let scale = if (max_val - min_val).abs() > 0.0 {
225                (max_val - min_val) / ((qmax - qmin) as f32)
226            } else {
227                1.0
228            };
229            let zero_point = qmin - (min_val / scale).round() as i32;
230            let zero_point = zero_point.clamp(qmin, qmax);
231            (scale, zero_point)
232        };
233
234        Self {
235            scale,
236            zero_point,
237            qmin,
238            qmax,
239        }
240    }
241
242    /// Quantize a floating-point value
243    #[inline]
244    pub fn quantize(&self, value: f32) -> i32 {
245        let quantized = (value / self.scale).round() as i32 + self.zero_point;
246        quantized.clamp(self.qmin, self.qmax)
247    }
248
249    /// Dequantize a quantized value
250    #[inline]
251    pub fn dequantize(&self, quantized: i32) -> f32 {
252        (quantized - self.zero_point) as f32 * self.scale
253    }
254}
255
256/// Quantized tensor representation
257#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct QuantizedTensor {
259    /// Quantized data (stored as i32 for all schemes)
260    pub data: Vec<i32>,
261    /// Tensor shape
262    pub shape: Vec<usize>,
263    /// Quantization parameters (one per channel for per-channel quantization)
264    pub params: Vec<QuantizationParams>,
265    /// Quantization configuration
266    pub config: QuantizationConfig,
267}
268
269impl QuantizedTensor {
270    /// Quantize a tensor with per-tensor quantization
271    pub fn quantize_per_tensor(
272        data: &[f32],
273        shape: Vec<usize>,
274        config: QuantizationConfig,
275    ) -> Result<Self, QuantizationError> {
276        if data.is_empty() {
277            return Err(QuantizationError::EmptyTensor);
278        }
279
280        // Ensure config is per-tensor
281        if !matches!(config.granularity, QuantizationGranularity::PerTensor) {
282            return Err(QuantizationError::UnsupportedScheme(
283                "Expected per-tensor granularity".to_string(),
284            ));
285        }
286
287        // Calculate min/max
288        let (min_val, max_val) = Self::calculate_min_max(data, &config.calibration)?;
289
290        // Create quantization parameters
291        let params =
292            QuantizationParams::from_min_max(min_val, max_val, config.scheme, config.symmetric);
293
294        // Quantize data
295        let quantized_data: Vec<i32> = data.iter().map(|&v| params.quantize(v)).collect();
296
297        Ok(Self {
298            data: quantized_data,
299            shape,
300            params: vec![params],
301            config,
302        })
303    }
304
305    /// Quantize a tensor with per-channel quantization
306    pub fn quantize_per_channel(
307        data: &[f32],
308        shape: Vec<usize>,
309        config: QuantizationConfig,
310    ) -> Result<Self, QuantizationError> {
311        if data.is_empty() {
312            return Err(QuantizationError::EmptyTensor);
313        }
314
315        let num_channels = match config.granularity {
316            QuantizationGranularity::PerChannel { num_channels } => num_channels,
317            _ => {
318                return Err(QuantizationError::UnsupportedScheme(
319                    "Expected per-channel granularity".to_string(),
320                ))
321            }
322        };
323
324        if shape.is_empty() {
325            return Err(QuantizationError::InvalidShape("Empty shape".to_string()));
326        }
327
328        // First dimension is typically the output channel dimension
329        if shape[0] != num_channels {
330            return Err(QuantizationError::InvalidChannelCount {
331                expected: num_channels,
332                got: shape[0],
333            });
334        }
335
336        let channel_size = data.len() / num_channels;
337
338        // Calculate parameters for each channel
339        let mut params = Vec::with_capacity(num_channels);
340        for i in 0..num_channels {
341            let start = i * channel_size;
342            let end = start + channel_size;
343            let channel_data = &data[start..end];
344
345            let (min_val, max_val) = Self::calculate_min_max(channel_data, &config.calibration)?;
346            let channel_params =
347                QuantizationParams::from_min_max(min_val, max_val, config.scheme, config.symmetric);
348            params.push(channel_params);
349        }
350
351        // Quantize each channel
352        let mut quantized_data = Vec::with_capacity(data.len());
353        for (i, chunk) in data.chunks(channel_size).enumerate() {
354            for &value in chunk {
355                quantized_data.push(params[i].quantize(value));
356            }
357        }
358
359        Ok(Self {
360            data: quantized_data,
361            shape,
362            params,
363            config,
364        })
365    }
366
367    /// Calculate min/max values based on calibration method
368    fn calculate_min_max(
369        data: &[f32],
370        calibration: &CalibrationMethod,
371    ) -> Result<(f32, f32), QuantizationError> {
372        match calibration {
373            CalibrationMethod::MinMax => {
374                let min_val = data.iter().copied().fold(f32::INFINITY, f32::min);
375                let max_val = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
376                Ok((min_val, max_val))
377            }
378            CalibrationMethod::Percentile { lower, upper } => {
379                let mut sorted = data.to_vec();
380                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
381
382                let lower_idx = (sorted.len() as f32 * (*lower as f32 / 100.0)) as usize;
383                let upper_idx = (sorted.len() as f32 * (*upper as f32 / 100.0)) as usize;
384
385                let min_val = sorted[lower_idx.min(sorted.len() - 1)];
386                let max_val = sorted[upper_idx.min(sorted.len() - 1)];
387                Ok((min_val, max_val))
388            }
389            _ => {
390                // Entropy and MSE not yet implemented, fall back to MinMax
391                let min_val = data.iter().copied().fold(f32::INFINITY, f32::min);
392                let max_val = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
393                Ok((min_val, max_val))
394            }
395        }
396    }
397
398    /// Dequantize the tensor back to f32
399    pub fn dequantize(&self) -> Vec<f32> {
400        match self.config.granularity {
401            QuantizationGranularity::PerTensor => {
402                let params = &self.params[0];
403                self.data.iter().map(|&q| params.dequantize(q)).collect()
404            }
405            QuantizationGranularity::PerChannel { num_channels } => {
406                let channel_size = self.data.len() / num_channels;
407                let mut result = Vec::with_capacity(self.data.len());
408
409                for (i, chunk) in self.data.chunks(channel_size).enumerate() {
410                    for &q in chunk {
411                        result.push(self.params[i].dequantize(q));
412                    }
413                }
414                result
415            }
416            QuantizationGranularity::PerGroup { .. } => {
417                // Not yet implemented, fall back to per-tensor
418                let params = &self.params[0];
419                self.data.iter().map(|&q| params.dequantize(q)).collect()
420            }
421        }
422    }
423
424    /// Get the compression ratio compared to f32 storage
425    pub fn compression_ratio(&self) -> f32 {
426        let original_bytes = self.data.len() * 4; // f32 = 4 bytes
427        let quantized_bytes = self.size_bytes();
428        original_bytes as f32 / quantized_bytes as f32
429    }
430
431    /// Calculate size in bytes for quantized representation
432    pub fn size_bytes(&self) -> usize {
433        match self.config.scheme {
434            QuantizationScheme::Int4 => {
435                // INT4 packs 2 values per byte
436                self.data.len().div_ceil(2)
437                    + self.params.len() * std::mem::size_of::<QuantizationParams>()
438            }
439            QuantizationScheme::Int8 => {
440                self.data.len() + self.params.len() * std::mem::size_of::<QuantizationParams>()
441            }
442            QuantizationScheme::Int16 => {
443                self.data.len() * 2 + self.params.len() * std::mem::size_of::<QuantizationParams>()
444            }
445        }
446    }
447
448    /// Pack INT4 data into bytes (2 values per byte)
449    pub fn pack_int4(&self) -> Result<Vec<u8>, QuantizationError> {
450        if self.config.scheme != QuantizationScheme::Int4 {
451            return Err(QuantizationError::InvalidBitWidth(
452                self.config.scheme.bit_width(),
453            ));
454        }
455
456        let mut packed = Vec::with_capacity(self.data.len().div_ceil(2));
457        for chunk in self.data.chunks(2) {
458            let high = (chunk[0] & 0xF) as u8;
459            let low = if chunk.len() > 1 {
460                (chunk[1] & 0xF) as u8
461            } else {
462                0
463            };
464            packed.push((high << 4) | low);
465        }
466
467        Ok(packed)
468    }
469
470    /// Unpack INT4 data from bytes
471    pub fn unpack_int4(packed: &[u8], length: usize) -> Vec<i32> {
472        let mut unpacked = Vec::with_capacity(length);
473        for &byte in packed {
474            let high = ((byte >> 4) & 0xF) as i32;
475            let low = (byte & 0xF) as i32;
476            unpacked.push(high);
477            if unpacked.len() < length {
478                unpacked.push(low);
479            }
480        }
481        unpacked.truncate(length);
482        unpacked
483    }
484
485    /// Calculate quantization error (MSE)
486    pub fn quantization_error(&self, original: &[f32]) -> f32 {
487        let dequantized = self.dequantize();
488        let mse: f32 = original
489            .iter()
490            .zip(dequantized.iter())
491            .map(|(o, d)| {
492                let diff = o - d;
493                diff * diff
494            })
495            .sum::<f32>()
496            / original.len() as f32;
497        mse
498    }
499}
500
501impl fmt::Display for QuantizedTensor {
502    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
503        write!(
504            f,
505            "QuantizedTensor({:?}, shape={:?}, scheme={:?}, params={})",
506            self.config.granularity,
507            self.shape,
508            self.config.scheme,
509            self.params.len()
510        )
511    }
512}
513
514/// Dynamic quantization configuration for runtime quantization
515#[derive(Debug, Clone)]
516pub struct DynamicQuantizer {
517    /// Target quantization scheme
518    scheme: QuantizationScheme,
519    /// Use symmetric quantization
520    symmetric: bool,
521    /// Calibration method
522    calibration: CalibrationMethod,
523}
524
525impl DynamicQuantizer {
526    /// Create a new dynamic quantizer
527    pub fn new(scheme: QuantizationScheme, symmetric: bool) -> Self {
528        Self {
529            scheme,
530            symmetric,
531            calibration: CalibrationMethod::MinMax,
532        }
533    }
534
535    /// Quantize activation tensor at runtime
536    pub fn quantize_activation(
537        &self,
538        data: &[f32],
539        shape: Vec<usize>,
540    ) -> Result<QuantizedTensor, QuantizationError> {
541        let config = QuantizationConfig {
542            scheme: self.scheme,
543            granularity: QuantizationGranularity::PerTensor,
544            symmetric: self.symmetric,
545            calibration: self.calibration,
546        };
547
548        QuantizedTensor::quantize_per_tensor(data, shape, config)
549    }
550}
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555
556    #[test]
557    fn test_quantization_scheme_ranges() {
558        assert_eq!(QuantizationScheme::Int8.range(true), (-128, 127));
559        assert_eq!(QuantizationScheme::Int8.range(false), (0, 255));
560        assert_eq!(QuantizationScheme::Int4.range(true), (-8, 7));
561        assert_eq!(QuantizationScheme::Int4.range(false), (0, 15));
562    }
563
564    #[test]
565    fn test_quantization_params_symmetric() {
566        let params = QuantizationParams::from_min_max(-1.0, 1.0, QuantizationScheme::Int8, true);
567        assert_eq!(params.zero_point, 0);
568        assert!(params.scale > 0.0);
569
570        // Test quantization
571        assert_eq!(params.quantize(0.0), 0);
572        assert!(params.quantize(1.0) > 0);
573        assert!(params.quantize(-1.0) < 0);
574    }
575
576    #[test]
577    fn test_quantization_params_asymmetric() {
578        // Use a range that doesn't start at zero to ensure non-zero zero_point
579        let params = QuantizationParams::from_min_max(0.5, 1.5, QuantizationScheme::Int8, false);
580        // For asymmetric quantization, zero_point should be calculated
581        assert!(params.scale > 0.0);
582
583        // Test another case with negative values
584        let params2 = QuantizationParams::from_min_max(-1.0, 0.5, QuantizationScheme::Int8, false);
585        assert!(params2.scale > 0.0);
586        assert!(params2.zero_point >= params2.qmin && params2.zero_point <= params2.qmax);
587    }
588
589    #[test]
590    fn test_per_tensor_quantization() {
591        let data = vec![0.5, -0.3, 0.8, -0.1];
592        let config = QuantizationConfig::int8_symmetric();
593        let quantized = QuantizedTensor::quantize_per_tensor(&data, vec![4], config).unwrap();
594
595        assert_eq!(quantized.data.len(), 4);
596        assert_eq!(quantized.params.len(), 1);
597
598        // Dequantize and check
599        let dequantized = quantized.dequantize();
600        assert_eq!(dequantized.len(), 4);
601
602        // Should be close to original
603        for (orig, deq) in data.iter().zip(dequantized.iter()) {
604            assert!((orig - deq).abs() < 0.01);
605        }
606    }
607
608    #[test]
609    fn test_per_channel_quantization() {
610        // 2 channels, 4 elements each
611        let data = vec![0.5, 0.3, -0.2, -0.4, 0.1, 0.6, -0.5, 0.2];
612        let config = QuantizationConfig::int8_per_channel(2);
613        let quantized = QuantizedTensor::quantize_per_channel(&data, vec![2, 4], config).unwrap();
614
615        assert_eq!(quantized.data.len(), 8);
616        assert_eq!(quantized.params.len(), 2);
617
618        // Each channel should have its own parameters
619        assert_ne!(quantized.params[0].scale, quantized.params[1].scale);
620    }
621
622    #[test]
623    fn test_int4_quantization() {
624        let data = vec![0.1, 0.2, 0.3, 0.4];
625        let config = QuantizationConfig::int4_symmetric();
626        let quantized = QuantizedTensor::quantize_per_tensor(&data, vec![4], config).unwrap();
627
628        // INT4 range is -8 to 7
629        for &q in &quantized.data {
630            assert!(q >= -8 && q <= 7);
631        }
632
633        // Test packing
634        let packed = quantized.pack_int4().unwrap();
635        assert_eq!(packed.len(), 2); // 4 values packed into 2 bytes
636
637        // Test unpacking
638        let unpacked = QuantizedTensor::unpack_int4(&packed, 4);
639        assert_eq!(unpacked, quantized.data);
640    }
641
642    #[test]
643    fn test_compression_ratio() {
644        let data = vec![1.0; 100];
645        let config = QuantizationConfig::int8_symmetric();
646        let quantized = QuantizedTensor::quantize_per_tensor(&data, vec![100], config).unwrap();
647
648        let ratio = quantized.compression_ratio();
649        assert!(ratio > 1.0); // Should be compressed
650    }
651
652    #[test]
653    fn test_quantization_error() {
654        let data = vec![0.1, 0.5, 0.9, 0.3];
655        let config = QuantizationConfig::int8_symmetric();
656        let quantized = QuantizedTensor::quantize_per_tensor(&data, vec![4], config).unwrap();
657
658        let error = quantized.quantization_error(&data);
659        assert!(error < 0.001); // Error should be small for INT8
660    }
661
662    #[test]
663    fn test_dynamic_quantizer() {
664        let quantizer = DynamicQuantizer::new(QuantizationScheme::Int8, true);
665        let data = vec![1.0, 2.0, 3.0, 4.0];
666
667        let quantized = quantizer.quantize_activation(&data, vec![4]).unwrap();
668        assert_eq!(quantized.data.len(), 4);
669
670        let dequantized = quantized.dequantize();
671        for (orig, deq) in data.iter().zip(dequantized.iter()) {
672            assert!((orig - deq).abs() < 0.1);
673        }
674    }
675
676    #[test]
677    fn test_percentile_calibration() {
678        let mut data = vec![0.0; 100];
679        // Add outliers
680        data[0] = -100.0;
681        data[99] = 100.0;
682        // Normal data
683        for i in 1..99 {
684            data[i] = (i as f32 - 50.0) / 50.0; // Range roughly -1 to 1
685        }
686
687        let config = QuantizationConfig {
688            scheme: QuantizationScheme::Int8,
689            granularity: QuantizationGranularity::PerTensor,
690            symmetric: true,
691            calibration: CalibrationMethod::Percentile {
692                lower: 1,
693                upper: 99,
694            },
695        };
696
697        let quantized = QuantizedTensor::quantize_per_tensor(&data, vec![100], config).unwrap();
698
699        // The outliers should be clipped in the calibration
700        let params = &quantized.params[0];
701        assert!(params.scale < 1.0); // Scale should be much less than if we used min/max with outliers
702    }
703
704    #[test]
705    fn test_empty_tensor_error() {
706        let data: Vec<f32> = vec![];
707        let config = QuantizationConfig::int8_symmetric();
708        let result = QuantizedTensor::quantize_per_tensor(&data, vec![0], config);
709        assert!(result.is_err());
710    }
711
712    #[test]
713    fn test_invalid_channel_count() {
714        let data = vec![1.0; 8];
715        let config = QuantizationConfig::int8_per_channel(3); // Wrong number of channels
716        let result = QuantizedTensor::quantize_per_channel(&data, vec![2, 4], config);
717        assert!(result.is_err());
718    }
719}