Skip to main content

axonml_quant/
dequantize.rs

1//! Dequantization Functions
2//!
3//! Functions for converting quantized tensors back to floating point.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use axonml_tensor::Tensor;
9use rayon::prelude::*;
10
11use crate::error::{QuantError, QuantResult};
12use crate::types::{QuantType, QuantizedTensor, QuantizedBlock, Q8Block, Q4Block, Q4_1Block};
13
14// =============================================================================
15// Public API
16// =============================================================================
17
18/// Dequantizes a quantized tensor back to f32.
19///
20/// # Arguments
21/// * `quantized` - The quantized tensor to dequantize
22///
23/// # Returns
24/// A tensor with f32 values
25///
26/// # Example
27/// ```ignore
28/// use axonml_quant::{quantize_tensor, dequantize_tensor, QuantType};
29///
30/// let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4])?;
31/// let quantized = quantize_tensor(&tensor, QuantType::Q8_0)?;
32/// let dequantized = dequantize_tensor(&quantized)?;
33/// ```
34pub fn dequantize_tensor(quantized: &QuantizedTensor) -> QuantResult<Tensor<f32>> {
35    let data = match quantized.quant_type {
36        QuantType::Q8_0 => dequantize_q8_0(quantized),
37        QuantType::Q4_0 => dequantize_q4_0(quantized),
38        QuantType::Q4_1 => dequantize_q4_1(quantized),
39        QuantType::Q5_0 | QuantType::Q5_1 => dequantize_q4_0(quantized), // Fallback
40        QuantType::F16 => dequantize_f16(quantized),
41        QuantType::F32 => dequantize_f32(quantized),
42    }?;
43
44    // Truncate to original size
45    let expected_size = quantized.numel;
46    let data = if data.len() > expected_size {
47        data[..expected_size].to_vec()
48    } else {
49        data
50    };
51
52    Tensor::from_vec(data, &quantized.shape)
53        .map_err(|e| QuantError::TensorConversion(format!("{:?}", e)))
54}
55
56/// Dequantizes a single block.
57pub fn dequantize_block(block: &QuantizedBlock) -> Vec<f32> {
58    match block {
59        QuantizedBlock::Q8(b) => dequantize_q8_block(b),
60        QuantizedBlock::Q4(b) => dequantize_q4_block(b),
61        QuantizedBlock::Q4_1(b) => dequantize_q4_1_block(b),
62        QuantizedBlock::F16(data) => data.iter().map(|x| x.to_f32()).collect(),
63        QuantizedBlock::F32(data) => data.clone(),
64    }
65}
66
67// =============================================================================
68// Q8_0 Dequantization
69// =============================================================================
70
71/// Dequantizes Q8_0 data.
72fn dequantize_q8_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
73    let result: Vec<f32> = quantized
74        .blocks
75        .par_iter()
76        .flat_map(|block| {
77            if let QuantizedBlock::Q8(b) = block {
78                dequantize_q8_block(b)
79            } else {
80                vec![0.0; 32]
81            }
82        })
83        .collect();
84
85    Ok(result)
86}
87
88/// Dequantizes a single Q8 block.
89fn dequantize_q8_block(block: &Q8Block) -> Vec<f32> {
90    let scale = block.scale.to_f32();
91    block
92        .data
93        .iter()
94        .map(|&q| q as f32 * scale)
95        .collect()
96}
97
98// =============================================================================
99// Q4_0 Dequantization
100// =============================================================================
101
102/// Dequantizes Q4_0 data.
103fn dequantize_q4_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
104    let result: Vec<f32> = quantized
105        .blocks
106        .par_iter()
107        .flat_map(|block| {
108            if let QuantizedBlock::Q4(b) = block {
109                dequantize_q4_block(b)
110            } else {
111                vec![0.0; 32]
112            }
113        })
114        .collect();
115
116    Ok(result)
117}
118
119/// Dequantizes a single Q4 block.
120fn dequantize_q4_block(block: &Q4Block) -> Vec<f32> {
121    let scale = block.scale.to_f32();
122    let unpacked = block.unpack();
123
124    unpacked
125        .iter()
126        .map(|&q| q as f32 * scale)
127        .collect()
128}
129
130// =============================================================================
131// Q4_1 Dequantization
132// =============================================================================
133
134/// Dequantizes Q4_1 data.
135fn dequantize_q4_1(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
136    let result: Vec<f32> = quantized
137        .blocks
138        .par_iter()
139        .flat_map(|block| {
140            if let QuantizedBlock::Q4_1(b) = block {
141                dequantize_q4_1_block(b)
142            } else {
143                vec![0.0; 32]
144            }
145        })
146        .collect();
147
148    Ok(result)
149}
150
151/// Dequantizes a single Q4_1 block.
152fn dequantize_q4_1_block(block: &Q4_1Block) -> Vec<f32> {
153    let scale = block.scale.to_f32();
154    let min = block.min.to_f32();
155    let unpacked = block.unpack();
156
157    unpacked
158        .iter()
159        .map(|&q| q as f32 * scale + min)
160        .collect()
161}
162
163// =============================================================================
164// F16 Dequantization
165// =============================================================================
166
167/// Dequantizes F16 data.
168fn dequantize_f16(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
169    let result: Vec<f32> = quantized
170        .blocks
171        .iter()
172        .flat_map(|block| {
173            if let QuantizedBlock::F16(data) = block {
174                data.iter().map(|x| x.to_f32()).collect()
175            } else {
176                vec![]
177            }
178        })
179        .collect();
180
181    Ok(result)
182}
183
184// =============================================================================
185// F32 Dequantization (passthrough)
186// =============================================================================
187
188/// Dequantizes F32 data (passthrough).
189fn dequantize_f32(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
190    let result: Vec<f32> = quantized
191        .blocks
192        .iter()
193        .flat_map(|block| {
194            if let QuantizedBlock::F32(data) = block {
195                data.clone()
196            } else {
197                vec![]
198            }
199        })
200        .collect();
201
202    Ok(result)
203}
204
205// =============================================================================
206// Tests
207// =============================================================================
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use crate::quantize::quantize_tensor;
213
214    #[test]
215    fn test_roundtrip_q8_0() {
216        let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
217        let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
218
219        let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
220        let dequantized = dequantize_tensor(&quantized).unwrap();
221
222        // Check shape preserved
223        assert_eq!(dequantized.shape(), &[64]);
224
225        // Check values are close (some error expected)
226        let deq_data = dequantized.to_vec();
227        for (orig, deq) in original.iter().zip(deq_data.iter()) {
228            assert!((orig - deq).abs() < 0.1, "Q8 error too large: {} vs {}", orig, deq);
229        }
230    }
231
232    #[test]
233    fn test_roundtrip_q4_0() {
234        let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
235        let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
236
237        let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
238        let dequantized = dequantize_tensor(&quantized).unwrap();
239
240        assert_eq!(dequantized.shape(), &[64]);
241
242        // Q4 has more error but should still be reasonable
243        let deq_data = dequantized.to_vec();
244        let max_error: f32 = original
245            .iter()
246            .zip(deq_data.iter())
247            .map(|(a, b)| (a - b).abs())
248            .fold(0.0, f32::max);
249
250        assert!(max_error < 2.0, "Q4 max error too large: {}", max_error);
251    }
252
253    #[test]
254    fn test_roundtrip_f16() {
255        let original = vec![1.0f32, 2.5, -3.0, 4.25];
256        let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
257
258        let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
259        let dequantized = dequantize_tensor(&quantized).unwrap();
260
261        let deq_data = dequantized.to_vec();
262        for (orig, deq) in original.iter().zip(deq_data.iter()) {
263            assert!((orig - deq).abs() < 0.01, "F16 error: {} vs {}", orig, deq);
264        }
265    }
266
267    #[test]
268    fn test_roundtrip_f32() {
269        let original = vec![1.0f32, 2.5, -3.0, 4.25];
270        let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
271
272        let quantized = quantize_tensor(&tensor, QuantType::F32).unwrap();
273        let dequantized = dequantize_tensor(&quantized).unwrap();
274
275        let deq_data = dequantized.to_vec();
276        assert_eq!(original, deq_data);
277    }
278}