burn_train/learner/
base.rs

1use crate::checkpoint::{Checkpointer, CheckpointingAction, CheckpointingStrategy};
2use crate::components::LearnerComponentTypes;
3use crate::metric::store::EventStoreClient;
4use crate::{CloneEarlyStoppingStrategy, LearnerSummaryConfig, LearningStrategy};
5use burn_core::module::Module;
6use burn_core::tensor::Device;
7use burn_optim::Optimizer;
8use burn_optim::lr_scheduler::LrScheduler;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11
12/// Learner struct encapsulating all components necessary to train a Neural Network model.
13///
14/// To create a learner, use the [builder](crate::learner::LearnerBuilder) struct.
15pub struct Learner<LC: LearnerComponentTypes> {
16    pub(crate) model: LC::Model,
17    pub(crate) optim: LC::Optimizer,
18    pub(crate) lr_scheduler: LC::LrScheduler,
19    pub(crate) num_epochs: usize,
20    pub(crate) checkpoint: Option<usize>,
21    pub(crate) grad_accumulation: Option<usize>,
22    pub(crate) checkpointer: Option<LearnerCheckpointer<LC>>,
23    pub(crate) learning_strategy: LearningStrategy<LC>,
24    pub(crate) interrupter: Interrupter,
25    pub(crate) early_stopping: Option<EarlyStoppingStrategyRef>,
26    pub(crate) event_processor: LC::EventProcessor,
27    pub(crate) event_store: Arc<EventStoreClient>,
28    pub(crate) summary: Option<LearnerSummaryConfig>,
29}
30
31/// Cloneable reference to an early stopping strategy
32pub(crate) type EarlyStoppingStrategyRef = Box<dyn CloneEarlyStoppingStrategy>;
33
34#[derive(new)]
35/// Used to create, delete, or load checkpoints of the training process.
36pub struct LearnerCheckpointer<LC: LearnerComponentTypes> {
37    model: LC::CheckpointerModel,
38    optim: LC::CheckpointerOptimizer,
39    lr_scheduler: LC::CheckpointerLrScheduler,
40    strategy: LC::CheckpointerStrategy,
41}
42
43impl<LC: LearnerComponentTypes> LearnerCheckpointer<LC> {
44    /// Create checkpoint for the training process.
45    pub fn checkpoint(
46        &mut self,
47        model: &LC::Model,
48        optim: &LC::Optimizer,
49        scheduler: &LC::LrScheduler,
50        epoch: usize,
51        store: &EventStoreClient,
52    ) {
53        let actions = self.strategy.checkpointing(epoch, store);
54
55        for action in actions {
56            match action {
57                CheckpointingAction::Delete(epoch) => {
58                    self.model
59                        .delete(epoch)
60                        .expect("Can delete model checkpoint.");
61                    self.optim
62                        .delete(epoch)
63                        .expect("Can delete optimizer checkpoint.");
64                    self.lr_scheduler
65                        .delete(epoch)
66                        .expect("Can delete learning rate scheduler checkpoint.");
67                }
68                CheckpointingAction::Save => {
69                    self.model
70                        .save(epoch, model.clone().into_record())
71                        .expect("Can save model checkpoint.");
72                    self.optim
73                        .save(epoch, optim.to_record())
74                        .expect("Can save optimizer checkpoint.");
75                    self.lr_scheduler
76                        .save(epoch, scheduler.to_record())
77                        .expect("Can save learning rate scheduler checkpoint.");
78                }
79            }
80        }
81    }
82
83    /// Load a training checkpoint.
84    pub fn load_checkpoint(
85        &self,
86        model: LC::Model,
87        optim: LC::Optimizer,
88        scheduler: LC::LrScheduler,
89        device: &Device<LC::Backend>,
90        epoch: usize,
91    ) -> (LC::Model, LC::Optimizer, LC::LrScheduler) {
92        let record = self
93            .model
94            .restore(epoch, device)
95            .expect("Can load model checkpoint.");
96        let model = model.load_record(record);
97
98        let record = self
99            .optim
100            .restore(epoch, device)
101            .expect("Can load optimizer checkpoint.");
102        let optim = optim.load_record(record);
103
104        let record = self
105            .lr_scheduler
106            .restore(epoch, device)
107            .expect("Can load learning rate scheduler checkpoint.");
108        let scheduler = scheduler.load_record(record);
109
110        (model, optim, scheduler)
111    }
112}
113
114#[derive(Clone, Default)]
115/// A handle that allows aborting the training/evaluation process early.
116pub struct Interrupter {
117    state: Arc<AtomicBool>,
118}
119
120impl Interrupter {
121    /// Create a new instance.
122    pub fn new() -> Self {
123        Self::default()
124    }
125
126    /// Notify the learner that it should stop.
127    pub fn stop(&self) {
128        self.state.store(true, Ordering::Relaxed);
129    }
130
131    /// Reset the interrupter.
132    pub fn reset(&self) {
133        self.state.store(false, Ordering::Relaxed);
134    }
135
136    /// True if .stop() has been called.
137    pub fn should_stop(&self) -> bool {
138        self.state.load(Ordering::Relaxed)
139    }
140}