entrenar/prune/calibrate/
collector.rs1use std::collections::HashMap;
4
5use super::config::CalibrationConfig;
6use super::stats::LayerActivationStats;
7
8#[derive(Debug, Clone)]
13pub struct CalibrationCollector {
14 config: CalibrationConfig,
16 layer_stats: HashMap<String, LayerActivationStats>,
18 samples_processed: usize,
20 complete: bool,
22}
23
24impl CalibrationCollector {
25 pub fn new(config: CalibrationConfig) -> Self {
27 Self { config, layer_stats: HashMap::new(), samples_processed: 0, complete: false }
28 }
29
30 pub fn register_layer(&mut self, name: impl Into<String>, input_dim: usize) {
37 let name = name.into();
38 self.layer_stats.entry(name).or_insert_with(|| LayerActivationStats::new(input_dim));
39 }
40
41 pub fn record_activations(&mut self, layer_name: &str, activations: &[Vec<f32>]) {
48 if self.complete {
49 return;
50 }
51
52 if let Some(stats) = self.layer_stats.get_mut(layer_name) {
53 stats.update(activations);
54 }
55 }
56
57 pub fn batch_complete(&mut self, batch_size: usize) {
59 self.samples_processed += batch_size;
60
61 if self.samples_processed >= self.config.num_samples() {
62 self.complete = true;
63 }
64 }
65
66 pub fn get_layer_stats(&self, layer_name: &str) -> Option<&LayerActivationStats> {
68 self.layer_stats.get(layer_name)
69 }
70
71 pub fn layer_names(&self) -> Vec<&String> {
73 self.layer_stats.keys().collect()
74 }
75
76 pub fn is_complete(&self) -> bool {
78 self.complete
79 }
80
81 pub fn samples_processed(&self) -> usize {
83 self.samples_processed
84 }
85
86 pub fn config(&self) -> &CalibrationConfig {
88 &self.config
89 }
90
91 pub fn reset(&mut self) {
93 for stats in self.layer_stats.values_mut() {
94 stats.reset();
95 }
96 self.samples_processed = 0;
97 self.complete = false;
98 }
99
100 pub fn progress(&self) -> f32 {
102 if self.config.num_samples() == 0 {
103 return 1.0;
104 }
105 (self.samples_processed as f32 / self.config.num_samples() as f32).min(1.0)
106 }
107}