burn_train/learner/
train_val.rs

1use crate::components::LearnerComponents;
2use crate::metric::processor::{Event, EventProcessor};
3use crate::{Learner, TrainEpoch, ValidEpoch};
4use burn_core::data::dataloader::DataLoader;
5use burn_core::module::{AutodiffModule, Module};
6use burn_core::optim::{GradientsParams, Optimizer};
7use burn_core::tensor::backend::AutodiffBackend;
8use std::sync::Arc;
9
10/// A training output.
11pub struct TrainOutput<TO> {
12    /// The gradients.
13    pub grads: GradientsParams,
14
15    /// The item.
16    pub item: TO,
17}
18
19impl<TO> TrainOutput<TO> {
20    /// Creates a new training output.
21    ///
22    /// # Arguments
23    ///
24    /// * `module` - The module.
25    /// * `grads` - The gradients.
26    /// * `item` - The item.
27    ///
28    /// # Returns
29    ///
30    /// A new training output.
31    pub fn new<B: AutodiffBackend, M: AutodiffModule<B>>(
32        module: &M,
33        grads: B::Gradients,
34        item: TO,
35    ) -> Self {
36        let grads = GradientsParams::from_grads(grads, module);
37        Self { grads, item }
38    }
39}
40
41/// Trait to be implemented for training models.
42///
43/// The [step](TrainStep::step) method needs to be manually implemented for all structs.
44///
45/// The [optimize](TrainStep::optimize) method can be overridden if you want to control how the
46/// optimizer is used to update the model. This can be useful if you want to call custom mutable
47/// functions on your model (e.g., clipping the weights) before or after the optimizer is used.
48///
49/// # Notes
50///
51/// To be used with the [Learner](Learner) struct, the struct which implements this trait must
52/// also implement the [AutodiffModule] trait, which is done automatically with the
53/// [Module](burn_core::module::Module) derive.
54pub trait TrainStep<TI, TO> {
55    /// Runs the training step, which executes the forward and backward passes.
56    ///
57    /// # Arguments
58    ///
59    /// * `item` - The training input for the model.
60    ///
61    /// # Returns
62    ///
63    /// The training output containing the model output and the gradients.
64    fn step(&self, item: TI) -> TrainOutput<TO>;
65    /// Optimize the current module with the provided gradients and learning rate.
66    ///
67    /// # Arguments
68    ///
69    /// * `optim`: Optimizer used for training this model.
70    /// * `lr`: The learning rate used for this step.
71    /// * `grads`: The gradients of each parameter in the current model.
72    ///
73    /// # Returns
74    ///
75    /// The updated model.
76    fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
77    where
78        B: AutodiffBackend,
79        O: Optimizer<Self, B>,
80        Self: AutodiffModule<B>,
81    {
82        optim.step(lr, self, grads)
83    }
84}
85
86/// Trait to be implemented for validating models.
87pub trait ValidStep<VI, VO> {
88    /// Runs a validation step.
89    ///
90    /// # Arguments
91    ///
92    /// * `item` - The item to validate on.
93    ///
94    /// # Returns
95    ///
96    /// The validation output.
97    fn step(&self, item: VI) -> VO;
98}
99
100impl<LC: LearnerComponents> Learner<LC> {
101    /// Fits the model.
102    ///
103    /// # Arguments
104    ///
105    /// * `dataloader_train` - The training dataloader.
106    /// * `dataloader_valid` - The validation dataloader.
107    ///
108    /// # Returns
109    ///
110    /// The fitted model.
111    pub fn fit<InputTrain, InputValid, OutputTrain, OutputValid>(
112        mut self,
113        dataloader_train: Arc<dyn DataLoader<InputTrain>>,
114        dataloader_valid: Arc<dyn DataLoader<InputValid>>,
115    ) -> LC::Model
116    where
117        InputTrain: Send + 'static,
118        InputValid: Send,
119        OutputTrain: Send + 'static,
120        OutputValid: Send,
121        LC::Model: TrainStep<InputTrain, OutputTrain>,
122        <LC::Model as AutodiffModule<LC::Backend>>::InnerModule: ValidStep<InputValid, OutputValid>,
123        LC::EventProcessor: EventProcessor<ItemTrain = OutputTrain, ItemValid = OutputValid>,
124    {
125        log::info!("Fitting the model:\n {}", self.model);
126        // The reference model is always on the first device provided.
127        if let Some(device) = self.devices.first() {
128            self.model = self.model.fork(device);
129        }
130
131        let starting_epoch = match self.checkpoint {
132            Some(checkpoint) => {
133                if let Some(checkpointer) = &mut self.checkpointer {
134                    (self.model, self.optim, self.lr_scheduler) = checkpointer.load_checkpoint(
135                        self.model,
136                        self.optim,
137                        self.lr_scheduler,
138                        &Default::default(), // Load the checkpoint on the default device.
139                        checkpoint,
140                    );
141                }
142                checkpoint + 1
143            }
144            None => 1,
145        };
146
147        for epoch in starting_epoch..self.num_epochs + 1 {
148            let epoch_train = TrainEpoch::new(
149                dataloader_train.clone(),
150                epoch,
151                self.num_epochs,
152                self.grad_accumulation,
153            );
154
155            if self.devices.len() > 1 {
156                (self.model, self.optim) = epoch_train.run_multi_device::<LC, OutputTrain>(
157                    self.model,
158                    self.optim,
159                    &mut self.lr_scheduler,
160                    &mut self.event_processor,
161                    self.devices.clone(),
162                    &self.interrupter,
163                )
164            } else {
165                (self.model, self.optim) = epoch_train.run::<LC, OutputTrain>(
166                    self.model,
167                    self.optim,
168                    &mut self.lr_scheduler,
169                    &mut self.event_processor,
170                    &self.interrupter,
171                );
172            }
173
174            if self.interrupter.should_stop() {
175                break;
176            }
177
178            let epoch_valid = ValidEpoch::new(dataloader_valid.clone(), epoch, self.num_epochs);
179            epoch_valid.run::<LC, OutputValid>(
180                &self.model,
181                &mut self.event_processor,
182                &self.interrupter,
183            );
184
185            if let Some(checkpointer) = &mut self.checkpointer {
186                checkpointer.checkpoint(
187                    &self.model,
188                    &self.optim,
189                    &self.lr_scheduler,
190                    epoch,
191                    &self.event_store,
192                );
193            }
194
195            if let Some(early_stopping) = &mut self.early_stopping {
196                if early_stopping.should_stop(epoch, &self.event_store) {
197                    break;
198                }
199            }
200        }
201
202        // Signal training end. For the TUI renderer, this handles the exit & return to main screen.
203        self.event_processor.process_train(Event::End);
204
205        // Display learner summary
206        if let Some(summary) = self.summary {
207            match summary.init() {
208                Ok(summary) => {
209                    println!("{}", summary.with_model(self.model.to_string()))
210                }
211                Err(err) => log::error!("Could not retrieve learner summary:\n{err}"),
212            }
213        }
214
215        self.model
216    }
217}