1use crate::calibration::methods::CalibrationMethod;
7
8const NUM_BINS: usize = 256;
9
10#[derive(Debug, Clone)]
15pub struct ActivationStats {
16 min: f32,
17 max: f32,
18 mean: f32,
19 std: f32,
20 count: usize,
21
22 m2: f64,
24
25 histogram_bins: Vec<usize>,
26 hist_min: f32,
27 hist_max: f32,
28}
29
30impl ActivationStats {
31 pub fn min(&self) -> f32 { self.min }
33 pub fn max(&self) -> f32 { self.max }
35 pub fn mean(&self) -> f32 { self.mean }
37 pub fn std(&self) -> f32 { self.std }
39 pub fn count(&self) -> usize { self.count }
41}
42
43impl ActivationStats {
44 pub fn from_data(data: &[f32]) -> Self {
46 if data.is_empty() {
47 return Self::default();
48 }
49
50 let finite: Vec<f32> = data.iter().copied().filter(|v| v.is_finite()).collect();
51 if finite.is_empty() {
52 return Self::default();
53 }
54
55 let min = finite.iter().copied().fold(f32::INFINITY, f32::min);
56 let max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
57
58 let sum: f32 = finite.iter().sum();
59 let mean = sum / finite.len() as f32;
60
61 let m2: f64 = finite.iter()
62 .map(|&x| ((x - mean) as f64).powi(2))
63 .sum();
64 let std = (m2 / finite.len() as f64).sqrt() as f32;
65
66 let histogram_bins = build_histogram(data, min, max);
67
68 Self {
69 min,
70 max,
71 mean,
72 std,
73 count: finite.len(),
74 m2,
75 histogram_bins,
76 hist_min: min,
77 hist_max: max,
78 }
79 }
80
81 pub fn update(&mut self, data: &[f32]) {
83 if data.is_empty() {
84 return;
85 }
86
87 let data_min = data.iter().copied().filter(|v| v.is_finite()).fold(f32::INFINITY, f32::min);
88 let data_max = data.iter().copied().filter(|v| v.is_finite()).fold(f32::NEG_INFINITY, f32::max);
89
90 let new_min = self.min.min(data_min);
91 let new_max = self.max.max(data_max);
92
93 let old_count = self.count as f64;
96 let new_count = data.len() as f64;
97 let combined_count = old_count + new_count;
98
99 let data_sum: f64 = data.iter().map(|&x| x as f64).sum();
100 let data_mean = data_sum / new_count;
101
102 let data_m2: f64 = data.iter()
103 .map(|&x| ((x as f64) - data_mean).powi(2))
104 .sum();
105
106 let delta = data_mean - self.mean as f64;
108 self.m2 = self.m2 + data_m2 + delta * delta * old_count * new_count / combined_count;
109
110 self.mean = ((self.mean as f64) * old_count + data_sum) as f32 / combined_count as f32;
111 self.count = combined_count as usize;
112 self.std = (self.m2 / combined_count).sqrt() as f32;
113
114 if new_min < self.hist_min || new_max > self.hist_max {
116 let mut rebinned = vec![0usize; NUM_BINS];
117 rebin(&self.histogram_bins, self.hist_min, self.hist_max, &mut rebinned, new_min, new_max);
118 self.histogram_bins = rebinned;
119 self.hist_min = new_min;
120 self.hist_max = new_max;
121 }
122
123 let new_hist = build_histogram(data, self.hist_min, self.hist_max);
125 for (i, &c) in new_hist.iter().enumerate() {
126 self.histogram_bins[i] += c;
127 }
128
129 self.min = new_min;
130 self.max = new_max;
131 }
132
133 pub fn percentile(&self, p: f32) -> f32 {
135 if self.histogram_bins.is_empty() {
136 return self.min;
137 }
138
139 let total: usize = self.histogram_bins.iter().sum();
140 if total == 0 {
141 return self.min;
142 }
143
144 let target_count = (total as f32 * p / 100.0).ceil() as usize;
147 let mut cumulative = 0;
148
149 let bin_size = if (self.hist_max - self.hist_min).abs() < 1e-8 {
150 0.0
151 } else {
152 (self.hist_max - self.hist_min) / NUM_BINS as f32
153 };
154
155 for (i, &count) in self.histogram_bins.iter().enumerate() {
156 cumulative += count;
157 if cumulative >= target_count {
158 return self.hist_min + (i as f32 + 0.5) * bin_size;
159 }
160 }
161
162 self.max
163 }
164
165 pub fn histogram_data(&self) -> Vec<(f32, usize)> {
167 if (self.hist_max - self.hist_min).abs() < 1e-8 {
168 let total: usize = self.histogram_bins.iter().sum();
169 if total > 0 {
170 return vec![(self.hist_min, total)];
171 }
172 return Vec::new();
173 }
174 let bin_size = (self.hist_max - self.hist_min) / NUM_BINS as f32;
175 self.histogram_bins.iter()
176 .enumerate()
177 .filter(|(_, &count)| count > 0)
178 .map(|(i, &count)| {
179 let value = self.hist_min + (i as f32 + 0.5) * bin_size;
180 (value, count)
181 })
182 .collect()
183 }
184}
185
186impl Default for ActivationStats {
187 fn default() -> Self {
188 Self {
189 min: f32::INFINITY,
190 max: f32::NEG_INFINITY,
191 mean: 0.0,
192 std: 0.0,
193 count: 0,
194 m2: 0.0,
195 histogram_bins: Vec::new(),
196 hist_min: 0.0,
197 hist_max: 0.0,
198 }
199 }
200}
201
202fn build_histogram(data: &[f32], min: f32, max: f32) -> Vec<usize> {
203 let mut bins = vec![0usize; NUM_BINS];
204
205 if (max - min).abs() < 1e-8 {
206 let finite_count = data.iter().filter(|v| v.is_finite()).count();
208 if !bins.is_empty() {
209 bins[0] = finite_count;
210 }
211 return bins;
212 }
213
214 let bin_size = (max - min) / NUM_BINS as f32;
215
216 for &value in data {
217 if !value.is_finite() { continue; }
218 let bin_idx = ((value - min) / bin_size).floor() as usize;
219 let bin_idx = bin_idx.min(NUM_BINS - 1);
220 bins[bin_idx] += 1;
221 }
222
223 bins
224}
225
226fn rebin(
228 old_bins: &[usize],
229 old_min: f32,
230 old_max: f32,
231 new_bins: &mut [usize],
232 new_min: f32,
233 new_max: f32,
234) {
235 if old_bins.is_empty() || new_bins.is_empty() {
236 return;
237 }
238 let old_range = old_max - old_min;
239 let new_range = new_max - new_min;
240 if old_range.abs() < 1e-8 || new_range.abs() < 1e-8 {
241 let total: usize = old_bins.iter().sum();
243 if total > 0 {
244 let center = (old_min + old_max) * 0.5;
245 let idx = ((center - new_min) / new_range * new_bins.len() as f32).floor() as usize;
246 let idx = idx.min(new_bins.len() - 1);
247 new_bins[idx] += total;
248 }
249 return;
250 }
251 let old_bin_size = old_range / old_bins.len() as f32;
252 let new_bin_count = new_bins.len();
253 for (i, &count) in old_bins.iter().enumerate() {
254 if count == 0 { continue; }
255 let center = old_min + (i as f32 + 0.5) * old_bin_size;
256 let new_idx = ((center - new_min) / new_range * new_bin_count as f32).floor() as usize;
257 let new_idx = new_idx.min(new_bin_count - 1);
258 new_bins[new_idx] += count;
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn test_activation_stats() {
268 let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
269 let stats = ActivationStats::from_data(&data);
270
271 assert_eq!(stats.min(), -1.0);
272 assert_eq!(stats.max(), 1.0);
273 assert!((stats.mean() - 0.0).abs() < 0.01);
274
275 let p50 = stats.percentile(50.0);
276 assert!((p50 - 0.0).abs() < 0.3);
277 }
278}
279
280pub fn calculate_optimal_range(
282 data: &[f32],
283 method: CalibrationMethod,
284) -> (f32, f32) {
285 if data.is_empty() {
286 return (0.0, 0.0);
287 }
288
289 match method {
290 CalibrationMethod::MinMax => {
291 let min = data.iter().copied().filter(|v| v.is_finite()).fold(f32::INFINITY, f32::min);
292 let max = data.iter().copied().filter(|v| v.is_finite()).fold(f32::NEG_INFINITY, f32::max);
293 (min, max)
294 }
295
296 CalibrationMethod::Percentile(p) => {
297 let stats = ActivationStats::from_data(data);
298 let lower = stats.percentile(100.0 - p);
299 let upper = stats.percentile(p);
300 (lower, upper)
301 }
302
303 CalibrationMethod::Entropy => {
304 optimize_kl_divergence(data)
305 }
306
307 CalibrationMethod::MSE => {
308 optimize_mse(data)
309 }
310 }
311}
312
313fn optimize_kl_divergence(data: &[f32]) -> (f32, f32) {
315 let stats = ActivationStats::from_data(data);
316
317 let candidates = [99.0, 99.5, 99.9, 99.95, 99.99];
319 let mut best_range = (stats.min, stats.max);
320 let mut best_kl = f32::INFINITY;
321
322 for &percentile in &candidates {
323 let lower = stats.percentile(100.0 - percentile);
324 let upper = stats.percentile(percentile);
325
326 let kl = calculate_kl_divergence(data, lower, upper);
327
328 if kl < best_kl {
329 best_kl = kl;
330 best_range = (lower, upper);
331 }
332 }
333
334 best_range
335}
336
337fn optimize_mse(data: &[f32]) -> (f32, f32) {
339 let stats = ActivationStats::from_data(data);
340
341 let candidates = [99.0, 99.5, 99.9, 99.95, 99.99];
343 let mut best_range = (stats.min, stats.max);
344 let mut best_mse = f32::INFINITY;
345
346 for &percentile in &candidates {
347 let lower = stats.percentile(100.0 - percentile);
348 let upper = stats.percentile(percentile);
349
350 let mse = calculate_quantization_mse(data, lower, upper);
351
352 if mse < best_mse {
353 best_mse = mse;
354 best_range = (lower, upper);
355 }
356 }
357
358 best_range
359}
360
361fn calculate_kl_divergence(data: &[f32], min: f32, max: f32) -> f32 {
366 if (max - min).abs() < 1e-8 {
367 return 0.0;
368 }
369
370 let num_bins = 128;
371 let bin_size = (max - min) / num_bins as f32;
372 let scale = (max - min) / 255.0;
373
374 let mut orig_bins = vec![0usize; num_bins];
375 let mut quant_bins = vec![0usize; num_bins];
376
377 for &v in data {
378 let clipped = v.clamp(min, max);
379
380 let bin = ((clipped - min) / bin_size).floor() as usize;
382 let bin = bin.min(num_bins - 1);
383 orig_bins[bin] += 1;
384
385 let q = ((clipped - min) / scale).round();
387 let dequant = min + q * scale;
388 let qbin = ((dequant.clamp(min, max) - min) / bin_size).floor() as usize;
389 let qbin = qbin.min(num_bins - 1);
390 quant_bins[qbin] += 1;
391 }
392
393 let n = data.len() as f32;
394 let epsilon = 1e-10_f32;
395 let mut kl = 0.0_f32;
396
397 for i in 0..num_bins {
398 let p = (orig_bins[i] as f32 + epsilon) / (n + epsilon * num_bins as f32);
399 let q = (quant_bins[i] as f32 + epsilon) / (n + epsilon * num_bins as f32);
400 kl += p * (p / q).ln();
401 }
402
403 kl
404}
405
406fn calculate_quantization_mse(data: &[f32], min: f32, max: f32) -> f32 {
407 if (max - min).abs() < 1e-8 {
408 return 0.0;
409 }
410
411 let scale = (max - min) / 255.0;
412
413 let mse: f32 = data.iter()
414 .map(|&v| {
415 let clipped = v.clamp(min, max);
416 let q = ((clipped - min) / scale).round();
417 let dequantized = min + q * scale;
418 (v - dequantized).powi(2)
419 })
420 .sum::<f32>() / data.len() as f32;
421
422 mse
423}