Skip to main content

axonml_quant/
quantize.rs

1//! Quantization Functions
2//!
3//! Functions for quantizing tensors to various formats.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use axonml_tensor::Tensor;
9use half::f16;
10use rayon::prelude::*;
11
12use crate::error::QuantResult;
13use crate::types::{Q4Block, Q4_1Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
14use crate::DEFAULT_BLOCK_SIZE;
15
16// =============================================================================
17// Public API
18// =============================================================================
19
20/// Quantizes a tensor to the specified quantization type.
21///
22/// # Arguments
23/// * `tensor` - The input tensor to quantize
24/// * `quant_type` - The target quantization type
25///
26/// # Returns
27/// A quantized tensor
28///
29/// # Example
30/// ```ignore
31/// use axonml_quant::{quantize_tensor, QuantType};
32///
33/// let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4])?;
34/// let quantized = quantize_tensor(&tensor, QuantType::Q8_0)?;
35/// ```
36pub fn quantize_tensor(
37    tensor: &Tensor<f32>,
38    quant_type: QuantType,
39) -> QuantResult<QuantizedTensor> {
40    let data = tensor.to_vec();
41    let shape = tensor.shape().to_vec();
42
43    match quant_type {
44        QuantType::Q8_0 => quantize_q8_0(&data, shape),
45        QuantType::Q4_0 => quantize_q4_0(&data, shape),
46        QuantType::Q4_1 => quantize_q4_1(&data, shape),
47        QuantType::Q5_0 | QuantType::Q5_1 => {
48            // Fall back to Q4 for now
49            quantize_q4_0(&data, shape)
50        }
51        QuantType::F16 => quantize_f16(&data, shape),
52        QuantType::F32 => quantize_f32(&data, shape),
53    }
54}
55
56/// Quantizes a model (collection of named tensors).
57///
58/// # Arguments
59/// * `tensors` - Named tensors to quantize
60/// * `quant_type` - The target quantization type
61///
62/// # Returns
63/// A map of quantized tensors
64pub fn quantize_model(
65    tensors: &[(&str, &Tensor<f32>)],
66    quant_type: QuantType,
67) -> QuantResult<Vec<(String, QuantizedTensor)>> {
68    tensors
69        .par_iter()
70        .map(|(name, tensor)| {
71            let quantized = quantize_tensor(tensor, quant_type)?;
72            Ok((name.to_string(), quantized))
73        })
74        .collect()
75}
76
77// =============================================================================
78// Q8_0 Quantization
79// =============================================================================
80
81/// Quantizes data to Q8_0 format (8-bit with per-block scale).
82fn quantize_q8_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
83    let block_size = DEFAULT_BLOCK_SIZE;
84    let n_blocks = (data.len() + block_size - 1) / block_size;
85
86    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
87        .into_par_iter()
88        .map(|block_idx| {
89            let start = block_idx * block_size;
90            let end = (start + block_size).min(data.len());
91            let block_data = &data[start..end];
92
93            // Find max absolute value for scale
94            let max_abs = block_data
95                .iter()
96                .map(|x| x.abs())
97                .fold(0.0f32, |a, b| a.max(b));
98
99            // Compute scale (avoid division by zero)
100            let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
101
102            // Quantize to int8
103            let mut quantized = [0i8; 32];
104            for (i, &val) in block_data.iter().enumerate() {
105                let q = (val / scale).round().clamp(-127.0, 127.0) as i8;
106                quantized[i] = q;
107            }
108
109            QuantizedBlock::Q8(Q8Block::new(f16::from_f32(scale), quantized))
110        })
111        .collect();
112
113    Ok(QuantizedTensor::new(shape, QuantType::Q8_0, blocks))
114}
115
116// =============================================================================
117// Q4_0 Quantization
118// =============================================================================
119
120/// Quantizes data to Q4_0 format (4-bit with per-block scale).
121fn quantize_q4_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
122    let block_size = DEFAULT_BLOCK_SIZE;
123    let n_blocks = (data.len() + block_size - 1) / block_size;
124
125    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
126        .into_par_iter()
127        .map(|block_idx| {
128            let start = block_idx * block_size;
129            let end = (start + block_size).min(data.len());
130            let block_data = &data[start..end];
131
132            // Find max absolute value for scale
133            let max_abs = block_data
134                .iter()
135                .map(|x| x.abs())
136                .fold(0.0f32, |a, b| a.max(b));
137
138            // Compute scale (4-bit range is -8 to 7)
139            let scale = if max_abs > 0.0 { max_abs / 7.0 } else { 1.0 };
140
141            // Quantize to 4-bit (stored as i8 in range -8 to 7)
142            let mut quantized = [0i8; 32];
143            for (i, &val) in block_data.iter().enumerate() {
144                let q = (val / scale).round().clamp(-8.0, 7.0) as i8;
145                quantized[i] = q;
146            }
147
148            // Pack into bytes
149            let packed = Q4Block::pack(&quantized);
150
151            QuantizedBlock::Q4(Q4Block::new(f16::from_f32(scale), packed))
152        })
153        .collect();
154
155    Ok(QuantizedTensor::new(shape, QuantType::Q4_0, blocks))
156}
157
158// =============================================================================
159// Q4_1 Quantization
160// =============================================================================
161
162/// Quantizes data to Q4_1 format (4-bit with per-block scale and min).
163fn quantize_q4_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
164    let block_size = DEFAULT_BLOCK_SIZE;
165    let n_blocks = (data.len() + block_size - 1) / block_size;
166
167    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
168        .into_par_iter()
169        .map(|block_idx| {
170            let start = block_idx * block_size;
171            let end = (start + block_size).min(data.len());
172            let block_data = &data[start..end];
173
174            // Find min and max
175            let min = block_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
176            let max = block_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
177
178            // Compute scale (4-bit unsigned range is 0 to 15)
179            let scale = if max > min { (max - min) / 15.0 } else { 1.0 };
180
181            // Quantize to 4-bit unsigned
182            let mut quantized = [0u8; 32];
183            for (i, &val) in block_data.iter().enumerate() {
184                let q = ((val - min) / scale).round().clamp(0.0, 15.0) as u8;
185                quantized[i] = q;
186            }
187
188            // Pack into bytes
189            let mut packed = [0u8; 16];
190            for i in 0..16.min(block_data.len() / 2) {
191                let low = quantized[i * 2] & 0x0F;
192                let high = quantized.get(i * 2 + 1).copied().unwrap_or(0) & 0x0F;
193                packed[i] = low | (high << 4);
194            }
195
196            QuantizedBlock::Q4_1(Q4_1Block::new(
197                f16::from_f32(scale),
198                f16::from_f32(min),
199                packed,
200            ))
201        })
202        .collect();
203
204    Ok(QuantizedTensor::new(shape, QuantType::Q4_1, blocks))
205}
206
207// =============================================================================
208// F16 Quantization
209// =============================================================================
210
211/// Quantizes data to F16 format (half precision).
212fn quantize_f16(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
213    let f16_data: Vec<f16> = data.par_iter().map(|&x| f16::from_f32(x)).collect();
214
215    let blocks = vec![QuantizedBlock::F16(f16_data)];
216
217    Ok(QuantizedTensor::new(shape, QuantType::F16, blocks))
218}
219
220// =============================================================================
221// F32 (No Quantization)
222// =============================================================================
223
224/// Stores data as F32 (no quantization, for comparison).
225fn quantize_f32(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
226    let blocks = vec![QuantizedBlock::F32(data.to_vec())];
227    Ok(QuantizedTensor::new(shape, QuantType::F32, blocks))
228}
229
230// =============================================================================
231// Utility Functions
232// =============================================================================
233
234/// Computes the quantization error (RMSE) between original and quantized.
235pub fn compute_quantization_error(original: &[f32], dequantized: &[f32]) -> f32 {
236    if original.len() != dequantized.len() || original.is_empty() {
237        return f32::INFINITY;
238    }
239
240    let mse: f32 = original
241        .iter()
242        .zip(dequantized.iter())
243        .map(|(a, b)| (a - b).powi(2))
244        .sum::<f32>()
245        / original.len() as f32;
246
247    mse.sqrt()
248}
249
250/// Returns statistics about quantization error.
251pub struct QuantizationStats {
252    /// Root mean square error.
253    pub rmse: f32,
254    /// Maximum absolute error.
255    pub max_error: f32,
256    /// Mean absolute error.
257    pub mean_error: f32,
258    /// Compression ratio.
259    pub compression_ratio: f32,
260}
261
262/// Computes detailed quantization statistics.
263pub fn compute_quantization_stats(
264    original: &[f32],
265    dequantized: &[f32],
266    quant_type: QuantType,
267) -> QuantizationStats {
268    let errors: Vec<f32> = original
269        .iter()
270        .zip(dequantized.iter())
271        .map(|(a, b)| (a - b).abs())
272        .collect();
273
274    let mse: f32 = errors.iter().map(|e| e.powi(2)).sum::<f32>() / errors.len() as f32;
275    let max_error = errors.iter().fold(0.0f32, |a, &b| a.max(b));
276    let mean_error = errors.iter().sum::<f32>() / errors.len() as f32;
277
278    QuantizationStats {
279        rmse: mse.sqrt(),
280        max_error,
281        mean_error,
282        compression_ratio: quant_type.compression_ratio(),
283    }
284}
285
286// =============================================================================
287// Tests
288// =============================================================================
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    #[test]
295    fn test_quantize_q8_0() {
296        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
297        let tensor = Tensor::from_vec(data.clone(), &[8]).unwrap();
298        let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
299
300        assert_eq!(quantized.quant_type, QuantType::Q8_0);
301        assert_eq!(quantized.shape, vec![8]);
302        assert_eq!(quantized.num_blocks(), 1);
303    }
304
305    #[test]
306    fn test_quantize_q4_0() {
307        let data: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
308        let tensor = Tensor::from_vec(data.clone(), &[64]).unwrap();
309        let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
310
311        assert_eq!(quantized.quant_type, QuantType::Q4_0);
312        assert_eq!(quantized.num_blocks(), 2);
313    }
314
315    #[test]
316    fn test_quantize_f16() {
317        let data = vec![1.0, 2.0, 3.0, 4.0];
318        let tensor = Tensor::from_vec(data.clone(), &[4]).unwrap();
319        let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
320
321        assert_eq!(quantized.quant_type, QuantType::F16);
322    }
323
324    #[test]
325    fn test_compression_ratio() {
326        let data: Vec<f32> = (0..256).map(|x| x as f32).collect();
327        let tensor = Tensor::from_vec(data, &[256]).unwrap();
328
329        let q8 = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
330        let q4 = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
331
332        // Q8 should be about 4x compression, Q4 about 8x
333        assert!(q8.compression_ratio() > 2.0);
334        assert!(q4.compression_ratio() > q8.compression_ratio());
335    }
336
337    #[test]
338    fn test_quantization_error() {
339        let original = vec![1.0, 2.0, 3.0, 4.0];
340        let dequantized = vec![1.1, 2.0, 2.9, 4.1];
341
342        let rmse = compute_quantization_error(&original, &dequantized);
343        assert!(rmse > 0.0);
344        assert!(rmse < 0.2);
345    }
346}