burn_train/learner/supervised/strategies/
base.rs1use std::sync::Arc;
2
3#[cfg(feature = "ddp")]
4use burn_core::tensor::backend::distributed::{DistributedBackend, DistributedConfig};
5use burn_core::{module::AutodiffModule, prelude::Backend};
6
7use crate::{
8 EarlyStoppingStrategyRef, InferenceModel, Interrupter, Learner, LearnerSummaryConfig,
9 LearningCheckpointer, LearningResult, SupervisedTrainingEventProcessor, TrainLoader,
10 TrainingModel, ValidLoader,
11 components::LearningComponentsTypes,
12 metric::{
13 processor::{EventProcessorTraining, LearnerEvent},
14 store::EventStoreClient,
15 },
16};
17
18pub type CustomLearningStrategy<LC> = Arc<dyn SupervisedLearningStrategy<LC>>;
20
21#[derive(Clone, Copy, Debug)]
22pub enum MultiDeviceOptim {
24 OptimMainDevice,
26 OptimSharded,
28}
29
30pub enum ExecutionStrategy<B: Backend> {
32 SingleDevice(B::Device),
34 MultiDevice(Vec<B::Device>, MultiDeviceOptim),
37 #[cfg(feature = "ddp")]
40 DistributedDataParallel {
41 devices: Vec<B::Device>,
43 runtime: Box<dyn DistributedRuntime>,
45 },
46}
47
48impl<B: Backend> ExecutionStrategy<B> {
49 pub fn main_device(&self) -> &B::Device {
51 match self {
52 ExecutionStrategy::SingleDevice(device) => device,
53 ExecutionStrategy::MultiDevice(devices, _optim) => &devices[0],
54 #[cfg(feature = "ddp")]
55 ExecutionStrategy::DistributedDataParallel {
56 devices,
57 runtime: _,
58 } => &devices[0],
59 }
60 }
61
62 pub fn single(device: B::Device) -> Self {
64 Self::SingleDevice(device)
65 }
66
67 pub fn multi(devices: Vec<B::Device>, optim: MultiDeviceOptim) -> Self {
69 Self::MultiDevice(devices, optim)
70 }
71}
72
73#[cfg(feature = "ddp")]
74impl<B: DistributedBackend> ExecutionStrategy<B> {
75 pub fn ddp(devices: Vec<B::Device>, config: DistributedConfig) -> Self {
77 let session = DistributedSession::<B> {
78 devices: devices.clone(),
79 config,
80 };
81 Self::DistributedDataParallel {
82 devices,
83 runtime: Box::new(session),
84 }
85 }
86}
87
88pub enum TrainingStrategy<LC: LearningComponentsTypes> {
90 Default(ExecutionStrategy<LC::Backend>),
92 Custom(CustomLearningStrategy<LC>),
94}
95
96impl<LC: LearningComponentsTypes> From<ExecutionStrategy<LC::Backend>> for TrainingStrategy<LC> {
97 fn from(value: ExecutionStrategy<LC::Backend>) -> Self {
98 Self::Default(value)
99 }
100}
101
102#[cfg(feature = "ddp")]
103pub trait DistributedRuntime: Send + Sync + 'static {
108 fn start(&self);
110
111 fn close(&self);
113}
114
115#[cfg(feature = "ddp")]
116pub struct DistributedSession<B: DistributedBackend> {
121 devices: Vec<B::Device>,
122 config: DistributedConfig,
123}
124
125#[cfg(feature = "ddp")]
126impl<B: DistributedBackend> DistributedRuntime for DistributedSession<B> {
127 fn start(&self) {
128 B::start_communication_server(&self.devices, self.config.clone());
129 }
130
131 fn close(&self) {
132 B::close_communication_server(&self.devices[0]);
133 }
134}
135
136impl<LC: LearningComponentsTypes> Default for TrainingStrategy<LC> {
137 fn default() -> Self {
138 Self::Default(ExecutionStrategy::SingleDevice(Default::default()))
139 }
140}
141
142pub struct TrainingComponents<LC: LearningComponentsTypes> {
145 pub num_epochs: usize,
147 pub checkpoint: Option<usize>,
149 pub checkpointer: Option<LearningCheckpointer<LC>>,
151 pub grad_accumulation: Option<usize>,
153 pub interrupter: Interrupter,
155 pub early_stopping: Option<EarlyStoppingStrategyRef>,
157 pub event_processor: SupervisedTrainingEventProcessor<LC>,
159 pub event_store: Arc<EventStoreClient>,
161 pub summary: Option<LearnerSummaryConfig>,
163}
164
165pub trait SupervisedLearningStrategy<LC: LearningComponentsTypes> {
167 fn train(
169 &self,
170 mut learner: Learner<LC>,
171 dataloader_train: TrainLoader<LC>,
172 dataloader_valid: ValidLoader<LC>,
173 mut training_components: TrainingComponents<LC>,
174 ) -> LearningResult<InferenceModel<LC>> {
175 let starting_epoch = match training_components.checkpoint {
176 Some(checkpoint) => {
177 if let Some(checkpointer) = &mut training_components.checkpointer {
178 learner =
179 checkpointer.load_checkpoint(learner, &Default::default(), checkpoint);
180 }
181 checkpoint + 1
182 }
183 None => 1,
184 };
185
186 let summary_config = training_components.summary.clone();
187
188 training_components
190 .event_processor
191 .process_train(LearnerEvent::Start);
192 let (model, mut event_processor) = self.fit(
194 training_components,
195 learner,
196 dataloader_train,
197 dataloader_valid,
198 starting_epoch,
199 );
200
201 let summary = summary_config.and_then(|summary| {
202 summary
203 .init()
204 .map(|summary| summary.with_model(model.to_string()))
205 .ok()
206 });
207
208 event_processor.process_train(LearnerEvent::End(summary));
210
211 let model = model.valid();
212 let renderer = event_processor.renderer();
213
214 LearningResult::<InferenceModel<LC>> { model, renderer }
215 }
216
217 fn fit(
219 &self,
220 training_components: TrainingComponents<LC>,
221 learner: Learner<LC>,
222 dataloader_train: TrainLoader<LC>,
223 dataloader_valid: ValidLoader<LC>,
224 starting_epoch: usize,
225 ) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>);
226}