axonml_quant/
dequantize.rs

1//! Dequantization Functions
2//!
3//! Functions for converting quantized tensors back to floating point.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use axonml_tensor::Tensor;
9use half::f16;
10use rayon::prelude::*;
11
12use crate::error::{QuantError, QuantResult};
13use crate::types::{QuantType, QuantizedTensor, QuantizedBlock, Q8Block, Q4Block, Q4_1Block};
14
15// =============================================================================
16// Public API
17// =============================================================================
18
19/// Dequantizes a quantized tensor back to f32.
20///
21/// # Arguments
22/// * `quantized` - The quantized tensor to dequantize
23///
24/// # Returns
25/// A tensor with f32 values
26///
27/// # Example
28/// ```ignore
29/// use axonml_quant::{quantize_tensor, dequantize_tensor, QuantType};
30///
31/// let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4])?;
32/// let quantized = quantize_tensor(&tensor, QuantType::Q8_0)?;
33/// let dequantized = dequantize_tensor(&quantized)?;
34/// ```
35pub fn dequantize_tensor(quantized: &QuantizedTensor) -> QuantResult<Tensor<f32>> {
36    let data = match quantized.quant_type {
37        QuantType::Q8_0 => dequantize_q8_0(quantized),
38        QuantType::Q4_0 => dequantize_q4_0(quantized),
39        QuantType::Q4_1 => dequantize_q4_1(quantized),
40        QuantType::Q5_0 | QuantType::Q5_1 => dequantize_q4_0(quantized), // Fallback
41        QuantType::F16 => dequantize_f16(quantized),
42        QuantType::F32 => dequantize_f32(quantized),
43    }?;
44
45    // Truncate to original size
46    let expected_size = quantized.numel;
47    let data = if data.len() > expected_size {
48        data[..expected_size].to_vec()
49    } else {
50        data
51    };
52
53    Tensor::from_vec(data, &quantized.shape)
54        .map_err(|e| QuantError::TensorConversion(format!("{:?}", e)))
55}
56
57/// Dequantizes a single block.
58pub fn dequantize_block(block: &QuantizedBlock) -> Vec<f32> {
59    match block {
60        QuantizedBlock::Q8(b) => dequantize_q8_block(b),
61        QuantizedBlock::Q4(b) => dequantize_q4_block(b),
62        QuantizedBlock::Q4_1(b) => dequantize_q4_1_block(b),
63        QuantizedBlock::F16(data) => data.iter().map(|x| x.to_f32()).collect(),
64        QuantizedBlock::F32(data) => data.clone(),
65    }
66}
67
68// =============================================================================
69// Q8_0 Dequantization
70// =============================================================================
71
72/// Dequantizes Q8_0 data.
73fn dequantize_q8_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
74    let result: Vec<f32> = quantized
75        .blocks
76        .par_iter()
77        .flat_map(|block| {
78            if let QuantizedBlock::Q8(b) = block {
79                dequantize_q8_block(b)
80            } else {
81                vec![0.0; 32]
82            }
83        })
84        .collect();
85
86    Ok(result)
87}
88
89/// Dequantizes a single Q8 block.
90fn dequantize_q8_block(block: &Q8Block) -> Vec<f32> {
91    let scale = block.scale.to_f32();
92    block
93        .data
94        .iter()
95        .map(|&q| q as f32 * scale)
96        .collect()
97}
98
99// =============================================================================
100// Q4_0 Dequantization
101// =============================================================================
102
103/// Dequantizes Q4_0 data.
104fn dequantize_q4_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
105    let result: Vec<f32> = quantized
106        .blocks
107        .par_iter()
108        .flat_map(|block| {
109            if let QuantizedBlock::Q4(b) = block {
110                dequantize_q4_block(b)
111            } else {
112                vec![0.0; 32]
113            }
114        })
115        .collect();
116
117    Ok(result)
118}
119
120/// Dequantizes a single Q4 block.
121fn dequantize_q4_block(block: &Q4Block) -> Vec<f32> {
122    let scale = block.scale.to_f32();
123    let unpacked = block.unpack();
124
125    unpacked
126        .iter()
127        .map(|&q| q as f32 * scale)
128        .collect()
129}
130
131// =============================================================================
132// Q4_1 Dequantization
133// =============================================================================
134
135/// Dequantizes Q4_1 data.
136fn dequantize_q4_1(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
137    let result: Vec<f32> = quantized
138        .blocks
139        .par_iter()
140        .flat_map(|block| {
141            if let QuantizedBlock::Q4_1(b) = block {
142                dequantize_q4_1_block(b)
143            } else {
144                vec![0.0; 32]
145            }
146        })
147        .collect();
148
149    Ok(result)
150}
151
152/// Dequantizes a single Q4_1 block.
153fn dequantize_q4_1_block(block: &Q4_1Block) -> Vec<f32> {
154    let scale = block.scale.to_f32();
155    let min = block.min.to_f32();
156    let unpacked = block.unpack();
157
158    unpacked
159        .iter()
160        .map(|&q| q as f32 * scale + min)
161        .collect()
162}
163
164// =============================================================================
165// F16 Dequantization
166// =============================================================================
167
168/// Dequantizes F16 data.
169fn dequantize_f16(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
170    let result: Vec<f32> = quantized
171        .blocks
172        .iter()
173        .flat_map(|block| {
174            if let QuantizedBlock::F16(data) = block {
175                data.iter().map(|x| x.to_f32()).collect()
176            } else {
177                vec![]
178            }
179        })
180        .collect();
181
182    Ok(result)
183}
184
185// =============================================================================
186// F32 Dequantization (passthrough)
187// =============================================================================
188
189/// Dequantizes F32 data (passthrough).
190fn dequantize_f32(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
191    let result: Vec<f32> = quantized
192        .blocks
193        .iter()
194        .flat_map(|block| {
195            if let QuantizedBlock::F32(data) = block {
196                data.clone()
197            } else {
198                vec![]
199            }
200        })
201        .collect();
202
203    Ok(result)
204}
205
206// =============================================================================
207// Tests
208// =============================================================================
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::quantize::quantize_tensor;
214
215    #[test]
216    fn test_roundtrip_q8_0() {
217        let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
218        let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
219
220        let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
221        let dequantized = dequantize_tensor(&quantized).unwrap();
222
223        // Check shape preserved
224        assert_eq!(dequantized.shape(), &[64]);
225
226        // Check values are close (some error expected)
227        let deq_data = dequantized.to_vec();
228        for (orig, deq) in original.iter().zip(deq_data.iter()) {
229            assert!((orig - deq).abs() < 0.1, "Q8 error too large: {} vs {}", orig, deq);
230        }
231    }
232
233    #[test]
234    fn test_roundtrip_q4_0() {
235        let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
236        let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
237
238        let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
239        let dequantized = dequantize_tensor(&quantized).unwrap();
240
241        assert_eq!(dequantized.shape(), &[64]);
242
243        // Q4 has more error but should still be reasonable
244        let deq_data = dequantized.to_vec();
245        let max_error: f32 = original
246            .iter()
247            .zip(deq_data.iter())
248            .map(|(a, b)| (a - b).abs())
249            .fold(0.0, f32::max);
250
251        assert!(max_error < 2.0, "Q4 max error too large: {}", max_error);
252    }
253
254    #[test]
255    fn test_roundtrip_f16() {
256        let original = vec![1.0f32, 2.5, -3.0, 4.25];
257        let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
258
259        let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
260        let dequantized = dequantize_tensor(&quantized).unwrap();
261
262        let deq_data = dequantized.to_vec();
263        for (orig, deq) in original.iter().zip(deq_data.iter()) {
264            assert!((orig - deq).abs() < 0.01, "F16 error: {} vs {}", orig, deq);
265        }
266    }
267
268    #[test]
269    fn test_roundtrip_f32() {
270        let original = vec![1.0f32, 2.5, -3.0, 4.25];
271        let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
272
273        let quantized = quantize_tensor(&tensor, QuantType::F32).unwrap();
274        let dequantized = dequantize_tensor(&quantized).unwrap();
275
276        let deq_data = dequantized.to_vec();
277        assert_eq!(original, deq_data);
278    }
279}