Skip to main content

oxicuda_quant/qat/
observer.rs

1//! # Quantization Observers
2//!
3//! Observers calibrate the quantization range by accumulating statistics
4//! over a calibration dataset and then deriving scale / zero-point.
5//!
6//! | Observer           | Calibration strategy                       |
7//! |--------------------|--------------------------------------------|
8//! | `MinMaxObserver`   | Global min/max over all observed data      |
9//! | `MovingAvgObserver`| Exponential moving average of min/max      |
10//! | `HistogramObserver`| Histogram + min-MSE clipping range search  |
11
12use crate::error::{QuantError, QuantResult};
13
14// ─── Observer trait ───────────────────────────────────────────────────────────
15
16/// Common interface for all quantization observers.
17pub trait Observer {
18    /// Observe a batch of values.
19    fn observe(&mut self, data: &[f32]);
20
21    /// Compute `(scale, zero_point)` from accumulated statistics.
22    ///
23    /// # Errors
24    ///
25    /// Returns [`QuantError::CalibrationRequired`] if no data has been observed.
26    fn compute_params(&self) -> QuantResult<(f32, i32)>;
27
28    /// Reset all accumulated statistics.
29    fn reset(&mut self);
30
31    /// Whether any data has been observed.
32    fn is_calibrated(&self) -> bool;
33}
34
35// ─── Shared helpers ───────────────────────────────────────────────────────────
36
37fn sym_scale(abs_max: f32, bits: u32) -> f32 {
38    let q_max = (1i32 << (bits - 1)) as f32 - 1.0;
39    abs_max.max(1e-8) / q_max
40}
41
42fn asym_scale_zp(min_val: f32, max_val: f32, bits: u32) -> (f32, i32) {
43    let q_range = ((1u32 << bits) - 1) as f32;
44    let range = (max_val - min_val).max(1e-8);
45    let scale = range / q_range;
46    let zp = (-min_val / scale).round().clamp(0.0, q_range) as i32;
47    (scale, zp)
48}
49
50// ─── MinMaxObserver ──────────────────────────────────────────────────────────
51
52/// Tracks the global minimum and maximum of observed values.
53///
54/// **Symmetric** quantization: `scale = max(|min|, |max|) / q_max`, `zp = 0`.
55/// **Asymmetric** quantization: `scale = (max − min) / (2^bits − 1)`.
56#[derive(Debug, Clone)]
57pub struct MinMaxObserver {
58    /// Running minimum value.
59    pub min_val: f32,
60    /// Running maximum value.
61    pub max_val: f32,
62    /// Quantization bit-width.
63    pub bits: u32,
64    /// Symmetric (zero-point = 0) vs asymmetric.
65    pub symmetric: bool,
66}
67
68impl MinMaxObserver {
69    /// Create a new MinMaxObserver.
70    ///
71    /// # Panics
72    ///
73    /// Panics if `bits` is 0 or > 16.
74    #[must_use]
75    pub fn new(bits: u32, symmetric: bool) -> Self {
76        assert!(bits > 0 && bits <= 16, "bits must be in [1, 16]");
77        Self {
78            min_val: f32::INFINITY,
79            max_val: f32::NEG_INFINITY,
80            bits,
81            symmetric,
82        }
83    }
84}
85
86impl Observer for MinMaxObserver {
87    fn observe(&mut self, data: &[f32]) {
88        for &v in data {
89            if v.is_finite() {
90                if v < self.min_val {
91                    self.min_val = v;
92                }
93                if v > self.max_val {
94                    self.max_val = v;
95                }
96            }
97        }
98    }
99
100    fn compute_params(&self) -> QuantResult<(f32, i32)> {
101        if !self.is_calibrated() {
102            return Err(QuantError::CalibrationRequired("MinMaxObserver"));
103        }
104        if self.symmetric {
105            let abs_max = self.min_val.abs().max(self.max_val.abs());
106            Ok((sym_scale(abs_max, self.bits), 0))
107        } else {
108            Ok(asym_scale_zp(self.min_val, self.max_val, self.bits))
109        }
110    }
111
112    fn reset(&mut self) {
113        self.min_val = f32::INFINITY;
114        self.max_val = f32::NEG_INFINITY;
115    }
116
117    fn is_calibrated(&self) -> bool {
118        self.min_val.is_finite() && self.max_val.is_finite()
119    }
120}
121
122// ─── MovingAvgObserver ───────────────────────────────────────────────────────
123
124/// Tracks an exponential moving average of per-batch min/max statistics.
125///
126/// Update rule:
127/// ```text
128/// min_val ← momentum × min_val + (1 − momentum) × batch_min
129/// max_val ← momentum × max_val + (1 − momentum) × batch_max
130/// ```
131#[derive(Debug, Clone)]
132pub struct MovingAvgObserver {
133    /// Running EMA minimum.
134    pub min_val: f32,
135    /// Running EMA maximum.
136    pub max_val: f32,
137    /// EMA momentum (fraction of old statistics to retain, typically 0.9–0.99).
138    pub momentum: f32,
139    /// Quantization bit-width.
140    pub bits: u32,
141    /// Symmetric vs asymmetric quantization.
142    pub symmetric: bool,
143    initialized: bool,
144}
145
146impl MovingAvgObserver {
147    /// Create a new MovingAvgObserver.
148    ///
149    /// # Panics
150    ///
151    /// Panics if `bits` is 0 or > 16 or if `momentum` is not in (0, 1).
152    #[must_use]
153    pub fn new(bits: u32, symmetric: bool, momentum: f32) -> Self {
154        assert!(bits > 0 && bits <= 16, "bits must be in [1, 16]");
155        assert!(
156            momentum > 0.0 && momentum < 1.0,
157            "momentum must be in (0, 1), got {momentum}"
158        );
159        Self {
160            min_val: 0.0,
161            max_val: 0.0,
162            momentum,
163            bits,
164            symmetric,
165            initialized: false,
166        }
167    }
168}
169
170impl Observer for MovingAvgObserver {
171    fn observe(&mut self, data: &[f32]) {
172        if data.is_empty() {
173            return;
174        }
175        let batch_min = data
176            .iter()
177            .copied()
178            .filter(|v| v.is_finite())
179            .fold(f32::INFINITY, f32::min);
180        let batch_max = data
181            .iter()
182            .copied()
183            .filter(|v| v.is_finite())
184            .fold(f32::NEG_INFINITY, f32::max);
185        if !batch_min.is_finite() || !batch_max.is_finite() {
186            return;
187        }
188        if !self.initialized {
189            self.min_val = batch_min;
190            self.max_val = batch_max;
191            self.initialized = true;
192        } else {
193            let m = self.momentum;
194            self.min_val = m * self.min_val + (1.0 - m) * batch_min;
195            self.max_val = m * self.max_val + (1.0 - m) * batch_max;
196        }
197    }
198
199    fn compute_params(&self) -> QuantResult<(f32, i32)> {
200        if !self.is_calibrated() {
201            return Err(QuantError::CalibrationRequired("MovingAvgObserver"));
202        }
203        if self.symmetric {
204            let abs_max = self.min_val.abs().max(self.max_val.abs());
205            Ok((sym_scale(abs_max, self.bits), 0))
206        } else {
207            Ok(asym_scale_zp(self.min_val, self.max_val, self.bits))
208        }
209    }
210
211    fn reset(&mut self) {
212        self.min_val = 0.0;
213        self.max_val = 0.0;
214        self.initialized = false;
215    }
216
217    fn is_calibrated(&self) -> bool {
218        self.initialized
219    }
220}
221
222// ─── HistogramObserver ───────────────────────────────────────────────────────
223
224/// Calibrates using a fixed-width histogram and min-MSE clipping search.
225///
226/// Accumulates a histogram over the absolute range of all observed data.
227/// `compute_params` searches over percentile clipping thresholds and returns
228/// the range that minimises estimated quantization MSE.
229#[derive(Debug, Clone)]
230pub struct HistogramObserver {
231    /// Histogram bin counts.
232    bins: Vec<u64>,
233    /// Left edge of the histogram range.
234    range_min: f32,
235    /// Right edge of the histogram range.
236    range_max: f32,
237    /// Number of histogram bins.
238    n_bins: usize,
239    /// Quantization bit-width.
240    pub bits: u32,
241    /// Symmetric vs asymmetric.
242    pub symmetric: bool,
243    initialized: bool,
244}
245
246impl HistogramObserver {
247    /// Create a new HistogramObserver with `n_bins` bins.
248    ///
249    /// # Panics
250    ///
251    /// Panics if `bits` is 0 or > 16 or `n_bins` is 0.
252    #[must_use]
253    pub fn new(bits: u32, symmetric: bool, n_bins: usize) -> Self {
254        assert!(bits > 0 && bits <= 16, "bits must be in [1, 16]");
255        assert!(n_bins > 0, "n_bins must be > 0");
256        Self {
257            bins: vec![0_u64; n_bins],
258            range_min: 0.0,
259            range_max: 0.0,
260            n_bins,
261            bits,
262            symmetric,
263            initialized: false,
264        }
265    }
266
267    /// Bin width of the current histogram.
268    fn bin_width(&self) -> f32 {
269        (self.range_max - self.range_min) / self.n_bins as f32
270    }
271
272    /// Estimate the quantization MSE for the clipping range `[lo, hi]`.
273    fn estimate_mse(&self, lo: f32, hi: f32) -> f32 {
274        let bw = self.bin_width();
275        let total: u64 = self.bins.iter().sum();
276        if total == 0 || (hi - lo).abs() < 1e-12 {
277            return f32::INFINITY;
278        }
279
280        let n_levels = ((1u32 << self.bits) - 1) as f32;
281        let step = (hi - lo) / n_levels;
282
283        let mut mse = 0.0_f32;
284        for (b, &cnt) in self.bins.iter().enumerate() {
285            if cnt == 0 {
286                continue;
287            }
288            let center = self.range_min + (b as f32 + 0.5) * bw;
289            let quant_val = if center <= lo {
290                lo
291            } else if center >= hi {
292                hi
293            } else {
294                let idx = ((center - lo) / step).round();
295                lo + idx * step
296            };
297            let err = center - quant_val;
298            mse += cnt as f32 * err * err;
299        }
300        mse / total as f32
301    }
302}
303
304impl Observer for HistogramObserver {
305    fn observe(&mut self, data: &[f32]) {
306        let finite: Vec<f32> = data.iter().copied().filter(|v| v.is_finite()).collect();
307        if finite.is_empty() {
308            return;
309        }
310
311        let d_min = finite.iter().copied().fold(f32::INFINITY, f32::min);
312        let d_max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
313
314        if !self.initialized {
315            self.range_min = d_min;
316            self.range_max = d_max;
317            self.initialized = true;
318        } else {
319            // Expand range if needed (histogram bins are not re-bucketed for simplicity).
320            if d_min < self.range_min {
321                self.range_min = d_min;
322            }
323            if d_max > self.range_max {
324                self.range_max = d_max;
325            }
326        }
327
328        // Ensure non-trivial range.
329        if (self.range_max - self.range_min).abs() < 1e-8 {
330            self.range_max = self.range_min + 1e-8;
331        }
332
333        let bw = self.bin_width();
334        for &v in &finite {
335            let idx = ((v - self.range_min) / bw) as usize;
336            let idx = idx.min(self.n_bins - 1);
337            self.bins[idx] += 1;
338        }
339    }
340
341    fn compute_params(&self) -> QuantResult<(f32, i32)> {
342        if !self.is_calibrated() {
343            return Err(QuantError::CalibrationRequired("HistogramObserver"));
344        }
345
346        // Search over 20 percentile thresholds (0.5% to 100% of histogram range).
347        let n_search = 20_usize;
348        let mut best_mse = f32::INFINITY;
349        let mut best_lo = self.range_min;
350        let mut best_hi = self.range_max;
351
352        let total: u64 = self.bins.iter().sum();
353        if total == 0 {
354            return Err(QuantError::CalibrationRequired("HistogramObserver"));
355        }
356
357        // Find quantile boundaries.
358        let percentiles: Vec<f32> = (1..=n_search).map(|i| i as f32 / n_search as f32).collect();
359
360        for &pct in &percentiles {
361            let threshold = (pct * total as f32) as u64;
362            let mut cum = 0_u64;
363            let mut cut_bin = self.n_bins - 1;
364            for (b, &cnt) in self.bins.iter().enumerate() {
365                cum += cnt;
366                if cum >= threshold {
367                    cut_bin = b;
368                    break;
369                }
370            }
371            let bw = self.bin_width();
372            let hi = self.range_min + (cut_bin as f32 + 1.0) * bw;
373            let lo = if self.symmetric { -hi } else { self.range_min };
374
375            let mse = self.estimate_mse(lo, hi);
376            if mse < best_mse {
377                best_mse = mse;
378                best_lo = lo;
379                best_hi = hi;
380            }
381        }
382
383        if self.symmetric {
384            let abs_max = best_lo.abs().max(best_hi.abs());
385            Ok((sym_scale(abs_max, self.bits), 0))
386        } else {
387            Ok(asym_scale_zp(best_lo, best_hi, self.bits))
388        }
389    }
390
391    fn reset(&mut self) {
392        self.bins.fill(0);
393        self.range_min = 0.0;
394        self.range_max = 0.0;
395        self.initialized = false;
396    }
397
398    fn is_calibrated(&self) -> bool {
399        self.initialized
400    }
401}
402
403// ─── Tests ───────────────────────────────────────────────────────────────────
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408    use approx::assert_abs_diff_eq;
409
410    #[test]
411    fn minmax_symmetric_scale() {
412        let mut obs = MinMaxObserver::new(8, true);
413        obs.observe(&[-2.0_f32, -1.0, 0.5, 2.0]);
414        let (scale, zp) = obs.compute_params().unwrap();
415        // abs_max = 2.0, q_max = 127 → scale = 2/127
416        assert_abs_diff_eq!(scale, 2.0 / 127.0, epsilon = 1e-6);
417        assert_eq!(zp, 0);
418    }
419
420    #[test]
421    fn minmax_asymmetric_scale_zp() {
422        let mut obs = MinMaxObserver::new(8, false);
423        obs.observe(&[0.0_f32, 1.0, 2.0, 3.0]);
424        let (scale, zp) = obs.compute_params().unwrap();
425        assert_abs_diff_eq!(scale, 3.0 / 255.0, epsilon = 1e-5);
426        assert_eq!(zp, 0);
427    }
428
429    #[test]
430    fn minmax_calibration_required() {
431        let obs = MinMaxObserver::new(8, true);
432        assert!(matches!(
433            obs.compute_params(),
434            Err(QuantError::CalibrationRequired(_))
435        ));
436    }
437
438    #[test]
439    fn minmax_reset() {
440        let mut obs = MinMaxObserver::new(8, true);
441        obs.observe(&[1.0_f32, 2.0]);
442        obs.reset();
443        assert!(!obs.is_calibrated());
444    }
445
446    #[test]
447    fn moving_avg_first_batch_exact() {
448        let mut obs = MovingAvgObserver::new(8, true, 0.9);
449        obs.observe(&[-1.0_f32, 1.0]);
450        // First batch: min=-1, max=1, no averaging yet.
451        let (scale, zp) = obs.compute_params().unwrap();
452        assert_abs_diff_eq!(scale, 1.0 / 127.0, epsilon = 1e-5);
453        assert_eq!(zp, 0);
454    }
455
456    #[test]
457    fn moving_avg_ema_update() {
458        let mut obs = MovingAvgObserver::new(8, true, 0.9);
459        obs.observe(&[2.0_f32, 2.0]); // first: min=2, max=2
460        obs.observe(&[4.0_f32, 4.0]); // second: EMA
461        // max_val = 0.9*2 + 0.1*4 = 2.2
462        assert_abs_diff_eq!(obs.max_val, 2.2, epsilon = 1e-5);
463    }
464
465    #[test]
466    fn moving_avg_calibration_required() {
467        let obs = MovingAvgObserver::new(8, true, 0.9);
468        assert!(matches!(
469            obs.compute_params(),
470            Err(QuantError::CalibrationRequired(_))
471        ));
472    }
473
474    #[test]
475    fn histogram_observer_calibrates() {
476        let mut obs = HistogramObserver::new(8, true, 256);
477        let data: Vec<f32> = (0..1024).map(|i| (i as f32 / 512.0) - 1.0).collect();
478        obs.observe(&data);
479        assert!(obs.is_calibrated());
480        let (scale, zp) = obs.compute_params().unwrap();
481        assert!(scale > 0.0, "scale must be positive: {scale}");
482        assert_eq!(zp, 0, "symmetric: zp must be 0");
483    }
484
485    #[test]
486    fn histogram_observer_reset() {
487        let mut obs = HistogramObserver::new(8, true, 128);
488        obs.observe(&[1.0_f32, 2.0]);
489        obs.reset();
490        assert!(!obs.is_calibrated());
491    }
492
493    #[test]
494    fn histogram_observer_uncalibrated_error() {
495        let obs = HistogramObserver::new(8, true, 64);
496        assert!(matches!(
497            obs.compute_params(),
498            Err(QuantError::CalibrationRequired(_))
499        ));
500    }
501}