burn_train/learner/
base.rs1use crate::checkpoint::{Checkpointer, CheckpointingAction, CheckpointingStrategy};
2use crate::components::LearnerComponentTypes;
3use crate::metric::store::EventStoreClient;
4use crate::{CloneEarlyStoppingStrategy, LearnerSummaryConfig, LearningStrategy};
5use burn_core::module::Module;
6use burn_core::tensor::Device;
7use burn_optim::Optimizer;
8use burn_optim::lr_scheduler::LrScheduler;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11
12pub struct Learner<LC: LearnerComponentTypes> {
16 pub(crate) model: LC::Model,
17 pub(crate) optim: LC::Optimizer,
18 pub(crate) lr_scheduler: LC::LrScheduler,
19 pub(crate) num_epochs: usize,
20 pub(crate) checkpoint: Option<usize>,
21 pub(crate) grad_accumulation: Option<usize>,
22 pub(crate) checkpointer: Option<LearnerCheckpointer<LC>>,
23 pub(crate) learning_strategy: LearningStrategy<LC>,
24 pub(crate) interrupter: Interrupter,
25 pub(crate) early_stopping: Option<EarlyStoppingStrategyRef>,
26 pub(crate) event_processor: LC::EventProcessor,
27 pub(crate) event_store: Arc<EventStoreClient>,
28 pub(crate) summary: Option<LearnerSummaryConfig>,
29}
30
31pub(crate) type EarlyStoppingStrategyRef = Box<dyn CloneEarlyStoppingStrategy>;
33
34#[derive(new)]
35pub struct LearnerCheckpointer<LC: LearnerComponentTypes> {
37 model: LC::CheckpointerModel,
38 optim: LC::CheckpointerOptimizer,
39 lr_scheduler: LC::CheckpointerLrScheduler,
40 strategy: LC::CheckpointerStrategy,
41}
42
43impl<LC: LearnerComponentTypes> LearnerCheckpointer<LC> {
44 pub fn checkpoint(
46 &mut self,
47 model: &LC::Model,
48 optim: &LC::Optimizer,
49 scheduler: &LC::LrScheduler,
50 epoch: usize,
51 store: &EventStoreClient,
52 ) {
53 let actions = self.strategy.checkpointing(epoch, store);
54
55 for action in actions {
56 match action {
57 CheckpointingAction::Delete(epoch) => {
58 self.model
59 .delete(epoch)
60 .expect("Can delete model checkpoint.");
61 self.optim
62 .delete(epoch)
63 .expect("Can delete optimizer checkpoint.");
64 self.lr_scheduler
65 .delete(epoch)
66 .expect("Can delete learning rate scheduler checkpoint.");
67 }
68 CheckpointingAction::Save => {
69 self.model
70 .save(epoch, model.clone().into_record())
71 .expect("Can save model checkpoint.");
72 self.optim
73 .save(epoch, optim.to_record())
74 .expect("Can save optimizer checkpoint.");
75 self.lr_scheduler
76 .save(epoch, scheduler.to_record())
77 .expect("Can save learning rate scheduler checkpoint.");
78 }
79 }
80 }
81 }
82
83 pub fn load_checkpoint(
85 &self,
86 model: LC::Model,
87 optim: LC::Optimizer,
88 scheduler: LC::LrScheduler,
89 device: &Device<LC::Backend>,
90 epoch: usize,
91 ) -> (LC::Model, LC::Optimizer, LC::LrScheduler) {
92 let record = self
93 .model
94 .restore(epoch, device)
95 .expect("Can load model checkpoint.");
96 let model = model.load_record(record);
97
98 let record = self
99 .optim
100 .restore(epoch, device)
101 .expect("Can load optimizer checkpoint.");
102 let optim = optim.load_record(record);
103
104 let record = self
105 .lr_scheduler
106 .restore(epoch, device)
107 .expect("Can load learning rate scheduler checkpoint.");
108 let scheduler = scheduler.load_record(record);
109
110 (model, optim, scheduler)
111 }
112}
113
114#[derive(Clone, Default)]
115pub struct Interrupter {
117 state: Arc<AtomicBool>,
118}
119
120impl Interrupter {
121 pub fn new() -> Self {
123 Self::default()
124 }
125
126 pub fn stop(&self) {
128 self.state.store(true, Ordering::Relaxed);
129 }
130
131 pub fn reset(&self) {
133 self.state.store(false, Ordering::Relaxed);
134 }
135
136 pub fn should_stop(&self) -> bool {
138 self.state.load(Ordering::Relaxed)
139 }
140}