Skip to main content

burn_train/learner/
base.rs

1use crate::LearningComponentsMarker;
2use crate::checkpoint::{
3    AsyncCheckpointer, Checkpointer, CheckpointingAction, CheckpointingStrategy,
4};
5use crate::components::{LearningComponentsTypes, TrainingBackend};
6use crate::metric::store::EventStoreClient;
7use crate::{
8    CloneEarlyStoppingStrategy, InferenceStep, TrainOutput, TrainStep, TrainingModelInput,
9    TrainingModelOutput,
10};
11use burn_core::module::{AutodiffModule, Module};
12use burn_core::tensor::Device;
13use burn_core::tensor::backend::AutodiffBackend;
14use burn_optim::lr_scheduler::LrScheduler;
15use burn_optim::{GradientsParams, MultiGradientsParams, Optimizer};
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::{Arc, Mutex};
18
19/// The record of the learner's model.
20pub type LearnerModelRecord<LC> =
21    <<LC as LearningComponentsTypes>::TrainingModel as Module<TrainingBackend<LC>>>::Record;
22/// The record of the optimizer.
23pub type LearnerOptimizerRecord<LC> = <<LC as LearningComponentsTypes>::Optimizer as Optimizer<
24    <LC as LearningComponentsTypes>::TrainingModel,
25    TrainingBackend<LC>,
26>>::Record;
27/// The record of the LR scheduler.
28pub type LearnerSchedulerRecord<LC> =
29    <<LC as LearningComponentsTypes>::LrScheduler as LrScheduler>::Record<TrainingBackend<LC>>;
30
31/// Learner struct encapsulating all components necessary to train a Neural Network model.
32pub struct Learner<LC: LearningComponentsTypes> {
33    pub(crate) model: LC::TrainingModel,
34    optim: LC::Optimizer,
35    lr_scheduler: LC::LrScheduler,
36    lr: f64,
37}
38
39impl<LC: LearningComponentsTypes> Clone for Learner<LC> {
40    fn clone(&self) -> Self {
41        Self {
42            model: self.model.clone(),
43            optim: self.optim.clone(),
44            lr_scheduler: self.lr_scheduler.clone(),
45            lr: self.lr,
46        }
47    }
48}
49
50impl<B, LR, M, O> Learner<LearningComponentsMarker<B, LR, M, O>>
51where
52    B: AutodiffBackend,
53    LR: LrScheduler + 'static,
54    M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
55    M::InnerModule: InferenceStep,
56    O: Optimizer<M, B> + 'static,
57{
58    /// Create a learner.
59    pub fn new(model: M, optim: O, lr_scheduler: LR) -> Self {
60        Self {
61            model,
62            optim,
63            lr_scheduler,
64            lr: 0.0,
65        }
66    }
67}
68
69impl<LC: LearningComponentsTypes> Learner<LC> {
70    /// Fork the learner's model to the given device.
71    pub fn fork(&mut self, device: &Device<TrainingBackend<LC>>) {
72        self.model = self.model().fork(device);
73    }
74
75    /// Returns the current model.
76    pub fn model(&self) -> LC::TrainingModel {
77        self.model.clone()
78    }
79
80    /// Returns the current learning rate.
81    pub fn lr_current(&self) -> f64 {
82        self.lr
83    }
84
85    /// Executes a step of the learning rate scheduler.
86    pub fn lr_step(&mut self) {
87        self.lr = self.lr_scheduler.step();
88    }
89
90    /// Runs a step of the model for training, which executes the forward and backward passes.
91    ///
92    /// # Arguments
93    ///
94    /// * `item` - The input for the model.
95    ///
96    /// # Returns
97    ///
98    /// The output containing the model output and the gradients.
99    pub fn train_step(&self, item: TrainingModelInput<LC>) -> TrainOutput<TrainingModelOutput<LC>> {
100        self.model.step(item)
101    }
102
103    /// Optimize the current module with the provided gradients and learning rate.
104    ///
105    /// # Arguments
106    ///
107    /// * `optim`: Optimizer used for learning.
108    /// * `lr`: The learning rate used for this step.
109    /// * `grads`: The gradients of each parameter in the current model.
110    pub fn optimizer_step(&mut self, grads: GradientsParams) {
111        self.model = self.model().optimize(&mut self.optim, self.lr, grads);
112    }
113
114    /// Optimize the current module with the provided gradients and learning rate.
115    ///
116    /// # Arguments
117    ///
118    /// * `optim`: Optimizer used for learning.
119    /// * `lr`: The learning rate used for this step.
120    /// * `grads`: Multiple gradients associated to each parameter in the current model.
121    pub fn optimizer_step_multi(&mut self, grads: MultiGradientsParams) {
122        self.model = self.model().optimize_multi(&mut self.optim, self.lr, grads);
123    }
124
125    /// Load the module state from a [record](LearnerModelRecord<LC>).
126    pub fn load_model(&mut self, record: LearnerModelRecord<LC>) {
127        self.model = self.model.clone().load_record(record);
128    }
129
130    /// Load the state of the learner's optimizer as a [record](LearnerOptimizerRecord<LC>).
131    pub fn load_optim(&mut self, record: LearnerOptimizerRecord<LC>) {
132        self.optim = self.optim.clone().load_record(record);
133    }
134
135    /// Load the state of the learner's scheduler as a [record](LearnerSchedulerRecord<LC>).
136    pub fn load_scheduler(&mut self, record: LearnerSchedulerRecord<LC>) {
137        self.lr_scheduler = self.lr_scheduler.clone().load_record(record);
138    }
139}
140
141#[derive(new)]
142/// Used to create, delete, or load checkpoints of the training process.
143pub struct LearningCheckpointer<LC: LearningComponentsTypes> {
144    model: AsyncCheckpointer<LearnerModelRecord<LC>, LC::Backend>,
145    optim: AsyncCheckpointer<LearnerOptimizerRecord<LC>, LC::Backend>,
146    lr_scheduler: AsyncCheckpointer<LearnerSchedulerRecord<LC>, LC::Backend>,
147    strategy: Box<dyn CheckpointingStrategy>,
148}
149
150impl<LC: LearningComponentsTypes> LearningCheckpointer<LC> {
151    /// Create checkpoint for the training process.
152    pub fn checkpoint(&mut self, learner: &Learner<LC>, epoch: usize, store: &EventStoreClient) {
153        let actions = self.strategy.checkpointing(epoch, store);
154
155        for action in actions {
156            match action {
157                CheckpointingAction::Delete(epoch) => {
158                    self.model
159                        .delete(epoch)
160                        .expect("Can delete model checkpoint.");
161                    self.optim
162                        .delete(epoch)
163                        .expect("Can delete optimizer checkpoint.");
164                    self.lr_scheduler
165                        .delete(epoch)
166                        .expect("Can delete learning rate scheduler checkpoint.");
167                }
168                CheckpointingAction::Save => {
169                    self.model
170                        .save(epoch, learner.model.clone().into_record())
171                        .expect("Can save model checkpoint.");
172                    self.optim
173                        .save(epoch, learner.optim.to_record())
174                        .expect("Can save optimizer checkpoint.");
175                    self.lr_scheduler
176                        .save(epoch, learner.lr_scheduler.to_record())
177                        .expect("Can save learning rate scheduler checkpoint.");
178                }
179            }
180        }
181    }
182
183    /// Load a training checkpoint.
184    pub fn load_checkpoint(
185        &self,
186        mut learner: Learner<LC>,
187        device: &Device<LC::Backend>,
188        epoch: usize,
189    ) -> Learner<LC> {
190        let record = self
191            .model
192            .restore(epoch, device)
193            .expect("Can load model checkpoint.");
194        learner.load_model(record);
195
196        let record = self
197            .optim
198            .restore(epoch, device)
199            .expect("Can load optimizer checkpoint.");
200        learner.load_optim(record);
201
202        let record = self
203            .lr_scheduler
204            .restore(epoch, device)
205            .expect("Can load learning rate scheduler checkpoint.");
206        learner.load_scheduler(record);
207
208        learner
209    }
210}
211
212/// Cloneable reference to an early stopping strategy
213pub(crate) type EarlyStoppingStrategyRef = Box<dyn CloneEarlyStoppingStrategy>;
214
215#[derive(Clone, Default)]
216/// A handle that allows aborting the training/evaluation process early.
217pub struct Interrupter {
218    state: Arc<AtomicBool>,
219    message: Arc<Mutex<Option<String>>>,
220}
221
222impl Interrupter {
223    /// Create a new instance.
224    pub fn new() -> Self {
225        Self::default()
226    }
227
228    /// Notify the learner that it should stop.
229    /// # Arguments
230    /// * `reason` - A string describing the reason the training was stopped.
231    pub fn stop(&self, reason: Option<&str>) {
232        self.state.store(true, Ordering::Relaxed);
233        reason.inspect(|r| {
234            let mut message = self.message.lock().unwrap();
235            *message = Some(String::from(*r));
236        });
237    }
238
239    /// Reset the interrupter.
240    pub fn reset(&self) {
241        self.state.store(false, Ordering::Relaxed);
242    }
243
244    /// True if .stop() has been called.
245    pub fn should_stop(&self) -> bool {
246        self.state.load(Ordering::Relaxed)
247    }
248
249    /// Get the message associated with the interrupt.
250    pub fn get_message(&self) -> Option<String> {
251        let message = self.message.lock().unwrap();
252        message.clone()
253    }
254}