use anyhow::{anyhow, Result};
use scirs2_core::ndarray_ext::Array1;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, info};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuantizationScheme {
Symmetric,
Asymmetric,
PerChannel,
PerTensor,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BitWidth {
Int8,
Int4,
Binary,
}
impl BitWidth {
pub fn range(&self) -> (i32, i32) {
match self {
BitWidth::Int8 => (-128, 127),
BitWidth::Int4 => (-8, 7),
BitWidth::Binary => (0, 1),
}
}
pub fn bits(&self) -> usize {
match self {
BitWidth::Int8 => 8,
BitWidth::Int4 => 4,
BitWidth::Binary => 1,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizationConfig {
pub scheme: QuantizationScheme,
pub bit_width: BitWidth,
pub calibration: bool,
pub calibration_samples: usize,
pub weights_only: bool,
pub qat: bool,
}
impl Default for QuantizationConfig {
fn default() -> Self {
Self {
scheme: QuantizationScheme::Symmetric,
bit_width: BitWidth::Int8,
calibration: true,
calibration_samples: 1000,
weights_only: true,
qat: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizationParams {
pub scale: f32,
pub zero_point: i32,
pub min_val: f32,
pub max_val: f32,
}
impl QuantizationParams {
pub fn from_statistics(
min_val: f32,
max_val: f32,
bit_width: BitWidth,
symmetric: bool,
) -> Self {
let (qmin, qmax) = bit_width.range();
let (scale, zero_point) = if symmetric {
let max_abs = min_val.abs().max(max_val.abs());
let scale = (2.0 * max_abs) / (qmax - qmin) as f32;
(scale, 0)
} else {
let scale = (max_val - min_val) / (qmax - qmin) as f32;
let zero_point = qmin - (min_val / scale).round() as i32;
(scale, zero_point)
};
Self {
scale,
zero_point,
min_val,
max_val,
}
}
pub fn quantize(&self, value: f32, bit_width: BitWidth) -> i8 {
let (qmin, qmax) = bit_width.range();
let quantized = (value / self.scale).round() as i32 + self.zero_point;
quantized.clamp(qmin, qmax) as i8
}
pub fn dequantize(&self, quantized: i8) -> f32 {
(quantized as i32 - self.zero_point) as f32 * self.scale
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizedTensor {
pub values: Vec<i8>,
pub params: QuantizationParams,
pub shape: Vec<usize>,
}
impl QuantizedTensor {
pub fn from_array(array: &Array1<f32>, config: &QuantizationConfig) -> Self {
let min_val = array.iter().cloned().fold(f32::INFINITY, f32::min);
let max_val = array.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let symmetric = matches!(config.scheme, QuantizationScheme::Symmetric);
let params =
QuantizationParams::from_statistics(min_val, max_val, config.bit_width, symmetric);
let values: Vec<i8> = array
.iter()
.map(|&v| params.quantize(v, config.bit_width))
.collect();
Self {
values,
params,
shape: vec![array.len()],
}
}
pub fn to_array(&self) -> Array1<f32> {
Array1::from_vec(
self.values
.iter()
.map(|&v| self.params.dequantize(v))
.collect(),
)
}
pub fn compression_ratio(&self) -> f32 {
let original_size = self.values.len() * 4;
let quantized_size = self.values.len() + std::mem::size_of::<QuantizationParams>();
original_size as f32 / quantized_size as f32
}
pub fn size_bytes(&self) -> usize {
self.values.len() + std::mem::size_of::<QuantizationParams>()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizationStats {
pub total_params: usize,
pub original_size_bytes: usize,
pub quantized_size_bytes: usize,
pub compression_ratio: f32,
pub avg_quantization_error: f32,
pub max_quantization_error: f32,
}
impl Default for QuantizationStats {
fn default() -> Self {
Self {
total_params: 0,
original_size_bytes: 0,
quantized_size_bytes: 0,
compression_ratio: 1.0,
avg_quantization_error: 0.0,
max_quantization_error: 0.0,
}
}
}
pub struct ModelQuantizer {
config: QuantizationConfig,
stats: QuantizationStats,
}
impl ModelQuantizer {
pub fn new(config: QuantizationConfig) -> Self {
info!(
"Initialized model quantizer: scheme={:?}, bit_width={:?}",
config.scheme, config.bit_width
);
Self {
config,
stats: QuantizationStats::default(),
}
}
pub fn quantize_embeddings(
&mut self,
embeddings: &HashMap<String, Array1<f32>>,
) -> Result<HashMap<String, QuantizedTensor>> {
if embeddings.is_empty() {
return Err(anyhow!("No embeddings to quantize"));
}
info!("Quantizing {} embeddings", embeddings.len());
let mut quantized_embeddings = HashMap::new();
let mut total_error = 0.0;
let mut max_error: f32 = 0.0;
for (entity, embedding) in embeddings {
let quantized = QuantizedTensor::from_array(embedding, &self.config);
let dequantized = quantized.to_array();
let error = self.compute_error(embedding, &dequantized);
total_error += error;
max_error = max_error.max(error);
self.stats.original_size_bytes += embedding.len() * 4;
self.stats.quantized_size_bytes += quantized.size_bytes();
quantized_embeddings.insert(entity.clone(), quantized);
}
self.stats.total_params = embeddings.values().map(|e| e.len()).sum();
self.stats.compression_ratio =
self.stats.original_size_bytes as f32 / self.stats.quantized_size_bytes as f32;
self.stats.avg_quantization_error = total_error / embeddings.len() as f32;
self.stats.max_quantization_error = max_error;
info!(
"Quantization complete: compression_ratio={:.2}x, avg_error={:.6}",
self.stats.compression_ratio, self.stats.avg_quantization_error
);
Ok(quantized_embeddings)
}
pub fn dequantize_embeddings(
&self,
quantized: &HashMap<String, QuantizedTensor>,
) -> HashMap<String, Array1<f32>> {
quantized
.iter()
.map(|(entity, q)| (entity.clone(), q.to_array()))
.collect()
}
pub fn quantize_embedding(&self, embedding: &Array1<f32>) -> QuantizedTensor {
QuantizedTensor::from_array(embedding, &self.config)
}
pub fn dequantize_embedding(&self, quantized: &QuantizedTensor) -> Array1<f32> {
quantized.to_array()
}
fn compute_error(&self, original: &Array1<f32>, dequantized: &Array1<f32>) -> f32 {
let diff = original - dequantized;
let mse = diff.dot(&diff) / original.len() as f32;
mse.sqrt() }
pub fn calibrate(&mut self, embeddings: &HashMap<String, Array1<f32>>) -> Result<()> {
if !self.config.calibration {
return Ok(());
}
info!(
"Calibrating quantization parameters with {} samples",
self.config.calibration_samples.min(embeddings.len())
);
let samples: Vec<&Array1<f32>> = embeddings
.values()
.take(self.config.calibration_samples)
.collect();
let mut global_min = f32::INFINITY;
let mut global_max = f32::NEG_INFINITY;
for embedding in samples {
let min = embedding.iter().cloned().fold(f32::INFINITY, f32::min);
let max = embedding.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
global_min = global_min.min(min);
global_max = global_max.max(max);
}
debug!(
"Calibration complete: min={:.6}, max={:.6}",
global_min, global_max
);
Ok(())
}
pub fn get_stats(&self) -> &QuantizationStats {
&self.stats
}
pub fn estimate_speedup(&self) -> f32 {
match self.config.bit_width {
BitWidth::Int8 => 3.0,
BitWidth::Int4 => 5.0,
BitWidth::Binary => 10.0,
}
}
pub fn config(&self) -> &QuantizationConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray_ext::array;
#[test]
fn test_quantization_params() {
let min_val = -10.0;
let max_val = 10.0;
let params = QuantizationParams::from_statistics(
min_val,
max_val,
BitWidth::Int8,
true, );
assert!(params.scale > 0.0);
assert_eq!(params.zero_point, 0); }
#[test]
fn test_quantize_dequantize() {
let params = QuantizationParams::from_statistics(-10.0, 10.0, BitWidth::Int8, true);
let value = 5.0;
let quantized = params.quantize(value, BitWidth::Int8);
let dequantized = params.dequantize(quantized);
assert!((value - dequantized).abs() < 1.0);
}
#[test]
fn test_quantized_tensor() {
let array = Array1::from_vec((0..128).map(|i| i as f32 * 0.1).collect());
let config = QuantizationConfig::default();
let quantized = QuantizedTensor::from_array(&array, &config);
let dequantized = quantized.to_array();
assert_eq!(quantized.values.len(), 128);
assert_eq!(dequantized.len(), 128);
assert!(quantized.compression_ratio() > 1.0);
}
#[test]
fn test_model_quantizer() {
let mut embeddings = HashMap::new();
embeddings.insert(
"e1".to_string(),
Array1::from_vec((0..128).map(|i| i as f32 * 0.1).collect()),
);
embeddings.insert(
"e2".to_string(),
Array1::from_vec((0..128).map(|i| (i as f32 * 0.1) + 10.0).collect()),
);
let config = QuantizationConfig::default();
let mut quantizer = ModelQuantizer::new(config);
let quantized = quantizer
.quantize_embeddings(&embeddings)
.expect("should succeed");
assert_eq!(quantized.len(), 2);
assert!(quantizer.stats.compression_ratio > 1.0);
assert!(quantizer.stats.avg_quantization_error >= 0.0);
}
#[test]
fn test_roundtrip() {
let mut embeddings = HashMap::new();
embeddings.insert("e1".to_string(), array![1.0, -2.0, 3.5, -4.2]);
let config = QuantizationConfig::default();
let mut quantizer = ModelQuantizer::new(config);
let quantized = quantizer
.quantize_embeddings(&embeddings)
.expect("should succeed");
let dequantized = quantizer.dequantize_embeddings(&quantized);
assert_eq!(dequantized.len(), 1);
let original = &embeddings["e1"];
let recovered = &dequantized["e1"];
for i in 0..original.len() {
let error = (original[i] - recovered[i]).abs();
assert!(error < 1.0);
}
}
#[test]
fn test_compression_ratio() {
let mut embeddings = HashMap::new();
for i in 0..100 {
let emb = Array1::from_vec(vec![i as f32; 128]);
embeddings.insert(format!("e{}", i), emb);
}
let config = QuantizationConfig::default();
let mut quantizer = ModelQuantizer::new(config);
quantizer
.quantize_embeddings(&embeddings)
.expect("should succeed");
assert!(quantizer.stats.compression_ratio > 3.0);
assert!(quantizer.stats.compression_ratio < 5.0);
}
}