axonml_quant/
dequantize.rs1use axonml_tensor::Tensor;
9use rayon::prelude::*;
10
11use crate::error::{QuantError, QuantResult};
12use crate::types::{QuantType, QuantizedTensor, QuantizedBlock, Q8Block, Q4Block, Q4_1Block};
13
14pub fn dequantize_tensor(quantized: &QuantizedTensor) -> QuantResult<Tensor<f32>> {
35 let data = match quantized.quant_type {
36 QuantType::Q8_0 => dequantize_q8_0(quantized),
37 QuantType::Q4_0 => dequantize_q4_0(quantized),
38 QuantType::Q4_1 => dequantize_q4_1(quantized),
39 QuantType::Q5_0 | QuantType::Q5_1 => dequantize_q4_0(quantized), QuantType::F16 => dequantize_f16(quantized),
41 QuantType::F32 => dequantize_f32(quantized),
42 }?;
43
44 let expected_size = quantized.numel;
46 let data = if data.len() > expected_size {
47 data[..expected_size].to_vec()
48 } else {
49 data
50 };
51
52 Tensor::from_vec(data, &quantized.shape)
53 .map_err(|e| QuantError::TensorConversion(format!("{:?}", e)))
54}
55
56pub fn dequantize_block(block: &QuantizedBlock) -> Vec<f32> {
58 match block {
59 QuantizedBlock::Q8(b) => dequantize_q8_block(b),
60 QuantizedBlock::Q4(b) => dequantize_q4_block(b),
61 QuantizedBlock::Q4_1(b) => dequantize_q4_1_block(b),
62 QuantizedBlock::F16(data) => data.iter().map(|x| x.to_f32()).collect(),
63 QuantizedBlock::F32(data) => data.clone(),
64 }
65}
66
67fn dequantize_q8_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
73 let result: Vec<f32> = quantized
74 .blocks
75 .par_iter()
76 .flat_map(|block| {
77 if let QuantizedBlock::Q8(b) = block {
78 dequantize_q8_block(b)
79 } else {
80 vec![0.0; 32]
81 }
82 })
83 .collect();
84
85 Ok(result)
86}
87
88fn dequantize_q8_block(block: &Q8Block) -> Vec<f32> {
90 let scale = block.scale.to_f32();
91 block
92 .data
93 .iter()
94 .map(|&q| q as f32 * scale)
95 .collect()
96}
97
98fn dequantize_q4_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
104 let result: Vec<f32> = quantized
105 .blocks
106 .par_iter()
107 .flat_map(|block| {
108 if let QuantizedBlock::Q4(b) = block {
109 dequantize_q4_block(b)
110 } else {
111 vec![0.0; 32]
112 }
113 })
114 .collect();
115
116 Ok(result)
117}
118
119fn dequantize_q4_block(block: &Q4Block) -> Vec<f32> {
121 let scale = block.scale.to_f32();
122 let unpacked = block.unpack();
123
124 unpacked
125 .iter()
126 .map(|&q| q as f32 * scale)
127 .collect()
128}
129
130fn dequantize_q4_1(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
136 let result: Vec<f32> = quantized
137 .blocks
138 .par_iter()
139 .flat_map(|block| {
140 if let QuantizedBlock::Q4_1(b) = block {
141 dequantize_q4_1_block(b)
142 } else {
143 vec![0.0; 32]
144 }
145 })
146 .collect();
147
148 Ok(result)
149}
150
151fn dequantize_q4_1_block(block: &Q4_1Block) -> Vec<f32> {
153 let scale = block.scale.to_f32();
154 let min = block.min.to_f32();
155 let unpacked = block.unpack();
156
157 unpacked
158 .iter()
159 .map(|&q| q as f32 * scale + min)
160 .collect()
161}
162
163fn dequantize_f16(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
169 let result: Vec<f32> = quantized
170 .blocks
171 .iter()
172 .flat_map(|block| {
173 if let QuantizedBlock::F16(data) = block {
174 data.iter().map(|x| x.to_f32()).collect()
175 } else {
176 vec![]
177 }
178 })
179 .collect();
180
181 Ok(result)
182}
183
184fn dequantize_f32(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
190 let result: Vec<f32> = quantized
191 .blocks
192 .iter()
193 .flat_map(|block| {
194 if let QuantizedBlock::F32(data) = block {
195 data.clone()
196 } else {
197 vec![]
198 }
199 })
200 .collect();
201
202 Ok(result)
203}
204
205#[cfg(test)]
210mod tests {
211 use super::*;
212 use crate::quantize::quantize_tensor;
213
214 #[test]
215 fn test_roundtrip_q8_0() {
216 let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
217 let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
218
219 let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
220 let dequantized = dequantize_tensor(&quantized).unwrap();
221
222 assert_eq!(dequantized.shape(), &[64]);
224
225 let deq_data = dequantized.to_vec();
227 for (orig, deq) in original.iter().zip(deq_data.iter()) {
228 assert!((orig - deq).abs() < 0.1, "Q8 error too large: {} vs {}", orig, deq);
229 }
230 }
231
232 #[test]
233 fn test_roundtrip_q4_0() {
234 let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
235 let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
236
237 let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
238 let dequantized = dequantize_tensor(&quantized).unwrap();
239
240 assert_eq!(dequantized.shape(), &[64]);
241
242 let deq_data = dequantized.to_vec();
244 let max_error: f32 = original
245 .iter()
246 .zip(deq_data.iter())
247 .map(|(a, b)| (a - b).abs())
248 .fold(0.0, f32::max);
249
250 assert!(max_error < 2.0, "Q4 max error too large: {}", max_error);
251 }
252
253 #[test]
254 fn test_roundtrip_f16() {
255 let original = vec![1.0f32, 2.5, -3.0, 4.25];
256 let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
257
258 let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
259 let dequantized = dequantize_tensor(&quantized).unwrap();
260
261 let deq_data = dequantized.to_vec();
262 for (orig, deq) in original.iter().zip(deq_data.iter()) {
263 assert!((orig - deq).abs() < 0.01, "F16 error: {} vs {}", orig, deq);
264 }
265 }
266
267 #[test]
268 fn test_roundtrip_f32() {
269 let original = vec![1.0f32, 2.5, -3.0, 4.25];
270 let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
271
272 let quantized = quantize_tensor(&tensor, QuantType::F32).unwrap();
273 let dequantized = dequantize_tensor(&quantized).unwrap();
274
275 let deq_data = dequantized.to_vec();
276 assert_eq!(original, deq_data);
277 }
278}