Skip to main content

axonml_quant/
quantize.rs

1//! Quantization Functions
2//!
3//! # File
4//! `crates/axonml-quant/src/quantize.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_tensor::Tensor;
18use half::f16;
19use rayon::prelude::*;
20
21use crate::DEFAULT_BLOCK_SIZE;
22use crate::error::QuantResult;
23use crate::types::{Q4_1Block, Q4Block, Q5Block, Q5_1Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
24
25// =============================================================================
26// Public API
27// =============================================================================
28
29/// Quantizes a tensor to the specified quantization type.
30///
31/// # Arguments
32/// * `tensor` - The input tensor to quantize
33/// * `quant_type` - The target quantization type
34///
35/// # Returns
36/// A quantized tensor
37///
38/// # Example
39/// ```ignore
40/// use axonml_quant::{quantize_tensor, QuantType};
41///
42/// let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4])?;
43/// let quantized = quantize_tensor(&tensor, QuantType::Q8_0)?;
44/// ```
45pub fn quantize_tensor(
46    tensor: &Tensor<f32>,
47    quant_type: QuantType,
48) -> QuantResult<QuantizedTensor> {
49    let data = tensor.to_vec();
50    let shape = tensor.shape().to_vec();
51
52    match quant_type {
53        QuantType::Q8_0 => quantize_q8_0(&data, shape),
54        QuantType::Q4_0 => quantize_q4_0(&data, shape),
55        QuantType::Q4_1 => quantize_q4_1(&data, shape),
56        QuantType::Q5_0 => quantize_q5_0(&data, shape),
57        QuantType::Q5_1 => quantize_q5_1(&data, shape),
58        QuantType::F16 => quantize_f16(&data, shape),
59        QuantType::F32 => quantize_f32(&data, shape),
60    }
61}
62
63/// Quantizes a model (collection of named tensors).
64///
65/// # Arguments
66/// * `tensors` - Named tensors to quantize
67/// * `quant_type` - The target quantization type
68///
69/// # Returns
70/// A map of quantized tensors
71pub fn quantize_model(
72    tensors: &[(&str, &Tensor<f32>)],
73    quant_type: QuantType,
74) -> QuantResult<Vec<(String, QuantizedTensor)>> {
75    tensors
76        .par_iter()
77        .map(|(name, tensor)| {
78            let quantized = quantize_tensor(tensor, quant_type)?;
79            Ok((name.to_string(), quantized))
80        })
81        .collect()
82}
83
84// =============================================================================
85// Q8_0 Quantization
86// =============================================================================
87
88/// Quantizes data to Q8_0 format (8-bit with per-block scale).
89fn quantize_q8_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
90    let block_size = DEFAULT_BLOCK_SIZE;
91    let n_blocks = data.len().div_ceil(block_size);
92
93    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
94        .into_par_iter()
95        .map(|block_idx| {
96            let start = block_idx * block_size;
97            let end = (start + block_size).min(data.len());
98            let block_data = &data[start..end];
99
100            // Find max absolute value for scale
101            let max_abs = block_data
102                .iter()
103                .map(|x| x.abs())
104                .fold(0.0f32, |a, b| a.max(b));
105
106            // Compute scale (avoid division by zero)
107            let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
108
109            // Quantize to int8
110            let mut quantized = [0i8; 32];
111            for (i, &val) in block_data.iter().enumerate() {
112                let q = (val / scale).round().clamp(-127.0, 127.0) as i8;
113                quantized[i] = q;
114            }
115
116            QuantizedBlock::Q8(Q8Block::new(f16::from_f32(scale), quantized))
117        })
118        .collect();
119
120    Ok(QuantizedTensor::new(shape, QuantType::Q8_0, blocks))
121}
122
123// =============================================================================
124// Q4_0 Quantization
125// =============================================================================
126
127/// Quantizes data to Q4_0 format (4-bit with per-block scale).
128fn quantize_q4_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
129    let block_size = DEFAULT_BLOCK_SIZE;
130    let n_blocks = data.len().div_ceil(block_size);
131
132    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
133        .into_par_iter()
134        .map(|block_idx| {
135            let start = block_idx * block_size;
136            let end = (start + block_size).min(data.len());
137            let block_data = &data[start..end];
138
139            // Find max absolute value for scale
140            let max_abs = block_data
141                .iter()
142                .map(|x| x.abs())
143                .fold(0.0f32, |a, b| a.max(b));
144
145            // Compute scale (4-bit range is -8 to 7)
146            let scale = if max_abs > 0.0 { max_abs / 7.0 } else { 1.0 };
147
148            // Quantize to 4-bit (stored as i8 in range -8 to 7)
149            let mut quantized = [0i8; 32];
150            for (i, &val) in block_data.iter().enumerate() {
151                let q = (val / scale).round().clamp(-8.0, 7.0) as i8;
152                quantized[i] = q;
153            }
154
155            // Pack into bytes
156            let packed = Q4Block::pack(&quantized);
157
158            QuantizedBlock::Q4(Q4Block::new(f16::from_f32(scale), packed))
159        })
160        .collect();
161
162    Ok(QuantizedTensor::new(shape, QuantType::Q4_0, blocks))
163}
164
165// =============================================================================
166// Q4_1 Quantization
167// =============================================================================
168
169/// Quantizes data to Q4_1 format (4-bit with per-block scale and min).
170fn quantize_q4_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
171    let block_size = DEFAULT_BLOCK_SIZE;
172    let n_blocks = data.len().div_ceil(block_size);
173
174    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
175        .into_par_iter()
176        .map(|block_idx| {
177            let start = block_idx * block_size;
178            let end = (start + block_size).min(data.len());
179            let block_data = &data[start..end];
180
181            // Find min and max
182            let min = block_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
183            let max = block_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
184
185            // Compute scale (4-bit unsigned range is 0 to 15)
186            let scale = if max > min { (max - min) / 15.0 } else { 1.0 };
187
188            // Quantize to 4-bit unsigned
189            let mut quantized = [0u8; 32];
190            for (i, &val) in block_data.iter().enumerate() {
191                let q = ((val - min) / scale).round().clamp(0.0, 15.0) as u8;
192                quantized[i] = q;
193            }
194
195            // Pack into bytes
196            let mut packed = [0u8; 16];
197            for i in 0..16.min(block_data.len() / 2) {
198                let low = quantized[i * 2] & 0x0F;
199                let high = quantized.get(i * 2 + 1).copied().unwrap_or(0) & 0x0F;
200                packed[i] = low | (high << 4);
201            }
202
203            QuantizedBlock::Q4_1(Q4_1Block::new(
204                f16::from_f32(scale),
205                f16::from_f32(min),
206                packed,
207            ))
208        })
209        .collect();
210
211    Ok(QuantizedTensor::new(shape, QuantType::Q4_1, blocks))
212}
213
214// =============================================================================
215// Q5_0 Quantization (5-bit symmetric)
216// =============================================================================
217
218/// Quantizes data to Q5_0 format (5-bit signed with per-block scale).
219fn quantize_q5_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
220    let block_size = DEFAULT_BLOCK_SIZE;
221    let n_blocks = data.len().div_ceil(block_size);
222
223    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
224        .into_par_iter()
225        .map(|block_idx| {
226            let start = block_idx * block_size;
227            let end = (start + block_size).min(data.len());
228            let block_data = &data[start..end];
229
230            let max_abs = block_data
231                .iter()
232                .map(|x| x.abs())
233                .fold(0.0f32, |a, b| a.max(b));
234
235            // 5-bit signed range: -16 to 15
236            let scale = if max_abs > 0.0 { max_abs / 15.0 } else { 1.0 };
237
238            let mut quantized = [0i8; 32];
239            for (i, &val) in block_data.iter().enumerate() {
240                let q = (val / scale).round().clamp(-16.0, 15.0) as i8;
241                quantized[i] = q;
242            }
243
244            let packed = Q5Block::pack(&quantized);
245            QuantizedBlock::Q5(Q5Block::new(f16::from_f32(scale), packed))
246        })
247        .collect();
248
249    Ok(QuantizedTensor::new(shape, QuantType::Q5_0, blocks))
250}
251
252// =============================================================================
253// Q5_1 Quantization (5-bit asymmetric)
254// =============================================================================
255
256/// Quantizes data to Q5_1 format (5-bit unsigned with per-block scale and min).
257fn quantize_q5_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
258    let block_size = DEFAULT_BLOCK_SIZE;
259    let n_blocks = data.len().div_ceil(block_size);
260
261    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
262        .into_par_iter()
263        .map(|block_idx| {
264            let start = block_idx * block_size;
265            let end = (start + block_size).min(data.len());
266            let block_data = &data[start..end];
267
268            let min = block_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
269            let max = block_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
270
271            // 5-bit unsigned range: 0 to 31
272            let scale = if max > min { (max - min) / 31.0 } else { 1.0 };
273
274            let mut quantized = [0u8; 32];
275            for (i, &val) in block_data.iter().enumerate() {
276                let q = ((val - min) / scale).round().clamp(0.0, 31.0) as u8;
277                quantized[i] = q;
278            }
279
280            let packed = Q5_1Block::pack(&quantized);
281            QuantizedBlock::Q5_1(Q5_1Block::new(f16::from_f32(scale), f16::from_f32(min), packed))
282        })
283        .collect();
284
285    Ok(QuantizedTensor::new(shape, QuantType::Q5_1, blocks))
286}
287
288// =============================================================================
289// F16 Quantization
290// =============================================================================
291
292/// Quantizes data to F16 format (half precision).
293fn quantize_f16(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
294    let f16_data: Vec<f16> = data.par_iter().map(|&x| f16::from_f32(x)).collect();
295
296    let blocks = vec![QuantizedBlock::F16(f16_data)];
297
298    Ok(QuantizedTensor::new(shape, QuantType::F16, blocks))
299}
300
301// =============================================================================
302// F32 (No Quantization)
303// =============================================================================
304
305/// Stores data as F32 (no quantization, for comparison).
306fn quantize_f32(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
307    let blocks = vec![QuantizedBlock::F32(data.to_vec())];
308    Ok(QuantizedTensor::new(shape, QuantType::F32, blocks))
309}
310
311// =============================================================================
312// Utility Functions
313// =============================================================================
314
315/// Computes the quantization error (RMSE) between original and quantized.
316pub fn compute_quantization_error(original: &[f32], dequantized: &[f32]) -> f32 {
317    if original.len() != dequantized.len() || original.is_empty() {
318        return f32::INFINITY;
319    }
320
321    let mse: f32 = original
322        .iter()
323        .zip(dequantized.iter())
324        .map(|(a, b)| (a - b).powi(2))
325        .sum::<f32>()
326        / original.len() as f32;
327
328    mse.sqrt()
329}
330
331/// Returns statistics about quantization error.
332pub struct QuantizationStats {
333    /// Root mean square error.
334    pub rmse: f32,
335    /// Maximum absolute error.
336    pub max_error: f32,
337    /// Mean absolute error.
338    pub mean_error: f32,
339    /// Compression ratio.
340    pub compression_ratio: f32,
341}
342
343/// Computes detailed quantization statistics.
344pub fn compute_quantization_stats(
345    original: &[f32],
346    dequantized: &[f32],
347    quant_type: QuantType,
348) -> QuantizationStats {
349    let errors: Vec<f32> = original
350        .iter()
351        .zip(dequantized.iter())
352        .map(|(a, b)| (a - b).abs())
353        .collect();
354
355    let mse: f32 = errors.iter().map(|e| e.powi(2)).sum::<f32>() / errors.len() as f32;
356    let max_error = errors.iter().fold(0.0f32, |a, &b| a.max(b));
357    let mean_error = errors.iter().sum::<f32>() / errors.len() as f32;
358
359    QuantizationStats {
360        rmse: mse.sqrt(),
361        max_error,
362        mean_error,
363        compression_ratio: quant_type.compression_ratio(),
364    }
365}
366
367// =============================================================================
368// Tests
369// =============================================================================
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    #[test]
376    fn test_quantize_q8_0() {
377        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
378        let tensor = Tensor::from_vec(data.clone(), &[8]).unwrap();
379        let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
380
381        assert_eq!(quantized.quant_type, QuantType::Q8_0);
382        assert_eq!(quantized.shape, vec![8]);
383        assert_eq!(quantized.num_blocks(), 1);
384    }
385
386    #[test]
387    fn test_quantize_q4_0() {
388        let data: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
389        let tensor = Tensor::from_vec(data.clone(), &[64]).unwrap();
390        let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
391
392        assert_eq!(quantized.quant_type, QuantType::Q4_0);
393        assert_eq!(quantized.num_blocks(), 2);
394    }
395
396    #[test]
397    fn test_quantize_f16() {
398        let data = vec![1.0, 2.0, 3.0, 4.0];
399        let tensor = Tensor::from_vec(data.clone(), &[4]).unwrap();
400        let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
401
402        assert_eq!(quantized.quant_type, QuantType::F16);
403    }
404
405    #[test]
406    fn test_compression_ratio() {
407        let data: Vec<f32> = (0..256).map(|x| x as f32).collect();
408        let tensor = Tensor::from_vec(data, &[256]).unwrap();
409
410        let q8 = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
411        let q4 = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
412
413        // Q8 should be about 4x compression, Q4 about 8x
414        assert!(q8.compression_ratio() > 2.0);
415        assert!(q4.compression_ratio() > q8.compression_ratio());
416    }
417
418    #[test]
419    fn test_quantization_error() {
420        let original = vec![1.0, 2.0, 3.0, 4.0];
421        let dequantized = vec![1.1, 2.0, 2.9, 4.1];
422
423        let rmse = compute_quantization_error(&original, &dequantized);
424        assert!(rmse > 0.0);
425        assert!(rmse < 0.2);
426    }
427}