burn_train/learner/
epoch.rs

1use burn_core::data::dataloader::DataLoader;
2use burn_core::tensor::backend::AutodiffBackend;
3use burn_core::{
4    lr_scheduler::LrScheduler, module::AutodiffModule, optim::GradientsAccumulator,
5    tensor::backend::Backend,
6};
7use std::sync::Arc;
8
9use crate::metric::processor::{Event, EventProcessor, LearnerItem};
10use crate::{MultiDevicesTrainStep, TrainStep, ValidStep};
11use crate::{components::LearnerComponents, learner::base::TrainingInterrupter};
12
13/// A validation epoch.
14#[derive(new)]
15pub struct ValidEpoch<B: Backend, VI> {
16    dataloader: Arc<dyn DataLoader<B, VI>>,
17    epoch: usize,
18    epoch_total: usize,
19}
20
21/// A training epoch.
22#[derive(new)]
23pub struct TrainEpoch<B: AutodiffBackend, TI> {
24    dataloader: Vec<Arc<dyn DataLoader<B, TI>>>,
25    epoch: usize,
26    epoch_total: usize,
27    grad_accumulation: Option<usize>,
28}
29
30impl<B: Backend, VI> ValidEpoch<B, VI> {
31    /// Runs the validation epoch.
32    ///
33    /// # Arguments
34    ///
35    /// * `model` - The model to validate.
36    /// * `processor` - The event processor to use.
37    pub fn run<LC: LearnerComponents, VO>(
38        &self,
39        model: &LC::Model,
40        processor: &mut LC::EventProcessor,
41        interrupter: &TrainingInterrupter,
42    ) where
43        LC::EventProcessor: EventProcessor<ItemValid = VO>,
44        <LC::Model as AutodiffModule<LC::Backend>>::InnerModule: ValidStep<VI, VO>,
45        LC::Backend: AutodiffBackend<InnerBackend = B>,
46    {
47        log::info!("Executing validation step for epoch {}", self.epoch);
48        let model = model.valid();
49
50        let mut iterator = self.dataloader.iter();
51        let mut iteration = 0;
52
53        while let Some(item) = iterator.next() {
54            let progress = iterator.progress();
55            iteration += 1;
56
57            let item = model.step(item);
58            let item = LearnerItem::new(
59                item,
60                progress,
61                self.epoch,
62                self.epoch_total,
63                iteration,
64                None,
65            );
66
67            processor.process_valid(Event::ProcessedItem(item));
68
69            if interrupter.should_stop() {
70                log::info!("Training interrupted.");
71                break;
72            }
73        }
74        processor.process_valid(Event::EndEpoch(self.epoch));
75    }
76}
77
78impl<B: AutodiffBackend, TI> TrainEpoch<B, TI> {
79    /// Runs the training epoch.
80    ///
81    /// # Arguments
82    ///
83    /// * `model` - The model to train.
84    /// * `optim` - The optimizer to use.
85    /// * `scheduler` - The learning rate scheduler to use.
86    /// * `processor` - The event processor to use.
87    ///
88    /// # Returns
89    ///
90    /// The trained model and the optimizer.
91    pub fn run<LC: LearnerComponents<Backend = B>, TO>(
92        &mut self,
93        mut model: LC::Model,
94        mut optim: LC::Optimizer,
95        scheduler: &mut LC::LrScheduler,
96        processor: &mut LC::EventProcessor,
97        interrupter: &TrainingInterrupter,
98    ) -> (LC::Model, LC::Optimizer)
99    where
100        LC::EventProcessor: EventProcessor<ItemTrain = TO>,
101        LC::Model: TrainStep<TI, TO>,
102    {
103        log::info!("Executing training step for epoch {}", self.epoch,);
104
105        // Single device / dataloader
106        let mut iterator = self.dataloader[0].iter();
107        let mut iteration = 0;
108        let mut accumulator = GradientsAccumulator::new();
109        let mut accumulation_current = 0;
110
111        while let Some(item) = iterator.next() {
112            iteration += 1;
113            let lr = scheduler.step();
114            log::info!("Iteration {iteration}");
115
116            let progress = iterator.progress();
117            let item = model.step(item);
118
119            match self.grad_accumulation {
120                Some(accumulation) => {
121                    accumulator.accumulate(&model, item.grads);
122                    accumulation_current += 1;
123
124                    if accumulation <= accumulation_current {
125                        let grads = accumulator.grads();
126                        model = model.optimize(&mut optim, lr, grads);
127                        accumulation_current = 0;
128                    }
129                }
130                None => model = model.optimize(&mut optim, lr, item.grads),
131            }
132
133            let item = LearnerItem::new(
134                item.item,
135                progress,
136                self.epoch,
137                self.epoch_total,
138                iteration,
139                Some(lr),
140            );
141
142            processor.process_train(Event::ProcessedItem(item));
143
144            if interrupter.should_stop() {
145                log::info!("Training interrupted.");
146                break;
147            }
148        }
149        processor.process_train(Event::EndEpoch(self.epoch));
150
151        self.epoch += 1;
152
153        (model, optim)
154    }
155}
156
157impl<B: AutodiffBackend, TI> TrainEpoch<B, TI> {
158    /// Runs the training epoch on multiple devices.
159    ///
160    /// # Arguments
161    ///
162    /// * `model` - The model to train.
163    /// * `optim` - The optimizer to use.
164    /// * `lr_scheduler` - The learning rate scheduler to use.
165    /// * `processor` - The event processor to use.
166    /// * `devices` - The devices to use.
167    ///
168    /// # Returns
169    ///
170    /// The trained model and the optimizer.
171    pub fn run_multi_device<LC: LearnerComponents<Backend = B>, TO>(
172        &mut self,
173        mut model: LC::Model,
174        mut optim: LC::Optimizer,
175        lr_scheduler: &mut LC::LrScheduler,
176        processor: &mut LC::EventProcessor,
177        devices: Vec<<LC::Backend as Backend>::Device>,
178        interrupter: &TrainingInterrupter,
179    ) -> (LC::Model, LC::Optimizer)
180    where
181        LC::EventProcessor: EventProcessor<ItemTrain = TO>,
182        LC::Model: TrainStep<TI, TO>,
183        TO: Send + 'static,
184        TI: Send + 'static,
185    {
186        log::info!(
187            "Executing training step for epoch {} on devices {:?}",
188            self.epoch,
189            devices
190        );
191
192        let mut iterators = self.dataloader.iter().map(|d| d.iter()).collect::<Vec<_>>();
193        let mut iteration = 0;
194        let mut accumulator = GradientsAccumulator::new();
195        let mut accumulation_current = 0;
196
197        let accumulation = self.grad_accumulation.unwrap_or(1) * devices.len();
198        let step = MultiDevicesTrainStep::new(&devices);
199
200        // The main device is always the first in the list.
201        let device_main = devices.first().expect("A minimum of one device.").clone();
202        let mut interrupted = false;
203
204        loop {
205            let (items, progress) = step.step(iterators.as_mut_slice(), &model);
206            if items.is_empty() {
207                break;
208            }
209
210            for item in items {
211                iteration += 1;
212                let lr = lr_scheduler.step();
213
214                // TODO: aggregate multi device (all-reduce)
215                let grads = item.grads.to_device(&device_main, &model);
216
217                accumulator.accumulate(&model, grads);
218                accumulation_current += 1;
219
220                if accumulation <= accumulation_current {
221                    let grads = accumulator.grads();
222                    model = model.optimize(&mut optim, lr, grads);
223                    accumulation_current = 0;
224                }
225
226                let item = LearnerItem::new(
227                    item.item,
228                    progress.clone(),
229                    self.epoch,
230                    self.epoch_total,
231                    iteration,
232                    Some(lr),
233                );
234
235                processor.process_train(Event::ProcessedItem(item));
236
237                if interrupter.should_stop() {
238                    log::info!("Training interrupted.");
239                    interrupted = true;
240                    break;
241                }
242            }
243
244            if interrupted {
245                break;
246            }
247        }
248
249        processor.process_train(Event::EndEpoch(self.epoch));
250
251        self.epoch += 1;
252
253        (model, optim)
254    }
255}