entrenar/prune/pipeline/
metrics.rs1use super::stage::PruningStage;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Default, Serialize, Deserialize)]
10pub struct PruningMetrics {
11 pub achieved_sparsity: f32,
13 pub target_sparsity: f32,
15 pub total_parameters: usize,
17 pub parameters_pruned: usize,
19 pub parameters_remaining: usize,
21 pub layer_sparsity: Vec<(String, f32)>,
23 pub pre_prune_ppl: Option<f32>,
25 pub post_prune_ppl: Option<f32>,
27 pub ppl_increase_pct: Option<f32>,
29 pub finetune_losses: Vec<f32>,
31 pub stage_durations: Vec<(PruningStage, f64)>,
33}
34
35impl PruningMetrics {
36 pub fn new(target_sparsity: f32) -> Self {
38 Self { target_sparsity, ..Default::default() }
39 }
40
41 pub fn update_sparsity(&mut self, pruned: usize, total: usize) {
43 self.total_parameters = total;
44 self.parameters_pruned = pruned;
45 self.parameters_remaining = total.saturating_sub(pruned);
46 self.achieved_sparsity = if total > 0 { pruned as f32 / total as f32 } else { 0.0 };
47 }
48
49 pub fn add_layer_sparsity(&mut self, name: impl Into<String>, sparsity: f32) {
51 self.layer_sparsity.push((name.into(), sparsity));
52 }
53
54 pub fn set_pre_prune_ppl(&mut self, ppl: f32) {
56 self.pre_prune_ppl = Some(ppl);
57 }
58
59 pub fn set_post_prune_ppl(&mut self, ppl: f32) {
61 self.post_prune_ppl = Some(ppl);
62 if let Some(pre) = self.pre_prune_ppl {
63 if pre > 0.0 {
64 self.ppl_increase_pct = Some((ppl - pre) / pre * 100.0);
65 }
66 }
67 }
68
69 pub fn record_finetune_loss(&mut self, loss: f32) {
71 self.finetune_losses.push(loss);
72 }
73
74 pub fn record_stage_duration(&mut self, stage: PruningStage, duration_secs: f64) {
76 self.stage_durations.push((stage, duration_secs));
77 }
78
79 pub fn sparsity_gap(&self) -> f32 {
81 self.target_sparsity - self.achieved_sparsity
82 }
83
84 pub fn target_achieved(&self) -> bool {
86 self.achieved_sparsity >= self.target_sparsity - 1e-4
87 }
88
89 pub fn mean_layer_sparsity(&self) -> f32 {
91 if self.layer_sparsity.is_empty() {
92 return self.achieved_sparsity;
93 }
94 let sum: f32 = self.layer_sparsity.iter().map(|(_, s)| s).sum();
95 sum / self.layer_sparsity.len() as f32
96 }
97
98 pub fn layer_sparsity_variance(&self) -> f32 {
100 if self.layer_sparsity.is_empty() {
101 return 0.0;
102 }
103 let mean = self.mean_layer_sparsity();
104 let variance: f32 =
105 self.layer_sparsity.iter().map(|(_, s)| (s - mean).powi(2)).sum::<f32>()
106 / self.layer_sparsity.len().max(1) as f32;
107 variance
108 }
109
110 pub fn total_duration_secs(&self) -> f64 {
112 self.stage_durations.iter().map(|(_, d)| d).sum()
113 }
114}