use half::f16;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;
use std::path::{Path, PathBuf};
use trustformers_core::errors::{invalid_config, runtime_error, tensor_op_error, Result};
use trustformers_core::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[allow(non_camel_case_types)]
pub enum QuantizationScheme {
Int4,
Int8,
FP16,
Dynamic,
GGUF_Q2_K,
GGUF_Q3_K,
GGUF_Q4_K,
GGUF_Q5_0,
GGUF_Q6_K,
}
impl std::fmt::Display for QuantizationScheme {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
QuantizationScheme::Int4 => write!(f, "INT4"),
QuantizationScheme::Int8 => write!(f, "INT8"),
QuantizationScheme::FP16 => write!(f, "FP16"),
QuantizationScheme::Dynamic => write!(f, "Dynamic"),
QuantizationScheme::GGUF_Q2_K => write!(f, "GGUF_Q2_K"),
QuantizationScheme::GGUF_Q3_K => write!(f, "GGUF_Q3_K"),
QuantizationScheme::GGUF_Q4_K => write!(f, "GGUF_Q4_K"),
QuantizationScheme::GGUF_Q5_0 => write!(f, "GGUF_Q5_0"),
QuantizationScheme::GGUF_Q6_K => write!(f, "GGUF_Q6_K"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CalibrationMethod {
MinMax,
Percentile,
MovingAverage,
KLDivergence,
}
#[derive(Debug, Clone)]
pub struct QuantizationContext {
pub method: CalibrationMethod,
pub num_calibration_samples: usize,
pub percentile: f32, pub smooth_factor: f32, }
impl Default for QuantizationContext {
fn default() -> Self {
Self {
method: CalibrationMethod::MinMax,
num_calibration_samples: 100,
percentile: 99.9,
smooth_factor: 0.999,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct QuantizationCalibration {
pub min_values: HashMap<String, f32>,
pub max_values: HashMap<String, f32>,
pub scales: HashMap<String, f32>,
pub zero_points: HashMap<String, i32>,
pub histogram_bins: HashMap<String, Vec<f32>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizationSchemeConfig {
pub default_scheme: QuantizationScheme,
pub layer_schemes: HashMap<String, QuantizationScheme>,
pub tensor_schemes: HashMap<String, QuantizationScheme>,
pub model_schemes: HashMap<String, QuantizationScheme>,
pub performance_schemes: HashMap<String, QuantizationScheme>,
}
impl Default for QuantizationSchemeConfig {
fn default() -> Self {
Self {
default_scheme: QuantizationScheme::Int8,
layer_schemes: HashMap::new(),
tensor_schemes: HashMap::new(),
model_schemes: HashMap::new(),
performance_schemes: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct QuantizationSchemeStorage {
pub config_path: Option<PathBuf>,
pub config: QuantizationSchemeConfig,
pub scheme_cache: HashMap<String, QuantizationScheme>,
}
impl Default for QuantizationSchemeStorage {
fn default() -> Self {
Self::new()
}
}
impl QuantizationSchemeStorage {
pub fn new() -> Self {
Self {
config_path: None,
config: QuantizationSchemeConfig::default(),
scheme_cache: HashMap::new(),
}
}
pub fn with_config_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let config_path = path.as_ref().to_path_buf();
let config = Self::load_config(&config_path)?;
Ok(Self {
config_path: Some(config_path),
config,
scheme_cache: HashMap::new(),
})
}
pub fn load_config<P: AsRef<Path>>(path: P) -> Result<QuantizationSchemeConfig> {
let file = File::open(path.as_ref())
.map_err(|e| runtime_error(format!("Failed to open config file: {}", e)))?;
let reader = BufReader::new(file);
serde_json::from_reader(reader)
.map_err(|e| invalid_config("load_config", format!("Failed to parse config: {}", e)))
}
pub fn save_config(&self) -> Result<()> {
if let Some(ref path) = self.config_path {
let file = File::create(path)
.map_err(|e| runtime_error(format!("Failed to create config file: {}", e)))?;
serde_json::to_writer_pretty(file, &self.config)
.map_err(|e| runtime_error(format!("Failed to write config: {}", e)))?;
}
Ok(())
}
pub fn determine_scheme(
&mut self,
tensor_id: &str,
layer_name: Option<&str>,
model_name: Option<&str>,
) -> QuantizationScheme {
if let Some(&scheme) = self.scheme_cache.get(tensor_id) {
return scheme;
}
if let Some(&scheme) = self.config.tensor_schemes.get(tensor_id) {
self.scheme_cache.insert(tensor_id.to_string(), scheme);
return scheme;
}
if let Some(layer) = layer_name {
if let Some(&scheme) = self.config.layer_schemes.get(layer) {
self.scheme_cache.insert(tensor_id.to_string(), scheme);
return scheme;
}
}
if let Some(model) = model_name {
if let Some(&scheme) = self.config.model_schemes.get(model) {
self.scheme_cache.insert(tensor_id.to_string(), scheme);
return scheme;
}
}
let default_scheme = self.config.default_scheme;
self.scheme_cache.insert(tensor_id.to_string(), default_scheme);
default_scheme
}
pub fn set_tensor_scheme(&mut self, tensor_id: String, scheme: QuantizationScheme) {
self.config.tensor_schemes.insert(tensor_id.clone(), scheme);
self.scheme_cache.insert(tensor_id, scheme);
}
pub fn set_layer_scheme(&mut self, layer_name: String, scheme: QuantizationScheme) {
self.config.layer_schemes.insert(layer_name, scheme);
}
pub fn set_model_scheme(&mut self, model_name: String, scheme: QuantizationScheme) {
self.config.model_schemes.insert(model_name, scheme);
}
pub fn clear_cache(&mut self) {
self.scheme_cache.clear();
}
pub fn generate_tensor_id(tensor: &Tensor, layer_name: Option<&str>) -> String {
let shape_str = tensor.shape().iter().map(|&s| s.to_string()).collect::<Vec<_>>().join("x");
let data_hash = {
if let Ok(data) = tensor.data() {
let sample_size = (data.len() / 100).max(1).min(1000); let mut hash = 0u64;
for i in (0..data.len()).step_by(sample_size) {
hash = hash.wrapping_mul(31).wrapping_add(data[i].to_bits() as u64);
}
hash
} else {
0u64 }
};
match layer_name {
Some(layer) => format!("{}:{}:{:x}", layer, shape_str, data_hash),
None => format!("tensor:{}:{:x}", shape_str, data_hash),
}
}
}
pub trait MobileQuantizer: Send + Sync {
fn get_scheme(&self) -> QuantizationScheme;
fn requires_calibration(&self) -> bool;
fn calibrate(&self, data: &[Tensor]) -> Result<()>;
fn quantize_tensor(&self, tensor: &Tensor) -> Result<Tensor>;
fn dequantize_tensor(&self, tensor: &Tensor) -> Result<Tensor>;
}
pub struct Int4Quantizer {
context: QuantizationContext,
calibration: std::sync::RwLock<QuantizationCalibration>,
}
impl Default for Int4Quantizer {
fn default() -> Self {
Self::new()
}
}
impl Int4Quantizer {
pub fn new() -> Self {
Self {
context: QuantizationContext::default(),
calibration: std::sync::RwLock::new(QuantizationCalibration::default()),
}
}
fn compute_scale_zero_point(&self, min_val: f32, max_val: f32) -> (f32, i32) {
let qmin = -8.0; let qmax = 7.0;
let scale = (max_val - min_val) / (qmax - qmin);
let zero_point = ((qmin - min_val / scale).round() as i32).clamp(-8, 7);
(scale, zero_point)
}
fn quantize_value(&self, value: f32, scale: f32, zero_point: i32) -> i8 {
let quantized = (value / scale).round() as i32 + zero_point;
quantized.clamp(-8, 7) as i8
}
fn dequantize_value(&self, quantized: i8, scale: f32, zero_point: i32) -> f32 {
(quantized as i32 - zero_point) as f32 * scale
}
}
impl MobileQuantizer for Int4Quantizer {
fn get_scheme(&self) -> QuantizationScheme {
QuantizationScheme::Int4
}
fn requires_calibration(&self) -> bool {
true
}
fn calibrate(&self, data: &[Tensor]) -> Result<()> {
let mut calibration = self.calibration.write().expect("RwLock poisoned");
for tensor in data {
let tensor_data = tensor.data()?;
let min_val = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_val = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let (scale, zero_point) = self.compute_scale_zero_point(min_val, max_val);
calibration.min_values.insert("global".to_string(), min_val);
calibration.max_values.insert("global".to_string(), max_val);
calibration.scales.insert("global".to_string(), scale);
calibration.zero_points.insert("global".to_string(), zero_point);
}
Ok(())
}
fn quantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
let calibration = self.calibration.read().expect("RwLock poisoned");
let tensor_data = tensor.data()?;
let (scale, zero_point) = if let Some(&scale) = calibration.scales.get("global") {
(
scale,
*calibration.zero_points.get("global").expect("No global zero point"),
)
} else {
let min_val = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_val = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
self.compute_scale_zero_point(min_val, max_val)
};
let quantized_data: Vec<i8> =
tensor_data.iter().map(|&x| self.quantize_value(x, scale, zero_point)).collect();
let quantized_f32: Vec<f32> = quantized_data.iter().map(|&x| x as f32).collect();
let quantized_tensor = Tensor::from_vec(quantized_f32, &tensor.shape())?;
Ok(quantized_tensor)
}
fn dequantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
let calibration = self.calibration.read().expect("RwLock poisoned");
let tensor_data = tensor.data()?;
let (scale, zero_point) = if let Some(&scale) = calibration.scales.get("global") {
(
scale,
*calibration.zero_points.get("global").expect("No global zero point"),
)
} else {
let min_q = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b)) as i8;
let max_q = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)) as i8;
let range = (max_q - min_q) as f32;
let scale = if range > 0.0 { 15.0 / range } else { 1.0 }; (scale, 0)
};
let dequantized_data: Vec<f32> = tensor_data
.iter()
.map(|&x| self.dequantize_value(x as i8, scale, zero_point))
.collect();
Tensor::from_vec(dequantized_data, &tensor.shape())
}
}
pub struct Int8Quantizer {
context: QuantizationContext,
calibration: std::sync::RwLock<QuantizationCalibration>,
symmetric: bool,
}
impl Default for Int8Quantizer {
fn default() -> Self {
Self::new()
}
}
impl Int8Quantizer {
pub fn new() -> Self {
Self {
context: QuantizationContext::default(),
calibration: std::sync::RwLock::new(QuantizationCalibration::default()),
symmetric: true, }
}
fn compute_scale_zero_point(&self, min_val: f32, max_val: f32) -> (f32, i32) {
if self.symmetric {
let abs_max = min_val.abs().max(max_val.abs());
let scale = abs_max / 127.0;
(scale, 0)
} else {
let qmin = -128.0;
let qmax = 127.0;
let scale = (max_val - min_val) / (qmax - qmin);
let zero_point = ((qmin - min_val / scale).round() as i32).clamp(-128, 127);
(scale, zero_point)
}
}
}
impl MobileQuantizer for Int8Quantizer {
fn get_scheme(&self) -> QuantizationScheme {
QuantizationScheme::Int8
}
fn requires_calibration(&self) -> bool {
true
}
fn calibrate(&self, data: &[Tensor]) -> Result<()> {
let mut calibration = self.calibration.write().expect("RwLock poisoned");
for tensor in data {
let tensor_data = tensor.data()?;
let (min_val, max_val) = match self.context.method {
CalibrationMethod::MinMax => {
let min = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
(min, max)
},
CalibrationMethod::Percentile => {
let mut sorted = tensor_data.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let percentile_idx =
(sorted.len() as f32 * self.context.percentile / 100.0) as usize;
let min_idx =
(sorted.len() as f32 * (100.0 - self.context.percentile) / 100.0) as usize;
(
sorted[min_idx],
sorted[percentile_idx.min(sorted.len() - 1)],
)
},
_ => {
let min = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
(min, max)
},
};
let (scale, zero_point) = self.compute_scale_zero_point(min_val, max_val);
calibration.min_values.insert("global".to_string(), min_val);
calibration.max_values.insert("global".to_string(), max_val);
calibration.scales.insert("global".to_string(), scale);
calibration.zero_points.insert("global".to_string(), zero_point);
}
Ok(())
}
fn quantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
let calibration = self.calibration.read().expect("RwLock poisoned");
let tensor_data = tensor.data()?;
let (scale, zero_point) = if let Some(&scale) = calibration.scales.get("global") {
(
scale,
*calibration.zero_points.get("global").expect("No global zero point"),
)
} else {
let min_val = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_val = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
self.compute_scale_zero_point(min_val, max_val)
};
let quantized_data: Vec<i8> = tensor_data
.iter()
.map(|&x| {
let q = (x / scale).round() as i32 + zero_point;
q.clamp(-128, 127) as i8
})
.collect();
let quantized_f32: Vec<f32> = quantized_data.iter().map(|&x| x as f32).collect();
let quantized_tensor = Tensor::from_vec(quantized_f32, &tensor.shape())?;
Ok(quantized_tensor)
}
fn dequantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
let calibration = self.calibration.read().expect("RwLock poisoned");
let tensor_data = tensor.data()?;
let (scale, zero_point) = if let Some(&scale) = calibration.scales.get("global") {
(
scale,
*calibration.zero_points.get("global").expect("No global zero point"),
)
} else {
let min_q = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b)) as i32;
let max_q = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)) as i32;
let range = (max_q - min_q) as f32;
let scale = if range > 0.0 { 255.0 / range } else { 1.0 }; (scale, 0)
};
let dequantized_data: Vec<f32> =
tensor_data.iter().map(|&x| ((x as i32) - zero_point) as f32 * scale).collect();
Tensor::from_vec(dequantized_data, &tensor.shape())
}
}
pub struct FP16Quantizer;
impl Default for FP16Quantizer {
fn default() -> Self {
Self::new()
}
}
impl FP16Quantizer {
pub fn new() -> Self {
Self
}
}
impl MobileQuantizer for FP16Quantizer {
fn get_scheme(&self) -> QuantizationScheme {
QuantizationScheme::FP16
}
fn requires_calibration(&self) -> bool {
false }
fn calibrate(&self, _data: &[Tensor]) -> Result<()> {
Ok(()) }
fn quantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
let tensor_data = tensor.data()?;
let fp16_data: Vec<f16> = tensor_data.iter().map(|&x| f16::from_f32(x)).collect();
let quantized_data: Vec<f32> = fp16_data.iter().map(|&x| f32::from(x)).collect();
let quantized_tensor = Tensor::from_vec(quantized_data, &tensor.shape())?;
Ok(quantized_tensor)
}
fn dequantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
Ok(tensor.clone())
}
}
pub struct DynamicQuantizer {
int8_quantizer: Int8Quantizer,
fp16_quantizer: FP16Quantizer,
selection_threshold: f32,
scheme_storage: QuantizationSchemeStorage,
layer_context: Option<String>,
model_context: Option<String>,
}
impl Default for DynamicQuantizer {
fn default() -> Self {
Self::new()
}
}
impl DynamicQuantizer {
pub fn new() -> Self {
Self {
int8_quantizer: Int8Quantizer::new(),
fp16_quantizer: FP16Quantizer::new(),
selection_threshold: 0.1, scheme_storage: QuantizationSchemeStorage::new(),
layer_context: None,
model_context: None,
}
}
pub fn with_config_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let scheme_storage = QuantizationSchemeStorage::with_config_file(path)?;
Ok(Self {
int8_quantizer: Int8Quantizer::new(),
fp16_quantizer: FP16Quantizer::new(),
selection_threshold: 0.1,
scheme_storage,
layer_context: None,
model_context: None,
})
}
pub fn set_layer_context(&mut self, layer_name: String) {
self.layer_context = Some(layer_name);
}
pub fn set_model_context(&mut self, model_name: String) {
self.model_context = Some(model_name);
}
pub fn scheme_storage_mut(&mut self) -> &mut QuantizationSchemeStorage {
&mut self.scheme_storage
}
pub fn scheme_storage(&self) -> &QuantizationSchemeStorage {
&self.scheme_storage
}
fn select_quantization_scheme(&self, tensor: &Tensor) -> Result<QuantizationScheme> {
let tensor_data = tensor.data()?;
let min_val = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_val = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let range = max_val - min_val;
let mean = tensor_data.iter().sum::<f32>() / tensor_data.len() as f32;
let variance =
tensor_data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / tensor_data.len() as f32;
if range < 1.0 && variance < 0.01 {
Ok(QuantizationScheme::Int8)
} else {
Ok(QuantizationScheme::FP16)
}
}
}
impl MobileQuantizer for DynamicQuantizer {
fn get_scheme(&self) -> QuantizationScheme {
QuantizationScheme::Dynamic
}
fn requires_calibration(&self) -> bool {
true }
fn calibrate(&self, data: &[Tensor]) -> Result<()> {
self.int8_quantizer.calibrate(data)
}
fn quantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
let tensor_id =
QuantizationSchemeStorage::generate_tensor_id(tensor, self.layer_context.as_deref());
let mut storage = self.scheme_storage.clone();
let scheme = storage.determine_scheme(
&tensor_id,
self.layer_context.as_deref(),
self.model_context.as_deref(),
);
let final_scheme = if scheme == QuantizationScheme::Dynamic {
self.select_quantization_scheme(tensor)?
} else {
scheme
};
match final_scheme {
QuantizationScheme::Int4 => {
let int4_quantizer = Int4Quantizer::new();
int4_quantizer.quantize_tensor(tensor)
},
QuantizationScheme::Int8 => self.int8_quantizer.quantize_tensor(tensor),
QuantizationScheme::FP16 => self.fp16_quantizer.quantize_tensor(tensor),
QuantizationScheme::GGUF_Q2_K
| QuantizationScheme::GGUF_Q3_K
| QuantizationScheme::GGUF_Q4_K
| QuantizationScheme::GGUF_Q5_0
| QuantizationScheme::GGUF_Q6_K => {
self.int8_quantizer.quantize_tensor(tensor)
},
QuantizationScheme::Dynamic => {
let selected_scheme = self.select_quantization_scheme(tensor)?;
match selected_scheme {
QuantizationScheme::Int4 => {
let int4_quantizer = Int4Quantizer::new();
int4_quantizer.quantize_tensor(tensor)
},
QuantizationScheme::Int8 => self.int8_quantizer.quantize_tensor(tensor),
QuantizationScheme::FP16 => self.fp16_quantizer.quantize_tensor(tensor),
QuantizationScheme::GGUF_Q2_K
| QuantizationScheme::GGUF_Q3_K
| QuantizationScheme::GGUF_Q4_K
| QuantizationScheme::GGUF_Q5_0
| QuantizationScheme::GGUF_Q6_K => self.int8_quantizer.quantize_tensor(tensor),
QuantizationScheme::Dynamic => {
self.int8_quantizer.quantize_tensor(tensor)
},
}
},
}
}
fn dequantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
let tensor_id =
QuantizationSchemeStorage::generate_tensor_id(tensor, self.layer_context.as_deref());
let mut storage = self.scheme_storage.clone();
let scheme = storage.determine_scheme(
&tensor_id,
self.layer_context.as_deref(),
self.model_context.as_deref(),
);
match scheme {
QuantizationScheme::Int8 => self.int8_quantizer.dequantize_tensor(tensor),
QuantizationScheme::FP16 => self.fp16_quantizer.dequantize_tensor(tensor),
QuantizationScheme::Int4 => {
let int4_quantizer = Int4Quantizer::new();
int4_quantizer.dequantize_tensor(tensor)
},
QuantizationScheme::GGUF_Q2_K
| QuantizationScheme::GGUF_Q3_K
| QuantizationScheme::GGUF_Q4_K
| QuantizationScheme::GGUF_Q5_0
| QuantizationScheme::GGUF_Q6_K => {
self.int8_quantizer.dequantize_tensor(tensor)
},
QuantizationScheme::Dynamic => {
let selected_scheme = self.select_quantization_scheme(tensor)?;
match selected_scheme {
QuantizationScheme::Int8 => self.int8_quantizer.dequantize_tensor(tensor),
QuantizationScheme::FP16 => self.fp16_quantizer.dequantize_tensor(tensor),
_ => self.int8_quantizer.dequantize_tensor(tensor), }
},
}
}
}
pub struct QuantizationUtils;
impl QuantizationUtils {
pub fn compute_error(original: &Tensor, quantized: &Tensor) -> Result<f32> {
let orig_data = original.data()?;
let quant_data = quantized.data()?;
if orig_data.len() != quant_data.len() {
return Err(tensor_op_error(
"compute_error",
"Tensors must have same size for error computation",
));
}
let mse = orig_data
.iter()
.zip(quant_data.iter())
.map(|(&o, &q)| (o - q).powi(2))
.sum::<f32>()
/ orig_data.len() as f32;
Ok(mse.sqrt())
}
pub fn compression_ratio(scheme: QuantizationScheme) -> f32 {
match scheme {
QuantizationScheme::Int4 => 8.0, QuantizationScheme::Int8 => 4.0, QuantizationScheme::FP16 => 2.0, QuantizationScheme::Dynamic => 3.0, QuantizationScheme::GGUF_Q2_K => 32.0 / 2.5625, QuantizationScheme::GGUF_Q3_K => 32.0 / 3.4375, QuantizationScheme::GGUF_Q4_K => 32.0 / 4.5, QuantizationScheme::GGUF_Q5_0 => 32.0 / 5.5, QuantizationScheme::GGUF_Q6_K => 32.0 / 6.5, }
}
pub fn memory_savings_percent(scheme: QuantizationScheme) -> f32 {
let ratio = Self::compression_ratio(scheme);
(1.0 - 1.0 / ratio) * 100.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_int4_quantization() {
let quantizer = Int4Quantizer::new();
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4])
.expect("Failed to create tensor");
quantizer.calibrate(std::slice::from_ref(&tensor)).expect("Calibration failed");
let quantized = quantizer.quantize_tensor(&tensor).expect("Quantization failed");
assert_eq!(quantized.shape(), tensor.shape());
let dequantized = quantizer.dequantize_tensor(&quantized).expect("Dequantization failed");
assert_eq!(dequantized.shape(), tensor.shape());
let error = QuantizationUtils::compute_error(&tensor, &dequantized)
.expect("Error computation failed");
assert!(error < 1.0); }
#[test]
fn test_int8_quantization() {
let quantizer = Int8Quantizer::new();
let tensor = Tensor::from_vec(vec![-10.0, -5.0, 0.0, 5.0, 10.0], &[5])
.expect("Failed to create tensor");
quantizer.calibrate(std::slice::from_ref(&tensor)).expect("Calibration failed");
let quantized = quantizer.quantize_tensor(&tensor).expect("Quantization failed");
let dequantized = quantizer.dequantize_tensor(&quantized).expect("Dequantization failed");
let error = QuantizationUtils::compute_error(&tensor, &dequantized)
.expect("Error computation failed");
assert!(error < 0.1); }
#[test]
fn test_fp16_quantization() {
let quantizer = FP16Quantizer::new();
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("Operation failed");
assert!(!quantizer.requires_calibration());
let quantized = quantizer.quantize_tensor(&tensor).expect("Quantization failed");
let dequantized = quantizer.dequantize_tensor(&quantized).expect("Dequantization failed");
let error = QuantizationUtils::compute_error(&tensor, &dequantized)
.expect("Error computation failed");
assert!(error < 0.001);
}
#[test]
fn test_dynamic_quantization() {
let mut quantizer = DynamicQuantizer::new();
let small_range =
Tensor::from_vec(vec![0.1, 0.2, 0.3, 0.4], &[4]).expect("Operation failed");
quantizer
.calibrate(std::slice::from_ref(&small_range))
.expect("Operation failed");
let quantized = quantizer.quantize_tensor(&small_range).expect("Operation failed");
let tensor_id = QuantizationSchemeStorage::generate_tensor_id(&small_range, None);
quantizer
.scheme_storage_mut()
.set_tensor_scheme(tensor_id.clone(), QuantizationScheme::FP16);
let quantized_fp16 = quantizer.quantize_tensor(&small_range).expect("Operation failed");
let stored_scheme = quantizer.scheme_storage_mut().determine_scheme(&tensor_id, None, None);
assert_eq!(stored_scheme, QuantizationScheme::FP16);
let unknown_tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("Operation failed");
let unknown_id = QuantizationSchemeStorage::generate_tensor_id(&unknown_tensor, None);
let default_scheme =
quantizer.scheme_storage_mut().determine_scheme(&unknown_id, None, None);
assert_eq!(default_scheme, QuantizationScheme::Int8); }
#[test]
fn test_compression_ratios() {
assert_eq!(
QuantizationUtils::compression_ratio(QuantizationScheme::Int4),
8.0
);
assert_eq!(
QuantizationUtils::compression_ratio(QuantizationScheme::Int8),
4.0
);
assert_eq!(
QuantizationUtils::compression_ratio(QuantizationScheme::FP16),
2.0
);
assert_eq!(
QuantizationUtils::memory_savings_percent(QuantizationScheme::Int4),
87.5
);
assert_eq!(
QuantizationUtils::memory_savings_percent(QuantizationScheme::Int8),
75.0
);
assert_eq!(
QuantizationUtils::memory_savings_percent(QuantizationScheme::FP16),
50.0
);
}
}