1use axonml_tensor::Tensor;
19use half::f16;
20use rayon::prelude::*;
21
22use crate::DEFAULT_BLOCK_SIZE;
23use crate::error::QuantResult;
24use crate::types::{
25 Q4_1Block, Q4Block, Q5_1Block, Q5Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor,
26};
27
28pub fn quantize_tensor(
49 tensor: &Tensor<f32>,
50 quant_type: QuantType,
51) -> QuantResult<QuantizedTensor> {
52 let data = tensor.to_vec();
53 let shape = tensor.shape().to_vec();
54
55 match quant_type {
56 QuantType::Q8_0 => quantize_q8_0(&data, shape),
57 QuantType::Q4_0 => quantize_q4_0(&data, shape),
58 QuantType::Q4_1 => quantize_q4_1(&data, shape),
59 QuantType::Q5_0 => quantize_q5_0(&data, shape),
60 QuantType::Q5_1 => quantize_q5_1(&data, shape),
61 QuantType::F16 => quantize_f16(&data, shape),
62 QuantType::F32 => quantize_f32(&data, shape),
63 }
64}
65
66pub fn quantize_model(
75 tensors: &[(&str, &Tensor<f32>)],
76 quant_type: QuantType,
77) -> QuantResult<Vec<(String, QuantizedTensor)>> {
78 tensors
79 .par_iter()
80 .map(|(name, tensor)| {
81 let quantized = quantize_tensor(tensor, quant_type)?;
82 Ok((name.to_string(), quantized))
83 })
84 .collect()
85}
86
87fn quantize_q8_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
93 let block_size = DEFAULT_BLOCK_SIZE;
94 let n_blocks = data.len().div_ceil(block_size);
95
96 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
97 .into_par_iter()
98 .map(|block_idx| {
99 let start = block_idx * block_size;
100 let end = (start + block_size).min(data.len());
101 let block_data = &data[start..end];
102
103 let max_abs = block_data
105 .iter()
106 .map(|x| x.abs())
107 .fold(0.0f32, |a, b| a.max(b));
108
109 let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
111
112 let mut quantized = [0i8; 32];
114 for (i, &val) in block_data.iter().enumerate() {
115 let q = (val / scale).round().clamp(-127.0, 127.0) as i8;
116 quantized[i] = q;
117 }
118
119 QuantizedBlock::Q8(Q8Block::new(f16::from_f32(scale), quantized))
120 })
121 .collect();
122
123 Ok(QuantizedTensor::new(shape, QuantType::Q8_0, blocks))
124}
125
126fn quantize_q4_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
132 let block_size = DEFAULT_BLOCK_SIZE;
133 let n_blocks = data.len().div_ceil(block_size);
134
135 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
136 .into_par_iter()
137 .map(|block_idx| {
138 let start = block_idx * block_size;
139 let end = (start + block_size).min(data.len());
140 let block_data = &data[start..end];
141
142 let max_abs = block_data
144 .iter()
145 .map(|x| x.abs())
146 .fold(0.0f32, |a, b| a.max(b));
147
148 let scale = if max_abs > 0.0 { max_abs / 7.0 } else { 1.0 };
150
151 let mut quantized = [0i8; 32];
153 for (i, &val) in block_data.iter().enumerate() {
154 let q = (val / scale).round().clamp(-8.0, 7.0) as i8;
155 quantized[i] = q;
156 }
157
158 let packed = Q4Block::pack(&quantized);
160
161 QuantizedBlock::Q4(Q4Block::new(f16::from_f32(scale), packed))
162 })
163 .collect();
164
165 Ok(QuantizedTensor::new(shape, QuantType::Q4_0, blocks))
166}
167
168fn quantize_q4_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
174 let block_size = DEFAULT_BLOCK_SIZE;
175 let n_blocks = data.len().div_ceil(block_size);
176
177 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
178 .into_par_iter()
179 .map(|block_idx| {
180 let start = block_idx * block_size;
181 let end = (start + block_size).min(data.len());
182 let block_data = &data[start..end];
183
184 let min = block_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
186 let max = block_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
187
188 let scale = if max > min { (max - min) / 15.0 } else { 1.0 };
190
191 let mut quantized = [0u8; 32];
193 for (i, &val) in block_data.iter().enumerate() {
194 let q = ((val - min) / scale).round().clamp(0.0, 15.0) as u8;
195 quantized[i] = q;
196 }
197
198 let mut packed = [0u8; 16];
200 for i in 0..16.min(block_data.len() / 2) {
201 let low = quantized[i * 2] & 0x0F;
202 let high = quantized.get(i * 2 + 1).copied().unwrap_or(0) & 0x0F;
203 packed[i] = low | (high << 4);
204 }
205
206 QuantizedBlock::Q4_1(Q4_1Block::new(
207 f16::from_f32(scale),
208 f16::from_f32(min),
209 packed,
210 ))
211 })
212 .collect();
213
214 Ok(QuantizedTensor::new(shape, QuantType::Q4_1, blocks))
215}
216
217fn quantize_q5_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
223 let block_size = DEFAULT_BLOCK_SIZE;
224 let n_blocks = data.len().div_ceil(block_size);
225
226 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
227 .into_par_iter()
228 .map(|block_idx| {
229 let start = block_idx * block_size;
230 let end = (start + block_size).min(data.len());
231 let block_data = &data[start..end];
232
233 let max_abs = block_data
234 .iter()
235 .map(|x| x.abs())
236 .fold(0.0f32, |a, b| a.max(b));
237
238 let scale = if max_abs > 0.0 { max_abs / 15.0 } else { 1.0 };
240
241 let mut quantized = [0i8; 32];
242 for (i, &val) in block_data.iter().enumerate() {
243 let q = (val / scale).round().clamp(-16.0, 15.0) as i8;
244 quantized[i] = q;
245 }
246
247 let packed = Q5Block::pack(&quantized);
248 QuantizedBlock::Q5(Q5Block::new(f16::from_f32(scale), packed))
249 })
250 .collect();
251
252 Ok(QuantizedTensor::new(shape, QuantType::Q5_0, blocks))
253}
254
255fn quantize_q5_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
261 let block_size = DEFAULT_BLOCK_SIZE;
262 let n_blocks = data.len().div_ceil(block_size);
263
264 let blocks: Vec<QuantizedBlock> = (0..n_blocks)
265 .into_par_iter()
266 .map(|block_idx| {
267 let start = block_idx * block_size;
268 let end = (start + block_size).min(data.len());
269 let block_data = &data[start..end];
270
271 let min = block_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
272 let max = block_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
273
274 let scale = if max > min { (max - min) / 31.0 } else { 1.0 };
276
277 let mut quantized = [0u8; 32];
278 for (i, &val) in block_data.iter().enumerate() {
279 let q = ((val - min) / scale).round().clamp(0.0, 31.0) as u8;
280 quantized[i] = q;
281 }
282
283 let packed = Q5_1Block::pack(&quantized);
284 QuantizedBlock::Q5_1(Q5_1Block::new(
285 f16::from_f32(scale),
286 f16::from_f32(min),
287 packed,
288 ))
289 })
290 .collect();
291
292 Ok(QuantizedTensor::new(shape, QuantType::Q5_1, blocks))
293}
294
295fn quantize_f16(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
301 let f16_data: Vec<f16> = data.par_iter().map(|&x| f16::from_f32(x)).collect();
302
303 let blocks = vec![QuantizedBlock::F16(f16_data)];
304
305 Ok(QuantizedTensor::new(shape, QuantType::F16, blocks))
306}
307
308fn quantize_f32(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
314 let blocks = vec![QuantizedBlock::F32(data.to_vec())];
315 Ok(QuantizedTensor::new(shape, QuantType::F32, blocks))
316}
317
318pub fn compute_quantization_error(original: &[f32], dequantized: &[f32]) -> f32 {
324 if original.len() != dequantized.len() || original.is_empty() {
325 return f32::INFINITY;
326 }
327
328 let mse: f32 = original
329 .iter()
330 .zip(dequantized.iter())
331 .map(|(a, b)| (a - b).powi(2))
332 .sum::<f32>()
333 / original.len() as f32;
334
335 mse.sqrt()
336}
337
338pub struct QuantizationStats {
340 pub rmse: f32,
342 pub max_error: f32,
344 pub mean_error: f32,
346 pub compression_ratio: f32,
348}
349
350pub fn compute_quantization_stats(
352 original: &[f32],
353 dequantized: &[f32],
354 quant_type: QuantType,
355) -> QuantizationStats {
356 let errors: Vec<f32> = original
357 .iter()
358 .zip(dequantized.iter())
359 .map(|(a, b)| (a - b).abs())
360 .collect();
361
362 let mse: f32 = errors.iter().map(|e| e.powi(2)).sum::<f32>() / errors.len() as f32;
363 let max_error = errors.iter().fold(0.0f32, |a, &b| a.max(b));
364 let mean_error = errors.iter().sum::<f32>() / errors.len() as f32;
365
366 QuantizationStats {
367 rmse: mse.sqrt(),
368 max_error,
369 mean_error,
370 compression_ratio: quant_type.compression_ratio(),
371 }
372}
373
374#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn test_quantize_q8_0() {
384 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
385 let tensor = Tensor::from_vec(data.clone(), &[8]).unwrap();
386 let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
387
388 assert_eq!(quantized.quant_type, QuantType::Q8_0);
389 assert_eq!(quantized.shape, vec![8]);
390 assert_eq!(quantized.num_blocks(), 1);
391 }
392
393 #[test]
394 fn test_quantize_q4_0() {
395 let data: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
396 let tensor = Tensor::from_vec(data.clone(), &[64]).unwrap();
397 let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
398
399 assert_eq!(quantized.quant_type, QuantType::Q4_0);
400 assert_eq!(quantized.num_blocks(), 2);
401 }
402
403 #[test]
404 fn test_quantize_f16() {
405 let data = vec![1.0, 2.0, 3.0, 4.0];
406 let tensor = Tensor::from_vec(data.clone(), &[4]).unwrap();
407 let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
408
409 assert_eq!(quantized.quant_type, QuantType::F16);
410 }
411
412 #[test]
413 fn test_compression_ratio() {
414 let data: Vec<f32> = (0..256).map(|x| x as f32).collect();
415 let tensor = Tensor::from_vec(data, &[256]).unwrap();
416
417 let q8 = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
418 let q4 = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
419
420 assert!(q8.compression_ratio() > 2.0);
422 assert!(q4.compression_ratio() > q8.compression_ratio());
423 }
424
425 #[test]
426 fn test_quantization_error() {
427 let original = vec![1.0, 2.0, 3.0, 4.0];
428 let dequantized = vec![1.1, 2.0, 2.9, 4.1];
429
430 let rmse = compute_quantization_error(&original, &dequantized);
431 assert!(rmse > 0.0);
432 assert!(rmse < 0.2);
433 }
434}