burn_train/learner/supervised/strategies/
base.rs

1use std::sync::Arc;
2
3#[cfg(feature = "ddp")]
4use burn_collective::CollectiveConfig;
5use burn_core::{module::AutodiffModule, prelude::Backend};
6
7use crate::{
8    EarlyStoppingStrategyRef, InferenceModel, Interrupter, Learner, LearnerSummaryConfig,
9    LearningCheckpointer, LearningResult, SupervisedTrainingEventProcessor, TrainLoader,
10    TrainingModel, ValidLoader,
11    components::LearningComponentsTypes,
12    metric::{
13        processor::{EventProcessorTraining, LearnerEvent},
14        store::EventStoreClient,
15    },
16};
17
18type LearnerDevice<LC> = <<LC as LearningComponentsTypes>::Backend as Backend>::Device;
19
20/// A reference to an implementation of SupervisedLearningStrategy.
21pub type CustomLearningStrategy<LC> = Arc<dyn SupervisedLearningStrategy<LC>>;
22
23#[derive(Clone, Copy, Debug)]
24/// Determine how the optimization is performed when training with multiple devices.
25pub enum MultiDeviceOptim {
26    /// The optimization is done on an elected device.
27    OptimMainDevice,
28    /// The optimization is sharded across all devices.
29    OptimSharded,
30}
31
32/// How should the learner run the learning for the model
33#[derive(Clone)]
34pub enum TrainingStrategy<LC: LearningComponentsTypes> {
35    /// Training on one device
36    SingleDevice(LearnerDevice<LC>),
37    /// Performs data-parallel distributed training where the optimization is
38    /// done on an elected master device.
39    MultiDevice(Vec<LearnerDevice<LC>>, MultiDeviceOptim),
40    /// Training using a custom learning strategy
41    Custom(CustomLearningStrategy<LC>),
42    /// Training with input distributed across devices, each device has its own copy of the model.
43    /// Collective ops are used to sync the gradients after each pass.
44    #[cfg(feature = "ddp")]
45    DistributedDataParallel {
46        /// Devices on this node for the DDP
47        devices: Vec<LearnerDevice<LC>>,
48
49        /// The configuration for collective operations
50        /// num_devices is ignored
51        config: CollectiveConfig,
52    },
53}
54
55/// Constructor for a distributed data parallel (DDP) learning strategy
56#[cfg(feature = "ddp")]
57pub fn ddp<LC: LearningComponentsTypes>(
58    devices: Vec<LearnerDevice<LC>>,
59    config: CollectiveConfig,
60) -> TrainingStrategy<LC> {
61    TrainingStrategy::DistributedDataParallel { devices, config }
62}
63
64impl<LC: LearningComponentsTypes> Default for TrainingStrategy<LC> {
65    fn default() -> Self {
66        Self::SingleDevice(Default::default())
67    }
68}
69
70/// Struct to minimise parameters passed to [SupervisedLearningStrategy::train].
71/// These components are used during training.
72pub struct TrainingComponents<LC: LearningComponentsTypes> {
73    /// The total number of epochs
74    pub num_epochs: usize,
75    /// The epoch number from which to continue the training.
76    pub checkpoint: Option<usize>,
77    /// A checkpointer used to load and save learner checkpoints.
78    pub checkpointer: Option<LearningCheckpointer<LC>>,
79    /// Enables gradients accumulation.
80    pub grad_accumulation: Option<usize>,
81    /// An [Interupter](Interrupter) that allows aborting the training/evaluation process early.
82    pub interrupter: Interrupter,
83    /// Cloneable reference to an early stopping strategy.
84    pub early_stopping: Option<EarlyStoppingStrategyRef>,
85    /// An [EventProcessor](crate::EventProcessorTraining) that processes events happening during training and validation.
86    pub event_processor: SupervisedTrainingEventProcessor<LC>,
87    /// A reference to an [EventStoreClient](EventStoreClient).
88    pub event_store: Arc<EventStoreClient>,
89    /// Config for creating a summary of the learning
90    pub summary: Option<LearnerSummaryConfig>,
91}
92
93/// Provides the `fit` function for any learning strategy
94pub trait SupervisedLearningStrategy<LC: LearningComponentsTypes> {
95    /// Train the learner's model with this strategy.
96    fn train(
97        &self,
98        mut learner: Learner<LC>,
99        dataloader_train: TrainLoader<LC>,
100        dataloader_valid: ValidLoader<LC>,
101        mut training_components: TrainingComponents<LC>,
102    ) -> LearningResult<InferenceModel<LC>> {
103        let starting_epoch = match training_components.checkpoint {
104            Some(checkpoint) => {
105                if let Some(checkpointer) = &mut training_components.checkpointer {
106                    learner =
107                        checkpointer.load_checkpoint(learner, &Default::default(), checkpoint);
108                }
109                checkpoint + 1
110            }
111            None => 1,
112        };
113
114        let summary_config = training_components.summary.clone();
115
116        // Event processor start training
117        training_components
118            .event_processor
119            .process_train(LearnerEvent::Start);
120        // Training loop
121        let (model, mut event_processor) = self.fit(
122            training_components,
123            learner,
124            dataloader_train,
125            dataloader_valid,
126            starting_epoch,
127        );
128
129        let summary = summary_config.and_then(|summary| {
130            summary
131                .init()
132                .map(|summary| summary.with_model(model.to_string()))
133                .ok()
134        });
135
136        // Signal training end. For the TUI renderer, this handles the exit & return to main screen.
137        event_processor.process_train(LearnerEvent::End(summary));
138
139        let model = model.valid();
140        let renderer = event_processor.renderer();
141
142        LearningResult::<InferenceModel<LC>> { model, renderer }
143    }
144
145    /// Training loop for this strategy
146    fn fit(
147        &self,
148        training_components: TrainingComponents<LC>,
149        learner: Learner<LC>,
150        dataloader_train: TrainLoader<LC>,
151        dataloader_valid: ValidLoader<LC>,
152        starting_epoch: usize,
153    ) -> (TrainingModel<LC>, SupervisedTrainingEventProcessor<LC>);
154}