1use burn_core::data::dataloader::DataLoader;
2use burn_core::tensor::backend::AutodiffBackend;
3use burn_core::{
4 lr_scheduler::LrScheduler, module::AutodiffModule, optim::GradientsAccumulator,
5 tensor::backend::Backend,
6};
7use std::sync::Arc;
8
9use crate::metric::processor::{Event, EventProcessor, LearnerItem};
10use crate::{MultiDevicesTrainStep, TrainStep, ValidStep};
11use crate::{components::LearnerComponents, learner::base::TrainingInterrupter};
12
13#[derive(new)]
15pub struct ValidEpoch<B: Backend, VI> {
16 dataloader: Arc<dyn DataLoader<B, VI>>,
17 epoch: usize,
18 epoch_total: usize,
19}
20
21#[derive(new)]
23pub struct TrainEpoch<B: AutodiffBackend, TI> {
24 dataloader: Vec<Arc<dyn DataLoader<B, TI>>>,
25 epoch: usize,
26 epoch_total: usize,
27 grad_accumulation: Option<usize>,
28}
29
30impl<B: Backend, VI> ValidEpoch<B, VI> {
31 pub fn run<LC: LearnerComponents, VO>(
38 &self,
39 model: &LC::Model,
40 processor: &mut LC::EventProcessor,
41 interrupter: &TrainingInterrupter,
42 ) where
43 LC::EventProcessor: EventProcessor<ItemValid = VO>,
44 <LC::Model as AutodiffModule<LC::Backend>>::InnerModule: ValidStep<VI, VO>,
45 LC::Backend: AutodiffBackend<InnerBackend = B>,
46 {
47 log::info!("Executing validation step for epoch {}", self.epoch);
48 let model = model.valid();
49
50 let mut iterator = self.dataloader.iter();
51 let mut iteration = 0;
52
53 while let Some(item) = iterator.next() {
54 let progress = iterator.progress();
55 iteration += 1;
56
57 let item = model.step(item);
58 let item = LearnerItem::new(
59 item,
60 progress,
61 self.epoch,
62 self.epoch_total,
63 iteration,
64 None,
65 );
66
67 processor.process_valid(Event::ProcessedItem(item));
68
69 if interrupter.should_stop() {
70 log::info!("Training interrupted.");
71 break;
72 }
73 }
74 processor.process_valid(Event::EndEpoch(self.epoch));
75 }
76}
77
78impl<B: AutodiffBackend, TI> TrainEpoch<B, TI> {
79 pub fn run<LC: LearnerComponents<Backend = B>, TO>(
92 &mut self,
93 mut model: LC::Model,
94 mut optim: LC::Optimizer,
95 scheduler: &mut LC::LrScheduler,
96 processor: &mut LC::EventProcessor,
97 interrupter: &TrainingInterrupter,
98 ) -> (LC::Model, LC::Optimizer)
99 where
100 LC::EventProcessor: EventProcessor<ItemTrain = TO>,
101 LC::Model: TrainStep<TI, TO>,
102 {
103 log::info!("Executing training step for epoch {}", self.epoch,);
104
105 let mut iterator = self.dataloader[0].iter();
107 let mut iteration = 0;
108 let mut accumulator = GradientsAccumulator::new();
109 let mut accumulation_current = 0;
110
111 while let Some(item) = iterator.next() {
112 iteration += 1;
113 let lr = scheduler.step();
114 log::info!("Iteration {iteration}");
115
116 let progress = iterator.progress();
117 let item = model.step(item);
118
119 match self.grad_accumulation {
120 Some(accumulation) => {
121 accumulator.accumulate(&model, item.grads);
122 accumulation_current += 1;
123
124 if accumulation <= accumulation_current {
125 let grads = accumulator.grads();
126 model = model.optimize(&mut optim, lr, grads);
127 accumulation_current = 0;
128 }
129 }
130 None => model = model.optimize(&mut optim, lr, item.grads),
131 }
132
133 let item = LearnerItem::new(
134 item.item,
135 progress,
136 self.epoch,
137 self.epoch_total,
138 iteration,
139 Some(lr),
140 );
141
142 processor.process_train(Event::ProcessedItem(item));
143
144 if interrupter.should_stop() {
145 log::info!("Training interrupted.");
146 break;
147 }
148 }
149 processor.process_train(Event::EndEpoch(self.epoch));
150
151 self.epoch += 1;
152
153 (model, optim)
154 }
155}
156
157impl<B: AutodiffBackend, TI> TrainEpoch<B, TI> {
158 pub fn run_multi_device<LC: LearnerComponents<Backend = B>, TO>(
172 &mut self,
173 mut model: LC::Model,
174 mut optim: LC::Optimizer,
175 lr_scheduler: &mut LC::LrScheduler,
176 processor: &mut LC::EventProcessor,
177 devices: Vec<<LC::Backend as Backend>::Device>,
178 interrupter: &TrainingInterrupter,
179 ) -> (LC::Model, LC::Optimizer)
180 where
181 LC::EventProcessor: EventProcessor<ItemTrain = TO>,
182 LC::Model: TrainStep<TI, TO>,
183 TO: Send + 'static,
184 TI: Send + 'static,
185 {
186 log::info!(
187 "Executing training step for epoch {} on devices {:?}",
188 self.epoch,
189 devices
190 );
191
192 let mut iterators = self.dataloader.iter().map(|d| d.iter()).collect::<Vec<_>>();
193 let mut iteration = 0;
194 let mut accumulator = GradientsAccumulator::new();
195 let mut accumulation_current = 0;
196
197 let accumulation = self.grad_accumulation.unwrap_or(1) * devices.len();
198 let step = MultiDevicesTrainStep::new(&devices);
199
200 let device_main = devices.first().expect("A minimum of one device.").clone();
202 let mut interrupted = false;
203
204 loop {
205 let (items, progress) = step.step(iterators.as_mut_slice(), &model);
206 if items.is_empty() {
207 break;
208 }
209
210 for item in items {
211 iteration += 1;
212 let lr = lr_scheduler.step();
213
214 let grads = item.grads.to_device(&device_main, &model);
216
217 accumulator.accumulate(&model, grads);
218 accumulation_current += 1;
219
220 if accumulation <= accumulation_current {
221 let grads = accumulator.grads();
222 model = model.optimize(&mut optim, lr, grads);
223 accumulation_current = 0;
224 }
225
226 let item = LearnerItem::new(
227 item.item,
228 progress.clone(),
229 self.epoch,
230 self.epoch_total,
231 iteration,
232 Some(lr),
233 );
234
235 processor.process_train(Event::ProcessedItem(item));
236
237 if interrupter.should_stop() {
238 log::info!("Training interrupted.");
239 interrupted = true;
240 break;
241 }
242 }
243
244 if interrupted {
245 break;
246 }
247 }
248
249 processor.process_train(Event::EndEpoch(self.epoch));
250
251 self.epoch += 1;
252
253 (model, optim)
254 }
255}