burn_train/learner/train_val.rs
1use crate::{ItemLazy, renderer::MetricsRenderer};
2use burn_core::module::AutodiffModule;
3use burn_core::tensor::backend::AutodiffBackend;
4use burn_optim::{GradientsParams, MultiGradientsParams, Optimizer};
5
6/// A training output.
7pub struct TrainOutput<TO> {
8 /// The gradients.
9 pub grads: GradientsParams,
10
11 /// The item.
12 pub item: TO,
13}
14
15impl<TO> TrainOutput<TO> {
16 /// Creates a new training output.
17 ///
18 /// # Arguments
19 ///
20 /// * `module` - The module.
21 /// * `grads` - The gradients.
22 /// * `item` - The item.
23 ///
24 /// # Returns
25 ///
26 /// A new training output.
27 pub fn new<B: AutodiffBackend, M: AutodiffModule<B>>(
28 module: &M,
29 grads: B::Gradients,
30 item: TO,
31 ) -> Self {
32 let grads = GradientsParams::from_grads(grads, module);
33 Self { grads, item }
34 }
35}
36
37/// Trait to be implemented for models to be able to be trained.
38///
39/// The [step](TrainStep::step) method needs to be manually implemented for all structs.
40///
41/// The [optimize](TrainStep::optimize) method can be overridden if you want to control how the
42/// optimizer is used to update the model. This can be useful if you want to call custom mutable
43/// functions on your model (e.g., clipping the weights) before or after the optimizer is used.
44///
45/// # Notes
46///
47/// To be used with the [Learner](crate::Learner) struct, the struct which implements this trait must
48/// also implement the [AutodiffModule] trait, which is done automatically with the
49/// [Module](burn_core::module::Module) derive.
50pub trait TrainStep {
51 /// Type of input for a step of the training stage.
52 type Input: Send + 'static;
53 /// Type of output for a step of the training stage.
54 type Output: ItemLazy + 'static;
55 /// Runs a step for training, which executes the forward and backward passes.
56 ///
57 /// # Arguments
58 ///
59 /// * `item` - The input for the model.
60 ///
61 /// # Returns
62 ///
63 /// The output containing the model output and the gradients.
64 fn step(&self, item: Self::Input) -> TrainOutput<Self::Output>;
65 /// Optimize the current module with the provided gradients and learning rate.
66 ///
67 /// # Arguments
68 ///
69 /// * `optim`: Optimizer used for learning.
70 /// * `lr`: The learning rate used for this step.
71 /// * `grads`: The gradients of each parameter in the current model.
72 ///
73 /// # Returns
74 ///
75 /// The updated model.
76 fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
77 where
78 B: AutodiffBackend,
79 O: Optimizer<Self, B>,
80 Self: AutodiffModule<B>,
81 {
82 optim.step(lr, self, grads)
83 }
84 /// Optimize the current module with the provided gradients and learning rate.
85 ///
86 /// # Arguments
87 ///
88 /// * `optim`: Optimizer used for learning.
89 /// * `lr`: The learning rate used for this step.
90 /// * `grads`: Multiple gradients associated to each parameter in the current model.
91 ///
92 /// # Returns
93 ///
94 /// The updated model.
95 fn optimize_multi<B, O>(self, optim: &mut O, lr: f64, grads: MultiGradientsParams) -> Self
96 where
97 B: AutodiffBackend,
98 O: Optimizer<Self, B>,
99 Self: AutodiffModule<B>,
100 {
101 optim.step_multi(lr, self, grads)
102 }
103}
104
105/// Trait to be implemented for validating models.
106pub trait InferenceStep {
107 /// Type of input for an inference step.
108 type Input: Send + 'static;
109 /// Type of output for an inference step.
110 type Output: ItemLazy + 'static;
111 /// Runs a validation step.
112 ///
113 /// # Arguments
114 ///
115 /// * `item` - The item to validate on.
116 ///
117 /// # Returns
118 ///
119 /// The validation output.
120 fn step(&self, item: Self::Input) -> Self::Output;
121}
122
123/// The result of a training, containing the model along with the [renderer](MetricsRenderer).
124pub struct LearningResult<M> {
125 /// The model with the learned weights.
126 pub model: M,
127 /// The renderer that can be used for follow up training and evaluation.
128 pub renderer: Box<dyn MetricsRenderer>,
129}