Skip to main content

entrenar/quant/calibration/
calibrator.rs

1//! PTQ Calibrator implementation
2//!
3//! The main `Calibrator` struct for collecting statistics and computing
4//! quantization parameters.
5
6use crate::Tensor;
7
8use super::helpers::rand_simple;
9use super::types::{CalibrationMethod, CalibrationResult};
10
11/// PTQ Calibrator for collecting statistics and computing quantization parameters
12#[derive(Clone, Debug)]
13pub struct Calibrator {
14    /// Calibration method
15    method: CalibrationMethod,
16    /// Whether quantization is symmetric
17    symmetric: bool,
18    /// Number of bits for quantization
19    bits: usize,
20    /// Running minimum (for moving average)
21    running_min: Option<f32>,
22    /// Running maximum (for moving average)
23    running_max: Option<f32>,
24    /// Collected samples (for percentile)
25    samples: Vec<f32>,
26    /// Maximum samples to collect (for percentile)
27    max_samples: usize,
28    /// Number of batches observed
29    num_batches: usize,
30}
31
32impl Calibrator {
33    /// Create new calibrator with min-max method
34    pub fn min_max(bits: usize, symmetric: bool) -> Self {
35        Self {
36            method: CalibrationMethod::MinMax,
37            symmetric,
38            bits,
39            running_min: None,
40            running_max: None,
41            samples: Vec::new(),
42            max_samples: 0,
43            num_batches: 0,
44        }
45    }
46
47    /// Create new calibrator with percentile method
48    ///
49    /// # Arguments
50    /// * `bits` - Number of quantization bits
51    /// * `symmetric` - Whether to use symmetric quantization
52    /// * `lower` - Lower percentile (e.g., 0.01 for 0.01%)
53    /// * `upper` - Upper percentile (e.g., 99.99 for 99.99%)
54    /// * `max_samples` - Maximum number of samples to collect
55    pub fn percentile(
56        bits: usize,
57        symmetric: bool,
58        lower: f32,
59        upper: f32,
60        max_samples: usize,
61    ) -> Self {
62        Self {
63            method: CalibrationMethod::Percentile { lower, upper },
64            symmetric,
65            bits,
66            running_min: None,
67            running_max: None,
68            samples: Vec::with_capacity(max_samples.min(10000)),
69            max_samples,
70            num_batches: 0,
71        }
72    }
73
74    /// Create new calibrator with moving average method
75    pub fn moving_average(bits: usize, symmetric: bool, momentum: f32) -> Self {
76        Self {
77            method: CalibrationMethod::MovingAverage { momentum },
78            symmetric,
79            bits,
80            running_min: None,
81            running_max: None,
82            samples: Vec::new(),
83            max_samples: 0,
84            num_batches: 0,
85        }
86    }
87
88    /// Observe a batch of data for calibration
89    pub fn observe(&mut self, data: &[f32]) {
90        if data.is_empty() {
91            return;
92        }
93
94        match &self.method {
95            CalibrationMethod::MinMax => {
96                self.observe_min_max(data);
97            }
98            CalibrationMethod::Percentile { .. } => {
99                self.observe_percentile(data);
100            }
101            CalibrationMethod::MovingAverage { momentum } => {
102                let momentum = *momentum;
103                self.observe_moving_average(data, momentum);
104            }
105        }
106
107        self.num_batches += 1;
108    }
109
110    /// Observe a tensor for calibration
111    pub fn observe_tensor(&mut self, tensor: &Tensor) {
112        if let Some(slice) = tensor.data().as_slice() {
113            self.observe(slice);
114        }
115    }
116
117    /// Observe multiple tensors
118    pub fn observe_tensors(&mut self, tensors: &[&Tensor]) {
119        for tensor in tensors {
120            self.observe_tensor(tensor);
121        }
122    }
123
124    /// Compute calibration result
125    pub fn compute(&self) -> CalibrationResult {
126        let (observed_min, observed_max) = match &self.method {
127            CalibrationMethod::MinMax | CalibrationMethod::MovingAverage { .. } => {
128                (self.running_min.unwrap_or(0.0), self.running_max.unwrap_or(0.0))
129            }
130            CalibrationMethod::Percentile { lower, upper } => {
131                self.compute_percentile_bounds(*lower, *upper)
132            }
133        };
134
135        let (scale, zero_point) = self.compute_scale_zero_point(observed_min, observed_max);
136
137        CalibrationResult {
138            scale,
139            zero_point,
140            observed_min,
141            observed_max,
142            method: self.method.clone(),
143        }
144    }
145
146    /// Get number of batches observed
147    pub fn num_batches(&self) -> usize {
148        self.num_batches
149    }
150
151    /// Get calibration method
152    pub fn method(&self) -> &CalibrationMethod {
153        &self.method
154    }
155
156    /// Check if any data has been observed
157    pub fn has_data(&self) -> bool {
158        self.num_batches > 0
159    }
160
161    /// Reset calibration state
162    pub fn reset(&mut self) {
163        self.running_min = None;
164        self.running_max = None;
165        self.samples.clear();
166        self.num_batches = 0;
167    }
168
169    // Internal methods
170
171    fn observe_min_max(&mut self, data: &[f32]) {
172        let batch_min = data.iter().copied().fold(f32::INFINITY, f32::min);
173        let batch_max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
174
175        self.running_min = Some(self.running_min.map_or(batch_min, |m| m.min(batch_min)));
176        self.running_max = Some(self.running_max.map_or(batch_max, |m| m.max(batch_max)));
177    }
178
179    fn observe_percentile(&mut self, data: &[f32]) {
180        // Collect samples (with reservoir sampling if needed)
181        if self.samples.len() < self.max_samples {
182            let remaining = self.max_samples - self.samples.len();
183            self.samples.extend(data.iter().take(remaining).copied());
184        } else {
185            // Reservoir sampling for samples beyond max_samples
186            let total_seen = self.num_batches * data.len() + data.len();
187            for (i, &val) in data.iter().enumerate() {
188                let j = rand_simple(total_seen + i);
189                if j < self.max_samples {
190                    self.samples[j] = val;
191                }
192            }
193        }
194
195        // Also track min/max for fallback
196        self.observe_min_max(data);
197    }
198
199    fn observe_moving_average(&mut self, data: &[f32], momentum: f32) {
200        let batch_min = data.iter().copied().fold(f32::INFINITY, f32::min);
201        let batch_max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
202
203        self.running_min = Some(
204            self.running_min.map_or(batch_min, |m| m * (1.0 - momentum) + batch_min * momentum),
205        );
206        self.running_max = Some(
207            self.running_max.map_or(batch_max, |m| m * (1.0 - momentum) + batch_max * momentum),
208        );
209    }
210
211    fn compute_percentile_bounds(&self, lower: f32, upper: f32) -> (f32, f32) {
212        if self.samples.is_empty() {
213            return (self.running_min.unwrap_or(0.0), self.running_max.unwrap_or(0.0));
214        }
215
216        let mut sorted = self.samples.clone();
217        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
218
219        let n = sorted.len();
220        let lower_idx = ((lower / 100.0) * n as f32) as usize;
221        let upper_idx = ((upper / 100.0) * n as f32).min((n - 1) as f32) as usize;
222
223        (sorted[lower_idx], sorted[upper_idx])
224    }
225
226    fn compute_scale_zero_point(&self, min_val: f32, max_val: f32) -> (f32, i32) {
227        let qmax = (1 << (self.bits - 1)) - 1;
228        let qmin = if self.symmetric { -qmax } else { 0 };
229        let qmax_full = if self.symmetric { qmax } else { (1 << self.bits) - 1 };
230
231        if self.symmetric {
232            // Symmetric: scale from max absolute value
233            let max_abs = min_val.abs().max(max_val.abs());
234            let scale = if max_abs < 1e-10 { 1e-10 } else { max_abs / qmax as f32 };
235            (scale, 0)
236        } else {
237            // Asymmetric: scale from range
238            let range = max_val - min_val;
239            let scale = if range < 1e-10 { 1e-10 } else { range / (qmax_full - qmin) as f32 };
240            let zero_point = (qmin as f32 - min_val / scale).round() as i32;
241            let zero_point = zero_point.clamp(qmin, qmax_full);
242            (scale, zero_point)
243        }
244    }
245}