burn_train/learner/
builder.rs

1use std::collections::BTreeSet;
2use std::marker::PhantomData;
3use std::path::{Path, PathBuf};
4use std::sync::Arc;
5
6use super::Learner;
7use crate::checkpoint::{
8    AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
9    KeepLastNCheckpoints, MetricCheckpointingStrategy,
10};
11use crate::components::{LearnerComponentsMarker, LearningDataMarker};
12use crate::learner::EarlyStoppingStrategy;
13use crate::learner::base::Interrupter;
14use crate::logger::{FileMetricLogger, MetricLogger};
15use crate::metric::processor::{
16    AsyncProcessorTraining, FullEventProcessorTraining, ItemLazy, MetricsTraining,
17};
18use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
19use crate::metric::{Adaptor, LossMetric, Metric};
20use crate::renderer::{MetricsRenderer, default_renderer};
21use crate::{
22    ApplicationLoggerInstaller, EarlyStoppingStrategyRef, FileApplicationLoggerInstaller,
23    LearnerCheckpointer, LearnerSummaryConfig, LearningStrategy, TrainStep, ValidStep,
24};
25use burn_core::module::AutodiffModule;
26use burn_core::record::FileRecorder;
27use burn_core::tensor::backend::AutodiffBackend;
28use burn_optim::Optimizer;
29use burn_optim::lr_scheduler::LrScheduler;
30
31/// Struct to configure and create a [learner](Learner).
32///
33/// The generics components of the builder should probably not be set manually, as they are
34/// optimized for Rust type inference.
35pub struct LearnerBuilder<B, M, O, S, TI, VI, TO, VO>
36where
37    B: AutodiffBackend,
38    M: AutodiffModule<B> + TrainStep<TI, TO> + core::fmt::Display + 'static,
39    M::InnerModule: ValidStep<VI, VO>,
40    O: Optimizer<M, B>,
41    S: LrScheduler,
42    TI: Send + 'static,
43    VI: Send + 'static,
44    TO: ItemLazy + 'static,
45    VO: ItemLazy + 'static,
46{
47    // Not that complex and very convenient when the traits are
48    // already constrained correctly. Extracting in another type
49    // would be more complex.
50    #[allow(clippy::type_complexity)]
51    checkpointers: Option<(
52        AsyncCheckpointer<M::Record, B>,
53        AsyncCheckpointer<O::Record, B>,
54        AsyncCheckpointer<S::Record<B>, B>,
55    )>,
56    num_epochs: usize,
57    checkpoint: Option<usize>,
58    directory: PathBuf,
59    grad_accumulation: Option<usize>,
60    learning_strategy: LearningStrategy<B>,
61    renderer: Option<Box<dyn MetricsRenderer + 'static>>,
62    metrics: MetricsTraining<TO, VO>,
63    event_store: LogEventStore,
64    interrupter: Interrupter,
65    tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
66    num_loggers: usize,
67    checkpointer_strategy: Box<dyn CheckpointingStrategy>,
68    early_stopping: Option<EarlyStoppingStrategyRef>,
69    // Use BTreeSet instead of HashSet for consistent (alphabetical) iteration order
70    summary_metrics: BTreeSet<String>,
71    summary: bool,
72    _p: PhantomData<(TI, VI, TO, VO)>,
73}
74
75impl<B, M, O, S, TI, VI, TO, VO> LearnerBuilder<B, M, O, S, TI, VI, TO, VO>
76where
77    B: AutodiffBackend,
78    M: AutodiffModule<B> + TrainStep<TI, TO> + core::fmt::Display + 'static,
79    M::InnerModule: ValidStep<VI, VO>,
80    O: Optimizer<M, B>,
81    S: LrScheduler,
82    TI: Send + 'static,
83    VI: Send + 'static,
84    TO: ItemLazy + 'static,
85    VO: ItemLazy + 'static,
86{
87    /// Creates a new learner builder.
88    ///
89    /// # Arguments
90    ///
91    /// * `directory` - The directory to save the checkpoints.
92    pub fn new(directory: impl AsRef<Path>) -> Self {
93        let directory = directory.as_ref().to_path_buf();
94        let experiment_log_file = directory.join("experiment.log");
95        Self {
96            num_epochs: 1,
97            checkpoint: None,
98            checkpointers: None,
99            directory,
100            grad_accumulation: None,
101            learning_strategy: LearningStrategy::default(),
102            metrics: MetricsTraining::default(),
103            event_store: LogEventStore::default(),
104            renderer: None,
105            interrupter: Interrupter::new(),
106            tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
107                experiment_log_file,
108            ))),
109            num_loggers: 0,
110            checkpointer_strategy: Box::new(
111                ComposedCheckpointingStrategy::builder()
112                    .add(KeepLastNCheckpoints::new(2))
113                    .add(MetricCheckpointingStrategy::new(
114                        &LossMetric::<B>::new(), // default to valid loss
115                        Aggregate::Mean,
116                        Direction::Lowest,
117                        Split::Valid,
118                    ))
119                    .build(),
120            ),
121            early_stopping: None,
122            summary_metrics: BTreeSet::new(),
123            summary: false,
124            _p: PhantomData,
125        }
126    }
127
128    /// Replace the default metric loggers with the provided ones.
129    ///
130    /// # Arguments
131    ///
132    /// * `logger_train` - The training logger.
133    /// * `logger_valid` - The validation logger.
134    pub fn metric_loggers<MT, MV>(mut self, logger_train: MT, logger_valid: MV) -> Self
135    where
136        MT: MetricLogger + 'static,
137        MV: MetricLogger + 'static,
138    {
139        self.event_store.register_logger_train(logger_train);
140        self.event_store.register_logger_valid(logger_valid);
141        self.num_loggers += 1;
142        self
143    }
144
145    /// Update the checkpointing_strategy.
146    pub fn with_checkpointing_strategy<CS>(mut self, strategy: CS) -> Self
147    where
148        CS: CheckpointingStrategy + 'static,
149    {
150        self.checkpointer_strategy = Box::new(strategy);
151        self
152    }
153
154    /// Replace the default CLI renderer with a custom one.
155    ///
156    /// # Arguments
157    ///
158    /// * `renderer` - The custom renderer.
159    pub fn renderer<MR>(mut self, renderer: MR) -> Self
160    where
161        MR: MetricsRenderer + 'static,
162    {
163        self.renderer = Some(Box::new(renderer));
164        self
165    }
166
167    /// Register all metrics as numeric for the training and validation set.
168    pub fn metrics<Me: MetricRegistration<B, M, O, S, TI, VI, TO, VO>>(self, metrics: Me) -> Self {
169        metrics.register(self)
170    }
171
172    /// Register all metrics as numeric for the training and validation set.
173    pub fn metrics_text<Me: TextMetricRegistration<B, M, O, S, TI, VI, TO, VO>>(
174        self,
175        metrics: Me,
176    ) -> Self {
177        metrics.register(self)
178    }
179
180    /// Register a training metric.
181    pub fn metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
182    where
183        <TO as ItemLazy>::ItemSync: Adaptor<Me::Input>,
184    {
185        self.metrics.register_train_metric(metric);
186        self
187    }
188
189    /// Register a validation metric.
190    pub fn metric_valid<Me: Metric + 'static>(mut self, metric: Me) -> Self
191    where
192        <VO as ItemLazy>::ItemSync: Adaptor<Me::Input>,
193    {
194        self.metrics.register_valid_metric(metric);
195        self
196    }
197
198    /// Enable gradients accumulation.
199    ///
200    /// # Notes
201    ///
202    /// When you enable gradients accumulation, the gradients object used by the optimizer will be
203    /// the sum of all gradients generated by each backward pass. It might be a good idea to
204    /// reduce the learning to compensate.
205    ///
206    /// The effect is similar to increasing the `batch size` and the `learning rate` by the `accumulation`
207    /// amount.
208    pub fn grads_accumulation(mut self, accumulation: usize) -> Self {
209        self.grad_accumulation = Some(accumulation);
210        self
211    }
212
213    /// Register a [numeric](crate::metric::Numeric) training [metric](Metric).
214    pub fn metric_train_numeric<Me>(mut self, metric: Me) -> Self
215    where
216        Me: Metric + crate::metric::Numeric + 'static,
217        <TO as ItemLazy>::ItemSync: Adaptor<Me::Input>,
218    {
219        self.summary_metrics.insert(metric.name().to_string());
220        self.metrics.register_train_metric_numeric(metric);
221        self
222    }
223
224    /// Register a [numeric](crate::metric::Numeric) validation [metric](Metric).
225    pub fn metric_valid_numeric<Me: Metric + crate::metric::Numeric + 'static>(
226        mut self,
227        metric: Me,
228    ) -> Self
229    where
230        <VO as ItemLazy>::ItemSync: Adaptor<Me::Input>,
231    {
232        self.summary_metrics.insert(metric.name().to_string());
233        self.metrics.register_valid_metric_numeric(metric);
234        self
235    }
236
237    /// The number of epochs the training should last.
238    pub fn num_epochs(mut self, num_epochs: usize) -> Self {
239        self.num_epochs = num_epochs;
240        self
241    }
242
243    /// Run the training loop with different strategies
244    pub fn learning_strategy(mut self, learning_strategy: LearningStrategy<B>) -> Self {
245        self.learning_strategy = learning_strategy;
246        self
247    }
248
249    /// The epoch from which the training must resume.
250    pub fn checkpoint(mut self, checkpoint: usize) -> Self {
251        self.checkpoint = Some(checkpoint);
252        self
253    }
254
255    /// Provides a handle that can be used to interrupt training.
256    pub fn interrupter(&self) -> Interrupter {
257        self.interrupter.clone()
258    }
259
260    /// Override the handle for stopping training with an externally provided handle
261    pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
262        self.interrupter = interrupter;
263        self
264    }
265
266    /// Register an [early stopping strategy](EarlyStoppingStrategy) to stop the training when the
267    /// conditions are meet.
268    pub fn early_stopping<Strategy>(mut self, strategy: Strategy) -> Self
269    where
270        Strategy: EarlyStoppingStrategy + Clone + Send + Sync + 'static,
271    {
272        self.early_stopping = Some(Box::new(strategy));
273        self
274    }
275
276    /// By default, Rust logs are captured and written into
277    /// `experiment.log`. If disabled, standard Rust log handling
278    /// will apply.
279    pub fn with_application_logger(
280        mut self,
281        logger: Option<Box<dyn ApplicationLoggerInstaller>>,
282    ) -> Self {
283        self.tracing_logger = logger;
284        self
285    }
286
287    /// Register a checkpointer that will save the [optimizer](Optimizer), the
288    /// [model](AutodiffModule) and the [scheduler](LrScheduler) to different files.
289    pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
290    where
291        FR: FileRecorder<B> + 'static,
292        FR: FileRecorder<B::InnerBackend> + 'static,
293        O::Record: 'static,
294        M::Record: 'static,
295        S::Record<B>: 'static,
296    {
297        let checkpoint_dir = self.directory.join("checkpoint");
298        let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model");
299        let checkpointer_optimizer =
300            FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim");
301        let checkpointer_scheduler: FileCheckpointer<FR> =
302            FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler");
303
304        self.checkpointers = Some((
305            AsyncCheckpointer::new(checkpointer_model),
306            AsyncCheckpointer::new(checkpointer_optimizer),
307            AsyncCheckpointer::new(checkpointer_scheduler),
308        ));
309
310        self
311    }
312
313    /// Enable the training summary report.
314    ///
315    /// The summary will be displayed after `.fit()`, when the renderer is dropped.
316    pub fn summary(mut self) -> Self {
317        self.summary = true;
318        self
319    }
320
321    /// Create the [learner](Learner) from a [model](AutodiffModule) and an [optimizer](Optimizer).
322    /// The [learning rate scheduler](LrScheduler) can also be a simple
323    /// [learning rate](burn_optim::LearningRate).
324    #[allow(clippy::type_complexity)] // The goal for the builder is to handle all types and
325    // creates a clean learner.
326    pub fn build(
327        mut self,
328        model: M,
329        optim: O,
330        lr_scheduler: S,
331    ) -> Learner<
332        LearnerComponentsMarker<
333            B,
334            S,
335            M,
336            O,
337            AsyncCheckpointer<M::Record, B>,
338            AsyncCheckpointer<O::Record, B>,
339            AsyncCheckpointer<S::Record<B>, B>,
340            AsyncProcessorTraining<FullEventProcessorTraining<TO, VO>>,
341            Box<dyn CheckpointingStrategy>,
342            LearningDataMarker<TI, VI, TO, VO>,
343        >,
344    >
345    where
346        M::Record: 'static,
347        O::Record: 'static,
348        S::Record<B>: 'static,
349    {
350        if self.tracing_logger.is_some()
351            && let Err(e) = self.tracing_logger.as_ref().unwrap().install()
352        {
353            log::warn!("Failed to install the experiment logger: {e}");
354        }
355        let renderer = self
356            .renderer
357            .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
358
359        if self.num_loggers == 0 {
360            self.event_store
361                .register_logger_train(FileMetricLogger::new_train(self.directory.join("train")));
362            self.event_store
363                .register_logger_valid(FileMetricLogger::new_train(self.directory.join("valid")));
364        }
365
366        let event_store = Arc::new(EventStoreClient::new(self.event_store));
367        let event_processor = AsyncProcessorTraining::new(FullEventProcessorTraining::new(
368            self.metrics,
369            renderer,
370            event_store.clone(),
371        ));
372
373        let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
374            LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy)
375        });
376
377        let summary = if self.summary {
378            Some(LearnerSummaryConfig {
379                directory: self.directory,
380                metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
381            })
382        } else {
383            None
384        };
385
386        let learning_strategy = Self::prepare_learning_strategy(self.learning_strategy);
387
388        Learner {
389            model,
390            optim,
391            lr_scheduler,
392            checkpointer,
393            num_epochs: self.num_epochs,
394            event_processor,
395            event_store,
396            checkpoint: self.checkpoint,
397            grad_accumulation: self.grad_accumulation,
398            learning_strategy,
399            interrupter: self.interrupter,
400            early_stopping: self.early_stopping,
401            summary,
402        }
403    }
404
405    fn prepare_learning_strategy(learning_strategy: LearningStrategy<B>) -> LearningStrategy<B> {
406        if let LearningStrategy::MultiDeviceNaive(devices) = &learning_strategy
407            && devices.len() == 1
408        {
409            return LearningStrategy::SingleDevice(devices[0].clone());
410        }
411
412        learning_strategy
413    }
414}
415
416/// Trait to fake variadic generics.
417pub trait MetricRegistration<B, M, O, S, TI, VI, TO, VO>: Sized
418where
419    B: AutodiffBackend,
420    M: AutodiffModule<B> + TrainStep<TI, TO> + core::fmt::Display + 'static,
421    M::InnerModule: ValidStep<VI, VO>,
422    O: Optimizer<M, B>,
423    S: LrScheduler,
424    TI: Send + 'static,
425    VI: Send + 'static,
426    TO: ItemLazy + 'static,
427    VO: ItemLazy + 'static,
428{
429    /// Register the metrics.
430    fn register(
431        self,
432        builder: LearnerBuilder<B, M, O, S, TI, VI, TO, VO>,
433    ) -> LearnerBuilder<B, M, O, S, TI, VI, TO, VO>;
434}
435
436/// Trait to fake variadic generics.
437pub trait TextMetricRegistration<B, M, O, S, TI, VI, TO, VO>: Sized
438where
439    B: AutodiffBackend,
440    M: AutodiffModule<B> + TrainStep<TI, TO> + core::fmt::Display + 'static,
441    M::InnerModule: ValidStep<VI, VO>,
442    O: Optimizer<M, B>,
443    S: LrScheduler,
444    TI: Send + 'static,
445    VI: Send + 'static,
446    TO: ItemLazy + 'static,
447    VO: ItemLazy + 'static,
448{
449    /// Register the metrics.
450    fn register(
451        self,
452        builder: LearnerBuilder<B, M, O, S, TI, VI, TO, VO>,
453    ) -> LearnerBuilder<B, M, O, S, TI, VI, TO, VO>;
454}
455
456macro_rules! gen_tuple {
457    ($($M:ident),*) => {
458        impl<$($M,)* B, M, O, S, TI, VI, TO, VO> TextMetricRegistration<B, M, O, S, TI, VI, TO, VO> for ($($M,)*)
459        where
460            B: AutodiffBackend,
461            M: AutodiffModule<B> + TrainStep<TI, TO> + core::fmt::Display + 'static,
462            M::InnerModule: ValidStep<VI, VO>,
463            O: Optimizer<M, B>,
464            S: LrScheduler,
465            TI: Send + 'static,
466            VI: Send + 'static,
467            TO: ItemLazy + 'static,
468            VO: ItemLazy + 'static,
469            $(TO::ItemSync: Adaptor<$M::Input>,)*
470            $(VO::ItemSync: Adaptor<$M::Input>,)*
471            $($M: Metric + 'static,)*
472        {
473            #[allow(non_snake_case)]
474            fn register(
475                self,
476                builder: LearnerBuilder<B, M, O, S, TI, VI, TO, VO>,
477            ) -> LearnerBuilder<B, M, O, S, TI, VI, TO, VO> {
478                let ($($M,)*) = self;
479                $(let builder = builder.metric_train($M.clone());)*
480                $(let builder = builder.metric_valid($M);)*
481                builder
482            }
483        }
484
485        impl<$($M,)* B, M, O, S, TI, VI, TO, VO> MetricRegistration<B, M, O, S, TI, VI, TO, VO> for ($($M,)*)
486        where
487            B: AutodiffBackend,
488            M: AutodiffModule<B> + TrainStep<TI, TO> + core::fmt::Display + 'static,
489            M::InnerModule: ValidStep<VI, VO>,
490            O: Optimizer<M, B>,
491            S: LrScheduler,
492            TI: Send + 'static,
493            VI: Send + 'static,
494            TO: ItemLazy + 'static,
495            VO: ItemLazy + 'static,
496            $(TO::ItemSync: Adaptor<$M::Input>,)*
497            $(VO::ItemSync: Adaptor<$M::Input>,)*
498            $($M: Metric + $crate::metric::Numeric + 'static,)*
499        {
500            #[allow(non_snake_case)]
501            fn register(
502                self,
503                builder: LearnerBuilder<B, M, O, S, TI, VI, TO, VO>,
504            ) -> LearnerBuilder<B, M, O, S, TI, VI, TO, VO> {
505                let ($($M,)*) = self;
506                $(let builder = builder.metric_train_numeric($M.clone());)*
507                $(let builder = builder.metric_valid_numeric($M);)*
508                builder
509            }
510        }
511    };
512}
513
514gen_tuple!(M1);
515gen_tuple!(M1, M2);
516gen_tuple!(M1, M2, M3);
517gen_tuple!(M1, M2, M3, M4);
518gen_tuple!(M1, M2, M3, M4, M5);
519gen_tuple!(M1, M2, M3, M4, M5, M6);