1use std::collections::BTreeSet;
2use std::marker::PhantomData;
3use std::path::{Path, PathBuf};
4use std::sync::Arc;
5
6use super::Learner;
7use crate::checkpoint::{
8 AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
9 KeepLastNCheckpoints, MetricCheckpointingStrategy,
10};
11use crate::components::{LearnerComponentsMarker, LearningDataMarker};
12use crate::learner::EarlyStoppingStrategy;
13use crate::learner::base::Interrupter;
14use crate::logger::{FileMetricLogger, MetricLogger};
15use crate::metric::processor::{
16 AsyncProcessorTraining, FullEventProcessorTraining, ItemLazy, MetricsTraining,
17};
18use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
19use crate::metric::{Adaptor, LossMetric, Metric};
20use crate::renderer::{MetricsRenderer, default_renderer};
21use crate::{
22 ApplicationLoggerInstaller, EarlyStoppingStrategyRef, FileApplicationLoggerInstaller,
23 LearnerCheckpointer, LearnerSummaryConfig, LearningStrategy, TrainStep, ValidStep,
24};
25use burn_core::module::AutodiffModule;
26use burn_core::record::FileRecorder;
27use burn_core::tensor::backend::AutodiffBackend;
28use burn_optim::Optimizer;
29use burn_optim::lr_scheduler::LrScheduler;
30
31pub struct LearnerBuilder<B, M, O, S, TI, VI, TO, VO>
36where
37 B: AutodiffBackend,
38 M: AutodiffModule<B> + TrainStep<TI, TO> + core::fmt::Display + 'static,
39 M::InnerModule: ValidStep<VI, VO>,
40 O: Optimizer<M, B>,
41 S: LrScheduler,
42 TI: Send + 'static,
43 VI: Send + 'static,
44 TO: ItemLazy + 'static,
45 VO: ItemLazy + 'static,
46{
47 #[allow(clippy::type_complexity)]
51 checkpointers: Option<(
52 AsyncCheckpointer<M::Record, B>,
53 AsyncCheckpointer<O::Record, B>,
54 AsyncCheckpointer<S::Record<B>, B>,
55 )>,
56 num_epochs: usize,
57 checkpoint: Option<usize>,
58 directory: PathBuf,
59 grad_accumulation: Option<usize>,
60 learning_strategy: LearningStrategy<B>,
61 renderer: Option<Box<dyn MetricsRenderer + 'static>>,
62 metrics: MetricsTraining<TO, VO>,
63 event_store: LogEventStore,
64 interrupter: Interrupter,
65 tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
66 num_loggers: usize,
67 checkpointer_strategy: Box<dyn CheckpointingStrategy>,
68 early_stopping: Option<EarlyStoppingStrategyRef>,
69 summary_metrics: BTreeSet<String>,
71 summary: bool,
72 _p: PhantomData<(TI, VI, TO, VO)>,
73}
74
75impl<B, M, O, S, TI, VI, TO, VO> LearnerBuilder<B, M, O, S, TI, VI, TO, VO>
76where
77 B: AutodiffBackend,
78 M: AutodiffModule<B> + TrainStep<TI, TO> + core::fmt::Display + 'static,
79 M::InnerModule: ValidStep<VI, VO>,
80 O: Optimizer<M, B>,
81 S: LrScheduler,
82 TI: Send + 'static,
83 VI: Send + 'static,
84 TO: ItemLazy + 'static,
85 VO: ItemLazy + 'static,
86{
87 pub fn new(directory: impl AsRef<Path>) -> Self {
93 let directory = directory.as_ref().to_path_buf();
94 let experiment_log_file = directory.join("experiment.log");
95 Self {
96 num_epochs: 1,
97 checkpoint: None,
98 checkpointers: None,
99 directory,
100 grad_accumulation: None,
101 learning_strategy: LearningStrategy::default(),
102 metrics: MetricsTraining::default(),
103 event_store: LogEventStore::default(),
104 renderer: None,
105 interrupter: Interrupter::new(),
106 tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
107 experiment_log_file,
108 ))),
109 num_loggers: 0,
110 checkpointer_strategy: Box::new(
111 ComposedCheckpointingStrategy::builder()
112 .add(KeepLastNCheckpoints::new(2))
113 .add(MetricCheckpointingStrategy::new(
114 &LossMetric::<B>::new(), Aggregate::Mean,
116 Direction::Lowest,
117 Split::Valid,
118 ))
119 .build(),
120 ),
121 early_stopping: None,
122 summary_metrics: BTreeSet::new(),
123 summary: false,
124 _p: PhantomData,
125 }
126 }
127
128 pub fn metric_loggers<MT, MV>(mut self, logger_train: MT, logger_valid: MV) -> Self
135 where
136 MT: MetricLogger + 'static,
137 MV: MetricLogger + 'static,
138 {
139 self.event_store.register_logger_train(logger_train);
140 self.event_store.register_logger_valid(logger_valid);
141 self.num_loggers += 1;
142 self
143 }
144
145 pub fn with_checkpointing_strategy<CS>(mut self, strategy: CS) -> Self
147 where
148 CS: CheckpointingStrategy + 'static,
149 {
150 self.checkpointer_strategy = Box::new(strategy);
151 self
152 }
153
154 pub fn renderer<MR>(mut self, renderer: MR) -> Self
160 where
161 MR: MetricsRenderer + 'static,
162 {
163 self.renderer = Some(Box::new(renderer));
164 self
165 }
166
167 pub fn metrics<Me: MetricRegistration<B, M, O, S, TI, VI, TO, VO>>(self, metrics: Me) -> Self {
169 metrics.register(self)
170 }
171
172 pub fn metrics_text<Me: TextMetricRegistration<B, M, O, S, TI, VI, TO, VO>>(
174 self,
175 metrics: Me,
176 ) -> Self {
177 metrics.register(self)
178 }
179
180 pub fn metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
182 where
183 <TO as ItemLazy>::ItemSync: Adaptor<Me::Input>,
184 {
185 self.metrics.register_train_metric(metric);
186 self
187 }
188
189 pub fn metric_valid<Me: Metric + 'static>(mut self, metric: Me) -> Self
191 where
192 <VO as ItemLazy>::ItemSync: Adaptor<Me::Input>,
193 {
194 self.metrics.register_valid_metric(metric);
195 self
196 }
197
198 pub fn grads_accumulation(mut self, accumulation: usize) -> Self {
209 self.grad_accumulation = Some(accumulation);
210 self
211 }
212
213 pub fn metric_train_numeric<Me>(mut self, metric: Me) -> Self
215 where
216 Me: Metric + crate::metric::Numeric + 'static,
217 <TO as ItemLazy>::ItemSync: Adaptor<Me::Input>,
218 {
219 self.summary_metrics.insert(metric.name().to_string());
220 self.metrics.register_train_metric_numeric(metric);
221 self
222 }
223
224 pub fn metric_valid_numeric<Me: Metric + crate::metric::Numeric + 'static>(
226 mut self,
227 metric: Me,
228 ) -> Self
229 where
230 <VO as ItemLazy>::ItemSync: Adaptor<Me::Input>,
231 {
232 self.summary_metrics.insert(metric.name().to_string());
233 self.metrics.register_valid_metric_numeric(metric);
234 self
235 }
236
237 pub fn num_epochs(mut self, num_epochs: usize) -> Self {
239 self.num_epochs = num_epochs;
240 self
241 }
242
243 pub fn learning_strategy(mut self, learning_strategy: LearningStrategy<B>) -> Self {
245 self.learning_strategy = learning_strategy;
246 self
247 }
248
249 pub fn checkpoint(mut self, checkpoint: usize) -> Self {
251 self.checkpoint = Some(checkpoint);
252 self
253 }
254
255 pub fn interrupter(&self) -> Interrupter {
257 self.interrupter.clone()
258 }
259
260 pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
262 self.interrupter = interrupter;
263 self
264 }
265
266 pub fn early_stopping<Strategy>(mut self, strategy: Strategy) -> Self
269 where
270 Strategy: EarlyStoppingStrategy + Clone + Send + Sync + 'static,
271 {
272 self.early_stopping = Some(Box::new(strategy));
273 self
274 }
275
276 pub fn with_application_logger(
280 mut self,
281 logger: Option<Box<dyn ApplicationLoggerInstaller>>,
282 ) -> Self {
283 self.tracing_logger = logger;
284 self
285 }
286
287 pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
290 where
291 FR: FileRecorder<B> + 'static,
292 FR: FileRecorder<B::InnerBackend> + 'static,
293 O::Record: 'static,
294 M::Record: 'static,
295 S::Record<B>: 'static,
296 {
297 let checkpoint_dir = self.directory.join("checkpoint");
298 let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model");
299 let checkpointer_optimizer =
300 FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim");
301 let checkpointer_scheduler: FileCheckpointer<FR> =
302 FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler");
303
304 self.checkpointers = Some((
305 AsyncCheckpointer::new(checkpointer_model),
306 AsyncCheckpointer::new(checkpointer_optimizer),
307 AsyncCheckpointer::new(checkpointer_scheduler),
308 ));
309
310 self
311 }
312
313 pub fn summary(mut self) -> Self {
317 self.summary = true;
318 self
319 }
320
321 #[allow(clippy::type_complexity)] pub fn build(
327 mut self,
328 model: M,
329 optim: O,
330 lr_scheduler: S,
331 ) -> Learner<
332 LearnerComponentsMarker<
333 B,
334 S,
335 M,
336 O,
337 AsyncCheckpointer<M::Record, B>,
338 AsyncCheckpointer<O::Record, B>,
339 AsyncCheckpointer<S::Record<B>, B>,
340 AsyncProcessorTraining<FullEventProcessorTraining<TO, VO>>,
341 Box<dyn CheckpointingStrategy>,
342 LearningDataMarker<TI, VI, TO, VO>,
343 >,
344 >
345 where
346 M::Record: 'static,
347 O::Record: 'static,
348 S::Record<B>: 'static,
349 {
350 if self.tracing_logger.is_some()
351 && let Err(e) = self.tracing_logger.as_ref().unwrap().install()
352 {
353 log::warn!("Failed to install the experiment logger: {e}");
354 }
355 let renderer = self
356 .renderer
357 .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
358
359 if self.num_loggers == 0 {
360 self.event_store
361 .register_logger_train(FileMetricLogger::new_train(self.directory.join("train")));
362 self.event_store
363 .register_logger_valid(FileMetricLogger::new_train(self.directory.join("valid")));
364 }
365
366 let event_store = Arc::new(EventStoreClient::new(self.event_store));
367 let event_processor = AsyncProcessorTraining::new(FullEventProcessorTraining::new(
368 self.metrics,
369 renderer,
370 event_store.clone(),
371 ));
372
373 let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
374 LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy)
375 });
376
377 let summary = if self.summary {
378 Some(LearnerSummaryConfig {
379 directory: self.directory,
380 metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
381 })
382 } else {
383 None
384 };
385
386 let learning_strategy = Self::prepare_learning_strategy(self.learning_strategy);
387
388 Learner {
389 model,
390 optim,
391 lr_scheduler,
392 checkpointer,
393 num_epochs: self.num_epochs,
394 event_processor,
395 event_store,
396 checkpoint: self.checkpoint,
397 grad_accumulation: self.grad_accumulation,
398 learning_strategy,
399 interrupter: self.interrupter,
400 early_stopping: self.early_stopping,
401 summary,
402 }
403 }
404
405 fn prepare_learning_strategy(learning_strategy: LearningStrategy<B>) -> LearningStrategy<B> {
406 if let LearningStrategy::MultiDeviceNaive(devices) = &learning_strategy
407 && devices.len() == 1
408 {
409 return LearningStrategy::SingleDevice(devices[0].clone());
410 }
411
412 learning_strategy
413 }
414}
415
416pub trait MetricRegistration<B, M, O, S, TI, VI, TO, VO>: Sized
418where
419 B: AutodiffBackend,
420 M: AutodiffModule<B> + TrainStep<TI, TO> + core::fmt::Display + 'static,
421 M::InnerModule: ValidStep<VI, VO>,
422 O: Optimizer<M, B>,
423 S: LrScheduler,
424 TI: Send + 'static,
425 VI: Send + 'static,
426 TO: ItemLazy + 'static,
427 VO: ItemLazy + 'static,
428{
429 fn register(
431 self,
432 builder: LearnerBuilder<B, M, O, S, TI, VI, TO, VO>,
433 ) -> LearnerBuilder<B, M, O, S, TI, VI, TO, VO>;
434}
435
436pub trait TextMetricRegistration<B, M, O, S, TI, VI, TO, VO>: Sized
438where
439 B: AutodiffBackend,
440 M: AutodiffModule<B> + TrainStep<TI, TO> + core::fmt::Display + 'static,
441 M::InnerModule: ValidStep<VI, VO>,
442 O: Optimizer<M, B>,
443 S: LrScheduler,
444 TI: Send + 'static,
445 VI: Send + 'static,
446 TO: ItemLazy + 'static,
447 VO: ItemLazy + 'static,
448{
449 fn register(
451 self,
452 builder: LearnerBuilder<B, M, O, S, TI, VI, TO, VO>,
453 ) -> LearnerBuilder<B, M, O, S, TI, VI, TO, VO>;
454}
455
456macro_rules! gen_tuple {
457 ($($M:ident),*) => {
458 impl<$($M,)* B, M, O, S, TI, VI, TO, VO> TextMetricRegistration<B, M, O, S, TI, VI, TO, VO> for ($($M,)*)
459 where
460 B: AutodiffBackend,
461 M: AutodiffModule<B> + TrainStep<TI, TO> + core::fmt::Display + 'static,
462 M::InnerModule: ValidStep<VI, VO>,
463 O: Optimizer<M, B>,
464 S: LrScheduler,
465 TI: Send + 'static,
466 VI: Send + 'static,
467 TO: ItemLazy + 'static,
468 VO: ItemLazy + 'static,
469 $(TO::ItemSync: Adaptor<$M::Input>,)*
470 $(VO::ItemSync: Adaptor<$M::Input>,)*
471 $($M: Metric + 'static,)*
472 {
473 #[allow(non_snake_case)]
474 fn register(
475 self,
476 builder: LearnerBuilder<B, M, O, S, TI, VI, TO, VO>,
477 ) -> LearnerBuilder<B, M, O, S, TI, VI, TO, VO> {
478 let ($($M,)*) = self;
479 $(let builder = builder.metric_train($M.clone());)*
480 $(let builder = builder.metric_valid($M);)*
481 builder
482 }
483 }
484
485 impl<$($M,)* B, M, O, S, TI, VI, TO, VO> MetricRegistration<B, M, O, S, TI, VI, TO, VO> for ($($M,)*)
486 where
487 B: AutodiffBackend,
488 M: AutodiffModule<B> + TrainStep<TI, TO> + core::fmt::Display + 'static,
489 M::InnerModule: ValidStep<VI, VO>,
490 O: Optimizer<M, B>,
491 S: LrScheduler,
492 TI: Send + 'static,
493 VI: Send + 'static,
494 TO: ItemLazy + 'static,
495 VO: ItemLazy + 'static,
496 $(TO::ItemSync: Adaptor<$M::Input>,)*
497 $(VO::ItemSync: Adaptor<$M::Input>,)*
498 $($M: Metric + $crate::metric::Numeric + 'static,)*
499 {
500 #[allow(non_snake_case)]
501 fn register(
502 self,
503 builder: LearnerBuilder<B, M, O, S, TI, VI, TO, VO>,
504 ) -> LearnerBuilder<B, M, O, S, TI, VI, TO, VO> {
505 let ($($M,)*) = self;
506 $(let builder = builder.metric_train_numeric($M.clone());)*
507 $(let builder = builder.metric_valid_numeric($M);)*
508 builder
509 }
510 }
511 };
512}
513
514gen_tuple!(M1);
515gen_tuple!(M1, M2);
516gen_tuple!(M1, M2, M3);
517gen_tuple!(M1, M2, M3, M4);
518gen_tuple!(M1, M2, M3, M4, M5);
519gen_tuple!(M1, M2, M3, M4, M5, M6);