1use crate::error::{QuantError, QuantResult};
13
14pub trait Observer {
18 fn observe(&mut self, data: &[f32]);
20
21 fn compute_params(&self) -> QuantResult<(f32, i32)>;
27
28 fn reset(&mut self);
30
31 fn is_calibrated(&self) -> bool;
33}
34
35fn sym_scale(abs_max: f32, bits: u32) -> f32 {
38 let q_max = (1i32 << (bits - 1)) as f32 - 1.0;
39 abs_max.max(1e-8) / q_max
40}
41
42fn asym_scale_zp(min_val: f32, max_val: f32, bits: u32) -> (f32, i32) {
43 let q_range = ((1u32 << bits) - 1) as f32;
44 let range = (max_val - min_val).max(1e-8);
45 let scale = range / q_range;
46 let zp = (-min_val / scale).round().clamp(0.0, q_range) as i32;
47 (scale, zp)
48}
49
50#[derive(Debug, Clone)]
57pub struct MinMaxObserver {
58 pub min_val: f32,
60 pub max_val: f32,
62 pub bits: u32,
64 pub symmetric: bool,
66}
67
68impl MinMaxObserver {
69 #[must_use]
75 pub fn new(bits: u32, symmetric: bool) -> Self {
76 assert!(bits > 0 && bits <= 16, "bits must be in [1, 16]");
77 Self {
78 min_val: f32::INFINITY,
79 max_val: f32::NEG_INFINITY,
80 bits,
81 symmetric,
82 }
83 }
84}
85
86impl Observer for MinMaxObserver {
87 fn observe(&mut self, data: &[f32]) {
88 for &v in data {
89 if v.is_finite() {
90 if v < self.min_val {
91 self.min_val = v;
92 }
93 if v > self.max_val {
94 self.max_val = v;
95 }
96 }
97 }
98 }
99
100 fn compute_params(&self) -> QuantResult<(f32, i32)> {
101 if !self.is_calibrated() {
102 return Err(QuantError::CalibrationRequired("MinMaxObserver"));
103 }
104 if self.symmetric {
105 let abs_max = self.min_val.abs().max(self.max_val.abs());
106 Ok((sym_scale(abs_max, self.bits), 0))
107 } else {
108 Ok(asym_scale_zp(self.min_val, self.max_val, self.bits))
109 }
110 }
111
112 fn reset(&mut self) {
113 self.min_val = f32::INFINITY;
114 self.max_val = f32::NEG_INFINITY;
115 }
116
117 fn is_calibrated(&self) -> bool {
118 self.min_val.is_finite() && self.max_val.is_finite()
119 }
120}
121
122#[derive(Debug, Clone)]
132pub struct MovingAvgObserver {
133 pub min_val: f32,
135 pub max_val: f32,
137 pub momentum: f32,
139 pub bits: u32,
141 pub symmetric: bool,
143 initialized: bool,
144}
145
146impl MovingAvgObserver {
147 #[must_use]
153 pub fn new(bits: u32, symmetric: bool, momentum: f32) -> Self {
154 assert!(bits > 0 && bits <= 16, "bits must be in [1, 16]");
155 assert!(
156 momentum > 0.0 && momentum < 1.0,
157 "momentum must be in (0, 1), got {momentum}"
158 );
159 Self {
160 min_val: 0.0,
161 max_val: 0.0,
162 momentum,
163 bits,
164 symmetric,
165 initialized: false,
166 }
167 }
168}
169
170impl Observer for MovingAvgObserver {
171 fn observe(&mut self, data: &[f32]) {
172 if data.is_empty() {
173 return;
174 }
175 let batch_min = data
176 .iter()
177 .copied()
178 .filter(|v| v.is_finite())
179 .fold(f32::INFINITY, f32::min);
180 let batch_max = data
181 .iter()
182 .copied()
183 .filter(|v| v.is_finite())
184 .fold(f32::NEG_INFINITY, f32::max);
185 if !batch_min.is_finite() || !batch_max.is_finite() {
186 return;
187 }
188 if !self.initialized {
189 self.min_val = batch_min;
190 self.max_val = batch_max;
191 self.initialized = true;
192 } else {
193 let m = self.momentum;
194 self.min_val = m * self.min_val + (1.0 - m) * batch_min;
195 self.max_val = m * self.max_val + (1.0 - m) * batch_max;
196 }
197 }
198
199 fn compute_params(&self) -> QuantResult<(f32, i32)> {
200 if !self.is_calibrated() {
201 return Err(QuantError::CalibrationRequired("MovingAvgObserver"));
202 }
203 if self.symmetric {
204 let abs_max = self.min_val.abs().max(self.max_val.abs());
205 Ok((sym_scale(abs_max, self.bits), 0))
206 } else {
207 Ok(asym_scale_zp(self.min_val, self.max_val, self.bits))
208 }
209 }
210
211 fn reset(&mut self) {
212 self.min_val = 0.0;
213 self.max_val = 0.0;
214 self.initialized = false;
215 }
216
217 fn is_calibrated(&self) -> bool {
218 self.initialized
219 }
220}
221
222#[derive(Debug, Clone)]
230pub struct HistogramObserver {
231 bins: Vec<u64>,
233 range_min: f32,
235 range_max: f32,
237 n_bins: usize,
239 pub bits: u32,
241 pub symmetric: bool,
243 initialized: bool,
244}
245
246impl HistogramObserver {
247 #[must_use]
253 pub fn new(bits: u32, symmetric: bool, n_bins: usize) -> Self {
254 assert!(bits > 0 && bits <= 16, "bits must be in [1, 16]");
255 assert!(n_bins > 0, "n_bins must be > 0");
256 Self {
257 bins: vec![0_u64; n_bins],
258 range_min: 0.0,
259 range_max: 0.0,
260 n_bins,
261 bits,
262 symmetric,
263 initialized: false,
264 }
265 }
266
267 fn bin_width(&self) -> f32 {
269 (self.range_max - self.range_min) / self.n_bins as f32
270 }
271
272 fn estimate_mse(&self, lo: f32, hi: f32) -> f32 {
274 let bw = self.bin_width();
275 let total: u64 = self.bins.iter().sum();
276 if total == 0 || (hi - lo).abs() < 1e-12 {
277 return f32::INFINITY;
278 }
279
280 let n_levels = ((1u32 << self.bits) - 1) as f32;
281 let step = (hi - lo) / n_levels;
282
283 let mut mse = 0.0_f32;
284 for (b, &cnt) in self.bins.iter().enumerate() {
285 if cnt == 0 {
286 continue;
287 }
288 let center = self.range_min + (b as f32 + 0.5) * bw;
289 let quant_val = if center <= lo {
290 lo
291 } else if center >= hi {
292 hi
293 } else {
294 let idx = ((center - lo) / step).round();
295 lo + idx * step
296 };
297 let err = center - quant_val;
298 mse += cnt as f32 * err * err;
299 }
300 mse / total as f32
301 }
302}
303
304impl Observer for HistogramObserver {
305 fn observe(&mut self, data: &[f32]) {
306 let finite: Vec<f32> = data.iter().copied().filter(|v| v.is_finite()).collect();
307 if finite.is_empty() {
308 return;
309 }
310
311 let d_min = finite.iter().copied().fold(f32::INFINITY, f32::min);
312 let d_max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
313
314 if !self.initialized {
315 self.range_min = d_min;
316 self.range_max = d_max;
317 self.initialized = true;
318 } else {
319 if d_min < self.range_min {
321 self.range_min = d_min;
322 }
323 if d_max > self.range_max {
324 self.range_max = d_max;
325 }
326 }
327
328 if (self.range_max - self.range_min).abs() < 1e-8 {
330 self.range_max = self.range_min + 1e-8;
331 }
332
333 let bw = self.bin_width();
334 for &v in &finite {
335 let idx = ((v - self.range_min) / bw) as usize;
336 let idx = idx.min(self.n_bins - 1);
337 self.bins[idx] += 1;
338 }
339 }
340
341 fn compute_params(&self) -> QuantResult<(f32, i32)> {
342 if !self.is_calibrated() {
343 return Err(QuantError::CalibrationRequired("HistogramObserver"));
344 }
345
346 let n_search = 20_usize;
348 let mut best_mse = f32::INFINITY;
349 let mut best_lo = self.range_min;
350 let mut best_hi = self.range_max;
351
352 let total: u64 = self.bins.iter().sum();
353 if total == 0 {
354 return Err(QuantError::CalibrationRequired("HistogramObserver"));
355 }
356
357 let percentiles: Vec<f32> = (1..=n_search).map(|i| i as f32 / n_search as f32).collect();
359
360 for &pct in &percentiles {
361 let threshold = (pct * total as f32) as u64;
362 let mut cum = 0_u64;
363 let mut cut_bin = self.n_bins - 1;
364 for (b, &cnt) in self.bins.iter().enumerate() {
365 cum += cnt;
366 if cum >= threshold {
367 cut_bin = b;
368 break;
369 }
370 }
371 let bw = self.bin_width();
372 let hi = self.range_min + (cut_bin as f32 + 1.0) * bw;
373 let lo = if self.symmetric { -hi } else { self.range_min };
374
375 let mse = self.estimate_mse(lo, hi);
376 if mse < best_mse {
377 best_mse = mse;
378 best_lo = lo;
379 best_hi = hi;
380 }
381 }
382
383 if self.symmetric {
384 let abs_max = best_lo.abs().max(best_hi.abs());
385 Ok((sym_scale(abs_max, self.bits), 0))
386 } else {
387 Ok(asym_scale_zp(best_lo, best_hi, self.bits))
388 }
389 }
390
391 fn reset(&mut self) {
392 self.bins.fill(0);
393 self.range_min = 0.0;
394 self.range_max = 0.0;
395 self.initialized = false;
396 }
397
398 fn is_calibrated(&self) -> bool {
399 self.initialized
400 }
401}
402
403#[cfg(test)]
406mod tests {
407 use super::*;
408 use approx::assert_abs_diff_eq;
409
410 #[test]
411 fn minmax_symmetric_scale() {
412 let mut obs = MinMaxObserver::new(8, true);
413 obs.observe(&[-2.0_f32, -1.0, 0.5, 2.0]);
414 let (scale, zp) = obs.compute_params().unwrap();
415 assert_abs_diff_eq!(scale, 2.0 / 127.0, epsilon = 1e-6);
417 assert_eq!(zp, 0);
418 }
419
420 #[test]
421 fn minmax_asymmetric_scale_zp() {
422 let mut obs = MinMaxObserver::new(8, false);
423 obs.observe(&[0.0_f32, 1.0, 2.0, 3.0]);
424 let (scale, zp) = obs.compute_params().unwrap();
425 assert_abs_diff_eq!(scale, 3.0 / 255.0, epsilon = 1e-5);
426 assert_eq!(zp, 0);
427 }
428
429 #[test]
430 fn minmax_calibration_required() {
431 let obs = MinMaxObserver::new(8, true);
432 assert!(matches!(
433 obs.compute_params(),
434 Err(QuantError::CalibrationRequired(_))
435 ));
436 }
437
438 #[test]
439 fn minmax_reset() {
440 let mut obs = MinMaxObserver::new(8, true);
441 obs.observe(&[1.0_f32, 2.0]);
442 obs.reset();
443 assert!(!obs.is_calibrated());
444 }
445
446 #[test]
447 fn moving_avg_first_batch_exact() {
448 let mut obs = MovingAvgObserver::new(8, true, 0.9);
449 obs.observe(&[-1.0_f32, 1.0]);
450 let (scale, zp) = obs.compute_params().unwrap();
452 assert_abs_diff_eq!(scale, 1.0 / 127.0, epsilon = 1e-5);
453 assert_eq!(zp, 0);
454 }
455
456 #[test]
457 fn moving_avg_ema_update() {
458 let mut obs = MovingAvgObserver::new(8, true, 0.9);
459 obs.observe(&[2.0_f32, 2.0]); obs.observe(&[4.0_f32, 4.0]); assert_abs_diff_eq!(obs.max_val, 2.2, epsilon = 1e-5);
463 }
464
465 #[test]
466 fn moving_avg_calibration_required() {
467 let obs = MovingAvgObserver::new(8, true, 0.9);
468 assert!(matches!(
469 obs.compute_params(),
470 Err(QuantError::CalibrationRequired(_))
471 ));
472 }
473
474 #[test]
475 fn histogram_observer_calibrates() {
476 let mut obs = HistogramObserver::new(8, true, 256);
477 let data: Vec<f32> = (0..1024).map(|i| (i as f32 / 512.0) - 1.0).collect();
478 obs.observe(&data);
479 assert!(obs.is_calibrated());
480 let (scale, zp) = obs.compute_params().unwrap();
481 assert!(scale > 0.0, "scale must be positive: {scale}");
482 assert_eq!(zp, 0, "symmetric: zp must be 0");
483 }
484
485 #[test]
486 fn histogram_observer_reset() {
487 let mut obs = HistogramObserver::new(8, true, 128);
488 obs.observe(&[1.0_f32, 2.0]);
489 obs.reset();
490 assert!(!obs.is_calibrated());
491 }
492
493 #[test]
494 fn histogram_observer_uncalibrated_error() {
495 let obs = HistogramObserver::new(8, true, 64);
496 assert!(matches!(
497 obs.compute_params(),
498 Err(QuantError::CalibrationRequired(_))
499 ));
500 }
501}