burn_train/learner/
base.rs

1use crate::checkpoint::{Checkpointer, CheckpointingAction, CheckpointingStrategy};
2use crate::components::LearnerComponents;
3use crate::learner::EarlyStoppingStrategy;
4use crate::metric::store::EventStoreClient;
5use crate::LearnerSummaryConfig;
6use burn_core::lr_scheduler::LrScheduler;
7use burn_core::module::Module;
8use burn_core::optim::Optimizer;
9use burn_core::tensor::backend::Backend;
10use burn_core::tensor::Device;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13
14/// Learner struct encapsulating all components necessary to train a Neural Network model.
15///
16/// To create a learner, use the [builder](crate::learner::LearnerBuilder) struct.
17pub struct Learner<LC: LearnerComponents> {
18    pub(crate) model: LC::Model,
19    pub(crate) optim: LC::Optimizer,
20    pub(crate) lr_scheduler: LC::LrScheduler,
21    pub(crate) num_epochs: usize,
22    pub(crate) checkpoint: Option<usize>,
23    pub(crate) grad_accumulation: Option<usize>,
24    pub(crate) checkpointer: Option<LearnerCheckpointer<LC>>,
25    pub(crate) devices: Vec<<LC::Backend as Backend>::Device>,
26    pub(crate) interrupter: TrainingInterrupter,
27    pub(crate) early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
28    pub(crate) event_processor: LC::EventProcessor,
29    pub(crate) event_store: Arc<EventStoreClient>,
30    pub(crate) summary: Option<LearnerSummaryConfig>,
31}
32
33#[derive(new)]
34pub(crate) struct LearnerCheckpointer<LC: LearnerComponents> {
35    model: LC::CheckpointerModel,
36    optim: LC::CheckpointerOptimizer,
37    lr_scheduler: LC::CheckpointerLrScheduler,
38    strategy: LC::CheckpointerStrategy,
39}
40
41impl<LC: LearnerComponents> LearnerCheckpointer<LC> {
42    pub(crate) fn checkpoint(
43        &mut self,
44        model: &LC::Model,
45        optim: &LC::Optimizer,
46        scheduler: &LC::LrScheduler,
47        epoch: usize,
48        store: &EventStoreClient,
49    ) {
50        let actions = self.strategy.checkpointing(epoch, store);
51
52        for action in actions {
53            match action {
54                CheckpointingAction::Delete(epoch) => {
55                    self.model
56                        .delete(epoch)
57                        .expect("Can delete model checkpoint.");
58                    self.optim
59                        .delete(epoch)
60                        .expect("Can delete optimizer checkpoint.");
61                    self.lr_scheduler
62                        .delete(epoch)
63                        .expect("Can delete learning rate scheduler checkpoint.");
64                }
65                CheckpointingAction::Save => {
66                    self.model
67                        .save(epoch, model.clone().into_record())
68                        .expect("Can save model checkpoint.");
69                    self.optim
70                        .save(epoch, optim.to_record())
71                        .expect("Can save optimizer checkpoint.");
72                    self.lr_scheduler
73                        .save(epoch, scheduler.to_record())
74                        .expect("Can save learning rate scheduler checkpoint.");
75                }
76            }
77        }
78    }
79
80    pub(crate) fn load_checkpoint(
81        &self,
82        model: LC::Model,
83        optim: LC::Optimizer,
84        scheduler: LC::LrScheduler,
85        device: &Device<LC::Backend>,
86        epoch: usize,
87    ) -> (LC::Model, LC::Optimizer, LC::LrScheduler) {
88        let record = self
89            .model
90            .restore(epoch, device)
91            .expect("Can load model checkpoint.");
92        let model = model.load_record(record);
93
94        let record = self
95            .optim
96            .restore(epoch, device)
97            .expect("Can load optimizer checkpoint.");
98        let optim = optim.load_record(record);
99
100        let record = self
101            .lr_scheduler
102            .restore(epoch, device)
103            .expect("Can load learning rate scheduler checkpoint.");
104        let scheduler = scheduler.load_record(record);
105
106        (model, optim, scheduler)
107    }
108}
109
110#[derive(Clone, Default)]
111/// A handle that allows aborting the training process early.
112pub struct TrainingInterrupter {
113    state: Arc<AtomicBool>,
114}
115
116impl TrainingInterrupter {
117    /// Create a new instance.
118    pub fn new() -> Self {
119        Self::default()
120    }
121
122    /// Notify the learner that it should stop.
123    pub fn stop(&self) {
124        self.state.store(true, Ordering::Relaxed);
125    }
126
127    /// True if .stop() has been called.
128    pub fn should_stop(&self) -> bool {
129        self.state.load(Ordering::Relaxed)
130    }
131}