burn_train/learner/
epoch.rs

1use burn_core::{
2    data::dataloader::DataLoader, lr_scheduler::LrScheduler, module::AutodiffModule,
3    optim::GradientsAccumulator, tensor::backend::Backend,
4};
5use std::sync::Arc;
6
7use crate::metric::processor::{Event, EventProcessor, LearnerItem};
8use crate::{components::LearnerComponents, learner::base::TrainingInterrupter};
9use crate::{MultiDevicesTrainStep, TrainStep, ValidStep};
10
11/// A validation epoch.
12#[derive(new)]
13pub struct ValidEpoch<VI> {
14    dataloader: Arc<dyn DataLoader<VI>>,
15    epoch: usize,
16    epoch_total: usize,
17}
18
19/// A training epoch.
20#[derive(new)]
21pub struct TrainEpoch<TI> {
22    dataloader: Arc<dyn DataLoader<TI>>,
23    epoch: usize,
24    epoch_total: usize,
25    grad_accumulation: Option<usize>,
26}
27
28impl<VI> ValidEpoch<VI> {
29    /// Runs the validation epoch.
30    ///
31    /// # Arguments
32    ///
33    /// * `model` - The model to validate.
34    /// * `processor` - The event processor to use.
35    pub fn run<LC: LearnerComponents, VO>(
36        &self,
37        model: &LC::Model,
38        processor: &mut LC::EventProcessor,
39        interrupter: &TrainingInterrupter,
40    ) where
41        LC::EventProcessor: EventProcessor<ItemValid = VO>,
42        <LC::Model as AutodiffModule<LC::Backend>>::InnerModule: ValidStep<VI, VO>,
43    {
44        log::info!("Executing validation step for epoch {}", self.epoch);
45        let model = model.valid();
46
47        let mut iterator = self.dataloader.iter();
48        let mut iteration = 0;
49
50        while let Some(item) = iterator.next() {
51            let progress = iterator.progress();
52            iteration += 1;
53
54            let item = model.step(item);
55            let item = LearnerItem::new(
56                item,
57                progress,
58                self.epoch,
59                self.epoch_total,
60                iteration,
61                None,
62            );
63
64            processor.process_valid(Event::ProcessedItem(item));
65
66            if interrupter.should_stop() {
67                log::info!("Training interrupted.");
68                break;
69            }
70        }
71        processor.process_valid(Event::EndEpoch(self.epoch));
72    }
73}
74
75impl<TI> TrainEpoch<TI> {
76    /// Runs the training epoch.
77    ///
78    /// # Arguments
79    ///
80    /// * `model` - The model to train.
81    /// * `optim` - The optimizer to use.
82    /// * `scheduler` - The learning rate scheduler to use.
83    /// * `processor` - The event processor to use.
84    ///
85    /// # Returns
86    ///
87    /// The trained model and the optimizer.
88    pub fn run<LC: LearnerComponents, TO>(
89        &self,
90        mut model: LC::Model,
91        mut optim: LC::Optimizer,
92        scheduler: &mut LC::LrScheduler,
93        processor: &mut LC::EventProcessor,
94        interrupter: &TrainingInterrupter,
95    ) -> (LC::Model, LC::Optimizer)
96    where
97        LC::EventProcessor: EventProcessor<ItemTrain = TO>,
98        LC::Model: TrainStep<TI, TO>,
99    {
100        log::info!("Executing training step for epoch {}", self.epoch,);
101
102        let mut iterator = self.dataloader.iter();
103        let mut iteration = 0;
104        let mut accumulator = GradientsAccumulator::new();
105        let mut accumulation_current = 0;
106
107        while let Some(item) = iterator.next() {
108            iteration += 1;
109            let lr = scheduler.step();
110            log::info!("Iteration {}", iteration);
111
112            let progress = iterator.progress();
113            let item = model.step(item);
114
115            match self.grad_accumulation {
116                Some(accumulation) => {
117                    accumulator.accumulate(&model, item.grads);
118                    accumulation_current += 1;
119
120                    if accumulation <= accumulation_current {
121                        let grads = accumulator.grads();
122                        model = model.optimize(&mut optim, lr, grads);
123                        accumulation_current = 0;
124                    }
125                }
126                None => model = model.optimize(&mut optim, lr, item.grads),
127            }
128
129            let item = LearnerItem::new(
130                item.item,
131                progress,
132                self.epoch,
133                self.epoch_total,
134                iteration,
135                Some(lr),
136            );
137
138            processor.process_train(Event::ProcessedItem(item));
139
140            if interrupter.should_stop() {
141                log::info!("Training interrupted.");
142                break;
143            }
144        }
145        processor.process_train(Event::EndEpoch(self.epoch));
146
147        (model, optim)
148    }
149}
150
151impl<TI> TrainEpoch<TI> {
152    /// Runs the training epoch on multiple devices.
153    ///
154    /// # Arguments
155    ///
156    /// * `model` - The model to train.
157    /// * `optim` - The optimizer to use.
158    /// * `lr_scheduler` - The learning rate scheduler to use.
159    /// * `processor` - The event processor to use.
160    /// * `devices` - The devices to use.
161    ///
162    /// # Returns
163    ///
164    /// The trained model and the optimizer.
165    pub fn run_multi_device<LC: LearnerComponents, TO>(
166        &self,
167        mut model: LC::Model,
168        mut optim: LC::Optimizer,
169        lr_scheduler: &mut LC::LrScheduler,
170        processor: &mut LC::EventProcessor,
171        devices: Vec<<LC::Backend as Backend>::Device>,
172        interrupter: &TrainingInterrupter,
173    ) -> (LC::Model, LC::Optimizer)
174    where
175        LC::EventProcessor: EventProcessor<ItemTrain = TO>,
176        LC::Model: TrainStep<TI, TO>,
177        TO: Send + 'static,
178        TI: Send + 'static,
179    {
180        log::info!(
181            "Executing training step for epoch {} on devices {:?}",
182            self.epoch,
183            devices
184        );
185
186        let mut iterator = self.dataloader.iter();
187        let mut iteration = 0;
188        let mut accumulator = GradientsAccumulator::new();
189        let mut accumulation_current = 0;
190
191        let accumulation = self.grad_accumulation.unwrap_or(1) * devices.len();
192        let step = MultiDevicesTrainStep::new(&devices);
193
194        // The main device is always the first in the list.
195        let device_main = devices.first().expect("A minimum of one device.").clone();
196        let mut interrupted = false;
197
198        loop {
199            let items = step.step(&mut iterator, &model);
200            if items.is_empty() {
201                break;
202            }
203
204            for item in items {
205                iteration += 1;
206                let lr = lr_scheduler.step();
207                let progress = iterator.progress();
208
209                let grads = item.grads.to_device(&device_main, &model);
210
211                accumulator.accumulate(&model, grads);
212                accumulation_current += 1;
213
214                if accumulation <= accumulation_current {
215                    let grads = accumulator.grads();
216                    model = model.optimize(&mut optim, lr, grads);
217                    accumulation_current = 0;
218                }
219
220                let item = LearnerItem::new(
221                    item.item,
222                    progress,
223                    self.epoch,
224                    self.epoch_total,
225                    iteration,
226                    Some(lr),
227                );
228
229                processor.process_train(Event::ProcessedItem(item));
230
231                if interrupter.should_stop() {
232                    log::info!("Training interrupted.");
233                    interrupted = true;
234                    break;
235                }
236            }
237
238            if interrupted {
239                break;
240            }
241        }
242
243        processor.process_train(Event::EndEpoch(self.epoch));
244
245        (model, optim)
246    }
247}