burn_train/learner/
train_val.rs

1use crate::components::{LearnerComponents, TrainBackend, ValidBackend};
2use crate::metric::processor::{Event, EventProcessor};
3use crate::{Learner, TrainEpoch, ValidEpoch};
4use burn_core::data::dataloader::DataLoader;
5use burn_core::data::dataloader::split::split_dataloader;
6use burn_core::module::{AutodiffModule, Module};
7use burn_core::optim::{GradientsParams, Optimizer};
8use burn_core::tensor::backend::AutodiffBackend;
9use std::sync::Arc;
10
11/// A training output.
12pub struct TrainOutput<TO> {
13    /// The gradients.
14    pub grads: GradientsParams,
15
16    /// The item.
17    pub item: TO,
18}
19
20impl<TO> TrainOutput<TO> {
21    /// Creates a new training output.
22    ///
23    /// # Arguments
24    ///
25    /// * `module` - The module.
26    /// * `grads` - The gradients.
27    /// * `item` - The item.
28    ///
29    /// # Returns
30    ///
31    /// A new training output.
32    pub fn new<B: AutodiffBackend, M: AutodiffModule<B>>(
33        module: &M,
34        grads: B::Gradients,
35        item: TO,
36    ) -> Self {
37        let grads = GradientsParams::from_grads(grads, module);
38        Self { grads, item }
39    }
40}
41
42/// Trait to be implemented for training models.
43///
44/// The [step](TrainStep::step) method needs to be manually implemented for all structs.
45///
46/// The [optimize](TrainStep::optimize) method can be overridden if you want to control how the
47/// optimizer is used to update the model. This can be useful if you want to call custom mutable
48/// functions on your model (e.g., clipping the weights) before or after the optimizer is used.
49///
50/// # Notes
51///
52/// To be used with the [Learner](Learner) struct, the struct which implements this trait must
53/// also implement the [AutodiffModule] trait, which is done automatically with the
54/// [Module](burn_core::module::Module) derive.
55pub trait TrainStep<TI, TO> {
56    /// Runs the training step, which executes the forward and backward passes.
57    ///
58    /// # Arguments
59    ///
60    /// * `item` - The training input for the model.
61    ///
62    /// # Returns
63    ///
64    /// The training output containing the model output and the gradients.
65    fn step(&self, item: TI) -> TrainOutput<TO>;
66    /// Optimize the current module with the provided gradients and learning rate.
67    ///
68    /// # Arguments
69    ///
70    /// * `optim`: Optimizer used for training this model.
71    /// * `lr`: The learning rate used for this step.
72    /// * `grads`: The gradients of each parameter in the current model.
73    ///
74    /// # Returns
75    ///
76    /// The updated model.
77    fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
78    where
79        B: AutodiffBackend,
80        O: Optimizer<Self, B>,
81        Self: AutodiffModule<B>,
82    {
83        optim.step(lr, self, grads)
84    }
85}
86
87/// Trait to be implemented for validating models.
88pub trait ValidStep<VI, VO> {
89    /// Runs a validation step.
90    ///
91    /// # Arguments
92    ///
93    /// * `item` - The item to validate on.
94    ///
95    /// # Returns
96    ///
97    /// The validation output.
98    fn step(&self, item: VI) -> VO;
99}
100
101impl<LC: LearnerComponents> Learner<LC> {
102    /// Fits the model.
103    ///
104    /// # Arguments
105    ///
106    /// * `dataloader_train` - The training dataloader.
107    /// * `dataloader_valid` - The validation dataloader.
108    ///
109    /// # Returns
110    ///
111    /// The fitted model.
112    pub fn fit<InputTrain, InputValid, OutputTrain, OutputValid>(
113        mut self,
114        mut dataloader_train: Arc<dyn DataLoader<TrainBackend<LC>, InputTrain>>,
115        mut dataloader_valid: Arc<dyn DataLoader<ValidBackend<LC>, InputValid>>,
116    ) -> LC::Model
117    where
118        InputTrain: Send + 'static,
119        InputValid: Send,
120        OutputTrain: Send + 'static,
121        OutputValid: Send,
122        LC::Model: TrainStep<InputTrain, OutputTrain>,
123        <LC::Model as AutodiffModule<LC::Backend>>::InnerModule: ValidStep<InputValid, OutputValid>,
124        LC::EventProcessor: EventProcessor<ItemTrain = OutputTrain, ItemValid = OutputValid>,
125    {
126        log::info!("Fitting the model:\n {}", self.model);
127        // The reference model is always on the first device provided.
128        if let Some(device) = self.devices.first() {
129            self.model = self.model.fork(device);
130            dataloader_train = dataloader_train.to_device(device);
131            dataloader_valid = dataloader_valid.to_device(device);
132        }
133
134        let starting_epoch = match self.checkpoint {
135            Some(checkpoint) => {
136                if let Some(checkpointer) = &mut self.checkpointer {
137                    (self.model, self.optim, self.lr_scheduler) = checkpointer.load_checkpoint(
138                        self.model,
139                        self.optim,
140                        self.lr_scheduler,
141                        &Default::default(), // Load the checkpoint on the default device.
142                        checkpoint,
143                    );
144                }
145                checkpoint + 1
146            }
147            None => 1,
148        };
149
150        // `MultiDevicesTrainStep` has one worker per device, so we use a fixed device strategy
151        // for each (worker) data loader. This matches the expected device on the worker, so we
152        // don't have to move the data between devices.
153        let dataloaders_train = split_dataloader(dataloader_train, &self.devices);
154
155        // Changed the train epoch to keep the dataloaders
156        let mut epoch_train = TrainEpoch::new(
157            dataloaders_train,
158            starting_epoch,
159            self.num_epochs,
160            self.grad_accumulation,
161        );
162
163        for epoch in starting_epoch..self.num_epochs + 1 {
164            if self.devices.len() > 1 {
165                (self.model, self.optim) = epoch_train.run_multi_device::<LC, OutputTrain>(
166                    self.model,
167                    self.optim,
168                    &mut self.lr_scheduler,
169                    &mut self.event_processor,
170                    self.devices.clone(),
171                    &self.interrupter,
172                )
173            } else {
174                (self.model, self.optim) = epoch_train.run::<LC, OutputTrain>(
175                    self.model,
176                    self.optim,
177                    &mut self.lr_scheduler,
178                    &mut self.event_processor,
179                    &self.interrupter,
180                );
181            }
182
183            if self.interrupter.should_stop() {
184                break;
185            }
186
187            // TODO: multi-device validation?
188            let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs);
189            epoch_valid.run::<LC, OutputValid>(
190                &self.model,
191                &mut self.event_processor,
192                &self.interrupter,
193            );
194
195            if let Some(checkpointer) = &mut self.checkpointer {
196                checkpointer.checkpoint(
197                    &self.model,
198                    &self.optim,
199                    &self.lr_scheduler,
200                    epoch,
201                    &self.event_store,
202                );
203            }
204
205            if let Some(early_stopping) = &mut self.early_stopping {
206                if early_stopping.should_stop(epoch, &self.event_store) {
207                    break;
208                }
209            }
210        }
211
212        // Signal training end. For the TUI renderer, this handles the exit & return to main screen.
213        self.event_processor.process_train(Event::End);
214
215        // Display learner summary
216        if let Some(summary) = self.summary {
217            match summary.init() {
218                Ok(summary) => {
219                    println!("{}", summary.with_model(self.model.to_string()))
220                }
221                Err(err) => log::error!("Could not retrieve learner summary:\n{err}"),
222            }
223        }
224
225        self.model
226    }
227}