Skip to main content

entrenar/prune/pipeline/
orchestrator.rs

1//! Prune-Finetune pipeline orchestrator
2//!
3//! Manages the full pruning workflow from calibration through export.
4
5use super::metrics::PruningMetrics;
6use super::stage::PruningStage;
7use crate::prune::calibrate::CalibrationCollector;
8use crate::prune::config::PruningConfig;
9
10/// Prune-Finetune pipeline orchestrator.
11///
12/// Manages the full pruning workflow from calibration through export.
13#[derive(Debug)]
14pub struct PruneFinetunePipeline {
15    /// Configuration.
16    config: PruningConfig,
17    /// Current stage.
18    stage: PruningStage,
19    /// Collected metrics.
20    metrics: PruningMetrics,
21    /// Calibration collector.
22    calibration: Option<CalibrationCollector>,
23    /// Error message if failed.
24    error: Option<String>,
25}
26
27impl PruneFinetunePipeline {
28    /// Create a new pipeline with the given configuration.
29    pub fn new(config: PruningConfig) -> Self {
30        let metrics = PruningMetrics::new(config.target_sparsity());
31        Self { config, stage: PruningStage::Idle, metrics, calibration: None, error: None }
32    }
33
34    /// Get the current stage.
35    pub fn stage(&self) -> PruningStage {
36        self.stage
37    }
38
39    /// Get the configuration.
40    pub fn config(&self) -> &PruningConfig {
41        &self.config
42    }
43
44    /// Get the collected metrics.
45    pub fn metrics(&self) -> &PruningMetrics {
46        &self.metrics
47    }
48
49    /// Get mutable access to metrics.
50    pub fn metrics_mut(&mut self) -> &mut PruningMetrics {
51        &mut self.metrics
52    }
53
54    /// Get the error message if failed.
55    pub fn error(&self) -> Option<&str> {
56        self.error.as_deref()
57    }
58
59    /// Start the calibration stage.
60    pub fn start_calibration(&mut self, calibration: CalibrationCollector) {
61        if self.stage != PruningStage::Idle {
62            return;
63        }
64        self.calibration = Some(calibration);
65        self.stage = PruningStage::Calibrating;
66    }
67
68    /// Advance to the next stage.
69    pub fn advance(&mut self) {
70        self.stage = match self.stage {
71            PruningStage::Idle => PruningStage::Calibrating,
72            PruningStage::Calibrating => PruningStage::ComputingImportance,
73            PruningStage::ComputingImportance => PruningStage::Pruning,
74            PruningStage::Pruning => {
75                if self.config.fine_tune_after_pruning() {
76                    PruningStage::FineTuning
77                } else {
78                    PruningStage::Evaluating
79                }
80            }
81            PruningStage::FineTuning => PruningStage::Evaluating,
82            PruningStage::Evaluating => PruningStage::Exporting,
83            PruningStage::Exporting => PruningStage::Complete,
84            // Terminal states don't advance
85            PruningStage::Complete | PruningStage::Failed => self.stage,
86        };
87    }
88
89    /// Mark the pipeline as failed with an error message.
90    pub fn fail(&mut self, error: impl Into<String>) {
91        self.error = Some(error.into());
92        self.stage = PruningStage::Failed;
93    }
94
95    /// Execute the export stage.
96    ///
97    /// This is called when the pipeline reaches the `Exporting` stage.
98    /// Exports the pruned model weights and sparsity metadata.
99    ///
100    /// Returns `Ok(())` and advances to `Complete` on success, or
101    /// sets the pipeline to `Failed` on error.
102    pub fn execute_export(
103        &mut self,
104        weights: &std::collections::HashMap<String, Vec<f32>>,
105        shapes: &std::collections::HashMap<String, Vec<usize>>,
106        output_dir: impl AsRef<std::path::Path>,
107        filename: &str,
108    ) -> Result<super::sparse_export::SparseExportResult, String> {
109        if self.stage != PruningStage::Exporting {
110            return Err(format!("Cannot export in stage {:?}, expected Exporting", self.stage));
111        }
112
113        match super::sparse_export::export_sparse_model(
114            weights,
115            shapes,
116            &self.metrics,
117            output_dir,
118            filename,
119        ) {
120            Ok(result) => {
121                self.advance(); // -> Complete
122                Ok(result)
123            }
124            Err(e) => {
125                self.fail(format!("Export failed: {e}"));
126                Err(format!("Export failed: {e}"))
127            }
128        }
129    }
130
131    /// Reset the pipeline to idle state.
132    pub fn reset(&mut self) {
133        self.stage = PruningStage::Idle;
134        self.metrics = PruningMetrics::new(self.config.target_sparsity());
135        self.calibration = None;
136        self.error = None;
137    }
138
139    /// Check if the pipeline is complete (success or failure).
140    pub fn is_complete(&self) -> bool {
141        self.stage.is_terminal()
142    }
143
144    /// Check if the pipeline succeeded.
145    pub fn succeeded(&self) -> bool {
146        self.stage == PruningStage::Complete
147    }
148
149    /// Check if the pipeline failed.
150    pub fn failed(&self) -> bool {
151        self.stage == PruningStage::Failed
152    }
153
154    /// Get calibration collector if available.
155    pub fn calibration(&self) -> Option<&CalibrationCollector> {
156        self.calibration.as_ref()
157    }
158
159    /// Get calibration progress (0.0 to 1.0).
160    pub fn calibration_progress(&self) -> f32 {
161        self.calibration.as_ref().map_or(0.0, CalibrationCollector::progress)
162    }
163
164    /// Get overall pipeline progress (0.0 to 1.0).
165    pub fn overall_progress(&self) -> f32 {
166        match self.stage {
167            PruningStage::Idle => 0.0,
168            PruningStage::Calibrating => 0.1 + 0.1 * self.calibration_progress(),
169            PruningStage::ComputingImportance => 0.25,
170            PruningStage::Pruning => 0.4,
171            PruningStage::FineTuning => 0.6,
172            PruningStage::Evaluating => 0.8,
173            PruningStage::Exporting => 0.95,
174            PruningStage::Complete => 1.0,
175            PruningStage::Failed => 0.0, // Reset on failure
176        }
177    }
178}
179
180impl Clone for PruneFinetunePipeline {
181    fn clone(&self) -> Self {
182        Self {
183            config: self.config.clone(),
184            stage: self.stage,
185            metrics: self.metrics.clone(),
186            calibration: self.calibration.clone(),
187            error: self.error.clone(),
188        }
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    fn make_pipeline() -> PruneFinetunePipeline {
197        PruneFinetunePipeline::new(PruningConfig::new())
198    }
199
200    #[test]
201    fn test_advance_from_idle() {
202        let mut p = make_pipeline();
203        assert_eq!(p.stage(), PruningStage::Idle);
204        p.advance();
205        assert_eq!(p.stage(), PruningStage::Calibrating);
206    }
207
208    #[test]
209    fn test_advance_full_pipeline_with_finetune() {
210        // Default config has fine_tune_after_pruning=true
211        let mut p = make_pipeline();
212
213        // Idle → Calibrating
214        p.advance();
215        assert_eq!(p.stage(), PruningStage::Calibrating);
216
217        // Calibrating → ComputingImportance
218        p.advance();
219        assert_eq!(p.stage(), PruningStage::ComputingImportance);
220
221        // ComputingImportance → Pruning
222        p.advance();
223        assert_eq!(p.stage(), PruningStage::Pruning);
224
225        // Pruning → FineTuning (fine_tune_after_pruning=true)
226        p.advance();
227        assert_eq!(p.stage(), PruningStage::FineTuning);
228
229        // FineTuning → Evaluating
230        p.advance();
231        assert_eq!(p.stage(), PruningStage::Evaluating);
232
233        // Evaluating → Exporting
234        p.advance();
235        assert_eq!(p.stage(), PruningStage::Exporting);
236
237        // Exporting → Complete
238        p.advance();
239        assert_eq!(p.stage(), PruningStage::Complete);
240
241        // Complete stays Complete
242        p.advance();
243        assert_eq!(p.stage(), PruningStage::Complete);
244    }
245
246    #[test]
247    fn test_advance_skip_finetune() {
248        let config = PruningConfig::new().with_fine_tune(false);
249        let mut p = PruneFinetunePipeline::new(config);
250        // Advance to Pruning
251        p.advance(); // Calibrating
252        p.advance(); // ComputingImportance
253        p.advance(); // Pruning
254                     // Pruning → Evaluating (fine_tune_after_pruning=false)
255        p.advance();
256        assert_eq!(p.stage(), PruningStage::Evaluating);
257    }
258
259    #[test]
260    fn test_advance_failed_stays_failed() {
261        let mut p = make_pipeline();
262        p.fail("test error");
263        assert_eq!(p.stage(), PruningStage::Failed);
264        p.advance();
265        assert_eq!(p.stage(), PruningStage::Failed);
266    }
267
268    #[test]
269    fn test_overall_progress_all_stages() {
270        // Default config has fine_tune_after_pruning=true
271        let mut p = make_pipeline();
272
273        // Idle → 0.0
274        assert_eq!(p.overall_progress(), 0.0);
275
276        // Calibrating → ~0.1
277        p.advance();
278        assert!(p.overall_progress() >= 0.1);
279
280        // ComputingImportance → 0.25
281        p.advance();
282        assert_eq!(p.overall_progress(), 0.25);
283
284        // Pruning → 0.4
285        p.advance();
286        assert_eq!(p.overall_progress(), 0.4);
287
288        // FineTuning → 0.6
289        p.advance();
290        assert_eq!(p.overall_progress(), 0.6);
291
292        // Evaluating → 0.8
293        p.advance();
294        assert_eq!(p.overall_progress(), 0.8);
295
296        // Exporting → 0.95
297        p.advance();
298        assert_eq!(p.overall_progress(), 0.95);
299
300        // Complete → 1.0
301        p.advance();
302        assert_eq!(p.overall_progress(), 1.0);
303    }
304
305    #[test]
306    fn test_overall_progress_failed() {
307        let mut p = make_pipeline();
308        p.fail("test");
309        assert_eq!(p.overall_progress(), 0.0);
310    }
311
312    #[test]
313    fn test_reset_to_idle() {
314        let mut p = make_pipeline();
315        p.advance();
316        p.advance();
317        p.reset();
318        assert_eq!(p.stage(), PruningStage::Idle);
319        assert!(p.error().is_none());
320    }
321}