burn_train/learner/strategies/
base.rs1use std::sync::Arc;
2
3#[cfg(feature = "ddp")]
4use burn_collective::CollectiveConfig;
5use burn_core::{module::AutodiffModule, tensor::backend::AutodiffBackend};
6
7use crate::{
8 EarlyStoppingStrategyRef, Interrupter, Learner, LearnerCheckpointer, TrainLoader,
9 TrainingResult, ValidLoader,
10 components::LearnerComponentTypes,
11 metric::{
12 processor::{EventProcessorTraining, LearnerEvent},
13 store::EventStoreClient,
14 },
15};
16
17#[derive(Clone)]
19pub enum LearningStrategy<B: AutodiffBackend> {
20 SingleDevice(B::Device),
22
23 MultiDeviceNaive(Vec<B::Device>),
25
26 #[cfg(feature = "ddp")]
29 DistributedDataParallel {
30 devices: Vec<B::Device>,
32
33 config: CollectiveConfig,
36 },
37}
38
39#[cfg(feature = "ddp")]
41pub fn ddp<B: AutodiffBackend>(
42 devices: Vec<B::Device>,
43 config: CollectiveConfig,
44) -> LearningStrategy<B> {
45 LearningStrategy::DistributedDataParallel { devices, config }
46}
47
48impl<B: AutodiffBackend> Default for LearningStrategy<B> {
49 fn default() -> Self {
50 Self::SingleDevice(Default::default())
51 }
52}
53
54pub(crate) trait LearningMethod<LC: LearnerComponentTypes> {
56 type PreparedDataloaders;
60 type PreparedModel;
64
65 fn fit(
67 &self,
68 mut learner: Learner<LC>,
69 dataloader_train: TrainLoader<LC>,
70 dataloader_valid: ValidLoader<LC>,
71 ) -> TrainingResult<LC::InnerModel> {
72 let mut model = learner.model;
73 let mut optim = learner.optim;
74 let mut lr_scheduler = learner.lr_scheduler;
75 let checkpoint = learner.checkpoint;
76
77 let starting_epoch = match checkpoint {
78 Some(checkpoint) => {
79 if let Some(checkpointer) = &mut learner.checkpointer {
80 (model, optim, lr_scheduler) = checkpointer.load_checkpoint(
81 model,
82 optim,
83 lr_scheduler,
84 &Default::default(), checkpoint,
86 );
87 }
88 checkpoint + 1
89 }
90 None => 1,
91 };
92
93 let dataloaders = self.prepare_dataloaders(dataloader_train, dataloader_valid);
94 let model = self.prepare_model(model);
95
96 let components = LearnerComponents {
98 optim,
99 lr_scheduler,
100 num_epochs: learner.num_epochs,
101 checkpointer: learner.checkpointer,
102 grad_accumulation: learner.grad_accumulation,
103 interrupter: learner.interrupter,
104 early_stopping: learner.early_stopping,
105 event_processor: learner.event_processor,
106 event_store: learner.event_store,
107 };
108 let (model, mut event_processor) =
109 self.learn(model, dataloaders, starting_epoch, components);
110
111 let summary = learner.summary.and_then(|summary| {
112 summary
113 .init()
114 .map(|summary| summary.with_model(model.to_string()))
115 .ok()
116 });
117
118 event_processor.process_train(LearnerEvent::End(summary));
120
121 let model = model.valid();
122 let renderer = event_processor.renderer();
123
124 TrainingResult::<LC::InnerModel> { model, renderer }
125 }
126
127 fn prepare_dataloaders(
130 &self,
131 dataloader_train: TrainLoader<LC>,
132 dataloader_valid: ValidLoader<LC>,
133 ) -> Self::PreparedDataloaders;
134
135 fn prepare_model(&self, model: LC::Model) -> Self::PreparedModel;
138
139 fn learn(
141 &self,
142 model: Self::PreparedModel,
143 dataloaders: Self::PreparedDataloaders,
144 starting_epoch: usize,
145 components: LearnerComponents<LC>,
146 ) -> (LC::Model, LC::EventProcessor);
147}
148
149pub(crate) struct LearnerComponents<LC: LearnerComponentTypes> {
152 pub optim: LC::Optimizer,
153 pub lr_scheduler: LC::LrScheduler,
154 pub num_epochs: usize,
155 pub grad_accumulation: Option<usize>,
156 pub checkpointer: Option<LearnerCheckpointer<LC>>,
157 pub interrupter: Interrupter,
158 pub early_stopping: Option<EarlyStoppingStrategyRef>,
159 pub event_processor: LC::EventProcessor,
160 pub event_store: Arc<EventStoreClient>,
161}