burn_train/learner/
base.rs1use crate::LearningComponentsMarker;
2use crate::checkpoint::{
3 AsyncCheckpointer, Checkpointer, CheckpointingAction, CheckpointingStrategy,
4};
5use crate::components::{LearningComponentsTypes, TrainingBackend};
6use crate::metric::store::EventStoreClient;
7use crate::{
8 CloneEarlyStoppingStrategy, InferenceStep, TrainOutput, TrainStep, TrainingModelInput,
9 TrainingModelOutput,
10};
11use burn_core::module::{AutodiffModule, Module};
12use burn_core::prelude::Backend;
13use burn_core::tensor::Device;
14use burn_core::tensor::backend::AutodiffBackend;
15use burn_optim::lr_scheduler::LrScheduler;
16use burn_optim::{GradientsParams, MultiGradientsParams, Optimizer};
17use std::sync::atomic::{AtomicBool, Ordering};
18use std::sync::{Arc, Mutex};
19
20pub type LearnerModelRecord<LC> =
22 <<LC as LearningComponentsTypes>::TrainingModel as Module<TrainingBackend<LC>>>::Record;
23pub type LearnerOptimizerRecord<LC> = <<LC as LearningComponentsTypes>::Optimizer as Optimizer<
25 <LC as LearningComponentsTypes>::TrainingModel,
26 TrainingBackend<LC>,
27>>::Record;
28pub type LearnerSchedulerRecord<LC> =
30 <<LC as LearningComponentsTypes>::LrScheduler as LrScheduler>::Record<TrainingBackend<LC>>;
31
32pub struct Learner<LC: LearningComponentsTypes> {
34 model: LC::TrainingModel,
35 optim: LC::Optimizer,
36 lr_scheduler: LC::LrScheduler,
37 lr: f64,
38}
39
40impl<LC: LearningComponentsTypes> Clone for Learner<LC> {
41 fn clone(&self) -> Self {
42 Self {
43 model: self.model.clone(),
44 optim: self.optim.clone(),
45 lr_scheduler: self.lr_scheduler.clone(),
46 lr: self.lr,
47 }
48 }
49}
50
51impl<B, LR, M, O> Learner<LearningComponentsMarker<B, LR, M, O>>
52where
53 B: AutodiffBackend,
54 LR: LrScheduler + 'static,
55 M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
56 M::InnerModule: InferenceStep,
57 O: Optimizer<M, B> + 'static,
58{
59 pub fn new(model: M, optim: O, lr_scheduler: LR) -> Self {
61 Self {
62 model,
63 optim,
64 lr_scheduler,
65 lr: 0.0,
66 }
67 }
68}
69
70impl<LC: LearningComponentsTypes> Learner<LC> {
71 pub fn fork(&mut self, device: &<TrainingBackend<LC> as Backend>::Device) {
73 self.model = self.model().fork(device);
74 }
75
76 pub fn model(&self) -> LC::TrainingModel {
78 self.model.clone()
79 }
80
81 pub fn lr_current(&self) -> f64 {
83 self.lr
84 }
85
86 pub fn lr_step(&mut self) {
88 self.lr = self.lr_scheduler.step();
89 }
90
91 pub fn train_step(&self, item: TrainingModelInput<LC>) -> TrainOutput<TrainingModelOutput<LC>> {
101 self.model.step(item)
102 }
103
104 pub fn optimizer_step(&mut self, grads: GradientsParams) {
112 self.model = self.model().optimize(&mut self.optim, self.lr, grads);
113 }
114
115 pub fn optimizer_step_multi(&mut self, grads: MultiGradientsParams) {
123 self.model = self.model().optimize_multi(&mut self.optim, self.lr, grads);
124 }
125
126 pub fn load_model(&mut self, record: LearnerModelRecord<LC>) {
128 self.model = self.model.clone().load_record(record);
129 }
130
131 pub fn load_optim(&mut self, record: LearnerOptimizerRecord<LC>) {
133 self.optim = self.optim.clone().load_record(record);
134 }
135
136 pub fn load_scheduler(&mut self, record: LearnerSchedulerRecord<LC>) {
138 self.lr_scheduler = self.lr_scheduler.clone().load_record(record);
139 }
140}
141
142#[derive(new)]
143pub struct LearningCheckpointer<LC: LearningComponentsTypes> {
145 model: AsyncCheckpointer<LearnerModelRecord<LC>, LC::Backend>,
146 optim: AsyncCheckpointer<LearnerOptimizerRecord<LC>, LC::Backend>,
147 lr_scheduler: AsyncCheckpointer<LearnerSchedulerRecord<LC>, LC::Backend>,
148 strategy: Box<dyn CheckpointingStrategy>,
149}
150
151impl<LC: LearningComponentsTypes> LearningCheckpointer<LC> {
152 pub fn checkpoint(&mut self, learner: &Learner<LC>, epoch: usize, store: &EventStoreClient) {
154 let actions = self.strategy.checkpointing(epoch, store);
155
156 for action in actions {
157 match action {
158 CheckpointingAction::Delete(epoch) => {
159 self.model
160 .delete(epoch)
161 .expect("Can delete model checkpoint.");
162 self.optim
163 .delete(epoch)
164 .expect("Can delete optimizer checkpoint.");
165 self.lr_scheduler
166 .delete(epoch)
167 .expect("Can delete learning rate scheduler checkpoint.");
168 }
169 CheckpointingAction::Save => {
170 self.model
171 .save(epoch, learner.model.clone().into_record())
172 .expect("Can save model checkpoint.");
173 self.optim
174 .save(epoch, learner.optim.to_record())
175 .expect("Can save optimizer checkpoint.");
176 self.lr_scheduler
177 .save(epoch, learner.lr_scheduler.to_record())
178 .expect("Can save learning rate scheduler checkpoint.");
179 }
180 }
181 }
182 }
183
184 pub fn load_checkpoint(
186 &self,
187 mut learner: Learner<LC>,
188 device: &Device<LC::Backend>,
189 epoch: usize,
190 ) -> Learner<LC> {
191 let record = self
192 .model
193 .restore(epoch, device)
194 .expect("Can load model checkpoint.");
195 learner.load_model(record);
196
197 let record = self
198 .optim
199 .restore(epoch, device)
200 .expect("Can load optimizer checkpoint.");
201 learner.load_optim(record);
202
203 let record = self
204 .lr_scheduler
205 .restore(epoch, device)
206 .expect("Can load learning rate scheduler checkpoint.");
207 learner.load_scheduler(record);
208
209 learner
210 }
211}
212
213pub(crate) type EarlyStoppingStrategyRef = Box<dyn CloneEarlyStoppingStrategy>;
215
216#[derive(Clone, Default)]
217pub struct Interrupter {
219 state: Arc<AtomicBool>,
220 message: Arc<Mutex<Option<String>>>,
221}
222
223impl Interrupter {
224 pub fn new() -> Self {
226 Self::default()
227 }
228
229 pub fn stop(&self, reason: Option<&str>) {
233 self.state.store(true, Ordering::Relaxed);
234 reason.inspect(|r| {
235 let mut message = self.message.lock().unwrap();
236 *message = Some(String::from(*r));
237 });
238 }
239
240 pub fn reset(&self) {
242 self.state.store(false, Ordering::Relaxed);
243 }
244
245 pub fn should_stop(&self) -> bool {
247 self.state.load(Ordering::Relaxed)
248 }
249
250 pub fn get_message(&self) -> Option<String> {
252 let message = self.message.lock().unwrap();
253 message.clone()
254 }
255}