Skip to main content

burn_train/learner/supervised/strategies/
base.rs

1use std::sync::Arc;
2
3#[cfg(feature = "ddp")]
4use burn_core::tensor::backend::distributed::{DistributedBackend, DistributedConfig};
5use burn_core::{module::AutodiffModule, prelude::Backend};
6
7use crate::{
8    EarlyStoppingStrategyRef, InferenceModel, Interrupter, Learner, LearnerSummaryConfig,
9    LearningCheckpointer, LearningResult, SupervisedTrainingEventProcessor, TrainLoader,
10    TrainingModel, ValidLoader,
11    components::LearningComponentsTypes,
12    metric::{
13        processor::{EventProcessorTraining, LearnerEvent},
14        store::EventStoreClient,
15    },
16};
17
18/// A reference to an implementation of SupervisedLearningStrategy.
19pub type CustomLearningStrategy<LC> = Arc<dyn SupervisedLearningStrategy<LC>>;
20
21#[derive(Clone, Copy, Debug)]
22/// Determine how the optimization is performed when training with multiple devices.
23pub enum MultiDeviceOptim {
24    /// The optimization is done on an elected device.
25    OptimMainDevice,
26    /// The optimization is sharded across all devices.
27    OptimSharded,
28}
29
30/// Describes where training runs.
31pub enum ExecutionStrategy<B: Backend> {
32    /// Training on one device
33    SingleDevice(B::Device),
34    /// Performs data-parallel distributed training where the optimization is
35    /// done on an elected master device.
36    MultiDevice(Vec<B::Device>, MultiDeviceOptim),
37    /// Training with input distributed across devices, each device has its own copy of the model.
38    /// Collective ops are used to sync the gradients after each pass.
39    #[cfg(feature = "ddp")]
40    DistributedDataParallel {
41        /// Devices on this node for the DDP
42        devices: Vec<B::Device>,
43        /// The distributed runtime.
44        runtime: Box<dyn DistributedRuntime>,
45    },
46}
47
48impl<B: Backend> ExecutionStrategy<B> {
49    /// Returns the primary device responsible for coordination.
50    pub fn main_device(&self) -> &B::Device {
51        match self {
52            ExecutionStrategy::SingleDevice(device) => device,
53            ExecutionStrategy::MultiDevice(devices, _optim) => &devices[0],
54            #[cfg(feature = "ddp")]
55            ExecutionStrategy::DistributedDataParallel {
56                devices,
57                runtime: _,
58            } => &devices[0],
59        }
60    }
61
62    /// Creates a strategy for a single device.
63    pub fn single(device: B::Device) -> Self {
64        Self::SingleDevice(device)
65    }
66
67    /// Creates a multi-device strategy.
68    pub fn multi(devices: Vec<B::Device>, optim: MultiDeviceOptim) -> Self {
69        Self::MultiDevice(devices, optim)
70    }
71}
72
73#[cfg(feature = "ddp")]
74impl<B: DistributedBackend> ExecutionStrategy<B> {
75    /// Creates a distributed data parallel (DDP) strategy.
76    pub fn ddp(devices: Vec<B::Device>, config: DistributedConfig) -> Self {
77        let session = DistributedSession::<B> {
78            devices: devices.clone(),
79            config,
80        };
81        Self::DistributedDataParallel {
82            devices,
83            runtime: Box::new(session),
84        }
85    }
86}
87
88/// How should the learner run the learning for the model
89pub enum TrainingStrategy<LC: LearningComponentsTypes> {
90    /// Default training loop with specified device strategy.
91    Default(ExecutionStrategy<LC::Backend>),
92    /// Training using a custom learning strategy
93    Custom(CustomLearningStrategy<LC>),
94}
95
96impl<LC: LearningComponentsTypes> From<ExecutionStrategy<LC::Backend>> for TrainingStrategy<LC> {
97    fn from(value: ExecutionStrategy<LC::Backend>) -> Self {
98        Self::Default(value)
99    }
100}
101
102#[cfg(feature = "ddp")]
103/// Manages the orchestration of a distributed training environment.
104///
105/// This trait provides a generic interface to initialize and finalize
106/// the communication infrastructure required for cross-device synchronization.
107pub trait DistributedRuntime: Send + Sync + 'static {
108    /// Initialize the distributed environment.
109    fn start(&self);
110
111    /// Cleanup the distributed environment.
112    fn close(&self);
113}
114
115#[cfg(feature = "ddp")]
116/// A concrete implementation of [`DistributedRuntime`] for a [distributed backend](DistributedBackend).
117///
118/// It encapsulates the necessary configuration and device information to
119/// manage the resources related to a [`DistributedBackend`].
120pub struct DistributedSession<B: DistributedBackend> {
121    devices: Vec<B::Device>,
122    config: DistributedConfig,
123}
124
125#[cfg(feature = "ddp")]
126impl<B: DistributedBackend> DistributedRuntime for DistributedSession<B> {
127    fn start(&self) {
128        B::start_communication_server(&self.devices, self.config.clone());
129    }
130
131    fn close(&self) {
132        B::close_communication_server(&self.devices[0]);
133    }
134}
135
136impl<LC: LearningComponentsTypes> Default for TrainingStrategy<LC> {
137    fn default() -> Self {
138        Self::Default(ExecutionStrategy::SingleDevice(Default::default()))
139    }
140}
141
142/// Struct to minimise parameters passed to [SupervisedLearningStrategy::train].
143/// These components are used during training.
144pub struct TrainingComponents<LC: LearningComponentsTypes> {
145    /// The total number of epochs
146    pub num_epochs: usize,
147    /// The epoch number from which to continue the training.
148    pub checkpoint: Option<usize>,
149    /// A checkpointer used to load and save learner checkpoints.
150    pub checkpointer: Option<LearningCheckpointer<LC>>,
151    /// Enables gradients accumulation.
152    pub grad_accumulation: Option<usize>,
153    /// An [Interupter](Interrupter) that allows aborting the training/evaluation process early.
154    pub interrupter: Interrupter,
155    /// Cloneable reference to an early stopping strategy.
156    pub early_stopping: Option<EarlyStoppingStrategyRef>,
157    /// An [EventProcessor](crate::EventProcessorTraining) that processes events happening during training and validation.
158    pub event_processor: SupervisedTrainingEventProcessor<LC>,
159    /// A reference to an [EventStoreClient](EventStoreClient).
160    pub event_store: Arc<EventStoreClient>,
161    /// Config for creating a summary of the learning
162    pub summary: Option<LearnerSummaryConfig>,
163}
164
165/// Provides the `fit` function for any learning strategy
166pub trait SupervisedLearningStrategy<LC: LearningComponentsTypes> {
167    /// Train the learner's model with this strategy.
168    fn train(
169        &self,
170        mut learner: Learner<LC>,
171        dataloader_train: TrainLoader<LC>,
172        dataloader_valid: ValidLoader<LC>,
173        mut training_components: TrainingComponents<LC>,
174    ) -> LearningResult<InferenceModel<LC>> {
175        let starting_epoch = match training_components.checkpoint {
176            Some(checkpoint) => {
177                if let Some(checkpointer) = &mut training_components.checkpointer {
178                    learner =
179                        checkpointer.load_checkpoint(learner, &Default::default(), checkpoint);
180                }
181                checkpoint + 1
182            }
183            None => 1,
184        };
185
186        let summary_config = training_components.summary.clone();
187
188        // Event processor start training
189        training_components
190            .event_processor
191            .process_train(LearnerEvent::Start);
192        // Training loop
193        let (model, mut event_processor) = self.fit(
194            training_components,
195            learner,
196            dataloader_train,
197            dataloader_valid,
198            starting_epoch,
199        );
200
201        let summary = summary_config.and_then(|summary| {
202            summary
203                .init()
204                .map(|summary| summary.with_model(model.to_string()))
205                .ok()
206        });
207
208        // Signal training end. For the TUI renderer, this handles the exit & return to main screen.
209        event_processor.process_train(LearnerEvent::End(summary));
210
211        let model = model.valid();
212        let renderer = event_processor.renderer();
213
214        LearningResult::<InferenceModel<LC>> { model, renderer }
215    }
216
217    /// Training loop for this strategy
218    fn fit(
219        &self,
220        training_components: TrainingComponents<LC>,
221        learner: Learner<LC>,
222        dataloader_train: TrainLoader<LC>,
223        dataloader_valid: ValidLoader<LC>,
224        starting_epoch: usize,
225    ) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>);
226}