Skip to main content

entrenar/prune/trainer_integration/
trainer.rs

1//! Prune-finetune trainer implementation.
2//!
3//! Provides the main trainer that orchestrates the full pruning pipeline.
4
5use crate::prune::calibrate::CalibrationCollector;
6use crate::prune::data_loader::CalibrationDataLoader;
7use crate::prune::pipeline::{PruneFinetunePipeline, PruningMetrics, PruningStage};
8
9use super::config::PruneTrainerConfig;
10
11/// Prune-finetune trainer that orchestrates the full pipeline.
12///
13/// # Example
14///
15/// ```ignore
16/// use entrenar::prune::{PruneTrainer, PruneTrainerConfig, PruningConfig};
17///
18/// fn example() -> Result<(), Box<dyn std::error::Error>> {
19///     let config = PruneTrainerConfig::new()
20///         .with_pruning(PruningConfig::default().with_target_sparsity(0.5))
21///         .with_finetune_epochs(3);
22///
23///     let mut trainer = PruneTrainer::new(config);
24///     trainer.run()?;
25///     Ok(())
26/// }
27/// ```
28#[derive(Debug)]
29pub struct PruneTrainer {
30    /// Configuration.
31    config: PruneTrainerConfig,
32    /// Pipeline state.
33    pipeline: PruneFinetunePipeline,
34    /// Calibration data loader.
35    data_loader: CalibrationDataLoader,
36    /// Calibration collector.
37    pub(crate) calibration: Option<CalibrationCollector>,
38    /// Current epoch in fine-tuning.
39    current_epoch: usize,
40}
41
42impl PruneTrainer {
43    /// Create a new prune trainer.
44    pub fn new(config: PruneTrainerConfig) -> Self {
45        let pipeline = PruneFinetunePipeline::new(config.pruning.clone());
46        let data_loader = CalibrationDataLoader::new(config.calibration.clone());
47
48        Self { config, pipeline, data_loader, calibration: None, current_epoch: 0 }
49    }
50
51    /// Get the configuration.
52    pub fn config(&self) -> &PruneTrainerConfig {
53        &self.config
54    }
55
56    /// Get the current pipeline state.
57    pub fn pipeline(&self) -> &PruneFinetunePipeline {
58        &self.pipeline
59    }
60
61    /// Get mutable access to the pipeline.
62    pub fn pipeline_mut(&mut self) -> &mut PruneFinetunePipeline {
63        &mut self.pipeline
64    }
65
66    /// Get the current stage.
67    pub fn stage(&self) -> PruningStage {
68        self.pipeline.stage()
69    }
70
71    /// Get the metrics.
72    pub fn metrics(&self) -> &PruningMetrics {
73        self.pipeline.metrics()
74    }
75
76    /// Get the current fine-tuning epoch.
77    pub fn current_epoch(&self) -> usize {
78        self.current_epoch
79    }
80
81    /// Check if the trainer is complete.
82    pub fn is_complete(&self) -> bool {
83        self.pipeline.is_complete()
84    }
85
86    /// Check if training succeeded.
87    pub fn succeeded(&self) -> bool {
88        self.pipeline.succeeded()
89    }
90
91    /// Get the error message if failed.
92    pub fn error(&self) -> Option<&str> {
93        self.pipeline.error()
94    }
95
96    /// Initialize the trainer and load calibration data.
97    pub fn initialize(&mut self) -> Result<(), String> {
98        // Validate configuration
99        self.config.validate()?;
100
101        // Load calibration data
102        self.data_loader.load()?;
103
104        // Initialize calibration collector if needed
105        if self.config.pruning.requires_calibration() {
106            self.calibration = Some(CalibrationCollector::new(
107                crate::prune::calibrate::CalibrationConfig::new()
108                    .with_num_samples(self.config.calibration.num_samples()),
109            ));
110        }
111
112        Ok(())
113    }
114
115    /// Run calibration phase.
116    ///
117    /// In a real implementation, this would:
118    /// 1. Run forward passes through the model
119    /// 2. Collect activation statistics for each layer
120    /// 3. Store statistics in the calibration collector
121    pub fn calibrate(&mut self) -> Result<(), String> {
122        if self.pipeline.stage() != PruningStage::Idle
123            && self.pipeline.stage() != PruningStage::Calibrating
124        {
125            return Err("Cannot calibrate in current stage".to_string());
126        }
127
128        // Initialize calibration if not done
129        if self.calibration.is_none() && self.config.pruning.requires_calibration() {
130            self.calibration = Some(CalibrationCollector::new(
131                crate::prune::calibrate::CalibrationConfig::new()
132                    .with_num_samples(self.config.calibration.num_samples()),
133            ));
134        }
135
136        // Start calibration stage if at Idle
137        if self.pipeline.stage() == PruningStage::Idle {
138            if let Some(cal) = self.calibration.take() {
139                self.pipeline.start_calibration(cal);
140            } else {
141                // No calibration needed, advance from Idle to Calibrating
142                self.pipeline.advance();
143            }
144        }
145
146        // Process calibration batches
147        for _batch in &self.data_loader {
148            // In real implementation:
149            // 1. Forward pass through model
150            // 2. Extract activations at each layer
151            // 3. Update calibration statistics
152        }
153
154        // Advance from Calibrating to ComputingImportance
155        if self.pipeline.stage() == PruningStage::Calibrating {
156            self.pipeline.advance();
157        }
158
159        Ok(())
160    }
161
162    /// Run the pruning phase.
163    ///
164    /// In a real implementation, this would:
165    /// 1. Compute importance scores using calibration data
166    /// 2. Generate sparsity masks
167    /// 3. Apply masks to model weights
168    pub fn prune(&mut self) -> Result<(), String> {
169        // Advance through importance computation
170        while self.pipeline.stage() == PruningStage::ComputingImportance {
171            // Compute importance scores
172            self.pipeline.advance();
173        }
174
175        if self.pipeline.stage() != PruningStage::Pruning {
176            return Err(format!("Cannot prune in stage {:?}", self.pipeline.stage()));
177        }
178
179        // In real implementation:
180        // 1. Generate masks based on importance and target sparsity
181        // 2. Apply masks to model weights
182        // 3. Record metrics
183
184        let target_sparsity = self.config.pruning.target_sparsity();
185        self.pipeline.metrics_mut().target_sparsity = target_sparsity;
186        self.pipeline.metrics_mut().achieved_sparsity = target_sparsity;
187
188        self.pipeline.advance();
189        Ok(())
190    }
191
192    /// Run the fine-tuning phase.
193    ///
194    /// In a real implementation, this would:
195    /// 1. Set up optimizer with fine-tuning learning rate
196    /// 2. Run training epochs
197    /// 3. Track loss and metrics
198    pub fn finetune(&mut self) -> Result<(), String> {
199        if self.pipeline.stage() != PruningStage::FineTuning {
200            return Err(format!("Cannot finetune in stage {:?}", self.pipeline.stage()));
201        }
202
203        for epoch in 0..self.config.finetune_epochs {
204            self.current_epoch = epoch;
205
206            // In real implementation:
207            // 1. Run training epoch
208            // 2. Track loss
209            // 3. Optionally save checkpoint
210
211            // Simulate loss decrease
212            let loss = 1.0 / (epoch + 1) as f32;
213            self.pipeline.metrics_mut().record_finetune_loss(loss);
214        }
215
216        self.pipeline.advance();
217        Ok(())
218    }
219
220    /// Run evaluation phase.
221    pub fn evaluate(&mut self) -> Result<(), String> {
222        if self.pipeline.stage() != PruningStage::Evaluating {
223            return Err(format!("Cannot evaluate in stage {:?}", self.pipeline.stage()));
224        }
225
226        // In real implementation:
227        // 1. Run evaluation on validation set
228        // 2. Compute perplexity/accuracy
229        // 3. Record metrics
230
231        self.pipeline.advance();
232        Ok(())
233    }
234
235    /// Run export phase.
236    pub fn export(&mut self) -> Result<(), String> {
237        if self.pipeline.stage() != PruningStage::Exporting {
238            return Err(format!("Cannot export in stage {:?}", self.pipeline.stage()));
239        }
240
241        // In real implementation:
242        // 1. Save pruned model
243        // 2. Export to desired format (SafeTensors, GGUF, etc.)
244
245        self.pipeline.advance();
246        Ok(())
247    }
248
249    /// Run the full prune-finetune pipeline.
250    pub fn run(&mut self) -> Result<PruningMetrics, String> {
251        self.initialize()?;
252        self.calibrate()?;
253        self.prune()?;
254
255        if self.config.pruning.fine_tune_after_pruning() {
256            self.finetune()?;
257        }
258
259        if self.config.evaluate_pre_post {
260            self.evaluate()?;
261        }
262
263        self.export()?;
264
265        Ok(self.metrics().clone())
266    }
267
268    /// Reset the trainer to initial state.
269    pub fn reset(&mut self) {
270        self.pipeline.reset();
271        self.calibration = None;
272        self.current_epoch = 0;
273        self.data_loader.reset();
274    }
275}
276
277impl Clone for PruneTrainer {
278    fn clone(&self) -> Self {
279        Self {
280            config: self.config.clone(),
281            pipeline: self.pipeline.clone(),
282            data_loader: self.data_loader.clone(),
283            calibration: self.calibration.clone(),
284            current_epoch: self.current_epoch,
285        }
286    }
287}