kizzasi_model/
dynamic_quantization.rs

1//! Dynamic Quantization for On-the-Fly Model Compression
2//!
3//! Provides automatic weight quantization during model loading for memory-efficient inference.
4//!
5//! # Features
6//!
7//! - **Weight-Only Quantization**: Quantize weights while keeping activations in FP32
8//! - **Dynamic Quantization**: Quantize weights and activations at runtime
9//! - **Mixed Precision**: Selective quantization based on layer sensitivity
10//! - **Multiple Backends**: INT8, FP16, BF16 support
11//! - **HuggingFace Integration**: Automatic quantization on model load
12//!
13//! # Quantization Strategies
14//!
15//! ## INT8 Weight-Only Quantization
16//! - Quantize weights to INT8 (4x compression)
17//! - Keep activations in FP32 for accuracy
18//! - Best for memory-bound workloads
19//!
20//! ## FP16 Mixed Precision
21//! - Convert weights to FP16 (2x compression)
22//! - Better accuracy than INT8
23//! - Hardware acceleration on modern GPUs
24//!
25//! ## Dynamic Quantization
26//! - Quantize both weights and activations
27//! - Maximum memory savings (8x with INT8)
28//! - Automatic calibration from data
29//!
30//! # Example
31//!
32//! ```rust,ignore
33//! use kizzasi_model::dynamic_quantization::*;
34//!
35//! // Load and quantize HuggingFace model
36//! let quantizer = DynamicQuantizer::new()
37//!     .with_strategy(QuantStrategy::INT8WeightOnly)
38//!     .with_calibration_samples(100);
39//!
40//! let quantized_weights = quantizer.quantize_weights(&weights)?;
41//! ```
42
43use crate::error::ModelResult;
44use crate::mixed_precision::{BF16Weights, FP16Weights};
45use crate::quantization::{
46    quantize_symmetric_2d, quantize_symmetric_per_channel, QuantizationGranularity, QuantizedWeight,
47};
48use scirs2_core::ndarray::Array2;
49use std::collections::HashMap;
50
51/// Quantization strategy
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum QuantStrategy {
54    /// No quantization (FP32)
55    None,
56    /// Quantize weights to INT8, keep activations in FP32
57    INT8WeightOnly,
58    /// Quantize weights to FP16
59    FP16,
60    /// Quantize weights to BF16
61    BF16,
62    /// Quantize both weights and activations to INT8 (dynamic)
63    INT8Dynamic,
64    /// Mixed precision: sensitive layers in FP32, others in INT8
65    MixedPrecision,
66}
67
68impl QuantStrategy {
69    /// Get memory compression ratio
70    pub fn compression_ratio(&self) -> f32 {
71        match self {
72            QuantStrategy::None => 1.0,
73            QuantStrategy::INT8WeightOnly => 4.0,
74            QuantStrategy::FP16 | QuantStrategy::BF16 => 2.0,
75            QuantStrategy::INT8Dynamic => 8.0, // weights + activations
76            QuantStrategy::MixedPrecision => 3.0, // average
77        }
78    }
79
80    /// Check if strategy quantizes weights
81    pub fn quantizes_weights(&self) -> bool {
82        !matches!(self, QuantStrategy::None)
83    }
84
85    /// Check if strategy quantizes activations
86    pub fn quantizes_activations(&self) -> bool {
87        matches!(self, QuantStrategy::INT8Dynamic)
88    }
89}
90
91/// Quantized model weights storage
92#[derive(Debug, Clone)]
93pub enum QuantizedWeightStorage {
94    /// Original FP32 weights (no quantization)
95    FP32(Array2<f32>),
96    /// INT8 quantized weights
97    INT8(QuantizedWeight),
98    /// FP16 weights
99    FP16(FP16Weights),
100    /// BF16 weights
101    BF16(BF16Weights),
102}
103
104impl QuantizedWeightStorage {
105    /// Get memory size in bytes
106    pub fn memory_size(&self) -> usize {
107        match self {
108            QuantizedWeightStorage::FP32(array) => array.len() * 4,
109            QuantizedWeightStorage::INT8(qw) => qw.memory_size(),
110            QuantizedWeightStorage::FP16(fp16) => fp16.memory_size(),
111            QuantizedWeightStorage::BF16(bf16) => bf16.data.len() * 2,
112        }
113    }
114
115    /// Convert to FP32 array for inference
116    pub fn to_fp32(&self) -> ModelResult<Array2<f32>> {
117        match self {
118            QuantizedWeightStorage::FP32(array) => Ok(array.clone()),
119            QuantizedWeightStorage::INT8(qw) => qw.dequantize_2d(),
120            QuantizedWeightStorage::FP16(fp16) => fp16.to_f32_2d(),
121            QuantizedWeightStorage::BF16(bf16) => bf16.to_f32_2d(),
122        }
123    }
124
125    /// Get weight storage type as string
126    pub fn storage_type(&self) -> &'static str {
127        match self {
128            QuantizedWeightStorage::FP32(_) => "FP32",
129            QuantizedWeightStorage::INT8(_) => "INT8",
130            QuantizedWeightStorage::FP16(_) => "FP16",
131            QuantizedWeightStorage::BF16(_) => "BF16",
132        }
133    }
134}
135
136/// Layer sensitivity classification for mixed precision
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum LayerSensitivity {
139    /// High sensitivity - keep in FP32
140    High,
141    /// Medium sensitivity - use FP16
142    Medium,
143    /// Low sensitivity - can use INT8
144    Low,
145}
146
147/// Dynamic quantizer for automatic model compression
148pub struct DynamicQuantizer {
149    /// Quantization strategy
150    strategy: QuantStrategy,
151    /// Number of calibration samples for dynamic quantization
152    calibration_samples: usize,
153    /// Granularity for INT8 quantization
154    granularity: QuantizationGranularity,
155    /// Layer sensitivity heuristics
156    sensitivity_heuristics: HashMap<String, LayerSensitivity>,
157}
158
159impl DynamicQuantizer {
160    /// Create a new dynamic quantizer with default settings
161    pub fn new() -> Self {
162        Self {
163            strategy: QuantStrategy::INT8WeightOnly,
164            calibration_samples: 100,
165            granularity: QuantizationGranularity::PerChannel,
166            sensitivity_heuristics: Self::default_sensitivity_heuristics(),
167        }
168    }
169
170    /// Set quantization strategy
171    pub fn with_strategy(mut self, strategy: QuantStrategy) -> Self {
172        self.strategy = strategy;
173        self
174    }
175
176    /// Set number of calibration samples
177    pub fn with_calibration_samples(mut self, samples: usize) -> Self {
178        self.calibration_samples = samples;
179        self
180    }
181
182    /// Set quantization granularity for INT8
183    pub fn with_granularity(mut self, granularity: QuantizationGranularity) -> Self {
184        self.granularity = granularity;
185        self
186    }
187
188    /// Default layer sensitivity heuristics
189    ///
190    /// Based on common SSM architecture patterns:
191    /// - Input/output projections: High sensitivity (first/last layers)
192    /// - SSM parameters (A, B, C matrices): High sensitivity (state dynamics)
193    /// - Layer norms: Medium sensitivity
194    /// - MLP/FFN layers: Low sensitivity (most compressible)
195    fn default_sensitivity_heuristics() -> HashMap<String, LayerSensitivity> {
196        let mut heuristics = HashMap::new();
197
198        // High sensitivity layers
199        heuristics.insert("input_proj".to_string(), LayerSensitivity::High);
200        heuristics.insert("output_proj".to_string(), LayerSensitivity::High);
201        heuristics.insert("ssm.log_a".to_string(), LayerSensitivity::High);
202        heuristics.insert("ssm.b_proj".to_string(), LayerSensitivity::High);
203        heuristics.insert("ssm.c_proj".to_string(), LayerSensitivity::High);
204
205        // Medium sensitivity layers
206        heuristics.insert("norm".to_string(), LayerSensitivity::Medium);
207        heuristics.insert("ln".to_string(), LayerSensitivity::Medium);
208        heuristics.insert("time_mix".to_string(), LayerSensitivity::Medium);
209
210        // Low sensitivity layers (FFN, channel mixing)
211        heuristics.insert("channel_mix".to_string(), LayerSensitivity::Low);
212        heuristics.insert("ffn".to_string(), LayerSensitivity::Low);
213        heuristics.insert("mlp".to_string(), LayerSensitivity::Low);
214
215        heuristics
216    }
217
218    /// Classify layer sensitivity based on name
219    pub fn classify_layer(&self, layer_name: &str) -> LayerSensitivity {
220        // Check exact matches first
221        if let Some(&sensitivity) = self.sensitivity_heuristics.get(layer_name) {
222            return sensitivity;
223        }
224
225        // Check partial matches
226        for (pattern, &sensitivity) in &self.sensitivity_heuristics {
227            if layer_name.contains(pattern) {
228                return sensitivity;
229            }
230        }
231
232        // Default to medium sensitivity
233        LayerSensitivity::Medium
234    }
235
236    /// Quantize a single weight tensor
237    pub fn quantize_weight(
238        &self,
239        weight: &Array2<f32>,
240        layer_name: &str,
241    ) -> ModelResult<QuantizedWeightStorage> {
242        match self.strategy {
243            QuantStrategy::None => Ok(QuantizedWeightStorage::FP32(weight.clone())),
244
245            QuantStrategy::INT8WeightOnly => {
246                let quantized = match self.granularity {
247                    QuantizationGranularity::PerTensor => quantize_symmetric_2d(weight)?,
248                    QuantizationGranularity::PerChannel => quantize_symmetric_per_channel(weight)?,
249                };
250                Ok(QuantizedWeightStorage::INT8(quantized))
251            }
252
253            QuantStrategy::FP16 => {
254                let fp16_weights = FP16Weights::from_f32_2d(weight);
255                Ok(QuantizedWeightStorage::FP16(fp16_weights))
256            }
257
258            QuantStrategy::BF16 => {
259                let bf16_weights = BF16Weights::from_f32_2d(weight);
260                Ok(QuantizedWeightStorage::BF16(bf16_weights))
261            }
262
263            QuantStrategy::INT8Dynamic => {
264                // Dynamic quantization: same as weight-only for weights
265                let quantized = match self.granularity {
266                    QuantizationGranularity::PerTensor => quantize_symmetric_2d(weight)?,
267                    QuantizationGranularity::PerChannel => quantize_symmetric_per_channel(weight)?,
268                };
269                Ok(QuantizedWeightStorage::INT8(quantized))
270            }
271
272            QuantStrategy::MixedPrecision => {
273                // Selective quantization based on layer sensitivity
274                let sensitivity = self.classify_layer(layer_name);
275
276                match sensitivity {
277                    LayerSensitivity::High => {
278                        // Keep high-sensitivity layers in FP32
279                        Ok(QuantizedWeightStorage::FP32(weight.clone()))
280                    }
281                    LayerSensitivity::Medium => {
282                        // Medium sensitivity: use FP16
283                        let fp16_weights = FP16Weights::from_f32_2d(weight);
284                        Ok(QuantizedWeightStorage::FP16(fp16_weights))
285                    }
286                    LayerSensitivity::Low => {
287                        // Low sensitivity: use INT8
288                        let quantized = quantize_symmetric_per_channel(weight)?;
289                        Ok(QuantizedWeightStorage::INT8(quantized))
290                    }
291                }
292            }
293        }
294    }
295
296    /// Quantize all weights in a model
297    pub fn quantize_weights(
298        &self,
299        weights: &HashMap<String, Array2<f32>>,
300    ) -> ModelResult<HashMap<String, QuantizedWeightStorage>> {
301        let mut quantized_weights = HashMap::new();
302
303        for (name, weight) in weights {
304            let quantized = self.quantize_weight(weight, name)?;
305            quantized_weights.insert(name.clone(), quantized);
306        }
307
308        Ok(quantized_weights)
309    }
310
311    /// Calculate total memory savings
312    pub fn calculate_memory_savings(
313        &self,
314        original_weights: &HashMap<String, Array2<f32>>,
315        quantized_weights: &HashMap<String, QuantizedWeightStorage>,
316    ) -> QuantizationStats {
317        let mut original_size = 0;
318        let mut quantized_size = 0;
319
320        for (name, original) in original_weights {
321            original_size += original.len() * 4; // FP32: 4 bytes
322
323            if let Some(quantized) = quantized_weights.get(name) {
324                quantized_size += quantized.memory_size();
325            }
326        }
327
328        let compression_ratio = original_size as f32 / quantized_size.max(1) as f32;
329        let memory_saved = original_size.saturating_sub(quantized_size);
330
331        QuantizationStats {
332            original_size_bytes: original_size,
333            quantized_size_bytes: quantized_size,
334            compression_ratio,
335            memory_saved_bytes: memory_saved,
336            strategy: self.strategy,
337        }
338    }
339
340    /// Get quantization strategy
341    pub fn strategy(&self) -> QuantStrategy {
342        self.strategy
343    }
344
345    /// Get calibration sample count
346    pub fn calibration_samples(&self) -> usize {
347        self.calibration_samples
348    }
349}
350
351impl Default for DynamicQuantizer {
352    fn default() -> Self {
353        Self::new()
354    }
355}
356
357/// Quantization statistics
358#[derive(Debug, Clone)]
359pub struct QuantizationStats {
360    /// Original model size in bytes
361    pub original_size_bytes: usize,
362    /// Quantized model size in bytes
363    pub quantized_size_bytes: usize,
364    /// Compression ratio (original / quantized)
365    pub compression_ratio: f32,
366    /// Memory saved in bytes
367    pub memory_saved_bytes: usize,
368    /// Strategy used
369    pub strategy: QuantStrategy,
370}
371
372impl QuantizationStats {
373    /// Format size as human-readable string
374    pub fn format_size(bytes: usize) -> String {
375        const KB: usize = 1024;
376        const MB: usize = KB * 1024;
377        const GB: usize = MB * 1024;
378
379        if bytes >= GB {
380            format!("{:.2} GB", bytes as f64 / GB as f64)
381        } else if bytes >= MB {
382            format!("{:.2} MB", bytes as f64 / MB as f64)
383        } else if bytes >= KB {
384            format!("{:.2} KB", bytes as f64 / KB as f64)
385        } else {
386            format!("{} bytes", bytes)
387        }
388    }
389
390    /// Print summary
391    pub fn print_summary(&self) {
392        println!("Quantization Summary");
393        println!("====================");
394        println!("Strategy: {:?}", self.strategy);
395        println!(
396            "Original Size: {}",
397            Self::format_size(self.original_size_bytes)
398        );
399        println!(
400            "Quantized Size: {}",
401            Self::format_size(self.quantized_size_bytes)
402        );
403        println!("Compression Ratio: {:.2}x", self.compression_ratio);
404        println!(
405            "Memory Saved: {} ({:.1}%)",
406            Self::format_size(self.memory_saved_bytes),
407            (self.memory_saved_bytes as f64 / self.original_size_bytes as f64) * 100.0
408        );
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use scirs2_core::ndarray::Array2;
416
417    #[test]
418    fn test_quant_strategy_compression_ratio() {
419        assert_eq!(QuantStrategy::None.compression_ratio(), 1.0);
420        assert_eq!(QuantStrategy::INT8WeightOnly.compression_ratio(), 4.0);
421        assert_eq!(QuantStrategy::FP16.compression_ratio(), 2.0);
422        assert_eq!(QuantStrategy::BF16.compression_ratio(), 2.0);
423        assert_eq!(QuantStrategy::INT8Dynamic.compression_ratio(), 8.0);
424    }
425
426    #[test]
427    fn test_dynamic_quantizer_creation() {
428        let quantizer = DynamicQuantizer::new();
429        assert_eq!(quantizer.strategy(), QuantStrategy::INT8WeightOnly);
430        assert_eq!(quantizer.calibration_samples(), 100);
431    }
432
433    #[test]
434    fn test_quantizer_with_strategy() {
435        let quantizer = DynamicQuantizer::new()
436            .with_strategy(QuantStrategy::FP16)
437            .with_calibration_samples(200);
438
439        assert_eq!(quantizer.strategy(), QuantStrategy::FP16);
440        assert_eq!(quantizer.calibration_samples(), 200);
441    }
442
443    #[test]
444    fn test_layer_sensitivity_classification() {
445        let quantizer = DynamicQuantizer::new();
446
447        assert_eq!(
448            quantizer.classify_layer("input_proj"),
449            LayerSensitivity::High
450        );
451        assert_eq!(
452            quantizer.classify_layer("layers.0.ssm.log_a"),
453            LayerSensitivity::High
454        );
455        assert_eq!(
456            quantizer.classify_layer("layers.0.norm.weight"),
457            LayerSensitivity::Medium
458        );
459        assert_eq!(
460            quantizer.classify_layer("layers.0.channel_mix.key"),
461            LayerSensitivity::Low
462        );
463        assert_eq!(
464            quantizer.classify_layer("unknown_layer"),
465            LayerSensitivity::Medium
466        ); // default
467    }
468
469    #[test]
470    fn test_quantize_weight_int8() {
471        let quantizer = DynamicQuantizer::new().with_strategy(QuantStrategy::INT8WeightOnly);
472
473        let weight = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, -1.0, -2.0, -3.0])
474            .expect("Failed to create test array");
475
476        let quantized = quantizer
477            .quantize_weight(&weight, "test_layer")
478            .expect("Failed to quantize weight");
479
480        assert_eq!(quantized.storage_type(), "INT8");
481        assert!(quantized.memory_size() < weight.len() * 4);
482    }
483
484    #[test]
485    fn test_quantize_weight_fp16() {
486        let quantizer = DynamicQuantizer::new().with_strategy(QuantStrategy::FP16);
487
488        let weight = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, -1.0, -2.0, -3.0])
489            .expect("Failed to create test array");
490
491        let quantized = quantizer
492            .quantize_weight(&weight, "test_layer")
493            .expect("Failed to quantize weight");
494
495        assert_eq!(quantized.storage_type(), "FP16");
496        assert_eq!(quantized.memory_size(), weight.len() * 2); // FP16 = 2 bytes
497    }
498
499    #[test]
500    fn test_quantize_weight_mixed_precision() {
501        let quantizer = DynamicQuantizer::new().with_strategy(QuantStrategy::MixedPrecision);
502
503        let weight = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, -1.0, -2.0, -3.0])
504            .expect("Failed to create test array");
505
506        // High sensitivity layer - should stay FP32
507        let quantized_high = quantizer
508            .quantize_weight(&weight, "input_proj")
509            .expect("Failed to quantize weight");
510        assert_eq!(quantized_high.storage_type(), "FP32");
511
512        // Medium sensitivity layer - should be FP16
513        let quantized_medium = quantizer
514            .quantize_weight(&weight, "norm")
515            .expect("Failed to quantize weight");
516        assert_eq!(quantized_medium.storage_type(), "FP16");
517
518        // Low sensitivity layer - should be INT8
519        let quantized_low = quantizer
520            .quantize_weight(&weight, "channel_mix")
521            .expect("Failed to quantize weight");
522        assert_eq!(quantized_low.storage_type(), "INT8");
523    }
524
525    #[test]
526    fn test_quantize_all_weights() {
527        let quantizer = DynamicQuantizer::new().with_strategy(QuantStrategy::INT8WeightOnly);
528
529        let mut weights = HashMap::new();
530        weights.insert(
531            "layer1".to_string(),
532            Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(),
533        );
534        weights.insert(
535            "layer2".to_string(),
536            Array2::from_shape_vec((2, 2), vec![-1.0, -2.0, -3.0, -4.0]).unwrap(),
537        );
538
539        let quantized = quantizer
540            .quantize_weights(&weights)
541            .expect("Failed to quantize weights");
542
543        assert_eq!(quantized.len(), 2);
544        assert!(quantized.contains_key("layer1"));
545        assert!(quantized.contains_key("layer2"));
546    }
547
548    #[test]
549    fn test_calculate_memory_savings() {
550        let quantizer = DynamicQuantizer::new().with_strategy(QuantStrategy::INT8WeightOnly);
551
552        let mut weights = HashMap::new();
553        weights.insert(
554            "layer1".to_string(),
555            Array2::from_shape_vec((100, 100), vec![1.0; 10000]).unwrap(),
556        );
557
558        let quantized = quantizer.quantize_weights(&weights).unwrap();
559        let stats = quantizer.calculate_memory_savings(&weights, &quantized);
560
561        assert_eq!(stats.original_size_bytes, 10000 * 4); // FP32
562        assert_eq!(stats.quantized_size_bytes, 10000); // INT8
563        assert!((stats.compression_ratio - 4.0).abs() < 0.01);
564    }
565
566    #[test]
567    fn test_quantization_stats_format() {
568        let stats = QuantizationStats {
569            original_size_bytes: 1024 * 1024 * 100, // 100 MB
570            quantized_size_bytes: 1024 * 1024 * 25, // 25 MB
571            compression_ratio: 4.0,
572            memory_saved_bytes: 1024 * 1024 * 75, // 75 MB
573            strategy: QuantStrategy::INT8WeightOnly,
574        };
575
576        let formatted = QuantizationStats::format_size(stats.original_size_bytes);
577        assert!(formatted.contains("MB"));
578    }
579
580    #[test]
581    fn test_storage_to_fp32_roundtrip() {
582        let original = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, -1.0, -2.0, -3.0])
583            .expect("Failed to create test array");
584
585        // Test FP32 storage
586        let storage_fp32 = QuantizedWeightStorage::FP32(original.clone());
587        let restored = storage_fp32.to_fp32().expect("Failed to restore");
588        assert_eq!(restored, original);
589
590        // Test FP16 storage
591        let fp16 = FP16Weights::from_f32_2d(&original);
592        let storage_fp16 = QuantizedWeightStorage::FP16(fp16);
593        let restored_fp16 = storage_fp16.to_fp32().expect("Failed to restore");
594        assert_eq!(restored_fp16.dim(), original.dim());
595
596        // Test INT8 storage
597        let int8 = quantize_symmetric_2d(&original).expect("Failed to quantize");
598        let storage_int8 = QuantizedWeightStorage::INT8(int8);
599        let restored_int8 = storage_int8.to_fp32().expect("Failed to restore");
600        assert_eq!(restored_int8.dim(), original.dim());
601    }
602}