burn_train/learner/
base.rs1use crate::checkpoint::{Checkpointer, CheckpointingAction, CheckpointingStrategy};
2use crate::components::LearnerComponentTypes;
3use crate::metric::store::EventStoreClient;
4use crate::{CloneEarlyStoppingStrategy, LearnerSummaryConfig, LearningStrategy};
5use burn_core::module::Module;
6use burn_core::tensor::Device;
7use burn_optim::Optimizer;
8use burn_optim::lr_scheduler::LrScheduler;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11
12pub struct Learner<LC: LearnerComponentTypes> {
16 pub(crate) model: LC::Model,
17 pub(crate) optim: LC::Optimizer,
18 pub(crate) lr_scheduler: LC::LrScheduler,
19 pub(crate) num_epochs: usize,
20 pub(crate) checkpoint: Option<usize>,
21 pub(crate) grad_accumulation: Option<usize>,
22 pub(crate) checkpointer: Option<LearnerCheckpointer<LC>>,
23 pub(crate) learning_strategy: LearningStrategy<LC::Backend>,
24 pub(crate) interrupter: Interrupter,
25 pub(crate) early_stopping: Option<EarlyStoppingStrategyRef>,
26 pub(crate) event_processor: LC::EventProcessor,
27 pub(crate) event_store: Arc<EventStoreClient>,
28 pub(crate) summary: Option<LearnerSummaryConfig>,
29}
30
31pub(crate) type EarlyStoppingStrategyRef = Box<dyn CloneEarlyStoppingStrategy>;
33
34#[derive(new)]
35pub(crate) struct LearnerCheckpointer<LC: LearnerComponentTypes> {
36 model: LC::CheckpointerModel,
37 optim: LC::CheckpointerOptimizer,
38 lr_scheduler: LC::CheckpointerLrScheduler,
39 strategy: LC::CheckpointerStrategy,
40}
41
42impl<LC: LearnerComponentTypes> LearnerCheckpointer<LC> {
43 pub(crate) fn checkpoint(
44 &mut self,
45 model: &LC::Model,
46 optim: &LC::Optimizer,
47 scheduler: &LC::LrScheduler,
48 epoch: usize,
49 store: &EventStoreClient,
50 ) {
51 let actions = self.strategy.checkpointing(epoch, store);
52
53 for action in actions {
54 match action {
55 CheckpointingAction::Delete(epoch) => {
56 self.model
57 .delete(epoch)
58 .expect("Can delete model checkpoint.");
59 self.optim
60 .delete(epoch)
61 .expect("Can delete optimizer checkpoint.");
62 self.lr_scheduler
63 .delete(epoch)
64 .expect("Can delete learning rate scheduler checkpoint.");
65 }
66 CheckpointingAction::Save => {
67 self.model
68 .save(epoch, model.clone().into_record())
69 .expect("Can save model checkpoint.");
70 self.optim
71 .save(epoch, optim.to_record())
72 .expect("Can save optimizer checkpoint.");
73 self.lr_scheduler
74 .save(epoch, scheduler.to_record())
75 .expect("Can save learning rate scheduler checkpoint.");
76 }
77 }
78 }
79 }
80
81 pub(crate) fn load_checkpoint(
82 &self,
83 model: LC::Model,
84 optim: LC::Optimizer,
85 scheduler: LC::LrScheduler,
86 device: &Device<LC::Backend>,
87 epoch: usize,
88 ) -> (LC::Model, LC::Optimizer, LC::LrScheduler) {
89 let record = self
90 .model
91 .restore(epoch, device)
92 .expect("Can load model checkpoint.");
93 let model = model.load_record(record);
94
95 let record = self
96 .optim
97 .restore(epoch, device)
98 .expect("Can load optimizer checkpoint.");
99 let optim = optim.load_record(record);
100
101 let record = self
102 .lr_scheduler
103 .restore(epoch, device)
104 .expect("Can load learning rate scheduler checkpoint.");
105 let scheduler = scheduler.load_record(record);
106
107 (model, optim, scheduler)
108 }
109}
110
111#[derive(Clone, Default)]
112pub struct Interrupter {
114 state: Arc<AtomicBool>,
115}
116
117impl Interrupter {
118 pub fn new() -> Self {
120 Self::default()
121 }
122
123 pub fn stop(&self) {
125 self.state.store(true, Ordering::Relaxed);
126 }
127
128 pub fn reset(&self) {
130 self.state.store(false, Ordering::Relaxed);
131 }
132
133 pub fn should_stop(&self) -> bool {
135 self.state.load(Ordering::Relaxed)
136 }
137}