burn_train/learner/
base.rs

1use crate::checkpoint::{Checkpointer, CheckpointingAction, CheckpointingStrategy};
2use crate::components::LearnerComponentTypes;
3use crate::metric::store::EventStoreClient;
4use crate::{CloneEarlyStoppingStrategy, LearnerSummaryConfig, LearningStrategy};
5use burn_core::module::Module;
6use burn_core::tensor::Device;
7use burn_optim::Optimizer;
8use burn_optim::lr_scheduler::LrScheduler;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11
12/// Learner struct encapsulating all components necessary to train a Neural Network model.
13///
14/// To create a learner, use the [builder](crate::learner::LearnerBuilder) struct.
15pub struct Learner<LC: LearnerComponentTypes> {
16    pub(crate) model: LC::Model,
17    pub(crate) optim: LC::Optimizer,
18    pub(crate) lr_scheduler: LC::LrScheduler,
19    pub(crate) num_epochs: usize,
20    pub(crate) checkpoint: Option<usize>,
21    pub(crate) grad_accumulation: Option<usize>,
22    pub(crate) checkpointer: Option<LearnerCheckpointer<LC>>,
23    pub(crate) learning_strategy: LearningStrategy<LC::Backend>,
24    pub(crate) interrupter: Interrupter,
25    pub(crate) early_stopping: Option<EarlyStoppingStrategyRef>,
26    pub(crate) event_processor: LC::EventProcessor,
27    pub(crate) event_store: Arc<EventStoreClient>,
28    pub(crate) summary: Option<LearnerSummaryConfig>,
29}
30
31/// Cloneable reference to an early stopping strategy
32pub(crate) type EarlyStoppingStrategyRef = Box<dyn CloneEarlyStoppingStrategy>;
33
34#[derive(new)]
35pub(crate) struct LearnerCheckpointer<LC: LearnerComponentTypes> {
36    model: LC::CheckpointerModel,
37    optim: LC::CheckpointerOptimizer,
38    lr_scheduler: LC::CheckpointerLrScheduler,
39    strategy: LC::CheckpointerStrategy,
40}
41
42impl<LC: LearnerComponentTypes> LearnerCheckpointer<LC> {
43    pub(crate) fn checkpoint(
44        &mut self,
45        model: &LC::Model,
46        optim: &LC::Optimizer,
47        scheduler: &LC::LrScheduler,
48        epoch: usize,
49        store: &EventStoreClient,
50    ) {
51        let actions = self.strategy.checkpointing(epoch, store);
52
53        for action in actions {
54            match action {
55                CheckpointingAction::Delete(epoch) => {
56                    self.model
57                        .delete(epoch)
58                        .expect("Can delete model checkpoint.");
59                    self.optim
60                        .delete(epoch)
61                        .expect("Can delete optimizer checkpoint.");
62                    self.lr_scheduler
63                        .delete(epoch)
64                        .expect("Can delete learning rate scheduler checkpoint.");
65                }
66                CheckpointingAction::Save => {
67                    self.model
68                        .save(epoch, model.clone().into_record())
69                        .expect("Can save model checkpoint.");
70                    self.optim
71                        .save(epoch, optim.to_record())
72                        .expect("Can save optimizer checkpoint.");
73                    self.lr_scheduler
74                        .save(epoch, scheduler.to_record())
75                        .expect("Can save learning rate scheduler checkpoint.");
76                }
77            }
78        }
79    }
80
81    pub(crate) fn load_checkpoint(
82        &self,
83        model: LC::Model,
84        optim: LC::Optimizer,
85        scheduler: LC::LrScheduler,
86        device: &Device<LC::Backend>,
87        epoch: usize,
88    ) -> (LC::Model, LC::Optimizer, LC::LrScheduler) {
89        let record = self
90            .model
91            .restore(epoch, device)
92            .expect("Can load model checkpoint.");
93        let model = model.load_record(record);
94
95        let record = self
96            .optim
97            .restore(epoch, device)
98            .expect("Can load optimizer checkpoint.");
99        let optim = optim.load_record(record);
100
101        let record = self
102            .lr_scheduler
103            .restore(epoch, device)
104            .expect("Can load learning rate scheduler checkpoint.");
105        let scheduler = scheduler.load_record(record);
106
107        (model, optim, scheduler)
108    }
109}
110
111#[derive(Clone, Default)]
112/// A handle that allows aborting the training/evaluation process early.
113pub struct Interrupter {
114    state: Arc<AtomicBool>,
115}
116
117impl Interrupter {
118    /// Create a new instance.
119    pub fn new() -> Self {
120        Self::default()
121    }
122
123    /// Notify the learner that it should stop.
124    pub fn stop(&self) {
125        self.state.store(true, Ordering::Relaxed);
126    }
127
128    /// Reset the interrupter.
129    pub fn reset(&self) {
130        self.state.store(false, Ordering::Relaxed);
131    }
132
133    /// True if .stop() has been called.
134    pub fn should_stop(&self) -> bool {
135        self.state.load(Ordering::Relaxed)
136    }
137}