burn_train/learner/
base.rs

1use crate::LearningComponentsMarker;
2use crate::checkpoint::{
3    AsyncCheckpointer, Checkpointer, CheckpointingAction, CheckpointingStrategy,
4};
5use crate::components::{LearningComponentsTypes, TrainingBackend};
6use crate::metric::store::EventStoreClient;
7use crate::{
8    CloneEarlyStoppingStrategy, InferenceStep, TrainOutput, TrainStep, TrainingModelInput,
9    TrainingModelOutput,
10};
11use burn_core::module::{AutodiffModule, Module};
12use burn_core::prelude::Backend;
13use burn_core::tensor::Device;
14use burn_core::tensor::backend::AutodiffBackend;
15use burn_optim::lr_scheduler::LrScheduler;
16use burn_optim::{GradientsParams, MultiGradientsParams, Optimizer};
17use std::sync::atomic::{AtomicBool, Ordering};
18use std::sync::{Arc, Mutex};
19
20/// The record of the learner's model.
21pub type LearnerModelRecord<LC> =
22    <<LC as LearningComponentsTypes>::TrainingModel as Module<TrainingBackend<LC>>>::Record;
23/// The record of the optimizer.
24pub type LearnerOptimizerRecord<LC> = <<LC as LearningComponentsTypes>::Optimizer as Optimizer<
25    <LC as LearningComponentsTypes>::TrainingModel,
26    TrainingBackend<LC>,
27>>::Record;
28/// The record of the LR scheduler.
29pub type LearnerSchedulerRecord<LC> =
30    <<LC as LearningComponentsTypes>::LrScheduler as LrScheduler>::Record<TrainingBackend<LC>>;
31
32/// Learner struct encapsulating all components necessary to train a Neural Network model.
33pub struct Learner<LC: LearningComponentsTypes> {
34    model: LC::TrainingModel,
35    optim: LC::Optimizer,
36    lr_scheduler: LC::LrScheduler,
37    lr: f64,
38}
39
40impl<LC: LearningComponentsTypes> Clone for Learner<LC> {
41    fn clone(&self) -> Self {
42        Self {
43            model: self.model.clone(),
44            optim: self.optim.clone(),
45            lr_scheduler: self.lr_scheduler.clone(),
46            lr: self.lr,
47        }
48    }
49}
50
51impl<B, LR, M, O> Learner<LearningComponentsMarker<B, LR, M, O>>
52where
53    B: AutodiffBackend,
54    LR: LrScheduler + 'static,
55    M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
56    M::InnerModule: InferenceStep,
57    O: Optimizer<M, B> + 'static,
58{
59    /// Create a learner.
60    pub fn new(model: M, optim: O, lr_scheduler: LR) -> Self {
61        Self {
62            model,
63            optim,
64            lr_scheduler,
65            lr: 0.0,
66        }
67    }
68}
69
70impl<LC: LearningComponentsTypes> Learner<LC> {
71    /// Fork the learner's model to the given device.
72    pub fn fork(&mut self, device: &<TrainingBackend<LC> as Backend>::Device) {
73        self.model = self.model().fork(device);
74    }
75
76    /// Returns the current model.
77    pub fn model(&self) -> LC::TrainingModel {
78        self.model.clone()
79    }
80
81    /// Returns the current learning rate.
82    pub fn lr_current(&self) -> f64 {
83        self.lr
84    }
85
86    /// Executes a step of the learning rate scheduler.
87    pub fn lr_step(&mut self) {
88        self.lr = self.lr_scheduler.step();
89    }
90
91    /// Runs a step of the model for training, which executes the forward and backward passes.
92    ///
93    /// # Arguments
94    ///
95    /// * `item` - The input for the model.
96    ///
97    /// # Returns
98    ///
99    /// The output containing the model output and the gradients.
100    pub fn train_step(&self, item: TrainingModelInput<LC>) -> TrainOutput<TrainingModelOutput<LC>> {
101        self.model.step(item)
102    }
103
104    /// Optimize the current module with the provided gradients and learning rate.
105    ///
106    /// # Arguments
107    ///
108    /// * `optim`: Optimizer used for learning.
109    /// * `lr`: The learning rate used for this step.
110    /// * `grads`: The gradients of each parameter in the current model.
111    pub fn optimizer_step(&mut self, grads: GradientsParams) {
112        self.model = self.model().optimize(&mut self.optim, self.lr, grads);
113    }
114
115    /// Optimize the current module with the provided gradients and learning rate.
116    ///
117    /// # Arguments
118    ///
119    /// * `optim`: Optimizer used for learning.
120    /// * `lr`: The learning rate used for this step.
121    /// * `grads`: Multiple gradients associated to each parameter in the current model.
122    pub fn optimizer_step_multi(&mut self, grads: MultiGradientsParams) {
123        self.model = self.model().optimize_multi(&mut self.optim, self.lr, grads);
124    }
125
126    /// Load the module state from a [record](LearnerModelRecord<LC>).
127    pub fn load_model(&mut self, record: LearnerModelRecord<LC>) {
128        self.model = self.model.clone().load_record(record);
129    }
130
131    /// Load the state of the learner's optimizer as a [record](LearnerOptimizerRecord<LC>).
132    pub fn load_optim(&mut self, record: LearnerOptimizerRecord<LC>) {
133        self.optim = self.optim.clone().load_record(record);
134    }
135
136    /// Load the state of the learner's scheduler as a [record](LearnerSchedulerRecord<LC>).
137    pub fn load_scheduler(&mut self, record: LearnerSchedulerRecord<LC>) {
138        self.lr_scheduler = self.lr_scheduler.clone().load_record(record);
139    }
140}
141
142#[derive(new)]
143/// Used to create, delete, or load checkpoints of the training process.
144pub struct LearningCheckpointer<LC: LearningComponentsTypes> {
145    model: AsyncCheckpointer<LearnerModelRecord<LC>, LC::Backend>,
146    optim: AsyncCheckpointer<LearnerOptimizerRecord<LC>, LC::Backend>,
147    lr_scheduler: AsyncCheckpointer<LearnerSchedulerRecord<LC>, LC::Backend>,
148    strategy: Box<dyn CheckpointingStrategy>,
149}
150
151impl<LC: LearningComponentsTypes> LearningCheckpointer<LC> {
152    /// Create checkpoint for the training process.
153    pub fn checkpoint(&mut self, learner: &Learner<LC>, epoch: usize, store: &EventStoreClient) {
154        let actions = self.strategy.checkpointing(epoch, store);
155
156        for action in actions {
157            match action {
158                CheckpointingAction::Delete(epoch) => {
159                    self.model
160                        .delete(epoch)
161                        .expect("Can delete model checkpoint.");
162                    self.optim
163                        .delete(epoch)
164                        .expect("Can delete optimizer checkpoint.");
165                    self.lr_scheduler
166                        .delete(epoch)
167                        .expect("Can delete learning rate scheduler checkpoint.");
168                }
169                CheckpointingAction::Save => {
170                    self.model
171                        .save(epoch, learner.model.clone().into_record())
172                        .expect("Can save model checkpoint.");
173                    self.optim
174                        .save(epoch, learner.optim.to_record())
175                        .expect("Can save optimizer checkpoint.");
176                    self.lr_scheduler
177                        .save(epoch, learner.lr_scheduler.to_record())
178                        .expect("Can save learning rate scheduler checkpoint.");
179                }
180            }
181        }
182    }
183
184    /// Load a training checkpoint.
185    pub fn load_checkpoint(
186        &self,
187        mut learner: Learner<LC>,
188        device: &Device<LC::Backend>,
189        epoch: usize,
190    ) -> Learner<LC> {
191        let record = self
192            .model
193            .restore(epoch, device)
194            .expect("Can load model checkpoint.");
195        learner.load_model(record);
196
197        let record = self
198            .optim
199            .restore(epoch, device)
200            .expect("Can load optimizer checkpoint.");
201        learner.load_optim(record);
202
203        let record = self
204            .lr_scheduler
205            .restore(epoch, device)
206            .expect("Can load learning rate scheduler checkpoint.");
207        learner.load_scheduler(record);
208
209        learner
210    }
211}
212
213/// Cloneable reference to an early stopping strategy
214pub(crate) type EarlyStoppingStrategyRef = Box<dyn CloneEarlyStoppingStrategy>;
215
216#[derive(Clone, Default)]
217/// A handle that allows aborting the training/evaluation process early.
218pub struct Interrupter {
219    state: Arc<AtomicBool>,
220    message: Arc<Mutex<Option<String>>>,
221}
222
223impl Interrupter {
224    /// Create a new instance.
225    pub fn new() -> Self {
226        Self::default()
227    }
228
229    /// Notify the learner that it should stop.
230    /// # Arguments
231    /// * `reason` - A string describing the reason the training was stopped.
232    pub fn stop(&self, reason: Option<&str>) {
233        self.state.store(true, Ordering::Relaxed);
234        reason.inspect(|r| {
235            let mut message = self.message.lock().unwrap();
236            *message = Some(String::from(*r));
237        });
238    }
239
240    /// Reset the interrupter.
241    pub fn reset(&self) {
242        self.state.store(false, Ordering::Relaxed);
243    }
244
245    /// True if .stop() has been called.
246    pub fn should_stop(&self) -> bool {
247        self.state.load(Ordering::Relaxed)
248    }
249
250    /// Get the message associated with the interrupt.
251    pub fn get_message(&self) -> Option<String> {
252        let message = self.message.lock().unwrap();
253        message.clone()
254    }
255}