Skip to main content

axonml_quant/
quantize.rs

1//! Quantization Functions
2//!
3//! # File
4//! `crates/axonml-quant/src/quantize.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_tensor::Tensor;
18use half::f16;
19use rayon::prelude::*;
20
21use crate::DEFAULT_BLOCK_SIZE;
22use crate::error::QuantResult;
23use crate::types::{
24    Q4_1Block, Q4Block, Q5_1Block, Q5Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor,
25};
26
27// =============================================================================
28// Public API
29// =============================================================================
30
31/// Quantizes a tensor to the specified quantization type.
32///
33/// # Arguments
34/// * `tensor` - The input tensor to quantize
35/// * `quant_type` - The target quantization type
36///
37/// # Returns
38/// A quantized tensor
39///
40/// # Example
41/// ```ignore
42/// use axonml_quant::{quantize_tensor, QuantType};
43///
44/// let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4])?;
45/// let quantized = quantize_tensor(&tensor, QuantType::Q8_0)?;
46/// ```
47pub fn quantize_tensor(
48    tensor: &Tensor<f32>,
49    quant_type: QuantType,
50) -> QuantResult<QuantizedTensor> {
51    let data = tensor.to_vec();
52    let shape = tensor.shape().to_vec();
53
54    match quant_type {
55        QuantType::Q8_0 => quantize_q8_0(&data, shape),
56        QuantType::Q4_0 => quantize_q4_0(&data, shape),
57        QuantType::Q4_1 => quantize_q4_1(&data, shape),
58        QuantType::Q5_0 => quantize_q5_0(&data, shape),
59        QuantType::Q5_1 => quantize_q5_1(&data, shape),
60        QuantType::F16 => quantize_f16(&data, shape),
61        QuantType::F32 => quantize_f32(&data, shape),
62    }
63}
64
65/// Quantizes a model (collection of named tensors).
66///
67/// # Arguments
68/// * `tensors` - Named tensors to quantize
69/// * `quant_type` - The target quantization type
70///
71/// # Returns
72/// A map of quantized tensors
73pub fn quantize_model(
74    tensors: &[(&str, &Tensor<f32>)],
75    quant_type: QuantType,
76) -> QuantResult<Vec<(String, QuantizedTensor)>> {
77    tensors
78        .par_iter()
79        .map(|(name, tensor)| {
80            let quantized = quantize_tensor(tensor, quant_type)?;
81            Ok((name.to_string(), quantized))
82        })
83        .collect()
84}
85
86// =============================================================================
87// Q8_0 Quantization
88// =============================================================================
89
90/// Quantizes data to Q8_0 format (8-bit with per-block scale).
91fn quantize_q8_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
92    let block_size = DEFAULT_BLOCK_SIZE;
93    let n_blocks = data.len().div_ceil(block_size);
94
95    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
96        .into_par_iter()
97        .map(|block_idx| {
98            let start = block_idx * block_size;
99            let end = (start + block_size).min(data.len());
100            let block_data = &data[start..end];
101
102            // Find max absolute value for scale
103            let max_abs = block_data
104                .iter()
105                .map(|x| x.abs())
106                .fold(0.0f32, |a, b| a.max(b));
107
108            // Compute scale (avoid division by zero)
109            let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
110
111            // Quantize to int8
112            let mut quantized = [0i8; 32];
113            for (i, &val) in block_data.iter().enumerate() {
114                let q = (val / scale).round().clamp(-127.0, 127.0) as i8;
115                quantized[i] = q;
116            }
117
118            QuantizedBlock::Q8(Q8Block::new(f16::from_f32(scale), quantized))
119        })
120        .collect();
121
122    Ok(QuantizedTensor::new(shape, QuantType::Q8_0, blocks))
123}
124
125// =============================================================================
126// Q4_0 Quantization
127// =============================================================================
128
129/// Quantizes data to Q4_0 format (4-bit with per-block scale).
130fn quantize_q4_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
131    let block_size = DEFAULT_BLOCK_SIZE;
132    let n_blocks = data.len().div_ceil(block_size);
133
134    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
135        .into_par_iter()
136        .map(|block_idx| {
137            let start = block_idx * block_size;
138            let end = (start + block_size).min(data.len());
139            let block_data = &data[start..end];
140
141            // Find max absolute value for scale
142            let max_abs = block_data
143                .iter()
144                .map(|x| x.abs())
145                .fold(0.0f32, |a, b| a.max(b));
146
147            // Compute scale (4-bit range is -8 to 7)
148            let scale = if max_abs > 0.0 { max_abs / 7.0 } else { 1.0 };
149
150            // Quantize to 4-bit (stored as i8 in range -8 to 7)
151            let mut quantized = [0i8; 32];
152            for (i, &val) in block_data.iter().enumerate() {
153                let q = (val / scale).round().clamp(-8.0, 7.0) as i8;
154                quantized[i] = q;
155            }
156
157            // Pack into bytes
158            let packed = Q4Block::pack(&quantized);
159
160            QuantizedBlock::Q4(Q4Block::new(f16::from_f32(scale), packed))
161        })
162        .collect();
163
164    Ok(QuantizedTensor::new(shape, QuantType::Q4_0, blocks))
165}
166
167// =============================================================================
168// Q4_1 Quantization
169// =============================================================================
170
171/// Quantizes data to Q4_1 format (4-bit with per-block scale and min).
172fn quantize_q4_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
173    let block_size = DEFAULT_BLOCK_SIZE;
174    let n_blocks = data.len().div_ceil(block_size);
175
176    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
177        .into_par_iter()
178        .map(|block_idx| {
179            let start = block_idx * block_size;
180            let end = (start + block_size).min(data.len());
181            let block_data = &data[start..end];
182
183            // Find min and max
184            let min = block_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
185            let max = block_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
186
187            // Compute scale (4-bit unsigned range is 0 to 15)
188            let scale = if max > min { (max - min) / 15.0 } else { 1.0 };
189
190            // Quantize to 4-bit unsigned
191            let mut quantized = [0u8; 32];
192            for (i, &val) in block_data.iter().enumerate() {
193                let q = ((val - min) / scale).round().clamp(0.0, 15.0) as u8;
194                quantized[i] = q;
195            }
196
197            // Pack into bytes
198            let mut packed = [0u8; 16];
199            for i in 0..16.min(block_data.len() / 2) {
200                let low = quantized[i * 2] & 0x0F;
201                let high = quantized.get(i * 2 + 1).copied().unwrap_or(0) & 0x0F;
202                packed[i] = low | (high << 4);
203            }
204
205            QuantizedBlock::Q4_1(Q4_1Block::new(
206                f16::from_f32(scale),
207                f16::from_f32(min),
208                packed,
209            ))
210        })
211        .collect();
212
213    Ok(QuantizedTensor::new(shape, QuantType::Q4_1, blocks))
214}
215
216// =============================================================================
217// Q5_0 Quantization (5-bit symmetric)
218// =============================================================================
219
220/// Quantizes data to Q5_0 format (5-bit signed with per-block scale).
221fn quantize_q5_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
222    let block_size = DEFAULT_BLOCK_SIZE;
223    let n_blocks = data.len().div_ceil(block_size);
224
225    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
226        .into_par_iter()
227        .map(|block_idx| {
228            let start = block_idx * block_size;
229            let end = (start + block_size).min(data.len());
230            let block_data = &data[start..end];
231
232            let max_abs = block_data
233                .iter()
234                .map(|x| x.abs())
235                .fold(0.0f32, |a, b| a.max(b));
236
237            // 5-bit signed range: -16 to 15
238            let scale = if max_abs > 0.0 { max_abs / 15.0 } else { 1.0 };
239
240            let mut quantized = [0i8; 32];
241            for (i, &val) in block_data.iter().enumerate() {
242                let q = (val / scale).round().clamp(-16.0, 15.0) as i8;
243                quantized[i] = q;
244            }
245
246            let packed = Q5Block::pack(&quantized);
247            QuantizedBlock::Q5(Q5Block::new(f16::from_f32(scale), packed))
248        })
249        .collect();
250
251    Ok(QuantizedTensor::new(shape, QuantType::Q5_0, blocks))
252}
253
254// =============================================================================
255// Q5_1 Quantization (5-bit asymmetric)
256// =============================================================================
257
258/// Quantizes data to Q5_1 format (5-bit unsigned with per-block scale and min).
259fn quantize_q5_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
260    let block_size = DEFAULT_BLOCK_SIZE;
261    let n_blocks = data.len().div_ceil(block_size);
262
263    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
264        .into_par_iter()
265        .map(|block_idx| {
266            let start = block_idx * block_size;
267            let end = (start + block_size).min(data.len());
268            let block_data = &data[start..end];
269
270            let min = block_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
271            let max = block_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
272
273            // 5-bit unsigned range: 0 to 31
274            let scale = if max > min { (max - min) / 31.0 } else { 1.0 };
275
276            let mut quantized = [0u8; 32];
277            for (i, &val) in block_data.iter().enumerate() {
278                let q = ((val - min) / scale).round().clamp(0.0, 31.0) as u8;
279                quantized[i] = q;
280            }
281
282            let packed = Q5_1Block::pack(&quantized);
283            QuantizedBlock::Q5_1(Q5_1Block::new(
284                f16::from_f32(scale),
285                f16::from_f32(min),
286                packed,
287            ))
288        })
289        .collect();
290
291    Ok(QuantizedTensor::new(shape, QuantType::Q5_1, blocks))
292}
293
294// =============================================================================
295// F16 Quantization
296// =============================================================================
297
298/// Quantizes data to F16 format (half precision).
299fn quantize_f16(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
300    let f16_data: Vec<f16> = data.par_iter().map(|&x| f16::from_f32(x)).collect();
301
302    let blocks = vec![QuantizedBlock::F16(f16_data)];
303
304    Ok(QuantizedTensor::new(shape, QuantType::F16, blocks))
305}
306
307// =============================================================================
308// F32 (No Quantization)
309// =============================================================================
310
311/// Stores data as F32 (no quantization, for comparison).
312fn quantize_f32(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
313    let blocks = vec![QuantizedBlock::F32(data.to_vec())];
314    Ok(QuantizedTensor::new(shape, QuantType::F32, blocks))
315}
316
317// =============================================================================
318// Utility Functions
319// =============================================================================
320
321/// Computes the quantization error (RMSE) between original and quantized.
322pub fn compute_quantization_error(original: &[f32], dequantized: &[f32]) -> f32 {
323    if original.len() != dequantized.len() || original.is_empty() {
324        return f32::INFINITY;
325    }
326
327    let mse: f32 = original
328        .iter()
329        .zip(dequantized.iter())
330        .map(|(a, b)| (a - b).powi(2))
331        .sum::<f32>()
332        / original.len() as f32;
333
334    mse.sqrt()
335}
336
337/// Returns statistics about quantization error.
338pub struct QuantizationStats {
339    /// Root mean square error.
340    pub rmse: f32,
341    /// Maximum absolute error.
342    pub max_error: f32,
343    /// Mean absolute error.
344    pub mean_error: f32,
345    /// Compression ratio.
346    pub compression_ratio: f32,
347}
348
349/// Computes detailed quantization statistics.
350pub fn compute_quantization_stats(
351    original: &[f32],
352    dequantized: &[f32],
353    quant_type: QuantType,
354) -> QuantizationStats {
355    let errors: Vec<f32> = original
356        .iter()
357        .zip(dequantized.iter())
358        .map(|(a, b)| (a - b).abs())
359        .collect();
360
361    let mse: f32 = errors.iter().map(|e| e.powi(2)).sum::<f32>() / errors.len() as f32;
362    let max_error = errors.iter().fold(0.0f32, |a, &b| a.max(b));
363    let mean_error = errors.iter().sum::<f32>() / errors.len() as f32;
364
365    QuantizationStats {
366        rmse: mse.sqrt(),
367        max_error,
368        mean_error,
369        compression_ratio: quant_type.compression_ratio(),
370    }
371}
372
373// =============================================================================
374// Tests
375// =============================================================================
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn test_quantize_q8_0() {
383        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
384        let tensor = Tensor::from_vec(data.clone(), &[8]).unwrap();
385        let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
386
387        assert_eq!(quantized.quant_type, QuantType::Q8_0);
388        assert_eq!(quantized.shape, vec![8]);
389        assert_eq!(quantized.num_blocks(), 1);
390    }
391
392    #[test]
393    fn test_quantize_q4_0() {
394        let data: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
395        let tensor = Tensor::from_vec(data.clone(), &[64]).unwrap();
396        let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
397
398        assert_eq!(quantized.quant_type, QuantType::Q4_0);
399        assert_eq!(quantized.num_blocks(), 2);
400    }
401
402    #[test]
403    fn test_quantize_f16() {
404        let data = vec![1.0, 2.0, 3.0, 4.0];
405        let tensor = Tensor::from_vec(data.clone(), &[4]).unwrap();
406        let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
407
408        assert_eq!(quantized.quant_type, QuantType::F16);
409    }
410
411    #[test]
412    fn test_compression_ratio() {
413        let data: Vec<f32> = (0..256).map(|x| x as f32).collect();
414        let tensor = Tensor::from_vec(data, &[256]).unwrap();
415
416        let q8 = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
417        let q4 = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
418
419        // Q8 should be about 4x compression, Q4 about 8x
420        assert!(q8.compression_ratio() > 2.0);
421        assert!(q4.compression_ratio() > q8.compression_ratio());
422    }
423
424    #[test]
425    fn test_quantization_error() {
426        let original = vec![1.0, 2.0, 3.0, 4.0];
427        let dequantized = vec![1.1, 2.0, 2.9, 4.1];
428
429        let rmse = compute_quantization_error(&original, &dequantized);
430        assert!(rmse > 0.0);
431        assert!(rmse < 0.2);
432    }
433}