1use crate::module::Module;
7use ghostflow_core::tensor::Tensor;
8use std::collections::HashMap;
9
10#[derive(Clone, Copy, Debug)]
12pub enum QuantizationScheme {
13 INT8,
15 FP16,
17 Dynamic,
19}
20
21#[derive(Clone, Debug)]
23pub struct QuantizationConfig {
24 pub scheme: QuantizationScheme,
25 pub per_channel: bool, pub symmetric: bool, }
28
29impl Default for QuantizationConfig {
30 fn default() -> Self {
31 Self {
32 scheme: QuantizationScheme::INT8,
33 per_channel: true,
34 symmetric: true,
35 }
36 }
37}
38
39#[derive(Clone, Debug)]
41pub struct QuantizedTensor {
42 pub data: Vec<i8>,
44 pub scales: Vec<f32>,
46 pub zero_points: Vec<i8>,
48 pub shape: Vec<usize>,
50 pub scheme: QuantizationScheme,
52}
53
54impl QuantizedTensor {
55 pub fn from_tensor(tensor: &Tensor, config: &QuantizationConfig) -> Self {
57 match config.scheme {
58 QuantizationScheme::INT8 => Self::quantize_int8(tensor, config),
59 QuantizationScheme::FP16 => Self::quantize_fp16(tensor, config),
60 QuantizationScheme::Dynamic => Self::quantize_int8(tensor, config),
61 }
62 }
63
64 fn quantize_int8(tensor: &Tensor, config: &QuantizationConfig) -> Self {
65 let data_guard = tensor.storage().as_slice::<f32>();
66 let data_slice = &*data_guard;
67 let shape = tensor.shape().dims().to_vec();
68
69 if config.per_channel {
70 Self::quantize_per_channel_int8(data_slice, &shape, config.symmetric)
72 } else {
73 Self::quantize_per_tensor_int8(data_slice, &shape, config.symmetric)
75 }
76 }
77
78 fn quantize_per_tensor_int8(data: &[f32], shape: &[usize], symmetric: bool) -> Self {
79 let min_val = data.iter().cloned().fold(f32::INFINITY, f32::min);
80 let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
81
82 let (scale, zero_point) = if symmetric {
83 let abs_max = min_val.abs().max(max_val.abs());
85 let scale = abs_max / 127.0;
86 (scale, 0i8)
87 } else {
88 let scale = (max_val - min_val) / 255.0;
90 let zero_point = (-min_val / scale - 128.0).round() as i8;
91 (scale, zero_point)
92 };
93
94 let quantized_data: Vec<i8> = data
95 .iter()
96 .map(|&x| {
97 let q = (x / scale).round() as i32 + zero_point as i32;
98 q.clamp(-128, 127) as i8
99 })
100 .collect();
101
102 Self {
103 data: quantized_data,
104 scales: vec![scale],
105 zero_points: vec![zero_point],
106 shape: shape.to_vec(),
107 scheme: QuantizationScheme::INT8,
108 }
109 }
110
111 fn quantize_per_channel_int8(data: &[f32], shape: &[usize], symmetric: bool) -> Self {
112 let num_channels = shape[0];
114 let channel_size = data.len() / num_channels;
115
116 let mut scales = Vec::with_capacity(num_channels);
117 let mut zero_points = Vec::with_capacity(num_channels);
118 let mut quantized_data = Vec::with_capacity(data.len());
119
120 for ch in 0..num_channels {
121 let start = ch * channel_size;
122 let end = start + channel_size;
123 let channel_data = &data[start..end];
124
125 let min_val = channel_data.iter().cloned().fold(f32::INFINITY, f32::min);
126 let max_val = channel_data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
127
128 let (scale, zero_point) = if symmetric {
129 let abs_max = min_val.abs().max(max_val.abs());
130 let scale = abs_max / 127.0;
131 (scale, 0i8)
132 } else {
133 let scale = (max_val - min_val) / 255.0;
134 let zero_point = (-min_val / scale - 128.0).round() as i8;
135 (scale, zero_point)
136 };
137
138 scales.push(scale);
139 zero_points.push(zero_point);
140
141 for &x in channel_data {
142 let q = (x / scale).round() as i32 + zero_point as i32;
143 quantized_data.push(q.clamp(-128, 127) as i8);
144 }
145 }
146
147 Self {
148 data: quantized_data,
149 scales,
150 zero_points,
151 shape: shape.to_vec(),
152 scheme: QuantizationScheme::INT8,
153 }
154 }
155
156 fn quantize_fp16(_tensor: &Tensor, _config: &QuantizationConfig) -> Self {
157 unimplemented!("FP16 quantization requires half-precision support")
160 }
161
162 pub fn dequantize(&self) -> Tensor {
164 match self.scheme {
165 QuantizationScheme::INT8 | QuantizationScheme::Dynamic => {
166 self.dequantize_int8()
167 }
168 QuantizationScheme::FP16 => {
169 unimplemented!("FP16 dequantization not yet implemented")
170 }
171 }
172 }
173
174 fn dequantize_int8(&self) -> Tensor {
175 if self.scales.len() == 1 {
176 let scale = self.scales[0];
178 let zero_point = self.zero_points[0];
179
180 let dequantized: Vec<f32> = self.data
181 .iter()
182 .map(|&q| (q as f32 - zero_point as f32) * scale)
183 .collect();
184
185 Tensor::from_slice::<f32>(&dequantized, &self.shape).unwrap()
186 } else {
187 let num_channels = self.shape[0];
189 let channel_size = self.data.len() / num_channels;
190 let mut dequantized = Vec::with_capacity(self.data.len());
191
192 for ch in 0..num_channels {
193 let scale = self.scales[ch];
194 let zero_point = self.zero_points[ch];
195 let start = ch * channel_size;
196 let end = start + channel_size;
197
198 for &q in &self.data[start..end] {
199 dequantized.push((q as f32 - zero_point as f32) * scale);
200 }
201 }
202
203 Tensor::from_slice::<f32>(&dequantized, &self.shape).unwrap()
204 }
205 }
206
207 pub fn compression_ratio(&self) -> f32 {
209 let original_size = self.data.len() * std::mem::size_of::<f32>();
210 let quantized_size = self.data.len() * std::mem::size_of::<i8>()
211 + self.scales.len() * std::mem::size_of::<f32>()
212 + self.zero_points.len() * std::mem::size_of::<i8>();
213 original_size as f32 / quantized_size as f32
214 }
215}
216
217pub struct QuantizationAwareTraining {
221 config: QuantizationConfig,
222 fake_quantize: bool,
223}
224
225impl QuantizationAwareTraining {
226 pub fn new(config: QuantizationConfig) -> Self {
227 Self {
228 config,
229 fake_quantize: true,
230 }
231 }
232
233 pub fn fake_quantize(&self, tensor: &Tensor) -> Tensor {
235 if !self.fake_quantize {
236 return tensor.clone();
237 }
238
239 let quantized = QuantizedTensor::from_tensor(tensor, &self.config);
240 quantized.dequantize()
241 }
242
243 pub fn set_fake_quantize(&mut self, enabled: bool) {
245 self.fake_quantize = enabled;
246 }
247}
248
249pub struct DynamicQuantization {
253 config: QuantizationConfig,
254 weight_quantized: HashMap<String, QuantizedTensor>,
255}
256
257impl DynamicQuantization {
258 pub fn new() -> Self {
259 Self {
260 config: QuantizationConfig {
261 scheme: QuantizationScheme::Dynamic,
262 per_channel: true,
263 symmetric: true,
264 },
265 weight_quantized: HashMap::new(),
266 }
267 }
268
269 pub fn quantize_weights(&mut self, name: &str, weights: &Tensor) {
271 let quantized = QuantizedTensor::from_tensor(weights, &self.config);
272 self.weight_quantized.insert(name.to_string(), quantized);
273 }
274
275 pub fn get_weights(&self, name: &str) -> Option<Tensor> {
277 self.weight_quantized.get(name).map(|q| q.dequantize())
278 }
279
280 pub fn quantize_activation(&self, activation: &Tensor) -> QuantizedTensor {
282 let config = QuantizationConfig {
283 scheme: QuantizationScheme::INT8,
284 per_channel: false, symmetric: false, };
287 QuantizedTensor::from_tensor(activation, &config)
288 }
289}
290
291impl Default for DynamicQuantization {
292 fn default() -> Self {
293 Self::new()
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[test]
302 fn test_per_tensor_quantization() {
303 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
304 let tensor = Tensor::from_slice(&data, &[2, 3]).unwrap();
305
306 let config = QuantizationConfig {
307 scheme: QuantizationScheme::INT8,
308 per_channel: false,
309 symmetric: true,
310 };
311
312 let quantized = QuantizedTensor::from_tensor(&tensor, &config);
313 let dequantized = quantized.dequantize();
314
315 assert_eq!(dequantized.shape().dims(), tensor.shape().dims());
317
318 let original = tensor.storage().as_slice::<f32>();
320 let recovered = dequantized.storage().as_slice::<f32>();
321 for (o, r) in original.iter().zip(recovered.iter()) {
322 assert!((o - r).abs() < 0.1, "Original: {}, Recovered: {}", o, r);
323 }
324 }
325
326 #[test]
327 fn test_per_channel_quantization() {
328 let data = vec![1.0f32, 2.0, 3.0, 10.0, 20.0, 30.0];
329 let tensor = Tensor::from_slice(&data, &[2, 3]).unwrap();
330
331 let config = QuantizationConfig {
332 scheme: QuantizationScheme::INT8,
333 per_channel: true,
334 symmetric: true,
335 };
336
337 let quantized = QuantizedTensor::from_tensor(&tensor, &config);
338
339 assert_eq!(quantized.scales.len(), 2);
341
342 let dequantized = quantized.dequantize();
343
344 let original = tensor.storage().as_slice::<f32>();
346 let recovered = dequantized.storage().as_slice::<f32>();
347 for (o, r) in original.iter().zip(recovered.iter()) {
348 assert!((o - r).abs() < 0.5, "Original: {}, Recovered: {}", o, r);
349 }
350 }
351
352 #[test]
353 fn test_asymmetric_quantization() {
354 let data = vec![-5.0f32, -3.0, -1.0, 1.0, 3.0, 5.0];
355 let tensor = Tensor::from_slice(&data, &[6]).unwrap();
356
357 let config = QuantizationConfig {
358 scheme: QuantizationScheme::INT8,
359 per_channel: false,
360 symmetric: false,
361 };
362
363 let quantized = QuantizedTensor::from_tensor(&tensor, &config);
364 let dequantized = quantized.dequantize();
365
366 let original = tensor.storage().as_slice::<f32>();
367 let recovered = dequantized.storage().as_slice::<f32>();
368 for (o, r) in original.iter().zip(recovered.iter()) {
369 assert!((o - r).abs() < 0.1, "Original: {}, Recovered: {}", o, r);
370 }
371 }
372
373 #[test]
374 fn test_compression_ratio() {
375 let data: Vec<f32> = (0..1000).map(|x| x as f32).collect();
376 let tensor = Tensor::from_slice(&data, &[1000]).unwrap();
377
378 let config = QuantizationConfig {
379 scheme: QuantizationScheme::INT8,
380 per_channel: false, symmetric: true,
382 };
383 let quantized = QuantizedTensor::from_tensor(&tensor, &config);
384
385 let ratio = quantized.compression_ratio();
386 assert!(ratio > 3.5 && ratio < 4.5, "Compression ratio: {}", ratio);
388 }
389
390 #[test]
391 fn test_quantization_aware_training() {
392 let data = vec![1.0f32, 2.0, 3.0, 4.0];
393 let tensor = Tensor::from_slice(&data, &[4]).unwrap();
394
395 let config = QuantizationConfig::default();
396 let qat = QuantizationAwareTraining::new(config);
397
398 let fake_quantized = qat.fake_quantize(&tensor);
399
400 assert_eq!(fake_quantized.shape().dims(), tensor.shape().dims());
402
403 let original = tensor.storage().as_slice::<f32>();
405 let quantized = fake_quantized.storage().as_slice::<f32>();
406 for (o, q) in original.iter().zip(quantized.iter()) {
407 assert!((o - q).abs() < 0.1);
408 }
409 }
410
411 #[test]
412 fn test_dynamic_quantization() {
413 let mut dq = DynamicQuantization::new();
414
415 let weights = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
417 dq.quantize_weights("layer1", &weights);
418
419 let retrieved = dq.get_weights("layer1").unwrap();
421 assert_eq!(retrieved.shape().dims(), weights.shape().dims());
422
423 let activation = Tensor::from_slice(&[0.5f32, 1.5, 2.5], &[3]).unwrap();
425 let q_activation = dq.quantize_activation(&activation);
426
427 assert_eq!(q_activation.shape, vec![3]);
428 assert_eq!(q_activation.scales.len(), 1); }
430}