Skip to main content

axonml_quant/
quantize.rs

1//! Quantization Functions
2//!
3//! Functions for quantizing tensors to various formats.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use axonml_tensor::Tensor;
9use half::f16;
10use rayon::prelude::*;
11
12use crate::error::QuantResult;
13use crate::types::{QuantType, QuantizedTensor, QuantizedBlock, Q8Block, Q4Block, Q4_1Block};
14use crate::DEFAULT_BLOCK_SIZE;
15
16// =============================================================================
17// Public API
18// =============================================================================
19
20/// Quantizes a tensor to the specified quantization type.
21///
22/// # Arguments
23/// * `tensor` - The input tensor to quantize
24/// * `quant_type` - The target quantization type
25///
26/// # Returns
27/// A quantized tensor
28///
29/// # Example
30/// ```ignore
31/// use axonml_quant::{quantize_tensor, QuantType};
32///
33/// let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4])?;
34/// let quantized = quantize_tensor(&tensor, QuantType::Q8_0)?;
35/// ```
36pub fn quantize_tensor(tensor: &Tensor<f32>, quant_type: QuantType) -> QuantResult<QuantizedTensor> {
37    let data = tensor.to_vec();
38    let shape = tensor.shape().to_vec();
39
40    match quant_type {
41        QuantType::Q8_0 => quantize_q8_0(&data, shape),
42        QuantType::Q4_0 => quantize_q4_0(&data, shape),
43        QuantType::Q4_1 => quantize_q4_1(&data, shape),
44        QuantType::Q5_0 | QuantType::Q5_1 => {
45            // Fall back to Q4 for now
46            quantize_q4_0(&data, shape)
47        }
48        QuantType::F16 => quantize_f16(&data, shape),
49        QuantType::F32 => quantize_f32(&data, shape),
50    }
51}
52
53/// Quantizes a model (collection of named tensors).
54///
55/// # Arguments
56/// * `tensors` - Named tensors to quantize
57/// * `quant_type` - The target quantization type
58///
59/// # Returns
60/// A map of quantized tensors
61pub fn quantize_model(
62    tensors: &[(&str, &Tensor<f32>)],
63    quant_type: QuantType,
64) -> QuantResult<Vec<(String, QuantizedTensor)>> {
65    tensors
66        .par_iter()
67        .map(|(name, tensor)| {
68            let quantized = quantize_tensor(tensor, quant_type)?;
69            Ok((name.to_string(), quantized))
70        })
71        .collect()
72}
73
74// =============================================================================
75// Q8_0 Quantization
76// =============================================================================
77
78/// Quantizes data to Q8_0 format (8-bit with per-block scale).
79fn quantize_q8_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
80    let block_size = DEFAULT_BLOCK_SIZE;
81    let n_blocks = (data.len() + block_size - 1) / block_size;
82
83    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
84        .into_par_iter()
85        .map(|block_idx| {
86            let start = block_idx * block_size;
87            let end = (start + block_size).min(data.len());
88            let block_data = &data[start..end];
89
90            // Find max absolute value for scale
91            let max_abs = block_data
92                .iter()
93                .map(|x| x.abs())
94                .fold(0.0f32, |a, b| a.max(b));
95
96            // Compute scale (avoid division by zero)
97            let scale = if max_abs > 0.0 {
98                max_abs / 127.0
99            } else {
100                1.0
101            };
102
103            // Quantize to int8
104            let mut quantized = [0i8; 32];
105            for (i, &val) in block_data.iter().enumerate() {
106                let q = (val / scale).round().clamp(-127.0, 127.0) as i8;
107                quantized[i] = q;
108            }
109
110            QuantizedBlock::Q8(Q8Block::new(f16::from_f32(scale), quantized))
111        })
112        .collect();
113
114    Ok(QuantizedTensor::new(shape, QuantType::Q8_0, blocks))
115}
116
117// =============================================================================
118// Q4_0 Quantization
119// =============================================================================
120
121/// Quantizes data to Q4_0 format (4-bit with per-block scale).
122fn quantize_q4_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
123    let block_size = DEFAULT_BLOCK_SIZE;
124    let n_blocks = (data.len() + block_size - 1) / block_size;
125
126    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
127        .into_par_iter()
128        .map(|block_idx| {
129            let start = block_idx * block_size;
130            let end = (start + block_size).min(data.len());
131            let block_data = &data[start..end];
132
133            // Find max absolute value for scale
134            let max_abs = block_data
135                .iter()
136                .map(|x| x.abs())
137                .fold(0.0f32, |a, b| a.max(b));
138
139            // Compute scale (4-bit range is -8 to 7)
140            let scale = if max_abs > 0.0 {
141                max_abs / 7.0
142            } else {
143                1.0
144            };
145
146            // Quantize to 4-bit (stored as i8 in range -8 to 7)
147            let mut quantized = [0i8; 32];
148            for (i, &val) in block_data.iter().enumerate() {
149                let q = (val / scale).round().clamp(-8.0, 7.0) as i8;
150                quantized[i] = q;
151            }
152
153            // Pack into bytes
154            let packed = Q4Block::pack(&quantized);
155
156            QuantizedBlock::Q4(Q4Block::new(f16::from_f32(scale), packed))
157        })
158        .collect();
159
160    Ok(QuantizedTensor::new(shape, QuantType::Q4_0, blocks))
161}
162
163// =============================================================================
164// Q4_1 Quantization
165// =============================================================================
166
167/// Quantizes data to Q4_1 format (4-bit with per-block scale and min).
168fn quantize_q4_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
169    let block_size = DEFAULT_BLOCK_SIZE;
170    let n_blocks = (data.len() + block_size - 1) / block_size;
171
172    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
173        .into_par_iter()
174        .map(|block_idx| {
175            let start = block_idx * block_size;
176            let end = (start + block_size).min(data.len());
177            let block_data = &data[start..end];
178
179            // Find min and max
180            let min = block_data
181                .iter()
182                .fold(f32::INFINITY, |a, &b| a.min(b));
183            let max = block_data
184                .iter()
185                .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
186
187            // Compute scale (4-bit unsigned range is 0 to 15)
188            let scale = if max > min {
189                (max - min) / 15.0
190            } else {
191                1.0
192            };
193
194            // Quantize to 4-bit unsigned
195            let mut quantized = [0u8; 32];
196            for (i, &val) in block_data.iter().enumerate() {
197                let q = ((val - min) / scale).round().clamp(0.0, 15.0) as u8;
198                quantized[i] = q;
199            }
200
201            // Pack into bytes
202            let mut packed = [0u8; 16];
203            for i in 0..16.min(block_data.len() / 2) {
204                let low = quantized[i * 2] & 0x0F;
205                let high = quantized.get(i * 2 + 1).copied().unwrap_or(0) & 0x0F;
206                packed[i] = low | (high << 4);
207            }
208
209            QuantizedBlock::Q4_1(Q4_1Block::new(
210                f16::from_f32(scale),
211                f16::from_f32(min),
212                packed,
213            ))
214        })
215        .collect();
216
217    Ok(QuantizedTensor::new(shape, QuantType::Q4_1, blocks))
218}
219
220// =============================================================================
221// F16 Quantization
222// =============================================================================
223
224/// Quantizes data to F16 format (half precision).
225fn quantize_f16(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
226    let f16_data: Vec<f16> = data
227        .par_iter()
228        .map(|&x| f16::from_f32(x))
229        .collect();
230
231    let blocks = vec![QuantizedBlock::F16(f16_data)];
232
233    Ok(QuantizedTensor::new(shape, QuantType::F16, blocks))
234}
235
236// =============================================================================
237// F32 (No Quantization)
238// =============================================================================
239
240/// Stores data as F32 (no quantization, for comparison).
241fn quantize_f32(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
242    let blocks = vec![QuantizedBlock::F32(data.to_vec())];
243    Ok(QuantizedTensor::new(shape, QuantType::F32, blocks))
244}
245
246// =============================================================================
247// Utility Functions
248// =============================================================================
249
250/// Computes the quantization error (RMSE) between original and quantized.
251pub fn compute_quantization_error(original: &[f32], dequantized: &[f32]) -> f32 {
252    if original.len() != dequantized.len() || original.is_empty() {
253        return f32::INFINITY;
254    }
255
256    let mse: f32 = original
257        .iter()
258        .zip(dequantized.iter())
259        .map(|(a, b)| (a - b).powi(2))
260        .sum::<f32>()
261        / original.len() as f32;
262
263    mse.sqrt()
264}
265
266/// Returns statistics about quantization error.
267pub struct QuantizationStats {
268    /// Root mean square error.
269    pub rmse: f32,
270    /// Maximum absolute error.
271    pub max_error: f32,
272    /// Mean absolute error.
273    pub mean_error: f32,
274    /// Compression ratio.
275    pub compression_ratio: f32,
276}
277
278/// Computes detailed quantization statistics.
279pub fn compute_quantization_stats(
280    original: &[f32],
281    dequantized: &[f32],
282    quant_type: QuantType,
283) -> QuantizationStats {
284    let errors: Vec<f32> = original
285        .iter()
286        .zip(dequantized.iter())
287        .map(|(a, b)| (a - b).abs())
288        .collect();
289
290    let mse: f32 = errors.iter().map(|e| e.powi(2)).sum::<f32>() / errors.len() as f32;
291    let max_error = errors.iter().fold(0.0f32, |a, &b| a.max(b));
292    let mean_error = errors.iter().sum::<f32>() / errors.len() as f32;
293
294    QuantizationStats {
295        rmse: mse.sqrt(),
296        max_error,
297        mean_error,
298        compression_ratio: quant_type.compression_ratio(),
299    }
300}
301
302// =============================================================================
303// Tests
304// =============================================================================
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_quantize_q8_0() {
312        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
313        let tensor = Tensor::from_vec(data.clone(), &[8]).unwrap();
314        let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
315
316        assert_eq!(quantized.quant_type, QuantType::Q8_0);
317        assert_eq!(quantized.shape, vec![8]);
318        assert_eq!(quantized.num_blocks(), 1);
319    }
320
321    #[test]
322    fn test_quantize_q4_0() {
323        let data: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
324        let tensor = Tensor::from_vec(data.clone(), &[64]).unwrap();
325        let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
326
327        assert_eq!(quantized.quant_type, QuantType::Q4_0);
328        assert_eq!(quantized.num_blocks(), 2);
329    }
330
331    #[test]
332    fn test_quantize_f16() {
333        let data = vec![1.0, 2.0, 3.0, 4.0];
334        let tensor = Tensor::from_vec(data.clone(), &[4]).unwrap();
335        let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
336
337        assert_eq!(quantized.quant_type, QuantType::F16);
338    }
339
340    #[test]
341    fn test_compression_ratio() {
342        let data: Vec<f32> = (0..256).map(|x| x as f32).collect();
343        let tensor = Tensor::from_vec(data, &[256]).unwrap();
344
345        let q8 = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
346        let q4 = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
347
348        // Q8 should be about 4x compression, Q4 about 8x
349        assert!(q8.compression_ratio() > 2.0);
350        assert!(q4.compression_ratio() > q8.compression_ratio());
351    }
352
353    #[test]
354    fn test_quantization_error() {
355        let original = vec![1.0, 2.0, 3.0, 4.0];
356        let dequantized = vec![1.1, 2.0, 2.9, 4.1];
357
358        let rmse = compute_quantization_error(&original, &dequantized);
359        assert!(rmse > 0.0);
360        assert!(rmse < 0.2);
361    }
362}