burn_train/learner/train_val.rs
1use crate::components::{
2 InputTrain, InputValid, LearnerComponentTypes, TrainBackend, ValidBackend,
3};
4#[cfg(feature = "ddp")]
5use crate::ddp::DdpLearningStrategy;
6use crate::multi::MultiDeviceLearningStrategy;
7use crate::renderer::MetricsRenderer;
8use crate::single::SingleDeviceLearningStrategy;
9use crate::{Learner, LearningMethod, LearningStrategy};
10use burn_core::data::dataloader::DataLoader;
11use burn_core::module::AutodiffModule;
12use burn_core::tensor::backend::AutodiffBackend;
13use burn_optim::{GradientsParams, Optimizer};
14use std::sync::Arc;
15
16/// A training output.
17pub struct TrainOutput<TO> {
18 /// The gradients.
19 pub grads: GradientsParams,
20
21 /// The item.
22 pub item: TO,
23}
24
25impl<TO> TrainOutput<TO> {
26 /// Creates a new training output.
27 ///
28 /// # Arguments
29 ///
30 /// * `module` - The module.
31 /// * `grads` - The gradients.
32 /// * `item` - The item.
33 ///
34 /// # Returns
35 ///
36 /// A new training output.
37 pub fn new<B: AutodiffBackend, M: AutodiffModule<B>>(
38 module: &M,
39 grads: B::Gradients,
40 item: TO,
41 ) -> Self {
42 let grads = GradientsParams::from_grads(grads, module);
43 Self { grads, item }
44 }
45}
46
47/// Trait to be implemented for training models.
48///
49/// The [step](TrainStep::step) method needs to be manually implemented for all structs.
50///
51/// The [optimize](TrainStep::optimize) method can be overridden if you want to control how the
52/// optimizer is used to update the model. This can be useful if you want to call custom mutable
53/// functions on your model (e.g., clipping the weights) before or after the optimizer is used.
54///
55/// # Notes
56///
57/// To be used with the [Learner](Learner) struct, the struct which implements this trait must
58/// also implement the [AutodiffModule] trait, which is done automatically with the
59/// [Module](burn_core::module::Module) derive.
60pub trait TrainStep<TI, TO> {
61 /// Runs the training step, which executes the forward and backward passes.
62 ///
63 /// # Arguments
64 ///
65 /// * `item` - The training input for the model.
66 ///
67 /// # Returns
68 ///
69 /// The training output containing the model output and the gradients.
70 fn step(&self, item: TI) -> TrainOutput<TO>;
71 /// Optimize the current module with the provided gradients and learning rate.
72 ///
73 /// # Arguments
74 ///
75 /// * `optim`: Optimizer used for training this model.
76 /// * `lr`: The learning rate used for this step.
77 /// * `grads`: The gradients of each parameter in the current model.
78 ///
79 /// # Returns
80 ///
81 /// The updated model.
82 fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
83 where
84 B: AutodiffBackend,
85 O: Optimizer<Self, B>,
86 Self: AutodiffModule<B>,
87 {
88 optim.step(lr, self, grads)
89 }
90}
91
92/// Trait to be implemented for validating models.
93pub trait ValidStep<VI, VO> {
94 /// Runs a validation step.
95 ///
96 /// # Arguments
97 ///
98 /// * `item` - The item to validate on.
99 ///
100 /// # Returns
101 ///
102 /// The validation output.
103 fn step(&self, item: VI) -> VO;
104}
105
106pub(crate) type TrainLoader<LC> = Arc<dyn DataLoader<TrainBackend<LC>, InputTrain<LC>>>;
107pub(crate) type ValidLoader<LC> = Arc<dyn DataLoader<ValidBackend<LC>, InputValid<LC>>>;
108
109/// The result of a training, containing the model along with the [renderer](MetricsRenderer).
110pub struct TrainingResult<M> {
111 /// The model trained.
112 pub model: M,
113 /// The renderer that can be used for follow up training and evaluation.
114 pub renderer: Box<dyn MetricsRenderer>,
115}
116
117impl<LC: LearnerComponentTypes + Send + 'static> Learner<LC> {
118 /// Fits the model.
119 ///
120 /// # Arguments
121 ///
122 /// * `dataloader_train` - The training dataloader.
123 /// * `dataloader_valid` - The validation dataloader.
124 ///
125 /// # Returns
126 ///
127 /// The fitted model.
128 pub fn fit(
129 self,
130 dataloader_train: TrainLoader<LC>,
131 dataloader_valid: ValidLoader<LC>,
132 ) -> TrainingResult<LC::InnerModel> {
133 log::info!("Fitting the model:\n {}", self.model);
134
135 match &self.learning_strategy {
136 LearningStrategy::SingleDevice(device) => {
137 let single_device = SingleDeviceLearningStrategy::new(device.clone());
138 single_device.fit(self, dataloader_train, dataloader_valid)
139 }
140 LearningStrategy::MultiDeviceNaive(devices) => {
141 let multi_device = MultiDeviceLearningStrategy::new(devices.clone());
142 multi_device.fit(self, dataloader_train, dataloader_valid)
143 }
144
145 #[cfg(feature = "ddp")]
146 LearningStrategy::DistributedDataParallel { devices, config } => {
147 let ddp = DdpLearningStrategy::new(devices.clone(), config.clone());
148 ddp.fit(self, dataloader_train, dataloader_valid)
149 }
150 }
151 }
152}