axonml_quant/
dequantize.rs1use axonml_tensor::Tensor;
9use rayon::prelude::*;
10
11use crate::error::{QuantError, QuantResult};
12use crate::types::{Q4Block, Q4_1Block, Q8Block, QuantType, QuantizedBlock, QuantizedTensor};
13
14pub fn dequantize_tensor(quantized: &QuantizedTensor) -> QuantResult<Tensor<f32>> {
35 let data = match quantized.quant_type {
36 QuantType::Q8_0 => dequantize_q8_0(quantized),
37 QuantType::Q4_0 => dequantize_q4_0(quantized),
38 QuantType::Q4_1 => dequantize_q4_1(quantized),
39 QuantType::Q5_0 | QuantType::Q5_1 => dequantize_q4_0(quantized), QuantType::F16 => dequantize_f16(quantized),
41 QuantType::F32 => dequantize_f32(quantized),
42 }?;
43
44 let expected_size = quantized.numel;
46 let data = if data.len() > expected_size {
47 data[..expected_size].to_vec()
48 } else {
49 data
50 };
51
52 Tensor::from_vec(data, &quantized.shape)
53 .map_err(|e| QuantError::TensorConversion(format!("{:?}", e)))
54}
55
56pub fn dequantize_block(block: &QuantizedBlock) -> Vec<f32> {
58 match block {
59 QuantizedBlock::Q8(b) => dequantize_q8_block(b),
60 QuantizedBlock::Q4(b) => dequantize_q4_block(b),
61 QuantizedBlock::Q4_1(b) => dequantize_q4_1_block(b),
62 QuantizedBlock::F16(data) => data.iter().map(|x| x.to_f32()).collect(),
63 QuantizedBlock::F32(data) => data.clone(),
64 }
65}
66
67fn dequantize_q8_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
73 let result: Vec<f32> = quantized
74 .blocks
75 .par_iter()
76 .flat_map(|block| {
77 if let QuantizedBlock::Q8(b) = block {
78 dequantize_q8_block(b)
79 } else {
80 vec![0.0; 32]
81 }
82 })
83 .collect();
84
85 Ok(result)
86}
87
88fn dequantize_q8_block(block: &Q8Block) -> Vec<f32> {
90 let scale = block.scale.to_f32();
91 block.data.iter().map(|&q| q as f32 * scale).collect()
92}
93
94fn dequantize_q4_0(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
100 let result: Vec<f32> = quantized
101 .blocks
102 .par_iter()
103 .flat_map(|block| {
104 if let QuantizedBlock::Q4(b) = block {
105 dequantize_q4_block(b)
106 } else {
107 vec![0.0; 32]
108 }
109 })
110 .collect();
111
112 Ok(result)
113}
114
115fn dequantize_q4_block(block: &Q4Block) -> Vec<f32> {
117 let scale = block.scale.to_f32();
118 let unpacked = block.unpack();
119
120 unpacked.iter().map(|&q| q as f32 * scale).collect()
121}
122
123fn dequantize_q4_1(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
129 let result: Vec<f32> = quantized
130 .blocks
131 .par_iter()
132 .flat_map(|block| {
133 if let QuantizedBlock::Q4_1(b) = block {
134 dequantize_q4_1_block(b)
135 } else {
136 vec![0.0; 32]
137 }
138 })
139 .collect();
140
141 Ok(result)
142}
143
144fn dequantize_q4_1_block(block: &Q4_1Block) -> Vec<f32> {
146 let scale = block.scale.to_f32();
147 let min = block.min.to_f32();
148 let unpacked = block.unpack();
149
150 unpacked.iter().map(|&q| q as f32 * scale + min).collect()
151}
152
153fn dequantize_f16(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
159 let result: Vec<f32> = quantized
160 .blocks
161 .iter()
162 .flat_map(|block| {
163 if let QuantizedBlock::F16(data) = block {
164 data.iter().map(|x| x.to_f32()).collect()
165 } else {
166 vec![]
167 }
168 })
169 .collect();
170
171 Ok(result)
172}
173
174fn dequantize_f32(quantized: &QuantizedTensor) -> QuantResult<Vec<f32>> {
180 let result: Vec<f32> = quantized
181 .blocks
182 .iter()
183 .flat_map(|block| {
184 if let QuantizedBlock::F32(data) = block {
185 data.clone()
186 } else {
187 vec![]
188 }
189 })
190 .collect();
191
192 Ok(result)
193}
194
195#[cfg(test)]
200mod tests {
201 use super::*;
202 use crate::quantize::quantize_tensor;
203
204 #[test]
205 fn test_roundtrip_q8_0() {
206 let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
207 let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
208
209 let quantized = quantize_tensor(&tensor, QuantType::Q8_0).unwrap();
210 let dequantized = dequantize_tensor(&quantized).unwrap();
211
212 assert_eq!(dequantized.shape(), &[64]);
214
215 let deq_data = dequantized.to_vec();
217 for (orig, deq) in original.iter().zip(deq_data.iter()) {
218 assert!(
219 (orig - deq).abs() < 0.1,
220 "Q8 error too large: {} vs {}",
221 orig,
222 deq
223 );
224 }
225 }
226
227 #[test]
228 fn test_roundtrip_q4_0() {
229 let original: Vec<f32> = (0..64).map(|x| x as f32 / 10.0).collect();
230 let tensor = Tensor::from_vec(original.clone(), &[64]).unwrap();
231
232 let quantized = quantize_tensor(&tensor, QuantType::Q4_0).unwrap();
233 let dequantized = dequantize_tensor(&quantized).unwrap();
234
235 assert_eq!(dequantized.shape(), &[64]);
236
237 let deq_data = dequantized.to_vec();
239 let max_error: f32 = original
240 .iter()
241 .zip(deq_data.iter())
242 .map(|(a, b)| (a - b).abs())
243 .fold(0.0, f32::max);
244
245 assert!(max_error < 2.0, "Q4 max error too large: {}", max_error);
246 }
247
248 #[test]
249 fn test_roundtrip_f16() {
250 let original = vec![1.0f32, 2.5, -3.0, 4.25];
251 let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
252
253 let quantized = quantize_tensor(&tensor, QuantType::F16).unwrap();
254 let dequantized = dequantize_tensor(&quantized).unwrap();
255
256 let deq_data = dequantized.to_vec();
257 for (orig, deq) in original.iter().zip(deq_data.iter()) {
258 assert!((orig - deq).abs() < 0.01, "F16 error: {} vs {}", orig, deq);
259 }
260 }
261
262 #[test]
263 fn test_roundtrip_f32() {
264 let original = vec![1.0f32, 2.5, -3.0, 4.25];
265 let tensor = Tensor::from_vec(original.clone(), &[4]).unwrap();
266
267 let quantized = quantize_tensor(&tensor, QuantType::F32).unwrap();
268 let dequantized = dequantize_tensor(&quantized).unwrap();
269
270 let deq_data = dequantized.to_vec();
271 assert_eq!(original, deq_data);
272 }
273}