Skip to main content

entrenar/prune/calibrate/
collector.rs

1//! Calibration data collector for pruning.
2
3use std::collections::HashMap;
4
5use super::config::CalibrationConfig;
6use super::stats::LayerActivationStats;
7
8/// Calibration data collector for pruning.
9///
10/// Collects and stores per-layer activation statistics during
11/// calibration forward passes.
12#[derive(Debug, Clone)]
13pub struct CalibrationCollector {
14    /// Configuration for calibration.
15    config: CalibrationConfig,
16    /// Per-layer activation statistics.
17    layer_stats: HashMap<String, LayerActivationStats>,
18    /// Total samples processed.
19    samples_processed: usize,
20    /// Whether calibration is complete.
21    complete: bool,
22}
23
24impl CalibrationCollector {
25    /// Create a new calibration collector.
26    pub fn new(config: CalibrationConfig) -> Self {
27        Self { config, layer_stats: HashMap::new(), samples_processed: 0, complete: false }
28    }
29
30    /// Register a layer for calibration.
31    ///
32    /// # Arguments
33    ///
34    /// * `name` - Layer name
35    /// * `input_dim` - Input feature dimension
36    pub fn register_layer(&mut self, name: impl Into<String>, input_dim: usize) {
37        let name = name.into();
38        self.layer_stats.entry(name).or_insert_with(|| LayerActivationStats::new(input_dim));
39    }
40
41    /// Record activations for a layer.
42    ///
43    /// # Arguments
44    ///
45    /// * `layer_name` - Layer name
46    /// * `activations` - Batch of activations
47    pub fn record_activations(&mut self, layer_name: &str, activations: &[Vec<f32>]) {
48        if self.complete {
49            return;
50        }
51
52        if let Some(stats) = self.layer_stats.get_mut(layer_name) {
53            stats.update(activations);
54        }
55    }
56
57    /// Mark a batch as processed.
58    pub fn batch_complete(&mut self, batch_size: usize) {
59        self.samples_processed += batch_size;
60
61        if self.samples_processed >= self.config.num_samples() {
62            self.complete = true;
63        }
64    }
65
66    /// Get activation statistics for a layer.
67    pub fn get_layer_stats(&self, layer_name: &str) -> Option<&LayerActivationStats> {
68        self.layer_stats.get(layer_name)
69    }
70
71    /// Get all layer names.
72    pub fn layer_names(&self) -> Vec<&String> {
73        self.layer_stats.keys().collect()
74    }
75
76    /// Check if calibration is complete.
77    pub fn is_complete(&self) -> bool {
78        self.complete
79    }
80
81    /// Get the number of samples processed.
82    pub fn samples_processed(&self) -> usize {
83        self.samples_processed
84    }
85
86    /// Get the configuration.
87    pub fn config(&self) -> &CalibrationConfig {
88        &self.config
89    }
90
91    /// Reset all collected statistics.
92    pub fn reset(&mut self) {
93        for stats in self.layer_stats.values_mut() {
94            stats.reset();
95        }
96        self.samples_processed = 0;
97        self.complete = false;
98    }
99
100    /// Get progress as a fraction (0.0 to 1.0).
101    pub fn progress(&self) -> f32 {
102        if self.config.num_samples() == 0 {
103            return 1.0;
104        }
105        (self.samples_processed as f32 / self.config.num_samples() as f32).min(1.0)
106    }
107}