burn_train/learner/strategies/
base.rs

1use std::sync::Arc;
2
3#[cfg(feature = "ddp")]
4use burn_collective::CollectiveConfig;
5#[cfg(feature = "ddp")]
6use burn_core::tensor::backend::AutodiffBackend;
7use burn_core::{module::AutodiffModule, prelude::Backend};
8
9use crate::{
10    EarlyStoppingStrategyRef, Interrupter, Learner, LearnerCheckpointer, TrainLoader,
11    TrainingResult, ValidLoader,
12    components::LearnerComponentTypes,
13    metric::{
14        processor::{EventProcessorTraining, LearnerEvent},
15        store::EventStoreClient,
16    },
17    multi::CustomMultiDeviceLearningStrategy,
18    single::CustomSingleDeviceLearningStrategy,
19};
20
21pub use crate::multi::MultiDeviceOptim;
22
23type LearnerDevice<LC> = <<LC as LearnerComponentTypes>::Backend as Backend>::Device;
24
25/// How should the learner run the learning for the model
26#[derive(Clone)]
27pub enum LearningStrategy<LC: LearnerComponentTypes> {
28    /// Training on one device
29    SingleDevice(LearnerDevice<LC>),
30
31    /// Training on one device with a custom learning strategy
32    CustomSingleDevice(CustomSingleDeviceLearningStrategy<LC>),
33
34    /// Performs data-parralel distributed training where the optimization is
35    /// done on an elected master device.
36    MultiDevice(Vec<LearnerDevice<LC>>, MultiDeviceOptim),
37
38    /// Training on multiple devices with a custom learning strategy.
39    CustomMultiDevice(CustomMultiDeviceLearningStrategy<LC>),
40
41    /// Training with input distributed across devices, each device has its own copy of the model.
42    /// Collective ops are used to sync the gradients after each pass.
43    #[cfg(feature = "ddp")]
44    DistributedDataParallel {
45        /// Devices on this node for the DDP
46        devices: Vec<LearnerDevice<LC>>,
47
48        /// The configuration for collective operations
49        /// num_devices is ignored
50        config: CollectiveConfig,
51    },
52}
53
54/// Constructor for a distributed data parallel (DDP) learning strategy
55#[cfg(feature = "ddp")]
56pub fn ddp<B: AutodiffBackend, LC: LearnerComponentTypes>(
57    devices: Vec<LearnerDevice<LC>>,
58    config: CollectiveConfig,
59) -> LearningStrategy<LC> {
60    LearningStrategy::DistributedDataParallel { devices, config }
61}
62
63impl<LC: LearnerComponentTypes> Default for LearningStrategy<LC> {
64    fn default() -> Self {
65        Self::SingleDevice(Default::default())
66    }
67}
68
69/// Provides the `fit` function for any learning strategy
70pub trait LearningMethod<LC: LearnerComponentTypes> {
71    /// The dataloaders after being prepared for this trainin strategy
72    ///
73    /// (eg: splitting for multiple devices)
74    type PreparedDataloaders;
75    /// The model after being prepared for this training strategy
76    ///
77    /// The prepared model will be correctly initialized on the proper device for training.
78    type PreparedModel;
79
80    /// Fit the learner's model with this strategy.
81    fn fit(
82        &self,
83        mut learner: Learner<LC>,
84        dataloader_train: TrainLoader<LC>,
85        dataloader_valid: ValidLoader<LC>,
86    ) -> TrainingResult<LC::InnerModel> {
87        let mut model = learner.model;
88        let mut optim = learner.optim;
89        let mut lr_scheduler = learner.lr_scheduler;
90        let checkpoint = learner.checkpoint;
91
92        let starting_epoch = match checkpoint {
93            Some(checkpoint) => {
94                if let Some(checkpointer) = &mut learner.checkpointer {
95                    (model, optim, lr_scheduler) = checkpointer.load_checkpoint(
96                        model,
97                        optim,
98                        lr_scheduler,
99                        &Default::default(), // Load the checkpoint on the default device.
100                        checkpoint,
101                    );
102                }
103                checkpoint + 1
104            }
105            None => 1,
106        };
107
108        let dataloaders = self.prepare_dataloaders(dataloader_train, dataloader_valid);
109        let model = self.prepare_model(model);
110
111        // Training loop
112        let mut components = LearnerComponents {
113            optim,
114            lr_scheduler,
115            num_epochs: learner.num_epochs,
116            checkpointer: learner.checkpointer,
117            grad_accumulation: learner.grad_accumulation,
118            interrupter: learner.interrupter,
119            early_stopping: learner.early_stopping,
120            event_processor: learner.event_processor,
121            event_store: learner.event_store,
122        };
123        // Event processor start training
124        components
125            .event_processor
126            .process_train(LearnerEvent::Start);
127        let (model, mut event_processor) =
128            self.learn(model, dataloaders, starting_epoch, components);
129
130        let summary = learner.summary.and_then(|summary| {
131            summary
132                .init()
133                .map(|summary| summary.with_model(model.to_string()))
134                .ok()
135        });
136
137        // Signal training end. For the TUI renderer, this handles the exit & return to main screen.
138        event_processor.process_train(LearnerEvent::End(summary));
139
140        let model = model.valid();
141        let renderer = event_processor.renderer();
142
143        TrainingResult::<LC::InnerModel> { model, renderer }
144    }
145
146    /// Prepare the dataloaders for this strategy.
147    /// The output will be used in [the learn function](Self::learn)
148    fn prepare_dataloaders(
149        &self,
150        dataloader_train: TrainLoader<LC>,
151        dataloader_valid: ValidLoader<LC>,
152    ) -> Self::PreparedDataloaders;
153
154    /// Prepare the model for this training strategy.
155    /// The output will be used in [the learn function](Self::learn)
156    fn prepare_model(&self, model: LC::Model) -> Self::PreparedModel;
157
158    /// Training loop for this strategy
159    fn learn(
160        &self,
161        model: Self::PreparedModel,
162        dataloaders: Self::PreparedDataloaders,
163        starting_epoch: usize,
164        components: LearnerComponents<LC>,
165    ) -> (LC::Model, LC::EventProcessor);
166}
167
168/// Struct to minimise parameters passed to [LearningMethod::learn]
169/// These components are used during training
170pub struct LearnerComponents<LC: LearnerComponentTypes> {
171    /// The [Optimizer](LearnerComponentTypes::Optimizer) used for the training.
172    pub optim: LC::Optimizer,
173    /// The [learning rate scheduler](LearnerComponentTypes::LrScheduler) used for the training.
174    pub lr_scheduler: LC::LrScheduler,
175    /// The number of epochs the training should last.
176    pub num_epochs: usize,
177    /// Enables gradients accumulation.
178    pub grad_accumulation: Option<usize>,
179    /// A [LearnerCheckpointer](LearnerCheckpointer) used to save and load training checkpoints.
180    pub checkpointer: Option<LearnerCheckpointer<LC>>,
181    /// An [Interupter](Interrupter) that allows aborting the training/evaluation process early.
182    pub interrupter: Interrupter,
183    /// [Cloneable reference to an early stopping strategy](EarlyStoppingStrategyRef).
184    pub early_stopping: Option<EarlyStoppingStrategyRef>,
185    /// An [EventProcessor](LearnerComponentTypes::EventProcessor) that processes events happening during training and validation.
186    pub event_processor: LC::EventProcessor,
187    /// A reference to an [EventStoreClient](EventStoreClient).
188    pub event_store: Arc<EventStoreClient>,
189}