burn_train/learner/supervised/strategies/
base.rs1use std::sync::Arc;
2
3#[cfg(feature = "ddp")]
4use burn_collective::CollectiveConfig;
5use burn_core::{module::AutodiffModule, prelude::Backend};
6
7use crate::{
8 EarlyStoppingStrategyRef, InferenceModel, Interrupter, Learner, LearnerSummaryConfig,
9 LearningCheckpointer, LearningResult, SupervisedTrainingEventProcessor, TrainLoader,
10 TrainingModel, ValidLoader,
11 components::LearningComponentsTypes,
12 metric::{
13 processor::{EventProcessorTraining, LearnerEvent},
14 store::EventStoreClient,
15 },
16};
17
18type LearnerDevice<LC> = <<LC as LearningComponentsTypes>::Backend as Backend>::Device;
19
20pub type CustomLearningStrategy<LC> = Arc<dyn SupervisedLearningStrategy<LC>>;
22
23#[derive(Clone, Copy, Debug)]
24pub enum MultiDeviceOptim {
26 OptimMainDevice,
28 OptimSharded,
30}
31
32#[derive(Clone)]
34pub enum TrainingStrategy<LC: LearningComponentsTypes> {
35 SingleDevice(LearnerDevice<LC>),
37 MultiDevice(Vec<LearnerDevice<LC>>, MultiDeviceOptim),
40 Custom(CustomLearningStrategy<LC>),
42 #[cfg(feature = "ddp")]
45 DistributedDataParallel {
46 devices: Vec<LearnerDevice<LC>>,
48
49 config: CollectiveConfig,
52 },
53}
54
55#[cfg(feature = "ddp")]
57pub fn ddp<LC: LearningComponentsTypes>(
58 devices: Vec<LearnerDevice<LC>>,
59 config: CollectiveConfig,
60) -> TrainingStrategy<LC> {
61 TrainingStrategy::DistributedDataParallel { devices, config }
62}
63
64impl<LC: LearningComponentsTypes> Default for TrainingStrategy<LC> {
65 fn default() -> Self {
66 Self::SingleDevice(Default::default())
67 }
68}
69
70pub struct TrainingComponents<LC: LearningComponentsTypes> {
73 pub num_epochs: usize,
75 pub checkpoint: Option<usize>,
77 pub checkpointer: Option<LearningCheckpointer<LC>>,
79 pub grad_accumulation: Option<usize>,
81 pub interrupter: Interrupter,
83 pub early_stopping: Option<EarlyStoppingStrategyRef>,
85 pub event_processor: SupervisedTrainingEventProcessor<LC>,
87 pub event_store: Arc<EventStoreClient>,
89 pub summary: Option<LearnerSummaryConfig>,
91}
92
93pub trait SupervisedLearningStrategy<LC: LearningComponentsTypes> {
95 fn train(
97 &self,
98 mut learner: Learner<LC>,
99 dataloader_train: TrainLoader<LC>,
100 dataloader_valid: ValidLoader<LC>,
101 mut training_components: TrainingComponents<LC>,
102 ) -> LearningResult<InferenceModel<LC>> {
103 let starting_epoch = match training_components.checkpoint {
104 Some(checkpoint) => {
105 if let Some(checkpointer) = &mut training_components.checkpointer {
106 learner =
107 checkpointer.load_checkpoint(learner, &Default::default(), checkpoint);
108 }
109 checkpoint + 1
110 }
111 None => 1,
112 };
113
114 let summary_config = training_components.summary.clone();
115
116 training_components
118 .event_processor
119 .process_train(LearnerEvent::Start);
120 let (model, mut event_processor) = self.fit(
122 training_components,
123 learner,
124 dataloader_train,
125 dataloader_valid,
126 starting_epoch,
127 );
128
129 let summary = summary_config.and_then(|summary| {
130 summary
131 .init()
132 .map(|summary| summary.with_model(model.to_string()))
133 .ok()
134 });
135
136 event_processor.process_train(LearnerEvent::End(summary));
138
139 let model = model.valid();
140 let renderer = event_processor.renderer();
141
142 LearningResult::<InferenceModel<LC>> { model, renderer }
143 }
144
145 fn fit(
147 &self,
148 training_components: TrainingComponents<LC>,
149 learner: Learner<LC>,
150 dataloader_train: TrainLoader<LC>,
151 dataloader_valid: ValidLoader<LC>,
152 starting_epoch: usize,
153 ) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>);
154}