burn_train/learner/
base.rs1use crate::LearningComponentsMarker;
2use crate::checkpoint::{
3 AsyncCheckpointer, Checkpointer, CheckpointingAction, CheckpointingStrategy,
4};
5use crate::components::{LearningComponentsTypes, TrainingBackend};
6use crate::metric::store::EventStoreClient;
7use crate::{
8 CloneEarlyStoppingStrategy, InferenceStep, TrainOutput, TrainStep, TrainingModelInput,
9 TrainingModelOutput,
10};
11use burn_core::module::{AutodiffModule, Module};
12use burn_core::tensor::Device;
13use burn_core::tensor::backend::AutodiffBackend;
14use burn_optim::lr_scheduler::LrScheduler;
15use burn_optim::{GradientsParams, MultiGradientsParams, Optimizer};
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::{Arc, Mutex};
18
19pub type LearnerModelRecord<LC> =
21 <<LC as LearningComponentsTypes>::TrainingModel as Module<TrainingBackend<LC>>>::Record;
22pub type LearnerOptimizerRecord<LC> = <<LC as LearningComponentsTypes>::Optimizer as Optimizer<
24 <LC as LearningComponentsTypes>::TrainingModel,
25 TrainingBackend<LC>,
26>>::Record;
27pub type LearnerSchedulerRecord<LC> =
29 <<LC as LearningComponentsTypes>::LrScheduler as LrScheduler>::Record<TrainingBackend<LC>>;
30
31pub struct Learner<LC: LearningComponentsTypes> {
33 pub(crate) model: LC::TrainingModel,
34 optim: LC::Optimizer,
35 lr_scheduler: LC::LrScheduler,
36 lr: f64,
37}
38
39impl<LC: LearningComponentsTypes> Clone for Learner<LC> {
40 fn clone(&self) -> Self {
41 Self {
42 model: self.model.clone(),
43 optim: self.optim.clone(),
44 lr_scheduler: self.lr_scheduler.clone(),
45 lr: self.lr,
46 }
47 }
48}
49
50impl<B, LR, M, O> Learner<LearningComponentsMarker<B, LR, M, O>>
51where
52 B: AutodiffBackend,
53 LR: LrScheduler + 'static,
54 M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
55 M::InnerModule: InferenceStep,
56 O: Optimizer<M, B> + 'static,
57{
58 pub fn new(model: M, optim: O, lr_scheduler: LR) -> Self {
60 Self {
61 model,
62 optim,
63 lr_scheduler,
64 lr: 0.0,
65 }
66 }
67}
68
69impl<LC: LearningComponentsTypes> Learner<LC> {
70 pub fn fork(&mut self, device: &Device<TrainingBackend<LC>>) {
72 self.model = self.model().fork(device);
73 }
74
75 pub fn model(&self) -> LC::TrainingModel {
77 self.model.clone()
78 }
79
80 pub fn lr_current(&self) -> f64 {
82 self.lr
83 }
84
85 pub fn lr_step(&mut self) {
87 self.lr = self.lr_scheduler.step();
88 }
89
90 pub fn train_step(&self, item: TrainingModelInput<LC>) -> TrainOutput<TrainingModelOutput<LC>> {
100 self.model.step(item)
101 }
102
103 pub fn optimizer_step(&mut self, grads: GradientsParams) {
111 self.model = self.model().optimize(&mut self.optim, self.lr, grads);
112 }
113
114 pub fn optimizer_step_multi(&mut self, grads: MultiGradientsParams) {
122 self.model = self.model().optimize_multi(&mut self.optim, self.lr, grads);
123 }
124
125 pub fn load_model(&mut self, record: LearnerModelRecord<LC>) {
127 self.model = self.model.clone().load_record(record);
128 }
129
130 pub fn load_optim(&mut self, record: LearnerOptimizerRecord<LC>) {
132 self.optim = self.optim.clone().load_record(record);
133 }
134
135 pub fn load_scheduler(&mut self, record: LearnerSchedulerRecord<LC>) {
137 self.lr_scheduler = self.lr_scheduler.clone().load_record(record);
138 }
139}
140
141#[derive(new)]
142pub struct LearningCheckpointer<LC: LearningComponentsTypes> {
144 model: AsyncCheckpointer<LearnerModelRecord<LC>, LC::Backend>,
145 optim: AsyncCheckpointer<LearnerOptimizerRecord<LC>, LC::Backend>,
146 lr_scheduler: AsyncCheckpointer<LearnerSchedulerRecord<LC>, LC::Backend>,
147 strategy: Box<dyn CheckpointingStrategy>,
148}
149
150impl<LC: LearningComponentsTypes> LearningCheckpointer<LC> {
151 pub fn checkpoint(&mut self, learner: &Learner<LC>, epoch: usize, store: &EventStoreClient) {
153 let actions = self.strategy.checkpointing(epoch, store);
154
155 for action in actions {
156 match action {
157 CheckpointingAction::Delete(epoch) => {
158 self.model
159 .delete(epoch)
160 .expect("Can delete model checkpoint.");
161 self.optim
162 .delete(epoch)
163 .expect("Can delete optimizer checkpoint.");
164 self.lr_scheduler
165 .delete(epoch)
166 .expect("Can delete learning rate scheduler checkpoint.");
167 }
168 CheckpointingAction::Save => {
169 self.model
170 .save(epoch, learner.model.clone().into_record())
171 .expect("Can save model checkpoint.");
172 self.optim
173 .save(epoch, learner.optim.to_record())
174 .expect("Can save optimizer checkpoint.");
175 self.lr_scheduler
176 .save(epoch, learner.lr_scheduler.to_record())
177 .expect("Can save learning rate scheduler checkpoint.");
178 }
179 }
180 }
181 }
182
183 pub fn load_checkpoint(
185 &self,
186 mut learner: Learner<LC>,
187 device: &Device<LC::Backend>,
188 epoch: usize,
189 ) -> Learner<LC> {
190 let record = self
191 .model
192 .restore(epoch, device)
193 .expect("Can load model checkpoint.");
194 learner.load_model(record);
195
196 let record = self
197 .optim
198 .restore(epoch, device)
199 .expect("Can load optimizer checkpoint.");
200 learner.load_optim(record);
201
202 let record = self
203 .lr_scheduler
204 .restore(epoch, device)
205 .expect("Can load learning rate scheduler checkpoint.");
206 learner.load_scheduler(record);
207
208 learner
209 }
210}
211
212pub(crate) type EarlyStoppingStrategyRef = Box<dyn CloneEarlyStoppingStrategy>;
214
215#[derive(Clone, Default)]
216pub struct Interrupter {
218 state: Arc<AtomicBool>,
219 message: Arc<Mutex<Option<String>>>,
220}
221
222impl Interrupter {
223 pub fn new() -> Self {
225 Self::default()
226 }
227
228 pub fn stop(&self, reason: Option<&str>) {
232 self.state.store(true, Ordering::Relaxed);
233 reason.inspect(|r| {
234 let mut message = self.message.lock().unwrap();
235 *message = Some(String::from(*r));
236 });
237 }
238
239 pub fn reset(&self) {
241 self.state.store(false, Ordering::Relaxed);
242 }
243
244 pub fn should_stop(&self) -> bool {
246 self.state.load(Ordering::Relaxed)
247 }
248
249 pub fn get_message(&self) -> Option<String> {
251 let message = self.message.lock().unwrap();
252 message.clone()
253 }
254}