Skip to main content

axonml_quant/
quantize.rs

1//! Quantization Functions
2//!
3//! # File
4//! `crates/axonml-quant/src/quantize.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use axonml_tensor::Tensor;
19use half::f16;
20use rayon::prelude::*;
21
22use crate::DEFAULT_BLOCK_SIZE;
23use crate::error::QuantResult;
24use crate::types::{
25    Q4_1Block, Q4Block, Q5_1Block, Q5Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor,
26};
27
28// =============================================================================
29// Public API
30// =============================================================================
31
32/// Quantizes a tensor to the specified quantization type.
33///
34/// # Arguments
35/// * `tensor` - The input tensor to quantize
36/// * `quant_type` - The target quantization type
37///
38/// # Returns
39/// A quantized tensor
40///
41/// # Example
42/// ```ignore
43/// use axonml_quant::{quantize_tensor, QuantType};
44///
45/// let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4])?;
46/// let quantized = quantize_tensor(&tensor, QuantType::Q8_0)?;
47/// ```
48pub fn quantize_tensor(
49    tensor: &Tensor<f32>,
50    quant_type: QuantType,
51) -> QuantResult<QuantizedTensor> {
52    let data = tensor.to_vec();
53    let shape = tensor.shape().to_vec();
54
55    match quant_type {
56        QuantType::Q8_0 => quantize_q8_0(&data, shape),
57        QuantType::Q4_0 => quantize_q4_0(&data, shape),
58        QuantType::Q4_1 => quantize_q4_1(&data, shape),
59        QuantType::Q5_0 => quantize_q5_0(&data, shape),
60        QuantType::Q5_1 => quantize_q5_1(&data, shape),
61        QuantType::F16 => quantize_f16(&data, shape),
62        QuantType::F32 => quantize_f32(&data, shape),
63    }
64}
65
66/// Quantizes a model (collection of named tensors).
67///
68/// # Arguments
69/// * `tensors` - Named tensors to quantize
70/// * `quant_type` - The target quantization type
71///
72/// # Returns
73/// A map of quantized tensors
74pub fn quantize_model(
75    tensors: &[(&str, &Tensor<f32>)],
76    quant_type: QuantType,
77) -> QuantResult<Vec<(String, QuantizedTensor)>> {
78    tensors
79        .par_iter()
80        .map(|(name, tensor)| {
81            let quantized = quantize_tensor(tensor, quant_type)?;
82            Ok((name.to_string(), quantized))
83        })
84        .collect()
85}
86
87// =============================================================================
88// Q8_0 Quantization
89// =============================================================================
90
91/// Quantizes data to Q8_0 format (8-bit with per-block scale).
92fn quantize_q8_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
93    let block_size = DEFAULT_BLOCK_SIZE;
94    let n_blocks = data.len().div_ceil(block_size);
95
96    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
97        .into_par_iter()
98        .map(|block_idx| {
99            let start = block_idx * block_size;
100            let end = (start + block_size).min(data.len());
101            let block_data = &data[start..end];
102
103            // Find max absolute value for scale
104            let max_abs = block_data
105                .iter()
106                .map(|x| x.abs())
107                .fold(0.0f32, |a, b| a.max(b));
108
109            // Compute scale (avoid division by zero)
110            let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
111
112            // Quantize to int8
113            let mut quantized = [0i8; 32];
114            for (i, &val) in block_data.iter().enumerate() {
115                let q = (val / scale).round().clamp(-127.0, 127.0) as i8;
116                quantized[i] = q;
117            }
118
119            QuantizedBlock::Q8(Q8Block::new(f16::from_f32(scale), quantized))
120        })
121        .collect();
122
123    Ok(QuantizedTensor::new(shape, QuantType::Q8_0, blocks))
124}
125
126// =============================================================================
127// Q4_0 Quantization
128// =============================================================================
129
130/// Quantizes data to Q4_0 format (4-bit with per-block scale).
131fn quantize_q4_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
132    let block_size = DEFAULT_BLOCK_SIZE;
133    let n_blocks = data.len().div_ceil(block_size);
134
135    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
136        .into_par_iter()
137        .map(|block_idx| {
138            let start = block_idx * block_size;
139            let end = (start + block_size).min(data.len());
140            let block_data = &data[start..end];
141
142            // Find max absolute value for scale
143            let max_abs = block_data
144                .iter()
145                .map(|x| x.abs())
146                .fold(0.0f32, |a, b| a.max(b));
147
148            // Compute scale (4-bit range is -8 to 7)
149            let scale = if max_abs > 0.0 { max_abs / 7.0 } else { 1.0 };
150
151            // Quantize to 4-bit (stored as i8 in range -8 to 7)
152            let mut quantized = [0i8; 32];
153            for (i, &val) in block_data.iter().enumerate() {
154                let q = (val / scale).round().clamp(-8.0, 7.0) as i8;
155                quantized[i] = q;
156            }
157
158            // Pack into bytes
159            let packed = Q4Block::pack(&quantized);
160
161            QuantizedBlock::Q4(Q4Block::new(f16::from_f32(scale), packed))
162        })
163        .collect();
164
165    Ok(QuantizedTensor::new(shape, QuantType::Q4_0, blocks))
166}
167
168// =============================================================================
169// Q4_1 Quantization
170// =============================================================================
171
172/// Quantizes data to Q4_1 format (4-bit with per-block scale and min).
173fn quantize_q4_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
174    let block_size = DEFAULT_BLOCK_SIZE;
175    let n_blocks = data.len().div_ceil(block_size);
176
177    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
178        .into_par_iter()
179        .map(|block_idx| {
180            let start = block_idx * block_size;
181            let end = (start + block_size).min(data.len());
182            let block_data = &data[start..end];
183
184            // Find min and max
185            let min = block_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
186            let max = block_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
187
188            // Compute scale (4-bit unsigned range is 0 to 15)
189            let scale = if max > min { (max - min) / 15.0 } else { 1.0 };
190
191            // Quantize to 4-bit unsigned
192            let mut quantized = [0u8; 32];
193            for (i, &val) in block_data.iter().enumerate() {
194                let q = ((val - min) / scale).round().clamp(0.0, 15.0) as u8;
195                quantized[i] = q;
196            }
197
198            // Pack into bytes
199            let mut packed = [0u8; 16];
200            for i in 0..16.min(block_data.len() / 2) {
201                let low = quantized[i * 2] & 0x0F;
202                let high = quantized.get(i * 2 + 1).copied().unwrap_or(0) & 0x0F;
203                packed[i] = low | (high << 4);
204            }
205
206            QuantizedBlock::Q4_1(Q4_1Block::new(
207                f16::from_f32(scale),
208                f16::from_f32(min),
209                packed,
210            ))
211        })
212        .collect();
213
214    Ok(QuantizedTensor::new(shape, QuantType::Q4_1, blocks))
215}
216
217// =============================================================================
218// Q5_0 Quantization (5-bit symmetric)
219// =============================================================================
220
221/// Quantizes data to Q5_0 format (5-bit signed with per-block scale).
222fn quantize_q5_0(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
223    let block_size = DEFAULT_BLOCK_SIZE;
224    let n_blocks = data.len().div_ceil(block_size);
225
226    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
227        .into_par_iter()
228        .map(|block_idx| {
229            let start = block_idx * block_size;
230            let end = (start + block_size).min(data.len());
231            let block_data = &data[start..end];
232
233            let max_abs = block_data
234                .iter()
235                .map(|x| x.abs())
236                .fold(0.0f32, |a, b| a.max(b));
237
238            // 5-bit signed range: -16 to 15
239            let scale = if max_abs > 0.0 { max_abs / 15.0 } else { 1.0 };
240
241            let mut quantized = [0i8; 32];
242            for (i, &val) in block_data.iter().enumerate() {
243                let q = (val / scale).round().clamp(-16.0, 15.0) as i8;
244                quantized[i] = q;
245            }
246
247            let packed = Q5Block::pack(&quantized);
248            QuantizedBlock::Q5(Q5Block::new(f16::from_f32(scale), packed))
249        })
250        .collect();
251
252    Ok(QuantizedTensor::new(shape, QuantType::Q5_0, blocks))
253}
254
255// =============================================================================
256// Q5_1 Quantization (5-bit asymmetric)
257// =============================================================================
258
259/// Quantizes data to Q5_1 format (5-bit unsigned with per-block scale and min).
260fn quantize_q5_1(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
261    let block_size = DEFAULT_BLOCK_SIZE;
262    let n_blocks = data.len().div_ceil(block_size);
263
264    let blocks: Vec<QuantizedBlock> = (0..n_blocks)
265        .into_par_iter()
266        .map(|block_idx| {
267            let start = block_idx * block_size;
268            let end = (start + block_size).min(data.len());
269            let block_data = &data[start..end];
270
271            let min = block_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
272            let max = block_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
273
274            // 5-bit unsigned range: 0 to 31
275            let scale = if max > min { (max - min) / 31.0 } else { 1.0 };
276
277            let mut quantized = [0u8; 32];
278            for (i, &val) in block_data.iter().enumerate() {
279                let q = ((val - min) / scale).round().clamp(0.0, 31.0) as u8;
280                quantized[i] = q;
281            }
282
283            let packed = Q5_1Block::pack(&quantized);
284            QuantizedBlock::Q5_1(Q5_1Block::new(
285                f16::from_f32(scale),
286                f16::from_f32(min),
287                packed,
288            ))
289        })
290        .collect();
291
292    Ok(QuantizedTensor::new(shape, QuantType::Q5_1, blocks))
293}
294
295// =============================================================================
296// F16 Quantization
297// =============================================================================
298
299/// Quantizes data to F16 format (half precision).
300fn quantize_f16(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
301    let f16_data: Vec<f16> = data.par_iter().map(|&x| f16::from_f32(x)).collect();
302
303    let blocks = vec![QuantizedBlock::F16(f16_data)];
304
305    Ok(QuantizedTensor::new(shape, QuantType::F16, blocks))
306}
307
308// =============================================================================
309// F32 (No Quantization)
310// =============================================================================
311
312/// Stores data as F32 (no quantization, for comparison).
313fn quantize_f32(data: &[f32], shape: Vec<usize>) -> QuantResult<QuantizedTensor> {
314    let blocks = vec![QuantizedBlock::F32(data.to_vec())];
315    Ok(QuantizedTensor::new(shape, QuantType::F32, blocks))
316}
317
318// =============================================================================
319// Utility Functions
320// =============================================================================
321
322/// Computes the quantization error (RMSE) between original and quantized.
323pub fn compute_quantization_error(original: &[f32], dequantized: &[f32]) -> f32 {
324    if original.len() != dequantized.len() || original.is_empty() {
325        return f32::INFINITY;
326    }
327
328    let mse: f32 = original
329        .iter()
330        .zip(dequantized.iter())
331        .map(|(a, b)| (a - b).powi(2))
332        .sum::<f32>()
333        / original.len() as f32;
334
335    mse.sqrt()
336}
337
338/// Returns statistics about quantization error.
339pub struct QuantizationStats {
340    /// Root mean square error.
341    pub rmse: f32,
342    /// Maximum absolute error.
343    pub max_error: f32,
344    /// Mean absolute error.
345    pub mean_error: f32,
346    /// Compression ratio.
347    pub compression_ratio: f32,
348}
349
350/// Computes detailed quantization statistics.
351pub fn compute_quantization_stats(
352    original: &[f32],
353    dequantized: &[f32],
354    quant_type: QuantType,
355) -> QuantizationStats {
356    let errors: Vec<f32> = original
357        .iter()
358        .zip(dequantized.iter())
359        .map(|(a, b)| (a - b).abs())
360        .collect();
361
362    let mse: f32 = errors.iter().map(|e| e.powi(2)).sum::<f32>() / errors.len() as f32;
363    let max_error = errors.iter().fold(0.0f32, |a, &b| a.max(b));
364    let mean_error = errors.iter().sum::<f32>() / errors.len() as f32;
365
366    QuantizationStats {
367        rmse: mse.sqrt(),
368        max_error,
369        mean_error,
370        compression_ratio: quant_type.compression_ratio(),
371    }
372}
373
374// =============================================================================
375// Tests
376// =============================================================================
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381
382    #[test]
383    fn test_quantize_q8_0() {
384        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
385        let tensor = Tensor::from_vec(data.clone(), &[8]).unwrap();
386        let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
387
388        assert_eq!(quantized.quant_type, QuantType::Q8_0);
389        assert_eq!(quantized.shape, vec![8]);
390        assert_eq!(quantized.num_blocks(), 1);
391    }
392
393    #[test]
394    fn test_quantize_q4_0() {
395        let data: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
396        let tensor = Tensor::from_vec(data.clone(), &[64]).unwrap();
397        let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
398
399        assert_eq!(quantized.quant_type, QuantType::Q4_0);
400        assert_eq!(quantized.num_blocks(), 2);
401    }
402
403    #[test]
404    fn test_quantize_f16() {
405        let data = vec![1.0, 2.0, 3.0, 4.0];
406        let tensor = Tensor::from_vec(data.clone(), &[4]).unwrap();
407        let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
408
409        assert_eq!(quantized.quant_type, QuantType::F16);
410    }
411
412    #[test]
413    fn test_compression_ratio() {
414        let data: Vec<f32> = (0..256).map(|x| x as f32).collect();
415        let tensor = Tensor::from_vec(data, &[256]).unwrap();
416
417        let q8 = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
418        let q4 = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
419
420        // Q8 should be about 4x compression, Q4 about 8x
421        assert!(q8.compression_ratio() > 2.0);
422        assert!(q4.compression_ratio() > q8.compression_ratio());
423    }
424
425    #[test]
426    fn test_quantization_error() {
427        let original = vec![1.0, 2.0, 3.0, 4.0];
428        let dequantized = vec![1.1, 2.0, 2.9, 4.1];
429
430        let rmse = compute_quantization_error(&original, &dequantized);
431        assert!(rmse > 0.0);
432        assert!(rmse < 0.2);
433    }
434}