use super::metrics::PruningMetrics;
use super::stage::PruningStage;
use crate::prune::calibrate::CalibrationCollector;
use crate::prune::config::PruningConfig;
#[derive(Debug)]
pub struct PruneFinetunePipeline {
config: PruningConfig,
stage: PruningStage,
metrics: PruningMetrics,
calibration: Option<CalibrationCollector>,
error: Option<String>,
}
impl PruneFinetunePipeline {
pub fn new(config: PruningConfig) -> Self {
let metrics = PruningMetrics::new(config.target_sparsity());
Self { config, stage: PruningStage::Idle, metrics, calibration: None, error: None }
}
pub fn stage(&self) -> PruningStage {
self.stage
}
pub fn config(&self) -> &PruningConfig {
&self.config
}
pub fn metrics(&self) -> &PruningMetrics {
&self.metrics
}
pub fn metrics_mut(&mut self) -> &mut PruningMetrics {
&mut self.metrics
}
pub fn error(&self) -> Option<&str> {
self.error.as_deref()
}
pub fn start_calibration(&mut self, calibration: CalibrationCollector) {
if self.stage != PruningStage::Idle {
return;
}
self.calibration = Some(calibration);
self.stage = PruningStage::Calibrating;
}
pub fn advance(&mut self) {
self.stage = match self.stage {
PruningStage::Idle => PruningStage::Calibrating,
PruningStage::Calibrating => PruningStage::ComputingImportance,
PruningStage::ComputingImportance => PruningStage::Pruning,
PruningStage::Pruning => {
if self.config.fine_tune_after_pruning() {
PruningStage::FineTuning
} else {
PruningStage::Evaluating
}
}
PruningStage::FineTuning => PruningStage::Evaluating,
PruningStage::Evaluating => PruningStage::Exporting,
PruningStage::Exporting => PruningStage::Complete,
PruningStage::Complete | PruningStage::Failed => self.stage,
};
}
pub fn fail(&mut self, error: impl Into<String>) {
self.error = Some(error.into());
self.stage = PruningStage::Failed;
}
pub fn execute_export(
&mut self,
weights: &std::collections::HashMap<String, Vec<f32>>,
shapes: &std::collections::HashMap<String, Vec<usize>>,
output_dir: impl AsRef<std::path::Path>,
filename: &str,
) -> Result<super::sparse_export::SparseExportResult, String> {
if self.stage != PruningStage::Exporting {
return Err(format!("Cannot export in stage {:?}, expected Exporting", self.stage));
}
match super::sparse_export::export_sparse_model(
weights,
shapes,
&self.metrics,
output_dir,
filename,
) {
Ok(result) => {
self.advance(); Ok(result)
}
Err(e) => {
self.fail(format!("Export failed: {e}"));
Err(format!("Export failed: {e}"))
}
}
}
pub fn reset(&mut self) {
self.stage = PruningStage::Idle;
self.metrics = PruningMetrics::new(self.config.target_sparsity());
self.calibration = None;
self.error = None;
}
pub fn is_complete(&self) -> bool {
self.stage.is_terminal()
}
pub fn succeeded(&self) -> bool {
self.stage == PruningStage::Complete
}
pub fn failed(&self) -> bool {
self.stage == PruningStage::Failed
}
pub fn calibration(&self) -> Option<&CalibrationCollector> {
self.calibration.as_ref()
}
pub fn calibration_progress(&self) -> f32 {
self.calibration.as_ref().map_or(0.0, CalibrationCollector::progress)
}
pub fn overall_progress(&self) -> f32 {
match self.stage {
PruningStage::Idle => 0.0,
PruningStage::Calibrating => 0.1 + 0.1 * self.calibration_progress(),
PruningStage::ComputingImportance => 0.25,
PruningStage::Pruning => 0.4,
PruningStage::FineTuning => 0.6,
PruningStage::Evaluating => 0.8,
PruningStage::Exporting => 0.95,
PruningStage::Complete => 1.0,
PruningStage::Failed => 0.0, }
}
}
impl Clone for PruneFinetunePipeline {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
stage: self.stage,
metrics: self.metrics.clone(),
calibration: self.calibration.clone(),
error: self.error.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_pipeline() -> PruneFinetunePipeline {
PruneFinetunePipeline::new(PruningConfig::new())
}
#[test]
fn test_advance_from_idle() {
let mut p = make_pipeline();
assert_eq!(p.stage(), PruningStage::Idle);
p.advance();
assert_eq!(p.stage(), PruningStage::Calibrating);
}
#[test]
fn test_advance_full_pipeline_with_finetune() {
let mut p = make_pipeline();
p.advance();
assert_eq!(p.stage(), PruningStage::Calibrating);
p.advance();
assert_eq!(p.stage(), PruningStage::ComputingImportance);
p.advance();
assert_eq!(p.stage(), PruningStage::Pruning);
p.advance();
assert_eq!(p.stage(), PruningStage::FineTuning);
p.advance();
assert_eq!(p.stage(), PruningStage::Evaluating);
p.advance();
assert_eq!(p.stage(), PruningStage::Exporting);
p.advance();
assert_eq!(p.stage(), PruningStage::Complete);
p.advance();
assert_eq!(p.stage(), PruningStage::Complete);
}
#[test]
fn test_advance_skip_finetune() {
let config = PruningConfig::new().with_fine_tune(false);
let mut p = PruneFinetunePipeline::new(config);
p.advance(); p.advance(); p.advance(); p.advance();
assert_eq!(p.stage(), PruningStage::Evaluating);
}
#[test]
fn test_advance_failed_stays_failed() {
let mut p = make_pipeline();
p.fail("test error");
assert_eq!(p.stage(), PruningStage::Failed);
p.advance();
assert_eq!(p.stage(), PruningStage::Failed);
}
#[test]
fn test_overall_progress_all_stages() {
let mut p = make_pipeline();
assert_eq!(p.overall_progress(), 0.0);
p.advance();
assert!(p.overall_progress() >= 0.1);
p.advance();
assert_eq!(p.overall_progress(), 0.25);
p.advance();
assert_eq!(p.overall_progress(), 0.4);
p.advance();
assert_eq!(p.overall_progress(), 0.6);
p.advance();
assert_eq!(p.overall_progress(), 0.8);
p.advance();
assert_eq!(p.overall_progress(), 0.95);
p.advance();
assert_eq!(p.overall_progress(), 1.0);
}
#[test]
fn test_overall_progress_failed() {
let mut p = make_pipeline();
p.fail("test");
assert_eq!(p.overall_progress(), 0.0);
}
#[test]
fn test_reset_to_idle() {
let mut p = make_pipeline();
p.advance();
p.advance();
p.reset();
assert_eq!(p.stage(), PruningStage::Idle);
assert!(p.error().is_none());
}
}