axonml_quant/
dequantize.rs1use axonml_tensor::Tensor;
18use rayon::prelude::*;
19
20use crate::error::{QuantError, QuantResult};
21use crate::types::{Q4_1Block, Q4Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
22
23pub fn dequantize_tensor(quantized: &QuantizedTensor) -> QuantResult<Tensor<f32>> {
44 let data = match quantized.quant_type {
45 QuantType::Q8_0 => dequantize_q8_0(quantized),
46 QuantType::Q4_0 => dequantize_q4_0(quantized),
47 QuantType::Q4_1 => dequantize_q4_1(quantized),
48 QuantType::Q5_0 => dequantize_q5_0(quantized),
49 QuantType::Q5_1 => dequantize_q5_1(quantized),
50 QuantType::F16 => dequantize_f16(quantized),
51 QuantType::F32 => dequantize_f32(quantized),
52 }?;
53
54 let expected_size = quantized.numel;
56 let data = if data.len() > expected_size {
57 data[..expected_size].to_vec()
58 } else {
59 data
60 };
61
62 Tensor::from_vec(data, &quantized.shape)
63 .map_err(|e| QuantError::TensorConversion(format!("{:?}", e)))
64}
65
66pub fn dequantize_block(block: &QuantizedBlock) -> Vec<f32> {
68 match block {
69 QuantizedBlock::Q8(b) => dequantize_q8_block(b),
70 QuantizedBlock::Q4(b) => dequantize_q4_block(b),
71 QuantizedBlock::Q4_1(b) => dequantize_q4_1_block(b),
72 QuantizedBlock::Q5(b) => {
73 let scale = b.scale.to_f32();
74 b.unpack().iter().map(|&v| v as f32 * scale).collect()
75 }
76 QuantizedBlock::Q5_1(b) => {
77 let scale = b.scale.to_f32();
78 let min = b.min.to_f32();
79 b.unpack().iter().map(|&v| v as f32 * scale + min).collect()
80 }
81 QuantizedBlock::F16(data) => data.iter().map(|x| x.to_f32()).collect(),
82 QuantizedBlock::F32(data) => data.clone(),
83 }
84}
85
86fn dequantize_q8_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
92 let result: Vec<f32> = quantized
93 .blocks
94 .par_iter()
95 .flat_map(|block| {
96 if let QuantizedBlock::Q8(b) = block {
97 dequantize_q8_block(b)
98 } else {
99 vec![0.0; 32]
100 }
101 })
102 .collect();
103
104 Ok(result)
105}
106
107fn dequantize_q8_block(block: &Q8Block) -> Vec<f32> {
109 let scale = block.scale.to_f32();
110 block.data.iter().map(|&q| q as f32 * scale).collect()
111}
112
113fn dequantize_q4_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
119 let result: Vec<f32> = quantized
120 .blocks
121 .par_iter()
122 .flat_map(|block| {
123 if let QuantizedBlock::Q4(b) = block {
124 dequantize_q4_block(b)
125 } else {
126 vec![0.0; 32]
127 }
128 })
129 .collect();
130
131 Ok(result)
132}
133
134fn dequantize_q4_block(block: &Q4Block) -> Vec<f32> {
136 let scale = block.scale.to_f32();
137 let unpacked = block.unpack();
138
139 unpacked.iter().map(|&q| q as f32 * scale).collect()
140}
141
142fn dequantize_q4_1(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
148 let result: Vec<f32> = quantized
149 .blocks
150 .par_iter()
151 .flat_map(|block| {
152 if let QuantizedBlock::Q4_1(b) = block {
153 dequantize_q4_1_block(b)
154 } else {
155 vec![0.0; 32]
156 }
157 })
158 .collect();
159
160 Ok(result)
161}
162
163fn dequantize_q4_1_block(block: &Q4_1Block) -> Vec<f32> {
165 let scale = block.scale.to_f32();
166 let min = block.min.to_f32();
167 let unpacked = block.unpack();
168
169 unpacked.iter().map(|&q| q as f32 * scale + min).collect()
170}
171
172fn dequantize_q5_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
178 let mut result = Vec::new();
179 for block in &quantized.blocks {
180 if let QuantizedBlock::Q5(q5) = block {
181 let scale = q5.scale.to_f32();
182 let values = q5.unpack();
183 for &v in &values {
184 result.push(v as f32 * scale);
185 }
186 }
187 }
188 Ok(result)
189}
190
191fn dequantize_q5_1(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
197 let mut result = Vec::new();
198 for block in &quantized.blocks {
199 if let QuantizedBlock::Q5_1(q5) = block {
200 let scale = q5.scale.to_f32();
201 let min = q5.min.to_f32();
202 let values = q5.unpack();
203 for &v in &values {
204 result.push(v as f32 * scale + min);
205 }
206 }
207 }
208 Ok(result)
209}
210
211fn dequantize_f16(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
217 let result: Vec<f32> = quantized
218 .blocks
219 .iter()
220 .flat_map(|block| {
221 if let QuantizedBlock::F16(data) = block {
222 data.iter().map(|x| x.to_f32()).collect()
223 } else {
224 vec![]
225 }
226 })
227 .collect();
228
229 Ok(result)
230}
231
232fn dequantize_f32(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
238 let result: Vec<f32> = quantized
239 .blocks
240 .iter()
241 .flat_map(|block| {
242 if let QuantizedBlock::F32(data) = block {
243 data.clone()
244 } else {
245 vec![]
246 }
247 })
248 .collect();
249
250 Ok(result)
251}
252
253#[cfg(test)]
258mod tests {
259 use super::*;
260 use crate::quantize::quantize_tensor;
261
262 #[test]
263 fn test_roundtrip_q8_0() {
264 let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
265 let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
266
267 let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
268 let dequantized = dequantize_tensor(&quantized).unwrap();
269
270 assert_eq!(dequantized.shape(), &[64]);
272
273 let deq_data = dequantized.to_vec();
275 for (orig, deq) in original.iter().zip(deq_data.iter()) {
276 assert!(
277 (orig - deq).abs() < 0.1,
278 "Q8 error too large: {} vs {}",
279 orig,
280 deq
281 );
282 }
283 }
284
285 #[test]
286 fn test_roundtrip_q4_0() {
287 let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
288 let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
289
290 let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
291 let dequantized = dequantize_tensor(&quantized).unwrap();
292
293 assert_eq!(dequantized.shape(), &[64]);
294
295 let deq_data = dequantized.to_vec();
297 let max_error: f32 = original
298 .iter()
299 .zip(deq_data.iter())
300 .map(|(a, b)| (a - b).abs())
301 .fold(0.0, f32::max);
302
303 assert!(max_error < 2.0, "Q4 max error too large: {}", max_error);
304 }
305
306 #[test]
307 fn test_roundtrip_f16() {
308 let original = vec![1.0f32, 2.5, -3.0, 4.25];
309 let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
310
311 let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
312 let dequantized = dequantize_tensor(&quantized).unwrap();
313
314 let deq_data = dequantized.to_vec();
315 for (orig, deq) in original.iter().zip(deq_data.iter()) {
316 assert!((orig - deq).abs() < 0.01, "F16 error: {} vs {}", orig, deq);
317 }
318 }
319
320 #[test]
321 fn test_roundtrip_f32() {
322 let original = vec![1.0f32, 2.5, -3.0, 4.25];
323 let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
324
325 let quantized = quantize_tensor(&tensor, QuantType::F32).unwrap();
326 let dequantized = dequantize_tensor(&quantized).unwrap();
327
328 let deq_data = dequantized.to_vec();
329 assert_eq!(original, deq_data);
330 }
331}