burn_train/learner/
builder.rs

1use std::collections::BTreeSet;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use super::Learner;
6use crate::checkpoint::{
7    AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
8    KeepLastNCheckpoints, MetricCheckpointingStrategy,
9};
10use crate::components::LearnerComponentsMarker;
11use crate::learner::EarlyStoppingStrategy;
12use crate::learner::base::TrainingInterrupter;
13use crate::logger::{FileMetricLogger, MetricLogger};
14use crate::metric::processor::{AsyncProcessor, FullEventProcessor, ItemLazy, Metrics};
15use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
16use crate::metric::{Adaptor, LossMetric, Metric};
17use crate::renderer::{MetricsRenderer, default_renderer};
18use crate::{
19    ApplicationLoggerInstaller, FileApplicationLoggerInstaller, LearnerCheckpointer,
20    LearnerSummaryConfig,
21};
22use burn_core::lr_scheduler::LrScheduler;
23use burn_core::module::AutodiffModule;
24use burn_core::optim::Optimizer;
25use burn_core::record::FileRecorder;
26use burn_core::tensor::backend::AutodiffBackend;
27
28/// Struct to configure and create a [learner](Learner).
29pub struct LearnerBuilder<B, T, V, M, O, S>
30where
31    T: ItemLazy + 'static,
32    V: ItemLazy + 'static,
33    B: AutodiffBackend,
34    M: AutodiffModule<B>,
35    O: Optimizer<M, B>,
36    S: LrScheduler,
37{
38    // Not that complex and very convenient when the traits are
39    // already constrained correctly. Extracting in another type
40    // would be more complex.
41    #[allow(clippy::type_complexity)]
42    checkpointers: Option<(
43        AsyncCheckpointer<M::Record, B>,
44        AsyncCheckpointer<O::Record, B>,
45        AsyncCheckpointer<S::Record<B>, B>,
46    )>,
47    num_epochs: usize,
48    checkpoint: Option<usize>,
49    directory: PathBuf,
50    grad_accumulation: Option<usize>,
51    devices: Vec<B::Device>,
52    renderer: Option<Box<dyn MetricsRenderer + 'static>>,
53    metrics: Metrics<T, V>,
54    event_store: LogEventStore,
55    interrupter: TrainingInterrupter,
56    tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
57    num_loggers: usize,
58    checkpointer_strategy: Box<dyn CheckpointingStrategy>,
59    early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
60    // Use BTreeSet instead of HashSet for consistent (alphabetical) iteration order
61    summary_metrics: BTreeSet<String>,
62    summary: bool,
63}
64
65impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
66where
67    B: AutodiffBackend,
68    T: ItemLazy + 'static,
69    V: ItemLazy + 'static,
70    M: AutodiffModule<B> + core::fmt::Display + 'static,
71    O: Optimizer<M, B>,
72    S: LrScheduler,
73{
74    /// Creates a new learner builder.
75    ///
76    /// # Arguments
77    ///
78    /// * `directory` - The directory to save the checkpoints.
79    pub fn new(directory: impl AsRef<Path>) -> Self {
80        let directory = directory.as_ref().to_path_buf();
81        let experiment_log_file = directory.join("experiment.log");
82        Self {
83            num_epochs: 1,
84            checkpoint: None,
85            checkpointers: None,
86            directory,
87            grad_accumulation: None,
88            devices: vec![B::Device::default()],
89            metrics: Metrics::default(),
90            event_store: LogEventStore::default(),
91            renderer: None,
92            interrupter: TrainingInterrupter::new(),
93            tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
94                experiment_log_file,
95            ))),
96            num_loggers: 0,
97            checkpointer_strategy: Box::new(
98                ComposedCheckpointingStrategy::builder()
99                    .add(KeepLastNCheckpoints::new(2))
100                    .add(MetricCheckpointingStrategy::new(
101                        &LossMetric::<B>::new(), // default to valid loss
102                        Aggregate::Mean,
103                        Direction::Lowest,
104                        Split::Valid,
105                    ))
106                    .build(),
107            ),
108            early_stopping: None,
109            summary_metrics: BTreeSet::new(),
110            summary: false,
111        }
112    }
113
114    /// Replace the default metric loggers with the provided ones.
115    ///
116    /// # Arguments
117    ///
118    /// * `logger_train` - The training logger.
119    /// * `logger_valid` - The validation logger.
120    pub fn metric_loggers<MT, MV>(mut self, logger_train: MT, logger_valid: MV) -> Self
121    where
122        MT: MetricLogger + 'static,
123        MV: MetricLogger + 'static,
124    {
125        self.event_store.register_logger_train(logger_train);
126        self.event_store.register_logger_valid(logger_valid);
127        self.num_loggers += 1;
128        self
129    }
130
131    /// Update the checkpointing_strategy.
132    pub fn with_checkpointing_strategy<CS>(mut self, strategy: CS) -> Self
133    where
134        CS: CheckpointingStrategy + 'static,
135    {
136        self.checkpointer_strategy = Box::new(strategy);
137        self
138    }
139
140    /// Replace the default CLI renderer with a custom one.
141    ///
142    /// # Arguments
143    ///
144    /// * `renderer` - The custom renderer.
145    pub fn renderer<MR>(mut self, renderer: MR) -> Self
146    where
147        MR: MetricsRenderer + 'static,
148    {
149        self.renderer = Some(Box::new(renderer));
150        self
151    }
152
153    /// Register a training metric.
154    pub fn metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
155    where
156        T::ItemSync: Adaptor<Me::Input>,
157    {
158        self.metrics.register_train_metric(metric);
159        self
160    }
161
162    /// Register a validation metric.
163    pub fn metric_valid<Me: Metric + 'static>(mut self, metric: Me) -> Self
164    where
165        V::ItemSync: Adaptor<Me::Input>,
166    {
167        self.metrics.register_valid_metric(metric);
168        self
169    }
170
171    /// Enable gradients accumulation.
172    ///
173    /// # Notes
174    ///
175    /// When you enable gradients accumulation, the gradients object used by the optimizer will be
176    /// the sum of all gradients generated by each backward pass. It might be a good idea to
177    /// reduce the learning to compensate.
178    ///
179    /// The effect is similar to increasing the `batch size` and the `learning rate` by the `accumulation`
180    /// amount.
181    pub fn grads_accumulation(mut self, accumulation: usize) -> Self {
182        self.grad_accumulation = Some(accumulation);
183        self
184    }
185
186    /// Register a [numeric](crate::metric::Numeric) training [metric](Metric).
187    pub fn metric_train_numeric<Me>(mut self, metric: Me) -> Self
188    where
189        Me: Metric + crate::metric::Numeric + 'static,
190        T::ItemSync: Adaptor<Me::Input>,
191    {
192        self.summary_metrics.insert(metric.name());
193        self.metrics.register_train_metric_numeric(metric);
194        self
195    }
196
197    /// Register a [numeric](crate::metric::Numeric) validation [metric](Metric).
198    pub fn metric_valid_numeric<Me: Metric + crate::metric::Numeric + 'static>(
199        mut self,
200        metric: Me,
201    ) -> Self
202    where
203        V::ItemSync: Adaptor<Me::Input>,
204    {
205        self.summary_metrics.insert(metric.name());
206        self.metrics.register_valid_metric_numeric(metric);
207        self
208    }
209
210    /// The number of epochs the training should last.
211    pub fn num_epochs(mut self, num_epochs: usize) -> Self {
212        self.num_epochs = num_epochs;
213        self
214    }
215
216    /// Run the training loop on multiple devices.
217    pub fn devices(mut self, devices: Vec<B::Device>) -> Self {
218        self.devices = devices;
219        self
220    }
221
222    /// The epoch from which the training must resume.
223    pub fn checkpoint(mut self, checkpoint: usize) -> Self {
224        self.checkpoint = Some(checkpoint);
225        self
226    }
227
228    /// Provides a handle that can be used to interrupt training.
229    pub fn interrupter(&self) -> TrainingInterrupter {
230        self.interrupter.clone()
231    }
232
233    /// Register an [early stopping strategy](EarlyStoppingStrategy) to stop the training when the
234    /// conditions are meet.
235    pub fn early_stopping<Strategy>(mut self, strategy: Strategy) -> Self
236    where
237        Strategy: EarlyStoppingStrategy + 'static,
238    {
239        self.early_stopping = Some(Box::new(strategy));
240        self
241    }
242
243    /// By default, Rust logs are captured and written into
244    /// `experiment.log`. If disabled, standard Rust log handling
245    /// will apply.
246    pub fn with_application_logger(
247        mut self,
248        logger: Option<Box<dyn ApplicationLoggerInstaller>>,
249    ) -> Self {
250        self.tracing_logger = logger;
251        self
252    }
253
254    /// Register a checkpointer that will save the [optimizer](Optimizer), the
255    /// [model](AutodiffModule) and the [scheduler](LrScheduler) to different files.
256    pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
257    where
258        FR: FileRecorder<B> + 'static,
259        FR: FileRecorder<B::InnerBackend> + 'static,
260        O::Record: 'static,
261        M::Record: 'static,
262        S::Record<B>: 'static,
263    {
264        let checkpoint_dir = self.directory.join("checkpoint");
265        let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model");
266        let checkpointer_optimizer =
267            FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim");
268        let checkpointer_scheduler: FileCheckpointer<FR> =
269            FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler");
270
271        self.checkpointers = Some((
272            AsyncCheckpointer::new(checkpointer_model),
273            AsyncCheckpointer::new(checkpointer_optimizer),
274            AsyncCheckpointer::new(checkpointer_scheduler),
275        ));
276
277        self
278    }
279
280    /// Enable the training summary report.
281    ///
282    /// The summary will be displayed at the end of `.fit()`.
283    pub fn summary(mut self) -> Self {
284        self.summary = true;
285        self
286    }
287
288    /// Create the [learner](Learner) from a [model](AutodiffModule) and an [optimizer](Optimizer).
289    /// The [learning rate scheduler](LrScheduler) can also be a simple
290    /// [learning rate](burn_core::LearningRate).
291    #[allow(clippy::type_complexity)] // The goal for the builder is to handle all types and
292    // creates a clean learner.
293    pub fn build(
294        mut self,
295        model: M,
296        optim: O,
297        lr_scheduler: S,
298    ) -> Learner<
299        LearnerComponentsMarker<
300            B,
301            S,
302            M,
303            O,
304            AsyncCheckpointer<M::Record, B>,
305            AsyncCheckpointer<O::Record, B>,
306            AsyncCheckpointer<S::Record<B>, B>,
307            AsyncProcessor<FullEventProcessor<T, V>>,
308            Box<dyn CheckpointingStrategy>,
309        >,
310    >
311    where
312        M::Record: 'static,
313        O::Record: 'static,
314        S::Record<B>: 'static,
315    {
316        if self.tracing_logger.is_some() {
317            if let Err(e) = self.tracing_logger.as_ref().unwrap().install() {
318                log::warn!("Failed to install the experiment logger: {e}");
319            }
320        }
321        let renderer = self
322            .renderer
323            .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
324
325        if self.num_loggers == 0 {
326            self.event_store
327                .register_logger_train(FileMetricLogger::new(self.directory.join("train")));
328            self.event_store
329                .register_logger_valid(FileMetricLogger::new(self.directory.join("valid")));
330        }
331
332        let event_store = Arc::new(EventStoreClient::new(self.event_store));
333        let event_processor = AsyncProcessor::new(FullEventProcessor::new(
334            self.metrics,
335            renderer,
336            event_store.clone(),
337        ));
338
339        let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
340            LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy)
341        });
342
343        let summary = if self.summary {
344            Some(LearnerSummaryConfig {
345                directory: self.directory,
346                metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
347            })
348        } else {
349            None
350        };
351
352        Learner {
353            model,
354            optim,
355            lr_scheduler,
356            checkpointer,
357            num_epochs: self.num_epochs,
358            event_processor,
359            event_store,
360            checkpoint: self.checkpoint,
361            grad_accumulation: self.grad_accumulation,
362            devices: self.devices,
363            interrupter: self.interrupter,
364            early_stopping: self.early_stopping,
365            summary,
366        }
367    }
368}