axonml_quant/
dequantize.rs1use axonml_tensor::Tensor;
19use rayon::prelude::*;
20
21use crate::error::{QuantError, QuantResult};
22use crate::types::{Q4_1Block, Q4Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
23
24pub fn dequantize_tensor(quantized: &QuantizedTensor) -> QuantResult<Tensor<f32>> {
45 let data = match quantized.quant_type {
46 QuantType::Q8_0 => dequantize_q8_0(quantized),
47 QuantType::Q4_0 => dequantize_q4_0(quantized),
48 QuantType::Q4_1 => dequantize_q4_1(quantized),
49 QuantType::Q5_0 => dequantize_q5_0(quantized),
50 QuantType::Q5_1 => dequantize_q5_1(quantized),
51 QuantType::F16 => dequantize_f16(quantized),
52 QuantType::F32 => dequantize_f32(quantized),
53 }?;
54
55 let expected_size = quantized.numel;
57 let data = if data.len() > expected_size {
58 data[..expected_size].to_vec()
59 } else {
60 data
61 };
62
63 Tensor::from_vec(data, &quantized.shape)
64 .map_err(|e| QuantError::TensorConversion(format!("{:?}", e)))
65}
66
67pub fn dequantize_block(block: &QuantizedBlock) -> Vec<f32> {
69 match block {
70 QuantizedBlock::Q8(b) => dequantize_q8_block(b),
71 QuantizedBlock::Q4(b) => dequantize_q4_block(b),
72 QuantizedBlock::Q4_1(b) => dequantize_q4_1_block(b),
73 QuantizedBlock::Q5(b) => {
74 let scale = b.scale.to_f32();
75 b.unpack().iter().map(|&v| v as f32 * scale).collect()
76 }
77 QuantizedBlock::Q5_1(b) => {
78 let scale = b.scale.to_f32();
79 let min = b.min.to_f32();
80 b.unpack().iter().map(|&v| v as f32 * scale + min).collect()
81 }
82 QuantizedBlock::F16(data) => data.iter().map(|x| x.to_f32()).collect(),
83 QuantizedBlock::F32(data) => data.clone(),
84 }
85}
86
87fn dequantize_q8_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
93 let result: Vec<f32> = quantized
94 .blocks
95 .par_iter()
96 .flat_map(|block| {
97 if let QuantizedBlock::Q8(b) = block {
98 dequantize_q8_block(b)
99 } else {
100 vec![0.0; 32]
101 }
102 })
103 .collect();
104
105 Ok(result)
106}
107
108fn dequantize_q8_block(block: &Q8Block) -> Vec<f32> {
110 let scale = block.scale.to_f32();
111 block.data.iter().map(|&q| q as f32 * scale).collect()
112}
113
114fn dequantize_q4_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
120 let result: Vec<f32> = quantized
121 .blocks
122 .par_iter()
123 .flat_map(|block| {
124 if let QuantizedBlock::Q4(b) = block {
125 dequantize_q4_block(b)
126 } else {
127 vec![0.0; 32]
128 }
129 })
130 .collect();
131
132 Ok(result)
133}
134
135fn dequantize_q4_block(block: &Q4Block) -> Vec<f32> {
137 let scale = block.scale.to_f32();
138 let unpacked = block.unpack();
139
140 unpacked.iter().map(|&q| q as f32 * scale).collect()
141}
142
143fn dequantize_q4_1(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
149 let result: Vec<f32> = quantized
150 .blocks
151 .par_iter()
152 .flat_map(|block| {
153 if let QuantizedBlock::Q4_1(b) = block {
154 dequantize_q4_1_block(b)
155 } else {
156 vec![0.0; 32]
157 }
158 })
159 .collect();
160
161 Ok(result)
162}
163
164fn dequantize_q4_1_block(block: &Q4_1Block) -> Vec<f32> {
166 let scale = block.scale.to_f32();
167 let min = block.min.to_f32();
168 let unpacked = block.unpack();
169
170 unpacked.iter().map(|&q| q as f32 * scale + min).collect()
171}
172
173fn dequantize_q5_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
179 let mut result = Vec::new();
180 for block in &quantized.blocks {
181 if let QuantizedBlock::Q5(q5) = block {
182 let scale = q5.scale.to_f32();
183 let values = q5.unpack();
184 for &v in &values {
185 result.push(v as f32 * scale);
186 }
187 }
188 }
189 Ok(result)
190}
191
192fn dequantize_q5_1(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
198 let mut result = Vec::new();
199 for block in &quantized.blocks {
200 if let QuantizedBlock::Q5_1(q5) = block {
201 let scale = q5.scale.to_f32();
202 let min = q5.min.to_f32();
203 let values = q5.unpack();
204 for &v in &values {
205 result.push(v as f32 * scale + min);
206 }
207 }
208 }
209 Ok(result)
210}
211
212fn dequantize_f16(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
218 let result: Vec<f32> = quantized
219 .blocks
220 .iter()
221 .flat_map(|block| {
222 if let QuantizedBlock::F16(data) = block {
223 data.iter().map(|x| x.to_f32()).collect()
224 } else {
225 vec![]
226 }
227 })
228 .collect();
229
230 Ok(result)
231}
232
233fn dequantize_f32(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
239 let result: Vec<f32> = quantized
240 .blocks
241 .iter()
242 .flat_map(|block| {
243 if let QuantizedBlock::F32(data) = block {
244 data.clone()
245 } else {
246 vec![]
247 }
248 })
249 .collect();
250
251 Ok(result)
252}
253
254#[cfg(test)]
259mod tests {
260 use super::*;
261 use crate::quantize::quantize_tensor;
262
263 #[test]
264 fn test_roundtrip_q8_0() {
265 let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
266 let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
267
268 let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
269 let dequantized = dequantize_tensor(&quantized).unwrap();
270
271 assert_eq!(dequantized.shape(), &[64]);
273
274 let deq_data = dequantized.to_vec();
276 for (orig, deq) in original.iter().zip(deq_data.iter()) {
277 assert!(
278 (orig - deq).abs() < 0.1,
279 "Q8 error too large: {} vs {}",
280 orig,
281 deq
282 );
283 }
284 }
285
286 #[test]
287 fn test_roundtrip_q4_0() {
288 let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
289 let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
290
291 let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
292 let dequantized = dequantize_tensor(&quantized).unwrap();
293
294 assert_eq!(dequantized.shape(), &[64]);
295
296 let deq_data = dequantized.to_vec();
298 let max_error: f32 = original
299 .iter()
300 .zip(deq_data.iter())
301 .map(|(a, b)| (a - b).abs())
302 .fold(0.0, f32::max);
303
304 assert!(max_error < 2.0, "Q4 max error too large: {}", max_error);
305 }
306
307 #[test]
308 fn test_roundtrip_f16() {
309 let original = vec![1.0f32, 2.5, -3.0, 4.25];
310 let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
311
312 let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
313 let dequantized = dequantize_tensor(&quantized).unwrap();
314
315 let deq_data = dequantized.to_vec();
316 for (orig, deq) in original.iter().zip(deq_data.iter()) {
317 assert!((orig - deq).abs() < 0.01, "F16 error: {} vs {}", orig, deq);
318 }
319 }
320
321 #[test]
322 fn test_roundtrip_f32() {
323 let original = vec![1.0f32, 2.5, -3.0, 4.25];
324 let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
325
326 let quantized = quantize_tensor(&tensor, QuantType::F32).unwrap();
327 let dequantized = dequantize_tensor(&quantized).unwrap();
328
329 let deq_data = dequantized.to_vec();
330 assert_eq!(original, deq_data);
331 }
332}