Skip to main content

axonml_quant/
dequantize.rs

1//! Dequantization Functions
2//!
3//! # File
4//! `crates/axonml-quant/src/dequantize.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_tensor::Tensor;
18use rayon::prelude::*;
19
20use crate::error::{QuantError, QuantResult};
21use crate::types::{Q4_1Block, Q4Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
22
23// =============================================================================
24// Public API
25// =============================================================================
26
27/// Dequantizes a quantized tensor back to f32.
28///
29/// # Arguments
30/// * `quantized` - The quantized tensor to dequantize
31///
32/// # Returns
33/// A tensor with f32 values
34///
35/// # Example
36/// ```ignore
37/// use axonml_quant::{quantize_tensor, dequantize_tensor, QuantType};
38///
39/// let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4])?;
40/// let quantized = quantize_tensor(&tensor, QuantType::Q8_0)?;
41/// let dequantized = dequantize_tensor(&quantized)?;
42/// ```
43pub fn dequantize_tensor(quantized: &QuantizedTensor) -> QuantResult<Tensor<f32>> {
44    let data = match quantized.quant_type {
45        QuantType::Q8_0 => dequantize_q8_0(quantized),
46        QuantType::Q4_0 => dequantize_q4_0(quantized),
47        QuantType::Q4_1 => dequantize_q4_1(quantized),
48        QuantType::Q5_0 | QuantType::Q5_1 => dequantize_q4_0(quantized), // Fallback
49        QuantType::F16 => dequantize_f16(quantized),
50        QuantType::F32 => dequantize_f32(quantized),
51    }?;
52
53    // Truncate to original size
54    let expected_size = quantized.numel;
55    let data = if data.len() > expected_size {
56        data[..expected_size].to_vec()
57    } else {
58        data
59    };
60
61    Tensor::from_vec(data, &quantized.shape)
62        .map_err(|e| QuantError::TensorConversion(format!("{:?}", e)))
63}
64
65/// Dequantizes a single block.
66pub fn dequantize_block(block: &QuantizedBlock) -> Vec<f32> {
67    match block {
68        QuantizedBlock::Q8(b) => dequantize_q8_block(b),
69        QuantizedBlock::Q4(b) => dequantize_q4_block(b),
70        QuantizedBlock::Q4_1(b) => dequantize_q4_1_block(b),
71        QuantizedBlock::F16(data) => data.iter().map(|x| x.to_f32()).collect(),
72        QuantizedBlock::F32(data) => data.clone(),
73    }
74}
75
76// =============================================================================
77// Q8_0 Dequantization
78// =============================================================================
79
80/// Dequantizes Q8_0 data.
81fn dequantize_q8_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
82    let result: Vec<f32> = quantized
83        .blocks
84        .par_iter()
85        .flat_map(|block| {
86            if let QuantizedBlock::Q8(b) = block {
87                dequantize_q8_block(b)
88            } else {
89                vec![0.0; 32]
90            }
91        })
92        .collect();
93
94    Ok(result)
95}
96
97/// Dequantizes a single Q8 block.
98fn dequantize_q8_block(block: &Q8Block) -> Vec<f32> {
99    let scale = block.scale.to_f32();
100    block.data.iter().map(|&q| q as f32 * scale).collect()
101}
102
103// =============================================================================
104// Q4_0 Dequantization
105// =============================================================================
106
107/// Dequantizes Q4_0 data.
108fn dequantize_q4_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
109    let result: Vec<f32> = quantized
110        .blocks
111        .par_iter()
112        .flat_map(|block| {
113            if let QuantizedBlock::Q4(b) = block {
114                dequantize_q4_block(b)
115            } else {
116                vec![0.0; 32]
117            }
118        })
119        .collect();
120
121    Ok(result)
122}
123
124/// Dequantizes a single Q4 block.
125fn dequantize_q4_block(block: &Q4Block) -> Vec<f32> {
126    let scale = block.scale.to_f32();
127    let unpacked = block.unpack();
128
129    unpacked.iter().map(|&q| q as f32 * scale).collect()
130}
131
132// =============================================================================
133// Q4_1 Dequantization
134// =============================================================================
135
136/// Dequantizes Q4_1 data.
137fn dequantize_q4_1(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
138    let result: Vec<f32> = quantized
139        .blocks
140        .par_iter()
141        .flat_map(|block| {
142            if let QuantizedBlock::Q4_1(b) = block {
143                dequantize_q4_1_block(b)
144            } else {
145                vec![0.0; 32]
146            }
147        })
148        .collect();
149
150    Ok(result)
151}
152
153/// Dequantizes a single Q4_1 block.
154fn dequantize_q4_1_block(block: &Q4_1Block) -> Vec<f32> {
155    let scale = block.scale.to_f32();
156    let min = block.min.to_f32();
157    let unpacked = block.unpack();
158
159    unpacked.iter().map(|&q| q as f32 * scale + min).collect()
160}
161
162// =============================================================================
163// F16 Dequantization
164// =============================================================================
165
166/// Dequantizes F16 data.
167fn dequantize_f16(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
168    let result: Vec<f32> = quantized
169        .blocks
170        .iter()
171        .flat_map(|block| {
172            if let QuantizedBlock::F16(data) = block {
173                data.iter().map(|x| x.to_f32()).collect()
174            } else {
175                vec![]
176            }
177        })
178        .collect();
179
180    Ok(result)
181}
182
183// =============================================================================
184// F32 Dequantization (passthrough)
185// =============================================================================
186
187/// Dequantizes F32 data (passthrough).
188fn dequantize_f32(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
189    let result: Vec<f32> = quantized
190        .blocks
191        .iter()
192        .flat_map(|block| {
193            if let QuantizedBlock::F32(data) = block {
194                data.clone()
195            } else {
196                vec![]
197            }
198        })
199        .collect();
200
201    Ok(result)
202}
203
204// =============================================================================
205// Tests
206// =============================================================================
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use crate::quantize::quantize_tensor;
212
213    #[test]
214    fn test_roundtrip_q8_0() {
215        let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
216        let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
217
218        let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
219        let dequantized = dequantize_tensor(&quantized).unwrap();
220
221        // Check shape preserved
222        assert_eq!(dequantized.shape(), &[64]);
223
224        // Check values are close (some error expected)
225        let deq_data = dequantized.to_vec();
226        for (orig, deq) in original.iter().zip(deq_data.iter()) {
227            assert!(
228                (orig - deq).abs() < 0.1,
229                "Q8 error too large: {} vs {}",
230                orig,
231                deq
232            );
233        }
234    }
235
236    #[test]
237    fn test_roundtrip_q4_0() {
238        let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
239        let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
240
241        let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
242        let dequantized = dequantize_tensor(&quantized).unwrap();
243
244        assert_eq!(dequantized.shape(), &[64]);
245
246        // Q4 has more error but should still be reasonable
247        let deq_data = dequantized.to_vec();
248        let max_error: f32 = original
249            .iter()
250            .zip(deq_data.iter())
251            .map(|(a, b)| (a - b).abs())
252            .fold(0.0, f32::max);
253
254        assert!(max_error < 2.0, "Q4 max error too large: {}", max_error);
255    }
256
257    #[test]
258    fn test_roundtrip_f16() {
259        let original = vec![1.0f32, 2.5, -3.0, 4.25];
260        let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
261
262        let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
263        let dequantized = dequantize_tensor(&quantized).unwrap();
264
265        let deq_data = dequantized.to_vec();
266        for (orig, deq) in original.iter().zip(deq_data.iter()) {
267            assert!((orig - deq).abs() < 0.01, "F16 error: {} vs {}", orig, deq);
268        }
269    }
270
271    #[test]
272    fn test_roundtrip_f32() {
273        let original = vec![1.0f32, 2.5, -3.0, 4.25];
274        let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
275
276        let quantized = quantize_tensor(&tensor, QuantType::F32).unwrap();
277        let dequantized = dequantize_tensor(&quantized).unwrap();
278
279        let deq_data = dequantized.to_vec();
280        assert_eq!(original, deq_data);
281    }
282}