Skip to main content

entrenar/quant/fake_quantize/
quantize.rs

1//! Fake quantization operation with Straight-Through Estimator (STE).
2
3use crate::Tensor;
4
5use super::config::FakeQuantConfig;
6
7/// Fake quantization operation with Straight-Through Estimator (STE)
8///
9/// This struct holds the state for fake quantization including learned
10/// or calibrated scale and zero_point parameters.
11#[derive(Clone, Debug)]
12pub struct FakeQuantize {
13    /// Quantization configuration
14    pub config: FakeQuantConfig,
15    /// Scale factor for quantization
16    pub scale: f32,
17    /// Zero point for asymmetric quantization
18    pub zero_point: i32,
19    /// Whether scale has been initialized
20    pub initialized: bool,
21}
22
23impl FakeQuantize {
24    /// Create new fake quantization operation
25    pub fn new(config: FakeQuantConfig) -> Self {
26        Self { config, scale: 1.0, zero_point: 0, initialized: false }
27    }
28
29    /// Create with 4-bit symmetric quantization
30    pub fn q4() -> Self {
31        Self::new(FakeQuantConfig::q4_symmetric())
32    }
33
34    /// Create with 8-bit symmetric quantization
35    pub fn q8() -> Self {
36        Self::new(FakeQuantConfig::q8_symmetric())
37    }
38
39    /// Initialize scale from data (min-max calibration)
40    ///
41    /// For symmetric: scale = max(|min|, |max|) / qmax
42    /// For asymmetric: scale = (max - min) / (qmax - qmin)
43    pub fn calibrate(&mut self, data: &[f32]) {
44        if data.is_empty() {
45            return;
46        }
47
48        let min_val = data.iter().copied().fold(f32::INFINITY, f32::min);
49        let max_val = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
50
51        if self.config.symmetric {
52            // Symmetric: scale from max absolute value
53            let max_abs = min_val.abs().max(max_val.abs());
54            self.scale = max_abs / self.config.qmax as f32;
55            self.zero_point = 0;
56        } else {
57            // Asymmetric: scale from range
58            self.scale = (max_val - min_val) / (self.config.qmax - self.config.qmin) as f32;
59            self.zero_point = (self.config.qmin as f32 - min_val / self.scale).round() as i32;
60            self.zero_point = self.zero_point.clamp(self.config.qmin, self.config.qmax);
61        }
62
63        // Prevent division by zero
64        if self.scale < 1e-10 {
65            self.scale = 1e-10;
66        }
67
68        self.initialized = true;
69    }
70
71    /// Forward pass: fake quantize (quantize → dequantize)
72    ///
73    /// Simulates quantization effects while keeping values in floating point.
74    /// Output = dequantize(quantize(input))
75    pub fn forward(&self, input: &Tensor) -> Tensor {
76        let data: Vec<f32> = input.data().iter().map(|&x| self.fake_quantize_value(x)).collect();
77
78        Tensor::new(ndarray::arr1(&data), input.requires_grad())
79    }
80
81    /// Forward pass with auto-calibration
82    ///
83    /// If not initialized, calibrates from input data first.
84    pub fn forward_with_calibration(&mut self, input: &Tensor) -> Tensor {
85        if !self.initialized {
86            self.calibrate(input.data().as_slice().unwrap_or(&[]));
87        }
88        self.forward(input)
89    }
90
91    /// Backward pass: Straight-Through Estimator (STE)
92    ///
93    /// The gradient passes through unchanged:
94    /// ∂L/∂x = ∂L/∂y (where y = fake_quantize(x))
95    ///
96    /// This allows gradients to flow during training despite the
97    /// non-differentiable quantization operation.
98    pub fn backward(&self, grad_output: &Tensor) -> Tensor {
99        // STE: gradient passes through unchanged
100        grad_output.clone()
101    }
102
103    /// Backward pass with gradient clipping (clamped STE)
104    ///
105    /// Clips gradients to zero outside the quantization range.
106    /// This can improve training stability.
107    pub fn backward_clamped(&self, grad_output: &Tensor, input: &Tensor) -> Tensor {
108        let qmin_float = self.config.qmin as f32 * self.scale;
109        let qmax_float = self.config.qmax as f32 * self.scale;
110
111        let data: Vec<f32> = grad_output
112            .data()
113            .iter()
114            .zip(input.data().iter())
115            .map(|(&grad, &x)| {
116                // Zero gradient outside quantization range
117                if x < qmin_float || x > qmax_float {
118                    0.0
119                } else {
120                    grad
121                }
122            })
123            .collect();
124
125        Tensor::new(ndarray::arr1(&data), grad_output.requires_grad())
126    }
127
128    /// Fake quantize a single value
129    fn fake_quantize_value(&self, x: f32) -> f32 {
130        // Quantize
131        let q = if self.config.symmetric {
132            (x / self.scale).round().clamp(self.config.qmin as f32, self.config.qmax as f32) as i32
133        } else {
134            ((x / self.scale) + self.zero_point as f32)
135                .round()
136                .clamp(self.config.qmin as f32, self.config.qmax as f32) as i32
137        };
138
139        // Dequantize
140        if self.config.symmetric {
141            q as f32 * self.scale
142        } else {
143            (q - self.zero_point) as f32 * self.scale
144        }
145    }
146
147    /// Get the quantization scale
148    pub fn scale(&self) -> f32 {
149        self.scale
150    }
151
152    /// Get the zero point
153    pub fn zero_point(&self) -> i32 {
154        self.zero_point
155    }
156
157    /// Check if calibrated
158    pub fn is_initialized(&self) -> bool {
159        self.initialized
160    }
161
162    /// Get number of quantization levels
163    pub fn num_levels(&self) -> usize {
164        (self.config.qmax - self.config.qmin + 1) as usize
165    }
166}