Skip to main content

god_graph/transformer/quantization/
mod.rs

1//! Quantization module for efficient inference
2//!
3//! This module provides:
4//! - INT8 quantization
5//! - INT4 quantization (experimental)
6//! - Quantized matrix multiplication
7//! - Post-training quantization (PTQ)
8
9use crate::tensor::DenseTensor;
10use crate::tensor::traits::TensorBase;
11
12/// Quantization data type
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum QuantDtype {
15    /// 32-bit floating point
16    F32,
17    /// 8-bit integer
18    INT8,
19    /// 4-bit integer
20    INT4,
21}
22
23/// Quantization configuration
24#[derive(Debug, Clone)]
25pub struct QuantizationConfig {
26    /// Target data type
27    pub dtype: QuantDtype,
28    /// Whether to use symmetric quantization
29    pub symmetric: bool,
30    /// Whether to use per-channel quantization
31    pub per_channel: bool,
32    /// Quantization granularity (for per-channel)
33    pub axis: Option<usize>,
34}
35
36impl QuantizationConfig {
37    /// Create default INT8 quantization config
38    pub fn int8() -> Self {
39        Self {
40            dtype: QuantDtype::INT8,
41            symmetric: true,
42            per_channel: false,
43            axis: None,
44        }
45    }
46
47    /// Create INT4 quantization config
48    pub fn int4() -> Self {
49        Self {
50            dtype: QuantDtype::INT4,
51            symmetric: true,
52            per_channel: false,
53            axis: None,
54        }
55    }
56
57    /// Create per-channel INT8 config
58    pub fn per_channel_int8(axis: usize) -> Self {
59        Self {
60            dtype: QuantDtype::INT8,
61            symmetric: true,
62            per_channel: true,
63            axis: Some(axis),
64        }
65    }
66}
67
68/// Quantized tensor (INT8)
69#[derive(Debug, Clone)]
70pub struct QuantizedTensor {
71    /// Quantized data
72    pub data: Vec<i8>,
73    /// Scale factor(s)
74    pub scale: Vec<f64>,
75    /// Zero point(s)
76    pub zero_point: Vec<i8>,
77    /// Original shape
78    pub shape: Vec<usize>,
79    /// Quantization configuration
80    pub config: QuantizationConfig,
81    /// Channel-wise scales (for per-channel quantization)
82    pub channel_scales: Option<Vec<f64>>,
83    /// Channel-wise zero points (for per-channel quantization)
84    pub channel_zero_points: Option<Vec<i8>>,
85}
86
87impl QuantizedTensor {
88    /// Quantize a dense tensor to INT8
89    ///
90    /// # Arguments
91    /// * `tensor` - Input tensor to quantize
92    /// * `config` - Quantization configuration
93    pub fn from_tensor(tensor: &DenseTensor, config: QuantizationConfig) -> Self {
94        match config.dtype {
95            QuantDtype::INT8 => Self::quantize_int8(tensor, &config),
96            QuantDtype::INT4 => Self::quantize_int4(tensor, &config),
97            QuantDtype::F32 => {
98                // No quantization needed
99                let data = tensor.data().iter().map(|&x| x as i8).collect();
100                Self {
101                    data,
102                    scale: vec![1.0],
103                    zero_point: vec![0],
104                    shape: tensor.shape().to_vec(),
105                    config,
106                    channel_scales: None,
107                    channel_zero_points: None,
108                }
109            }
110        }
111    }
112
113    /// Quantize to INT8
114    fn quantize_int8(tensor: &DenseTensor, config: &QuantizationConfig) -> Self {
115        if config.per_channel {
116            Self::quantize_int8_per_channel(tensor, config.axis.unwrap_or(0))
117        } else {
118            Self::quantize_int8_per_tensor(tensor)
119        }
120    }
121
122    /// Per-tensor INT8 quantization (symmetric)
123    fn quantize_int8_per_tensor(tensor: &DenseTensor) -> Self {
124        let data = tensor.data();
125
126        // Find max absolute value for symmetric quantization
127        let max_abs = data.iter().fold(0.0_f64, |max, &x: &f64| max.max(x.abs()));
128
129        // Compute scale for symmetric quantization [-127, 127]
130        let scale = max_abs / 127.0;
131
132        // Quantize
133        let quantized: Vec<i8> = data
134            .iter()
135            .map(|&x| {
136                let q = (x / scale).round() as i32;
137                q.clamp(-128, 127) as i8
138            })
139            .collect();
140
141        Self {
142            data: quantized,
143            scale: vec![scale],
144            zero_point: vec![0],
145            shape: tensor.shape().to_vec(),
146            config: QuantizationConfig::int8(),
147            channel_scales: None,
148            channel_zero_points: None,
149        }
150    }
151
152    /// Per-channel INT8 quantization
153    fn quantize_int8_per_channel(tensor: &DenseTensor, axis: usize) -> Self {
154        let data = tensor.data();
155        let shape = tensor.shape();
156
157        if axis >= shape.len() {
158            return Self::quantize_int8_per_tensor(tensor);
159        }
160
161        let channel_dim = shape[axis];
162        let channels_before: usize = shape[..axis].iter().product();
163        let channels_after: usize = shape[axis + 1..].iter().product();
164
165        let mut channel_scales = Vec::with_capacity(channel_dim);
166        let mut channel_zero_points = Vec::with_capacity(channel_dim);
167        let mut quantized = Vec::with_capacity(data.len());
168
169        for c in 0..channel_dim {
170            // Extract channel data
171            let mut channel_min = f64::INFINITY;
172            let mut channel_max = f64::NEG_INFINITY;
173
174            for cb in 0..channels_before {
175                for ca in 0..channels_after {
176                    let offset = (cb * channel_dim + c) * channels_after + ca;
177                    let val = data[offset];
178                    channel_min = channel_min.min(val);
179                    channel_max = channel_max.max(val);
180                }
181            }
182
183            // Compute scale and zero point for this channel
184            let scale = (channel_max - channel_min) / 255.0;
185            let zero_point = 0i8;
186
187            channel_scales.push(scale);
188            channel_zero_points.push(zero_point);
189        }
190
191        // Quantize all data
192        for (i, &val) in data.iter().enumerate() {
193            let c = (i / channels_after) % channel_dim;
194            let scale = channel_scales[c];
195
196            let q = (val / scale).round() as i32;
197            let q = q.clamp(-128, 127) as i8;
198            quantized.push(q);
199        }
200
201        Self {
202            data: quantized,
203            scale: vec![1.0],
204            zero_point: vec![0],
205            shape: shape.to_vec(),
206            config: QuantizationConfig::per_channel_int8(axis),
207            channel_scales: Some(channel_scales),
208            channel_zero_points: Some(channel_zero_points),
209        }
210    }
211
212    /// Quantize to INT4 (experimental)
213    fn quantize_int4(tensor: &DenseTensor, config: &QuantizationConfig) -> Self {
214        // INT4 quantization packs two values per byte
215        let data = tensor.data();
216
217        // Find min and max
218        let (min, max) = data.iter().fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max): (f64, f64), &x| {
219            (min.min(x), max.max(x))
220        });
221
222        let scale = (max - min) / 15.0; // INT4 has 16 levels
223
224        // Quantize to INT4 and pack
225        let mut packed_data = Vec::with_capacity(data.len().div_ceil(2));
226
227        for i in (0..data.len()).step_by(2) {
228            let q0 = ((data[i] - min) / scale).round() as i32;
229            let q0 = q0.clamp(0, 15) as u8;
230
231            let q1 = if i + 1 < data.len() {
232                ((data[i + 1] - min) / scale).round() as i32
233            } else {
234                0
235            };
236            let q1 = q1.clamp(0, 15) as u8;
237
238            // Pack two INT4 values into one byte
239            let packed = (q1 << 4) | q0;
240            packed_data.push(packed as i8);
241        }
242
243        Self {
244            data: packed_data,
245            scale: vec![scale],
246            zero_point: vec![0],
247            shape: tensor.shape().to_vec(),
248            config: config.clone(),
249            channel_scales: None,
250            channel_zero_points: None,
251        }
252    }
253
254    /// Dequantize to dense tensor
255    pub fn dequantize(&self) -> DenseTensor {
256        match self.config.dtype {
257            QuantDtype::INT8 => self.dequantize_int8(),
258            QuantDtype::INT4 => self.dequantize_int4(),
259            QuantDtype::F32 => {
260                let data = self.data.iter().map(|&x| x as f64).collect();
261                DenseTensor::new(data, self.shape.clone())
262            }
263        }
264    }
265
266    /// Dequantize INT8
267    fn dequantize_int8(&self) -> DenseTensor {
268        let data = if let Some(scales) = &self.channel_scales {
269            // Per-channel dequantization
270            let shape = &self.shape;
271            let axis = self.config.axis.unwrap_or(0);
272            let channel_dim = shape[axis];
273            let _channels_before: usize = shape[..axis].iter().product();
274            let channels_after: usize = shape[axis + 1..].iter().product();
275
276            self.data
277                .iter()
278                .enumerate()
279                .map(|(i, &q)| {
280                    let c = (i / channels_after) % channel_dim;
281                    let scale = scales[c];
282                    q as f64 * scale
283                })
284                .collect()
285        } else {
286            // Per-tensor dequantization
287            let scale = self.scale[0];
288
289            self.data
290                .iter()
291                .map(|&q| q as f64 * scale)
292                .collect()
293        };
294
295        DenseTensor::new(data, self.shape.clone())
296    }
297
298    /// Dequantize INT4
299    fn dequantize_int4(&self) -> DenseTensor {
300        let scale = self.scale[0];
301        let mut data = Vec::with_capacity(self.shape.iter().product::<usize>());
302
303        for &packed in &self.data {
304            let q0 = (packed as u8) & 0x0F;
305            let q1 = (packed as u8) >> 4;
306
307            data.push(q0 as f64 * scale);
308            data.push(q1 as f64 * scale);
309        }
310
311        // Trim to original size
312        let total: usize = self.shape.iter().product();
313        data.truncate(total);
314
315        DenseTensor::new(data, self.shape.clone())
316    }
317
318    /// Get quantized data
319    pub fn quantized_data(&self) -> &[i8] {
320        &self.data
321    }
322
323    /// Get scale
324    pub fn scale(&self) -> f64 {
325        self.scale[0]
326    }
327
328    /// Get memory size in bytes
329    pub fn memory_bytes(&self) -> usize {
330        let total_elements = self.shape.iter().product::<usize>();
331        match self.config.dtype {
332            QuantDtype::INT8 => total_elements, // 1 byte per element
333            QuantDtype::INT4 => total_elements.div_ceil(2), // Packed: 2 elements per byte
334            QuantDtype::F32 => total_elements * 4, // 4 bytes per element
335        }
336    }
337
338    /// Get compression ratio compared to F32
339    pub fn compression_ratio(&self) -> f64 {
340        let original_bytes = self.shape.iter().product::<usize>() * 4; // F32 = 4 bytes
341        original_bytes as f64 / self.memory_bytes() as f64
342    }
343}
344
345/// Quantized matrix multiplication
346pub struct QuantizedMatMul;
347
348impl QuantizedMatMul {
349    /// Multiply quantized matrix with quantized matrix
350    ///
351    /// # Arguments
352    /// * `a` - Quantized matrix [M, K]
353    /// * `b` - Quantized matrix [K, N]
354    ///
355    /// # Returns
356    /// Dequantized result [M, N]
357    pub fn matmul(a: &QuantizedTensor, b: &QuantizedTensor) -> DenseTensor {
358        // Use pure INT8 GEMM for better performance
359        Self::gemm_int8(a, b)
360    }
361
362    /// Multiply quantized matrix with dense matrix
363    ///
364    /// # Arguments
365    /// * `a` - Quantized matrix [M, K]
366    /// * `b` - Dense matrix [K, N]
367    ///
368    /// # Returns
369    /// Dequantized result [M, N]
370    pub fn matmul_qd(a: &QuantizedTensor, b: &DenseTensor) -> DenseTensor {
371        // Quantize b temporarily and use INT8 GEMM
372        let b_q = QuantizedTensor::from_tensor(b, QuantizationConfig::int8());
373        Self::gemm_int8(a, &b_q)
374    }
375
376    /// Multiply dense matrix with quantized matrix
377    ///
378    /// # Arguments
379    /// * `a` - Dense matrix [M, K]
380    /// * `b` - Quantized matrix [K, N]
381    ///
382    /// # Returns
383    /// Dequantized result [M, N]
384    pub fn matmul_dq(a: &DenseTensor, b: &QuantizedTensor) -> DenseTensor {
385        // Quantize a temporarily and use INT8 GEMM
386        let a_q = QuantizedTensor::from_tensor(a, QuantizationConfig::int8());
387        Self::gemm_int8(&a_q, b)
388    }
389
390    /// Pure INT8 GEMM implementation (no dequantization during computation)
391    ///
392    /// This is the performance-critical path that avoids dequantizing
393    /// until the final result, enabling potential SIMD optimizations.
394    ///
395    /// # Algorithm
396    /// For C = A @ B where A, B are INT8:
397    /// 1. Compute INT32 accumulator: acc = sum(a_ik * b_kj)
398    /// 2. Dequantize: C_ij = acc_ij * scale_a * scale_b
399    ///
400    /// # Arguments
401    /// * `a` - Quantized matrix [M, K]
402    /// * `b` - Quantized matrix [K, N]
403    ///
404    /// # Returns
405    /// Dequantized result [M, N]
406    pub fn gemm_int8(a: &QuantizedTensor, b: &QuantizedTensor) -> DenseTensor {
407        let m = a.shape[0];
408        let k = a.shape[1];
409        let n = b.shape[1];
410
411        assert_eq!(a.shape[1], b.shape[0], "Inner dimensions must match");
412
413        // Combined scale for dequantization
414        let scale_a = if let Some(ref scales) = a.channel_scales {
415            // Per-channel quantization for A (axis=1, output channels)
416            scales
417        } else {
418            // Per-tensor quantization
419            &vec![a.scale[0]; k]
420        };
421
422        let scale_b = if let Some(ref scales) = b.channel_scales {
423            // Per-channel quantization for B (axis=0, input channels)
424            scales
425        } else {
426            // Per-tensor quantization
427            &vec![b.scale[0]; k]
428        };
429
430        // Precompute per-row scales for dequantization
431        let output_scales: Vec<f64> = if a.channel_scales.is_some() && b.channel_scales.is_some() {
432            // Both per-channel: output_scale[i,j] = scale_a[j] * scale_b[j]
433            // For simplicity, use average scale
434            let avg_scale_a = scale_a.iter().sum::<f64>() / scale_a.len() as f64;
435            let avg_scale_b = scale_b.iter().sum::<f64>() / scale_b.len() as f64;
436            vec![avg_scale_a * avg_scale_b; m * n]
437        } else if a.channel_scales.is_some() {
438            // A is per-channel, B is per-tensor
439            let scale_b_val = b.scale[0];
440            scale_a.iter().map(|&s| s * scale_b_val).collect()
441        } else if b.channel_scales.is_some() {
442            // A is per-tensor, B is per-channel
443            let scale_a_val = a.scale[0];
444            scale_b.iter().map(|&s| scale_a_val * s).collect()
445        } else {
446            // Both per-tensor
447            vec![a.scale[0] * b.scale[0]; m * n]
448        };
449
450        // INT8 GEMM kernel: compute INT32 accumulators
451        let mut result = Vec::with_capacity(m * n);
452
453        for i in 0..m {
454            for j in 0..n {
455                let mut acc: i32 = 0;
456                
457                // Dot product in INT8, accumulate in INT32
458                for p in 0..k {
459                    let a_val = a.data[i * k + p];
460                    let b_val = b.data[p * n + j];
461                    acc += (a_val as i32) * (b_val as i32);
462                }
463
464                // Dequantize the accumulator
465                let scale = output_scales[i * n + j];
466                result.push(acc as f64 * scale);
467            }
468        }
469
470        DenseTensor::new(result, vec![m, n])
471    }
472
473    /// Optimized INT8 GEMM with loop unrolling and better cache locality
474    ///
475    /// This version uses:
476    /// - Loop unrolling (4x) for better ILP
477    /// - Row-major access pattern for better cache utilization
478    ///
479    /// # Arguments
480    /// * `a` - Quantized matrix [M, K]
481    /// * `b` - Quantized matrix [K, N]
482    ///
483    /// # Returns
484    /// Dequantized result [M, N]
485    pub fn gemm_int8_optimized(a: &QuantizedTensor, b: &QuantizedTensor) -> DenseTensor {
486        let m = a.shape[0];
487        let k = a.shape[1];
488        let n = b.shape[1];
489
490        assert_eq!(a.shape[1], b.shape[0], "Inner dimensions must match");
491
492        // Combined scale
493        let scale = a.scale[0] * b.scale[0];
494
495        let mut result = vec![0.0f64; m * n];
496
497        // Block processing for better cache utilization
498        const BLOCK_SIZE: usize = 32;
499
500        for i_block in (0..m).step_by(BLOCK_SIZE) {
501            for j_block in (0..n).step_by(BLOCK_SIZE) {
502                let i_end = (i_block + BLOCK_SIZE).min(m);
503                let j_end = (j_block + BLOCK_SIZE).min(n);
504
505                for p in 0..k {
506                    // Load and replicate a[p] for this row block
507                    for i in i_block..i_end {
508                        let a_val = a.data[i * k + p] as i32;
509                        
510                        // Process b row with loop unrolling
511                        let mut j = j_block;
512                        while j + 4 <= j_end {
513                            let b0 = b.data[p * n + j] as i32;
514                            let b1 = b.data[p * n + j + 1] as i32;
515                            let b2 = b.data[p * n + j + 2] as i32;
516                            let b3 = b.data[p * n + j + 3] as i32;
517
518                            // Accumulate (will dequantize later)
519                            // Note: We're storing f64 directly for simplicity
520                            // A production implementation would use INT32 accumulators
521                            result[i * n + j] += (a_val * b0) as f64;
522                            result[i * n + j + 1] += (a_val * b1) as f64;
523                            result[i * n + j + 2] += (a_val * b2) as f64;
524                            result[i * n + j + 3] += (a_val * b3) as f64;
525
526                            j += 4;
527                        }
528
529                        // Handle remainder
530                        while j < j_end {
531                            let b_val = b.data[p * n + j] as i32;
532                            result[i * n + j] += (a_val * b_val) as f64;
533                            j += 1;
534                        }
535                    }
536                }
537            }
538        }
539
540        // Final dequantization
541        for val in &mut result {
542            *val *= scale;
543        }
544
545        DenseTensor::new(result, vec![m, n])
546    }
547}
548
549/// Quantization utilities for model weights
550pub mod weight_quantization {
551    use super::*;
552
553    /// Quantize model weights to INT8
554    pub fn quantize_weights(weights: &DenseTensor) -> QuantizedTensor {
555        QuantizedTensor::from_tensor(weights, QuantizationConfig::int8())
556    }
557
558    /// Quantize model weights with per-channel quantization
559    pub fn quantize_weights_per_channel(weights: &DenseTensor, axis: usize) -> QuantizedTensor {
560        QuantizedTensor::from_tensor(weights, QuantizationConfig::per_channel_int8(axis))
561    }
562
563    /// Quantize embedding weights
564    pub fn quantize_embeddings(embeddings: &DenseTensor) -> QuantizedTensor {
565        // Embeddings often benefit from per-row quantization
566        QuantizedTensor::from_tensor(embeddings, QuantizationConfig::per_channel_int8(0))
567    }
568
569    /// Quantize linear layer weights (output channel quantization)
570    pub fn quantize_linear_weights(weights: &DenseTensor) -> QuantizedTensor {
571        // For linear layers, per-output-channel quantization is common
572        QuantizedTensor::from_tensor(weights, QuantizationConfig::per_channel_int8(1))
573    }
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579
580    #[test]
581    fn test_int8_quantization() {
582        let tensor = DenseTensor::new(vec![0.0, 0.25, 0.5, 0.75, 1.0], vec![1, 5]);
583        let config = QuantizationConfig::int8();
584
585        let quantized = QuantizedTensor::from_tensor(&tensor, config);
586
587        assert_eq!(quantized.shape, vec![1, 5]);
588        assert_eq!(quantized.data.len(), 5);
589
590        // Dequantize and check error
591        let dequantized = quantized.dequantize();
592        let original = tensor.data();
593        let reconstructed = dequantized.data();
594
595        for (orig, recon) in original.iter().zip(reconstructed.iter()) {
596            // INT8 quantization error should be within 1/255 of the range
597            assert!((orig - recon).abs() < 0.1, "Quantization error too large: orig={}, recon={}", orig, recon);
598        }
599    }
600
601    #[test]
602    fn test_int8_per_channel_quantization() {
603        let tensor = DenseTensor::new(vec![0.0, 1.0, 2.0, 10.0, 20.0, 30.0], vec![2, 3]);
604        let config = QuantizationConfig::per_channel_int8(1);
605
606        let quantized = QuantizedTensor::from_tensor(&tensor, config);
607
608        assert!(quantized.channel_scales.is_some());
609        assert_eq!(quantized.channel_scales.unwrap().len(), 3);
610    }
611
612    #[test]
613    fn test_int4_quantization() {
614        let tensor = DenseTensor::new(vec![0.0, 0.5, 1.0], vec![1, 3]);
615        let config = QuantizationConfig::int4();
616
617        let quantized = QuantizedTensor::from_tensor(&tensor, config);
618
619        // INT4 packs 2 values per byte, so 3 values need 2 bytes
620        assert_eq!(quantized.data.len(), 2);
621    }
622
623    #[test]
624    fn test_compression_ratio() {
625        let tensor = DenseTensor::new(vec![0.0; 100], vec![10, 10]);
626
627        let int8 = QuantizedTensor::from_tensor(&tensor, QuantizationConfig::int8());
628        assert!((int8.compression_ratio() - 4.0).abs() < 0.1); // INT8 is 4x smaller than F32
629
630        let int4 = QuantizedTensor::from_tensor(&tensor, QuantizationConfig::int4());
631        assert!((int4.compression_ratio() - 8.0).abs() < 0.1); // INT4 is 8x smaller than F32
632    }
633
634    #[test]
635    fn test_quantized_matmul() {
636        let a = DenseTensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
637        let b = DenseTensor::new(vec![0.1, 0.2, 0.3, 0.4], vec![2, 2]);
638
639        let a_q = QuantizedTensor::from_tensor(&a, QuantizationConfig::int8());
640        let b_q = QuantizedTensor::from_tensor(&b, QuantizationConfig::int8());
641
642        let result = QuantizedMatMul::matmul(&a_q, &b_q);
643
644        assert_eq!(result.shape(), &[2, 2]);
645    }
646
647    #[test]
648    fn test_weight_quantization() {
649        let weights = DenseTensor::new(vec![-1.0, -0.5, 0.0, 0.5, 1.0], vec![1, 5]);
650
651        let quantized = weight_quantization::quantize_weights(&weights);
652
653        assert_eq!(quantized.config.dtype, QuantDtype::INT8);
654
655        let dequantized = quantized.dequantize();
656        let original = weights.data();
657        let reconstructed = dequantized.data();
658
659        for (orig, recon) in original.iter().zip(reconstructed.iter()) {
660            // Weight quantization error should be within acceptable range
661            assert!((orig - recon).abs() < 0.15, "Weight quantization error too large: orig={}, recon={}", orig, recon);
662        }
663    }
664}