burn_train/learner/train_val.rs
1use crate::components::{
2 InputTrain, InputValid, LearnerComponentTypes, TrainBackend, ValidBackend,
3};
4#[cfg(feature = "ddp")]
5use crate::ddp::DdpLearningStrategy;
6use crate::multi::MultiDeviceLearningStrategy;
7use crate::renderer::MetricsRenderer;
8use crate::single::SingleDeviceLearningStrategy;
9use crate::{Learner, LearningMethod, LearningStrategy};
10use burn_core::data::dataloader::DataLoader;
11use burn_core::module::AutodiffModule;
12use burn_core::tensor::backend::AutodiffBackend;
13use burn_optim::{GradientsParams, MultiGradientsParams, Optimizer};
14use std::sync::Arc;
15
16/// A training output.
17pub struct TrainOutput<TO> {
18 /// The gradients.
19 pub grads: GradientsParams,
20
21 /// The item.
22 pub item: TO,
23}
24
25impl<TO> TrainOutput<TO> {
26 /// Creates a new training output.
27 ///
28 /// # Arguments
29 ///
30 /// * `module` - The module.
31 /// * `grads` - The gradients.
32 /// * `item` - The item.
33 ///
34 /// # Returns
35 ///
36 /// A new training output.
37 pub fn new<B: AutodiffBackend, M: AutodiffModule<B>>(
38 module: &M,
39 grads: B::Gradients,
40 item: TO,
41 ) -> Self {
42 let grads = GradientsParams::from_grads(grads, module);
43 Self { grads, item }
44 }
45}
46
47/// Trait to be implemented for training models.
48///
49/// The [step](TrainStep::step) method needs to be manually implemented for all structs.
50///
51/// The [optimize](TrainStep::optimize) method can be overridden if you want to control how the
52/// optimizer is used to update the model. This can be useful if you want to call custom mutable
53/// functions on your model (e.g., clipping the weights) before or after the optimizer is used.
54///
55/// # Notes
56///
57/// To be used with the [Learner](Learner) struct, the struct which implements this trait must
58/// also implement the [AutodiffModule] trait, which is done automatically with the
59/// [Module](burn_core::module::Module) derive.
60pub trait TrainStep<TI, TO> {
61 /// Runs the training step, which executes the forward and backward passes.
62 ///
63 /// # Arguments
64 ///
65 /// * `item` - The training input for the model.
66 ///
67 /// # Returns
68 ///
69 /// The training output containing the model output and the gradients.
70 fn step(&self, item: TI) -> TrainOutput<TO>;
71 /// Optimize the current module with the provided gradients and learning rate.
72 ///
73 /// # Arguments
74 ///
75 /// * `optim`: Optimizer used for training this model.
76 /// * `lr`: The learning rate used for this step.
77 /// * `grads`: The gradients of each parameter in the current model.
78 ///
79 /// # Returns
80 ///
81 /// The updated model.
82 fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
83 where
84 B: AutodiffBackend,
85 O: Optimizer<Self, B>,
86 Self: AutodiffModule<B>,
87 {
88 optim.step(lr, self, grads)
89 }
90 /// Optimize the current module with the provided gradients and learning rate.
91 ///
92 /// # Arguments
93 ///
94 /// * `optim`: Optimizer used for training this model.
95 /// * `lr`: The learning rate used for this step.
96 /// * `grads`: Multiple gradients associated to each parameter in the current model.
97 ///
98 /// # Returns
99 ///
100 /// The updated model.
101 fn optimize_multi<B, O>(self, optim: &mut O, lr: f64, grads: MultiGradientsParams) -> Self
102 where
103 B: AutodiffBackend,
104 O: Optimizer<Self, B>,
105 Self: AutodiffModule<B>,
106 {
107 optim.step_multi(lr, self, grads)
108 }
109}
110
111/// Trait to be implemented for validating models.
112pub trait ValidStep<VI, VO> {
113 /// Runs a validation step.
114 ///
115 /// # Arguments
116 ///
117 /// * `item` - The item to validate on.
118 ///
119 /// # Returns
120 ///
121 /// The validation output.
122 fn step(&self, item: VI) -> VO;
123}
124
125/// A reference to the training split [DataLoader](DataLoader).
126pub type TrainLoader<LC> = Arc<dyn DataLoader<TrainBackend<LC>, InputTrain<LC>>>;
127/// A reference to the validation split [DataLoader](DataLoader).
128pub type ValidLoader<LC> = Arc<dyn DataLoader<ValidBackend<LC>, InputValid<LC>>>;
129
130/// The result of a training, containing the model along with the [renderer](MetricsRenderer).
131pub struct TrainingResult<M> {
132 /// The model trained.
133 pub model: M,
134 /// The renderer that can be used for follow up training and evaluation.
135 pub renderer: Box<dyn MetricsRenderer>,
136}
137
138impl<LC: LearnerComponentTypes + Send + 'static> Learner<LC> {
139 /// Fits the model.
140 ///
141 /// # Arguments
142 ///
143 /// * `dataloader_train` - The training dataloader.
144 /// * `dataloader_valid` - The validation dataloader.
145 ///
146 /// # Returns
147 ///
148 /// The fitted model.
149 pub fn fit(
150 self,
151 dataloader_train: TrainLoader<LC>,
152 dataloader_valid: ValidLoader<LC>,
153 ) -> TrainingResult<LC::InnerModel> {
154 log::info!("Fitting the model:\n {}", self.model);
155
156 match &self.learning_strategy {
157 LearningStrategy::SingleDevice(device) => {
158 let single_device = SingleDeviceLearningStrategy::new(device.clone());
159 single_device.fit(self, dataloader_train, dataloader_valid)
160 }
161 LearningStrategy::CustomSingleDevice(learning_strategy) => learning_strategy
162 .clone()
163 .fit(self, dataloader_train, dataloader_valid),
164 LearningStrategy::MultiDevice(devices, optim) => {
165 let multi_device = MultiDeviceLearningStrategy::new(devices.clone(), *optim);
166 multi_device.fit(self, dataloader_train, dataloader_valid)
167 }
168 LearningStrategy::CustomMultiDevice(learning_strategy) => learning_strategy
169 .clone()
170 .fit(self, dataloader_train, dataloader_valid),
171
172 #[cfg(feature = "ddp")]
173 LearningStrategy::DistributedDataParallel { devices, config } => {
174 let ddp = DdpLearningStrategy::new(devices.clone(), config.clone());
175 ddp.fit(self, dataloader_train, dataloader_valid)
176 }
177 }
178 }
179}