burn_train/learner/strategies/
base.rs1use std::sync::Arc;
2
3#[cfg(feature = "ddp")]
4use burn_collective::CollectiveConfig;
5#[cfg(feature = "ddp")]
6use burn_core::tensor::backend::AutodiffBackend;
7use burn_core::{module::AutodiffModule, prelude::Backend};
8
9use crate::{
10 EarlyStoppingStrategyRef, Interrupter, Learner, LearnerCheckpointer, TrainLoader,
11 TrainingResult, ValidLoader,
12 components::LearnerComponentTypes,
13 metric::{
14 processor::{EventProcessorTraining, LearnerEvent},
15 store::EventStoreClient,
16 },
17 multi::CustomMultiDeviceLearningStrategy,
18 single::CustomSingleDeviceLearningStrategy,
19};
20
21pub use crate::multi::MultiDeviceOptim;
22
23type LearnerDevice<LC> = <<LC as LearnerComponentTypes>::Backend as Backend>::Device;
24
25#[derive(Clone)]
27pub enum LearningStrategy<LC: LearnerComponentTypes> {
28 SingleDevice(LearnerDevice<LC>),
30
31 CustomSingleDevice(CustomSingleDeviceLearningStrategy<LC>),
33
34 MultiDevice(Vec<LearnerDevice<LC>>, MultiDeviceOptim),
37
38 CustomMultiDevice(CustomMultiDeviceLearningStrategy<LC>),
40
41 #[cfg(feature = "ddp")]
44 DistributedDataParallel {
45 devices: Vec<LearnerDevice<LC>>,
47
48 config: CollectiveConfig,
51 },
52}
53
54#[cfg(feature = "ddp")]
56pub fn ddp<B: AutodiffBackend, LC: LearnerComponentTypes>(
57 devices: Vec<LearnerDevice<LC>>,
58 config: CollectiveConfig,
59) -> LearningStrategy<LC> {
60 LearningStrategy::DistributedDataParallel { devices, config }
61}
62
63impl<LC: LearnerComponentTypes> Default for LearningStrategy<LC> {
64 fn default() -> Self {
65 Self::SingleDevice(Default::default())
66 }
67}
68
69pub trait LearningMethod<LC: LearnerComponentTypes> {
71 type PreparedDataloaders;
75 type PreparedModel;
79
80 fn fit(
82 &self,
83 mut learner: Learner<LC>,
84 dataloader_train: TrainLoader<LC>,
85 dataloader_valid: ValidLoader<LC>,
86 ) -> TrainingResult<LC::InnerModel> {
87 let mut model = learner.model;
88 let mut optim = learner.optim;
89 let mut lr_scheduler = learner.lr_scheduler;
90 let checkpoint = learner.checkpoint;
91
92 let starting_epoch = match checkpoint {
93 Some(checkpoint) => {
94 if let Some(checkpointer) = &mut learner.checkpointer {
95 (model, optim, lr_scheduler) = checkpointer.load_checkpoint(
96 model,
97 optim,
98 lr_scheduler,
99 &Default::default(), checkpoint,
101 );
102 }
103 checkpoint + 1
104 }
105 None => 1,
106 };
107
108 let dataloaders = self.prepare_dataloaders(dataloader_train, dataloader_valid);
109 let model = self.prepare_model(model);
110
111 let mut components = LearnerComponents {
113 optim,
114 lr_scheduler,
115 num_epochs: learner.num_epochs,
116 checkpointer: learner.checkpointer,
117 grad_accumulation: learner.grad_accumulation,
118 interrupter: learner.interrupter,
119 early_stopping: learner.early_stopping,
120 event_processor: learner.event_processor,
121 event_store: learner.event_store,
122 };
123 components
125 .event_processor
126 .process_train(LearnerEvent::Start);
127 let (model, mut event_processor) =
128 self.learn(model, dataloaders, starting_epoch, components);
129
130 let summary = learner.summary.and_then(|summary| {
131 summary
132 .init()
133 .map(|summary| summary.with_model(model.to_string()))
134 .ok()
135 });
136
137 event_processor.process_train(LearnerEvent::End(summary));
139
140 let model = model.valid();
141 let renderer = event_processor.renderer();
142
143 TrainingResult::<LC::InnerModel> { model, renderer }
144 }
145
146 fn prepare_dataloaders(
149 &self,
150 dataloader_train: TrainLoader<LC>,
151 dataloader_valid: ValidLoader<LC>,
152 ) -> Self::PreparedDataloaders;
153
154 fn prepare_model(&self, model: LC::Model) -> Self::PreparedModel;
157
158 fn learn(
160 &self,
161 model: Self::PreparedModel,
162 dataloaders: Self::PreparedDataloaders,
163 starting_epoch: usize,
164 components: LearnerComponents<LC>,
165 ) -> (LC::Model, LC::EventProcessor);
166}
167
168pub struct LearnerComponents<LC: LearnerComponentTypes> {
171 pub optim: LC::Optimizer,
173 pub lr_scheduler: LC::LrScheduler,
175 pub num_epochs: usize,
177 pub grad_accumulation: Option<usize>,
179 pub checkpointer: Option<LearnerCheckpointer<LC>>,
181 pub interrupter: Interrupter,
183 pub early_stopping: Option<EarlyStoppingStrategyRef>,
185 pub event_processor: LC::EventProcessor,
187 pub event_store: Arc<EventStoreClient>,
189}