use anyhow::Result;
use candle_core::{DType, Tensor};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct QuantizeConfig {
pub skip_layers: Vec<String>,
pub min_size: usize,
pub num_levels: usize,
}
impl Default for QuantizeConfig {
fn default() -> Self {
Self {
skip_layers: vec![
"embed".to_string(),
"lut".to_string(),
"out_proj".to_string(),
"eos_head".to_string(),
],
min_size: 1024, num_levels: 256, }
}
}
#[derive(Debug, Clone)]
pub struct QuantizedTensor {
pub data: Tensor,
pub scale: f32,
pub zero_point: f32,
pub num_levels: usize,
}
impl QuantizedTensor {
pub fn quantize(tensor: &Tensor, num_levels: usize) -> Result<Self> {
let tensor_f32 = tensor.to_dtype(DType::F32)?;
let abs_max = tensor_f32.abs()?.max_all()?.to_scalar::<f32>()?;
let half_levels = (num_levels / 2) as f32;
let scale = if abs_max > 0.0 {
abs_max / (half_levels - 1.0)
} else {
1.0
};
let scale_tensor = Tensor::new(&[scale], tensor.device())?;
let quantized = tensor_f32.broadcast_div(&scale_tensor)?;
let quantized = quantized.round()?;
let clamped = quantized.clamp(-(half_levels - 1.0) as f64, (half_levels - 1.0) as f64)?;
let data = clamped.broadcast_mul(&scale_tensor)?;
Ok(Self {
data,
scale,
zero_point: 0.0, num_levels,
})
}
pub fn data(&self) -> &Tensor {
&self.data
}
pub fn scale(&self) -> f32 {
self.scale
}
pub fn theoretical_memory_savings(&self) -> f32 {
match self.num_levels {
256 => 4.0, 65536 => 2.0, _ => 1.0,
}
}
}
fn should_skip_layer(name: &str, config: &QuantizeConfig) -> bool {
config.skip_layers.iter().any(|skip| name.contains(skip))
}
pub fn quantize_weights(
weights: &HashMap<String, Tensor>,
config: &QuantizeConfig,
) -> Result<HashMap<String, QuantizedTensor>> {
let mut quantized = HashMap::new();
for (name, tensor) in weights {
if tensor.elem_count() < config.min_size || should_skip_layer(name, config) {
quantized.insert(
name.clone(),
QuantizedTensor {
data: tensor.clone(),
scale: 1.0,
zero_point: 0.0,
num_levels: 0, },
);
} else {
quantized.insert(
name.clone(),
QuantizedTensor::quantize(tensor, config.num_levels)?,
);
}
}
Ok(quantized)
}
pub fn calculate_snr(original: &Tensor, quantized: &Tensor) -> Result<f32> {
let original_f32 = original.to_dtype(DType::F32)?;
let quantized_f32 = quantized.to_dtype(DType::F32)?;
let signal_power = original_f32.sqr()?.mean_all()?.to_scalar::<f32>()?;
let noise = (&original_f32 - &quantized_f32)?;
let noise_power = noise.sqr()?.mean_all()?.to_scalar::<f32>()?;
if noise_power <= 0.0 {
return Ok(f32::INFINITY); }
Ok(10.0 * (signal_power / noise_power).log10())
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn test_quantize_tensor() {
let device = Device::Cpu;
let tensor = Tensor::new(&[1.0f32, 2.0, -3.0, 4.5, -2.1], &device).unwrap();
let quantized = QuantizedTensor::quantize(&tensor, 256).unwrap();
let snr = calculate_snr(&tensor, &quantized.data).unwrap();
assert!(snr > 30.0, "SNR {} is too low", snr);
}
#[test]
fn test_quantize_large_tensor() {
let device = Device::Cpu;
let values: Vec<f32> = (0..10000).map(|i| (i as f32 * 0.01).sin() * 10.0).collect();
let tensor = Tensor::new(&values[..], &device).unwrap();
let quantized = QuantizedTensor::quantize(&tensor, 256).unwrap();
let snr = calculate_snr(&tensor, &quantized.data).unwrap();
assert!(snr > 30.0, "SNR {} is too low", snr);
}
#[test]
fn test_quantize_config_skip_layers() {
let config = QuantizeConfig::default();
assert!(should_skip_layer("model.embed_tokens", &config));
assert!(should_skip_layer("decoder.out_proj", &config));
assert!(!should_skip_layer("encoder.layers.0.linear", &config));
}
#[test]
fn test_theoretical_savings() {
let device = Device::Cpu;
let tensor = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
let quantized = QuantizedTensor::quantize(&tensor, 256).unwrap();
assert_eq!(quantized.theoretical_memory_savings(), 4.0);
}
}