axonml_quant/
dequantize.rs1use axonml_tensor::Tensor;
9use half::f16;
10use rayon::prelude::*;
11
12use crate::error::{QuantError, QuantResult};
13use crate::types::{QuantType, QuantizedTensor, QuantizedBlock, Q8Block, Q4Block, Q4_1Block};
14
15pub fn dequantize_tensor(quantized: &QuantizedTensor) -> QuantResult<Tensor<f32>> {
36 let data = match quantized.quant_type {
37 QuantType::Q8_0 => dequantize_q8_0(quantized),
38 QuantType::Q4_0 => dequantize_q4_0(quantized),
39 QuantType::Q4_1 => dequantize_q4_1(quantized),
40 QuantType::Q5_0 | QuantType::Q5_1 => dequantize_q4_0(quantized), QuantType::F16 => dequantize_f16(quantized),
42 QuantType::F32 => dequantize_f32(quantized),
43 }?;
44
45 let expected_size = quantized.numel;
47 let data = if data.len() > expected_size {
48 data[..expected_size].to_vec()
49 } else {
50 data
51 };
52
53 Tensor::from_vec(data, &quantized.shape)
54 .map_err(|e| QuantError::TensorConversion(format!("{:?}", e)))
55}
56
57pub fn dequantize_block(block: &QuantizedBlock) -> Vec<f32> {
59 match block {
60 QuantizedBlock::Q8(b) => dequantize_q8_block(b),
61 QuantizedBlock::Q4(b) => dequantize_q4_block(b),
62 QuantizedBlock::Q4_1(b) => dequantize_q4_1_block(b),
63 QuantizedBlock::F16(data) => data.iter().map(|x| x.to_f32()).collect(),
64 QuantizedBlock::F32(data) => data.clone(),
65 }
66}
67
68fn dequantize_q8_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
74 let result: Vec<f32> = quantized
75 .blocks
76 .par_iter()
77 .flat_map(|block| {
78 if let QuantizedBlock::Q8(b) = block {
79 dequantize_q8_block(b)
80 } else {
81 vec![0.0; 32]
82 }
83 })
84 .collect();
85
86 Ok(result)
87}
88
89fn dequantize_q8_block(block: &Q8Block) -> Vec<f32> {
91 let scale = block.scale.to_f32();
92 block
93 .data
94 .iter()
95 .map(|&q| q as f32 * scale)
96 .collect()
97}
98
99fn dequantize_q4_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
105 let result: Vec<f32> = quantized
106 .blocks
107 .par_iter()
108 .flat_map(|block| {
109 if let QuantizedBlock::Q4(b) = block {
110 dequantize_q4_block(b)
111 } else {
112 vec![0.0; 32]
113 }
114 })
115 .collect();
116
117 Ok(result)
118}
119
120fn dequantize_q4_block(block: &Q4Block) -> Vec<f32> {
122 let scale = block.scale.to_f32();
123 let unpacked = block.unpack();
124
125 unpacked
126 .iter()
127 .map(|&q| q as f32 * scale)
128 .collect()
129}
130
131fn dequantize_q4_1(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
137 let result: Vec<f32> = quantized
138 .blocks
139 .par_iter()
140 .flat_map(|block| {
141 if let QuantizedBlock::Q4_1(b) = block {
142 dequantize_q4_1_block(b)
143 } else {
144 vec![0.0; 32]
145 }
146 })
147 .collect();
148
149 Ok(result)
150}
151
152fn dequantize_q4_1_block(block: &Q4_1Block) -> Vec<f32> {
154 let scale = block.scale.to_f32();
155 let min = block.min.to_f32();
156 let unpacked = block.unpack();
157
158 unpacked
159 .iter()
160 .map(|&q| q as f32 * scale + min)
161 .collect()
162}
163
164fn dequantize_f16(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
170 let result: Vec<f32> = quantized
171 .blocks
172 .iter()
173 .flat_map(|block| {
174 if let QuantizedBlock::F16(data) = block {
175 data.iter().map(|x| x.to_f32()).collect()
176 } else {
177 vec![]
178 }
179 })
180 .collect();
181
182 Ok(result)
183}
184
185fn dequantize_f32(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
191 let result: Vec<f32> = quantized
192 .blocks
193 .iter()
194 .flat_map(|block| {
195 if let QuantizedBlock::F32(data) = block {
196 data.clone()
197 } else {
198 vec![]
199 }
200 })
201 .collect();
202
203 Ok(result)
204}
205
206#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::quantize::quantize_tensor;
214
215 #[test]
216 fn test_roundtrip_q8_0() {
217 let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
218 let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
219
220 let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
221 let dequantized = dequantize_tensor(&quantized).unwrap();
222
223 assert_eq!(dequantized.shape(), &[64]);
225
226 let deq_data = dequantized.to_vec();
228 for (orig, deq) in original.iter().zip(deq_data.iter()) {
229 assert!((orig - deq).abs() < 0.1, "Q8 error too large: {} vs {}", orig, deq);
230 }
231 }
232
233 #[test]
234 fn test_roundtrip_q4_0() {
235 let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
236 let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
237
238 let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
239 let dequantized = dequantize_tensor(&quantized).unwrap();
240
241 assert_eq!(dequantized.shape(), &[64]);
242
243 let deq_data = dequantized.to_vec();
245 let max_error: f32 = original
246 .iter()
247 .zip(deq_data.iter())
248 .map(|(a, b)| (a - b).abs())
249 .fold(0.0, f32::max);
250
251 assert!(max_error < 2.0, "Q4 max error too large: {}", max_error);
252 }
253
254 #[test]
255 fn test_roundtrip_f16() {
256 let original = vec![1.0f32, 2.5, -3.0, 4.25];
257 let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
258
259 let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
260 let dequantized = dequantize_tensor(&quantized).unwrap();
261
262 let deq_data = dequantized.to_vec();
263 for (orig, deq) in original.iter().zip(deq_data.iter()) {
264 assert!((orig - deq).abs() < 0.01, "F16 error: {} vs {}", orig, deq);
265 }
266 }
267
268 #[test]
269 fn test_roundtrip_f32() {
270 let original = vec![1.0f32, 2.5, -3.0, 4.25];
271 let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
272
273 let quantized = quantize_tensor(&tensor, QuantType::F32).unwrap();
274 let dequantized = dequantize_tensor(&quantized).unwrap();
275
276 let deq_data = dequantized.to_vec();
277 assert_eq!(original, deq_data);
278 }
279}