1use std::collections::BTreeSet;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use super::Learner;
6use crate::checkpoint::{
7 AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
8 KeepLastNCheckpoints, MetricCheckpointingStrategy,
9};
10use crate::components::LearnerComponentsMarker;
11use crate::learner::EarlyStoppingStrategy;
12use crate::learner::base::TrainingInterrupter;
13use crate::logger::{FileMetricLogger, MetricLogger};
14use crate::metric::processor::{AsyncProcessor, FullEventProcessor, ItemLazy, Metrics};
15use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
16use crate::metric::{Adaptor, LossMetric, Metric};
17use crate::renderer::{MetricsRenderer, default_renderer};
18use crate::{
19 ApplicationLoggerInstaller, FileApplicationLoggerInstaller, LearnerCheckpointer,
20 LearnerSummaryConfig,
21};
22use burn_core::lr_scheduler::LrScheduler;
23use burn_core::module::AutodiffModule;
24use burn_core::optim::Optimizer;
25use burn_core::record::FileRecorder;
26use burn_core::tensor::backend::AutodiffBackend;
27
28pub struct LearnerBuilder<B, T, V, M, O, S>
30where
31 T: ItemLazy + 'static,
32 V: ItemLazy + 'static,
33 B: AutodiffBackend,
34 M: AutodiffModule<B>,
35 O: Optimizer<M, B>,
36 S: LrScheduler,
37{
38 #[allow(clippy::type_complexity)]
42 checkpointers: Option<(
43 AsyncCheckpointer<M::Record, B>,
44 AsyncCheckpointer<O::Record, B>,
45 AsyncCheckpointer<S::Record<B>, B>,
46 )>,
47 num_epochs: usize,
48 checkpoint: Option<usize>,
49 directory: PathBuf,
50 grad_accumulation: Option<usize>,
51 devices: Vec<B::Device>,
52 renderer: Option<Box<dyn MetricsRenderer + 'static>>,
53 metrics: Metrics<T, V>,
54 event_store: LogEventStore,
55 interrupter: TrainingInterrupter,
56 tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
57 num_loggers: usize,
58 checkpointer_strategy: Box<dyn CheckpointingStrategy>,
59 early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
60 summary_metrics: BTreeSet<String>,
62 summary: bool,
63}
64
65impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
66where
67 B: AutodiffBackend,
68 T: ItemLazy + 'static,
69 V: ItemLazy + 'static,
70 M: AutodiffModule<B> + core::fmt::Display + 'static,
71 O: Optimizer<M, B>,
72 S: LrScheduler,
73{
74 pub fn new(directory: impl AsRef<Path>) -> Self {
80 let directory = directory.as_ref().to_path_buf();
81 let experiment_log_file = directory.join("experiment.log");
82 Self {
83 num_epochs: 1,
84 checkpoint: None,
85 checkpointers: None,
86 directory,
87 grad_accumulation: None,
88 devices: vec![B::Device::default()],
89 metrics: Metrics::default(),
90 event_store: LogEventStore::default(),
91 renderer: None,
92 interrupter: TrainingInterrupter::new(),
93 tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
94 experiment_log_file,
95 ))),
96 num_loggers: 0,
97 checkpointer_strategy: Box::new(
98 ComposedCheckpointingStrategy::builder()
99 .add(KeepLastNCheckpoints::new(2))
100 .add(MetricCheckpointingStrategy::new(
101 &LossMetric::<B>::new(), Aggregate::Mean,
103 Direction::Lowest,
104 Split::Valid,
105 ))
106 .build(),
107 ),
108 early_stopping: None,
109 summary_metrics: BTreeSet::new(),
110 summary: false,
111 }
112 }
113
114 pub fn metric_loggers<MT, MV>(mut self, logger_train: MT, logger_valid: MV) -> Self
121 where
122 MT: MetricLogger + 'static,
123 MV: MetricLogger + 'static,
124 {
125 self.event_store.register_logger_train(logger_train);
126 self.event_store.register_logger_valid(logger_valid);
127 self.num_loggers += 1;
128 self
129 }
130
131 pub fn with_checkpointing_strategy<CS>(mut self, strategy: CS) -> Self
133 where
134 CS: CheckpointingStrategy + 'static,
135 {
136 self.checkpointer_strategy = Box::new(strategy);
137 self
138 }
139
140 pub fn renderer<MR>(mut self, renderer: MR) -> Self
146 where
147 MR: MetricsRenderer + 'static,
148 {
149 self.renderer = Some(Box::new(renderer));
150 self
151 }
152
153 pub fn metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
155 where
156 T::ItemSync: Adaptor<Me::Input>,
157 {
158 self.metrics.register_train_metric(metric);
159 self
160 }
161
162 pub fn metric_valid<Me: Metric + 'static>(mut self, metric: Me) -> Self
164 where
165 V::ItemSync: Adaptor<Me::Input>,
166 {
167 self.metrics.register_valid_metric(metric);
168 self
169 }
170
171 pub fn grads_accumulation(mut self, accumulation: usize) -> Self {
182 self.grad_accumulation = Some(accumulation);
183 self
184 }
185
186 pub fn metric_train_numeric<Me>(mut self, metric: Me) -> Self
188 where
189 Me: Metric + crate::metric::Numeric + 'static,
190 T::ItemSync: Adaptor<Me::Input>,
191 {
192 self.summary_metrics.insert(metric.name());
193 self.metrics.register_train_metric_numeric(metric);
194 self
195 }
196
197 pub fn metric_valid_numeric<Me: Metric + crate::metric::Numeric + 'static>(
199 mut self,
200 metric: Me,
201 ) -> Self
202 where
203 V::ItemSync: Adaptor<Me::Input>,
204 {
205 self.summary_metrics.insert(metric.name());
206 self.metrics.register_valid_metric_numeric(metric);
207 self
208 }
209
210 pub fn num_epochs(mut self, num_epochs: usize) -> Self {
212 self.num_epochs = num_epochs;
213 self
214 }
215
216 pub fn devices(mut self, devices: Vec<B::Device>) -> Self {
218 self.devices = devices;
219 self
220 }
221
222 pub fn checkpoint(mut self, checkpoint: usize) -> Self {
224 self.checkpoint = Some(checkpoint);
225 self
226 }
227
228 pub fn interrupter(&self) -> TrainingInterrupter {
230 self.interrupter.clone()
231 }
232
233 pub fn early_stopping<Strategy>(mut self, strategy: Strategy) -> Self
236 where
237 Strategy: EarlyStoppingStrategy + 'static,
238 {
239 self.early_stopping = Some(Box::new(strategy));
240 self
241 }
242
243 pub fn with_application_logger(
247 mut self,
248 logger: Option<Box<dyn ApplicationLoggerInstaller>>,
249 ) -> Self {
250 self.tracing_logger = logger;
251 self
252 }
253
254 pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
257 where
258 FR: FileRecorder<B> + 'static,
259 FR: FileRecorder<B::InnerBackend> + 'static,
260 O::Record: 'static,
261 M::Record: 'static,
262 S::Record<B>: 'static,
263 {
264 let checkpoint_dir = self.directory.join("checkpoint");
265 let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model");
266 let checkpointer_optimizer =
267 FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim");
268 let checkpointer_scheduler: FileCheckpointer<FR> =
269 FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler");
270
271 self.checkpointers = Some((
272 AsyncCheckpointer::new(checkpointer_model),
273 AsyncCheckpointer::new(checkpointer_optimizer),
274 AsyncCheckpointer::new(checkpointer_scheduler),
275 ));
276
277 self
278 }
279
280 pub fn summary(mut self) -> Self {
284 self.summary = true;
285 self
286 }
287
288 #[allow(clippy::type_complexity)] pub fn build(
294 mut self,
295 model: M,
296 optim: O,
297 lr_scheduler: S,
298 ) -> Learner<
299 LearnerComponentsMarker<
300 B,
301 S,
302 M,
303 O,
304 AsyncCheckpointer<M::Record, B>,
305 AsyncCheckpointer<O::Record, B>,
306 AsyncCheckpointer<S::Record<B>, B>,
307 AsyncProcessor<FullEventProcessor<T, V>>,
308 Box<dyn CheckpointingStrategy>,
309 >,
310 >
311 where
312 M::Record: 'static,
313 O::Record: 'static,
314 S::Record<B>: 'static,
315 {
316 if self.tracing_logger.is_some() {
317 if let Err(e) = self.tracing_logger.as_ref().unwrap().install() {
318 log::warn!("Failed to install the experiment logger: {e}");
319 }
320 }
321 let renderer = self
322 .renderer
323 .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
324
325 if self.num_loggers == 0 {
326 self.event_store
327 .register_logger_train(FileMetricLogger::new(self.directory.join("train")));
328 self.event_store
329 .register_logger_valid(FileMetricLogger::new(self.directory.join("valid")));
330 }
331
332 let event_store = Arc::new(EventStoreClient::new(self.event_store));
333 let event_processor = AsyncProcessor::new(FullEventProcessor::new(
334 self.metrics,
335 renderer,
336 event_store.clone(),
337 ));
338
339 let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
340 LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy)
341 });
342
343 let summary = if self.summary {
344 Some(LearnerSummaryConfig {
345 directory: self.directory,
346 metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
347 })
348 } else {
349 None
350 };
351
352 Learner {
353 model,
354 optim,
355 lr_scheduler,
356 checkpointer,
357 num_epochs: self.num_epochs,
358 event_processor,
359 event_store,
360 checkpoint: self.checkpoint,
361 grad_accumulation: self.grad_accumulation,
362 devices: self.devices,
363 interrupter: self.interrupter,
364 early_stopping: self.early_stopping,
365 summary,
366 }
367 }
368}