1use axonml_tensor::Tensor;
18use half::f16;
19use rayon::prelude::*;
20
21use crate::DEFAULT_BLOCK_SIZE;
22use crate::error::QuantResult;
23use crate::types::{
24 Q4_1Block, Q4Block, Q5_1Block, Q5Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor,
25};
26
27pub fn quantize_tensor(
48 tensor: &Tensor<f32>,
49 quant_type: QuantType,
50) -> QuantResult<QuantizedTensor> {
51 let data = tensor.to_vec();
52 let shape = tensor.shape().to_vec();
53
54 match quant_type {
55 QuantType::Q8_0 => quantize_q8_0(&data, shape),
56 QuantType::Q4_0 => quantize_q4_0(&data, shape),
57 QuantType::Q4_1 => quantize_q4_1(&data, shape),
58 QuantType::Q5_0 => quantize_q5_0(&data, shape),
59 QuantType::Q5_1 => quantize_q5_1(&data, shape),
60 QuantType::F16 => quantize_f16(&data, shape),
61 QuantType::F32 => quantize_f32(&data, shape),
62 }
63}
64
65pub fn quantize_model(
74 tensors: &[(&str, &Tensor<f32>)],
75 quant_type: QuantType,
76) -> QuantResult<Vec<(String, QuantizedTensor)>> {
77 tensors
78 .par_iter()
79 .map(|(name, tensor)| {
80 let quantized = quantize_tensor(tensor, quant_type)?;
81 Ok((name.to_string(), quantized))
82 })
83 .collect()
84}
85
86fn quantize_q8_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
92 let block_size = DEFAULT_BLOCK_SIZE;
93 let n_blocks = data.len().div_ceil(block_size);
94
95 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
96 .into_par_iter()
97 .map(|block_idx| {
98 let start = block_idx * block_size;
99 let end = (start + block_size).min(data.len());
100 let block_data = &data[start..end];
101
102 let max_abs = block_data
104 .iter()
105 .map(|x| x.abs())
106 .fold(0.0f32, |a, b| a.max(b));
107
108 let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
110
111 let mut quantized = [0i8; 32];
113 for (i, &val) in block_data.iter().enumerate() {
114 let q = (val / scale).round().clamp(-127.0, 127.0) as i8;
115 quantized[i] = q;
116 }
117
118 QuantizedBlock::Q8(Q8Block::new(f16::from_f32(scale), quantized))
119 })
120 .collect();
121
122 Ok(QuantizedTensor::new(shape, QuantType::Q8_0, blocks))
123}
124
125fn quantize_q4_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
131 let block_size = DEFAULT_BLOCK_SIZE;
132 let n_blocks = data.len().div_ceil(block_size);
133
134 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
135 .into_par_iter()
136 .map(|block_idx| {
137 let start = block_idx * block_size;
138 let end = (start + block_size).min(data.len());
139 let block_data = &data[start..end];
140
141 let max_abs = block_data
143 .iter()
144 .map(|x| x.abs())
145 .fold(0.0f32, |a, b| a.max(b));
146
147 let scale = if max_abs > 0.0 { max_abs / 7.0 } else { 1.0 };
149
150 let mut quantized = [0i8; 32];
152 for (i, &val) in block_data.iter().enumerate() {
153 let q = (val / scale).round().clamp(-8.0, 7.0) as i8;
154 quantized[i] = q;
155 }
156
157 let packed = Q4Block::pack(&quantized);
159
160 QuantizedBlock::Q4(Q4Block::new(f16::from_f32(scale), packed))
161 })
162 .collect();
163
164 Ok(QuantizedTensor::new(shape, QuantType::Q4_0, blocks))
165}
166
167fn quantize_q4_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
173 let block_size = DEFAULT_BLOCK_SIZE;
174 let n_blocks = data.len().div_ceil(block_size);
175
176 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
177 .into_par_iter()
178 .map(|block_idx| {
179 let start = block_idx * block_size;
180 let end = (start + block_size).min(data.len());
181 let block_data = &data[start..end];
182
183 let min = block_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
185 let max = block_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
186
187 let scale = if max > min { (max - min) / 15.0 } else { 1.0 };
189
190 let mut quantized = [0u8; 32];
192 for (i, &val) in block_data.iter().enumerate() {
193 let q = ((val - min) / scale).round().clamp(0.0, 15.0) as u8;
194 quantized[i] = q;
195 }
196
197 let mut packed = [0u8; 16];
199 for i in 0..16.min(block_data.len() / 2) {
200 let low = quantized[i * 2] & 0x0F;
201 let high = quantized.get(i * 2 + 1).copied().unwrap_or(0) & 0x0F;
202 packed[i] = low | (high << 4);
203 }
204
205 QuantizedBlock::Q4_1(Q4_1Block::new(
206 f16::from_f32(scale),
207 f16::from_f32(min),
208 packed,
209 ))
210 })
211 .collect();
212
213 Ok(QuantizedTensor::new(shape, QuantType::Q4_1, blocks))
214}
215
216fn quantize_q5_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
222 let block_size = DEFAULT_BLOCK_SIZE;
223 let n_blocks = data.len().div_ceil(block_size);
224
225 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
226 .into_par_iter()
227 .map(|block_idx| {
228 let start = block_idx * block_size;
229 let end = (start + block_size).min(data.len());
230 let block_data = &data[start..end];
231
232 let max_abs = block_data
233 .iter()
234 .map(|x| x.abs())
235 .fold(0.0f32, |a, b| a.max(b));
236
237 let scale = if max_abs > 0.0 { max_abs / 15.0 } else { 1.0 };
239
240 let mut quantized = [0i8; 32];
241 for (i, &val) in block_data.iter().enumerate() {
242 let q = (val / scale).round().clamp(-16.0, 15.0) as i8;
243 quantized[i] = q;
244 }
245
246 let packed = Q5Block::pack(&quantized);
247 QuantizedBlock::Q5(Q5Block::new(f16::from_f32(scale), packed))
248 })
249 .collect();
250
251 Ok(QuantizedTensor::new(shape, QuantType::Q5_0, blocks))
252}
253
254fn quantize_q5_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
260 let block_size = DEFAULT_BLOCK_SIZE;
261 let n_blocks = data.len().div_ceil(block_size);
262
263 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
264 .into_par_iter()
265 .map(|block_idx| {
266 let start = block_idx * block_size;
267 let end = (start + block_size).min(data.len());
268 let block_data = &data[start..end];
269
270 let min = block_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
271 let max = block_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
272
273 let scale = if max > min { (max - min) / 31.0 } else { 1.0 };
275
276 let mut quantized = [0u8; 32];
277 for (i, &val) in block_data.iter().enumerate() {
278 let q = ((val - min) / scale).round().clamp(0.0, 31.0) as u8;
279 quantized[i] = q;
280 }
281
282 let packed = Q5_1Block::pack(&quantized);
283 QuantizedBlock::Q5_1(Q5_1Block::new(
284 f16::from_f32(scale),
285 f16::from_f32(min),
286 packed,
287 ))
288 })
289 .collect();
290
291 Ok(QuantizedTensor::new(shape, QuantType::Q5_1, blocks))
292}
293
294fn quantize_f16(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
300 let f16_data: Vec<f16> = data.par_iter().map(|&x| f16::from_f32(x)).collect();
301
302 let blocks = vec![QuantizedBlock::F16(f16_data)];
303
304 Ok(QuantizedTensor::new(shape, QuantType::F16, blocks))
305}
306
307fn quantize_f32(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
313 let blocks = vec![QuantizedBlock::F32(data.to_vec())];
314 Ok(QuantizedTensor::new(shape, QuantType::F32, blocks))
315}
316
317pub fn compute_quantization_error(original: &[f32], dequantized: &[f32]) -> f32 {
323 if original.len() != dequantized.len() || original.is_empty() {
324 return f32::INFINITY;
325 }
326
327 let mse: f32 = original
328 .iter()
329 .zip(dequantized.iter())
330 .map(|(a, b)| (a - b).powi(2))
331 .sum::<f32>()
332 / original.len() as f32;
333
334 mse.sqrt()
335}
336
337pub struct QuantizationStats {
339 pub rmse: f32,
341 pub max_error: f32,
343 pub mean_error: f32,
345 pub compression_ratio: f32,
347}
348
349pub fn compute_quantization_stats(
351 original: &[f32],
352 dequantized: &[f32],
353 quant_type: QuantType,
354) -> QuantizationStats {
355 let errors: Vec<f32> = original
356 .iter()
357 .zip(dequantized.iter())
358 .map(|(a, b)| (a - b).abs())
359 .collect();
360
361 let mse: f32 = errors.iter().map(|e| e.powi(2)).sum::<f32>() / errors.len() as f32;
362 let max_error = errors.iter().fold(0.0f32, |a, &b| a.max(b));
363 let mean_error = errors.iter().sum::<f32>() / errors.len() as f32;
364
365 QuantizationStats {
366 rmse: mse.sqrt(),
367 max_error,
368 mean_error,
369 compression_ratio: quant_type.compression_ratio(),
370 }
371}
372
373#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[test]
382 fn test_quantize_q8_0() {
383 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
384 let tensor = Tensor::from_vec(data.clone(), &[8]).unwrap();
385 let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
386
387 assert_eq!(quantized.quant_type, QuantType::Q8_0);
388 assert_eq!(quantized.shape, vec![8]);
389 assert_eq!(quantized.num_blocks(), 1);
390 }
391
392 #[test]
393 fn test_quantize_q4_0() {
394 let data: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
395 let tensor = Tensor::from_vec(data.clone(), &[64]).unwrap();
396 let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
397
398 assert_eq!(quantized.quant_type, QuantType::Q4_0);
399 assert_eq!(quantized.num_blocks(), 2);
400 }
401
402 #[test]
403 fn test_quantize_f16() {
404 let data = vec![1.0, 2.0, 3.0, 4.0];
405 let tensor = Tensor::from_vec(data.clone(), &[4]).unwrap();
406 let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
407
408 assert_eq!(quantized.quant_type, QuantType::F16);
409 }
410
411 #[test]
412 fn test_compression_ratio() {
413 let data: Vec<f32> = (0..256).map(|x| x as f32).collect();
414 let tensor = Tensor::from_vec(data, &[256]).unwrap();
415
416 let q8 = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
417 let q4 = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
418
419 assert!(q8.compression_ratio() > 2.0);
421 assert!(q4.compression_ratio() > q8.compression_ratio());
422 }
423
424 #[test]
425 fn test_quantization_error() {
426 let original = vec![1.0, 2.0, 3.0, 4.0];
427 let dequantized = vec![1.1, 2.0, 2.9, 4.1];
428
429 let rmse = compute_quantization_error(&original, &dequantized);
430 assert!(rmse > 0.0);
431 assert!(rmse < 0.2);
432 }
433}