Skip to main content

axonml_quant/
dequantize.rs

1//! Dequantization Functions
2//!
3//! # File
4//! `crates/axonml-quant/src/dequantize.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use axonml_tensor::Tensor;
19use rayon::prelude::*;
20
21use crate::error::{QuantError, QuantResult};
22use crate::types::{Q4_1Block, Q4Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
23
24// =============================================================================
25// Public API
26// =============================================================================
27
28/// Dequantizes a quantized tensor back to f32.
29///
30/// # Arguments
31/// * `quantized` - The quantized tensor to dequantize
32///
33/// # Returns
34/// A tensor with f32 values
35///
36/// # Example
37/// ```ignore
38/// use axonml_quant::{quantize_tensor, dequantize_tensor, QuantType};
39///
40/// let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4])?;
41/// let quantized = quantize_tensor(&tensor, QuantType::Q8_0)?;
42/// let dequantized = dequantize_tensor(&quantized)?;
43/// ```
44pub fn dequantize_tensor(quantized: &QuantizedTensor) -> QuantResult<Tensor<f32>> {
45    let data = match quantized.quant_type {
46        QuantType::Q8_0 => dequantize_q8_0(quantized),
47        QuantType::Q4_0 => dequantize_q4_0(quantized),
48        QuantType::Q4_1 => dequantize_q4_1(quantized),
49        QuantType::Q5_0 => dequantize_q5_0(quantized),
50        QuantType::Q5_1 => dequantize_q5_1(quantized),
51        QuantType::F16 => dequantize_f16(quantized),
52        QuantType::F32 => dequantize_f32(quantized),
53    }?;
54
55    // Truncate to original size
56    let expected_size = quantized.numel;
57    let data = if data.len() > expected_size {
58        data[..expected_size].to_vec()
59    } else {
60        data
61    };
62
63    Tensor::from_vec(data, &quantized.shape)
64        .map_err(|e| QuantError::TensorConversion(format!("{:?}", e)))
65}
66
67/// Dequantizes a single block.
68pub fn dequantize_block(block: &QuantizedBlock) -> Vec<f32> {
69    match block {
70        QuantizedBlock::Q8(b) => dequantize_q8_block(b),
71        QuantizedBlock::Q4(b) => dequantize_q4_block(b),
72        QuantizedBlock::Q4_1(b) => dequantize_q4_1_block(b),
73        QuantizedBlock::Q5(b) => {
74            let scale = b.scale.to_f32();
75            b.unpack().iter().map(|&v| v as f32 * scale).collect()
76        }
77        QuantizedBlock::Q5_1(b) => {
78            let scale = b.scale.to_f32();
79            let min = b.min.to_f32();
80            b.unpack().iter().map(|&v| v as f32 * scale + min).collect()
81        }
82        QuantizedBlock::F16(data) => data.iter().map(|x| x.to_f32()).collect(),
83        QuantizedBlock::F32(data) => data.clone(),
84    }
85}
86
87// =============================================================================
88// Q8_0 Dequantization
89// =============================================================================
90
91/// Dequantizes Q8_0 data.
92fn dequantize_q8_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
93    let result: Vec<f32> = quantized
94        .blocks
95        .par_iter()
96        .flat_map(|block| {
97            if let QuantizedBlock::Q8(b) = block {
98                dequantize_q8_block(b)
99            } else {
100                vec![0.0; 32]
101            }
102        })
103        .collect();
104
105    Ok(result)
106}
107
108/// Dequantizes a single Q8 block.
109fn dequantize_q8_block(block: &Q8Block) -> Vec<f32> {
110    let scale = block.scale.to_f32();
111    block.data.iter().map(|&q| q as f32 * scale).collect()
112}
113
114// =============================================================================
115// Q4_0 Dequantization
116// =============================================================================
117
118/// Dequantizes Q4_0 data.
119fn dequantize_q4_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
120    let result: Vec<f32> = quantized
121        .blocks
122        .par_iter()
123        .flat_map(|block| {
124            if let QuantizedBlock::Q4(b) = block {
125                dequantize_q4_block(b)
126            } else {
127                vec![0.0; 32]
128            }
129        })
130        .collect();
131
132    Ok(result)
133}
134
135/// Dequantizes a single Q4 block.
136fn dequantize_q4_block(block: &Q4Block) -> Vec<f32> {
137    let scale = block.scale.to_f32();
138    let unpacked = block.unpack();
139
140    unpacked.iter().map(|&q| q as f32 * scale).collect()
141}
142
143// =============================================================================
144// Q4_1 Dequantization
145// =============================================================================
146
147/// Dequantizes Q4_1 data.
148fn dequantize_q4_1(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
149    let result: Vec<f32> = quantized
150        .blocks
151        .par_iter()
152        .flat_map(|block| {
153            if let QuantizedBlock::Q4_1(b) = block {
154                dequantize_q4_1_block(b)
155            } else {
156                vec![0.0; 32]
157            }
158        })
159        .collect();
160
161    Ok(result)
162}
163
164/// Dequantizes a single Q4_1 block.
165fn dequantize_q4_1_block(block: &Q4_1Block) -> Vec<f32> {
166    let scale = block.scale.to_f32();
167    let min = block.min.to_f32();
168    let unpacked = block.unpack();
169
170    unpacked.iter().map(|&q| q as f32 * scale + min).collect()
171}
172
173// =============================================================================
174// Q5_0 Dequantization
175// =============================================================================
176
177/// Dequantizes Q5_0 data (5-bit signed with per-block scale).
178fn dequantize_q5_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
179    let mut result = Vec::new();
180    for block in &quantized.blocks {
181        if let QuantizedBlock::Q5(q5) = block {
182            let scale = q5.scale.to_f32();
183            let values = q5.unpack();
184            for &v in &values {
185                result.push(v as f32 * scale);
186            }
187        }
188    }
189    Ok(result)
190}
191
192// =============================================================================
193// Q5_1 Dequantization
194// =============================================================================
195
196/// Dequantizes Q5_1 data (5-bit unsigned with per-block scale and min).
197fn dequantize_q5_1(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
198    let mut result = Vec::new();
199    for block in &quantized.blocks {
200        if let QuantizedBlock::Q5_1(q5) = block {
201            let scale = q5.scale.to_f32();
202            let min = q5.min.to_f32();
203            let values = q5.unpack();
204            for &v in &values {
205                result.push(v as f32 * scale + min);
206            }
207        }
208    }
209    Ok(result)
210}
211
212// =============================================================================
213// F16 Dequantization
214// =============================================================================
215
216/// Dequantizes F16 data.
217fn dequantize_f16(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
218    let result: Vec<f32> = quantized
219        .blocks
220        .iter()
221        .flat_map(|block| {
222            if let QuantizedBlock::F16(data) = block {
223                data.iter().map(|x| x.to_f32()).collect()
224            } else {
225                vec![]
226            }
227        })
228        .collect();
229
230    Ok(result)
231}
232
233// =============================================================================
234// F32 Dequantization (passthrough)
235// =============================================================================
236
237/// Dequantizes F32 data (passthrough).
238fn dequantize_f32(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
239    let result: Vec<f32> = quantized
240        .blocks
241        .iter()
242        .flat_map(|block| {
243            if let QuantizedBlock::F32(data) = block {
244                data.clone()
245            } else {
246                vec![]
247            }
248        })
249        .collect();
250
251    Ok(result)
252}
253
254// =============================================================================
255// Tests
256// =============================================================================
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use crate::quantize::quantize_tensor;
262
263    #[test]
264    fn test_roundtrip_q8_0() {
265        let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
266        let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
267
268        let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
269        let dequantized = dequantize_tensor(&quantized).unwrap();
270
271        // Check shape preserved
272        assert_eq!(dequantized.shape(), &[64]);
273
274        // Check values are close (some error expected)
275        let deq_data = dequantized.to_vec();
276        for (orig, deq) in original.iter().zip(deq_data.iter()) {
277            assert!(
278                (orig - deq).abs() < 0.1,
279                "Q8 error too large: {} vs {}",
280                orig,
281                deq
282            );
283        }
284    }
285
286    #[test]
287    fn test_roundtrip_q4_0() {
288        let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
289        let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
290
291        let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
292        let dequantized = dequantize_tensor(&quantized).unwrap();
293
294        assert_eq!(dequantized.shape(), &[64]);
295
296        // Q4 has more error but should still be reasonable
297        let deq_data = dequantized.to_vec();
298        let max_error: f32 = original
299            .iter()
300            .zip(deq_data.iter())
301            .map(|(a, b)| (a - b).abs())
302            .fold(0.0, f32::max);
303
304        assert!(max_error < 2.0, "Q4 max error too large: {}", max_error);
305    }
306
307    #[test]
308    fn test_roundtrip_f16() {
309        let original = vec![1.0f32, 2.5, -3.0, 4.25];
310        let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
311
312        let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
313        let dequantized = dequantize_tensor(&quantized).unwrap();
314
315        let deq_data = dequantized.to_vec();
316        for (orig, deq) in original.iter().zip(deq_data.iter()) {
317            assert!((orig - deq).abs() < 0.01, "F16 error: {} vs {}", orig, deq);
318        }
319    }
320
321    #[test]
322    fn test_roundtrip_f32() {
323        let original = vec![1.0f32, 2.5, -3.0, 4.25];
324        let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
325
326        let quantized = quantize_tensor(&tensor, QuantType::F32).unwrap();
327        let dequantized = dequantize_tensor(&quantized).unwrap();
328
329        let deq_data = dequantized.to_vec();
330        assert_eq!(original, deq_data);
331    }
332}