burn_train/learner/strategies/multi/
method.rs1use crate::{
2 LearnerComponents, LearningMethod, TrainLoader, ValidLoader, components::LearnerComponentTypes,
3 learner::strategies::single::epoch::SingleDeviceValidEpoch,
4 multi::epoch::MultiDeviceTrainEpoch,
5};
6use burn_core::{data::dataloader::split::split_dataloader, module::Module, prelude::Backend};
7use std::{marker::PhantomData, sync::Arc};
8
9#[derive(Clone, Copy, Debug)]
10pub enum MultiDeviceOptim {
12 OptimMainDevice,
14 OptimSharded,
16}
17
18pub struct MultiDeviceLearningStrategy<LC: LearnerComponentTypes> {
19 devices: Vec<<LC::Backend as Backend>::Device>,
20 optim: MultiDeviceOptim,
21 _p: PhantomData<LC>,
22}
23impl<LC: LearnerComponentTypes> MultiDeviceLearningStrategy<LC> {
24 pub fn new(devices: Vec<<LC::Backend as Backend>::Device>, optim: MultiDeviceOptim) -> Self {
25 Self {
26 devices,
27 optim,
28 _p: PhantomData,
29 }
30 }
31}
32
33pub type CustomMultiDeviceLearningStrategy<LC> = Arc<
34 dyn LearningMethod<
35 LC,
36 PreparedDataloaders = (Vec<TrainLoader<LC>>, ValidLoader<LC>),
37 PreparedModel = <LC as LearnerComponentTypes>::Model,
38 >,
39>;
40
41impl<LC: LearnerComponentTypes> LearningMethod<LC> for MultiDeviceLearningStrategy<LC> {
42 type PreparedDataloaders = (Vec<TrainLoader<LC>>, ValidLoader<LC>);
43
44 type PreparedModel = LC::Model;
45
46 fn prepare_dataloaders(
47 &self,
48 dataloader_train: TrainLoader<LC>,
49 dataloader_valid: ValidLoader<LC>,
50 ) -> Self::PreparedDataloaders {
51 let train = split_dataloader(dataloader_train, &self.devices);
55 let main_device = self.devices.first().unwrap();
56 let valid = dataloader_valid.to_device(main_device);
57
58 (train, valid)
59 }
60
61 fn prepare_model(&self, model: LC::Model) -> Self::PreparedModel {
62 let main_device = self.devices.first().unwrap();
63 model.fork(main_device)
64 }
65
66 fn learn(
67 &self,
68 mut model: LC::Model,
69 (dataloader_train, dataloader_valid): Self::PreparedDataloaders,
70 starting_epoch: usize,
71 mut components: LearnerComponents<LC>,
72 ) -> (LC::Model, LC::EventProcessor) {
73 let mut epoch_train = MultiDeviceTrainEpoch::<LC>::new(
74 dataloader_train,
75 starting_epoch,
76 components.num_epochs,
77 components.grad_accumulation,
78 );
79
80 for epoch in starting_epoch..components.num_epochs + 1 {
81 (model, components.optim) = epoch_train.run(
82 model,
83 components.optim,
84 &mut components.lr_scheduler,
85 &mut components.event_processor,
86 self.devices.to_vec(),
87 &components.interrupter,
88 self.optim,
89 );
90
91 if components.interrupter.should_stop() {
92 break;
93 }
94
95 let epoch_valid = SingleDeviceValidEpoch::<LC>::new(
96 dataloader_valid.clone(),
97 epoch,
98 components.num_epochs,
99 );
100 epoch_valid.run(
101 &model,
102 &mut components.event_processor,
103 &components.interrupter,
104 );
105
106 if let Some(checkpointer) = &mut components.checkpointer {
107 checkpointer.checkpoint(
108 &model,
109 &components.optim,
110 &components.lr_scheduler,
111 epoch,
112 &components.event_store,
113 );
114 }
115
116 if let Some(early_stopping) = &mut components.early_stopping
117 && early_stopping.should_stop(epoch, &components.event_store)
118 {
119 break;
120 }
121 }
122
123 (model, components.event_processor)
124 }
125}