burn_train/learner/
train_val.rs1use crate::components::LearnerComponents;
2use crate::metric::processor::{Event, EventProcessor};
3use crate::{Learner, TrainEpoch, ValidEpoch};
4use burn_core::data::dataloader::DataLoader;
5use burn_core::module::{AutodiffModule, Module};
6use burn_core::optim::{GradientsParams, Optimizer};
7use burn_core::tensor::backend::AutodiffBackend;
8use std::sync::Arc;
9
10pub struct TrainOutput<TO> {
12 pub grads: GradientsParams,
14
15 pub item: TO,
17}
18
19impl<TO> TrainOutput<TO> {
20 pub fn new<B: AutodiffBackend, M: AutodiffModule<B>>(
32 module: &M,
33 grads: B::Gradients,
34 item: TO,
35 ) -> Self {
36 let grads = GradientsParams::from_grads(grads, module);
37 Self { grads, item }
38 }
39}
40
41pub trait TrainStep<TI, TO> {
55 fn step(&self, item: TI) -> TrainOutput<TO>;
65 fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
77 where
78 B: AutodiffBackend,
79 O: Optimizer<Self, B>,
80 Self: AutodiffModule<B>,
81 {
82 optim.step(lr, self, grads)
83 }
84}
85
86pub trait ValidStep<VI, VO> {
88 fn step(&self, item: VI) -> VO;
98}
99
100impl<LC: LearnerComponents> Learner<LC> {
101 pub fn fit<InputTrain, InputValid, OutputTrain, OutputValid>(
112 mut self,
113 dataloader_train: Arc<dyn DataLoader<InputTrain>>,
114 dataloader_valid: Arc<dyn DataLoader<InputValid>>,
115 ) -> LC::Model
116 where
117 InputTrain: Send + 'static,
118 InputValid: Send,
119 OutputTrain: Send + 'static,
120 OutputValid: Send,
121 LC::Model: TrainStep<InputTrain, OutputTrain>,
122 <LC::Model as AutodiffModule<LC::Backend>>::InnerModule: ValidStep<InputValid, OutputValid>,
123 LC::EventProcessor: EventProcessor<ItemTrain = OutputTrain, ItemValid = OutputValid>,
124 {
125 log::info!("Fitting the model:\n {}", self.model);
126 if let Some(device) = self.devices.first() {
128 self.model = self.model.fork(device);
129 }
130
131 let starting_epoch = match self.checkpoint {
132 Some(checkpoint) => {
133 if let Some(checkpointer) = &mut self.checkpointer {
134 (self.model, self.optim, self.lr_scheduler) = checkpointer.load_checkpoint(
135 self.model,
136 self.optim,
137 self.lr_scheduler,
138 &Default::default(), checkpoint,
140 );
141 }
142 checkpoint + 1
143 }
144 None => 1,
145 };
146
147 for epoch in starting_epoch..self.num_epochs + 1 {
148 let epoch_train = TrainEpoch::new(
149 dataloader_train.clone(),
150 epoch,
151 self.num_epochs,
152 self.grad_accumulation,
153 );
154
155 if self.devices.len() > 1 {
156 (self.model, self.optim) = epoch_train.run_multi_device::<LC, OutputTrain>(
157 self.model,
158 self.optim,
159 &mut self.lr_scheduler,
160 &mut self.event_processor,
161 self.devices.clone(),
162 &self.interrupter,
163 )
164 } else {
165 (self.model, self.optim) = epoch_train.run::<LC, OutputTrain>(
166 self.model,
167 self.optim,
168 &mut self.lr_scheduler,
169 &mut self.event_processor,
170 &self.interrupter,
171 );
172 }
173
174 if self.interrupter.should_stop() {
175 break;
176 }
177
178 let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs);
179 epoch_valid.run::<LC, OutputValid>(
180 &self.model,
181 &mut self.event_processor,
182 &self.interrupter,
183 );
184
185 if let Some(checkpointer) = &mut self.checkpointer {
186 checkpointer.checkpoint(
187 &self.model,
188 &self.optim,
189 &self.lr_scheduler,
190 epoch,
191 &self.event_store,
192 );
193 }
194
195 if let Some(early_stopping) = &mut self.early_stopping {
196 if early_stopping.should_stop(epoch, &self.event_store) {
197 break;
198 }
199 }
200 }
201
202 self.event_processor.process_train(Event::End);
204
205 if let Some(summary) = self.summary {
207 match summary.init() {
208 Ok(summary) => {
209 println!("{}", summary.with_model(self.model.to_string()))
210 }
211 Err(err) => log::error!("Could not retrieve learner summary:\n{err}"),
212 }
213 }
214
215 self.model
216 }
217}