use std::collections::HashMap;
use super::config::CalibrationConfig;
use super::stats::LayerActivationStats;
#[derive(Debug, Clone)]
pub struct CalibrationCollector {
config: CalibrationConfig,
layer_stats: HashMap<String, LayerActivationStats>,
samples_processed: usize,
complete: bool,
}
impl CalibrationCollector {
pub fn new(config: CalibrationConfig) -> Self {
Self { config, layer_stats: HashMap::new(), samples_processed: 0, complete: false }
}
pub fn register_layer(&mut self, name: impl Into<String>, input_dim: usize) {
let name = name.into();
self.layer_stats.entry(name).or_insert_with(|| LayerActivationStats::new(input_dim));
}
pub fn record_activations(&mut self, layer_name: &str, activations: &[Vec<f32>]) {
if self.complete {
return;
}
if let Some(stats) = self.layer_stats.get_mut(layer_name) {
stats.update(activations);
}
}
pub fn batch_complete(&mut self, batch_size: usize) {
self.samples_processed += batch_size;
if self.samples_processed >= self.config.num_samples() {
self.complete = true;
}
}
pub fn get_layer_stats(&self, layer_name: &str) -> Option<&LayerActivationStats> {
self.layer_stats.get(layer_name)
}
pub fn layer_names(&self) -> Vec<&String> {
self.layer_stats.keys().collect()
}
pub fn is_complete(&self) -> bool {
self.complete
}
pub fn samples_processed(&self) -> usize {
self.samples_processed
}
pub fn config(&self) -> &CalibrationConfig {
&self.config
}
pub fn reset(&mut self) {
for stats in self.layer_stats.values_mut() {
stats.reset();
}
self.samples_processed = 0;
self.complete = false;
}
pub fn progress(&self) -> f32 {
if self.config.num_samples() == 0 {
return 1.0;
}
(self.samples_processed as f32 / self.config.num_samples() as f32).min(1.0)
}
}