npu_rs/
quantization.rs

1use crate::error::{NpuError, Result};
2use ndarray::ArrayD;
3
4/// Quantization statistics for calibration.
5#[derive(Debug, Clone)]
6pub struct QuantStats {
7    pub min_val: f32,
8    pub max_val: f32,
9    pub mean_val: f32,
10    pub std_val: f32,
11}
12
13impl QuantStats {
14    /// Compute statistics from a tensor.
15    pub fn from_tensor(data: &ArrayD<f32>) -> Self {
16        let values: Vec<f32> = data.iter().cloned().collect();
17        
18        let min_val = values.iter().cloned().fold(f32::INFINITY, f32::min);
19        let max_val = values.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
20        
21        let mean_val = values.iter().sum::<f32>() / values.len() as f32;
22        let variance = values.iter()
23            .map(|v| (v - mean_val).powi(2))
24            .sum::<f32>() / values.len() as f32;
25        let std_val = variance.sqrt();
26
27        Self {
28            min_val,
29            max_val,
30            mean_val,
31            std_val,
32        }
33    }
34
35    /// Get scale factor for quantization.
36    pub fn get_scale(&self, num_bits: u32) -> f32 {
37        let levels = (1u64 << num_bits) as f32;
38        self.max_val / (levels - 1.0)
39    }
40
41    /// Get zero point.
42    pub fn get_zero_point(&self, num_bits: u32, signed: bool) -> i32 {
43        let levels = (1u64 << num_bits) as f32;
44        if signed {
45            (-(levels / 2.0)) as i32
46        } else {
47            0
48        }
49    }
50}
51
52/// Quantization converter.
53pub struct QuantConverter {
54    scale: f32,
55    zero_point: i32,
56    num_bits: u32,
57}
58
59impl QuantConverter {
60    /// Create a new quantization converter.
61    pub fn new(stats: &QuantStats, num_bits: u32, signed: bool) -> Self {
62        Self {
63            scale: stats.get_scale(num_bits),
64            zero_point: stats.get_zero_point(num_bits, signed),
65            num_bits,
66        }
67    }
68
69    /// Quantize float32 to integer.
70    pub fn quantize(&self, value: f32) -> i32 {
71        let quantized = (value / self.scale) as i32 + self.zero_point;
72        let max_val = (1i64 << self.num_bits) as i32 - 1;
73        let min_val = -(1i64 << (self.num_bits - 1)) as i32;
74        quantized.max(min_val).min(max_val)
75    }
76
77    /// Dequantize integer to float32.
78    pub fn dequantize(&self, value: i32) -> f32 {
79        ((value - self.zero_point) as f32) * self.scale
80    }
81
82    /// Quantize entire tensor.
83    pub fn quantize_tensor(&self, tensor: &ArrayD<f32>) -> Result<Vec<i32>> {
84        Ok(tensor.iter().map(|&v| self.quantize(v)).collect())
85    }
86
87    /// Dequantize tensor.
88    pub fn dequantize_tensor(&self, quantized: &[i32]) -> Result<ArrayD<f32>> {
89        let values: Vec<f32> = quantized.iter().map(|&v| self.dequantize(v)).collect();
90        Ok(ArrayD::from_shape_vec(
91            ndarray::IxDyn(&[quantized.len()]),
92            values,
93        ).map_err(|_| NpuError::InvalidShape("Failed to reshape".to_string()))?)
94    }
95}
96
97/// Post-training quantization engine.
98pub struct PTQEngine {
99    num_bits: u32,
100    signed: bool,
101}
102
103impl PTQEngine {
104    /// Create a new PTQ engine.
105    pub fn new(num_bits: u32, signed: bool) -> Self {
106        Self { num_bits, signed }
107    }
108
109    /// Calibrate on sample data.
110    pub fn calibrate(&self, sample_data: &[ArrayD<f32>]) -> Result<QuantConverter> {
111        if sample_data.is_empty() {
112            return Err(NpuError::InvalidConfiguration(
113                "No calibration data provided".to_string(),
114            ));
115        }
116
117        let mut all_values = Vec::new();
118        for tensor in sample_data {
119            all_values.extend(tensor.iter().cloned());
120        }
121
122        let combined = ArrayD::from_shape_vec(
123            ndarray::IxDyn(&[all_values.len()]),
124            all_values,
125        ).map_err(|_| NpuError::InvalidShape("Failed to calibrate".to_string()))?;
126
127        let stats = QuantStats::from_tensor(&combined);
128        Ok(QuantConverter::new(&stats, self.num_bits, self.signed))
129    }
130}