burn_train/learner/train_val.rs
1use crate::components::{LearnerComponents, TrainBackend, ValidBackend};
2use crate::metric::processor::{Event, EventProcessor};
3use crate::{Learner, TrainEpoch, ValidEpoch};
4use burn_core::data::dataloader::DataLoader;
5use burn_core::data::dataloader::split::split_dataloader;
6use burn_core::module::{AutodiffModule, Module};
7use burn_core::optim::{GradientsParams, Optimizer};
8use burn_core::tensor::backend::AutodiffBackend;
9use std::sync::Arc;
10
11/// A training output.
12pub struct TrainOutput<TO> {
13 /// The gradients.
14 pub grads: GradientsParams,
15
16 /// The item.
17 pub item: TO,
18}
19
20impl<TO> TrainOutput<TO> {
21 /// Creates a new training output.
22 ///
23 /// # Arguments
24 ///
25 /// * `module` - The module.
26 /// * `grads` - The gradients.
27 /// * `item` - The item.
28 ///
29 /// # Returns
30 ///
31 /// A new training output.
32 pub fn new<B: AutodiffBackend, M: AutodiffModule<B>>(
33 module: &M,
34 grads: B::Gradients,
35 item: TO,
36 ) -> Self {
37 let grads = GradientsParams::from_grads(grads, module);
38 Self { grads, item }
39 }
40}
41
42/// Trait to be implemented for training models.
43///
44/// The [step](TrainStep::step) method needs to be manually implemented for all structs.
45///
46/// The [optimize](TrainStep::optimize) method can be overridden if you want to control how the
47/// optimizer is used to update the model. This can be useful if you want to call custom mutable
48/// functions on your model (e.g., clipping the weights) before or after the optimizer is used.
49///
50/// # Notes
51///
52/// To be used with the [Learner](Learner) struct, the struct which implements this trait must
53/// also implement the [AutodiffModule] trait, which is done automatically with the
54/// [Module](burn_core::module::Module) derive.
55pub trait TrainStep<TI, TO> {
56 /// Runs the training step, which executes the forward and backward passes.
57 ///
58 /// # Arguments
59 ///
60 /// * `item` - The training input for the model.
61 ///
62 /// # Returns
63 ///
64 /// The training output containing the model output and the gradients.
65 fn step(&self, item: TI) -> TrainOutput<TO>;
66 /// Optimize the current module with the provided gradients and learning rate.
67 ///
68 /// # Arguments
69 ///
70 /// * `optim`: Optimizer used for training this model.
71 /// * `lr`: The learning rate used for this step.
72 /// * `grads`: The gradients of each parameter in the current model.
73 ///
74 /// # Returns
75 ///
76 /// The updated model.
77 fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
78 where
79 B: AutodiffBackend,
80 O: Optimizer<Self, B>,
81 Self: AutodiffModule<B>,
82 {
83 optim.step(lr, self, grads)
84 }
85}
86
87/// Trait to be implemented for validating models.
88pub trait ValidStep<VI, VO> {
89 /// Runs a validation step.
90 ///
91 /// # Arguments
92 ///
93 /// * `item` - The item to validate on.
94 ///
95 /// # Returns
96 ///
97 /// The validation output.
98 fn step(&self, item: VI) -> VO;
99}
100
101impl<LC: LearnerComponents> Learner<LC> {
102 /// Fits the model.
103 ///
104 /// # Arguments
105 ///
106 /// * `dataloader_train` - The training dataloader.
107 /// * `dataloader_valid` - The validation dataloader.
108 ///
109 /// # Returns
110 ///
111 /// The fitted model.
112 pub fn fit<InputTrain, InputValid, OutputTrain, OutputValid>(
113 mut self,
114 mut dataloader_train: Arc<dyn DataLoader<TrainBackend<LC>, InputTrain>>,
115 mut dataloader_valid: Arc<dyn DataLoader<ValidBackend<LC>, InputValid>>,
116 ) -> LC::Model
117 where
118 InputTrain: Send + 'static,
119 InputValid: Send,
120 OutputTrain: Send + 'static,
121 OutputValid: Send,
122 LC::Model: TrainStep<InputTrain, OutputTrain>,
123 <LC::Model as AutodiffModule<LC::Backend>>::InnerModule: ValidStep<InputValid, OutputValid>,
124 LC::EventProcessor: EventProcessor<ItemTrain = OutputTrain, ItemValid = OutputValid>,
125 {
126 log::info!("Fitting the model:\n {}", self.model);
127 // The reference model is always on the first device provided.
128 if let Some(device) = self.devices.first() {
129 self.model = self.model.fork(device);
130 dataloader_train = dataloader_train.to_device(device);
131 dataloader_valid = dataloader_valid.to_device(device);
132 }
133
134 let starting_epoch = match self.checkpoint {
135 Some(checkpoint) => {
136 if let Some(checkpointer) = &mut self.checkpointer {
137 (self.model, self.optim, self.lr_scheduler) = checkpointer.load_checkpoint(
138 self.model,
139 self.optim,
140 self.lr_scheduler,
141 &Default::default(), // Load the checkpoint on the default device.
142 checkpoint,
143 );
144 }
145 checkpoint + 1
146 }
147 None => 1,
148 };
149
150 // `MultiDevicesTrainStep` has one worker per device, so we use a fixed device strategy
151 // for each (worker) data loader. This matches the expected device on the worker, so we
152 // don't have to move the data between devices.
153 let dataloaders_train = split_dataloader(dataloader_train, &self.devices);
154
155 // Changed the train epoch to keep the dataloaders
156 let mut epoch_train = TrainEpoch::new(
157 dataloaders_train,
158 starting_epoch,
159 self.num_epochs,
160 self.grad_accumulation,
161 );
162
163 for epoch in starting_epoch..self.num_epochs + 1 {
164 if self.devices.len() > 1 {
165 (self.model, self.optim) = epoch_train.run_multi_device::<LC, OutputTrain>(
166 self.model,
167 self.optim,
168 &mut self.lr_scheduler,
169 &mut self.event_processor,
170 self.devices.clone(),
171 &self.interrupter,
172 )
173 } else {
174 (self.model, self.optim) = epoch_train.run::<LC, OutputTrain>(
175 self.model,
176 self.optim,
177 &mut self.lr_scheduler,
178 &mut self.event_processor,
179 &self.interrupter,
180 );
181 }
182
183 if self.interrupter.should_stop() {
184 break;
185 }
186
187 // TODO: multi-device validation?
188 let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs);
189 epoch_valid.run::<LC, OutputValid>(
190 &self.model,
191 &mut self.event_processor,
192 &self.interrupter,
193 );
194
195 if let Some(checkpointer) = &mut self.checkpointer {
196 checkpointer.checkpoint(
197 &self.model,
198 &self.optim,
199 &self.lr_scheduler,
200 epoch,
201 &self.event_store,
202 );
203 }
204
205 if let Some(early_stopping) = &mut self.early_stopping {
206 if early_stopping.should_stop(epoch, &self.event_store) {
207 break;
208 }
209 }
210 }
211
212 // Signal training end. For the TUI renderer, this handles the exit & return to main screen.
213 self.event_processor.process_train(Event::End);
214
215 // Display learner summary
216 if let Some(summary) = self.summary {
217 match summary.init() {
218 Ok(summary) => {
219 println!("{}", summary.with_model(self.model.to_string()))
220 }
221 Err(err) => log::error!("Could not retrieve learner summary:\n{err}"),
222 }
223 }
224
225 self.model
226 }
227}