burn_train/learner/
base.rs1use crate::checkpoint::{Checkpointer, CheckpointingAction, CheckpointingStrategy};
2use crate::components::LearnerComponents;
3use crate::learner::EarlyStoppingStrategy;
4use crate::metric::store::EventStoreClient;
5use crate::LearnerSummaryConfig;
6use burn_core::lr_scheduler::LrScheduler;
7use burn_core::module::Module;
8use burn_core::optim::Optimizer;
9use burn_core::tensor::backend::Backend;
10use burn_core::tensor::Device;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13
14pub struct Learner<LC: LearnerComponents> {
18 pub(crate) model: LC::Model,
19 pub(crate) optim: LC::Optimizer,
20 pub(crate) lr_scheduler: LC::LrScheduler,
21 pub(crate) num_epochs: usize,
22 pub(crate) checkpoint: Option<usize>,
23 pub(crate) grad_accumulation: Option<usize>,
24 pub(crate) checkpointer: Option<LearnerCheckpointer<LC>>,
25 pub(crate) devices: Vec<<LC::Backend as Backend>::Device>,
26 pub(crate) interrupter: TrainingInterrupter,
27 pub(crate) early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
28 pub(crate) event_processor: LC::EventProcessor,
29 pub(crate) event_store: Arc<EventStoreClient>,
30 pub(crate) summary: Option<LearnerSummaryConfig>,
31}
32
33#[derive(new)]
34pub(crate) struct LearnerCheckpointer<LC: LearnerComponents> {
35 model: LC::CheckpointerModel,
36 optim: LC::CheckpointerOptimizer,
37 lr_scheduler: LC::CheckpointerLrScheduler,
38 strategy: LC::CheckpointerStrategy,
39}
40
41impl<LC: LearnerComponents> LearnerCheckpointer<LC> {
42 pub(crate) fn checkpoint(
43 &mut self,
44 model: &LC::Model,
45 optim: &LC::Optimizer,
46 scheduler: &LC::LrScheduler,
47 epoch: usize,
48 store: &EventStoreClient,
49 ) {
50 let actions = self.strategy.checkpointing(epoch, store);
51
52 for action in actions {
53 match action {
54 CheckpointingAction::Delete(epoch) => {
55 self.model
56 .delete(epoch)
57 .expect("Can delete model checkpoint.");
58 self.optim
59 .delete(epoch)
60 .expect("Can delete optimizer checkpoint.");
61 self.lr_scheduler
62 .delete(epoch)
63 .expect("Can delete learning rate scheduler checkpoint.");
64 }
65 CheckpointingAction::Save => {
66 self.model
67 .save(epoch, model.clone().into_record())
68 .expect("Can save model checkpoint.");
69 self.optim
70 .save(epoch, optim.to_record())
71 .expect("Can save optimizer checkpoint.");
72 self.lr_scheduler
73 .save(epoch, scheduler.to_record())
74 .expect("Can save learning rate scheduler checkpoint.");
75 }
76 }
77 }
78 }
79
80 pub(crate) fn load_checkpoint(
81 &self,
82 model: LC::Model,
83 optim: LC::Optimizer,
84 scheduler: LC::LrScheduler,
85 device: &Device<LC::Backend>,
86 epoch: usize,
87 ) -> (LC::Model, LC::Optimizer, LC::LrScheduler) {
88 let record = self
89 .model
90 .restore(epoch, device)
91 .expect("Can load model checkpoint.");
92 let model = model.load_record(record);
93
94 let record = self
95 .optim
96 .restore(epoch, device)
97 .expect("Can load optimizer checkpoint.");
98 let optim = optim.load_record(record);
99
100 let record = self
101 .lr_scheduler
102 .restore(epoch, device)
103 .expect("Can load learning rate scheduler checkpoint.");
104 let scheduler = scheduler.load_record(record);
105
106 (model, optim, scheduler)
107 }
108}
109
110#[derive(Clone, Default)]
111pub struct TrainingInterrupter {
113 state: Arc<AtomicBool>,
114}
115
116impl TrainingInterrupter {
117 pub fn new() -> Self {
119 Self::default()
120 }
121
122 pub fn stop(&self) {
124 self.state.store(true, Ordering::Relaxed);
125 }
126
127 pub fn should_stop(&self) -> bool {
129 self.state.load(Ordering::Relaxed)
130 }
131}