burn_train/learner/strategies/multi/
method.rs

1use crate::{
2    LearnerComponents, LearningMethod, TrainLoader, ValidLoader, components::LearnerComponentTypes,
3    learner::strategies::single::epoch::SingleDeviceValidEpoch,
4    multi::epoch::MultiDeviceTrainEpoch,
5};
6use burn_core::{data::dataloader::split::split_dataloader, module::Module, prelude::Backend};
7use std::{marker::PhantomData, sync::Arc};
8
9#[derive(Clone, Copy, Debug)]
10/// Determine how the optimization is performed when training with multiple devices.
11pub enum MultiDeviceOptim {
12    /// The optimization is done on an elected device.
13    OptimMainDevice,
14    /// The optimization is sharded across all devices.
15    OptimSharded,
16}
17
18pub struct MultiDeviceLearningStrategy<LC: LearnerComponentTypes> {
19    devices: Vec<<LC::Backend as Backend>::Device>,
20    optim: MultiDeviceOptim,
21    _p: PhantomData<LC>,
22}
23impl<LC: LearnerComponentTypes> MultiDeviceLearningStrategy<LC> {
24    pub fn new(devices: Vec<<LC::Backend as Backend>::Device>, optim: MultiDeviceOptim) -> Self {
25        Self {
26            devices,
27            optim,
28            _p: PhantomData,
29        }
30    }
31}
32
33pub type CustomMultiDeviceLearningStrategy<LC> = Arc<
34    dyn LearningMethod<
35            LC,
36            PreparedDataloaders = (Vec<TrainLoader<LC>>, ValidLoader<LC>),
37            PreparedModel = <LC as LearnerComponentTypes>::Model,
38        >,
39>;
40
41impl<LC: LearnerComponentTypes> LearningMethod<LC> for MultiDeviceLearningStrategy<LC> {
42    type PreparedDataloaders = (Vec<TrainLoader<LC>>, ValidLoader<LC>);
43
44    type PreparedModel = LC::Model;
45
46    fn prepare_dataloaders(
47        &self,
48        dataloader_train: TrainLoader<LC>,
49        dataloader_valid: ValidLoader<LC>,
50    ) -> Self::PreparedDataloaders {
51        // `MultiDevicesTrainStep` has one worker per device, so we use a fixed device strategy
52        // for each (worker) data loader. This matches the expected device on the worker, so we
53        // don't have to move the data between devices.
54        let train = split_dataloader(dataloader_train, &self.devices);
55        let main_device = self.devices.first().unwrap();
56        let valid = dataloader_valid.to_device(main_device);
57
58        (train, valid)
59    }
60
61    fn prepare_model(&self, model: LC::Model) -> Self::PreparedModel {
62        let main_device = self.devices.first().unwrap();
63        model.fork(main_device)
64    }
65
66    fn learn(
67        &self,
68        mut model: LC::Model,
69        (dataloader_train, dataloader_valid): Self::PreparedDataloaders,
70        starting_epoch: usize,
71        mut components: LearnerComponents<LC>,
72    ) -> (LC::Model, LC::EventProcessor) {
73        let mut epoch_train = MultiDeviceTrainEpoch::<LC>::new(
74            dataloader_train,
75            starting_epoch,
76            components.num_epochs,
77            components.grad_accumulation,
78        );
79
80        for epoch in starting_epoch..components.num_epochs + 1 {
81            (model, components.optim) = epoch_train.run(
82                model,
83                components.optim,
84                &mut components.lr_scheduler,
85                &mut components.event_processor,
86                self.devices.to_vec(),
87                &components.interrupter,
88                self.optim,
89            );
90
91            if components.interrupter.should_stop() {
92                break;
93            }
94
95            let epoch_valid = SingleDeviceValidEpoch::<LC>::new(
96                dataloader_valid.clone(),
97                epoch,
98                components.num_epochs,
99            );
100            epoch_valid.run(
101                &model,
102                &mut components.event_processor,
103                &components.interrupter,
104            );
105
106            if let Some(checkpointer) = &mut components.checkpointer {
107                checkpointer.checkpoint(
108                    &model,
109                    &components.optim,
110                    &components.lr_scheduler,
111                    epoch,
112                    &components.event_store,
113                );
114            }
115
116            if let Some(early_stopping) = &mut components.early_stopping
117                && early_stopping.should_stop(epoch, &components.event_store)
118            {
119                break;
120            }
121        }
122
123        (model, components.event_processor)
124    }
125}