Skip to main content

entrenar/prune/pipeline/
metrics.rs

1//! Pruning metrics collection
2//!
3//! Tracks metrics collected during the pruning pipeline.
4
5use super::stage::PruningStage;
6use serde::{Deserialize, Serialize};
7
8/// Metrics collected during pruning.
9#[derive(Debug, Clone, Default, Serialize, Deserialize)]
10pub struct PruningMetrics {
11    /// Achieved sparsity (0.0 to 1.0).
12    pub achieved_sparsity: f32,
13    /// Target sparsity.
14    pub target_sparsity: f32,
15    /// Total parameters in model.
16    pub total_parameters: usize,
17    /// Parameters pruned (set to zero).
18    pub parameters_pruned: usize,
19    /// Parameters remaining (non-zero).
20    pub parameters_remaining: usize,
21    /// Per-layer sparsity.
22    pub layer_sparsity: Vec<(String, f32)>,
23    /// Pre-pruning perplexity (if evaluated).
24    pub pre_prune_ppl: Option<f32>,
25    /// Post-pruning perplexity (if evaluated).
26    pub post_prune_ppl: Option<f32>,
27    /// Perplexity increase percentage.
28    pub ppl_increase_pct: Option<f32>,
29    /// Fine-tuning loss curve.
30    pub finetune_losses: Vec<f32>,
31    /// Duration of each stage in seconds.
32    pub stage_durations: Vec<(PruningStage, f64)>,
33}
34
35impl PruningMetrics {
36    /// Create new metrics with target sparsity.
37    pub fn new(target_sparsity: f32) -> Self {
38        Self { target_sparsity, ..Default::default() }
39    }
40
41    /// Update achieved sparsity and parameter counts.
42    pub fn update_sparsity(&mut self, pruned: usize, total: usize) {
43        self.total_parameters = total;
44        self.parameters_pruned = pruned;
45        self.parameters_remaining = total.saturating_sub(pruned);
46        self.achieved_sparsity = if total > 0 { pruned as f32 / total as f32 } else { 0.0 };
47    }
48
49    /// Add layer sparsity.
50    pub fn add_layer_sparsity(&mut self, name: impl Into<String>, sparsity: f32) {
51        self.layer_sparsity.push((name.into(), sparsity));
52    }
53
54    /// Set pre-pruning perplexity.
55    pub fn set_pre_prune_ppl(&mut self, ppl: f32) {
56        self.pre_prune_ppl = Some(ppl);
57    }
58
59    /// Set post-pruning perplexity and compute increase.
60    pub fn set_post_prune_ppl(&mut self, ppl: f32) {
61        self.post_prune_ppl = Some(ppl);
62        if let Some(pre) = self.pre_prune_ppl {
63            if pre > 0.0 {
64                self.ppl_increase_pct = Some((ppl - pre) / pre * 100.0);
65            }
66        }
67    }
68
69    /// Record a fine-tuning loss.
70    pub fn record_finetune_loss(&mut self, loss: f32) {
71        self.finetune_losses.push(loss);
72    }
73
74    /// Record stage duration.
75    pub fn record_stage_duration(&mut self, stage: PruningStage, duration_secs: f64) {
76        self.stage_durations.push((stage, duration_secs));
77    }
78
79    /// Get sparsity gap (target - achieved).
80    pub fn sparsity_gap(&self) -> f32 {
81        self.target_sparsity - self.achieved_sparsity
82    }
83
84    /// Check if target sparsity was achieved.
85    pub fn target_achieved(&self) -> bool {
86        self.achieved_sparsity >= self.target_sparsity - 1e-4
87    }
88
89    /// Get mean layer sparsity.
90    pub fn mean_layer_sparsity(&self) -> f32 {
91        if self.layer_sparsity.is_empty() {
92            return self.achieved_sparsity;
93        }
94        let sum: f32 = self.layer_sparsity.iter().map(|(_, s)| s).sum();
95        sum / self.layer_sparsity.len() as f32
96    }
97
98    /// Get sparsity variance across layers.
99    pub fn layer_sparsity_variance(&self) -> f32 {
100        if self.layer_sparsity.is_empty() {
101            return 0.0;
102        }
103        let mean = self.mean_layer_sparsity();
104        let variance: f32 =
105            self.layer_sparsity.iter().map(|(_, s)| (s - mean).powi(2)).sum::<f32>()
106                / self.layer_sparsity.len().max(1) as f32;
107        variance
108    }
109
110    /// Get total pipeline duration.
111    pub fn total_duration_secs(&self) -> f64 {
112        self.stage_durations.iter().map(|(_, d)| d).sum()
113    }
114}