1use crate::checkpoint::{
2 AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
3 KeepLastNCheckpoints, MetricCheckpointingStrategy,
4};
5use crate::components::{InferenceModelOutput, TrainingModelOutput};
6use crate::learner::EarlyStoppingStrategy;
7use crate::learner::base::Interrupter;
8use crate::logger::{FileMetricLogger, MetricLogger};
9use crate::metric::processor::{
10 AsyncProcessorTraining, FullEventProcessorTraining, ItemLazy, MetricsTraining,
11};
12use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
13use crate::metric::{Adaptor, LossMetric, Metric, Numeric};
14use crate::multi::MultiDeviceLearningStrategy;
15use crate::renderer::{MetricsRenderer, default_renderer};
16use crate::single::SingleDevicetrainingStrategy;
17use crate::{
18 ApplicationLoggerInstaller, EarlyStoppingStrategyRef, FileApplicationLoggerInstaller,
19 InferenceBackend, InferenceModel, InferenceModelInput, InferenceStep, LearnerModelRecord,
20 LearnerOptimizerRecord, LearnerSchedulerRecord, LearnerSummaryConfig, LearningCheckpointer,
21 LearningComponentsMarker, LearningComponentsTypes, LearningResult, TrainStep, TrainingBackend,
22 TrainingComponents, TrainingModelInput, TrainingStrategy,
23};
24use crate::{Learner, SupervisedLearningStrategy};
25use burn_core::data::dataloader::DataLoader;
26use burn_core::module::AutodiffModule;
27use burn_core::record::FileRecorder;
28use burn_core::tensor::backend::AutodiffBackend;
29use burn_optim::Optimizer;
30use burn_optim::lr_scheduler::LrScheduler;
31use std::collections::BTreeSet;
32use std::path::{Path, PathBuf};
33use std::sync::Arc;
34
35pub type TrainLoader<LC> = Arc<dyn DataLoader<TrainingBackend<LC>, TrainingModelInput<LC>>>;
37pub type ValidLoader<LC> = Arc<dyn DataLoader<InferenceBackend<LC>, InferenceModelInput<LC>>>;
39pub type SupervisedTrainingEventProcessor<LC> = AsyncProcessorTraining<
41 FullEventProcessorTraining<TrainingModelOutput<LC>, InferenceModelOutput<LC>>,
42>;
43
44pub struct SupervisedTraining<LC>
46where
47 LC: LearningComponentsTypes,
48{
49 #[allow(clippy::type_complexity)]
51 checkpointers: Option<(
52 AsyncCheckpointer<LearnerModelRecord<LC>, TrainingBackend<LC>>,
53 AsyncCheckpointer<LearnerOptimizerRecord<LC>, TrainingBackend<LC>>,
54 AsyncCheckpointer<LearnerSchedulerRecord<LC>, TrainingBackend<LC>>,
55 )>,
56 num_epochs: usize,
57 checkpoint: Option<usize>,
58 directory: PathBuf,
59 grad_accumulation: Option<usize>,
60 renderer: Option<Box<dyn MetricsRenderer + 'static>>,
61 metrics: MetricsTraining<TrainingModelOutput<LC>, InferenceModelOutput<LC>>,
62 event_store: LogEventStore,
63 interrupter: Interrupter,
64 tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
65 checkpointer_strategy: Box<dyn CheckpointingStrategy>,
66 early_stopping: Option<EarlyStoppingStrategyRef>,
67 training_strategy: TrainingStrategy<LC>,
68 dataloader_train: TrainLoader<LC>,
69 dataloader_valid: ValidLoader<LC>,
70 summary_metrics: BTreeSet<String>,
72 summary: bool,
73}
74
75impl<B, LR, M, O> SupervisedTraining<LearningComponentsMarker<B, LR, M, O>>
76where
77 B: AutodiffBackend,
78 LR: LrScheduler + 'static,
79 M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
80 M::InnerModule: InferenceStep,
81 O: Optimizer<M, B> + 'static,
82{
83 pub fn new(
91 directory: impl AsRef<Path>,
92 dataloader_train: Arc<dyn DataLoader<B, M::Input>>,
93 dataloader_valid: Arc<
94 dyn DataLoader<B::InnerBackend, <M::InnerModule as InferenceStep>::Input>,
95 >,
96 ) -> Self {
97 let directory = directory.as_ref().to_path_buf();
98 let experiment_log_file = directory.join("experiment.log");
99 Self {
100 num_epochs: 1,
101 checkpoint: None,
102 checkpointers: None,
103 directory,
104 grad_accumulation: None,
105 metrics: MetricsTraining::default(),
106 event_store: LogEventStore::default(),
107 renderer: None,
108 interrupter: Interrupter::new(),
109 tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
110 experiment_log_file,
111 ))),
112 checkpointer_strategy: Box::new(
113 ComposedCheckpointingStrategy::builder()
114 .add(KeepLastNCheckpoints::new(2))
115 .add(MetricCheckpointingStrategy::new(
116 &LossMetric::<B>::new(), Aggregate::Mean,
118 Direction::Lowest,
119 Split::Valid,
120 ))
121 .build(),
122 ),
123 early_stopping: None,
124 training_strategy: TrainingStrategy::SingleDevice(Default::default()),
125 summary_metrics: BTreeSet::new(),
126 summary: false,
127 dataloader_train,
128 dataloader_valid,
129 }
130 }
131}
132
133impl<LC: LearningComponentsTypes> SupervisedTraining<LC> {
134 pub fn with_training_strategy(mut self, training_strategy: TrainingStrategy<LC>) -> Self {
140 self.training_strategy = training_strategy;
141 self
142 }
143
144 pub fn with_metric_logger<ML>(mut self, logger: ML) -> Self
150 where
151 ML: MetricLogger + 'static,
152 {
153 self.event_store.register_logger(logger);
154 self
155 }
156
157 pub fn with_checkpointing_strategy<CS: CheckpointingStrategy + 'static>(
159 mut self,
160 strategy: CS,
161 ) -> Self {
162 self.checkpointer_strategy = Box::new(strategy);
163 self
164 }
165
166 pub fn renderer<MR>(mut self, renderer: MR) -> Self
172 where
173 MR: MetricsRenderer + 'static,
174 {
175 self.renderer = Some(Box::new(renderer));
176 self
177 }
178
179 pub fn metrics<Me: MetricRegistration<LC>>(self, metrics: Me) -> Self {
181 metrics.register(self)
182 }
183
184 pub fn metrics_text<Me: TextMetricRegistration<LC>>(self, metrics: Me) -> Self {
186 metrics.register(self)
187 }
188
189 pub fn metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
191 where
192 <TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
193 {
194 self.metrics.register_train_metric(metric);
195 self
196 }
197
198 pub fn metric_valid<Me: Metric + 'static>(mut self, metric: Me) -> Self
200 where
201 <InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
202 {
203 self.metrics.register_valid_metric(metric);
204 self
205 }
206
207 pub fn grads_accumulation(mut self, accumulation: usize) -> Self {
218 self.grad_accumulation = Some(accumulation);
219 self
220 }
221
222 pub fn metric_train_numeric<Me>(mut self, metric: Me) -> Self
224 where
225 Me: Metric + Numeric + 'static,
226 <TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
227 {
228 self.summary_metrics.insert(metric.name().to_string());
229 self.metrics.register_train_metric_numeric(metric);
230 self
231 }
232
233 pub fn metric_valid_numeric<Me: Metric + Numeric + 'static>(mut self, metric: Me) -> Self
235 where
236 <InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
237 {
238 self.summary_metrics.insert(metric.name().to_string());
239 self.metrics.register_valid_metric_numeric(metric);
240 self
241 }
242
243 pub fn num_epochs(mut self, num_epochs: usize) -> Self {
245 self.num_epochs = num_epochs;
246 self
247 }
248
249 pub fn checkpoint(mut self, checkpoint: usize) -> Self {
251 self.checkpoint = Some(checkpoint);
252 self
253 }
254
255 pub fn interrupter(&self) -> Interrupter {
257 self.interrupter.clone()
258 }
259
260 pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
262 self.interrupter = interrupter;
263 self
264 }
265
266 pub fn early_stopping<Strategy>(mut self, strategy: Strategy) -> Self
269 where
270 Strategy: EarlyStoppingStrategy + Clone + Send + Sync + 'static,
271 {
272 self.early_stopping = Some(Box::new(strategy));
273 self
274 }
275
276 pub fn with_application_logger(
280 mut self,
281 logger: Option<Box<dyn ApplicationLoggerInstaller>>,
282 ) -> Self {
283 self.tracing_logger = logger;
284 self
285 }
286
287 pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
290 where
291 FR: FileRecorder<<LC as LearningComponentsTypes>::Backend> + 'static,
292 FR: FileRecorder<
293 <<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend,
294 > + 'static,
295 {
296 let checkpoint_dir = self.directory.join("checkpoint");
297 let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model");
298 let checkpointer_optimizer =
299 FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim");
300 let checkpointer_scheduler: FileCheckpointer<FR> =
301 FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler");
302
303 self.checkpointers = Some((
304 AsyncCheckpointer::new(checkpointer_model),
305 AsyncCheckpointer::new(checkpointer_optimizer),
306 AsyncCheckpointer::new(checkpointer_scheduler),
307 ));
308
309 self
310 }
311
312 pub fn summary(mut self) -> Self {
316 self.summary = true;
317 self
318 }
319}
320
321impl<LC: LearningComponentsTypes + Send + 'static> SupervisedTraining<LC> {
322 pub fn launch(mut self, learner: Learner<LC>) -> LearningResult<InferenceModel<LC>> {
324 if self.tracing_logger.is_some()
325 && let Err(e) = self.tracing_logger.as_ref().unwrap().install()
326 {
327 log::warn!("Failed to install the experiment logger: {e}");
328 }
329 let renderer = self
330 .renderer
331 .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
332
333 if !self.event_store.has_loggers() {
334 self.event_store
335 .register_logger(FileMetricLogger::new(self.directory.clone()));
336 }
337
338 let event_store = Arc::new(EventStoreClient::new(self.event_store));
339 let event_processor = AsyncProcessorTraining::new(FullEventProcessorTraining::new(
340 self.metrics,
341 renderer,
342 event_store.clone(),
343 ));
344
345 let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
346 LearningCheckpointer::new(
347 model.with_interrupter(self.interrupter.clone()),
348 optim.with_interrupter(self.interrupter.clone()),
349 scheduler.with_interrupter(self.interrupter.clone()),
350 self.checkpointer_strategy,
351 )
352 });
353
354 let summary = if self.summary {
355 Some(LearnerSummaryConfig {
356 directory: self.directory,
357 metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
358 })
359 } else {
360 None
361 };
362
363 let components = TrainingComponents {
364 checkpoint: self.checkpoint,
365 checkpointer,
366 interrupter: self.interrupter,
367 early_stopping: self.early_stopping,
368 event_processor,
369 event_store,
370 num_epochs: self.num_epochs,
371 grad_accumulation: self.grad_accumulation,
372 summary,
373 };
374
375 match self.training_strategy {
376 TrainingStrategy::SingleDevice(device) => {
377 let single_device: SingleDevicetrainingStrategy<LC> =
378 SingleDevicetrainingStrategy::new(device);
379 single_device.train(
380 learner,
381 self.dataloader_train,
382 self.dataloader_valid,
383 components,
384 )
385 }
386 TrainingStrategy::Custom(learning_paradigm) => learning_paradigm.train(
387 learner,
388 self.dataloader_train,
389 self.dataloader_valid,
390 components,
391 ),
392 TrainingStrategy::MultiDevice(devices, multi_device_optim) => {
393 let multi_device = MultiDeviceLearningStrategy::new(devices, multi_device_optim);
394 multi_device.train(
395 learner,
396 self.dataloader_train,
397 self.dataloader_valid,
398 components,
399 )
400 }
401 #[cfg(feature = "ddp")]
402 TrainingStrategy::DistributedDataParallel { devices, config } => {
403 use crate::ddp::DdpTrainingStrategy;
404
405 let ddp = DdpTrainingStrategy::new(devices.clone(), config.clone());
406 ddp.train(
407 learner,
408 self.dataloader_train,
409 self.dataloader_valid,
410 components,
411 )
412 }
413 }
414 }
415}
416
417pub trait MetricRegistration<LC: LearningComponentsTypes>: Sized {
419 fn register(self, builder: SupervisedTraining<LC>) -> SupervisedTraining<LC>;
421}
422
423pub trait TextMetricRegistration<LC: LearningComponentsTypes>: Sized {
425 fn register(self, builder: SupervisedTraining<LC>) -> SupervisedTraining<LC>;
427}
428
429macro_rules! gen_tuple {
430 ($($M:ident),*) => {
431 impl<$($M,)* LC: LearningComponentsTypes> TextMetricRegistration<LC> for ($($M,)*)
432 where
433 $(<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
434 $(<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
435 $($M: Metric + 'static,)*
436 {
437 #[allow(non_snake_case)]
438 fn register(
439 self,
440 builder: SupervisedTraining<LC>,
441 ) -> SupervisedTraining<LC> {
442 let ($($M,)*) = self;
443 $(let builder = builder.metric_train($M.clone());)*
444 $(let builder = builder.metric_valid($M);)*
445 builder
446 }
447 }
448
449 impl<$($M,)* LC: LearningComponentsTypes> MetricRegistration<LC> for ($($M,)*)
450 where
451 $(<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
452 $(<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
453 $($M: Metric + Numeric + 'static,)*
454 {
455 #[allow(non_snake_case)]
456 fn register(
457 self,
458 builder: SupervisedTraining<LC>,
459 ) -> SupervisedTraining<LC> {
460 let ($($M,)*) = self;
461 $(let builder = builder.metric_train_numeric($M.clone());)*
462 $(let builder = builder.metric_valid_numeric($M);)*
463 builder
464 }
465 }
466 };
467}
468
469gen_tuple!(M1);
470gen_tuple!(M1, M2);
471gen_tuple!(M1, M2, M3);
472gen_tuple!(M1, M2, M3, M4);
473gen_tuple!(M1, M2, M3, M4, M5);
474gen_tuple!(M1, M2, M3, M4, M5, M6);