Skip to main content

axonml_quant/
dequantize.rs

1//! Dequantization Functions
2//!
3//! # File
4//! `crates/axonml-quant/src/dequantize.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_tensor::Tensor;
18use rayon::prelude::*;
19
20use crate::error::{QuantError, QuantResult};
21use crate::types::{Q4_1Block, Q4Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
22
23// =============================================================================
24// Public API
25// =============================================================================
26
27/// Dequantizes a quantized tensor back to f32.
28///
29/// # Arguments
30/// * `quantized` - The quantized tensor to dequantize
31///
32/// # Returns
33/// A tensor with f32 values
34///
35/// # Example
36/// ```ignore
37/// use axonml_quant::{quantize_tensor, dequantize_tensor, QuantType};
38///
39/// let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4])?;
40/// let quantized = quantize_tensor(&tensor, QuantType::Q8_0)?;
41/// let dequantized = dequantize_tensor(&quantized)?;
42/// ```
43pub fn dequantize_tensor(quantized: &QuantizedTensor) -> QuantResult<Tensor<f32>> {
44    let data = match quantized.quant_type {
45        QuantType::Q8_0 => dequantize_q8_0(quantized),
46        QuantType::Q4_0 => dequantize_q4_0(quantized),
47        QuantType::Q4_1 => dequantize_q4_1(quantized),
48        QuantType::Q5_0 => dequantize_q5_0(quantized),
49        QuantType::Q5_1 => dequantize_q5_1(quantized),
50        QuantType::F16 => dequantize_f16(quantized),
51        QuantType::F32 => dequantize_f32(quantized),
52    }?;
53
54    // Truncate to original size
55    let expected_size = quantized.numel;
56    let data = if data.len() > expected_size {
57        data[..expected_size].to_vec()
58    } else {
59        data
60    };
61
62    Tensor::from_vec(data, &quantized.shape)
63        .map_err(|e| QuantError::TensorConversion(format!("{:?}", e)))
64}
65
66/// Dequantizes a single block.
67pub fn dequantize_block(block: &QuantizedBlock) -> Vec<f32> {
68    match block {
69        QuantizedBlock::Q8(b) => dequantize_q8_block(b),
70        QuantizedBlock::Q4(b) => dequantize_q4_block(b),
71        QuantizedBlock::Q4_1(b) => dequantize_q4_1_block(b),
72        QuantizedBlock::Q5(b) => {
73            let scale = b.scale.to_f32();
74            b.unpack().iter().map(|&v| v as f32 * scale).collect()
75        }
76        QuantizedBlock::Q5_1(b) => {
77            let scale = b.scale.to_f32();
78            let min = b.min.to_f32();
79            b.unpack().iter().map(|&v| v as f32 * scale + min).collect()
80        }
81        QuantizedBlock::F16(data) => data.iter().map(|x| x.to_f32()).collect(),
82        QuantizedBlock::F32(data) => data.clone(),
83    }
84}
85
86// =============================================================================
87// Q8_0 Dequantization
88// =============================================================================
89
90/// Dequantizes Q8_0 data.
91fn dequantize_q8_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
92    let result: Vec<f32> = quantized
93        .blocks
94        .par_iter()
95        .flat_map(|block| {
96            if let QuantizedBlock::Q8(b) = block {
97                dequantize_q8_block(b)
98            } else {
99                vec![0.0; 32]
100            }
101        })
102        .collect();
103
104    Ok(result)
105}
106
107/// Dequantizes a single Q8 block.
108fn dequantize_q8_block(block: &Q8Block) -> Vec<f32> {
109    let scale = block.scale.to_f32();
110    block.data.iter().map(|&q| q as f32 * scale).collect()
111}
112
113// =============================================================================
114// Q4_0 Dequantization
115// =============================================================================
116
117/// Dequantizes Q4_0 data.
118fn dequantize_q4_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
119    let result: Vec<f32> = quantized
120        .blocks
121        .par_iter()
122        .flat_map(|block| {
123            if let QuantizedBlock::Q4(b) = block {
124                dequantize_q4_block(b)
125            } else {
126                vec![0.0; 32]
127            }
128        })
129        .collect();
130
131    Ok(result)
132}
133
134/// Dequantizes a single Q4 block.
135fn dequantize_q4_block(block: &Q4Block) -> Vec<f32> {
136    let scale = block.scale.to_f32();
137    let unpacked = block.unpack();
138
139    unpacked.iter().map(|&q| q as f32 * scale).collect()
140}
141
142// =============================================================================
143// Q4_1 Dequantization
144// =============================================================================
145
146/// Dequantizes Q4_1 data.
147fn dequantize_q4_1(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
148    let result: Vec<f32> = quantized
149        .blocks
150        .par_iter()
151        .flat_map(|block| {
152            if let QuantizedBlock::Q4_1(b) = block {
153                dequantize_q4_1_block(b)
154            } else {
155                vec![0.0; 32]
156            }
157        })
158        .collect();
159
160    Ok(result)
161}
162
163/// Dequantizes a single Q4_1 block.
164fn dequantize_q4_1_block(block: &Q4_1Block) -> Vec<f32> {
165    let scale = block.scale.to_f32();
166    let min = block.min.to_f32();
167    let unpacked = block.unpack();
168
169    unpacked.iter().map(|&q| q as f32 * scale + min).collect()
170}
171
172// =============================================================================
173// Q5_0 Dequantization
174// =============================================================================
175
176/// Dequantizes Q5_0 data (5-bit signed with per-block scale).
177fn dequantize_q5_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
178    let mut result = Vec::new();
179    for block in &quantized.blocks {
180        if let QuantizedBlock::Q5(q5) = block {
181            let scale = q5.scale.to_f32();
182            let values = q5.unpack();
183            for &v in &values {
184                result.push(v as f32 * scale);
185            }
186        }
187    }
188    Ok(result)
189}
190
191// =============================================================================
192// Q5_1 Dequantization
193// =============================================================================
194
195/// Dequantizes Q5_1 data (5-bit unsigned with per-block scale and min).
196fn dequantize_q5_1(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
197    let mut result = Vec::new();
198    for block in &quantized.blocks {
199        if let QuantizedBlock::Q5_1(q5) = block {
200            let scale = q5.scale.to_f32();
201            let min = q5.min.to_f32();
202            let values = q5.unpack();
203            for &v in &values {
204                result.push(v as f32 * scale + min);
205            }
206        }
207    }
208    Ok(result)
209}
210
211// =============================================================================
212// F16 Dequantization
213// =============================================================================
214
215/// Dequantizes F16 data.
216fn dequantize_f16(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
217    let result: Vec<f32> = quantized
218        .blocks
219        .iter()
220        .flat_map(|block| {
221            if let QuantizedBlock::F16(data) = block {
222                data.iter().map(|x| x.to_f32()).collect()
223            } else {
224                vec![]
225            }
226        })
227        .collect();
228
229    Ok(result)
230}
231
232// =============================================================================
233// F32 Dequantization (passthrough)
234// =============================================================================
235
236/// Dequantizes F32 data (passthrough).
237fn dequantize_f32(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
238    let result: Vec<f32> = quantized
239        .blocks
240        .iter()
241        .flat_map(|block| {
242            if let QuantizedBlock::F32(data) = block {
243                data.clone()
244            } else {
245                vec![]
246            }
247        })
248        .collect();
249
250    Ok(result)
251}
252
253// =============================================================================
254// Tests
255// =============================================================================
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use crate::quantize::quantize_tensor;
261
262    #[test]
263    fn test_roundtrip_q8_0() {
264        let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
265        let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
266
267        let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
268        let dequantized = dequantize_tensor(&quantized).unwrap();
269
270        // Check shape preserved
271        assert_eq!(dequantized.shape(), &[64]);
272
273        // Check values are close (some error expected)
274        let deq_data = dequantized.to_vec();
275        for (orig, deq) in original.iter().zip(deq_data.iter()) {
276            assert!(
277                (orig - deq).abs() < 0.1,
278                "Q8 error too large: {} vs {}",
279                orig,
280                deq
281            );
282        }
283    }
284
285    #[test]
286    fn test_roundtrip_q4_0() {
287        let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
288        let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
289
290        let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
291        let dequantized = dequantize_tensor(&quantized).unwrap();
292
293        assert_eq!(dequantized.shape(), &[64]);
294
295        // Q4 has more error but should still be reasonable
296        let deq_data = dequantized.to_vec();
297        let max_error: f32 = original
298            .iter()
299            .zip(deq_data.iter())
300            .map(|(a, b)| (a - b).abs())
301            .fold(0.0, f32::max);
302
303        assert!(max_error < 2.0, "Q4 max error too large: {}", max_error);
304    }
305
306    #[test]
307    fn test_roundtrip_f16() {
308        let original = vec![1.0f32, 2.5, -3.0, 4.25];
309        let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
310
311        let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
312        let dequantized = dequantize_tensor(&quantized).unwrap();
313
314        let deq_data = dequantized.to_vec();
315        for (orig, deq) in original.iter().zip(deq_data.iter()) {
316            assert!((orig - deq).abs() < 0.01, "F16 error: {} vs {}", orig, deq);
317        }
318    }
319
320    #[test]
321    fn test_roundtrip_f32() {
322        let original = vec![1.0f32, 2.5, -3.0, 4.25];
323        let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
324
325        let quantized = quantize_tensor(&tensor, QuantType::F32).unwrap();
326        let dequantized = dequantize_tensor(&quantized).unwrap();
327
328        let deq_data = dequantized.to_vec();
329        assert_eq!(original, deq_data);
330    }
331}