1use ghostflow_core::tensor::Tensor;
7use std::collections::HashMap;
8
9#[derive(Clone, Copy, Debug)]
11pub enum QuantizationScheme {
12 INT8,
14 FP16,
16 Dynamic,
18}
19
20#[derive(Clone, Debug)]
22pub struct QuantizationConfig {
23 pub scheme: QuantizationScheme,
24 pub per_channel: bool, pub symmetric: bool, }
27
28impl Default for QuantizationConfig {
29 fn default() -> Self {
30 Self {
31 scheme: QuantizationScheme::INT8,
32 per_channel: true,
33 symmetric: true,
34 }
35 }
36}
37
38#[derive(Clone, Debug)]
40pub struct QuantizedTensor {
41 pub data: Vec<i8>,
43 pub scales: Vec<f32>,
45 pub zero_points: Vec<i8>,
47 pub shape: Vec<usize>,
49 pub scheme: QuantizationScheme,
51}
52
53impl QuantizedTensor {
54 pub fn from_tensor(tensor: &Tensor, config: &QuantizationConfig) -> Self {
56 match config.scheme {
57 QuantizationScheme::INT8 => Self::quantize_int8(tensor, config),
58 QuantizationScheme::FP16 => Self::quantize_fp16(tensor, config),
59 QuantizationScheme::Dynamic => Self::quantize_int8(tensor, config),
60 }
61 }
62
63 fn quantize_int8(tensor: &Tensor, config: &QuantizationConfig) -> Self {
64 let data_guard = tensor.storage().as_slice::<f32>();
65 let data_slice = &*data_guard;
66 let shape = tensor.shape().dims().to_vec();
67
68 if config.per_channel {
69 Self::quantize_per_channel_int8(data_slice, &shape, config.symmetric)
71 } else {
72 Self::quantize_per_tensor_int8(data_slice, &shape, config.symmetric)
74 }
75 }
76
77 fn quantize_per_tensor_int8(data: &[f32], shape: &[usize], symmetric: bool) -> Self {
78 let min_val = data.iter().cloned().fold(f32::INFINITY, f32::min);
79 let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
80
81 let (scale, zero_point) = if symmetric {
82 let abs_max = min_val.abs().max(max_val.abs());
84 let scale = abs_max / 127.0;
85 (scale, 0i8)
86 } else {
87 let scale = (max_val - min_val) / 255.0;
89 let zero_point = (-min_val / scale - 128.0).round() as i8;
90 (scale, zero_point)
91 };
92
93 let quantized_data: Vec<i8> = data
94 .iter()
95 .map(|&x| {
96 let q = (x / scale).round() as i32 + zero_point as i32;
97 q.clamp(-128, 127) as i8
98 })
99 .collect();
100
101 Self {
102 data: quantized_data,
103 scales: vec![scale],
104 zero_points: vec![zero_point],
105 shape: shape.to_vec(),
106 scheme: QuantizationScheme::INT8,
107 }
108 }
109
110 fn quantize_per_channel_int8(data: &[f32], shape: &[usize], symmetric: bool) -> Self {
111 let num_channels = shape[0];
113 let channel_size = data.len() / num_channels;
114
115 let mut scales = Vec::with_capacity(num_channels);
116 let mut zero_points = Vec::with_capacity(num_channels);
117 let mut quantized_data = Vec::with_capacity(data.len());
118
119 for ch in 0..num_channels {
120 let start = ch * channel_size;
121 let end = start + channel_size;
122 let channel_data = &data[start..end];
123
124 let min_val = channel_data.iter().cloned().fold(f32::INFINITY, f32::min);
125 let max_val = channel_data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
126
127 let (scale, zero_point) = if symmetric {
128 let abs_max = min_val.abs().max(max_val.abs());
129 let scale = abs_max / 127.0;
130 (scale, 0i8)
131 } else {
132 let scale = (max_val - min_val) / 255.0;
133 let zero_point = (-min_val / scale - 128.0).round() as i8;
134 (scale, zero_point)
135 };
136
137 scales.push(scale);
138 zero_points.push(zero_point);
139
140 for &x in channel_data {
141 let q = (x / scale).round() as i32 + zero_point as i32;
142 quantized_data.push(q.clamp(-128, 127) as i8);
143 }
144 }
145
146 Self {
147 data: quantized_data,
148 scales,
149 zero_points,
150 shape: shape.to_vec(),
151 scheme: QuantizationScheme::INT8,
152 }
153 }
154
155 fn quantize_fp16(_tensor: &Tensor, _config: &QuantizationConfig) -> Self {
156 unimplemented!("FP16 quantization requires half-precision support")
159 }
160
161 pub fn dequantize(&self) -> Tensor {
163 match self.scheme {
164 QuantizationScheme::INT8 | QuantizationScheme::Dynamic => {
165 self.dequantize_int8()
166 }
167 QuantizationScheme::FP16 => {
168 unimplemented!("FP16 dequantization not yet implemented")
169 }
170 }
171 }
172
173 fn dequantize_int8(&self) -> Tensor {
174 if self.scales.len() == 1 {
175 let scale = self.scales[0];
177 let zero_point = self.zero_points[0];
178
179 let dequantized: Vec<f32> = self.data
180 .iter()
181 .map(|&q| (q as f32 - zero_point as f32) * scale)
182 .collect();
183
184 Tensor::from_slice::<f32>(&dequantized, &self.shape).unwrap()
185 } else {
186 let num_channels = self.shape[0];
188 let channel_size = self.data.len() / num_channels;
189 let mut dequantized = Vec::with_capacity(self.data.len());
190
191 for ch in 0..num_channels {
192 let scale = self.scales[ch];
193 let zero_point = self.zero_points[ch];
194 let start = ch * channel_size;
195 let end = start + channel_size;
196
197 for &q in &self.data[start..end] {
198 dequantized.push((q as f32 - zero_point as f32) * scale);
199 }
200 }
201
202 Tensor::from_slice::<f32>(&dequantized, &self.shape).unwrap()
203 }
204 }
205
206 pub fn compression_ratio(&self) -> f32 {
208 let original_size = self.data.len() * std::mem::size_of::<f32>();
209 let quantized_size = self.data.len() * std::mem::size_of::<i8>()
210 + self.scales.len() * std::mem::size_of::<f32>()
211 + self.zero_points.len() * std::mem::size_of::<i8>();
212 original_size as f32 / quantized_size as f32
213 }
214}
215
216pub struct QuantizationAwareTraining {
220 config: QuantizationConfig,
221 fake_quantize: bool,
222}
223
224impl QuantizationAwareTraining {
225 pub fn new(config: QuantizationConfig) -> Self {
226 Self {
227 config,
228 fake_quantize: true,
229 }
230 }
231
232 pub fn fake_quantize(&self, tensor: &Tensor) -> Tensor {
234 if !self.fake_quantize {
235 return tensor.clone();
236 }
237
238 let quantized = QuantizedTensor::from_tensor(tensor, &self.config);
239 quantized.dequantize()
240 }
241
242 pub fn set_fake_quantize(&mut self, enabled: bool) {
244 self.fake_quantize = enabled;
245 }
246}
247
248pub struct DynamicQuantization {
252 config: QuantizationConfig,
253 weight_quantized: HashMap<String, QuantizedTensor>,
254}
255
256impl DynamicQuantization {
257 pub fn new() -> Self {
258 Self {
259 config: QuantizationConfig {
260 scheme: QuantizationScheme::Dynamic,
261 per_channel: true,
262 symmetric: true,
263 },
264 weight_quantized: HashMap::new(),
265 }
266 }
267
268 pub fn quantize_weights(&mut self, name: &str, weights: &Tensor) {
270 let quantized = QuantizedTensor::from_tensor(weights, &self.config);
271 self.weight_quantized.insert(name.to_string(), quantized);
272 }
273
274 pub fn get_weights(&self, name: &str) -> Option<Tensor> {
276 self.weight_quantized.get(name).map(|q| q.dequantize())
277 }
278
279 pub fn quantize_activation(&self, activation: &Tensor) -> QuantizedTensor {
281 let config = QuantizationConfig {
282 scheme: QuantizationScheme::INT8,
283 per_channel: false, symmetric: false, };
286 QuantizedTensor::from_tensor(activation, &config)
287 }
288}
289
290impl Default for DynamicQuantization {
291 fn default() -> Self {
292 Self::new()
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_per_tensor_quantization() {
302 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
303 let tensor = Tensor::from_slice(&data, &[2, 3]).unwrap();
304
305 let config = QuantizationConfig {
306 scheme: QuantizationScheme::INT8,
307 per_channel: false,
308 symmetric: true,
309 };
310
311 let quantized = QuantizedTensor::from_tensor(&tensor, &config);
312 let dequantized = quantized.dequantize();
313
314 assert_eq!(dequantized.shape().dims(), tensor.shape().dims());
316
317 let original = tensor.storage().as_slice::<f32>();
319 let recovered = dequantized.storage().as_slice::<f32>();
320 for (o, r) in original.iter().zip(recovered.iter()) {
321 assert!((o - r).abs() < 0.1, "Original: {}, Recovered: {}", o, r);
322 }
323 }
324
325 #[test]
326 fn test_per_channel_quantization() {
327 let data = vec![1.0f32, 2.0, 3.0, 10.0, 20.0, 30.0];
328 let tensor = Tensor::from_slice(&data, &[2, 3]).unwrap();
329
330 let config = QuantizationConfig {
331 scheme: QuantizationScheme::INT8,
332 per_channel: true,
333 symmetric: true,
334 };
335
336 let quantized = QuantizedTensor::from_tensor(&tensor, &config);
337
338 assert_eq!(quantized.scales.len(), 2);
340
341 let dequantized = quantized.dequantize();
342
343 let original = tensor.storage().as_slice::<f32>();
345 let recovered = dequantized.storage().as_slice::<f32>();
346 for (o, r) in original.iter().zip(recovered.iter()) {
347 assert!((o - r).abs() < 0.5, "Original: {}, Recovered: {}", o, r);
348 }
349 }
350
351 #[test]
352 fn test_asymmetric_quantization() {
353 let data = vec![-5.0f32, -3.0, -1.0, 1.0, 3.0, 5.0];
354 let tensor = Tensor::from_slice(&data, &[6]).unwrap();
355
356 let config = QuantizationConfig {
357 scheme: QuantizationScheme::INT8,
358 per_channel: false,
359 symmetric: false,
360 };
361
362 let quantized = QuantizedTensor::from_tensor(&tensor, &config);
363 let dequantized = quantized.dequantize();
364
365 let original = tensor.storage().as_slice::<f32>();
366 let recovered = dequantized.storage().as_slice::<f32>();
367 for (o, r) in original.iter().zip(recovered.iter()) {
368 assert!((o - r).abs() < 0.1, "Original: {}, Recovered: {}", o, r);
369 }
370 }
371
372 #[test]
373 fn test_compression_ratio() {
374 let data: Vec<f32> = (0..1000).map(|x| x as f32).collect();
375 let tensor = Tensor::from_slice(&data, &[1000]).unwrap();
376
377 let config = QuantizationConfig {
378 scheme: QuantizationScheme::INT8,
379 per_channel: false, symmetric: true,
381 };
382 let quantized = QuantizedTensor::from_tensor(&tensor, &config);
383
384 let ratio = quantized.compression_ratio();
385 assert!(ratio > 3.5 && ratio < 4.5, "Compression ratio: {}", ratio);
387 }
388
389 #[test]
390 fn test_quantization_aware_training() {
391 let data = vec![1.0f32, 2.0, 3.0, 4.0];
392 let tensor = Tensor::from_slice(&data, &[4]).unwrap();
393
394 let config = QuantizationConfig::default();
395 let qat = QuantizationAwareTraining::new(config);
396
397 let fake_quantized = qat.fake_quantize(&tensor);
398
399 assert_eq!(fake_quantized.shape().dims(), tensor.shape().dims());
401
402 let original = tensor.storage().as_slice::<f32>();
404 let quantized = fake_quantized.storage().as_slice::<f32>();
405 for (o, q) in original.iter().zip(quantized.iter()) {
406 assert!((o - q).abs() < 0.1);
407 }
408 }
409
410 #[test]
411 fn test_dynamic_quantization() {
412 let mut dq = DynamicQuantization::new();
413
414 let weights = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
416 dq.quantize_weights("layer1", &weights);
417
418 let retrieved = dq.get_weights("layer1").unwrap();
420 assert_eq!(retrieved.shape().dims(), weights.shape().dims());
421
422 let activation = Tensor::from_slice(&[0.5f32, 1.5, 2.5], &[3]).unwrap();
424 let q_activation = dq.quantize_activation(&activation);
425
426 assert_eq!(q_activation.shape, vec![3]);
427 assert_eq!(q_activation.scales.len(), 1); }
429}