kizzasi_model/
quantization.rs

1//! Weight Quantization for Efficient Inference
2//!
3//! Provides INT8 quantization for model weights to reduce memory usage
4//! and improve inference speed on integer-optimized hardware.
5//!
6//! # Features
7//!
8//! - **INT8 quantization**: Symmetric and asymmetric quantization
9//! - **Per-tensor quantization**: Single scale/zero-point for entire tensor
10//! - **Per-channel quantization**: Independent scale/zero-point per output channel
11//! - **Calibration**: Automatic scale/zero-point computation from data
12//! - **Mixed precision**: Selective quantization of layers
13//!
14//! # Theory
15//!
16//! Quantization maps floating-point values to integers:
17//! ```text
18//! q = round(x / scale) + zero_point
19//! x_approx = (q - zero_point) * scale
20//! ```
21//!
22//! For symmetric quantization (zero_point = 0):
23//! ```text
24//! scale = max(|x|) / 127
25//! q = clamp(round(x / scale), -128, 127)
26//! ```
27
28use crate::error::{ModelError, ModelResult};
29use scirs2_core::ndarray::{Array1, Array2};
30
31/// Quantization method
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum QuantizationMethod {
34    /// Symmetric quantization (zero_point = 0)
35    Symmetric,
36    /// Asymmetric quantization (arbitrary zero_point)
37    Asymmetric,
38}
39
40/// Quantization granularity
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum QuantizationGranularity {
43    /// Single scale/zero_point for entire tensor
44    PerTensor,
45    /// Independent scale/zero_point per output channel (dim 0)
46    PerChannel,
47}
48
49/// Quantization parameters
50#[derive(Debug, Clone)]
51pub struct QuantizationParams {
52    /// Scale factor(s)
53    pub scale: Vec<f32>,
54    /// Zero point(s)
55    pub zero_point: Vec<i8>,
56    /// Quantization method
57    pub method: QuantizationMethod,
58    /// Granularity
59    pub granularity: QuantizationGranularity,
60}
61
62impl QuantizationParams {
63    /// Create symmetric per-tensor quantization params
64    pub fn symmetric_per_tensor(scale: f32) -> Self {
65        Self {
66            scale: vec![scale],
67            zero_point: vec![0],
68            method: QuantizationMethod::Symmetric,
69            granularity: QuantizationGranularity::PerTensor,
70        }
71    }
72
73    /// Create asymmetric per-tensor quantization params
74    pub fn asymmetric_per_tensor(scale: f32, zero_point: i8) -> Self {
75        Self {
76            scale: vec![scale],
77            zero_point: vec![zero_point],
78            method: QuantizationMethod::Asymmetric,
79            granularity: QuantizationGranularity::PerTensor,
80        }
81    }
82
83    /// Create symmetric per-channel quantization params
84    pub fn symmetric_per_channel(scales: Vec<f32>) -> Self {
85        let n = scales.len();
86        Self {
87            scale: scales,
88            zero_point: vec![0; n],
89            method: QuantizationMethod::Symmetric,
90            granularity: QuantizationGranularity::PerChannel,
91        }
92    }
93
94    /// Validate parameters
95    pub fn validate(&self) -> ModelResult<()> {
96        if self.scale.is_empty() {
97            return Err(ModelError::invalid_config("scale cannot be empty"));
98        }
99        if self.scale.len() != self.zero_point.len() {
100            return Err(ModelError::invalid_config(
101                "scale and zero_point must have same length",
102            ));
103        }
104        for &s in &self.scale {
105            if s <= 0.0 || !s.is_finite() {
106                return Err(ModelError::invalid_config(format!("invalid scale: {}", s)));
107            }
108        }
109        Ok(())
110    }
111}
112
113/// Quantized weight tensor
114#[derive(Debug, Clone)]
115pub struct QuantizedWeight {
116    /// Quantized data (INT8)
117    pub data: Vec<i8>,
118    /// Original shape
119    pub shape: Vec<usize>,
120    /// Quantization parameters
121    pub params: QuantizationParams,
122}
123
124impl QuantizedWeight {
125    /// Create a new quantized weight
126    pub fn new(data: Vec<i8>, shape: Vec<usize>, params: QuantizationParams) -> ModelResult<Self> {
127        params.validate()?;
128
129        let total_size: usize = shape.iter().product();
130        if data.len() != total_size {
131            return Err(ModelError::invalid_config(format!(
132                "data length {} does not match shape {:?}",
133                data.len(),
134                shape
135            )));
136        }
137
138        Ok(Self {
139            data,
140            shape,
141            params,
142        })
143    }
144
145    /// Dequantize to f32 array (1D)
146    pub fn dequantize_1d(&self) -> ModelResult<Array1<f32>> {
147        if self.shape.len() != 1 {
148            return Err(ModelError::invalid_config(format!(
149                "expected 1D shape, got {:?}",
150                self.shape
151            )));
152        }
153
154        let n = self.shape[0];
155        let mut result = Array1::zeros(n);
156
157        match self.params.granularity {
158            QuantizationGranularity::PerTensor => {
159                let scale = self.params.scale[0];
160                let zero_point = self.params.zero_point[0];
161
162                for i in 0..n {
163                    result[i] = (self.data[i] as i32 - zero_point as i32) as f32 * scale;
164                }
165            }
166            QuantizationGranularity::PerChannel => {
167                return Err(ModelError::invalid_config(
168                    "per-channel quantization not supported for 1D tensors",
169                ));
170            }
171        }
172
173        Ok(result)
174    }
175
176    /// Dequantize to f32 array (2D)
177    pub fn dequantize_2d(&self) -> ModelResult<Array2<f32>> {
178        if self.shape.len() != 2 {
179            return Err(ModelError::invalid_config(format!(
180                "expected 2D shape, got {:?}",
181                self.shape
182            )));
183        }
184
185        let (rows, cols) = (self.shape[0], self.shape[1]);
186        let mut result = Array2::zeros((rows, cols));
187
188        match self.params.granularity {
189            QuantizationGranularity::PerTensor => {
190                let scale = self.params.scale[0];
191                let zero_point = self.params.zero_point[0];
192
193                for i in 0..rows {
194                    for j in 0..cols {
195                        let idx = i * cols + j;
196                        result[[i, j]] = (self.data[idx] as i32 - zero_point as i32) as f32 * scale;
197                    }
198                }
199            }
200            QuantizationGranularity::PerChannel => {
201                // Per-channel quantization: independent scale/zero_point per row (output channel)
202                if self.params.scale.len() != rows {
203                    return Err(ModelError::invalid_config(format!(
204                        "expected {} scales for per-channel, got {}",
205                        rows,
206                        self.params.scale.len()
207                    )));
208                }
209
210                for i in 0..rows {
211                    let scale = self.params.scale[i];
212                    let zero_point = self.params.zero_point[i];
213
214                    for j in 0..cols {
215                        let idx = i * cols + j;
216                        result[[i, j]] = (self.data[idx] as i32 - zero_point as i32) as f32 * scale;
217                    }
218                }
219            }
220        }
221
222        Ok(result)
223    }
224
225    /// Get memory size in bytes
226    pub fn memory_size(&self) -> usize {
227        self.data.len() // INT8: 1 byte per element
228    }
229}
230
231/// Quantize f32 array to INT8 using symmetric quantization
232///
233/// # Arguments
234///
235/// * `array` - Input array
236///
237/// # Returns
238///
239/// Quantized weight with symmetric per-tensor quantization
240pub fn quantize_symmetric_1d(array: &Array1<f32>) -> ModelResult<QuantizedWeight> {
241    let max_val = array.iter().map(|&x| x.abs()).fold(0.0f32, f32::max);
242
243    if max_val == 0.0 {
244        // All zeros - use scale of 1.0 to avoid division by zero
245        let data = vec![0i8; array.len()];
246        let params = QuantizationParams::symmetric_per_tensor(1.0);
247        return QuantizedWeight::new(data, vec![array.len()], params);
248    }
249
250    let scale = max_val / 127.0;
251    let mut data = Vec::with_capacity(array.len());
252
253    for &x in array.iter() {
254        let q = (x / scale).round() as i32;
255        let q_clamped = q.clamp(-128, 127) as i8;
256        data.push(q_clamped);
257    }
258
259    let params = QuantizationParams::symmetric_per_tensor(scale);
260    QuantizedWeight::new(data, vec![array.len()], params)
261}
262
263/// Quantize f32 2D array to INT8 using symmetric per-tensor quantization
264pub fn quantize_symmetric_2d(array: &Array2<f32>) -> ModelResult<QuantizedWeight> {
265    let max_val = array.iter().map(|&x| x.abs()).fold(0.0f32, f32::max);
266
267    if max_val == 0.0 {
268        let (rows, cols) = array.dim();
269        let data = vec![0i8; rows * cols];
270        let params = QuantizationParams::symmetric_per_tensor(1.0);
271        return QuantizedWeight::new(data, vec![rows, cols], params);
272    }
273
274    let scale = max_val / 127.0;
275    let (rows, cols) = array.dim();
276    let mut data = Vec::with_capacity(rows * cols);
277
278    for i in 0..rows {
279        for j in 0..cols {
280            let x = array[[i, j]];
281            let q = (x / scale).round() as i32;
282            let q_clamped = q.clamp(-128, 127) as i8;
283            data.push(q_clamped);
284        }
285    }
286
287    let params = QuantizationParams::symmetric_per_tensor(scale);
288    QuantizedWeight::new(data, vec![rows, cols], params)
289}
290
291/// Quantize f32 2D array to INT8 using symmetric per-channel quantization
292///
293/// Each output channel (row) has independent scale factor
294pub fn quantize_symmetric_per_channel(array: &Array2<f32>) -> ModelResult<QuantizedWeight> {
295    let (rows, cols) = array.dim();
296    let mut scales = Vec::with_capacity(rows);
297    let mut data = Vec::with_capacity(rows * cols);
298
299    // Compute scale per row (output channel)
300    for i in 0..rows {
301        let row = array.row(i);
302        let max_val = row.iter().map(|&x| x.abs()).fold(0.0f32, f32::max);
303
304        let scale = if max_val == 0.0 {
305            1.0 // Avoid division by zero
306        } else {
307            max_val / 127.0
308        };
309
310        scales.push(scale);
311    }
312
313    // Quantize each row with its scale
314    for i in 0..rows {
315        let scale = scales[i];
316        for j in 0..cols {
317            let x = array[[i, j]];
318            let q = (x / scale).round() as i32;
319            let q_clamped = q.clamp(-128, 127) as i8;
320            data.push(q_clamped);
321        }
322    }
323
324    let params = QuantizationParams::symmetric_per_channel(scales);
325    QuantizedWeight::new(data, vec![rows, cols], params)
326}
327
328/// Quantize f32 array to INT8 using asymmetric quantization
329///
330/// Computes optimal scale and zero-point from min/max values
331pub fn quantize_asymmetric_1d(array: &Array1<f32>) -> ModelResult<QuantizedWeight> {
332    let min_val = array.iter().copied().fold(f32::INFINITY, f32::min);
333    let max_val = array.iter().copied().fold(f32::NEG_INFINITY, f32::max);
334
335    if (max_val - min_val).abs() < 1e-8 {
336        // Constant array
337        let data = vec![0i8; array.len()];
338        let params = QuantizationParams::asymmetric_per_tensor(1.0, 0);
339        return QuantizedWeight::new(data, vec![array.len()], params);
340    }
341
342    let scale = (max_val - min_val) / 255.0;
343    let zero_point_f = -128.0 - min_val / scale;
344    let zero_point = zero_point_f.round().clamp(-128.0, 127.0) as i8;
345
346    let mut data = Vec::with_capacity(array.len());
347
348    for &x in array.iter() {
349        let q_f = x / scale + zero_point as f32;
350        let q = q_f.round().clamp(-128.0, 127.0) as i8;
351        data.push(q);
352    }
353
354    let params = QuantizationParams::asymmetric_per_tensor(scale, zero_point);
355    QuantizedWeight::new(data, vec![array.len()], params)
356}
357
358/// Statistics for quantization calibration
359#[derive(Debug, Clone)]
360pub struct CalibrationStats {
361    /// Minimum observed value
362    pub min: f32,
363    /// Maximum observed value
364    pub max: f32,
365    /// Number of observations
366    pub count: usize,
367}
368
369impl CalibrationStats {
370    /// Create new calibration stats
371    pub fn new() -> Self {
372        Self {
373            min: f32::INFINITY,
374            max: f32::NEG_INFINITY,
375            count: 0,
376        }
377    }
378
379    /// Update statistics with new data
380    pub fn update_1d(&mut self, data: &Array1<f32>) {
381        for &x in data.iter() {
382            self.min = self.min.min(x);
383            self.max = self.max.max(x);
384        }
385        self.count += data.len();
386    }
387
388    /// Update statistics with 2D data
389    pub fn update_2d(&mut self, data: &Array2<f32>) {
390        for &x in data.iter() {
391            self.min = self.min.min(x);
392            self.max = self.max.max(x);
393        }
394        self.count += data.len();
395    }
396
397    /// Compute symmetric quantization params from statistics
398    pub fn to_symmetric_params(&self) -> ModelResult<QuantizationParams> {
399        let max_abs = self.max.abs().max(self.min.abs());
400        if max_abs == 0.0 {
401            Ok(QuantizationParams::symmetric_per_tensor(1.0))
402        } else {
403            Ok(QuantizationParams::symmetric_per_tensor(max_abs / 127.0))
404        }
405    }
406
407    /// Compute asymmetric quantization params from statistics
408    pub fn to_asymmetric_params(&self) -> ModelResult<QuantizationParams> {
409        if (self.max - self.min).abs() < 1e-8 {
410            Ok(QuantizationParams::asymmetric_per_tensor(1.0, 0))
411        } else {
412            let scale = (self.max - self.min) / 255.0;
413            let zero_point_f = -128.0 - self.min / scale;
414            let zero_point = zero_point_f.round().clamp(-128.0, 127.0) as i8;
415            Ok(QuantizationParams::asymmetric_per_tensor(scale, zero_point))
416        }
417    }
418}
419
420impl Default for CalibrationStats {
421    fn default() -> Self {
422        Self::new()
423    }
424}
425
426/// Activation quantization for dynamic runtime quantization
427///
428/// Quantizes activations on-the-fly during inference for additional memory savings
429#[derive(Debug, Clone)]
430pub struct ActivationQuantizer {
431    /// Method to use
432    method: QuantizationMethod,
433    /// Granularity
434    #[allow(dead_code)]
435    granularity: QuantizationGranularity,
436    /// Calibration statistics (optional, for static quantization)
437    calibration: Option<QuantizationParams>,
438}
439
440impl ActivationQuantizer {
441    /// Create a new activation quantizer with symmetric per-tensor quantization
442    pub fn new_symmetric() -> Self {
443        Self {
444            method: QuantizationMethod::Symmetric,
445            granularity: QuantizationGranularity::PerTensor,
446            calibration: None,
447        }
448    }
449
450    /// Create a new activation quantizer with asymmetric per-tensor quantization
451    pub fn new_asymmetric() -> Self {
452        Self {
453            method: QuantizationMethod::Asymmetric,
454            granularity: QuantizationGranularity::PerTensor,
455            calibration: None,
456        }
457    }
458
459    /// Set calibration parameters from statistics
460    pub fn calibrate(&mut self, stats: &CalibrationStats) -> ModelResult<()> {
461        self.calibration = Some(match self.method {
462            QuantizationMethod::Symmetric => stats.to_symmetric_params()?,
463            QuantizationMethod::Asymmetric => stats.to_asymmetric_params()?,
464        });
465        Ok(())
466    }
467
468    /// Quantize activation (1D)
469    pub fn quantize_activation_1d(&self, activation: &Array1<f32>) -> ModelResult<Vec<i8>> {
470        let params = if let Some(ref cal) = self.calibration {
471            // Use calibrated parameters
472            cal.clone()
473        } else {
474            // Compute parameters on-the-fly
475            let min_val = activation.iter().copied().fold(f32::INFINITY, f32::min);
476            let max_val = activation.iter().copied().fold(f32::NEG_INFINITY, f32::max);
477
478            match self.method {
479                QuantizationMethod::Symmetric => {
480                    let max_abs = max_val.abs().max(min_val.abs());
481                    QuantizationParams::symmetric_per_tensor(max_abs / 127.0)
482                }
483                QuantizationMethod::Asymmetric => {
484                    let scale = (max_val - min_val) / 255.0;
485                    let zero_point = (-128.0 - min_val / scale).round().clamp(-128.0, 127.0) as i8;
486                    QuantizationParams::asymmetric_per_tensor(scale, zero_point)
487                }
488            }
489        };
490
491        let scale = params.scale[0];
492        let zero_point = params.zero_point[0];
493
494        let mut quantized = Vec::with_capacity(activation.len());
495        for &x in activation.iter() {
496            let q = match self.method {
497                QuantizationMethod::Symmetric => (x / scale).round().clamp(-128.0, 127.0) as i8,
498                QuantizationMethod::Asymmetric => {
499                    let q_f = x / scale + zero_point as f32;
500                    q_f.round().clamp(-128.0, 127.0) as i8
501                }
502            };
503            quantized.push(q);
504        }
505
506        Ok(quantized)
507    }
508
509    /// Dequantize activation (1D)
510    pub fn dequantize_activation_1d(
511        &self,
512        quantized: &[i8],
513        original_len: usize,
514    ) -> ModelResult<Array1<f32>> {
515        if quantized.len() != original_len {
516            return Err(ModelError::invalid_config(format!(
517                "quantized length {} doesn't match expected {}",
518                quantized.len(),
519                original_len
520            )));
521        }
522
523        let params = self
524            .calibration
525            .as_ref()
526            .ok_or_else(|| ModelError::invalid_config("calibration required for dequantization"))?;
527
528        let scale = params.scale[0];
529        let zero_point = params.zero_point[0];
530
531        let mut result = Array1::zeros(original_len);
532        for (i, &q) in quantized.iter().enumerate() {
533            result[i] = (q as i32 - zero_point as i32) as f32 * scale;
534        }
535
536        Ok(result)
537    }
538
539    /// Quantize and immediately dequantize (simulates quantization error)
540    pub fn simulate_quantization(&self, activation: &Array1<f32>) -> ModelResult<Array1<f32>> {
541        // Compute quantization parameters
542        let min_val = activation.iter().copied().fold(f32::INFINITY, f32::min);
543        let max_val = activation.iter().copied().fold(f32::NEG_INFINITY, f32::max);
544
545        let (scale, zero_point) = match self.method {
546            QuantizationMethod::Symmetric => {
547                let max_abs = max_val.abs().max(min_val.abs());
548                (max_abs / 127.0, 0)
549            }
550            QuantizationMethod::Asymmetric => {
551                let scale = (max_val - min_val) / 255.0;
552                let zp = (-128.0 - min_val / scale).round().clamp(-128.0, 127.0) as i8;
553                (scale, zp)
554            }
555        };
556
557        // Quantize and dequantize in one pass
558        let mut result = Array1::zeros(activation.len());
559        for (i, &x) in activation.iter().enumerate() {
560            let q = match self.method {
561                QuantizationMethod::Symmetric => (x / scale).round().clamp(-128.0, 127.0) as i8,
562                QuantizationMethod::Asymmetric => {
563                    let q_f = x / scale + zero_point as f32;
564                    q_f.round().clamp(-128.0, 127.0) as i8
565                }
566            };
567            result[i] = (q as i32 - zero_point as i32) as f32 * scale;
568        }
569
570        Ok(result)
571    }
572
573    /// Get memory savings estimate (percentage)
574    pub fn memory_savings(&self) -> f32 {
575        // INT8 uses 1 byte vs FP32's 4 bytes
576        75.0 // 75% savings
577    }
578}
579
580impl Default for ActivationQuantizer {
581    fn default() -> Self {
582        Self::new_symmetric()
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589
590    fn approx_eq(a: f32, b: f32, epsilon: f32) -> bool {
591        (a - b).abs() < epsilon
592    }
593
594    #[test]
595    fn test_symmetric_quantization_1d() {
596        let array = Array1::from_vec(vec![-10.0, -5.0, 0.0, 5.0, 10.0]);
597        let quantized = quantize_symmetric_1d(&array).expect("Failed to quantize 1d array");
598
599        assert_eq!(quantized.shape, vec![5]);
600        assert_eq!(quantized.params.method, QuantizationMethod::Symmetric);
601
602        // Dequantize and check error
603        let dequantized = quantized
604            .dequantize_1d()
605            .expect("Failed to dequantize 1d array");
606        for i in 0..5 {
607            assert!(approx_eq(array[i], dequantized[i], 0.1));
608        }
609    }
610
611    #[test]
612    fn test_symmetric_quantization_2d() {
613        let array = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, -1.0, -2.0, -3.0])
614            .expect("Failed to create test array");
615        let quantized = quantize_symmetric_2d(&array).expect("Failed to quantize 2d array");
616
617        assert_eq!(quantized.shape, vec![2, 3]);
618
619        let dequantized = quantized
620            .dequantize_2d()
621            .expect("Failed to dequantize 2d array");
622        for i in 0..2 {
623            for j in 0..3 {
624                assert!(approx_eq(array[[i, j]], dequantized[[i, j]], 0.05));
625            }
626        }
627    }
628
629    #[test]
630    fn test_per_channel_quantization() {
631        // Different ranges per channel
632        let array = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 10.0, 20.0, 30.0])
633            .expect("Failed to create test array");
634        let quantized =
635            quantize_symmetric_per_channel(&array).expect("Failed to quantize per channel");
636
637        assert_eq!(
638            quantized.params.granularity,
639            QuantizationGranularity::PerChannel
640        );
641        assert_eq!(quantized.params.scale.len(), 2);
642
643        let dequantized = quantized
644            .dequantize_2d()
645            .expect("Failed to dequantize 2d array");
646
647        // Per-channel quantization should give better accuracy for different ranges
648        for i in 0..2 {
649            for j in 0..3 {
650                assert!(approx_eq(array[[i, j]], dequantized[[i, j]], 0.3));
651            }
652        }
653    }
654
655    #[test]
656    fn test_asymmetric_quantization() {
657        let array = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
658        let quantized = quantize_asymmetric_1d(&array).expect("Failed to quantize asymmetric");
659
660        assert_eq!(quantized.params.method, QuantizationMethod::Asymmetric);
661
662        let dequantized = quantized.dequantize_1d().expect("Failed to dequantize");
663        for i in 0..5 {
664            assert!(approx_eq(array[i], dequantized[i], 0.05));
665        }
666    }
667
668    #[test]
669    fn test_calibration_stats() {
670        let mut stats = CalibrationStats::new();
671
672        let data1 = Array1::from_vec(vec![-5.0, 0.0, 5.0]);
673        let data2 = Array1::from_vec(vec![-10.0, -2.0, 8.0]);
674
675        stats.update_1d(&data1);
676        stats.update_1d(&data2);
677
678        assert_eq!(stats.min, -10.0);
679        assert_eq!(stats.max, 8.0);
680        assert_eq!(stats.count, 6);
681
682        let params = stats.to_symmetric_params().expect("Failed to get params");
683        assert!(approx_eq(params.scale[0], 10.0 / 127.0, 1e-6));
684    }
685
686    #[test]
687    fn test_memory_savings() {
688        let array = Array2::from_shape_vec((100, 100), vec![1.0; 10000])
689            .expect("Failed to create test array");
690        let quantized = quantize_symmetric_2d(&array).expect("Failed to quantize");
691
692        // INT8: 1 byte per element, FP32: 4 bytes per element
693        let original_size = 10000 * 4;
694        let quantized_size = quantized.memory_size();
695
696        assert_eq!(quantized_size, 10000); // 1 byte per element
697        assert!(quantized_size < original_size / 3); // At least 4x compression
698    }
699
700    #[test]
701    fn test_activation_quantizer_symmetric() {
702        let quantizer = ActivationQuantizer::new_symmetric();
703        let activation = Array1::from_vec(vec![-10.0, -5.0, 0.0, 5.0, 10.0]);
704
705        // Quantize
706        let quantized = quantizer
707            .quantize_activation_1d(&activation)
708            .expect("Failed to quantize activation");
709        assert_eq!(quantized.len(), activation.len());
710
711        // Check memory savings
712        assert_eq!(quantizer.memory_savings(), 75.0);
713    }
714
715    #[test]
716    fn test_activation_quantizer_asymmetric() {
717        let quantizer = ActivationQuantizer::new_asymmetric();
718        let activation = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
719
720        let quantized = quantizer
721            .quantize_activation_1d(&activation)
722            .expect("Failed to quantize activation");
723        assert_eq!(quantized.len(), activation.len());
724    }
725
726    #[test]
727    fn test_activation_quantizer_with_calibration() {
728        let mut quantizer = ActivationQuantizer::new_symmetric();
729
730        // Collect calibration statistics
731        let mut stats = CalibrationStats::new();
732        stats.update_1d(&Array1::from_vec(vec![-10.0, 0.0, 10.0]));
733        stats.update_1d(&Array1::from_vec(vec![-5.0, 0.0, 5.0]));
734
735        // Calibrate quantizer
736        quantizer.calibrate(&stats).expect("Failed to calibrate");
737
738        // Now quantize with calibrated parameters
739        let activation = Array1::from_vec(vec![-8.0, 0.0, 8.0]);
740        let quantized = quantizer
741            .quantize_activation_1d(&activation)
742            .expect("Failed to quantize activation");
743
744        // Dequantize
745        let dequantized = quantizer
746            .dequantize_activation_1d(&quantized, activation.len())
747            .expect("Failed to dequantize activation");
748
749        // Check accuracy
750        for i in 0..activation.len() {
751            assert!((activation[i] - dequantized[i]).abs() < 1.0);
752        }
753    }
754
755    #[test]
756    fn test_simulate_quantization() {
757        let quantizer = ActivationQuantizer::new_symmetric();
758        let activation = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
759
760        let simulated = quantizer
761            .simulate_quantization(&activation)
762            .expect("Failed to simulate quantization");
763
764        // Simulated values should be close to original
765        for i in 0..activation.len() {
766            assert!((activation[i] - simulated[i]).abs() < 0.1);
767        }
768    }
769}