use super::stage::PruningStage;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PruningMetrics {
pub achieved_sparsity: f32,
pub target_sparsity: f32,
pub total_parameters: usize,
pub parameters_pruned: usize,
pub parameters_remaining: usize,
pub layer_sparsity: Vec<(String, f32)>,
pub pre_prune_ppl: Option<f32>,
pub post_prune_ppl: Option<f32>,
pub ppl_increase_pct: Option<f32>,
pub finetune_losses: Vec<f32>,
pub stage_durations: Vec<(PruningStage, f64)>,
}
impl PruningMetrics {
pub fn new(target_sparsity: f32) -> Self {
Self { target_sparsity, ..Default::default() }
}
pub fn update_sparsity(&mut self, pruned: usize, total: usize) {
self.total_parameters = total;
self.parameters_pruned = pruned;
self.parameters_remaining = total.saturating_sub(pruned);
self.achieved_sparsity = if total > 0 { pruned as f32 / total as f32 } else { 0.0 };
}
pub fn add_layer_sparsity(&mut self, name: impl Into<String>, sparsity: f32) {
self.layer_sparsity.push((name.into(), sparsity));
}
pub fn set_pre_prune_ppl(&mut self, ppl: f32) {
self.pre_prune_ppl = Some(ppl);
}
pub fn set_post_prune_ppl(&mut self, ppl: f32) {
self.post_prune_ppl = Some(ppl);
if let Some(pre) = self.pre_prune_ppl {
if pre > 0.0 {
self.ppl_increase_pct = Some((ppl - pre) / pre * 100.0);
}
}
}
pub fn record_finetune_loss(&mut self, loss: f32) {
self.finetune_losses.push(loss);
}
pub fn record_stage_duration(&mut self, stage: PruningStage, duration_secs: f64) {
self.stage_durations.push((stage, duration_secs));
}
pub fn sparsity_gap(&self) -> f32 {
self.target_sparsity - self.achieved_sparsity
}
pub fn target_achieved(&self) -> bool {
self.achieved_sparsity >= self.target_sparsity - 1e-4
}
pub fn mean_layer_sparsity(&self) -> f32 {
if self.layer_sparsity.is_empty() {
return self.achieved_sparsity;
}
let sum: f32 = self.layer_sparsity.iter().map(|(_, s)| s).sum();
sum / self.layer_sparsity.len() as f32
}
pub fn layer_sparsity_variance(&self) -> f32 {
if self.layer_sparsity.is_empty() {
return 0.0;
}
let mean = self.mean_layer_sparsity();
let variance: f32 =
self.layer_sparsity.iter().map(|(_, s)| (s - mean).powi(2)).sum::<f32>()
/ self.layer_sparsity.len().max(1) as f32;
variance
}
pub fn total_duration_secs(&self) -> f64 {
self.stage_durations.iter().map(|(_, d)| d).sum()
}
}