1use anyhow::{anyhow, Result};
8use scirs2_core::ndarray_ext::Array1;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use tracing::{debug, info};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum QuantizationScheme {
16 Symmetric,
18 Asymmetric,
20 PerChannel,
22 PerTensor,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28pub enum BitWidth {
29 Int8,
31 Int4,
33 Binary,
35}
36
37impl BitWidth {
38 pub fn range(&self) -> (i32, i32) {
40 match self {
41 BitWidth::Int8 => (-128, 127),
42 BitWidth::Int4 => (-8, 7),
43 BitWidth::Binary => (0, 1),
44 }
45 }
46
47 pub fn bits(&self) -> usize {
49 match self {
50 BitWidth::Int8 => 8,
51 BitWidth::Int4 => 4,
52 BitWidth::Binary => 1,
53 }
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct QuantizationConfig {
60 pub scheme: QuantizationScheme,
62 pub bit_width: BitWidth,
64 pub calibration: bool,
66 pub calibration_samples: usize,
68 pub weights_only: bool,
70 pub qat: bool,
72}
73
74impl Default for QuantizationConfig {
75 fn default() -> Self {
76 Self {
77 scheme: QuantizationScheme::Symmetric,
78 bit_width: BitWidth::Int8,
79 calibration: true,
80 calibration_samples: 1000,
81 weights_only: true,
82 qat: false,
83 }
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct QuantizationParams {
90 pub scale: f32,
92 pub zero_point: i32,
94 pub min_val: f32,
96 pub max_val: f32,
98}
99
100impl QuantizationParams {
101 pub fn from_statistics(
103 min_val: f32,
104 max_val: f32,
105 bit_width: BitWidth,
106 symmetric: bool,
107 ) -> Self {
108 let (qmin, qmax) = bit_width.range();
109
110 let (scale, zero_point) = if symmetric {
111 let max_abs = min_val.abs().max(max_val.abs());
113 let scale = (2.0 * max_abs) / (qmax - qmin) as f32;
114 (scale, 0)
115 } else {
116 let scale = (max_val - min_val) / (qmax - qmin) as f32;
118 let zero_point = qmin - (min_val / scale).round() as i32;
119 (scale, zero_point)
120 };
121
122 Self {
123 scale,
124 zero_point,
125 min_val,
126 max_val,
127 }
128 }
129
130 pub fn quantize(&self, value: f32, bit_width: BitWidth) -> i8 {
132 let (qmin, qmax) = bit_width.range();
133 let quantized = (value / self.scale).round() as i32 + self.zero_point;
134 quantized.clamp(qmin, qmax) as i8
135 }
136
137 pub fn dequantize(&self, quantized: i8) -> f32 {
139 (quantized as i32 - self.zero_point) as f32 * self.scale
140 }
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct QuantizedTensor {
146 pub values: Vec<i8>,
148 pub params: QuantizationParams,
150 pub shape: Vec<usize>,
152}
153
154impl QuantizedTensor {
155 pub fn from_array(array: &Array1<f32>, config: &QuantizationConfig) -> Self {
157 let min_val = array.iter().cloned().fold(f32::INFINITY, f32::min);
158 let max_val = array.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
159
160 let symmetric = matches!(config.scheme, QuantizationScheme::Symmetric);
161 let params =
162 QuantizationParams::from_statistics(min_val, max_val, config.bit_width, symmetric);
163
164 let values: Vec<i8> = array
165 .iter()
166 .map(|&v| params.quantize(v, config.bit_width))
167 .collect();
168
169 Self {
170 values,
171 params,
172 shape: vec![array.len()],
173 }
174 }
175
176 pub fn to_array(&self) -> Array1<f32> {
178 Array1::from_vec(
179 self.values
180 .iter()
181 .map(|&v| self.params.dequantize(v))
182 .collect(),
183 )
184 }
185
186 pub fn compression_ratio(&self) -> f32 {
188 let original_size = self.values.len() * 4;
191 let quantized_size = self.values.len() + std::mem::size_of::<QuantizationParams>();
192 original_size as f32 / quantized_size as f32
193 }
194
195 pub fn size_bytes(&self) -> usize {
197 self.values.len() + std::mem::size_of::<QuantizationParams>()
198 }
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct QuantizationStats {
204 pub total_params: usize,
206 pub original_size_bytes: usize,
208 pub quantized_size_bytes: usize,
210 pub compression_ratio: f32,
212 pub avg_quantization_error: f32,
214 pub max_quantization_error: f32,
216}
217
218impl Default for QuantizationStats {
219 fn default() -> Self {
220 Self {
221 total_params: 0,
222 original_size_bytes: 0,
223 quantized_size_bytes: 0,
224 compression_ratio: 1.0,
225 avg_quantization_error: 0.0,
226 max_quantization_error: 0.0,
227 }
228 }
229}
230
231pub struct ModelQuantizer {
233 config: QuantizationConfig,
234 stats: QuantizationStats,
235}
236
237impl ModelQuantizer {
238 pub fn new(config: QuantizationConfig) -> Self {
240 info!(
241 "Initialized model quantizer: scheme={:?}, bit_width={:?}",
242 config.scheme, config.bit_width
243 );
244
245 Self {
246 config,
247 stats: QuantizationStats::default(),
248 }
249 }
250
251 pub fn quantize_embeddings(
253 &mut self,
254 embeddings: &HashMap<String, Array1<f32>>,
255 ) -> Result<HashMap<String, QuantizedTensor>> {
256 if embeddings.is_empty() {
257 return Err(anyhow!("No embeddings to quantize"));
258 }
259
260 info!("Quantizing {} embeddings", embeddings.len());
261
262 let mut quantized_embeddings = HashMap::new();
263 let mut total_error = 0.0;
264 let mut max_error: f32 = 0.0;
265
266 for (entity, embedding) in embeddings {
267 let quantized = QuantizedTensor::from_array(embedding, &self.config);
268
269 let dequantized = quantized.to_array();
271 let error = self.compute_error(embedding, &dequantized);
272 total_error += error;
273 max_error = max_error.max(error);
274
275 self.stats.original_size_bytes += embedding.len() * 4;
277 self.stats.quantized_size_bytes += quantized.size_bytes();
278
279 quantized_embeddings.insert(entity.clone(), quantized);
280 }
281
282 self.stats.total_params = embeddings.values().map(|e| e.len()).sum();
283 self.stats.compression_ratio =
284 self.stats.original_size_bytes as f32 / self.stats.quantized_size_bytes as f32;
285 self.stats.avg_quantization_error = total_error / embeddings.len() as f32;
286 self.stats.max_quantization_error = max_error;
287
288 info!(
289 "Quantization complete: compression_ratio={:.2}x, avg_error={:.6}",
290 self.stats.compression_ratio, self.stats.avg_quantization_error
291 );
292
293 Ok(quantized_embeddings)
294 }
295
296 pub fn dequantize_embeddings(
298 &self,
299 quantized: &HashMap<String, QuantizedTensor>,
300 ) -> HashMap<String, Array1<f32>> {
301 quantized
302 .iter()
303 .map(|(entity, q)| (entity.clone(), q.to_array()))
304 .collect()
305 }
306
307 pub fn quantize_embedding(&self, embedding: &Array1<f32>) -> QuantizedTensor {
309 QuantizedTensor::from_array(embedding, &self.config)
310 }
311
312 pub fn dequantize_embedding(&self, quantized: &QuantizedTensor) -> Array1<f32> {
314 quantized.to_array()
315 }
316
317 fn compute_error(&self, original: &Array1<f32>, dequantized: &Array1<f32>) -> f32 {
319 let diff = original - dequantized;
320 let mse = diff.dot(&diff) / original.len() as f32;
321 mse.sqrt() }
323
324 pub fn calibrate(&mut self, embeddings: &HashMap<String, Array1<f32>>) -> Result<()> {
326 if !self.config.calibration {
327 return Ok(());
328 }
329
330 info!(
331 "Calibrating quantization parameters with {} samples",
332 self.config.calibration_samples.min(embeddings.len())
333 );
334
335 let samples: Vec<&Array1<f32>> = embeddings
337 .values()
338 .take(self.config.calibration_samples)
339 .collect();
340
341 let mut global_min = f32::INFINITY;
343 let mut global_max = f32::NEG_INFINITY;
344
345 for embedding in samples {
346 let min = embedding.iter().cloned().fold(f32::INFINITY, f32::min);
347 let max = embedding.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
348 global_min = global_min.min(min);
349 global_max = global_max.max(max);
350 }
351
352 debug!(
353 "Calibration complete: min={:.6}, max={:.6}",
354 global_min, global_max
355 );
356
357 Ok(())
358 }
359
360 pub fn get_stats(&self) -> &QuantizationStats {
362 &self.stats
363 }
364
365 pub fn estimate_speedup(&self) -> f32 {
367 match self.config.bit_width {
369 BitWidth::Int8 => 3.0,
370 BitWidth::Int4 => 5.0,
371 BitWidth::Binary => 10.0,
372 }
373 }
374
375 pub fn config(&self) -> &QuantizationConfig {
377 &self.config
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use scirs2_core::ndarray_ext::array;
385
386 #[test]
387 fn test_quantization_params() {
388 let min_val = -10.0;
389 let max_val = 10.0;
390
391 let params = QuantizationParams::from_statistics(
392 min_val,
393 max_val,
394 BitWidth::Int8,
395 true, );
397
398 assert!(params.scale > 0.0);
399 assert_eq!(params.zero_point, 0); }
401
402 #[test]
403 fn test_quantize_dequantize() {
404 let params = QuantizationParams::from_statistics(-10.0, 10.0, BitWidth::Int8, true);
405
406 let value = 5.0;
407 let quantized = params.quantize(value, BitWidth::Int8);
408 let dequantized = params.dequantize(quantized);
409
410 assert!((value - dequantized).abs() < 1.0);
412 }
413
414 #[test]
415 fn test_quantized_tensor() {
416 let array = Array1::from_vec((0..128).map(|i| i as f32 * 0.1).collect());
419 let config = QuantizationConfig::default();
420
421 let quantized = QuantizedTensor::from_array(&array, &config);
422 let dequantized = quantized.to_array();
423
424 assert_eq!(quantized.values.len(), 128);
425 assert_eq!(dequantized.len(), 128);
426
427 assert!(quantized.compression_ratio() > 1.0);
429 }
430
431 #[test]
432 fn test_model_quantizer() {
433 let mut embeddings = HashMap::new();
434 embeddings.insert(
436 "e1".to_string(),
437 Array1::from_vec((0..128).map(|i| i as f32 * 0.1).collect()),
438 );
439 embeddings.insert(
440 "e2".to_string(),
441 Array1::from_vec((0..128).map(|i| (i as f32 * 0.1) + 10.0).collect()),
442 );
443
444 let config = QuantizationConfig::default();
445 let mut quantizer = ModelQuantizer::new(config);
446
447 let quantized = quantizer.quantize_embeddings(&embeddings).unwrap();
448
449 assert_eq!(quantized.len(), 2);
450 assert!(quantizer.stats.compression_ratio > 1.0);
451 assert!(quantizer.stats.avg_quantization_error >= 0.0);
452 }
453
454 #[test]
455 fn test_roundtrip() {
456 let mut embeddings = HashMap::new();
457 embeddings.insert("e1".to_string(), array![1.0, -2.0, 3.5, -4.2]);
458
459 let config = QuantizationConfig::default();
460 let mut quantizer = ModelQuantizer::new(config);
461
462 let quantized = quantizer.quantize_embeddings(&embeddings).unwrap();
463 let dequantized = quantizer.dequantize_embeddings(&quantized);
464
465 assert_eq!(dequantized.len(), 1);
466
467 let original = &embeddings["e1"];
469 let recovered = &dequantized["e1"];
470
471 for i in 0..original.len() {
472 let error = (original[i] - recovered[i]).abs();
473 assert!(error < 1.0);
475 }
476 }
477
478 #[test]
479 fn test_compression_ratio() {
480 let mut embeddings = HashMap::new();
481 for i in 0..100 {
482 let emb = Array1::from_vec(vec![i as f32; 128]);
483 embeddings.insert(format!("e{}", i), emb);
484 }
485
486 let config = QuantizationConfig::default();
487 let mut quantizer = ModelQuantizer::new(config);
488
489 quantizer.quantize_embeddings(&embeddings).unwrap();
490
491 assert!(quantizer.stats.compression_ratio > 3.0);
493 assert!(quantizer.stats.compression_ratio < 5.0);
494 }
495}