oxirs_embed/
quantization.rs

1//! Quantization Support for Model Compression
2//!
3//! This module provides quantization techniques to compress knowledge graph
4//! embeddings by reducing precision from float32 to int8/int4, significantly
5//! reducing model size and improving inference speed.
6
7use anyhow::{anyhow, Result};
8use scirs2_core::ndarray_ext::Array1;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use tracing::{debug, info};
12
13/// Quantization scheme
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum QuantizationScheme {
16    /// Symmetric quantization (zero point = 0)
17    Symmetric,
18    /// Asymmetric quantization (learnable zero point)
19    Asymmetric,
20    /// Per-channel quantization
21    PerChannel,
22    /// Per-tensor quantization
23    PerTensor,
24}
25
26/// Quantization bit width
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28pub enum BitWidth {
29    /// 8-bit quantization
30    Int8,
31    /// 4-bit quantization
32    Int4,
33    /// Binary quantization (1-bit)
34    Binary,
35}
36
37impl BitWidth {
38    /// Get quantization range
39    pub fn range(&self) -> (i32, i32) {
40        match self {
41            BitWidth::Int8 => (-128, 127),
42            BitWidth::Int4 => (-8, 7),
43            BitWidth::Binary => (0, 1),
44        }
45    }
46
47    /// Get number of bits
48    pub fn bits(&self) -> usize {
49        match self {
50            BitWidth::Int8 => 8,
51            BitWidth::Int4 => 4,
52            BitWidth::Binary => 1,
53        }
54    }
55}
56
57/// Quantization configuration
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct QuantizationConfig {
60    /// Quantization scheme to use
61    pub scheme: QuantizationScheme,
62    /// Bit width for quantization
63    pub bit_width: BitWidth,
64    /// Enable calibration for better quantization
65    pub calibration: bool,
66    /// Number of calibration samples
67    pub calibration_samples: usize,
68    /// Quantize only weights (keep activations in float)
69    pub weights_only: bool,
70    /// Use quantization-aware training
71    pub qat: bool,
72}
73
74impl Default for QuantizationConfig {
75    fn default() -> Self {
76        Self {
77            scheme: QuantizationScheme::Symmetric,
78            bit_width: BitWidth::Int8,
79            calibration: true,
80            calibration_samples: 1000,
81            weights_only: true,
82            qat: false,
83        }
84    }
85}
86
87/// Quantization parameters
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct QuantizationParams {
90    /// Scale factor
91    pub scale: f32,
92    /// Zero point
93    pub zero_point: i32,
94    /// Min value observed during calibration
95    pub min_val: f32,
96    /// Max value observed during calibration
97    pub max_val: f32,
98}
99
100impl QuantizationParams {
101    /// Compute quantization parameters from tensor statistics
102    pub fn from_statistics(
103        min_val: f32,
104        max_val: f32,
105        bit_width: BitWidth,
106        symmetric: bool,
107    ) -> Self {
108        let (qmin, qmax) = bit_width.range();
109
110        let (scale, zero_point) = if symmetric {
111            // Symmetric quantization
112            let max_abs = min_val.abs().max(max_val.abs());
113            let scale = (2.0 * max_abs) / (qmax - qmin) as f32;
114            (scale, 0)
115        } else {
116            // Asymmetric quantization
117            let scale = (max_val - min_val) / (qmax - qmin) as f32;
118            let zero_point = qmin - (min_val / scale).round() as i32;
119            (scale, zero_point)
120        };
121
122        Self {
123            scale,
124            zero_point,
125            min_val,
126            max_val,
127        }
128    }
129
130    /// Quantize a float value
131    pub fn quantize(&self, value: f32, bit_width: BitWidth) -> i8 {
132        let (qmin, qmax) = bit_width.range();
133        let quantized = (value / self.scale).round() as i32 + self.zero_point;
134        quantized.clamp(qmin, qmax) as i8
135    }
136
137    /// Dequantize an int value back to float
138    pub fn dequantize(&self, quantized: i8) -> f32 {
139        (quantized as i32 - self.zero_point) as f32 * self.scale
140    }
141}
142
143/// Quantized tensor representation
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct QuantizedTensor {
146    /// Quantized values (int8)
147    pub values: Vec<i8>,
148    /// Quantization parameters
149    pub params: QuantizationParams,
150    /// Original shape
151    pub shape: Vec<usize>,
152}
153
154impl QuantizedTensor {
155    /// Create quantized tensor from float array
156    pub fn from_array(array: &Array1<f32>, config: &QuantizationConfig) -> Self {
157        let min_val = array.iter().cloned().fold(f32::INFINITY, f32::min);
158        let max_val = array.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
159
160        let symmetric = matches!(config.scheme, QuantizationScheme::Symmetric);
161        let params =
162            QuantizationParams::from_statistics(min_val, max_val, config.bit_width, symmetric);
163
164        let values: Vec<i8> = array
165            .iter()
166            .map(|&v| params.quantize(v, config.bit_width))
167            .collect();
168
169        Self {
170            values,
171            params,
172            shape: vec![array.len()],
173        }
174    }
175
176    /// Dequantize back to float array
177    pub fn to_array(&self) -> Array1<f32> {
178        Array1::from_vec(
179            self.values
180                .iter()
181                .map(|&v| self.params.dequantize(v))
182                .collect(),
183        )
184    }
185
186    /// Get compression ratio
187    pub fn compression_ratio(&self) -> f32 {
188        // Original: 4 bytes per float32
189        // Quantized: 1 byte per int8 + overhead for params
190        let original_size = self.values.len() * 4;
191        let quantized_size = self.values.len() + std::mem::size_of::<QuantizationParams>();
192        original_size as f32 / quantized_size as f32
193    }
194
195    /// Get size in bytes
196    pub fn size_bytes(&self) -> usize {
197        self.values.len() + std::mem::size_of::<QuantizationParams>()
198    }
199}
200
201/// Quantization statistics
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct QuantizationStats {
204    /// Total parameters quantized
205    pub total_params: usize,
206    /// Original model size (bytes)
207    pub original_size_bytes: usize,
208    /// Quantized model size (bytes)
209    pub quantized_size_bytes: usize,
210    /// Compression ratio
211    pub compression_ratio: f32,
212    /// Average quantization error
213    pub avg_quantization_error: f32,
214    /// Maximum quantization error
215    pub max_quantization_error: f32,
216}
217
218impl Default for QuantizationStats {
219    fn default() -> Self {
220        Self {
221            total_params: 0,
222            original_size_bytes: 0,
223            quantized_size_bytes: 0,
224            compression_ratio: 1.0,
225            avg_quantization_error: 0.0,
226            max_quantization_error: 0.0,
227        }
228    }
229}
230
231/// Model quantizer
232pub struct ModelQuantizer {
233    config: QuantizationConfig,
234    stats: QuantizationStats,
235}
236
237impl ModelQuantizer {
238    /// Create new model quantizer
239    pub fn new(config: QuantizationConfig) -> Self {
240        info!(
241            "Initialized model quantizer: scheme={:?}, bit_width={:?}",
242            config.scheme, config.bit_width
243        );
244
245        Self {
246            config,
247            stats: QuantizationStats::default(),
248        }
249    }
250
251    /// Quantize entity embeddings
252    pub fn quantize_embeddings(
253        &mut self,
254        embeddings: &HashMap<String, Array1<f32>>,
255    ) -> Result<HashMap<String, QuantizedTensor>> {
256        if embeddings.is_empty() {
257            return Err(anyhow!("No embeddings to quantize"));
258        }
259
260        info!("Quantizing {} embeddings", embeddings.len());
261
262        let mut quantized_embeddings = HashMap::new();
263        let mut total_error = 0.0;
264        let mut max_error: f32 = 0.0;
265
266        for (entity, embedding) in embeddings {
267            let quantized = QuantizedTensor::from_array(embedding, &self.config);
268
269            // Compute quantization error
270            let dequantized = quantized.to_array();
271            let error = self.compute_error(embedding, &dequantized);
272            total_error += error;
273            max_error = max_error.max(error);
274
275            // Update stats
276            self.stats.original_size_bytes += embedding.len() * 4;
277            self.stats.quantized_size_bytes += quantized.size_bytes();
278
279            quantized_embeddings.insert(entity.clone(), quantized);
280        }
281
282        self.stats.total_params = embeddings.values().map(|e| e.len()).sum();
283        self.stats.compression_ratio =
284            self.stats.original_size_bytes as f32 / self.stats.quantized_size_bytes as f32;
285        self.stats.avg_quantization_error = total_error / embeddings.len() as f32;
286        self.stats.max_quantization_error = max_error;
287
288        info!(
289            "Quantization complete: compression_ratio={:.2}x, avg_error={:.6}",
290            self.stats.compression_ratio, self.stats.avg_quantization_error
291        );
292
293        Ok(quantized_embeddings)
294    }
295
296    /// Dequantize embeddings
297    pub fn dequantize_embeddings(
298        &self,
299        quantized: &HashMap<String, QuantizedTensor>,
300    ) -> HashMap<String, Array1<f32>> {
301        quantized
302            .iter()
303            .map(|(entity, q)| (entity.clone(), q.to_array()))
304            .collect()
305    }
306
307    /// Quantize a single embedding
308    pub fn quantize_embedding(&self, embedding: &Array1<f32>) -> QuantizedTensor {
309        QuantizedTensor::from_array(embedding, &self.config)
310    }
311
312    /// Dequantize a single embedding
313    pub fn dequantize_embedding(&self, quantized: &QuantizedTensor) -> Array1<f32> {
314        quantized.to_array()
315    }
316
317    /// Compute mean squared error between original and dequantized
318    fn compute_error(&self, original: &Array1<f32>, dequantized: &Array1<f32>) -> f32 {
319        let diff = original - dequantized;
320        let mse = diff.dot(&diff) / original.len() as f32;
321        mse.sqrt() // RMSE
322    }
323
324    /// Calibrate quantization parameters using sample data
325    pub fn calibrate(&mut self, embeddings: &HashMap<String, Array1<f32>>) -> Result<()> {
326        if !self.config.calibration {
327            return Ok(());
328        }
329
330        info!(
331            "Calibrating quantization parameters with {} samples",
332            self.config.calibration_samples.min(embeddings.len())
333        );
334
335        // Collect statistics from sample embeddings
336        let samples: Vec<&Array1<f32>> = embeddings
337            .values()
338            .take(self.config.calibration_samples)
339            .collect();
340
341        // Find global min/max for per-tensor quantization
342        let mut global_min = f32::INFINITY;
343        let mut global_max = f32::NEG_INFINITY;
344
345        for embedding in samples {
346            let min = embedding.iter().cloned().fold(f32::INFINITY, f32::min);
347            let max = embedding.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
348            global_min = global_min.min(min);
349            global_max = global_max.max(max);
350        }
351
352        debug!(
353            "Calibration complete: min={:.6}, max={:.6}",
354            global_min, global_max
355        );
356
357        Ok(())
358    }
359
360    /// Get quantization statistics
361    pub fn get_stats(&self) -> &QuantizationStats {
362        &self.stats
363    }
364
365    /// Estimate inference speedup
366    pub fn estimate_speedup(&self) -> f32 {
367        // Int8 operations are typically 2-4x faster than float32
368        match self.config.bit_width {
369            BitWidth::Int8 => 3.0,
370            BitWidth::Int4 => 5.0,
371            BitWidth::Binary => 10.0,
372        }
373    }
374
375    /// Get configuration
376    pub fn config(&self) -> &QuantizationConfig {
377        &self.config
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use scirs2_core::ndarray_ext::array;
385
386    #[test]
387    fn test_quantization_params() {
388        let min_val = -10.0;
389        let max_val = 10.0;
390
391        let params = QuantizationParams::from_statistics(
392            min_val,
393            max_val,
394            BitWidth::Int8,
395            true, // symmetric
396        );
397
398        assert!(params.scale > 0.0);
399        assert_eq!(params.zero_point, 0); // Symmetric should have zero point = 0
400    }
401
402    #[test]
403    fn test_quantize_dequantize() {
404        let params = QuantizationParams::from_statistics(-10.0, 10.0, BitWidth::Int8, true);
405
406        let value = 5.0;
407        let quantized = params.quantize(value, BitWidth::Int8);
408        let dequantized = params.dequantize(quantized);
409
410        // Should be approximately equal (within quantization error)
411        assert!((value - dequantized).abs() < 1.0);
412    }
413
414    #[test]
415    fn test_quantized_tensor() {
416        // Use larger array (128 elements) so compression ratio > 1.0
417        // With small arrays, quantization params overhead dominates
418        let array = Array1::from_vec((0..128).map(|i| i as f32 * 0.1).collect());
419        let config = QuantizationConfig::default();
420
421        let quantized = QuantizedTensor::from_array(&array, &config);
422        let dequantized = quantized.to_array();
423
424        assert_eq!(quantized.values.len(), 128);
425        assert_eq!(dequantized.len(), 128);
426
427        // Check compression ratio (should be ~3.8x for 128-dim)
428        assert!(quantized.compression_ratio() > 1.0);
429    }
430
431    #[test]
432    fn test_model_quantizer() {
433        let mut embeddings = HashMap::new();
434        // Use larger embeddings (128-dim) for meaningful compression
435        embeddings.insert(
436            "e1".to_string(),
437            Array1::from_vec((0..128).map(|i| i as f32 * 0.1).collect()),
438        );
439        embeddings.insert(
440            "e2".to_string(),
441            Array1::from_vec((0..128).map(|i| (i as f32 * 0.1) + 10.0).collect()),
442        );
443
444        let config = QuantizationConfig::default();
445        let mut quantizer = ModelQuantizer::new(config);
446
447        let quantized = quantizer.quantize_embeddings(&embeddings).unwrap();
448
449        assert_eq!(quantized.len(), 2);
450        assert!(quantizer.stats.compression_ratio > 1.0);
451        assert!(quantizer.stats.avg_quantization_error >= 0.0);
452    }
453
454    #[test]
455    fn test_roundtrip() {
456        let mut embeddings = HashMap::new();
457        embeddings.insert("e1".to_string(), array![1.0, -2.0, 3.5, -4.2]);
458
459        let config = QuantizationConfig::default();
460        let mut quantizer = ModelQuantizer::new(config);
461
462        let quantized = quantizer.quantize_embeddings(&embeddings).unwrap();
463        let dequantized = quantizer.dequantize_embeddings(&quantized);
464
465        assert_eq!(dequantized.len(), 1);
466
467        // Values should be close to original
468        let original = &embeddings["e1"];
469        let recovered = &dequantized["e1"];
470
471        for i in 0..original.len() {
472            let error = (original[i] - recovered[i]).abs();
473            // Quantization error should be small but non-zero
474            assert!(error < 1.0);
475        }
476    }
477
478    #[test]
479    fn test_compression_ratio() {
480        let mut embeddings = HashMap::new();
481        for i in 0..100 {
482            let emb = Array1::from_vec(vec![i as f32; 128]);
483            embeddings.insert(format!("e{}", i), emb);
484        }
485
486        let config = QuantizationConfig::default();
487        let mut quantizer = ModelQuantizer::new(config);
488
489        quantizer.quantize_embeddings(&embeddings).unwrap();
490
491        // Int8 should give ~4x compression (32-bit to 8-bit)
492        assert!(quantizer.stats.compression_ratio > 3.0);
493        assert!(quantizer.stats.compression_ratio < 5.0);
494    }
495}