burn_train/learner/
train_val.rs

1use crate::components::{
2    InputTrain, InputValid, LearnerComponentTypes, TrainBackend, ValidBackend,
3};
4#[cfg(feature = "ddp")]
5use crate::ddp::DdpLearningStrategy;
6use crate::multi::MultiDeviceLearningStrategy;
7use crate::renderer::MetricsRenderer;
8use crate::single::SingleDeviceLearningStrategy;
9use crate::{Learner, LearningMethod, LearningStrategy};
10use burn_core::data::dataloader::DataLoader;
11use burn_core::module::AutodiffModule;
12use burn_core::tensor::backend::AutodiffBackend;
13use burn_optim::{GradientsParams, Optimizer};
14use std::sync::Arc;
15
16/// A training output.
17pub struct TrainOutput<TO> {
18    /// The gradients.
19    pub grads: GradientsParams,
20
21    /// The item.
22    pub item: TO,
23}
24
25impl<TO> TrainOutput<TO> {
26    /// Creates a new training output.
27    ///
28    /// # Arguments
29    ///
30    /// * `module` - The module.
31    /// * `grads` - The gradients.
32    /// * `item` - The item.
33    ///
34    /// # Returns
35    ///
36    /// A new training output.
37    pub fn new<B: AutodiffBackend, M: AutodiffModule<B>>(
38        module: &M,
39        grads: B::Gradients,
40        item: TO,
41    ) -> Self {
42        let grads = GradientsParams::from_grads(grads, module);
43        Self { grads, item }
44    }
45}
46
47/// Trait to be implemented for training models.
48///
49/// The [step](TrainStep::step) method needs to be manually implemented for all structs.
50///
51/// The [optimize](TrainStep::optimize) method can be overridden if you want to control how the
52/// optimizer is used to update the model. This can be useful if you want to call custom mutable
53/// functions on your model (e.g., clipping the weights) before or after the optimizer is used.
54///
55/// # Notes
56///
57/// To be used with the [Learner](Learner) struct, the struct which implements this trait must
58/// also implement the [AutodiffModule] trait, which is done automatically with the
59/// [Module](burn_core::module::Module) derive.
60pub trait TrainStep<TI, TO> {
61    /// Runs the training step, which executes the forward and backward passes.
62    ///
63    /// # Arguments
64    ///
65    /// * `item` - The training input for the model.
66    ///
67    /// # Returns
68    ///
69    /// The training output containing the model output and the gradients.
70    fn step(&self, item: TI) -> TrainOutput<TO>;
71    /// Optimize the current module with the provided gradients and learning rate.
72    ///
73    /// # Arguments
74    ///
75    /// * `optim`: Optimizer used for training this model.
76    /// * `lr`: The learning rate used for this step.
77    /// * `grads`: The gradients of each parameter in the current model.
78    ///
79    /// # Returns
80    ///
81    /// The updated model.
82    fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
83    where
84        B: AutodiffBackend,
85        O: Optimizer<Self, B>,
86        Self: AutodiffModule<B>,
87    {
88        optim.step(lr, self, grads)
89    }
90}
91
92/// Trait to be implemented for validating models.
93pub trait ValidStep<VI, VO> {
94    /// Runs a validation step.
95    ///
96    /// # Arguments
97    ///
98    /// * `item` - The item to validate on.
99    ///
100    /// # Returns
101    ///
102    /// The validation output.
103    fn step(&self, item: VI) -> VO;
104}
105
106pub(crate) type TrainLoader<LC> = Arc<dyn DataLoader<TrainBackend<LC>, InputTrain<LC>>>;
107pub(crate) type ValidLoader<LC> = Arc<dyn DataLoader<ValidBackend<LC>, InputValid<LC>>>;
108
109/// The result of a training, containing the model along with the [renderer](MetricsRenderer).
110pub struct TrainingResult<M> {
111    /// The model trained.
112    pub model: M,
113    /// The renderer that can be used for follow up training and evaluation.
114    pub renderer: Box<dyn MetricsRenderer>,
115}
116
117impl<LC: LearnerComponentTypes + Send + 'static> Learner<LC> {
118    /// Fits the model.
119    ///
120    /// # Arguments
121    ///
122    /// * `dataloader_train` - The training dataloader.
123    /// * `dataloader_valid` - The validation dataloader.
124    ///
125    /// # Returns
126    ///
127    /// The fitted model.
128    pub fn fit(
129        self,
130        dataloader_train: TrainLoader<LC>,
131        dataloader_valid: ValidLoader<LC>,
132    ) -> TrainingResult<LC::InnerModel> {
133        log::info!("Fitting the model:\n {}", self.model);
134
135        match &self.learning_strategy {
136            LearningStrategy::SingleDevice(device) => {
137                let single_device = SingleDeviceLearningStrategy::new(device.clone());
138                single_device.fit(self, dataloader_train, dataloader_valid)
139            }
140            LearningStrategy::MultiDeviceNaive(devices) => {
141                let multi_device = MultiDeviceLearningStrategy::new(devices.clone());
142                multi_device.fit(self, dataloader_train, dataloader_valid)
143            }
144
145            #[cfg(feature = "ddp")]
146            LearningStrategy::DistributedDataParallel { devices, config } => {
147                let ddp = DdpLearningStrategy::new(devices.clone(), config.clone());
148                ddp.fit(self, dataloader_train, dataloader_valid)
149            }
150        }
151    }
152}