1use anyhow::Result;
13use candle_core::{DType, Tensor};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone)]
18pub struct QuantizeConfig {
19 pub skip_layers: Vec<String>,
21 pub min_size: usize,
23 pub num_levels: usize,
25}
26
27impl Default for QuantizeConfig {
28 fn default() -> Self {
29 Self {
30 skip_layers: vec![
31 "embed".to_string(),
33 "lut".to_string(),
34 "out_proj".to_string(),
36 "eos_head".to_string(),
37 ],
38 min_size: 1024, num_levels: 256, }
41 }
42}
43
44#[derive(Debug, Clone)]
49pub struct QuantizedTensor {
50 pub data: Tensor,
52 pub scale: f32,
54 pub zero_point: f32,
56 pub num_levels: usize,
58}
59
60impl QuantizedTensor {
61 pub fn quantize(tensor: &Tensor, num_levels: usize) -> Result<Self> {
66 let tensor_f32 = tensor.to_dtype(DType::F32)?;
68
69 let abs_max = tensor_f32.abs()?.max_all()?.to_scalar::<f32>()?;
71
72 let half_levels = (num_levels / 2) as f32;
74 let scale = if abs_max > 0.0 {
75 abs_max / (half_levels - 1.0)
76 } else {
77 1.0
78 };
79
80 let scale_tensor = Tensor::new(&[scale], tensor.device())?;
83 let quantized = tensor_f32.broadcast_div(&scale_tensor)?;
84 let quantized = quantized.round()?;
85 let clamped = quantized.clamp(-(half_levels - 1.0) as f64, (half_levels - 1.0) as f64)?;
86 let data = clamped.broadcast_mul(&scale_tensor)?;
87
88 Ok(Self {
89 data,
90 scale,
91 zero_point: 0.0, num_levels,
93 })
94 }
95
96 pub fn data(&self) -> &Tensor {
98 &self.data
99 }
100
101 pub fn scale(&self) -> f32 {
103 self.scale
104 }
105
106 pub fn theoretical_memory_savings(&self) -> f32 {
109 match self.num_levels {
110 256 => 4.0, 65536 => 2.0, _ => 1.0,
113 }
114 }
115}
116
117fn should_skip_layer(name: &str, config: &QuantizeConfig) -> bool {
119 config.skip_layers.iter().any(|skip| name.contains(skip))
120}
121
122pub fn quantize_weights(
127 weights: &HashMap<String, Tensor>,
128 config: &QuantizeConfig,
129) -> Result<HashMap<String, QuantizedTensor>> {
130 let mut quantized = HashMap::new();
131
132 for (name, tensor) in weights {
133 if tensor.elem_count() < config.min_size || should_skip_layer(name, config) {
135 quantized.insert(
137 name.clone(),
138 QuantizedTensor {
139 data: tensor.clone(),
140 scale: 1.0,
141 zero_point: 0.0,
142 num_levels: 0, },
144 );
145 } else {
146 quantized.insert(
147 name.clone(),
148 QuantizedTensor::quantize(tensor, config.num_levels)?,
149 );
150 }
151 }
152
153 Ok(quantized)
154}
155
156pub fn calculate_snr(original: &Tensor, quantized: &Tensor) -> Result<f32> {
158 let original_f32 = original.to_dtype(DType::F32)?;
159 let quantized_f32 = quantized.to_dtype(DType::F32)?;
160
161 let signal_power = original_f32.sqr()?.mean_all()?.to_scalar::<f32>()?;
163 let noise = (&original_f32 - &quantized_f32)?;
164 let noise_power = noise.sqr()?.mean_all()?.to_scalar::<f32>()?;
165
166 if noise_power <= 0.0 {
167 return Ok(f32::INFINITY); }
169
170 Ok(10.0 * (signal_power / noise_power).log10())
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use candle_core::Device;
177
178 #[test]
179 fn test_quantize_tensor() {
180 let device = Device::Cpu;
181 let tensor = Tensor::new(&[1.0f32, 2.0, -3.0, 4.5, -2.1], &device).unwrap();
182
183 let quantized = QuantizedTensor::quantize(&tensor, 256).unwrap();
184
185 let snr = calculate_snr(&tensor, &quantized.data).unwrap();
187 assert!(snr > 30.0, "SNR {} is too low", snr);
188 }
189
190 #[test]
191 fn test_quantize_large_tensor() {
192 let device = Device::Cpu;
193 let values: Vec<f32> = (0..10000).map(|i| (i as f32 * 0.01).sin() * 10.0).collect();
195 let tensor = Tensor::new(&values[..], &device).unwrap();
196
197 let quantized = QuantizedTensor::quantize(&tensor, 256).unwrap();
198 let snr = calculate_snr(&tensor, &quantized.data).unwrap();
199
200 assert!(snr > 30.0, "SNR {} is too low", snr);
202 }
203
204 #[test]
205 fn test_quantize_config_skip_layers() {
206 let config = QuantizeConfig::default();
207 assert!(should_skip_layer("model.embed_tokens", &config));
208 assert!(should_skip_layer("decoder.out_proj", &config));
209 assert!(!should_skip_layer("encoder.layers.0.linear", &config));
210 }
211
212 #[test]
213 fn test_theoretical_savings() {
214 let device = Device::Cpu;
215 let tensor = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
216 let quantized = QuantizedTensor::quantize(&tensor, 256).unwrap();
217 assert_eq!(quantized.theoretical_memory_savings(), 4.0);
218 }
219}