1use axonml_tensor::Tensor;
9use half::f16;
10use rayon::prelude::*;
11
12use crate::error::QuantResult;
13use crate::types::{Q4Block, Q4_1Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
14use crate::DEFAULT_BLOCK_SIZE;
15
16pub fn quantize_tensor(
37 tensor: &Tensor<f32>,
38 quant_type: QuantType,
39) -> QuantResult<QuantizedTensor> {
40 let data = tensor.to_vec();
41 let shape = tensor.shape().to_vec();
42
43 match quant_type {
44 QuantType::Q8_0 => quantize_q8_0(&data, shape),
45 QuantType::Q4_0 => quantize_q4_0(&data, shape),
46 QuantType::Q4_1 => quantize_q4_1(&data, shape),
47 QuantType::Q5_0 | QuantType::Q5_1 => {
48 quantize_q4_0(&data, shape)
50 }
51 QuantType::F16 => quantize_f16(&data, shape),
52 QuantType::F32 => quantize_f32(&data, shape),
53 }
54}
55
56pub fn quantize_model(
65 tensors: &[(&str, &Tensor<f32>)],
66 quant_type: QuantType,
67) -> QuantResult<Vec<(String, QuantizedTensor)>> {
68 tensors
69 .par_iter()
70 .map(|(name, tensor)| {
71 let quantized = quantize_tensor(tensor, quant_type)?;
72 Ok((name.to_string(), quantized))
73 })
74 .collect()
75}
76
77fn quantize_q8_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
83 let block_size = DEFAULT_BLOCK_SIZE;
84 let n_blocks = (data.len() + block_size - 1) / block_size;
85
86 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
87 .into_par_iter()
88 .map(|block_idx| {
89 let start = block_idx * block_size;
90 let end = (start + block_size).min(data.len());
91 let block_data = &data[start..end];
92
93 let max_abs = block_data
95 .iter()
96 .map(|x| x.abs())
97 .fold(0.0f32, |a, b| a.max(b));
98
99 let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
101
102 let mut quantized = [0i8; 32];
104 for (i, &val) in block_data.iter().enumerate() {
105 let q = (val / scale).round().clamp(-127.0, 127.0) as i8;
106 quantized[i] = q;
107 }
108
109 QuantizedBlock::Q8(Q8Block::new(f16::from_f32(scale), quantized))
110 })
111 .collect();
112
113 Ok(QuantizedTensor::new(shape, QuantType::Q8_0, blocks))
114}
115
116fn quantize_q4_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
122 let block_size = DEFAULT_BLOCK_SIZE;
123 let n_blocks = (data.len() + block_size - 1) / block_size;
124
125 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
126 .into_par_iter()
127 .map(|block_idx| {
128 let start = block_idx * block_size;
129 let end = (start + block_size).min(data.len());
130 let block_data = &data[start..end];
131
132 let max_abs = block_data
134 .iter()
135 .map(|x| x.abs())
136 .fold(0.0f32, |a, b| a.max(b));
137
138 let scale = if max_abs > 0.0 { max_abs / 7.0 } else { 1.0 };
140
141 let mut quantized = [0i8; 32];
143 for (i, &val) in block_data.iter().enumerate() {
144 let q = (val / scale).round().clamp(-8.0, 7.0) as i8;
145 quantized[i] = q;
146 }
147
148 let packed = Q4Block::pack(&quantized);
150
151 QuantizedBlock::Q4(Q4Block::new(f16::from_f32(scale), packed))
152 })
153 .collect();
154
155 Ok(QuantizedTensor::new(shape, QuantType::Q4_0, blocks))
156}
157
158fn quantize_q4_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
164 let block_size = DEFAULT_BLOCK_SIZE;
165 let n_blocks = (data.len() + block_size - 1) / block_size;
166
167 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
168 .into_par_iter()
169 .map(|block_idx| {
170 let start = block_idx * block_size;
171 let end = (start + block_size).min(data.len());
172 let block_data = &data[start..end];
173
174 let min = block_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
176 let max = block_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
177
178 let scale = if max > min { (max - min) / 15.0 } else { 1.0 };
180
181 let mut quantized = [0u8; 32];
183 for (i, &val) in block_data.iter().enumerate() {
184 let q = ((val - min) / scale).round().clamp(0.0, 15.0) as u8;
185 quantized[i] = q;
186 }
187
188 let mut packed = [0u8; 16];
190 for i in 0..16.min(block_data.len() / 2) {
191 let low = quantized[i * 2] & 0x0F;
192 let high = quantized.get(i * 2 + 1).copied().unwrap_or(0) & 0x0F;
193 packed[i] = low | (high << 4);
194 }
195
196 QuantizedBlock::Q4_1(Q4_1Block::new(
197 f16::from_f32(scale),
198 f16::from_f32(min),
199 packed,
200 ))
201 })
202 .collect();
203
204 Ok(QuantizedTensor::new(shape, QuantType::Q4_1, blocks))
205}
206
207fn quantize_f16(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
213 let f16_data: Vec<f16> = data.par_iter().map(|&x| f16::from_f32(x)).collect();
214
215 let blocks = vec![QuantizedBlock::F16(f16_data)];
216
217 Ok(QuantizedTensor::new(shape, QuantType::F16, blocks))
218}
219
220fn quantize_f32(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
226 let blocks = vec![QuantizedBlock::F32(data.to_vec())];
227 Ok(QuantizedTensor::new(shape, QuantType::F32, blocks))
228}
229
230pub fn compute_quantization_error(original: &[f32], dequantized: &[f32]) -> f32 {
236 if original.len() != dequantized.len() || original.is_empty() {
237 return f32::INFINITY;
238 }
239
240 let mse: f32 = original
241 .iter()
242 .zip(dequantized.iter())
243 .map(|(a, b)| (a - b).powi(2))
244 .sum::<f32>()
245 / original.len() as f32;
246
247 mse.sqrt()
248}
249
250pub struct QuantizationStats {
252 pub rmse: f32,
254 pub max_error: f32,
256 pub mean_error: f32,
258 pub compression_ratio: f32,
260}
261
262pub fn compute_quantization_stats(
264 original: &[f32],
265 dequantized: &[f32],
266 quant_type: QuantType,
267) -> QuantizationStats {
268 let errors: Vec<f32> = original
269 .iter()
270 .zip(dequantized.iter())
271 .map(|(a, b)| (a - b).abs())
272 .collect();
273
274 let mse: f32 = errors.iter().map(|e| e.powi(2)).sum::<f32>() / errors.len() as f32;
275 let max_error = errors.iter().fold(0.0f32, |a, &b| a.max(b));
276 let mean_error = errors.iter().sum::<f32>() / errors.len() as f32;
277
278 QuantizationStats {
279 rmse: mse.sqrt(),
280 max_error,
281 mean_error,
282 compression_ratio: quant_type.compression_ratio(),
283 }
284}
285
286#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[test]
295 fn test_quantize_q8_0() {
296 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
297 let tensor = Tensor::from_vec(data.clone(), &[8]).unwrap();
298 let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
299
300 assert_eq!(quantized.quant_type, QuantType::Q8_0);
301 assert_eq!(quantized.shape, vec![8]);
302 assert_eq!(quantized.num_blocks(), 1);
303 }
304
305 #[test]
306 fn test_quantize_q4_0() {
307 let data: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
308 let tensor = Tensor::from_vec(data.clone(), &[64]).unwrap();
309 let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
310
311 assert_eq!(quantized.quant_type, QuantType::Q4_0);
312 assert_eq!(quantized.num_blocks(), 2);
313 }
314
315 #[test]
316 fn test_quantize_f16() {
317 let data = vec![1.0, 2.0, 3.0, 4.0];
318 let tensor = Tensor::from_vec(data.clone(), &[4]).unwrap();
319 let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
320
321 assert_eq!(quantized.quant_type, QuantType::F16);
322 }
323
324 #[test]
325 fn test_compression_ratio() {
326 let data: Vec<f32> = (0..256).map(|x| x as f32).collect();
327 let tensor = Tensor::from_vec(data, &[256]).unwrap();
328
329 let q8 = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
330 let q4 = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
331
332 assert!(q8.compression_ratio() > 2.0);
334 assert!(q4.compression_ratio() > q8.compression_ratio());
335 }
336
337 #[test]
338 fn test_quantization_error() {
339 let original = vec![1.0, 2.0, 3.0, 4.0];
340 let dequantized = vec![1.1, 2.0, 2.9, 4.1];
341
342 let rmse = compute_quantization_error(&original, &dequantized);
343 assert!(rmse > 0.0);
344 assert!(rmse < 0.2);
345 }
346}