entrenar/quant/calibration/
calibrator.rs1use crate::Tensor;
7
8use super::helpers::rand_simple;
9use super::types::{CalibrationMethod, CalibrationResult};
10
11#[derive(Clone, Debug)]
13pub struct Calibrator {
14 method: CalibrationMethod,
16 symmetric: bool,
18 bits: usize,
20 running_min: Option<f32>,
22 running_max: Option<f32>,
24 samples: Vec<f32>,
26 max_samples: usize,
28 num_batches: usize,
30}
31
32impl Calibrator {
33 pub fn min_max(bits: usize, symmetric: bool) -> Self {
35 Self {
36 method: CalibrationMethod::MinMax,
37 symmetric,
38 bits,
39 running_min: None,
40 running_max: None,
41 samples: Vec::new(),
42 max_samples: 0,
43 num_batches: 0,
44 }
45 }
46
47 pub fn percentile(
56 bits: usize,
57 symmetric: bool,
58 lower: f32,
59 upper: f32,
60 max_samples: usize,
61 ) -> Self {
62 Self {
63 method: CalibrationMethod::Percentile { lower, upper },
64 symmetric,
65 bits,
66 running_min: None,
67 running_max: None,
68 samples: Vec::with_capacity(max_samples.min(10000)),
69 max_samples,
70 num_batches: 0,
71 }
72 }
73
74 pub fn moving_average(bits: usize, symmetric: bool, momentum: f32) -> Self {
76 Self {
77 method: CalibrationMethod::MovingAverage { momentum },
78 symmetric,
79 bits,
80 running_min: None,
81 running_max: None,
82 samples: Vec::new(),
83 max_samples: 0,
84 num_batches: 0,
85 }
86 }
87
88 pub fn observe(&mut self, data: &[f32]) {
90 if data.is_empty() {
91 return;
92 }
93
94 match &self.method {
95 CalibrationMethod::MinMax => {
96 self.observe_min_max(data);
97 }
98 CalibrationMethod::Percentile { .. } => {
99 self.observe_percentile(data);
100 }
101 CalibrationMethod::MovingAverage { momentum } => {
102 let momentum = *momentum;
103 self.observe_moving_average(data, momentum);
104 }
105 }
106
107 self.num_batches += 1;
108 }
109
110 pub fn observe_tensor(&mut self, tensor: &Tensor) {
112 if let Some(slice) = tensor.data().as_slice() {
113 self.observe(slice);
114 }
115 }
116
117 pub fn observe_tensors(&mut self, tensors: &[&Tensor]) {
119 for tensor in tensors {
120 self.observe_tensor(tensor);
121 }
122 }
123
124 pub fn compute(&self) -> CalibrationResult {
126 let (observed_min, observed_max) = match &self.method {
127 CalibrationMethod::MinMax | CalibrationMethod::MovingAverage { .. } => {
128 (self.running_min.unwrap_or(0.0), self.running_max.unwrap_or(0.0))
129 }
130 CalibrationMethod::Percentile { lower, upper } => {
131 self.compute_percentile_bounds(*lower, *upper)
132 }
133 };
134
135 let (scale, zero_point) = self.compute_scale_zero_point(observed_min, observed_max);
136
137 CalibrationResult {
138 scale,
139 zero_point,
140 observed_min,
141 observed_max,
142 method: self.method.clone(),
143 }
144 }
145
146 pub fn num_batches(&self) -> usize {
148 self.num_batches
149 }
150
151 pub fn method(&self) -> &CalibrationMethod {
153 &self.method
154 }
155
156 pub fn has_data(&self) -> bool {
158 self.num_batches > 0
159 }
160
161 pub fn reset(&mut self) {
163 self.running_min = None;
164 self.running_max = None;
165 self.samples.clear();
166 self.num_batches = 0;
167 }
168
169 fn observe_min_max(&mut self, data: &[f32]) {
172 let batch_min = data.iter().copied().fold(f32::INFINITY, f32::min);
173 let batch_max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
174
175 self.running_min = Some(self.running_min.map_or(batch_min, |m| m.min(batch_min)));
176 self.running_max = Some(self.running_max.map_or(batch_max, |m| m.max(batch_max)));
177 }
178
179 fn observe_percentile(&mut self, data: &[f32]) {
180 if self.samples.len() < self.max_samples {
182 let remaining = self.max_samples - self.samples.len();
183 self.samples.extend(data.iter().take(remaining).copied());
184 } else {
185 let total_seen = self.num_batches * data.len() + data.len();
187 for (i, &val) in data.iter().enumerate() {
188 let j = rand_simple(total_seen + i);
189 if j < self.max_samples {
190 self.samples[j] = val;
191 }
192 }
193 }
194
195 self.observe_min_max(data);
197 }
198
199 fn observe_moving_average(&mut self, data: &[f32], momentum: f32) {
200 let batch_min = data.iter().copied().fold(f32::INFINITY, f32::min);
201 let batch_max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
202
203 self.running_min = Some(
204 self.running_min.map_or(batch_min, |m| m * (1.0 - momentum) + batch_min * momentum),
205 );
206 self.running_max = Some(
207 self.running_max.map_or(batch_max, |m| m * (1.0 - momentum) + batch_max * momentum),
208 );
209 }
210
211 fn compute_percentile_bounds(&self, lower: f32, upper: f32) -> (f32, f32) {
212 if self.samples.is_empty() {
213 return (self.running_min.unwrap_or(0.0), self.running_max.unwrap_or(0.0));
214 }
215
216 let mut sorted = self.samples.clone();
217 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
218
219 let n = sorted.len();
220 let lower_idx = ((lower / 100.0) * n as f32) as usize;
221 let upper_idx = ((upper / 100.0) * n as f32).min((n - 1) as f32) as usize;
222
223 (sorted[lower_idx], sorted[upper_idx])
224 }
225
226 fn compute_scale_zero_point(&self, min_val: f32, max_val: f32) -> (f32, i32) {
227 let qmax = (1 << (self.bits - 1)) - 1;
228 let qmin = if self.symmetric { -qmax } else { 0 };
229 let qmax_full = if self.symmetric { qmax } else { (1 << self.bits) - 1 };
230
231 if self.symmetric {
232 let max_abs = min_val.abs().max(max_val.abs());
234 let scale = if max_abs < 1e-10 { 1e-10 } else { max_abs / qmax as f32 };
235 (scale, 0)
236 } else {
237 let range = max_val - min_val;
239 let scale = if range < 1e-10 { 1e-10 } else { range / (qmax_full - qmin) as f32 };
240 let zero_point = (qmin as f32 - min_val / scale).round() as i32;
241 let zero_point = zero_point.clamp(qmin, qmax_full);
242 (scale, zero_point)
243 }
244 }
245}