entrenar/prune/pipeline/
orchestrator.rs1use super::metrics::PruningMetrics;
6use super::stage::PruningStage;
7use crate::prune::calibrate::CalibrationCollector;
8use crate::prune::config::PruningConfig;
9
10#[derive(Debug)]
14pub struct PruneFinetunePipeline {
15 config: PruningConfig,
17 stage: PruningStage,
19 metrics: PruningMetrics,
21 calibration: Option<CalibrationCollector>,
23 error: Option<String>,
25}
26
27impl PruneFinetunePipeline {
28 pub fn new(config: PruningConfig) -> Self {
30 let metrics = PruningMetrics::new(config.target_sparsity());
31 Self { config, stage: PruningStage::Idle, metrics, calibration: None, error: None }
32 }
33
34 pub fn stage(&self) -> PruningStage {
36 self.stage
37 }
38
39 pub fn config(&self) -> &PruningConfig {
41 &self.config
42 }
43
44 pub fn metrics(&self) -> &PruningMetrics {
46 &self.metrics
47 }
48
49 pub fn metrics_mut(&mut self) -> &mut PruningMetrics {
51 &mut self.metrics
52 }
53
54 pub fn error(&self) -> Option<&str> {
56 self.error.as_deref()
57 }
58
59 pub fn start_calibration(&mut self, calibration: CalibrationCollector) {
61 if self.stage != PruningStage::Idle {
62 return;
63 }
64 self.calibration = Some(calibration);
65 self.stage = PruningStage::Calibrating;
66 }
67
68 pub fn advance(&mut self) {
70 self.stage = match self.stage {
71 PruningStage::Idle => PruningStage::Calibrating,
72 PruningStage::Calibrating => PruningStage::ComputingImportance,
73 PruningStage::ComputingImportance => PruningStage::Pruning,
74 PruningStage::Pruning => {
75 if self.config.fine_tune_after_pruning() {
76 PruningStage::FineTuning
77 } else {
78 PruningStage::Evaluating
79 }
80 }
81 PruningStage::FineTuning => PruningStage::Evaluating,
82 PruningStage::Evaluating => PruningStage::Exporting,
83 PruningStage::Exporting => PruningStage::Complete,
84 PruningStage::Complete | PruningStage::Failed => self.stage,
86 };
87 }
88
89 pub fn fail(&mut self, error: impl Into<String>) {
91 self.error = Some(error.into());
92 self.stage = PruningStage::Failed;
93 }
94
95 pub fn execute_export(
103 &mut self,
104 weights: &std::collections::HashMap<String, Vec<f32>>,
105 shapes: &std::collections::HashMap<String, Vec<usize>>,
106 output_dir: impl AsRef<std::path::Path>,
107 filename: &str,
108 ) -> Result<super::sparse_export::SparseExportResult, String> {
109 if self.stage != PruningStage::Exporting {
110 return Err(format!("Cannot export in stage {:?}, expected Exporting", self.stage));
111 }
112
113 match super::sparse_export::export_sparse_model(
114 weights,
115 shapes,
116 &self.metrics,
117 output_dir,
118 filename,
119 ) {
120 Ok(result) => {
121 self.advance(); Ok(result)
123 }
124 Err(e) => {
125 self.fail(format!("Export failed: {e}"));
126 Err(format!("Export failed: {e}"))
127 }
128 }
129 }
130
131 pub fn reset(&mut self) {
133 self.stage = PruningStage::Idle;
134 self.metrics = PruningMetrics::new(self.config.target_sparsity());
135 self.calibration = None;
136 self.error = None;
137 }
138
139 pub fn is_complete(&self) -> bool {
141 self.stage.is_terminal()
142 }
143
144 pub fn succeeded(&self) -> bool {
146 self.stage == PruningStage::Complete
147 }
148
149 pub fn failed(&self) -> bool {
151 self.stage == PruningStage::Failed
152 }
153
154 pub fn calibration(&self) -> Option<&CalibrationCollector> {
156 self.calibration.as_ref()
157 }
158
159 pub fn calibration_progress(&self) -> f32 {
161 self.calibration.as_ref().map_or(0.0, CalibrationCollector::progress)
162 }
163
164 pub fn overall_progress(&self) -> f32 {
166 match self.stage {
167 PruningStage::Idle => 0.0,
168 PruningStage::Calibrating => 0.1 + 0.1 * self.calibration_progress(),
169 PruningStage::ComputingImportance => 0.25,
170 PruningStage::Pruning => 0.4,
171 PruningStage::FineTuning => 0.6,
172 PruningStage::Evaluating => 0.8,
173 PruningStage::Exporting => 0.95,
174 PruningStage::Complete => 1.0,
175 PruningStage::Failed => 0.0, }
177 }
178}
179
180impl Clone for PruneFinetunePipeline {
181 fn clone(&self) -> Self {
182 Self {
183 config: self.config.clone(),
184 stage: self.stage,
185 metrics: self.metrics.clone(),
186 calibration: self.calibration.clone(),
187 error: self.error.clone(),
188 }
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 fn make_pipeline() -> PruneFinetunePipeline {
197 PruneFinetunePipeline::new(PruningConfig::new())
198 }
199
200 #[test]
201 fn test_advance_from_idle() {
202 let mut p = make_pipeline();
203 assert_eq!(p.stage(), PruningStage::Idle);
204 p.advance();
205 assert_eq!(p.stage(), PruningStage::Calibrating);
206 }
207
208 #[test]
209 fn test_advance_full_pipeline_with_finetune() {
210 let mut p = make_pipeline();
212
213 p.advance();
215 assert_eq!(p.stage(), PruningStage::Calibrating);
216
217 p.advance();
219 assert_eq!(p.stage(), PruningStage::ComputingImportance);
220
221 p.advance();
223 assert_eq!(p.stage(), PruningStage::Pruning);
224
225 p.advance();
227 assert_eq!(p.stage(), PruningStage::FineTuning);
228
229 p.advance();
231 assert_eq!(p.stage(), PruningStage::Evaluating);
232
233 p.advance();
235 assert_eq!(p.stage(), PruningStage::Exporting);
236
237 p.advance();
239 assert_eq!(p.stage(), PruningStage::Complete);
240
241 p.advance();
243 assert_eq!(p.stage(), PruningStage::Complete);
244 }
245
246 #[test]
247 fn test_advance_skip_finetune() {
248 let config = PruningConfig::new().with_fine_tune(false);
249 let mut p = PruneFinetunePipeline::new(config);
250 p.advance(); p.advance(); p.advance(); p.advance();
256 assert_eq!(p.stage(), PruningStage::Evaluating);
257 }
258
259 #[test]
260 fn test_advance_failed_stays_failed() {
261 let mut p = make_pipeline();
262 p.fail("test error");
263 assert_eq!(p.stage(), PruningStage::Failed);
264 p.advance();
265 assert_eq!(p.stage(), PruningStage::Failed);
266 }
267
268 #[test]
269 fn test_overall_progress_all_stages() {
270 let mut p = make_pipeline();
272
273 assert_eq!(p.overall_progress(), 0.0);
275
276 p.advance();
278 assert!(p.overall_progress() >= 0.1);
279
280 p.advance();
282 assert_eq!(p.overall_progress(), 0.25);
283
284 p.advance();
286 assert_eq!(p.overall_progress(), 0.4);
287
288 p.advance();
290 assert_eq!(p.overall_progress(), 0.6);
291
292 p.advance();
294 assert_eq!(p.overall_progress(), 0.8);
295
296 p.advance();
298 assert_eq!(p.overall_progress(), 0.95);
299
300 p.advance();
302 assert_eq!(p.overall_progress(), 1.0);
303 }
304
305 #[test]
306 fn test_overall_progress_failed() {
307 let mut p = make_pipeline();
308 p.fail("test");
309 assert_eq!(p.overall_progress(), 0.0);
310 }
311
312 #[test]
313 fn test_reset_to_idle() {
314 let mut p = make_pipeline();
315 p.advance();
316 p.advance();
317 p.reset();
318 assert_eq!(p.stage(), PruningStage::Idle);
319 assert!(p.error().is_none());
320 }
321}