1use crate::error::ModelResult;
44use crate::mixed_precision::{BF16Weights, FP16Weights};
45use crate::quantization::{
46 quantize_symmetric_2d, quantize_symmetric_per_channel, QuantizationGranularity, QuantizedWeight,
47};
48use scirs2_core::ndarray::Array2;
49use std::collections::HashMap;
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum QuantStrategy {
54 None,
56 INT8WeightOnly,
58 FP16,
60 BF16,
62 INT8Dynamic,
64 MixedPrecision,
66}
67
68impl QuantStrategy {
69 pub fn compression_ratio(&self) -> f32 {
71 match self {
72 QuantStrategy::None => 1.0,
73 QuantStrategy::INT8WeightOnly => 4.0,
74 QuantStrategy::FP16 | QuantStrategy::BF16 => 2.0,
75 QuantStrategy::INT8Dynamic => 8.0, QuantStrategy::MixedPrecision => 3.0, }
78 }
79
80 pub fn quantizes_weights(&self) -> bool {
82 !matches!(self, QuantStrategy::None)
83 }
84
85 pub fn quantizes_activations(&self) -> bool {
87 matches!(self, QuantStrategy::INT8Dynamic)
88 }
89}
90
91#[derive(Debug, Clone)]
93pub enum QuantizedWeightStorage {
94 FP32(Array2<f32>),
96 INT8(QuantizedWeight),
98 FP16(FP16Weights),
100 BF16(BF16Weights),
102}
103
104impl QuantizedWeightStorage {
105 pub fn memory_size(&self) -> usize {
107 match self {
108 QuantizedWeightStorage::FP32(array) => array.len() * 4,
109 QuantizedWeightStorage::INT8(qw) => qw.memory_size(),
110 QuantizedWeightStorage::FP16(fp16) => fp16.memory_size(),
111 QuantizedWeightStorage::BF16(bf16) => bf16.data.len() * 2,
112 }
113 }
114
115 pub fn to_fp32(&self) -> ModelResult<Array2<f32>> {
117 match self {
118 QuantizedWeightStorage::FP32(array) => Ok(array.clone()),
119 QuantizedWeightStorage::INT8(qw) => qw.dequantize_2d(),
120 QuantizedWeightStorage::FP16(fp16) => fp16.to_f32_2d(),
121 QuantizedWeightStorage::BF16(bf16) => bf16.to_f32_2d(),
122 }
123 }
124
125 pub fn storage_type(&self) -> &'static str {
127 match self {
128 QuantizedWeightStorage::FP32(_) => "FP32",
129 QuantizedWeightStorage::INT8(_) => "INT8",
130 QuantizedWeightStorage::FP16(_) => "FP16",
131 QuantizedWeightStorage::BF16(_) => "BF16",
132 }
133 }
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum LayerSensitivity {
139 High,
141 Medium,
143 Low,
145}
146
147pub struct DynamicQuantizer {
149 strategy: QuantStrategy,
151 calibration_samples: usize,
153 granularity: QuantizationGranularity,
155 sensitivity_heuristics: HashMap<String, LayerSensitivity>,
157}
158
159impl DynamicQuantizer {
160 pub fn new() -> Self {
162 Self {
163 strategy: QuantStrategy::INT8WeightOnly,
164 calibration_samples: 100,
165 granularity: QuantizationGranularity::PerChannel,
166 sensitivity_heuristics: Self::default_sensitivity_heuristics(),
167 }
168 }
169
170 pub fn with_strategy(mut self, strategy: QuantStrategy) -> Self {
172 self.strategy = strategy;
173 self
174 }
175
176 pub fn with_calibration_samples(mut self, samples: usize) -> Self {
178 self.calibration_samples = samples;
179 self
180 }
181
182 pub fn with_granularity(mut self, granularity: QuantizationGranularity) -> Self {
184 self.granularity = granularity;
185 self
186 }
187
188 fn default_sensitivity_heuristics() -> HashMap<String, LayerSensitivity> {
196 let mut heuristics = HashMap::new();
197
198 heuristics.insert("input_proj".to_string(), LayerSensitivity::High);
200 heuristics.insert("output_proj".to_string(), LayerSensitivity::High);
201 heuristics.insert("ssm.log_a".to_string(), LayerSensitivity::High);
202 heuristics.insert("ssm.b_proj".to_string(), LayerSensitivity::High);
203 heuristics.insert("ssm.c_proj".to_string(), LayerSensitivity::High);
204
205 heuristics.insert("norm".to_string(), LayerSensitivity::Medium);
207 heuristics.insert("ln".to_string(), LayerSensitivity::Medium);
208 heuristics.insert("time_mix".to_string(), LayerSensitivity::Medium);
209
210 heuristics.insert("channel_mix".to_string(), LayerSensitivity::Low);
212 heuristics.insert("ffn".to_string(), LayerSensitivity::Low);
213 heuristics.insert("mlp".to_string(), LayerSensitivity::Low);
214
215 heuristics
216 }
217
218 pub fn classify_layer(&self, layer_name: &str) -> LayerSensitivity {
220 if let Some(&sensitivity) = self.sensitivity_heuristics.get(layer_name) {
222 return sensitivity;
223 }
224
225 for (pattern, &sensitivity) in &self.sensitivity_heuristics {
227 if layer_name.contains(pattern) {
228 return sensitivity;
229 }
230 }
231
232 LayerSensitivity::Medium
234 }
235
236 pub fn quantize_weight(
238 &self,
239 weight: &Array2<f32>,
240 layer_name: &str,
241 ) -> ModelResult<QuantizedWeightStorage> {
242 match self.strategy {
243 QuantStrategy::None => Ok(QuantizedWeightStorage::FP32(weight.clone())),
244
245 QuantStrategy::INT8WeightOnly => {
246 let quantized = match self.granularity {
247 QuantizationGranularity::PerTensor => quantize_symmetric_2d(weight)?,
248 QuantizationGranularity::PerChannel => quantize_symmetric_per_channel(weight)?,
249 };
250 Ok(QuantizedWeightStorage::INT8(quantized))
251 }
252
253 QuantStrategy::FP16 => {
254 let fp16_weights = FP16Weights::from_f32_2d(weight);
255 Ok(QuantizedWeightStorage::FP16(fp16_weights))
256 }
257
258 QuantStrategy::BF16 => {
259 let bf16_weights = BF16Weights::from_f32_2d(weight);
260 Ok(QuantizedWeightStorage::BF16(bf16_weights))
261 }
262
263 QuantStrategy::INT8Dynamic => {
264 let quantized = match self.granularity {
266 QuantizationGranularity::PerTensor => quantize_symmetric_2d(weight)?,
267 QuantizationGranularity::PerChannel => quantize_symmetric_per_channel(weight)?,
268 };
269 Ok(QuantizedWeightStorage::INT8(quantized))
270 }
271
272 QuantStrategy::MixedPrecision => {
273 let sensitivity = self.classify_layer(layer_name);
275
276 match sensitivity {
277 LayerSensitivity::High => {
278 Ok(QuantizedWeightStorage::FP32(weight.clone()))
280 }
281 LayerSensitivity::Medium => {
282 let fp16_weights = FP16Weights::from_f32_2d(weight);
284 Ok(QuantizedWeightStorage::FP16(fp16_weights))
285 }
286 LayerSensitivity::Low => {
287 let quantized = quantize_symmetric_per_channel(weight)?;
289 Ok(QuantizedWeightStorage::INT8(quantized))
290 }
291 }
292 }
293 }
294 }
295
296 pub fn quantize_weights(
298 &self,
299 weights: &HashMap<String, Array2<f32>>,
300 ) -> ModelResult<HashMap<String, QuantizedWeightStorage>> {
301 let mut quantized_weights = HashMap::new();
302
303 for (name, weight) in weights {
304 let quantized = self.quantize_weight(weight, name)?;
305 quantized_weights.insert(name.clone(), quantized);
306 }
307
308 Ok(quantized_weights)
309 }
310
311 pub fn calculate_memory_savings(
313 &self,
314 original_weights: &HashMap<String, Array2<f32>>,
315 quantized_weights: &HashMap<String, QuantizedWeightStorage>,
316 ) -> QuantizationStats {
317 let mut original_size = 0;
318 let mut quantized_size = 0;
319
320 for (name, original) in original_weights {
321 original_size += original.len() * 4; if let Some(quantized) = quantized_weights.get(name) {
324 quantized_size += quantized.memory_size();
325 }
326 }
327
328 let compression_ratio = original_size as f32 / quantized_size.max(1) as f32;
329 let memory_saved = original_size.saturating_sub(quantized_size);
330
331 QuantizationStats {
332 original_size_bytes: original_size,
333 quantized_size_bytes: quantized_size,
334 compression_ratio,
335 memory_saved_bytes: memory_saved,
336 strategy: self.strategy,
337 }
338 }
339
340 pub fn strategy(&self) -> QuantStrategy {
342 self.strategy
343 }
344
345 pub fn calibration_samples(&self) -> usize {
347 self.calibration_samples
348 }
349}
350
351impl Default for DynamicQuantizer {
352 fn default() -> Self {
353 Self::new()
354 }
355}
356
357#[derive(Debug, Clone)]
359pub struct QuantizationStats {
360 pub original_size_bytes: usize,
362 pub quantized_size_bytes: usize,
364 pub compression_ratio: f32,
366 pub memory_saved_bytes: usize,
368 pub strategy: QuantStrategy,
370}
371
372impl QuantizationStats {
373 pub fn format_size(bytes: usize) -> String {
375 const KB: usize = 1024;
376 const MB: usize = KB * 1024;
377 const GB: usize = MB * 1024;
378
379 if bytes >= GB {
380 format!("{:.2} GB", bytes as f64 / GB as f64)
381 } else if bytes >= MB {
382 format!("{:.2} MB", bytes as f64 / MB as f64)
383 } else if bytes >= KB {
384 format!("{:.2} KB", bytes as f64 / KB as f64)
385 } else {
386 format!("{} bytes", bytes)
387 }
388 }
389
390 pub fn print_summary(&self) {
392 println!("Quantization Summary");
393 println!("====================");
394 println!("Strategy: {:?}", self.strategy);
395 println!(
396 "Original Size: {}",
397 Self::format_size(self.original_size_bytes)
398 );
399 println!(
400 "Quantized Size: {}",
401 Self::format_size(self.quantized_size_bytes)
402 );
403 println!("Compression Ratio: {:.2}x", self.compression_ratio);
404 println!(
405 "Memory Saved: {} ({:.1}%)",
406 Self::format_size(self.memory_saved_bytes),
407 (self.memory_saved_bytes as f64 / self.original_size_bytes as f64) * 100.0
408 );
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use scirs2_core::ndarray::Array2;
416
417 #[test]
418 fn test_quant_strategy_compression_ratio() {
419 assert_eq!(QuantStrategy::None.compression_ratio(), 1.0);
420 assert_eq!(QuantStrategy::INT8WeightOnly.compression_ratio(), 4.0);
421 assert_eq!(QuantStrategy::FP16.compression_ratio(), 2.0);
422 assert_eq!(QuantStrategy::BF16.compression_ratio(), 2.0);
423 assert_eq!(QuantStrategy::INT8Dynamic.compression_ratio(), 8.0);
424 }
425
426 #[test]
427 fn test_dynamic_quantizer_creation() {
428 let quantizer = DynamicQuantizer::new();
429 assert_eq!(quantizer.strategy(), QuantStrategy::INT8WeightOnly);
430 assert_eq!(quantizer.calibration_samples(), 100);
431 }
432
433 #[test]
434 fn test_quantizer_with_strategy() {
435 let quantizer = DynamicQuantizer::new()
436 .with_strategy(QuantStrategy::FP16)
437 .with_calibration_samples(200);
438
439 assert_eq!(quantizer.strategy(), QuantStrategy::FP16);
440 assert_eq!(quantizer.calibration_samples(), 200);
441 }
442
443 #[test]
444 fn test_layer_sensitivity_classification() {
445 let quantizer = DynamicQuantizer::new();
446
447 assert_eq!(
448 quantizer.classify_layer("input_proj"),
449 LayerSensitivity::High
450 );
451 assert_eq!(
452 quantizer.classify_layer("layers.0.ssm.log_a"),
453 LayerSensitivity::High
454 );
455 assert_eq!(
456 quantizer.classify_layer("layers.0.norm.weight"),
457 LayerSensitivity::Medium
458 );
459 assert_eq!(
460 quantizer.classify_layer("layers.0.channel_mix.key"),
461 LayerSensitivity::Low
462 );
463 assert_eq!(
464 quantizer.classify_layer("unknown_layer"),
465 LayerSensitivity::Medium
466 ); }
468
469 #[test]
470 fn test_quantize_weight_int8() {
471 let quantizer = DynamicQuantizer::new().with_strategy(QuantStrategy::INT8WeightOnly);
472
473 let weight = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, -1.0, -2.0, -3.0])
474 .expect("Failed to create test array");
475
476 let quantized = quantizer
477 .quantize_weight(&weight, "test_layer")
478 .expect("Failed to quantize weight");
479
480 assert_eq!(quantized.storage_type(), "INT8");
481 assert!(quantized.memory_size() < weight.len() * 4);
482 }
483
484 #[test]
485 fn test_quantize_weight_fp16() {
486 let quantizer = DynamicQuantizer::new().with_strategy(QuantStrategy::FP16);
487
488 let weight = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, -1.0, -2.0, -3.0])
489 .expect("Failed to create test array");
490
491 let quantized = quantizer
492 .quantize_weight(&weight, "test_layer")
493 .expect("Failed to quantize weight");
494
495 assert_eq!(quantized.storage_type(), "FP16");
496 assert_eq!(quantized.memory_size(), weight.len() * 2); }
498
499 #[test]
500 fn test_quantize_weight_mixed_precision() {
501 let quantizer = DynamicQuantizer::new().with_strategy(QuantStrategy::MixedPrecision);
502
503 let weight = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, -1.0, -2.0, -3.0])
504 .expect("Failed to create test array");
505
506 let quantized_high = quantizer
508 .quantize_weight(&weight, "input_proj")
509 .expect("Failed to quantize weight");
510 assert_eq!(quantized_high.storage_type(), "FP32");
511
512 let quantized_medium = quantizer
514 .quantize_weight(&weight, "norm")
515 .expect("Failed to quantize weight");
516 assert_eq!(quantized_medium.storage_type(), "FP16");
517
518 let quantized_low = quantizer
520 .quantize_weight(&weight, "channel_mix")
521 .expect("Failed to quantize weight");
522 assert_eq!(quantized_low.storage_type(), "INT8");
523 }
524
525 #[test]
526 fn test_quantize_all_weights() {
527 let quantizer = DynamicQuantizer::new().with_strategy(QuantStrategy::INT8WeightOnly);
528
529 let mut weights = HashMap::new();
530 weights.insert(
531 "layer1".to_string(),
532 Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(),
533 );
534 weights.insert(
535 "layer2".to_string(),
536 Array2::from_shape_vec((2, 2), vec![-1.0, -2.0, -3.0, -4.0]).unwrap(),
537 );
538
539 let quantized = quantizer
540 .quantize_weights(&weights)
541 .expect("Failed to quantize weights");
542
543 assert_eq!(quantized.len(), 2);
544 assert!(quantized.contains_key("layer1"));
545 assert!(quantized.contains_key("layer2"));
546 }
547
548 #[test]
549 fn test_calculate_memory_savings() {
550 let quantizer = DynamicQuantizer::new().with_strategy(QuantStrategy::INT8WeightOnly);
551
552 let mut weights = HashMap::new();
553 weights.insert(
554 "layer1".to_string(),
555 Array2::from_shape_vec((100, 100), vec![1.0; 10000]).unwrap(),
556 );
557
558 let quantized = quantizer.quantize_weights(&weights).unwrap();
559 let stats = quantizer.calculate_memory_savings(&weights, &quantized);
560
561 assert_eq!(stats.original_size_bytes, 10000 * 4); assert_eq!(stats.quantized_size_bytes, 10000); assert!((stats.compression_ratio - 4.0).abs() < 0.01);
564 }
565
566 #[test]
567 fn test_quantization_stats_format() {
568 let stats = QuantizationStats {
569 original_size_bytes: 1024 * 1024 * 100, quantized_size_bytes: 1024 * 1024 * 25, compression_ratio: 4.0,
572 memory_saved_bytes: 1024 * 1024 * 75, strategy: QuantStrategy::INT8WeightOnly,
574 };
575
576 let formatted = QuantizationStats::format_size(stats.original_size_bytes);
577 assert!(formatted.contains("MB"));
578 }
579
580 #[test]
581 fn test_storage_to_fp32_roundtrip() {
582 let original = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, -1.0, -2.0, -3.0])
583 .expect("Failed to create test array");
584
585 let storage_fp32 = QuantizedWeightStorage::FP32(original.clone());
587 let restored = storage_fp32.to_fp32().expect("Failed to restore");
588 assert_eq!(restored, original);
589
590 let fp16 = FP16Weights::from_f32_2d(&original);
592 let storage_fp16 = QuantizedWeightStorage::FP16(fp16);
593 let restored_fp16 = storage_fp16.to_fp32().expect("Failed to restore");
594 assert_eq!(restored_fp16.dim(), original.dim());
595
596 let int8 = quantize_symmetric_2d(&original).expect("Failed to quantize");
598 let storage_int8 = QuantizedWeightStorage::INT8(int8);
599 let restored_int8 = storage_int8.to_fp32().expect("Failed to restore");
600 assert_eq!(restored_int8.dim(), original.dim());
601 }
602}