burn_train/learner/
train_val.rs

1use crate::components::{
2    InputTrain, InputValid, LearnerComponentTypes, TrainBackend, ValidBackend,
3};
4#[cfg(feature = "ddp")]
5use crate::ddp::DdpLearningStrategy;
6use crate::multi::MultiDeviceLearningStrategy;
7use crate::renderer::MetricsRenderer;
8use crate::single::SingleDeviceLearningStrategy;
9use crate::{Learner, LearningMethod, LearningStrategy};
10use burn_core::data::dataloader::DataLoader;
11use burn_core::module::AutodiffModule;
12use burn_core::tensor::backend::AutodiffBackend;
13use burn_optim::{GradientsParams, MultiGradientsParams, Optimizer};
14use std::sync::Arc;
15
16/// A training output.
17pub struct TrainOutput<TO> {
18    /// The gradients.
19    pub grads: GradientsParams,
20
21    /// The item.
22    pub item: TO,
23}
24
25impl<TO> TrainOutput<TO> {
26    /// Creates a new training output.
27    ///
28    /// # Arguments
29    ///
30    /// * `module` - The module.
31    /// * `grads` - The gradients.
32    /// * `item` - The item.
33    ///
34    /// # Returns
35    ///
36    /// A new training output.
37    pub fn new<B: AutodiffBackend, M: AutodiffModule<B>>(
38        module: &M,
39        grads: B::Gradients,
40        item: TO,
41    ) -> Self {
42        let grads = GradientsParams::from_grads(grads, module);
43        Self { grads, item }
44    }
45}
46
47/// Trait to be implemented for training models.
48///
49/// The [step](TrainStep::step) method needs to be manually implemented for all structs.
50///
51/// The [optimize](TrainStep::optimize) method can be overridden if you want to control how the
52/// optimizer is used to update the model. This can be useful if you want to call custom mutable
53/// functions on your model (e.g., clipping the weights) before or after the optimizer is used.
54///
55/// # Notes
56///
57/// To be used with the [Learner](Learner) struct, the struct which implements this trait must
58/// also implement the [AutodiffModule] trait, which is done automatically with the
59/// [Module](burn_core::module::Module) derive.
60pub trait TrainStep<TI, TO> {
61    /// Runs the training step, which executes the forward and backward passes.
62    ///
63    /// # Arguments
64    ///
65    /// * `item` - The training input for the model.
66    ///
67    /// # Returns
68    ///
69    /// The training output containing the model output and the gradients.
70    fn step(&self, item: TI) -> TrainOutput<TO>;
71    /// Optimize the current module with the provided gradients and learning rate.
72    ///
73    /// # Arguments
74    ///
75    /// * `optim`: Optimizer used for training this model.
76    /// * `lr`: The learning rate used for this step.
77    /// * `grads`: The gradients of each parameter in the current model.
78    ///
79    /// # Returns
80    ///
81    /// The updated model.
82    fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
83    where
84        B: AutodiffBackend,
85        O: Optimizer<Self, B>,
86        Self: AutodiffModule<B>,
87    {
88        optim.step(lr, self, grads)
89    }
90    /// Optimize the current module with the provided gradients and learning rate.
91    ///
92    /// # Arguments
93    ///
94    /// * `optim`: Optimizer used for training this model.
95    /// * `lr`: The learning rate used for this step.
96    /// * `grads`: Multiple gradients associated to each parameter in the current model.
97    ///
98    /// # Returns
99    ///
100    /// The updated model.
101    fn optimize_multi<B, O>(self, optim: &mut O, lr: f64, grads: MultiGradientsParams) -> Self
102    where
103        B: AutodiffBackend,
104        O: Optimizer<Self, B>,
105        Self: AutodiffModule<B>,
106    {
107        optim.step_multi(lr, self, grads)
108    }
109}
110
111/// Trait to be implemented for validating models.
112pub trait ValidStep<VI, VO> {
113    /// Runs a validation step.
114    ///
115    /// # Arguments
116    ///
117    /// * `item` - The item to validate on.
118    ///
119    /// # Returns
120    ///
121    /// The validation output.
122    fn step(&self, item: VI) -> VO;
123}
124
125/// A reference to the training split [DataLoader](DataLoader).
126pub type TrainLoader<LC> = Arc<dyn DataLoader<TrainBackend<LC>, InputTrain<LC>>>;
127/// A reference to the validation split [DataLoader](DataLoader).
128pub type ValidLoader<LC> = Arc<dyn DataLoader<ValidBackend<LC>, InputValid<LC>>>;
129
130/// The result of a training, containing the model along with the [renderer](MetricsRenderer).
131pub struct TrainingResult<M> {
132    /// The model trained.
133    pub model: M,
134    /// The renderer that can be used for follow up training and evaluation.
135    pub renderer: Box<dyn MetricsRenderer>,
136}
137
138impl<LC: LearnerComponentTypes + Send + 'static> Learner<LC> {
139    /// Fits the model.
140    ///
141    /// # Arguments
142    ///
143    /// * `dataloader_train` - The training dataloader.
144    /// * `dataloader_valid` - The validation dataloader.
145    ///
146    /// # Returns
147    ///
148    /// The fitted model.
149    pub fn fit(
150        self,
151        dataloader_train: TrainLoader<LC>,
152        dataloader_valid: ValidLoader<LC>,
153    ) -> TrainingResult<LC::InnerModel> {
154        log::info!("Fitting the model:\n {}", self.model);
155
156        match &self.learning_strategy {
157            LearningStrategy::SingleDevice(device) => {
158                let single_device = SingleDeviceLearningStrategy::new(device.clone());
159                single_device.fit(self, dataloader_train, dataloader_valid)
160            }
161            LearningStrategy::CustomSingleDevice(learning_strategy) => learning_strategy
162                .clone()
163                .fit(self, dataloader_train, dataloader_valid),
164            LearningStrategy::MultiDevice(devices, optim) => {
165                let multi_device = MultiDeviceLearningStrategy::new(devices.clone(), *optim);
166                multi_device.fit(self, dataloader_train, dataloader_valid)
167            }
168            LearningStrategy::CustomMultiDevice(learning_strategy) => learning_strategy
169                .clone()
170                .fit(self, dataloader_train, dataloader_valid),
171
172            #[cfg(feature = "ddp")]
173            LearningStrategy::DistributedDataParallel { devices, config } => {
174                let ddp = DdpLearningStrategy::new(devices.clone(), config.clone());
175                ddp.fit(self, dataloader_train, dataloader_valid)
176            }
177        }
178    }
179}