burn_train/learner/
train_val.rs

1use crate::{ItemLazy, renderer::MetricsRenderer};
2use burn_core::module::AutodiffModule;
3use burn_core::tensor::backend::AutodiffBackend;
4use burn_optim::{GradientsParams, MultiGradientsParams, Optimizer};
5
6/// A training output.
7pub struct TrainOutput<TO> {
8    /// The gradients.
9    pub grads: GradientsParams,
10
11    /// The item.
12    pub item: TO,
13}
14
15impl<TO> TrainOutput<TO> {
16    /// Creates a new training output.
17    ///
18    /// # Arguments
19    ///
20    /// * `module` - The module.
21    /// * `grads` - The gradients.
22    /// * `item` - The item.
23    ///
24    /// # Returns
25    ///
26    /// A new training output.
27    pub fn new<B: AutodiffBackend, M: AutodiffModule<B>>(
28        module: &M,
29        grads: B::Gradients,
30        item: TO,
31    ) -> Self {
32        let grads = GradientsParams::from_grads(grads, module);
33        Self { grads, item }
34    }
35}
36
37/// Trait to be implemented for models to be able to be trained.
38///
39/// The [step](TrainStep::step) method needs to be manually implemented for all structs.
40///
41/// The [optimize](TrainStep::optimize) method can be overridden if you want to control how the
42/// optimizer is used to update the model. This can be useful if you want to call custom mutable
43/// functions on your model (e.g., clipping the weights) before or after the optimizer is used.
44///
45/// # Notes
46///
47/// To be used with the [Learner](crate::Learner) struct, the struct which implements this trait must
48/// also implement the [AutodiffModule] trait, which is done automatically with the
49/// [Module](burn_core::module::Module) derive.
50pub trait TrainStep {
51    /// Type of input for a step of the training stage.
52    type Input: Send + 'static;
53    /// Type of output for a step of the training stage.
54    type Output: ItemLazy + 'static;
55    /// Runs a step for training, which executes the forward and backward passes.
56    ///
57    /// # Arguments
58    ///
59    /// * `item` - The input for the model.
60    ///
61    /// # Returns
62    ///
63    /// The output containing the model output and the gradients.
64    fn step(&self, item: Self::Input) -> TrainOutput<Self::Output>;
65    /// Optimize the current module with the provided gradients and learning rate.
66    ///
67    /// # Arguments
68    ///
69    /// * `optim`: Optimizer used for learning.
70    /// * `lr`: The learning rate used for this step.
71    /// * `grads`: The gradients of each parameter in the current model.
72    ///
73    /// # Returns
74    ///
75    /// The updated model.
76    fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
77    where
78        B: AutodiffBackend,
79        O: Optimizer<Self, B>,
80        Self: AutodiffModule<B>,
81    {
82        optim.step(lr, self, grads)
83    }
84    /// Optimize the current module with the provided gradients and learning rate.
85    ///
86    /// # Arguments
87    ///
88    /// * `optim`: Optimizer used for learning.
89    /// * `lr`: The learning rate used for this step.
90    /// * `grads`: Multiple gradients associated to each parameter in the current model.
91    ///
92    /// # Returns
93    ///
94    /// The updated model.
95    fn optimize_multi<B, O>(self, optim: &mut O, lr: f64, grads: MultiGradientsParams) -> Self
96    where
97        B: AutodiffBackend,
98        O: Optimizer<Self, B>,
99        Self: AutodiffModule<B>,
100    {
101        optim.step_multi(lr, self, grads)
102    }
103}
104
105/// Trait to be implemented for validating models.
106pub trait InferenceStep {
107    /// Type of input for an inference step.
108    type Input: Send + 'static;
109    /// Type of output for an inference step.
110    type Output: ItemLazy + 'static;
111    /// Runs a validation step.
112    ///
113    /// # Arguments
114    ///
115    /// * `item` - The item to validate on.
116    ///
117    /// # Returns
118    ///
119    /// The validation output.
120    fn step(&self, item: Self::Input) -> Self::Output;
121}
122
123/// The result of a training, containing the model along with the [renderer](MetricsRenderer).
124pub struct LearningResult<M> {
125    /// The model with the learned weights.
126    pub model: M,
127    /// The renderer that can be used for follow up training and evaluation.
128    pub renderer: Box<dyn MetricsRenderer>,
129}