Skip to main content

burn_train/learner/supervised/
paradigm.rs

1use crate::checkpoint::{
2    AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
3    KeepLastNCheckpoints, MetricCheckpointingStrategy,
4};
5use crate::components::{InferenceModelOutput, TrainingModelOutput};
6use crate::learner::EarlyStoppingStrategy;
7use crate::learner::base::Interrupter;
8use crate::logger::{FileMetricLogger, MetricLogger};
9use crate::metric::processor::{
10    AsyncProcessorTraining, FullEventProcessorTraining, ItemLazy, MetricsTraining,
11};
12use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
13use crate::metric::{Adaptor, LossMetric, Metric, Numeric};
14use crate::multi::MultiDeviceLearningStrategy;
15use crate::renderer::{MetricsRenderer, default_renderer};
16use crate::single::SingleDeviceTrainingStrategy;
17use crate::{
18    ApplicationLoggerInstaller, EarlyStoppingStrategyRef, ExecutionStrategy,
19    FileApplicationLoggerInstaller, InferenceBackend, InferenceModel, InferenceModelInput,
20    InferenceStep, LearnerEvent, LearnerModelRecord, LearnerOptimizerRecord,
21    LearnerSchedulerRecord, LearnerSummaryConfig, LearningCheckpointer, LearningComponentsMarker,
22    LearningComponentsTypes, LearningResult, TrainStep, TrainingBackend, TrainingComponents,
23    TrainingModelInput, TrainingStrategy,
24};
25use crate::{Learner, SupervisedLearningStrategy};
26use burn_core::data::dataloader::DataLoader;
27use burn_core::module::{AutodiffModule, Module};
28use burn_core::record::FileRecorder;
29use burn_core::tensor::backend::AutodiffBackend;
30use burn_optim::Optimizer;
31use burn_optim::lr_scheduler::LrScheduler;
32use std::collections::BTreeSet;
33use std::path::{Path, PathBuf};
34use std::sync::Arc;
35
36/// A reference to the training split [DataLoader](DataLoader).
37pub type TrainLoader<LC> = Arc<dyn DataLoader<TrainingBackend<LC>, TrainingModelInput<LC>>>;
38/// A reference to the validation split [DataLoader](DataLoader).
39pub type ValidLoader<LC> = Arc<dyn DataLoader<InferenceBackend<LC>, InferenceModelInput<LC>>>;
40/// The event processor type for supervised learning.
41pub type SupervisedTrainingEventProcessor<LC> = AsyncProcessorTraining<
42    LearnerEvent<TrainingModelOutput<LC>>,
43    LearnerEvent<InferenceModelOutput<LC>>,
44>;
45
46/// Structure to configure and launch supervised learning trainings.
47pub struct SupervisedTraining<LC>
48where
49    LC: LearningComponentsTypes,
50{
51    // Not that complex. Extracting into another type would only make it more confusing.
52    #[allow(clippy::type_complexity)]
53    checkpointers: Option<(
54        AsyncCheckpointer<LearnerModelRecord<LC>, TrainingBackend<LC>>,
55        AsyncCheckpointer<LearnerOptimizerRecord<LC>, TrainingBackend<LC>>,
56        AsyncCheckpointer<LearnerSchedulerRecord<LC>, TrainingBackend<LC>>,
57    )>,
58    num_epochs: usize,
59    checkpoint: Option<usize>,
60    directory: PathBuf,
61    grad_accumulation: Option<usize>,
62    renderer: Option<Box<dyn MetricsRenderer + 'static>>,
63    metrics: MetricsTraining<TrainingModelOutput<LC>, InferenceModelOutput<LC>>,
64    event_store: LogEventStore,
65    interrupter: Interrupter,
66    tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
67    checkpointer_strategy: Box<dyn CheckpointingStrategy>,
68    early_stopping: Option<EarlyStoppingStrategyRef>,
69    training_strategy: Option<TrainingStrategy<LC>>,
70    dataloader_train: TrainLoader<LC>,
71    dataloader_valid: ValidLoader<LC>,
72    // Use BTreeSet instead of HashSet for consistent (alphabetical) iteration order
73    summary_metrics: BTreeSet<String>,
74    summary: bool,
75}
76
77impl<B, LR, M, O> SupervisedTraining<LearningComponentsMarker<B, LR, M, O>>
78where
79    B: AutodiffBackend,
80    LR: LrScheduler + 'static,
81    M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
82    M::InnerModule: InferenceStep,
83    O: Optimizer<M, B> + 'static,
84{
85    /// Creates a new runner for a supervised training.
86    ///
87    /// # Arguments
88    ///
89    /// * `directory` - The directory to save the checkpoints.
90    /// * `dataloader_train` - The dataloader for the training split.
91    /// * `dataloader_valid` - The dataloader for the validation split.
92    pub fn new(
93        directory: impl AsRef<Path>,
94        dataloader_train: Arc<dyn DataLoader<B, M::Input>>,
95        dataloader_valid: Arc<
96            dyn DataLoader<B::InnerBackend, <M::InnerModule as InferenceStep>::Input>,
97        >,
98    ) -> Self {
99        let directory = directory.as_ref().to_path_buf();
100        let experiment_log_file = directory.join("experiment.log");
101        Self {
102            num_epochs: 1,
103            checkpoint: None,
104            checkpointers: None,
105            directory,
106            grad_accumulation: None,
107            metrics: MetricsTraining::default(),
108            event_store: LogEventStore::default(),
109            renderer: None,
110            interrupter: Interrupter::new(),
111            tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
112                experiment_log_file,
113            ))),
114            checkpointer_strategy: Box::new(
115                ComposedCheckpointingStrategy::builder()
116                    .add(KeepLastNCheckpoints::new(2))
117                    .add(MetricCheckpointingStrategy::new(
118                        &LossMetric::<B>::new(), // default to valid loss
119                        Aggregate::Mean,
120                        Direction::Lowest,
121                        Split::Valid,
122                    ))
123                    .build(),
124            ),
125            early_stopping: None,
126            training_strategy: None,
127            summary_metrics: BTreeSet::new(),
128            summary: false,
129            dataloader_train,
130            dataloader_valid,
131        }
132    }
133}
134
135impl<LC: LearningComponentsTypes> SupervisedTraining<LC> {
136    /// Replace the default training strategy (SingleDeviceTrainingStrategy) with the provided one.
137    ///
138    /// # Arguments
139    ///
140    /// * `training_strategy` - The training strategy.
141    pub fn with_training_strategy(mut self, training_strategy: TrainingStrategy<LC>) -> Self {
142        self.training_strategy = Some(training_strategy);
143        self
144    }
145
146    /// Replace the default metric loggers with the provided ones.
147    ///
148    /// # Arguments
149    ///
150    /// * `logger` - The training logger.
151    pub fn with_metric_logger<ML>(mut self, logger: ML) -> Self
152    where
153        ML: MetricLogger + 'static,
154    {
155        self.event_store.register_logger(logger);
156        self
157    }
158
159    /// Update the checkpointing_strategy.
160    pub fn with_checkpointing_strategy<CS: CheckpointingStrategy + 'static>(
161        mut self,
162        strategy: CS,
163    ) -> Self {
164        self.checkpointer_strategy = Box::new(strategy);
165        self
166    }
167
168    /// Replace the default CLI renderer with a custom one.
169    ///
170    /// # Arguments
171    ///
172    /// * `renderer` - The custom renderer.
173    pub fn renderer<MR>(mut self, renderer: MR) -> Self
174    where
175        MR: MetricsRenderer + 'static,
176    {
177        self.renderer = Some(Box::new(renderer));
178        self
179    }
180
181    /// Register all metrics as numeric for the training and validation set.
182    pub fn metrics<Me: MetricRegistration<LC>>(self, metrics: Me) -> Self {
183        metrics.register(self)
184    }
185
186    /// Register all metrics as text for the training and validation set.
187    pub fn metrics_text<Me: TextMetricRegistration<LC>>(self, metrics: Me) -> Self {
188        metrics.register(self)
189    }
190
191    /// Register a training metric.
192    pub fn metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
193    where
194        <TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
195    {
196        self.metrics.register_train_metric(metric);
197        self
198    }
199
200    /// Register a validation metric.
201    pub fn metric_valid<Me: Metric + 'static>(mut self, metric: Me) -> Self
202    where
203        <InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
204    {
205        self.metrics.register_valid_metric(metric);
206        self
207    }
208
209    /// Enable gradients accumulation.
210    ///
211    /// # Notes
212    ///
213    /// When you enable gradients accumulation, the gradients object used by the optimizer will be
214    /// the sum of all gradients generated by each backward pass. It might be a good idea to
215    /// reduce the learning to compensate.
216    ///
217    /// The effect is similar to increasing the `batch size` and the `learning rate` by the `accumulation`
218    /// amount.
219    pub fn grads_accumulation(mut self, accumulation: usize) -> Self {
220        self.grad_accumulation = Some(accumulation);
221        self
222    }
223
224    /// Register a [numeric](crate::metric::Numeric) training [metric](Metric).
225    pub fn metric_train_numeric<Me>(mut self, metric: Me) -> Self
226    where
227        Me: Metric + Numeric + 'static,
228        <TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
229    {
230        self.summary_metrics.insert(metric.name().to_string());
231        self.metrics.register_train_metric_numeric(metric);
232        self
233    }
234
235    /// Register a [numeric](crate::metric::Numeric) validation [metric](Metric).
236    pub fn metric_valid_numeric<Me: Metric + Numeric + 'static>(mut self, metric: Me) -> Self
237    where
238        <InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
239    {
240        self.summary_metrics.insert(metric.name().to_string());
241        self.metrics.register_valid_metric_numeric(metric);
242        self
243    }
244
245    /// The number of epochs the training should last.
246    pub fn num_epochs(mut self, num_epochs: usize) -> Self {
247        self.num_epochs = num_epochs;
248        self
249    }
250
251    /// The epoch from which the training must resume.
252    pub fn checkpoint(mut self, checkpoint: usize) -> Self {
253        self.checkpoint = Some(checkpoint);
254        self
255    }
256
257    /// Provides a handle that can be used to interrupt training.
258    pub fn interrupter(&self) -> Interrupter {
259        self.interrupter.clone()
260    }
261
262    /// Override the handle for stopping training with an externally provided handle
263    pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
264        self.interrupter = interrupter;
265        self
266    }
267
268    /// Register an [early stopping strategy](EarlyStoppingStrategy) to stop the training when the
269    /// conditions are meet.
270    pub fn early_stopping<Strategy>(mut self, strategy: Strategy) -> Self
271    where
272        Strategy: EarlyStoppingStrategy + Clone + Send + Sync + 'static,
273    {
274        self.early_stopping = Some(Box::new(strategy));
275        self
276    }
277
278    /// By default, Rust logs are captured and written into
279    /// `experiment.log`. If disabled, standard Rust log handling
280    /// will apply.
281    pub fn with_application_logger(
282        mut self,
283        logger: Option<Box<dyn ApplicationLoggerInstaller>>,
284    ) -> Self {
285        self.tracing_logger = logger;
286        self
287    }
288
289    /// Register a checkpointer that will save the [optimizer](Optimizer), the
290    /// [model](AutodiffModule) and the [scheduler](LrScheduler) to different files.
291    pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
292    where
293        FR: FileRecorder<<LC as LearningComponentsTypes>::Backend> + 'static,
294        FR: FileRecorder<
295                <<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend,
296            > + 'static,
297    {
298        let checkpoint_dir = self.directory.join("checkpoint");
299        let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model");
300        let checkpointer_optimizer =
301            FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim");
302        let checkpointer_scheduler: FileCheckpointer<FR> =
303            FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler");
304
305        self.checkpointers = Some((
306            AsyncCheckpointer::new(checkpointer_model),
307            AsyncCheckpointer::new(checkpointer_optimizer),
308            AsyncCheckpointer::new(checkpointer_scheduler),
309        ));
310
311        self
312    }
313
314    /// Enable the training summary report.
315    ///
316    /// The summary will be displayed after `.fit()`, when the renderer is dropped.
317    pub fn summary(mut self) -> Self {
318        self.summary = true;
319        self
320    }
321}
322
323impl<LC> SupervisedTraining<LC>
324where
325    LC: LearningComponentsTypes + Send + 'static,
326{
327    /// Launch this training with the given [Learner](Learner).
328    pub fn launch(mut self, learner: Learner<LC>) -> LearningResult<InferenceModel<LC>> {
329        if self.tracing_logger.is_some()
330            && let Err(e) = self.tracing_logger.as_ref().unwrap().install()
331        {
332            log::warn!("Failed to install the experiment logger: {e}");
333        }
334        let renderer = self
335            .renderer
336            .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
337
338        if !self.event_store.has_loggers() {
339            self.event_store
340                .register_logger(FileMetricLogger::new(self.directory.clone()));
341        }
342
343        let event_store = Arc::new(EventStoreClient::new(self.event_store));
344        let event_processor = AsyncProcessorTraining::new(FullEventProcessorTraining::new(
345            self.metrics,
346            renderer,
347            event_store.clone(),
348        ));
349
350        let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
351            LearningCheckpointer::new(
352                model.with_interrupter(self.interrupter.clone()),
353                optim.with_interrupter(self.interrupter.clone()),
354                scheduler.with_interrupter(self.interrupter.clone()),
355                self.checkpointer_strategy,
356            )
357        });
358
359        let summary = if self.summary {
360            Some(LearnerSummaryConfig {
361                directory: self.directory,
362                metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
363            })
364        } else {
365            None
366        };
367
368        let components = TrainingComponents {
369            checkpoint: self.checkpoint,
370            checkpointer,
371            interrupter: self.interrupter,
372            early_stopping: self.early_stopping,
373            event_processor,
374            event_store,
375            num_epochs: self.num_epochs,
376            grad_accumulation: self.grad_accumulation,
377            summary,
378        };
379
380        // Default to single device based on model
381        let training_strategy = self.training_strategy.unwrap_or(TrainingStrategy::Default(
382            ExecutionStrategy::SingleDevice(learner.model.devices()[0].clone()),
383        ));
384
385        match training_strategy {
386            TrainingStrategy::Custom(learning_paradigm) => learning_paradigm.train(
387                learner,
388                self.dataloader_train,
389                self.dataloader_valid,
390                components,
391            ),
392            TrainingStrategy::Default(strategy) => match strategy {
393                ExecutionStrategy::SingleDevice(device) => {
394                    let single_device: SingleDeviceTrainingStrategy<LC> =
395                        SingleDeviceTrainingStrategy::new(device);
396                    single_device.train(
397                        learner,
398                        self.dataloader_train,
399                        self.dataloader_valid,
400                        components,
401                    )
402                }
403                ExecutionStrategy::MultiDevice(devices, multi_device_optim) => {
404                    let strategy: Box<dyn SupervisedLearningStrategy<LC>> = match devices.len() == 1
405                    {
406                        true => Box::new(SingleDeviceTrainingStrategy::new(devices[0].clone())),
407                        false => Box::new(MultiDeviceLearningStrategy::new(
408                            devices,
409                            multi_device_optim,
410                        )),
411                    };
412                    strategy.train(
413                        learner,
414                        self.dataloader_train,
415                        self.dataloader_valid,
416                        components,
417                    )
418                }
419                #[cfg(feature = "ddp")]
420                ExecutionStrategy::DistributedDataParallel { devices, runtime } => {
421                    use crate::ddp::DdpTrainingStrategy;
422
423                    let ddp = DdpTrainingStrategy::new(devices, runtime);
424                    ddp.train(
425                        learner,
426                        self.dataloader_train,
427                        self.dataloader_valid,
428                        components,
429                    )
430                }
431            },
432        }
433    }
434}
435
436/// Trait to fake variadic generics.
437pub trait MetricRegistration<LC: LearningComponentsTypes>: Sized {
438    /// Register the metrics.
439    fn register(self, builder: SupervisedTraining<LC>) -> SupervisedTraining<LC>;
440}
441
442/// Trait to fake variadic generics.
443pub trait TextMetricRegistration<LC: LearningComponentsTypes>: Sized {
444    /// Register the metrics.
445    fn register(self, builder: SupervisedTraining<LC>) -> SupervisedTraining<LC>;
446}
447
448macro_rules! gen_tuple {
449    ($($M:ident),*) => {
450        impl<$($M,)* LC: LearningComponentsTypes> TextMetricRegistration<LC> for ($($M,)*)
451        where
452            $(<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
453            $(<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
454            $($M: Metric + 'static,)*
455        {
456            #[allow(non_snake_case)]
457            fn register(
458                self,
459                builder: SupervisedTraining<LC>,
460            ) -> SupervisedTraining<LC> {
461                let ($($M,)*) = self;
462                $(let builder = builder.metric_train($M.clone());)*
463                $(let builder = builder.metric_valid($M);)*
464                builder
465            }
466        }
467
468        impl<$($M,)* LC: LearningComponentsTypes> MetricRegistration<LC> for ($($M,)*)
469        where
470            $(<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
471            $(<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
472            $($M: Metric + Numeric + 'static,)*
473        {
474            #[allow(non_snake_case)]
475            fn register(
476                self,
477                builder: SupervisedTraining<LC>,
478            ) -> SupervisedTraining<LC> {
479                let ($($M,)*) = self;
480                $(let builder = builder.metric_train_numeric($M.clone());)*
481                $(let builder = builder.metric_valid_numeric($M);)*
482                builder
483            }
484        }
485    };
486}
487
488gen_tuple!(M1);
489gen_tuple!(M1, M2);
490gen_tuple!(M1, M2, M3);
491gen_tuple!(M1, M2, M3, M4);
492gen_tuple!(M1, M2, M3, M4, M5);
493gen_tuple!(M1, M2, M3, M4, M5, M6);