1use axonml_tensor::Tensor;
9use half::f16;
10use rayon::prelude::*;
11
12use crate::error::QuantResult;
13use crate::types::{QuantType, QuantizedTensor, QuantizedBlock, Q8Block, Q4Block, Q4_1Block};
14use crate::DEFAULT_BLOCK_SIZE;
15
16pub fn quantize_tensor(tensor: &Tensor<f32>, quant_type: QuantType) -> QuantResult<QuantizedTensor> {
37 let data = tensor.to_vec();
38 let shape = tensor.shape().to_vec();
39
40 match quant_type {
41 QuantType::Q8_0 => quantize_q8_0(&data, shape),
42 QuantType::Q4_0 => quantize_q4_0(&data, shape),
43 QuantType::Q4_1 => quantize_q4_1(&data, shape),
44 QuantType::Q5_0 | QuantType::Q5_1 => {
45 quantize_q4_0(&data, shape)
47 }
48 QuantType::F16 => quantize_f16(&data, shape),
49 QuantType::F32 => quantize_f32(&data, shape),
50 }
51}
52
53pub fn quantize_model(
62 tensors: &[(&str, &Tensor<f32>)],
63 quant_type: QuantType,
64) -> QuantResult<Vec<(String, QuantizedTensor)>> {
65 tensors
66 .par_iter()
67 .map(|(name, tensor)| {
68 let quantized = quantize_tensor(tensor, quant_type)?;
69 Ok((name.to_string(), quantized))
70 })
71 .collect()
72}
73
74fn quantize_q8_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
80 let block_size = DEFAULT_BLOCK_SIZE;
81 let n_blocks = (data.len() + block_size - 1) / block_size;
82
83 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
84 .into_par_iter()
85 .map(|block_idx| {
86 let start = block_idx * block_size;
87 let end = (start + block_size).min(data.len());
88 let block_data = &data[start..end];
89
90 let max_abs = block_data
92 .iter()
93 .map(|x| x.abs())
94 .fold(0.0f32, |a, b| a.max(b));
95
96 let scale = if max_abs > 0.0 {
98 max_abs / 127.0
99 } else {
100 1.0
101 };
102
103 let mut quantized = [0i8; 32];
105 for (i, &val) in block_data.iter().enumerate() {
106 let q = (val / scale).round().clamp(-127.0, 127.0) as i8;
107 quantized[i] = q;
108 }
109
110 QuantizedBlock::Q8(Q8Block::new(f16::from_f32(scale), quantized))
111 })
112 .collect();
113
114 Ok(QuantizedTensor::new(shape, QuantType::Q8_0, blocks))
115}
116
117fn quantize_q4_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
123 let block_size = DEFAULT_BLOCK_SIZE;
124 let n_blocks = (data.len() + block_size - 1) / block_size;
125
126 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
127 .into_par_iter()
128 .map(|block_idx| {
129 let start = block_idx * block_size;
130 let end = (start + block_size).min(data.len());
131 let block_data = &data[start..end];
132
133 let max_abs = block_data
135 .iter()
136 .map(|x| x.abs())
137 .fold(0.0f32, |a, b| a.max(b));
138
139 let scale = if max_abs > 0.0 {
141 max_abs / 7.0
142 } else {
143 1.0
144 };
145
146 let mut quantized = [0i8; 32];
148 for (i, &val) in block_data.iter().enumerate() {
149 let q = (val / scale).round().clamp(-8.0, 7.0) as i8;
150 quantized[i] = q;
151 }
152
153 let packed = Q4Block::pack(&quantized);
155
156 QuantizedBlock::Q4(Q4Block::new(f16::from_f32(scale), packed))
157 })
158 .collect();
159
160 Ok(QuantizedTensor::new(shape, QuantType::Q4_0, blocks))
161}
162
163fn quantize_q4_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
169 let block_size = DEFAULT_BLOCK_SIZE;
170 let n_blocks = (data.len() + block_size - 1) / block_size;
171
172 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
173 .into_par_iter()
174 .map(|block_idx| {
175 let start = block_idx * block_size;
176 let end = (start + block_size).min(data.len());
177 let block_data = &data[start..end];
178
179 let min = block_data
181 .iter()
182 .fold(f32::INFINITY, |a, &b| a.min(b));
183 let max = block_data
184 .iter()
185 .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
186
187 let scale = if max > min {
189 (max - min) / 15.0
190 } else {
191 1.0
192 };
193
194 let mut quantized = [0u8; 32];
196 for (i, &val) in block_data.iter().enumerate() {
197 let q = ((val - min) / scale).round().clamp(0.0, 15.0) as u8;
198 quantized[i] = q;
199 }
200
201 let mut packed = [0u8; 16];
203 for i in 0..16.min(block_data.len() / 2) {
204 let low = quantized[i * 2] & 0x0F;
205 let high = quantized.get(i * 2 + 1).copied().unwrap_or(0) & 0x0F;
206 packed[i] = low | (high << 4);
207 }
208
209 QuantizedBlock::Q4_1(Q4_1Block::new(
210 f16::from_f32(scale),
211 f16::from_f32(min),
212 packed,
213 ))
214 })
215 .collect();
216
217 Ok(QuantizedTensor::new(shape, QuantType::Q4_1, blocks))
218}
219
220fn quantize_f16(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
226 let f16_data: Vec<f16> = data
227 .par_iter()
228 .map(|&x| f16::from_f32(x))
229 .collect();
230
231 let blocks = vec![QuantizedBlock::F16(f16_data)];
232
233 Ok(QuantizedTensor::new(shape, QuantType::F16, blocks))
234}
235
236fn quantize_f32(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
242 let blocks = vec![QuantizedBlock::F32(data.to_vec())];
243 Ok(QuantizedTensor::new(shape, QuantType::F32, blocks))
244}
245
246pub fn compute_quantization_error(original: &[f32], dequantized: &[f32]) -> f32 {
252 if original.len() != dequantized.len() || original.is_empty() {
253 return f32::INFINITY;
254 }
255
256 let mse: f32 = original
257 .iter()
258 .zip(dequantized.iter())
259 .map(|(a, b)| (a - b).powi(2))
260 .sum::<f32>()
261 / original.len() as f32;
262
263 mse.sqrt()
264}
265
266pub struct QuantizationStats {
268 pub rmse: f32,
270 pub max_error: f32,
272 pub mean_error: f32,
274 pub compression_ratio: f32,
276}
277
278pub fn compute_quantization_stats(
280 original: &[f32],
281 dequantized: &[f32],
282 quant_type: QuantType,
283) -> QuantizationStats {
284 let errors: Vec<f32> = original
285 .iter()
286 .zip(dequantized.iter())
287 .map(|(a, b)| (a - b).abs())
288 .collect();
289
290 let mse: f32 = errors.iter().map(|e| e.powi(2)).sum::<f32>() / errors.len() as f32;
291 let max_error = errors.iter().fold(0.0f32, |a, &b| a.max(b));
292 let mean_error = errors.iter().sum::<f32>() / errors.len() as f32;
293
294 QuantizationStats {
295 rmse: mse.sqrt(),
296 max_error,
297 mean_error,
298 compression_ratio: quant_type.compression_ratio(),
299 }
300}
301
302#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_quantize_q8_0() {
312 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
313 let tensor = Tensor::from_vec(data.clone(), &[8]).unwrap();
314 let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
315
316 assert_eq!(quantized.quant_type, QuantType::Q8_0);
317 assert_eq!(quantized.shape, vec![8]);
318 assert_eq!(quantized.num_blocks(), 1);
319 }
320
321 #[test]
322 fn test_quantize_q4_0() {
323 let data: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
324 let tensor = Tensor::from_vec(data.clone(), &[64]).unwrap();
325 let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
326
327 assert_eq!(quantized.quant_type, QuantType::Q4_0);
328 assert_eq!(quantized.num_blocks(), 2);
329 }
330
331 #[test]
332 fn test_quantize_f16() {
333 let data = vec![1.0, 2.0, 3.0, 4.0];
334 let tensor = Tensor::from_vec(data.clone(), &[4]).unwrap();
335 let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
336
337 assert_eq!(quantized.quant_type, QuantType::F16);
338 }
339
340 #[test]
341 fn test_compression_ratio() {
342 let data: Vec<f32> = (0..256).map(|x| x as f32).collect();
343 let tensor = Tensor::from_vec(data, &[256]).unwrap();
344
345 let q8 = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
346 let q4 = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
347
348 assert!(q8.compression_ratio() > 2.0);
350 assert!(q4.compression_ratio() > q8.compression_ratio());
351 }
352
353 #[test]
354 fn test_quantization_error() {
355 let original = vec![1.0, 2.0, 3.0, 4.0];
356 let dequantized = vec![1.1, 2.0, 2.9, 4.1];
357
358 let rmse = compute_quantization_error(&original, &dequantized);
359 assert!(rmse > 0.0);
360 assert!(rmse < 0.2);
361 }
362}