burn_train/learner/supervised/
paradigm.rs

1use crate::checkpoint::{
2    AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
3    KeepLastNCheckpoints, MetricCheckpointingStrategy,
4};
5use crate::components::{InferenceModelOutput, TrainingModelOutput};
6use crate::learner::EarlyStoppingStrategy;
7use crate::learner::base::Interrupter;
8use crate::logger::{FileMetricLogger, MetricLogger};
9use crate::metric::processor::{
10    AsyncProcessorTraining, FullEventProcessorTraining, ItemLazy, MetricsTraining,
11};
12use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
13use crate::metric::{Adaptor, LossMetric, Metric, Numeric};
14use crate::multi::MultiDeviceLearningStrategy;
15use crate::renderer::{MetricsRenderer, default_renderer};
16use crate::single::SingleDevicetrainingStrategy;
17use crate::{
18    ApplicationLoggerInstaller, EarlyStoppingStrategyRef, FileApplicationLoggerInstaller,
19    InferenceBackend, InferenceModel, InferenceModelInput, InferenceStep, LearnerModelRecord,
20    LearnerOptimizerRecord, LearnerSchedulerRecord, LearnerSummaryConfig, LearningCheckpointer,
21    LearningComponentsMarker, LearningComponentsTypes, LearningResult, TrainStep, TrainingBackend,
22    TrainingComponents, TrainingModelInput, TrainingStrategy,
23};
24use crate::{Learner, SupervisedLearningStrategy};
25use burn_core::data::dataloader::DataLoader;
26use burn_core::module::AutodiffModule;
27use burn_core::record::FileRecorder;
28use burn_core::tensor::backend::AutodiffBackend;
29use burn_optim::Optimizer;
30use burn_optim::lr_scheduler::LrScheduler;
31use std::collections::BTreeSet;
32use std::path::{Path, PathBuf};
33use std::sync::Arc;
34
35/// A reference to the training split [DataLoader](DataLoader).
36pub type TrainLoader<LC> = Arc<dyn DataLoader<TrainingBackend<LC>, TrainingModelInput<LC>>>;
37/// A reference to the validation split [DataLoader](DataLoader).
38pub type ValidLoader<LC> = Arc<dyn DataLoader<InferenceBackend<LC>, InferenceModelInput<LC>>>;
39/// The event processor type for supervised learning.
40pub type SupervisedTrainingEventProcessor<LC> = AsyncProcessorTraining<
41    FullEventProcessorTraining<TrainingModelOutput<LC>, InferenceModelOutput<LC>>,
42>;
43
44/// Structure to configure and launch supervised learning trainings.
45pub struct SupervisedTraining<LC>
46where
47    LC: LearningComponentsTypes,
48{
49    // Not that complex. Extracting into another type would only make it more confusing.
50    #[allow(clippy::type_complexity)]
51    checkpointers: Option<(
52        AsyncCheckpointer<LearnerModelRecord<LC>, TrainingBackend<LC>>,
53        AsyncCheckpointer<LearnerOptimizerRecord<LC>, TrainingBackend<LC>>,
54        AsyncCheckpointer<LearnerSchedulerRecord<LC>, TrainingBackend<LC>>,
55    )>,
56    num_epochs: usize,
57    checkpoint: Option<usize>,
58    directory: PathBuf,
59    grad_accumulation: Option<usize>,
60    renderer: Option<Box<dyn MetricsRenderer + 'static>>,
61    metrics: MetricsTraining<TrainingModelOutput<LC>, InferenceModelOutput<LC>>,
62    event_store: LogEventStore,
63    interrupter: Interrupter,
64    tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
65    checkpointer_strategy: Box<dyn CheckpointingStrategy>,
66    early_stopping: Option<EarlyStoppingStrategyRef>,
67    training_strategy: TrainingStrategy<LC>,
68    dataloader_train: TrainLoader<LC>,
69    dataloader_valid: ValidLoader<LC>,
70    // Use BTreeSet instead of HashSet for consistent (alphabetical) iteration order
71    summary_metrics: BTreeSet<String>,
72    summary: bool,
73}
74
75impl<B, LR, M, O> SupervisedTraining<LearningComponentsMarker<B, LR, M, O>>
76where
77    B: AutodiffBackend,
78    LR: LrScheduler + 'static,
79    M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
80    M::InnerModule: InferenceStep,
81    O: Optimizer<M, B> + 'static,
82{
83    /// Creates a new runner for a supervised training.
84    ///
85    /// # Arguments
86    ///
87    /// * `directory` - The directory to save the checkpoints.
88    /// * `dataloader_train` - The dataloader for the training split.
89    /// * `dataloader_valid` - The dataloader for the validation split.
90    pub fn new(
91        directory: impl AsRef<Path>,
92        dataloader_train: Arc<dyn DataLoader<B, M::Input>>,
93        dataloader_valid: Arc<
94            dyn DataLoader<B::InnerBackend, <M::InnerModule as InferenceStep>::Input>,
95        >,
96    ) -> Self {
97        let directory = directory.as_ref().to_path_buf();
98        let experiment_log_file = directory.join("experiment.log");
99        Self {
100            num_epochs: 1,
101            checkpoint: None,
102            checkpointers: None,
103            directory,
104            grad_accumulation: None,
105            metrics: MetricsTraining::default(),
106            event_store: LogEventStore::default(),
107            renderer: None,
108            interrupter: Interrupter::new(),
109            tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
110                experiment_log_file,
111            ))),
112            checkpointer_strategy: Box::new(
113                ComposedCheckpointingStrategy::builder()
114                    .add(KeepLastNCheckpoints::new(2))
115                    .add(MetricCheckpointingStrategy::new(
116                        &LossMetric::<B>::new(), // default to valid loss
117                        Aggregate::Mean,
118                        Direction::Lowest,
119                        Split::Valid,
120                    ))
121                    .build(),
122            ),
123            early_stopping: None,
124            training_strategy: TrainingStrategy::SingleDevice(Default::default()),
125            summary_metrics: BTreeSet::new(),
126            summary: false,
127            dataloader_train,
128            dataloader_valid,
129        }
130    }
131}
132
133impl<LC: LearningComponentsTypes> SupervisedTraining<LC> {
134    /// Replace the default training strategy (SingleDeviceTrainingStrategy) with the provided ones.
135    ///
136    /// # Arguments
137    ///
138    /// * `training_strategy` - The training strategy.
139    pub fn with_training_strategy(mut self, training_strategy: TrainingStrategy<LC>) -> Self {
140        self.training_strategy = training_strategy;
141        self
142    }
143
144    /// Replace the default metric loggers with the provided ones.
145    ///
146    /// # Arguments
147    ///
148    /// * `logger` - The training logger.
149    pub fn with_metric_logger<ML>(mut self, logger: ML) -> Self
150    where
151        ML: MetricLogger + 'static,
152    {
153        self.event_store.register_logger(logger);
154        self
155    }
156
157    /// Update the checkpointing_strategy.
158    pub fn with_checkpointing_strategy<CS: CheckpointingStrategy + 'static>(
159        mut self,
160        strategy: CS,
161    ) -> Self {
162        self.checkpointer_strategy = Box::new(strategy);
163        self
164    }
165
166    /// Replace the default CLI renderer with a custom one.
167    ///
168    /// # Arguments
169    ///
170    /// * `renderer` - The custom renderer.
171    pub fn renderer<MR>(mut self, renderer: MR) -> Self
172    where
173        MR: MetricsRenderer + 'static,
174    {
175        self.renderer = Some(Box::new(renderer));
176        self
177    }
178
179    /// Register all metrics as numeric for the training and validation set.
180    pub fn metrics<Me: MetricRegistration<LC>>(self, metrics: Me) -> Self {
181        metrics.register(self)
182    }
183
184    /// Register all metrics as numeric for the training and validation set.
185    pub fn metrics_text<Me: TextMetricRegistration<LC>>(self, metrics: Me) -> Self {
186        metrics.register(self)
187    }
188
189    /// Register a training metric.
190    pub fn metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
191    where
192        <TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
193    {
194        self.metrics.register_train_metric(metric);
195        self
196    }
197
198    /// Register a validation metric.
199    pub fn metric_valid<Me: Metric + 'static>(mut self, metric: Me) -> Self
200    where
201        <InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
202    {
203        self.metrics.register_valid_metric(metric);
204        self
205    }
206
207    /// Enable gradients accumulation.
208    ///
209    /// # Notes
210    ///
211    /// When you enable gradients accumulation, the gradients object used by the optimizer will be
212    /// the sum of all gradients generated by each backward pass. It might be a good idea to
213    /// reduce the learning to compensate.
214    ///
215    /// The effect is similar to increasing the `batch size` and the `learning rate` by the `accumulation`
216    /// amount.
217    pub fn grads_accumulation(mut self, accumulation: usize) -> Self {
218        self.grad_accumulation = Some(accumulation);
219        self
220    }
221
222    /// Register a [numeric](crate::metric::Numeric) training [metric](Metric).
223    pub fn metric_train_numeric<Me>(mut self, metric: Me) -> Self
224    where
225        Me: Metric + Numeric + 'static,
226        <TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
227    {
228        self.summary_metrics.insert(metric.name().to_string());
229        self.metrics.register_train_metric_numeric(metric);
230        self
231    }
232
233    /// Register a [numeric](crate::metric::Numeric) validation [metric](Metric).
234    pub fn metric_valid_numeric<Me: Metric + Numeric + 'static>(mut self, metric: Me) -> Self
235    where
236        <InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
237    {
238        self.summary_metrics.insert(metric.name().to_string());
239        self.metrics.register_valid_metric_numeric(metric);
240        self
241    }
242
243    /// The number of epochs the training should last.
244    pub fn num_epochs(mut self, num_epochs: usize) -> Self {
245        self.num_epochs = num_epochs;
246        self
247    }
248
249    /// The epoch from which the training must resume.
250    pub fn checkpoint(mut self, checkpoint: usize) -> Self {
251        self.checkpoint = Some(checkpoint);
252        self
253    }
254
255    /// Provides a handle that can be used to interrupt training.
256    pub fn interrupter(&self) -> Interrupter {
257        self.interrupter.clone()
258    }
259
260    /// Override the handle for stopping training with an externally provided handle
261    pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
262        self.interrupter = interrupter;
263        self
264    }
265
266    /// Register an [early stopping strategy](EarlyStoppingStrategy) to stop the training when the
267    /// conditions are meet.
268    pub fn early_stopping<Strategy>(mut self, strategy: Strategy) -> Self
269    where
270        Strategy: EarlyStoppingStrategy + Clone + Send + Sync + 'static,
271    {
272        self.early_stopping = Some(Box::new(strategy));
273        self
274    }
275
276    /// By default, Rust logs are captured and written into
277    /// `experiment.log`. If disabled, standard Rust log handling
278    /// will apply.
279    pub fn with_application_logger(
280        mut self,
281        logger: Option<Box<dyn ApplicationLoggerInstaller>>,
282    ) -> Self {
283        self.tracing_logger = logger;
284        self
285    }
286
287    /// Register a checkpointer that will save the [optimizer](Optimizer), the
288    /// [model](AutodiffModule) and the [scheduler](LrScheduler) to different files.
289    pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
290    where
291        FR: FileRecorder<<LC as LearningComponentsTypes>::Backend> + 'static,
292        FR: FileRecorder<
293                <<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend,
294            > + 'static,
295    {
296        let checkpoint_dir = self.directory.join("checkpoint");
297        let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model");
298        let checkpointer_optimizer =
299            FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim");
300        let checkpointer_scheduler: FileCheckpointer<FR> =
301            FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler");
302
303        self.checkpointers = Some((
304            AsyncCheckpointer::new(checkpointer_model),
305            AsyncCheckpointer::new(checkpointer_optimizer),
306            AsyncCheckpointer::new(checkpointer_scheduler),
307        ));
308
309        self
310    }
311
312    /// Enable the training summary report.
313    ///
314    /// The summary will be displayed after `.fit()`, when the renderer is dropped.
315    pub fn summary(mut self) -> Self {
316        self.summary = true;
317        self
318    }
319}
320
321impl<LC: LearningComponentsTypes + Send + 'static> SupervisedTraining<LC> {
322    /// Launch this training with the given [Learner](Learner).
323    pub fn launch(mut self, learner: Learner<LC>) -> LearningResult<InferenceModel<LC>> {
324        if self.tracing_logger.is_some()
325            && let Err(e) = self.tracing_logger.as_ref().unwrap().install()
326        {
327            log::warn!("Failed to install the experiment logger: {e}");
328        }
329        let renderer = self
330            .renderer
331            .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
332
333        if !self.event_store.has_loggers() {
334            self.event_store
335                .register_logger(FileMetricLogger::new(self.directory.clone()));
336        }
337
338        let event_store = Arc::new(EventStoreClient::new(self.event_store));
339        let event_processor = AsyncProcessorTraining::new(FullEventProcessorTraining::new(
340            self.metrics,
341            renderer,
342            event_store.clone(),
343        ));
344
345        let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
346            LearningCheckpointer::new(
347                model.with_interrupter(self.interrupter.clone()),
348                optim.with_interrupter(self.interrupter.clone()),
349                scheduler.with_interrupter(self.interrupter.clone()),
350                self.checkpointer_strategy,
351            )
352        });
353
354        let summary = if self.summary {
355            Some(LearnerSummaryConfig {
356                directory: self.directory,
357                metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
358            })
359        } else {
360            None
361        };
362
363        let components = TrainingComponents {
364            checkpoint: self.checkpoint,
365            checkpointer,
366            interrupter: self.interrupter,
367            early_stopping: self.early_stopping,
368            event_processor,
369            event_store,
370            num_epochs: self.num_epochs,
371            grad_accumulation: self.grad_accumulation,
372            summary,
373        };
374
375        match self.training_strategy {
376            TrainingStrategy::SingleDevice(device) => {
377                let single_device: SingleDevicetrainingStrategy<LC> =
378                    SingleDevicetrainingStrategy::new(device);
379                single_device.train(
380                    learner,
381                    self.dataloader_train,
382                    self.dataloader_valid,
383                    components,
384                )
385            }
386            TrainingStrategy::Custom(learning_paradigm) => learning_paradigm.train(
387                learner,
388                self.dataloader_train,
389                self.dataloader_valid,
390                components,
391            ),
392            TrainingStrategy::MultiDevice(devices, multi_device_optim) => {
393                let multi_device = MultiDeviceLearningStrategy::new(devices, multi_device_optim);
394                multi_device.train(
395                    learner,
396                    self.dataloader_train,
397                    self.dataloader_valid,
398                    components,
399                )
400            }
401            #[cfg(feature = "ddp")]
402            TrainingStrategy::DistributedDataParallel { devices, config } => {
403                use crate::ddp::DdpTrainingStrategy;
404
405                let ddp = DdpTrainingStrategy::new(devices.clone(), config.clone());
406                ddp.train(
407                    learner,
408                    self.dataloader_train,
409                    self.dataloader_valid,
410                    components,
411                )
412            }
413        }
414    }
415}
416
417/// Trait to fake variadic generics.
418pub trait MetricRegistration<LC: LearningComponentsTypes>: Sized {
419    /// Register the metrics.
420    fn register(self, builder: SupervisedTraining<LC>) -> SupervisedTraining<LC>;
421}
422
423/// Trait to fake variadic generics.
424pub trait TextMetricRegistration<LC: LearningComponentsTypes>: Sized {
425    /// Register the metrics.
426    fn register(self, builder: SupervisedTraining<LC>) -> SupervisedTraining<LC>;
427}
428
429macro_rules! gen_tuple {
430    ($($M:ident),*) => {
431        impl<$($M,)* LC: LearningComponentsTypes> TextMetricRegistration<LC> for ($($M,)*)
432        where
433            $(<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
434            $(<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
435            $($M: Metric + 'static,)*
436        {
437            #[allow(non_snake_case)]
438            fn register(
439                self,
440                builder: SupervisedTraining<LC>,
441            ) -> SupervisedTraining<LC> {
442                let ($($M,)*) = self;
443                $(let builder = builder.metric_train($M.clone());)*
444                $(let builder = builder.metric_valid($M);)*
445                builder
446            }
447        }
448
449        impl<$($M,)* LC: LearningComponentsTypes> MetricRegistration<LC> for ($($M,)*)
450        where
451            $(<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
452            $(<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
453            $($M: Metric + Numeric + 'static,)*
454        {
455            #[allow(non_snake_case)]
456            fn register(
457                self,
458                builder: SupervisedTraining<LC>,
459            ) -> SupervisedTraining<LC> {
460                let ($($M,)*) = self;
461                $(let builder = builder.metric_train_numeric($M.clone());)*
462                $(let builder = builder.metric_valid_numeric($M);)*
463                builder
464            }
465        }
466    };
467}
468
469gen_tuple!(M1);
470gen_tuple!(M1, M2);
471gen_tuple!(M1, M2, M3);
472gen_tuple!(M1, M2, M3, M4);
473gen_tuple!(M1, M2, M3, M4, M5);
474gen_tuple!(M1, M2, M3, M4, M5, M6);