entrenar/prune/trainer_integration/
trainer.rs1use crate::prune::calibrate::CalibrationCollector;
6use crate::prune::data_loader::CalibrationDataLoader;
7use crate::prune::pipeline::{PruneFinetunePipeline, PruningMetrics, PruningStage};
8
9use super::config::PruneTrainerConfig;
10
11#[derive(Debug)]
29pub struct PruneTrainer {
30 config: PruneTrainerConfig,
32 pipeline: PruneFinetunePipeline,
34 data_loader: CalibrationDataLoader,
36 pub(crate) calibration: Option<CalibrationCollector>,
38 current_epoch: usize,
40}
41
42impl PruneTrainer {
43 pub fn new(config: PruneTrainerConfig) -> Self {
45 let pipeline = PruneFinetunePipeline::new(config.pruning.clone());
46 let data_loader = CalibrationDataLoader::new(config.calibration.clone());
47
48 Self { config, pipeline, data_loader, calibration: None, current_epoch: 0 }
49 }
50
51 pub fn config(&self) -> &PruneTrainerConfig {
53 &self.config
54 }
55
56 pub fn pipeline(&self) -> &PruneFinetunePipeline {
58 &self.pipeline
59 }
60
61 pub fn pipeline_mut(&mut self) -> &mut PruneFinetunePipeline {
63 &mut self.pipeline
64 }
65
66 pub fn stage(&self) -> PruningStage {
68 self.pipeline.stage()
69 }
70
71 pub fn metrics(&self) -> &PruningMetrics {
73 self.pipeline.metrics()
74 }
75
76 pub fn current_epoch(&self) -> usize {
78 self.current_epoch
79 }
80
81 pub fn is_complete(&self) -> bool {
83 self.pipeline.is_complete()
84 }
85
86 pub fn succeeded(&self) -> bool {
88 self.pipeline.succeeded()
89 }
90
91 pub fn error(&self) -> Option<&str> {
93 self.pipeline.error()
94 }
95
96 pub fn initialize(&mut self) -> Result<(), String> {
98 self.config.validate()?;
100
101 self.data_loader.load()?;
103
104 if self.config.pruning.requires_calibration() {
106 self.calibration = Some(CalibrationCollector::new(
107 crate::prune::calibrate::CalibrationConfig::new()
108 .with_num_samples(self.config.calibration.num_samples()),
109 ));
110 }
111
112 Ok(())
113 }
114
115 pub fn calibrate(&mut self) -> Result<(), String> {
122 if self.pipeline.stage() != PruningStage::Idle
123 && self.pipeline.stage() != PruningStage::Calibrating
124 {
125 return Err("Cannot calibrate in current stage".to_string());
126 }
127
128 if self.calibration.is_none() && self.config.pruning.requires_calibration() {
130 self.calibration = Some(CalibrationCollector::new(
131 crate::prune::calibrate::CalibrationConfig::new()
132 .with_num_samples(self.config.calibration.num_samples()),
133 ));
134 }
135
136 if self.pipeline.stage() == PruningStage::Idle {
138 if let Some(cal) = self.calibration.take() {
139 self.pipeline.start_calibration(cal);
140 } else {
141 self.pipeline.advance();
143 }
144 }
145
146 for _batch in &self.data_loader {
148 }
153
154 if self.pipeline.stage() == PruningStage::Calibrating {
156 self.pipeline.advance();
157 }
158
159 Ok(())
160 }
161
162 pub fn prune(&mut self) -> Result<(), String> {
169 while self.pipeline.stage() == PruningStage::ComputingImportance {
171 self.pipeline.advance();
173 }
174
175 if self.pipeline.stage() != PruningStage::Pruning {
176 return Err(format!("Cannot prune in stage {:?}", self.pipeline.stage()));
177 }
178
179 let target_sparsity = self.config.pruning.target_sparsity();
185 self.pipeline.metrics_mut().target_sparsity = target_sparsity;
186 self.pipeline.metrics_mut().achieved_sparsity = target_sparsity;
187
188 self.pipeline.advance();
189 Ok(())
190 }
191
192 pub fn finetune(&mut self) -> Result<(), String> {
199 if self.pipeline.stage() != PruningStage::FineTuning {
200 return Err(format!("Cannot finetune in stage {:?}", self.pipeline.stage()));
201 }
202
203 for epoch in 0..self.config.finetune_epochs {
204 self.current_epoch = epoch;
205
206 let loss = 1.0 / (epoch + 1) as f32;
213 self.pipeline.metrics_mut().record_finetune_loss(loss);
214 }
215
216 self.pipeline.advance();
217 Ok(())
218 }
219
220 pub fn evaluate(&mut self) -> Result<(), String> {
222 if self.pipeline.stage() != PruningStage::Evaluating {
223 return Err(format!("Cannot evaluate in stage {:?}", self.pipeline.stage()));
224 }
225
226 self.pipeline.advance();
232 Ok(())
233 }
234
235 pub fn export(&mut self) -> Result<(), String> {
237 if self.pipeline.stage() != PruningStage::Exporting {
238 return Err(format!("Cannot export in stage {:?}", self.pipeline.stage()));
239 }
240
241 self.pipeline.advance();
246 Ok(())
247 }
248
249 pub fn run(&mut self) -> Result<PruningMetrics, String> {
251 self.initialize()?;
252 self.calibrate()?;
253 self.prune()?;
254
255 if self.config.pruning.fine_tune_after_pruning() {
256 self.finetune()?;
257 }
258
259 if self.config.evaluate_pre_post {
260 self.evaluate()?;
261 }
262
263 self.export()?;
264
265 Ok(self.metrics().clone())
266 }
267
268 pub fn reset(&mut self) {
270 self.pipeline.reset();
271 self.calibration = None;
272 self.current_epoch = 0;
273 self.data_loader.reset();
274 }
275}
276
277impl Clone for PruneTrainer {
278 fn clone(&self) -> Self {
279 Self {
280 config: self.config.clone(),
281 pipeline: self.pipeline.clone(),
282 data_loader: self.data_loader.clone(),
283 calibration: self.calibration.clone(),
284 current_epoch: self.current_epoch,
285 }
286 }
287}