burn_train/learner/
builder.rs

1use std::collections::HashSet;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use super::Learner;
6use crate::checkpoint::{
7    AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
8    KeepLastNCheckpoints, MetricCheckpointingStrategy,
9};
10use crate::components::LearnerComponentsMarker;
11use crate::learner::base::TrainingInterrupter;
12use crate::learner::EarlyStoppingStrategy;
13use crate::logger::{FileMetricLogger, MetricLogger};
14use crate::metric::processor::{AsyncProcessor, FullEventProcessor, ItemLazy, Metrics};
15use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
16use crate::metric::{Adaptor, LossMetric, Metric};
17use crate::renderer::{default_renderer, MetricsRenderer};
18use crate::{
19    ApplicationLoggerInstaller, FileApplicationLoggerInstaller, LearnerCheckpointer,
20    LearnerSummaryConfig,
21};
22use burn_core::lr_scheduler::LrScheduler;
23use burn_core::module::AutodiffModule;
24use burn_core::optim::Optimizer;
25use burn_core::record::FileRecorder;
26use burn_core::tensor::backend::AutodiffBackend;
27
28/// Struct to configure and create a [learner](Learner).
29pub struct LearnerBuilder<B, T, V, M, O, S>
30where
31    T: ItemLazy + 'static,
32    V: ItemLazy + 'static,
33    B: AutodiffBackend,
34    M: AutodiffModule<B>,
35    O: Optimizer<M, B>,
36    S: LrScheduler,
37{
38    // Not that complex and very convenient when the traits are
39    // already constrained correctly. Extracting in another type
40    // would be more complex.
41    #[allow(clippy::type_complexity)]
42    checkpointers: Option<(
43        AsyncCheckpointer<M::Record, B>,
44        AsyncCheckpointer<O::Record, B>,
45        AsyncCheckpointer<S::Record<B>, B>,
46    )>,
47    num_epochs: usize,
48    checkpoint: Option<usize>,
49    directory: PathBuf,
50    grad_accumulation: Option<usize>,
51    devices: Vec<B::Device>,
52    renderer: Option<Box<dyn MetricsRenderer + 'static>>,
53    metrics: Metrics<T, V>,
54    event_store: LogEventStore,
55    interrupter: TrainingInterrupter,
56    tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
57    num_loggers: usize,
58    checkpointer_strategy: Box<dyn CheckpointingStrategy>,
59    early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
60    summary_metrics: HashSet<String>,
61    summary: bool,
62}
63
64impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
65where
66    B: AutodiffBackend,
67    T: ItemLazy + 'static,
68    V: ItemLazy + 'static,
69    M: AutodiffModule<B> + core::fmt::Display + 'static,
70    O: Optimizer<M, B>,
71    S: LrScheduler,
72{
73    /// Creates a new learner builder.
74    ///
75    /// # Arguments
76    ///
77    /// * `directory` - The directory to save the checkpoints.
78    pub fn new(directory: impl AsRef<Path>) -> Self {
79        let directory = directory.as_ref().to_path_buf();
80        let experiment_log_file = directory.join("experiment.log");
81        Self {
82            num_epochs: 1,
83            checkpoint: None,
84            checkpointers: None,
85            directory,
86            grad_accumulation: None,
87            devices: vec![B::Device::default()],
88            metrics: Metrics::default(),
89            event_store: LogEventStore::default(),
90            renderer: None,
91            interrupter: TrainingInterrupter::new(),
92            tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
93                experiment_log_file,
94            ))),
95            num_loggers: 0,
96            checkpointer_strategy: Box::new(
97                ComposedCheckpointingStrategy::builder()
98                    .add(KeepLastNCheckpoints::new(2))
99                    .add(MetricCheckpointingStrategy::new::<LossMetric<B>>(
100                        Aggregate::Mean,
101                        Direction::Lowest,
102                        Split::Valid,
103                    ))
104                    .build(),
105            ),
106            early_stopping: None,
107            summary_metrics: HashSet::new(),
108            summary: false,
109        }
110    }
111
112    /// Replace the default metric loggers with the provided ones.
113    ///
114    /// # Arguments
115    ///
116    /// * `logger_train` - The training logger.
117    /// * `logger_valid` - The validation logger.
118    pub fn metric_loggers<MT, MV>(mut self, logger_train: MT, logger_valid: MV) -> Self
119    where
120        MT: MetricLogger + 'static,
121        MV: MetricLogger + 'static,
122    {
123        self.event_store.register_logger_train(logger_train);
124        self.event_store.register_logger_valid(logger_valid);
125        self.num_loggers += 1;
126        self
127    }
128
129    /// Update the checkpointing_strategy.
130    pub fn with_checkpointing_strategy<CS>(mut self, strategy: CS) -> Self
131    where
132        CS: CheckpointingStrategy + 'static,
133    {
134        self.checkpointer_strategy = Box::new(strategy);
135        self
136    }
137
138    /// Replace the default CLI renderer with a custom one.
139    ///
140    /// # Arguments
141    ///
142    /// * `renderer` - The custom renderer.
143    pub fn renderer<MR>(mut self, renderer: MR) -> Self
144    where
145        MR: MetricsRenderer + 'static,
146    {
147        self.renderer = Some(Box::new(renderer));
148        self
149    }
150
151    /// Register a training metric.
152    pub fn metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
153    where
154        T::ItemSync: Adaptor<Me::Input>,
155    {
156        self.metrics.register_train_metric(metric);
157        self
158    }
159
160    /// Register a validation metric.
161    pub fn metric_valid<Me: Metric + 'static>(mut self, metric: Me) -> Self
162    where
163        V::ItemSync: Adaptor<Me::Input>,
164    {
165        self.metrics.register_valid_metric(metric);
166        self
167    }
168
169    /// Enable gradients accumulation.
170    ///
171    /// # Notes
172    ///
173    /// When you enable gradients accumulation, the gradients object used by the optimizer will be
174    /// the sum of all gradients generated by each backward pass. It might be a good idea to
175    /// reduce the learning to compensate.
176    ///
177    /// The effect is similar to increasing the `batch size` and the `learning rate` by the `accumulation`
178    /// amount.
179    pub fn grads_accumulation(mut self, accumulation: usize) -> Self {
180        self.grad_accumulation = Some(accumulation);
181        self
182    }
183
184    /// Register a [numeric](crate::metric::Numeric) training [metric](Metric).
185    pub fn metric_train_numeric<Me>(mut self, metric: Me) -> Self
186    where
187        Me: Metric + crate::metric::Numeric + 'static,
188        T::ItemSync: Adaptor<Me::Input>,
189    {
190        self.summary_metrics.insert(Me::NAME.to_string());
191        self.metrics.register_train_metric_numeric(metric);
192        self
193    }
194
195    /// Register a [numeric](crate::metric::Numeric) validation [metric](Metric).
196    pub fn metric_valid_numeric<Me: Metric + crate::metric::Numeric + 'static>(
197        mut self,
198        metric: Me,
199    ) -> Self
200    where
201        V::ItemSync: Adaptor<Me::Input>,
202    {
203        self.summary_metrics.insert(Me::NAME.to_string());
204        self.metrics.register_valid_metric_numeric(metric);
205        self
206    }
207
208    /// The number of epochs the training should last.
209    pub fn num_epochs(mut self, num_epochs: usize) -> Self {
210        self.num_epochs = num_epochs;
211        self
212    }
213
214    /// Run the training loop on multiple devices.
215    pub fn devices(mut self, devices: Vec<B::Device>) -> Self {
216        self.devices = devices;
217        self
218    }
219
220    /// The epoch from which the training must resume.
221    pub fn checkpoint(mut self, checkpoint: usize) -> Self {
222        self.checkpoint = Some(checkpoint);
223        self
224    }
225
226    /// Provides a handle that can be used to interrupt training.
227    pub fn interrupter(&self) -> TrainingInterrupter {
228        self.interrupter.clone()
229    }
230
231    /// Register an [early stopping strategy](EarlyStoppingStrategy) to stop the training when the
232    /// conditions are meet.
233    pub fn early_stopping<Strategy>(mut self, strategy: Strategy) -> Self
234    where
235        Strategy: EarlyStoppingStrategy + 'static,
236    {
237        self.early_stopping = Some(Box::new(strategy));
238        self
239    }
240
241    /// By default, Rust logs are captured and written into
242    /// `experiment.log`. If disabled, standard Rust log handling
243    /// will apply.
244    pub fn with_application_logger(
245        mut self,
246        logger: Option<Box<dyn ApplicationLoggerInstaller>>,
247    ) -> Self {
248        self.tracing_logger = logger;
249        self
250    }
251
252    /// Register a checkpointer that will save the [optimizer](Optimizer), the
253    /// [model](AutodiffModule) and the [scheduler](LrScheduler) to different files.
254    pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
255    where
256        FR: FileRecorder<B> + 'static,
257        FR: FileRecorder<B::InnerBackend> + 'static,
258        O::Record: 'static,
259        M::Record: 'static,
260        S::Record<B>: 'static,
261    {
262        let checkpoint_dir = self.directory.join("checkpoint");
263        let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model");
264        let checkpointer_optimizer =
265            FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim");
266        let checkpointer_scheduler: FileCheckpointer<FR> =
267            FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler");
268
269        self.checkpointers = Some((
270            AsyncCheckpointer::new(checkpointer_model),
271            AsyncCheckpointer::new(checkpointer_optimizer),
272            AsyncCheckpointer::new(checkpointer_scheduler),
273        ));
274
275        self
276    }
277
278    /// Enable the training summary report.
279    ///
280    /// The summary will be displayed at the end of `.fit()`.
281    pub fn summary(mut self) -> Self {
282        self.summary = true;
283        self
284    }
285
286    /// Create the [learner](Learner) from a [model](AutodiffModule) and an [optimizer](Optimizer).
287    /// The [learning rate scheduler](LrScheduler) can also be a simple
288    /// [learning rate](burn_core::LearningRate).
289    #[allow(clippy::type_complexity)] // The goal for the builder is to handle all types and
290                                      // creates a clean learner.
291    pub fn build(
292        mut self,
293        model: M,
294        optim: O,
295        lr_scheduler: S,
296    ) -> Learner<
297        LearnerComponentsMarker<
298            B,
299            S,
300            M,
301            O,
302            AsyncCheckpointer<M::Record, B>,
303            AsyncCheckpointer<O::Record, B>,
304            AsyncCheckpointer<S::Record<B>, B>,
305            AsyncProcessor<FullEventProcessor<T, V>>,
306            Box<dyn CheckpointingStrategy>,
307        >,
308    >
309    where
310        M::Record: 'static,
311        O::Record: 'static,
312        S::Record<B>: 'static,
313    {
314        if self.tracing_logger.is_some() {
315            if let Err(e) = self.tracing_logger.as_ref().unwrap().install() {
316                log::warn!("Failed to install the experiment logger: {}", e);
317            }
318        }
319        let renderer = self
320            .renderer
321            .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
322
323        if self.num_loggers == 0 {
324            self.event_store
325                .register_logger_train(FileMetricLogger::new(self.directory.join("train")));
326            self.event_store
327                .register_logger_valid(FileMetricLogger::new(self.directory.join("valid")));
328        }
329
330        let event_store = Arc::new(EventStoreClient::new(self.event_store));
331        let event_processor = AsyncProcessor::new(FullEventProcessor::new(
332            self.metrics,
333            renderer,
334            event_store.clone(),
335        ));
336
337        let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
338            LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy)
339        });
340
341        let summary = if self.summary {
342            Some(LearnerSummaryConfig {
343                directory: self.directory,
344                metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
345            })
346        } else {
347            None
348        };
349
350        Learner {
351            model,
352            optim,
353            lr_scheduler,
354            checkpointer,
355            num_epochs: self.num_epochs,
356            event_processor,
357            event_store,
358            checkpoint: self.checkpoint,
359            grad_accumulation: self.grad_accumulation,
360            devices: self.devices,
361            interrupter: self.interrupter,
362            early_stopping: self.early_stopping,
363            summary,
364        }
365    }
366}