1use burn_core::{
2 data::dataloader::DataLoader, lr_scheduler::LrScheduler, module::AutodiffModule,
3 optim::GradientsAccumulator, tensor::backend::Backend,
4};
5use std::sync::Arc;
6
7use crate::metric::processor::{Event, EventProcessor, LearnerItem};
8use crate::{components::LearnerComponents, learner::base::TrainingInterrupter};
9use crate::{MultiDevicesTrainStep, TrainStep, ValidStep};
10
11#[derive(new)]
13pub struct ValidEpoch<VI> {
14 dataloader: Arc<dyn DataLoader<VI>>,
15 epoch: usize,
16 epoch_total: usize,
17}
18
19#[derive(new)]
21pub struct TrainEpoch<TI> {
22 dataloader: Arc<dyn DataLoader<TI>>,
23 epoch: usize,
24 epoch_total: usize,
25 grad_accumulation: Option<usize>,
26}
27
28impl<VI> ValidEpoch<VI> {
29 pub fn run<LC: LearnerComponents, VO>(
36 &self,
37 model: &LC::Model,
38 processor: &mut LC::EventProcessor,
39 interrupter: &TrainingInterrupter,
40 ) where
41 LC::EventProcessor: EventProcessor<ItemValid = VO>,
42 <LC::Model as AutodiffModule<LC::Backend>>::InnerModule: ValidStep<VI, VO>,
43 {
44 log::info!("Executing validation step for epoch {}", self.epoch);
45 let model = model.valid();
46
47 let mut iterator = self.dataloader.iter();
48 let mut iteration = 0;
49
50 while let Some(item) = iterator.next() {
51 let progress = iterator.progress();
52 iteration += 1;
53
54 let item = model.step(item);
55 let item = LearnerItem::new(
56 item,
57 progress,
58 self.epoch,
59 self.epoch_total,
60 iteration,
61 None,
62 );
63
64 processor.process_valid(Event::ProcessedItem(item));
65
66 if interrupter.should_stop() {
67 log::info!("Training interrupted.");
68 break;
69 }
70 }
71 processor.process_valid(Event::EndEpoch(self.epoch));
72 }
73}
74
75impl<TI> TrainEpoch<TI> {
76 pub fn run<LC: LearnerComponents, TO>(
89 &self,
90 mut model: LC::Model,
91 mut optim: LC::Optimizer,
92 scheduler: &mut LC::LrScheduler,
93 processor: &mut LC::EventProcessor,
94 interrupter: &TrainingInterrupter,
95 ) -> (LC::Model, LC::Optimizer)
96 where
97 LC::EventProcessor: EventProcessor<ItemTrain = TO>,
98 LC::Model: TrainStep<TI, TO>,
99 {
100 log::info!("Executing training step for epoch {}", self.epoch,);
101
102 let mut iterator = self.dataloader.iter();
103 let mut iteration = 0;
104 let mut accumulator = GradientsAccumulator::new();
105 let mut accumulation_current = 0;
106
107 while let Some(item) = iterator.next() {
108 iteration += 1;
109 let lr = scheduler.step();
110 log::info!("Iteration {}", iteration);
111
112 let progress = iterator.progress();
113 let item = model.step(item);
114
115 match self.grad_accumulation {
116 Some(accumulation) => {
117 accumulator.accumulate(&model, item.grads);
118 accumulation_current += 1;
119
120 if accumulation <= accumulation_current {
121 let grads = accumulator.grads();
122 model = model.optimize(&mut optim, lr, grads);
123 accumulation_current = 0;
124 }
125 }
126 None => model = model.optimize(&mut optim, lr, item.grads),
127 }
128
129 let item = LearnerItem::new(
130 item.item,
131 progress,
132 self.epoch,
133 self.epoch_total,
134 iteration,
135 Some(lr),
136 );
137
138 processor.process_train(Event::ProcessedItem(item));
139
140 if interrupter.should_stop() {
141 log::info!("Training interrupted.");
142 break;
143 }
144 }
145 processor.process_train(Event::EndEpoch(self.epoch));
146
147 (model, optim)
148 }
149}
150
151impl<TI> TrainEpoch<TI> {
152 pub fn run_multi_device<LC: LearnerComponents, TO>(
166 &self,
167 mut model: LC::Model,
168 mut optim: LC::Optimizer,
169 lr_scheduler: &mut LC::LrScheduler,
170 processor: &mut LC::EventProcessor,
171 devices: Vec<<LC::Backend as Backend>::Device>,
172 interrupter: &TrainingInterrupter,
173 ) -> (LC::Model, LC::Optimizer)
174 where
175 LC::EventProcessor: EventProcessor<ItemTrain = TO>,
176 LC::Model: TrainStep<TI, TO>,
177 TO: Send + 'static,
178 TI: Send + 'static,
179 {
180 log::info!(
181 "Executing training step for epoch {} on devices {:?}",
182 self.epoch,
183 devices
184 );
185
186 let mut iterator = self.dataloader.iter();
187 let mut iteration = 0;
188 let mut accumulator = GradientsAccumulator::new();
189 let mut accumulation_current = 0;
190
191 let accumulation = self.grad_accumulation.unwrap_or(1) * devices.len();
192 let step = MultiDevicesTrainStep::new(&devices);
193
194 let device_main = devices.first().expect("A minimum of one device.").clone();
196 let mut interrupted = false;
197
198 loop {
199 let items = step.step(&mut iterator, &model);
200 if items.is_empty() {
201 break;
202 }
203
204 for item in items {
205 iteration += 1;
206 let lr = lr_scheduler.step();
207 let progress = iterator.progress();
208
209 let grads = item.grads.to_device(&device_main, &model);
210
211 accumulator.accumulate(&model, grads);
212 accumulation_current += 1;
213
214 if accumulation <= accumulation_current {
215 let grads = accumulator.grads();
216 model = model.optimize(&mut optim, lr, grads);
217 accumulation_current = 0;
218 }
219
220 let item = LearnerItem::new(
221 item.item,
222 progress,
223 self.epoch,
224 self.epoch_total,
225 iteration,
226 Some(lr),
227 );
228
229 processor.process_train(Event::ProcessedItem(item));
230
231 if interrupter.should_stop() {
232 log::info!("Training interrupted.");
233 interrupted = true;
234 break;
235 }
236 }
237
238 if interrupted {
239 break;
240 }
241 }
242
243 processor.process_train(Event::EndEpoch(self.epoch));
244
245 (model, optim)
246 }
247}