axonml_quant/
dequantize.rs1use axonml_tensor::Tensor;
18use rayon::prelude::*;
19
20use crate::error::{QuantError, QuantResult};
21use crate::types::{Q4_1Block, Q4Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
22
23pub fn dequantize_tensor(quantized: &QuantizedTensor) -> QuantResult<Tensor<f32>> {
44 let data = match quantized.quant_type {
45 QuantType::Q8_0 => dequantize_q8_0(quantized),
46 QuantType::Q4_0 => dequantize_q4_0(quantized),
47 QuantType::Q4_1 => dequantize_q4_1(quantized),
48 QuantType::Q5_0 | QuantType::Q5_1 => dequantize_q4_0(quantized), QuantType::F16 => dequantize_f16(quantized),
50 QuantType::F32 => dequantize_f32(quantized),
51 }?;
52
53 let expected_size = quantized.numel;
55 let data = if data.len() > expected_size {
56 data[..expected_size].to_vec()
57 } else {
58 data
59 };
60
61 Tensor::from_vec(data, &quantized.shape)
62 .map_err(|e| QuantError::TensorConversion(format!("{:?}", e)))
63}
64
65pub fn dequantize_block(block: &QuantizedBlock) -> Vec<f32> {
67 match block {
68 QuantizedBlock::Q8(b) => dequantize_q8_block(b),
69 QuantizedBlock::Q4(b) => dequantize_q4_block(b),
70 QuantizedBlock::Q4_1(b) => dequantize_q4_1_block(b),
71 QuantizedBlock::F16(data) => data.iter().map(|x| x.to_f32()).collect(),
72 QuantizedBlock::F32(data) => data.clone(),
73 }
74}
75
76fn dequantize_q8_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
82 let result: Vec<f32> = quantized
83 .blocks
84 .par_iter()
85 .flat_map(|block| {
86 if let QuantizedBlock::Q8(b) = block {
87 dequantize_q8_block(b)
88 } else {
89 vec![0.0; 32]
90 }
91 })
92 .collect();
93
94 Ok(result)
95}
96
97fn dequantize_q8_block(block: &Q8Block) -> Vec<f32> {
99 let scale = block.scale.to_f32();
100 block.data.iter().map(|&q| q as f32 * scale).collect()
101}
102
103fn dequantize_q4_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
109 let result: Vec<f32> = quantized
110 .blocks
111 .par_iter()
112 .flat_map(|block| {
113 if let QuantizedBlock::Q4(b) = block {
114 dequantize_q4_block(b)
115 } else {
116 vec![0.0; 32]
117 }
118 })
119 .collect();
120
121 Ok(result)
122}
123
124fn dequantize_q4_block(block: &Q4Block) -> Vec<f32> {
126 let scale = block.scale.to_f32();
127 let unpacked = block.unpack();
128
129 unpacked.iter().map(|&q| q as f32 * scale).collect()
130}
131
132fn dequantize_q4_1(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
138 let result: Vec<f32> = quantized
139 .blocks
140 .par_iter()
141 .flat_map(|block| {
142 if let QuantizedBlock::Q4_1(b) = block {
143 dequantize_q4_1_block(b)
144 } else {
145 vec![0.0; 32]
146 }
147 })
148 .collect();
149
150 Ok(result)
151}
152
153fn dequantize_q4_1_block(block: &Q4_1Block) -> Vec<f32> {
155 let scale = block.scale.to_f32();
156 let min = block.min.to_f32();
157 let unpacked = block.unpack();
158
159 unpacked.iter().map(|&q| q as f32 * scale + min).collect()
160}
161
162fn dequantize_f16(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
168 let result: Vec<f32> = quantized
169 .blocks
170 .iter()
171 .flat_map(|block| {
172 if let QuantizedBlock::F16(data) = block {
173 data.iter().map(|x| x.to_f32()).collect()
174 } else {
175 vec![]
176 }
177 })
178 .collect();
179
180 Ok(result)
181}
182
183fn dequantize_f32(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
189 let result: Vec<f32> = quantized
190 .blocks
191 .iter()
192 .flat_map(|block| {
193 if let QuantizedBlock::F32(data) = block {
194 data.clone()
195 } else {
196 vec![]
197 }
198 })
199 .collect();
200
201 Ok(result)
202}
203
204#[cfg(test)]
209mod tests {
210 use super::*;
211 use crate::quantize::quantize_tensor;
212
213 #[test]
214 fn test_roundtrip_q8_0() {
215 let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
216 let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
217
218 let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
219 let dequantized = dequantize_tensor(&quantized).unwrap();
220
221 assert_eq!(dequantized.shape(), &[64]);
223
224 let deq_data = dequantized.to_vec();
226 for (orig, deq) in original.iter().zip(deq_data.iter()) {
227 assert!(
228 (orig - deq).abs() < 0.1,
229 "Q8 error too large: {} vs {}",
230 orig,
231 deq
232 );
233 }
234 }
235
236 #[test]
237 fn test_roundtrip_q4_0() {
238 let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
239 let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
240
241 let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
242 let dequantized = dequantize_tensor(&quantized).unwrap();
243
244 assert_eq!(dequantized.shape(), &[64]);
245
246 let deq_data = dequantized.to_vec();
248 let max_error: f32 = original
249 .iter()
250 .zip(deq_data.iter())
251 .map(|(a, b)| (a - b).abs())
252 .fold(0.0, f32::max);
253
254 assert!(max_error < 2.0, "Q4 max error too large: {}", max_error);
255 }
256
257 #[test]
258 fn test_roundtrip_f16() {
259 let original = vec![1.0f32, 2.5, -3.0, 4.25];
260 let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
261
262 let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
263 let dequantized = dequantize_tensor(&quantized).unwrap();
264
265 let deq_data = dequantized.to_vec();
266 for (orig, deq) in original.iter().zip(deq_data.iter()) {
267 assert!((orig - deq).abs() < 0.01, "F16 error: {} vs {}", orig, deq);
268 }
269 }
270
271 #[test]
272 fn test_roundtrip_f32() {
273 let original = vec![1.0f32, 2.5, -3.0, 4.25];
274 let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
275
276 let quantized = quantize_tensor(&tensor, QuantType::F32).unwrap();
277 let dequantized = dequantize_tensor(&quantized).unwrap();
278
279 let deq_data = dequantized.to_vec();
280 assert_eq!(original, deq_data);
281 }
282}