1use axonml_tensor::Tensor;
18use half::f16;
19use rayon::prelude::*;
20
21use crate::DEFAULT_BLOCK_SIZE;
22use crate::error::QuantResult;
23use crate::types::{Q4_1Block, Q4Block, Q5Block, Q5_1Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
24
25pub fn quantize_tensor(
46 tensor: &Tensor<f32>,
47 quant_type: QuantType,
48) -> QuantResult<QuantizedTensor> {
49 let data = tensor.to_vec();
50 let shape = tensor.shape().to_vec();
51
52 match quant_type {
53 QuantType::Q8_0 => quantize_q8_0(&data, shape),
54 QuantType::Q4_0 => quantize_q4_0(&data, shape),
55 QuantType::Q4_1 => quantize_q4_1(&data, shape),
56 QuantType::Q5_0 => quantize_q5_0(&data, shape),
57 QuantType::Q5_1 => quantize_q5_1(&data, shape),
58 QuantType::F16 => quantize_f16(&data, shape),
59 QuantType::F32 => quantize_f32(&data, shape),
60 }
61}
62
63pub fn quantize_model(
72 tensors: &[(&str, &Tensor<f32>)],
73 quant_type: QuantType,
74) -> QuantResult<Vec<(String, QuantizedTensor)>> {
75 tensors
76 .par_iter()
77 .map(|(name, tensor)| {
78 let quantized = quantize_tensor(tensor, quant_type)?;
79 Ok((name.to_string(), quantized))
80 })
81 .collect()
82}
83
84fn quantize_q8_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
90 let block_size = DEFAULT_BLOCK_SIZE;
91 let n_blocks = data.len().div_ceil(block_size);
92
93 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
94 .into_par_iter()
95 .map(|block_idx| {
96 let start = block_idx * block_size;
97 let end = (start + block_size).min(data.len());
98 let block_data = &data[start..end];
99
100 let max_abs = block_data
102 .iter()
103 .map(|x| x.abs())
104 .fold(0.0f32, |a, b| a.max(b));
105
106 let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
108
109 let mut quantized = [0i8; 32];
111 for (i, &val) in block_data.iter().enumerate() {
112 let q = (val / scale).round().clamp(-127.0, 127.0) as i8;
113 quantized[i] = q;
114 }
115
116 QuantizedBlock::Q8(Q8Block::new(f16::from_f32(scale), quantized))
117 })
118 .collect();
119
120 Ok(QuantizedTensor::new(shape, QuantType::Q8_0, blocks))
121}
122
123fn quantize_q4_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
129 let block_size = DEFAULT_BLOCK_SIZE;
130 let n_blocks = data.len().div_ceil(block_size);
131
132 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
133 .into_par_iter()
134 .map(|block_idx| {
135 let start = block_idx * block_size;
136 let end = (start + block_size).min(data.len());
137 let block_data = &data[start..end];
138
139 let max_abs = block_data
141 .iter()
142 .map(|x| x.abs())
143 .fold(0.0f32, |a, b| a.max(b));
144
145 let scale = if max_abs > 0.0 { max_abs / 7.0 } else { 1.0 };
147
148 let mut quantized = [0i8; 32];
150 for (i, &val) in block_data.iter().enumerate() {
151 let q = (val / scale).round().clamp(-8.0, 7.0) as i8;
152 quantized[i] = q;
153 }
154
155 let packed = Q4Block::pack(&quantized);
157
158 QuantizedBlock::Q4(Q4Block::new(f16::from_f32(scale), packed))
159 })
160 .collect();
161
162 Ok(QuantizedTensor::new(shape, QuantType::Q4_0, blocks))
163}
164
165fn quantize_q4_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
171 let block_size = DEFAULT_BLOCK_SIZE;
172 let n_blocks = data.len().div_ceil(block_size);
173
174 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
175 .into_par_iter()
176 .map(|block_idx| {
177 let start = block_idx * block_size;
178 let end = (start + block_size).min(data.len());
179 let block_data = &data[start..end];
180
181 let min = block_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
183 let max = block_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
184
185 let scale = if max > min { (max - min) / 15.0 } else { 1.0 };
187
188 let mut quantized = [0u8; 32];
190 for (i, &val) in block_data.iter().enumerate() {
191 let q = ((val - min) / scale).round().clamp(0.0, 15.0) as u8;
192 quantized[i] = q;
193 }
194
195 let mut packed = [0u8; 16];
197 for i in 0..16.min(block_data.len() / 2) {
198 let low = quantized[i * 2] & 0x0F;
199 let high = quantized.get(i * 2 + 1).copied().unwrap_or(0) & 0x0F;
200 packed[i] = low | (high << 4);
201 }
202
203 QuantizedBlock::Q4_1(Q4_1Block::new(
204 f16::from_f32(scale),
205 f16::from_f32(min),
206 packed,
207 ))
208 })
209 .collect();
210
211 Ok(QuantizedTensor::new(shape, QuantType::Q4_1, blocks))
212}
213
214fn quantize_q5_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
220 let block_size = DEFAULT_BLOCK_SIZE;
221 let n_blocks = data.len().div_ceil(block_size);
222
223 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
224 .into_par_iter()
225 .map(|block_idx| {
226 let start = block_idx * block_size;
227 let end = (start + block_size).min(data.len());
228 let block_data = &data[start..end];
229
230 let max_abs = block_data
231 .iter()
232 .map(|x| x.abs())
233 .fold(0.0f32, |a, b| a.max(b));
234
235 let scale = if max_abs > 0.0 { max_abs / 15.0 } else { 1.0 };
237
238 let mut quantized = [0i8; 32];
239 for (i, &val) in block_data.iter().enumerate() {
240 let q = (val / scale).round().clamp(-16.0, 15.0) as i8;
241 quantized[i] = q;
242 }
243
244 let packed = Q5Block::pack(&quantized);
245 QuantizedBlock::Q5(Q5Block::new(f16::from_f32(scale), packed))
246 })
247 .collect();
248
249 Ok(QuantizedTensor::new(shape, QuantType::Q5_0, blocks))
250}
251
252fn quantize_q5_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
258 let block_size = DEFAULT_BLOCK_SIZE;
259 let n_blocks = data.len().div_ceil(block_size);
260
261 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
262 .into_par_iter()
263 .map(|block_idx| {
264 let start = block_idx * block_size;
265 let end = (start + block_size).min(data.len());
266 let block_data = &data[start..end];
267
268 let min = block_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
269 let max = block_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
270
271 let scale = if max > min { (max - min) / 31.0 } else { 1.0 };
273
274 let mut quantized = [0u8; 32];
275 for (i, &val) in block_data.iter().enumerate() {
276 let q = ((val - min) / scale).round().clamp(0.0, 31.0) as u8;
277 quantized[i] = q;
278 }
279
280 let packed = Q5_1Block::pack(&quantized);
281 QuantizedBlock::Q5_1(Q5_1Block::new(f16::from_f32(scale), f16::from_f32(min), packed))
282 })
283 .collect();
284
285 Ok(QuantizedTensor::new(shape, QuantType::Q5_1, blocks))
286}
287
288fn quantize_f16(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
294 let f16_data: Vec<f16> = data.par_iter().map(|&x| f16::from_f32(x)).collect();
295
296 let blocks = vec![QuantizedBlock::F16(f16_data)];
297
298 Ok(QuantizedTensor::new(shape, QuantType::F16, blocks))
299}
300
301fn quantize_f32(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
307 let blocks = vec![QuantizedBlock::F32(data.to_vec())];
308 Ok(QuantizedTensor::new(shape, QuantType::F32, blocks))
309}
310
311pub fn compute_quantization_error(original: &[f32], dequantized: &[f32]) -> f32 {
317 if original.len() != dequantized.len() || original.is_empty() {
318 return f32::INFINITY;
319 }
320
321 let mse: f32 = original
322 .iter()
323 .zip(dequantized.iter())
324 .map(|(a, b)| (a - b).powi(2))
325 .sum::<f32>()
326 / original.len() as f32;
327
328 mse.sqrt()
329}
330
331pub struct QuantizationStats {
333 pub rmse: f32,
335 pub max_error: f32,
337 pub mean_error: f32,
339 pub compression_ratio: f32,
341}
342
343pub fn compute_quantization_stats(
345 original: &[f32],
346 dequantized: &[f32],
347 quant_type: QuantType,
348) -> QuantizationStats {
349 let errors: Vec<f32> = original
350 .iter()
351 .zip(dequantized.iter())
352 .map(|(a, b)| (a - b).abs())
353 .collect();
354
355 let mse: f32 = errors.iter().map(|e| e.powi(2)).sum::<f32>() / errors.len() as f32;
356 let max_error = errors.iter().fold(0.0f32, |a, &b| a.max(b));
357 let mean_error = errors.iter().sum::<f32>() / errors.len() as f32;
358
359 QuantizationStats {
360 rmse: mse.sqrt(),
361 max_error,
362 mean_error,
363 compression_ratio: quant_type.compression_ratio(),
364 }
365}
366
367#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_quantize_q8_0() {
377 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
378 let tensor = Tensor::from_vec(data.clone(), &[8]).unwrap();
379 let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
380
381 assert_eq!(quantized.quant_type, QuantType::Q8_0);
382 assert_eq!(quantized.shape, vec![8]);
383 assert_eq!(quantized.num_blocks(), 1);
384 }
385
386 #[test]
387 fn test_quantize_q4_0() {
388 let data: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
389 let tensor = Tensor::from_vec(data.clone(), &[64]).unwrap();
390 let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
391
392 assert_eq!(quantized.quant_type, QuantType::Q4_0);
393 assert_eq!(quantized.num_blocks(), 2);
394 }
395
396 #[test]
397 fn test_quantize_f16() {
398 let data = vec![1.0, 2.0, 3.0, 4.0];
399 let tensor = Tensor::from_vec(data.clone(), &[4]).unwrap();
400 let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
401
402 assert_eq!(quantized.quant_type, QuantType::F16);
403 }
404
405 #[test]
406 fn test_compression_ratio() {
407 let data: Vec<f32> = (0..256).map(|x| x as f32).collect();
408 let tensor = Tensor::from_vec(data, &[256]).unwrap();
409
410 let q8 = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
411 let q4 = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
412
413 assert!(q8.compression_ratio() > 2.0);
415 assert!(q4.compression_ratio() > q8.compression_ratio());
416 }
417
418 #[test]
419 fn test_quantization_error() {
420 let original = vec![1.0, 2.0, 3.0, 4.0];
421 let dequantized = vec![1.1, 2.0, 2.9, 4.1];
422
423 let rmse = compute_quantization_error(&original, &dequantized);
424 assert!(rmse > 0.0);
425 assert!(rmse < 0.2);
426 }
427}