Skip to main content

cuda_rust_wasm/runtime/
quantization.rs

1//! INT8/INT4 Quantization for inference acceleration
2//!
3//! Provides quantization and dequantization primitives used in neural network
4//! inference to reduce memory bandwidth and leverage integer arithmetic units.
5//! Supports symmetric and asymmetric quantization schemes.
6//!
7//! Reference: "Quantization and Training of Neural Networks for Efficient
8//! Integer-Arithmetic-Only Inference" — Jacob et al., CVPR 2018
9
10use std::fmt;
11
12/// Quantization scheme.
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum QuantScheme {
15    /// Symmetric: zero_point = 0, range = [-scale*127, scale*127]
16    Symmetric,
17    /// Asymmetric: zero_point ≠ 0, full [0, 255] range used
18    Asymmetric,
19}
20
21/// Quantization bit width.
22#[derive(Debug, Clone, Copy, PartialEq)]
23pub enum QuantBits {
24    /// 8-bit integers (INT8).
25    Int8,
26    /// 4-bit integers (INT4), packed 2 per byte.
27    Int4,
28}
29
30/// Quantization parameters computed from calibration.
31#[derive(Debug, Clone)]
32pub struct QuantParams {
33    pub scale: f32,
34    pub zero_point: i32,
35    pub bits: QuantBits,
36    pub scheme: QuantScheme,
37    /// Per-channel scales (if per-channel quantization).
38    pub per_channel_scales: Option<Vec<f32>>,
39}
40
41impl QuantParams {
42    /// Compute quantization parameters from data range.
43    pub fn from_range(min_val: f32, max_val: f32, bits: QuantBits, scheme: QuantScheme) -> Self {
44        let (qmin, qmax) = match bits {
45            QuantBits::Int8 => (-128i32, 127i32),
46            QuantBits::Int4 => (-8i32, 7i32),
47        };
48
49        match scheme {
50            QuantScheme::Symmetric => {
51                let abs_max = min_val.abs().max(max_val.abs());
52                let scale = abs_max / qmax as f32;
53                Self {
54                    scale: if scale == 0.0 { 1.0 } else { scale },
55                    zero_point: 0,
56                    bits,
57                    scheme,
58                    per_channel_scales: None,
59                }
60            }
61            QuantScheme::Asymmetric => {
62                let range = max_val - min_val;
63                let scale = range / (qmax - qmin) as f32;
64                let zero_point = (qmin as f32 - min_val / scale).round() as i32;
65                Self {
66                    scale: if scale == 0.0 { 1.0 } else { scale },
67                    zero_point: zero_point.clamp(qmin, qmax),
68                    bits,
69                    scheme,
70                    per_channel_scales: None,
71                }
72            }
73        }
74    }
75
76    /// Compute parameters from data using min/max calibration.
77    pub fn calibrate(data: &[f32], bits: QuantBits, scheme: QuantScheme) -> Self {
78        if data.is_empty() {
79            return Self { scale: 1.0, zero_point: 0, bits, scheme, per_channel_scales: None };
80        }
81        let min_val = data.iter().cloned().fold(f32::INFINITY, f32::min);
82        let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
83        Self::from_range(min_val, max_val, bits, scheme)
84    }
85}
86
87/// Quantize an f32 tensor to INT8.
88pub fn quantize_int8(data: &[f32], params: &QuantParams) -> Vec<i8> {
89    data.iter().map(|&x| {
90        let q = (x / params.scale).round() as i32 + params.zero_point;
91        q.clamp(-128, 127) as i8
92    }).collect()
93}
94
95/// Dequantize INT8 to f32.
96pub fn dequantize_int8(data: &[i8], params: &QuantParams) -> Vec<f32> {
97    data.iter().map(|&q| {
98        (q as i32 - params.zero_point) as f32 * params.scale
99    }).collect()
100}
101
102/// Quantize an f32 tensor to INT4 (packed, 2 values per byte).
103pub fn quantize_int4(data: &[f32], params: &QuantParams) -> Vec<u8> {
104    let mut packed = Vec::with_capacity((data.len() + 1) / 2);
105    for chunk in data.chunks(2) {
106        let lo = {
107            let q = (chunk[0] / params.scale).round() as i32 + params.zero_point;
108            (q.clamp(-8, 7) & 0x0F) as u8
109        };
110        let hi = if chunk.len() > 1 {
111            let q = (chunk[1] / params.scale).round() as i32 + params.zero_point;
112            ((q.clamp(-8, 7) & 0x0F) as u8) << 4
113        } else {
114            0
115        };
116        packed.push(lo | hi);
117    }
118    packed
119}
120
121/// Dequantize INT4 (packed) to f32.
122pub fn dequantize_int4(data: &[u8], count: usize, params: &QuantParams) -> Vec<f32> {
123    let mut result = Vec::with_capacity(count);
124    for &byte in data {
125        if result.len() >= count { break; }
126        // Low nibble (sign-extend from 4 bits)
127        let lo = (byte & 0x0F) as i8;
128        let lo = if lo & 0x08 != 0 { lo | !0x0F_u8 as i8 } else { lo }; // sign extend
129        result.push((lo as i32 - params.zero_point) as f32 * params.scale);
130
131        if result.len() >= count { break; }
132        // High nibble
133        let hi = ((byte >> 4) & 0x0F) as i8;
134        let hi = if hi & 0x08 != 0 { hi | !0x0F_u8 as i8 } else { hi };
135        result.push((hi as i32 - params.zero_point) as f32 * params.scale);
136    }
137    result
138}
139
140/// INT8 matrix multiply with f32 accumulation: C = A · B
141/// A: (m × k) as i8, B: (k × n) as i8, C: (m × n) as i32 → f32
142pub fn quantized_gemm_int8(
143    a: &[i8], b: &[i8],
144    m: usize, k: usize, n: usize,
145    a_params: &QuantParams, b_params: &QuantParams,
146) -> Vec<f32> {
147    let mut c = vec![0i32; m * n];
148    for i in 0..m {
149        for p in 0..k {
150            let a_val = a[i * k + p] as i32 - a_params.zero_point;
151            for j in 0..n {
152                let b_val = b[p * n + j] as i32 - b_params.zero_point;
153                c[i * n + j] += a_val * b_val;
154            }
155        }
156    }
157    // Dequantize result
158    let output_scale = a_params.scale * b_params.scale;
159    c.iter().map(|&v| v as f32 * output_scale).collect()
160}
161
162/// Compute quantization error (MSE) between original and quantized-dequantized.
163pub fn quantization_error(original: &[f32], params: &QuantParams) -> QuantError {
164    let quantized = quantize_int8(original, params);
165    let dequantized = dequantize_int8(&quantized, params);
166
167    let mse: f64 = original.iter().zip(dequantized.iter())
168        .map(|(&o, &d)| ((o - d) as f64).powi(2))
169        .sum::<f64>() / original.len() as f64;
170
171    let max_error = original.iter().zip(dequantized.iter())
172        .map(|(&o, &d)| (o - d).abs())
173        .fold(0.0f32, f32::max);
174
175    let signal_power: f64 = original.iter().map(|&x| (x as f64).powi(2)).sum::<f64>() / original.len() as f64;
176    let snr = if mse > 0.0 { 10.0 * (signal_power / mse).log10() } else { f64::INFINITY };
177
178    QuantError {
179        mse: mse as f32,
180        max_error,
181        snr_db: snr as f32,
182        compression_ratio: match params.bits {
183            QuantBits::Int8 => 4.0,  // f32 → i8
184            QuantBits::Int4 => 8.0,  // f32 → i4
185        },
186    }
187}
188
189/// Quantization error statistics.
190#[derive(Debug, Clone)]
191pub struct QuantError {
192    pub mse: f32,
193    pub max_error: f32,
194    pub snr_db: f32,
195    pub compression_ratio: f32,
196}
197
198impl fmt::Display for QuantError {
199    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200        write!(f, "QuantError: MSE={:.6}, MaxErr={:.4}, SNR={:.1}dB, {}x compression",
201            self.mse, self.max_error, self.snr_db, self.compression_ratio)
202    }
203}
204
205// ── Tests ──────────────────────────────────────────────────────────
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_symmetric_int8_roundtrip() {
213        let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
214        let params = QuantParams::calibrate(&data, QuantBits::Int8, QuantScheme::Symmetric);
215        let quantized = quantize_int8(&data, &params);
216        let dequantized = dequantize_int8(&quantized, &params);
217
218        for i in 0..data.len() {
219            assert!((data[i] - dequantized[i]).abs() < 0.02,
220                "Mismatch at {}: original={}, dequantized={}", i, data[i], dequantized[i]);
221        }
222    }
223
224    #[test]
225    fn test_asymmetric_int8() {
226        let data = vec![0.0, 0.25, 0.5, 0.75, 1.0];
227        let params = QuantParams::calibrate(&data, QuantBits::Int8, QuantScheme::Asymmetric);
228        let quantized = quantize_int8(&data, &params);
229        let dequantized = dequantize_int8(&quantized, &params);
230
231        for i in 0..data.len() {
232            assert!((data[i] - dequantized[i]).abs() < 0.02,
233                "Asymmetric mismatch at {}: {} vs {}", i, data[i], dequantized[i]);
234        }
235    }
236
237    #[test]
238    fn test_int4_quantization() {
239        let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0, 1.5];
240        let params = QuantParams::calibrate(&data, QuantBits::Int4, QuantScheme::Symmetric);
241        let packed = quantize_int4(&data, &params);
242        let dequantized = dequantize_int4(&packed, data.len(), &params);
243
244        assert_eq!(dequantized.len(), data.len());
245        // INT4 has less precision, allow larger tolerance
246        for i in 0..data.len() {
247            assert!((data[i] - dequantized[i]).abs() < 0.5,
248                "INT4 mismatch at {}: {} vs {}", i, data[i], dequantized[i]);
249        }
250    }
251
252    #[test]
253    fn test_int4_packing() {
254        let data = vec![0.0, 0.0, 0.0, 0.0]; // 4 values → 2 bytes
255        let params = QuantParams::from_range(-1.0, 1.0, QuantBits::Int4, QuantScheme::Symmetric);
256        let packed = quantize_int4(&data, &params);
257        assert_eq!(packed.len(), 2);
258    }
259
260    #[test]
261    fn test_quantized_gemm() {
262        // A: 2×2, B: 2×2
263        let a_f32 = vec![1.0f32, 2.0, 3.0, 4.0];
264        let b_f32 = vec![5.0f32, 6.0, 7.0, 8.0];
265
266        let a_params = QuantParams::calibrate(&a_f32, QuantBits::Int8, QuantScheme::Symmetric);
267        let b_params = QuantParams::calibrate(&b_f32, QuantBits::Int8, QuantScheme::Symmetric);
268
269        let a_q = quantize_int8(&a_f32, &a_params);
270        let b_q = quantize_int8(&b_f32, &b_params);
271
272        let c = quantized_gemm_int8(&a_q, &b_q, 2, 2, 2, &a_params, &b_params);
273        // Expected: [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]] = [[19, 22], [43, 50]]
274        assert!((c[0] - 19.0).abs() < 1.0, "Got {}", c[0]);
275        assert!((c[1] - 22.0).abs() < 1.0, "Got {}", c[1]);
276        assert!((c[2] - 43.0).abs() < 1.5, "Got {}", c[2]);
277        assert!((c[3] - 50.0).abs() < 1.5, "Got {}", c[3]);
278    }
279
280    #[test]
281    fn test_quantization_error() {
282        let data: Vec<f32> = (0..100).map(|i| (i as f32 - 50.0) / 50.0).collect();
283        let params = QuantParams::calibrate(&data, QuantBits::Int8, QuantScheme::Symmetric);
284        let error = quantization_error(&data, &params);
285
286        assert!(error.mse < 0.001, "MSE too high: {}", error.mse);
287        assert!(error.snr_db > 30.0, "SNR too low: {}dB", error.snr_db);
288        assert_eq!(error.compression_ratio, 4.0);
289    }
290
291    #[test]
292    fn test_quantization_error_int4() {
293        let data: Vec<f32> = (0..100).map(|i| (i as f32 - 50.0) / 50.0).collect();
294        let params = QuantParams::calibrate(&data, QuantBits::Int4, QuantScheme::Symmetric);
295        let error = quantization_error(&data, &params);
296        assert_eq!(error.compression_ratio, 8.0);
297        // INT4 will have higher error than INT8
298    }
299
300    #[test]
301    fn test_zero_range_calibration() {
302        let data = vec![0.0, 0.0, 0.0];
303        let params = QuantParams::calibrate(&data, QuantBits::Int8, QuantScheme::Symmetric);
304        assert_eq!(params.scale, 1.0); // Should not be zero
305    }
306
307    #[test]
308    fn test_empty_calibration() {
309        let params = QuantParams::calibrate(&[], QuantBits::Int8, QuantScheme::Symmetric);
310        assert_eq!(params.scale, 1.0);
311    }
312
313    #[test]
314    fn test_quant_error_display() {
315        let error = QuantError { mse: 0.001, max_error: 0.01, snr_db: 40.0, compression_ratio: 4.0 };
316        let s = format!("{}", error);
317        assert!(s.contains("MSE"));
318        assert!(s.contains("4x"));
319    }
320}