Skip to main content

axonml_quant/
dequantize.rs

1//! Dequantization Functions
2//!
3//! Functions for converting quantized tensors back to floating point.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use axonml_tensor::Tensor;
9use rayon::prelude::*;
10
11use crate::error::{QuantError, QuantResult};
12use crate::types::{Q4Block, Q4_1Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
13
14// =============================================================================
15// Public API
16// =============================================================================
17
18/// Dequantizes a quantized tensor back to f32.
19///
20/// # Arguments
21/// * `quantized` - The quantized tensor to dequantize
22///
23/// # Returns
24/// A tensor with f32 values
25///
26/// # Example
27/// ```ignore
28/// use axonml_quant::{quantize_tensor, dequantize_tensor, QuantType};
29///
30/// let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4])?;
31/// let quantized = quantize_tensor(&tensor, QuantType::Q8_0)?;
32/// let dequantized = dequantize_tensor(&quantized)?;
33/// ```
34pub fn dequantize_tensor(quantized: &QuantizedTensor) -> QuantResult<Tensor<f32>> {
35    let data = match quantized.quant_type {
36        QuantType::Q8_0 => dequantize_q8_0(quantized),
37        QuantType::Q4_0 => dequantize_q4_0(quantized),
38        QuantType::Q4_1 => dequantize_q4_1(quantized),
39        QuantType::Q5_0 | QuantType::Q5_1 => dequantize_q4_0(quantized), // Fallback
40        QuantType::F16 => dequantize_f16(quantized),
41        QuantType::F32 => dequantize_f32(quantized),
42    }?;
43
44    // Truncate to original size
45    let expected_size = quantized.numel;
46    let data = if data.len() > expected_size {
47        data[..expected_size].to_vec()
48    } else {
49        data
50    };
51
52    Tensor::from_vec(data, &quantized.shape)
53        .map_err(|e| QuantError::TensorConversion(format!("{:?}", e)))
54}
55
56/// Dequantizes a single block.
57pub fn dequantize_block(block: &QuantizedBlock) -> Vec<f32> {
58    match block {
59        QuantizedBlock::Q8(b) => dequantize_q8_block(b),
60        QuantizedBlock::Q4(b) => dequantize_q4_block(b),
61        QuantizedBlock::Q4_1(b) => dequantize_q4_1_block(b),
62        QuantizedBlock::F16(data) => data.iter().map(|x| x.to_f32()).collect(),
63        QuantizedBlock::F32(data) => data.clone(),
64    }
65}
66
67// =============================================================================
68// Q8_0 Dequantization
69// =============================================================================
70
71/// Dequantizes Q8_0 data.
72fn dequantize_q8_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
73    let result: Vec<f32> = quantized
74        .blocks
75        .par_iter()
76        .flat_map(|block| {
77            if let QuantizedBlock::Q8(b) = block {
78                dequantize_q8_block(b)
79            } else {
80                vec![0.0; 32]
81            }
82        })
83        .collect();
84
85    Ok(result)
86}
87
88/// Dequantizes a single Q8 block.
89fn dequantize_q8_block(block: &Q8Block) -> Vec<f32> {
90    let scale = block.scale.to_f32();
91    block.data.iter().map(|&q| q as f32 * scale).collect()
92}
93
94// =============================================================================
95// Q4_0 Dequantization
96// =============================================================================
97
98/// Dequantizes Q4_0 data.
99fn dequantize_q4_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
100    let result: Vec<f32> = quantized
101        .blocks
102        .par_iter()
103        .flat_map(|block| {
104            if let QuantizedBlock::Q4(b) = block {
105                dequantize_q4_block(b)
106            } else {
107                vec![0.0; 32]
108            }
109        })
110        .collect();
111
112    Ok(result)
113}
114
115/// Dequantizes a single Q4 block.
116fn dequantize_q4_block(block: &Q4Block) -> Vec<f32> {
117    let scale = block.scale.to_f32();
118    let unpacked = block.unpack();
119
120    unpacked.iter().map(|&q| q as f32 * scale).collect()
121}
122
123// =============================================================================
124// Q4_1 Dequantization
125// =============================================================================
126
127/// Dequantizes Q4_1 data.
128fn dequantize_q4_1(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
129    let result: Vec<f32> = quantized
130        .blocks
131        .par_iter()
132        .flat_map(|block| {
133            if let QuantizedBlock::Q4_1(b) = block {
134                dequantize_q4_1_block(b)
135            } else {
136                vec![0.0; 32]
137            }
138        })
139        .collect();
140
141    Ok(result)
142}
143
144/// Dequantizes a single Q4_1 block.
145fn dequantize_q4_1_block(block: &Q4_1Block) -> Vec<f32> {
146    let scale = block.scale.to_f32();
147    let min = block.min.to_f32();
148    let unpacked = block.unpack();
149
150    unpacked.iter().map(|&q| q as f32 * scale + min).collect()
151}
152
153// =============================================================================
154// F16 Dequantization
155// =============================================================================
156
157/// Dequantizes F16 data.
158fn dequantize_f16(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
159    let result: Vec<f32> = quantized
160        .blocks
161        .iter()
162        .flat_map(|block| {
163            if let QuantizedBlock::F16(data) = block {
164                data.iter().map(|x| x.to_f32()).collect()
165            } else {
166                vec![]
167            }
168        })
169        .collect();
170
171    Ok(result)
172}
173
174// =============================================================================
175// F32 Dequantization (passthrough)
176// =============================================================================
177
178/// Dequantizes F32 data (passthrough).
179fn dequantize_f32(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
180    let result: Vec<f32> = quantized
181        .blocks
182        .iter()
183        .flat_map(|block| {
184            if let QuantizedBlock::F32(data) = block {
185                data.clone()
186            } else {
187                vec![]
188            }
189        })
190        .collect();
191
192    Ok(result)
193}
194
195// =============================================================================
196// Tests
197// =============================================================================
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use crate::quantize::quantize_tensor;
203
204    #[test]
205    fn test_roundtrip_q8_0() {
206        let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
207        let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
208
209        let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
210        let dequantized = dequantize_tensor(&quantized).unwrap();
211
212        // Check shape preserved
213        assert_eq!(dequantized.shape(), &[64]);
214
215        // Check values are close (some error expected)
216        let deq_data = dequantized.to_vec();
217        for (orig, deq) in original.iter().zip(deq_data.iter()) {
218            assert!(
219                (orig - deq).abs() < 0.1,
220                "Q8 error too large: {} vs {}",
221                orig,
222                deq
223            );
224        }
225    }
226
227    #[test]
228    fn test_roundtrip_q4_0() {
229        let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
230        let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
231
232        let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
233        let dequantized = dequantize_tensor(&quantized).unwrap();
234
235        assert_eq!(dequantized.shape(), &[64]);
236
237        // Q4 has more error but should still be reasonable
238        let deq_data = dequantized.to_vec();
239        let max_error: f32 = original
240            .iter()
241            .zip(deq_data.iter())
242            .map(|(a, b)| (a - b).abs())
243            .fold(0.0, f32::max);
244
245        assert!(max_error < 2.0, "Q4 max error too large: {}", max_error);
246    }
247
248    #[test]
249    fn test_roundtrip_f16() {
250        let original = vec![1.0f32, 2.5, -3.0, 4.25];
251        let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
252
253        let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
254        let dequantized = dequantize_tensor(&quantized).unwrap();
255
256        let deq_data = dequantized.to_vec();
257        for (orig, deq) in original.iter().zip(deq_data.iter()) {
258            assert!((orig - deq).abs() < 0.01, "F16 error: {} vs {}", orig, deq);
259        }
260    }
261
262    #[test]
263    fn test_roundtrip_f32() {
264        let original = vec![1.0f32, 2.5, -3.0, 4.25];
265        let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
266
267        let quantized = quantize_tensor(&tensor, QuantType::F32).unwrap();
268        let dequantized = dequantize_tensor(&quantized).unwrap();
269
270        let deq_data = dequantized.to_vec();
271        assert_eq!(original, deq_data);
272    }
273}