burn_train/learner/strategies/
base.rs

1use std::sync::Arc;
2
3#[cfg(feature = "ddp")]
4use burn_collective::CollectiveConfig;
5use burn_core::{module::AutodiffModule, tensor::backend::AutodiffBackend};
6
7use crate::{
8    EarlyStoppingStrategyRef, Interrupter, Learner, LearnerCheckpointer, TrainLoader,
9    TrainingResult, ValidLoader,
10    components::LearnerComponentTypes,
11    metric::{
12        processor::{EventProcessorTraining, LearnerEvent},
13        store::EventStoreClient,
14    },
15};
16
17/// How should the learner run the learning for the model
18#[derive(Clone)]
19pub enum LearningStrategy<B: AutodiffBackend> {
20    /// Training on one device
21    SingleDevice(B::Device),
22
23    /// Legacy implementation of local multi-device training
24    MultiDeviceNaive(Vec<B::Device>),
25
26    /// Training with input distributed across devices, each device has its own copy of the model.
27    /// Collective ops are used to sync the gradients after each pass.
28    #[cfg(feature = "ddp")]
29    DistributedDataParallel {
30        /// Devices on this node for the DDP
31        devices: Vec<B::Device>,
32
33        /// The configuration for collective operations
34        /// num_devices is ignored
35        config: CollectiveConfig,
36    },
37}
38
39/// Constructor for a distributed data parallel (DDP) learning strategy
40#[cfg(feature = "ddp")]
41pub fn ddp<B: AutodiffBackend>(
42    devices: Vec<B::Device>,
43    config: CollectiveConfig,
44) -> LearningStrategy<B> {
45    LearningStrategy::DistributedDataParallel { devices, config }
46}
47
48impl<B: AutodiffBackend> Default for LearningStrategy<B> {
49    fn default() -> Self {
50        Self::SingleDevice(Default::default())
51    }
52}
53
54/// Provides the `fit` function for any learning strategy
55pub(crate) trait LearningMethod<LC: LearnerComponentTypes> {
56    /// The dataloaders after being prepared for this trainin strategy
57    ///
58    /// (eg: splitting for multiple devices)
59    type PreparedDataloaders;
60    /// The model after being prepared for this training strategy
61    ///
62    /// The prepared model will be correctly initialized on the proper device for training.
63    type PreparedModel;
64
65    /// Fit the learner's model with this strategy.
66    fn fit(
67        &self,
68        mut learner: Learner<LC>,
69        dataloader_train: TrainLoader<LC>,
70        dataloader_valid: ValidLoader<LC>,
71    ) -> TrainingResult<LC::InnerModel> {
72        let mut model = learner.model;
73        let mut optim = learner.optim;
74        let mut lr_scheduler = learner.lr_scheduler;
75        let checkpoint = learner.checkpoint;
76
77        let starting_epoch = match checkpoint {
78            Some(checkpoint) => {
79                if let Some(checkpointer) = &mut learner.checkpointer {
80                    (model, optim, lr_scheduler) = checkpointer.load_checkpoint(
81                        model,
82                        optim,
83                        lr_scheduler,
84                        &Default::default(), // Load the checkpoint on the default device.
85                        checkpoint,
86                    );
87                }
88                checkpoint + 1
89            }
90            None => 1,
91        };
92
93        let dataloaders = self.prepare_dataloaders(dataloader_train, dataloader_valid);
94        let model = self.prepare_model(model);
95
96        // Training loop
97        let components = LearnerComponents {
98            optim,
99            lr_scheduler,
100            num_epochs: learner.num_epochs,
101            checkpointer: learner.checkpointer,
102            grad_accumulation: learner.grad_accumulation,
103            interrupter: learner.interrupter,
104            early_stopping: learner.early_stopping,
105            event_processor: learner.event_processor,
106            event_store: learner.event_store,
107        };
108        let (model, mut event_processor) =
109            self.learn(model, dataloaders, starting_epoch, components);
110
111        let summary = learner.summary.and_then(|summary| {
112            summary
113                .init()
114                .map(|summary| summary.with_model(model.to_string()))
115                .ok()
116        });
117
118        // Signal training end. For the TUI renderer, this handles the exit & return to main screen.
119        event_processor.process_train(LearnerEvent::End(summary));
120
121        let model = model.valid();
122        let renderer = event_processor.renderer();
123
124        TrainingResult::<LC::InnerModel> { model, renderer }
125    }
126
127    /// Prepare the dataloaders for this strategy.
128    /// The output will be used in [the learn function](Self::learn)
129    fn prepare_dataloaders(
130        &self,
131        dataloader_train: TrainLoader<LC>,
132        dataloader_valid: ValidLoader<LC>,
133    ) -> Self::PreparedDataloaders;
134
135    /// Prepare the model for this training strategy.
136    /// The output will be used in [the learn function](Self::learn)
137    fn prepare_model(&self, model: LC::Model) -> Self::PreparedModel;
138
139    /// Training loop for this strategy
140    fn learn(
141        &self,
142        model: Self::PreparedModel,
143        dataloaders: Self::PreparedDataloaders,
144        starting_epoch: usize,
145        components: LearnerComponents<LC>,
146    ) -> (LC::Model, LC::EventProcessor);
147}
148
149/// Struct to minimise parameters passed to [LearningMethod::learn]
150/// These components are used during training
151pub(crate) struct LearnerComponents<LC: LearnerComponentTypes> {
152    pub optim: LC::Optimizer,
153    pub lr_scheduler: LC::LrScheduler,
154    pub num_epochs: usize,
155    pub grad_accumulation: Option<usize>,
156    pub checkpointer: Option<LearnerCheckpointer<LC>>,
157    pub interrupter: Interrupter,
158    pub early_stopping: Option<EarlyStoppingStrategyRef>,
159    pub event_processor: LC::EventProcessor,
160    pub event_store: Arc<EventStoreClient>,
161}