Skip to main content

oxicuda_quant/qat/
fake_quant.rs

1//! # Fake Quantization with Straight-Through Estimator (STE)
2//!
3//! Fake quantization simulates the effect of quantization during training by
4//! applying a quantize-then-dequantize operation in the forward pass while
5//! passing gradients through unchanged (STE) where the value is within the
6//! representable range.
7//!
8//! ## Forward pass
9//!
10//! ```text
11//! q  = clamp(round(x / scale + zp), q_min, q_max)
12//! x̂  = (q − zp) × scale
13//! ```
14//!
15//! ## Backward pass (STE)
16//!
17//! ```text
18//! ∂L/∂x = ∂L/∂x̂  if  q_min_val ≤ x ≤ q_max_val
19//!        = 0       otherwise (clipped region)
20//! ```
21//!
22//! where `q_min_val = (q_min − zp) × scale`, `q_max_val = (q_max − zp) × scale`.
23
24use crate::error::{QuantError, QuantResult};
25
26// ─── FakeQuantize ─────────────────────────────────────────────────────────────
27
28/// Fake quantization operator for quantization-aware training (QAT).
29///
30/// Maintains the current scale and zero-point that are updated during
31/// calibration / training via an associated observer.
32#[derive(Debug, Clone)]
33pub struct FakeQuantize {
34    /// Quantization bit-width.
35    pub bits: u32,
36    /// Whether to use symmetric quantization (zp = 0).
37    pub symmetric: bool,
38    /// Current quantization scale (must be > 0).
39    pub scale: f32,
40    /// Current zero-point.
41    pub zero_point: i32,
42    /// Whether fake quantization is enabled.
43    /// When disabled, `forward` returns the input unchanged.
44    pub enabled: bool,
45}
46
47impl FakeQuantize {
48    /// Create a new fake quantizer with the given scale and zero-point.
49    ///
50    /// # Errors
51    ///
52    /// * [`QuantError::InvalidBitWidth`] — `bits` is 0 or > 16.
53    /// * [`QuantError::InvalidScale`]   — `scale` is ≤ 0 or non-finite.
54    pub fn new(bits: u32, symmetric: bool, scale: f32, zero_point: i32) -> QuantResult<Self> {
55        if bits == 0 || bits > 16 {
56            return Err(QuantError::InvalidBitWidth { bits });
57        }
58        if !scale.is_finite() || scale <= 0.0 {
59            return Err(QuantError::InvalidScale { scale });
60        }
61        Ok(Self {
62            bits,
63            symmetric,
64            scale,
65            zero_point,
66            enabled: true,
67        })
68    }
69
70    /// Create with default scale=1.0, zp=0 for the given bit-width.
71    ///
72    /// # Errors
73    ///
74    /// * [`QuantError::InvalidBitWidth`] — `bits` is 0 or > 16.
75    pub fn with_defaults(bits: u32, symmetric: bool) -> QuantResult<Self> {
76        Self::new(bits, symmetric, 1.0, 0)
77    }
78
79    /// Update scale and zero-point (e.g., from an observer).
80    ///
81    /// # Errors
82    ///
83    /// * [`QuantError::InvalidScale`] — `scale` is ≤ 0 or non-finite.
84    pub fn update_params(&mut self, scale: f32, zero_point: i32) -> QuantResult<()> {
85        if !scale.is_finite() || scale <= 0.0 {
86            return Err(QuantError::InvalidScale { scale });
87        }
88        self.scale = scale;
89        self.zero_point = zero_point;
90        Ok(())
91    }
92
93    /// Integer quantization bounds [q_min, q_max].
94    #[must_use]
95    pub fn quant_range(&self) -> (i32, i32) {
96        if self.symmetric {
97            let half = 1i32 << (self.bits - 1);
98            (-half, half - 1)
99        } else {
100            (0i32, (1i32 << self.bits) - 1)
101        }
102    }
103
104    /// Float clipping bounds `[x_min, x_max]` corresponding to the integer range.
105    #[must_use]
106    pub fn float_range(&self) -> (f32, f32) {
107        let (q_min, q_max) = self.quant_range();
108        let zp = self.zero_point as f32;
109        let lo = (q_min as f32 - zp) * self.scale;
110        let hi = (q_max as f32 - zp) * self.scale;
111        (lo, hi)
112    }
113
114    /// Forward pass: quantize-then-dequantize.
115    ///
116    /// If `enabled = false`, returns the input unchanged.
117    #[must_use]
118    pub fn forward(&self, x: &[f32]) -> Vec<f32> {
119        if !self.enabled {
120            return x.to_vec();
121        }
122        let (q_min, q_max) = self.quant_range();
123        let zp = self.zero_point as f32;
124        x.iter()
125            .map(|&v| {
126                let q = (v / self.scale + zp)
127                    .round()
128                    .clamp(q_min as f32, q_max as f32);
129                (q - zp) * self.scale
130            })
131            .collect()
132    }
133
134    /// Backward pass (Straight-Through Estimator).
135    ///
136    /// Passes `grad_output` through where `x` is inside the representable
137    /// float range; zeros the gradient where `x` is clipped.
138    ///
139    /// # Errors
140    ///
141    /// * [`QuantError::DimensionMismatch`] — `grad_output` and `x` lengths differ.
142    pub fn backward(&self, grad_output: &[f32], x: &[f32]) -> QuantResult<Vec<f32>> {
143        if grad_output.len() != x.len() {
144            return Err(QuantError::DimensionMismatch {
145                expected: x.len(),
146                got: grad_output.len(),
147            });
148        }
149        if !self.enabled {
150            return Ok(grad_output.to_vec());
151        }
152        let (x_min, x_max) = self.float_range();
153        let grad = grad_output
154            .iter()
155            .zip(x.iter())
156            .map(|(&g, &v)| if v >= x_min && v <= x_max { g } else { 0.0 })
157            .collect();
158        Ok(grad)
159    }
160
161    /// Estimate quantization noise (MSE between input and fake-quantized output).
162    ///
163    /// Useful for measuring quantization error at the current scale/zp.
164    #[must_use]
165    pub fn quantization_noise(&self, x: &[f32]) -> f32 {
166        if x.is_empty() {
167            return 0.0;
168        }
169        let fq = self.forward(x);
170        let mse = x
171            .iter()
172            .zip(fq.iter())
173            .map(|(a, b)| (a - b).powi(2))
174            .sum::<f32>();
175        mse / x.len() as f32
176    }
177}
178
179// ─── Tests ───────────────────────────────────────────────────────────────────
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use approx::assert_abs_diff_eq;
185
186    #[test]
187    fn forward_quantize_dequantize_int8() {
188        let fq = FakeQuantize::new(8, true, 1.0 / 127.0, 0).unwrap();
189        // Input 1.0 → q = 127 → dequant = 127 / 127 ≈ 1.0
190        let out = fq.forward(&[1.0_f32]);
191        assert_abs_diff_eq!(out[0], 1.0, epsilon = 0.01);
192    }
193
194    #[test]
195    fn forward_passthrough_when_disabled() {
196        let mut fq = FakeQuantize::new(8, true, 0.01, 0).unwrap();
197        fq.enabled = false;
198        let data = vec![1.5_f32, -2.3, 0.7];
199        let out = fq.forward(&data);
200        assert_eq!(out, data);
201    }
202
203    #[test]
204    fn backward_ste_passthrough() {
205        let fq = FakeQuantize::new(8, true, 1.0 / 127.0, 0).unwrap();
206        let x = vec![0.5_f32, -0.5];
207        let grad = vec![1.0_f32, -1.0];
208        let ste = fq.backward(&grad, &x).unwrap();
209        // x is within [-1, 1]: gradient passed through unchanged.
210        assert_abs_diff_eq!(ste[0], 1.0, epsilon = 1e-6);
211        assert_abs_diff_eq!(ste[1], -1.0, epsilon = 1e-6);
212    }
213
214    #[test]
215    fn backward_ste_zero_outside_range() {
216        let fq = FakeQuantize::new(8, true, 1.0 / 127.0, 0).unwrap();
217        // x = ±2.0 is outside [-1, 127/127] float range
218        let x = vec![2.0_f32, -2.0];
219        let grad = vec![1.0_f32, 1.0];
220        let ste = fq.backward(&grad, &x).unwrap();
221        assert_abs_diff_eq!(ste[0], 0.0, epsilon = 1e-6);
222        assert_abs_diff_eq!(ste[1], 0.0, epsilon = 1e-6);
223    }
224
225    #[test]
226    fn backward_dimension_mismatch_error() {
227        let fq = FakeQuantize::new(8, true, 0.01, 0).unwrap();
228        let x = vec![0.5_f32; 3];
229        let grad = vec![1.0_f32; 4];
230        assert!(matches!(
231            fq.backward(&grad, &x),
232            Err(QuantError::DimensionMismatch { .. })
233        ));
234    }
235
236    #[test]
237    fn invalid_scale_error() {
238        assert!(matches!(
239            FakeQuantize::new(8, true, -0.01, 0),
240            Err(QuantError::InvalidScale { .. })
241        ));
242        assert!(matches!(
243            FakeQuantize::new(8, true, 0.0, 0),
244            Err(QuantError::InvalidScale { .. })
245        ));
246    }
247
248    #[test]
249    fn invalid_bit_width_error() {
250        assert!(matches!(
251            FakeQuantize::new(0, true, 0.01, 0),
252            Err(QuantError::InvalidBitWidth { bits: 0 })
253        ));
254        assert!(matches!(
255            FakeQuantize::new(17, true, 0.01, 0),
256            Err(QuantError::InvalidBitWidth { bits: 17 })
257        ));
258    }
259
260    #[test]
261    fn quant_range_int8_symmetric() {
262        let fq = FakeQuantize::new(8, true, 0.01, 0).unwrap();
263        assert_eq!(fq.quant_range(), (-128, 127));
264    }
265
266    #[test]
267    fn quant_range_int4_asymmetric() {
268        let fq = FakeQuantize::new(4, false, 0.01, 0).unwrap();
269        assert_eq!(fq.quant_range(), (0, 15));
270    }
271
272    #[test]
273    fn quantization_noise_zero_for_fine_scale() {
274        // With scale = 1/127 and INT8, values in [-1, 1] should have small noise.
275        let fq = FakeQuantize::new(8, true, 1.0 / 127.0, 0).unwrap();
276        let data: Vec<f32> = (0..128).map(|i| i as f32 / 128.0 - 0.5).collect();
277        let noise = fq.quantization_noise(&data);
278        assert!(noise < 1e-5, "noise too high: {noise}");
279    }
280
281    #[test]
282    fn update_params_works() {
283        let mut fq = FakeQuantize::with_defaults(8, true).unwrap();
284        fq.update_params(0.5, 0).unwrap();
285        assert_abs_diff_eq!(fq.scale, 0.5, epsilon = 1e-7);
286    }
287
288    #[test]
289    fn asymmetric_forward_with_nonzero_zp() {
290        // scale=1/15, zp=0 for [0, 1] range with INT4 asymmetric
291        let fq = FakeQuantize::new(4, false, 1.0 / 15.0, 0).unwrap();
292        let out = fq.forward(&[0.0_f32, 0.5, 1.0]);
293        // 0 → q=0, 0.5 → q≈7 → 7/15, 1.0 → q=15 → 1.0
294        assert_abs_diff_eq!(out[0], 0.0, epsilon = 0.001);
295        assert!(out[1] > 0.4 && out[1] < 0.6, "midpoint: {}", out[1]);
296        assert_abs_diff_eq!(out[2], 1.0, epsilon = 0.001);
297    }
298}