1use crate::checkpoint::{
2 AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
3 KeepLastNCheckpoints, MetricCheckpointingStrategy,
4};
5use crate::components::{InferenceModelOutput, TrainingModelOutput};
6use crate::learner::EarlyStoppingStrategy;
7use crate::learner::base::Interrupter;
8use crate::logger::{FileMetricLogger, MetricLogger};
9use crate::metric::processor::{
10 AsyncProcessorTraining, FullEventProcessorTraining, ItemLazy, MetricsTraining,
11};
12use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
13use crate::metric::{Adaptor, LossMetric, Metric, Numeric};
14use crate::multi::MultiDeviceLearningStrategy;
15use crate::renderer::{MetricsRenderer, default_renderer};
16use crate::single::SingleDeviceTrainingStrategy;
17use crate::{
18 ApplicationLoggerInstaller, EarlyStoppingStrategyRef, ExecutionStrategy,
19 FileApplicationLoggerInstaller, InferenceBackend, InferenceModel, InferenceModelInput,
20 InferenceStep, LearnerEvent, LearnerModelRecord, LearnerOptimizerRecord,
21 LearnerSchedulerRecord, LearnerSummaryConfig, LearningCheckpointer, LearningComponentsMarker,
22 LearningComponentsTypes, LearningResult, TrainStep, TrainingBackend, TrainingComponents,
23 TrainingModelInput, TrainingStrategy,
24};
25use crate::{Learner, SupervisedLearningStrategy};
26use burn_core::data::dataloader::DataLoader;
27use burn_core::module::{AutodiffModule, Module};
28use burn_core::record::FileRecorder;
29use burn_core::tensor::backend::AutodiffBackend;
30use burn_optim::Optimizer;
31use burn_optim::lr_scheduler::LrScheduler;
32use std::collections::BTreeSet;
33use std::path::{Path, PathBuf};
34use std::sync::Arc;
35
36pub type TrainLoader<LC> = Arc<dyn DataLoader<TrainingBackend<LC>, TrainingModelInput<LC>>>;
38pub type ValidLoader<LC> = Arc<dyn DataLoader<InferenceBackend<LC>, InferenceModelInput<LC>>>;
40pub type SupervisedTrainingEventProcessor<LC> = AsyncProcessorTraining<
42 LearnerEvent<TrainingModelOutput<LC>>,
43 LearnerEvent<InferenceModelOutput<LC>>,
44>;
45
46pub struct SupervisedTraining<LC>
48where
49 LC: LearningComponentsTypes,
50{
51 #[allow(clippy::type_complexity)]
53 checkpointers: Option<(
54 AsyncCheckpointer<LearnerModelRecord<LC>, TrainingBackend<LC>>,
55 AsyncCheckpointer<LearnerOptimizerRecord<LC>, TrainingBackend<LC>>,
56 AsyncCheckpointer<LearnerSchedulerRecord<LC>, TrainingBackend<LC>>,
57 )>,
58 num_epochs: usize,
59 checkpoint: Option<usize>,
60 directory: PathBuf,
61 grad_accumulation: Option<usize>,
62 renderer: Option<Box<dyn MetricsRenderer + 'static>>,
63 metrics: MetricsTraining<TrainingModelOutput<LC>, InferenceModelOutput<LC>>,
64 event_store: LogEventStore,
65 interrupter: Interrupter,
66 tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
67 checkpointer_strategy: Box<dyn CheckpointingStrategy>,
68 early_stopping: Option<EarlyStoppingStrategyRef>,
69 training_strategy: Option<TrainingStrategy<LC>>,
70 dataloader_train: TrainLoader<LC>,
71 dataloader_valid: ValidLoader<LC>,
72 summary_metrics: BTreeSet<String>,
74 summary: bool,
75}
76
77impl<B, LR, M, O> SupervisedTraining<LearningComponentsMarker<B, LR, M, O>>
78where
79 B: AutodiffBackend,
80 LR: LrScheduler + 'static,
81 M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
82 M::InnerModule: InferenceStep,
83 O: Optimizer<M, B> + 'static,
84{
85 pub fn new(
93 directory: impl AsRef<Path>,
94 dataloader_train: Arc<dyn DataLoader<B, M::Input>>,
95 dataloader_valid: Arc<
96 dyn DataLoader<B::InnerBackend, <M::InnerModule as InferenceStep>::Input>,
97 >,
98 ) -> Self {
99 let directory = directory.as_ref().to_path_buf();
100 let experiment_log_file = directory.join("experiment.log");
101 Self {
102 num_epochs: 1,
103 checkpoint: None,
104 checkpointers: None,
105 directory,
106 grad_accumulation: None,
107 metrics: MetricsTraining::default(),
108 event_store: LogEventStore::default(),
109 renderer: None,
110 interrupter: Interrupter::new(),
111 tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
112 experiment_log_file,
113 ))),
114 checkpointer_strategy: Box::new(
115 ComposedCheckpointingStrategy::builder()
116 .add(KeepLastNCheckpoints::new(2))
117 .add(MetricCheckpointingStrategy::new(
118 &LossMetric::<B>::new(), Aggregate::Mean,
120 Direction::Lowest,
121 Split::Valid,
122 ))
123 .build(),
124 ),
125 early_stopping: None,
126 training_strategy: None,
127 summary_metrics: BTreeSet::new(),
128 summary: false,
129 dataloader_train,
130 dataloader_valid,
131 }
132 }
133}
134
135impl<LC: LearningComponentsTypes> SupervisedTraining<LC> {
136 pub fn with_training_strategy(mut self, training_strategy: TrainingStrategy<LC>) -> Self {
142 self.training_strategy = Some(training_strategy);
143 self
144 }
145
146 pub fn with_metric_logger<ML>(mut self, logger: ML) -> Self
152 where
153 ML: MetricLogger + 'static,
154 {
155 self.event_store.register_logger(logger);
156 self
157 }
158
159 pub fn with_checkpointing_strategy<CS: CheckpointingStrategy + 'static>(
161 mut self,
162 strategy: CS,
163 ) -> Self {
164 self.checkpointer_strategy = Box::new(strategy);
165 self
166 }
167
168 pub fn renderer<MR>(mut self, renderer: MR) -> Self
174 where
175 MR: MetricsRenderer + 'static,
176 {
177 self.renderer = Some(Box::new(renderer));
178 self
179 }
180
181 pub fn metrics<Me: MetricRegistration<LC>>(self, metrics: Me) -> Self {
183 metrics.register(self)
184 }
185
186 pub fn metrics_text<Me: TextMetricRegistration<LC>>(self, metrics: Me) -> Self {
188 metrics.register(self)
189 }
190
191 pub fn metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
193 where
194 <TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
195 {
196 self.metrics.register_train_metric(metric);
197 self
198 }
199
200 pub fn metric_valid<Me: Metric + 'static>(mut self, metric: Me) -> Self
202 where
203 <InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
204 {
205 self.metrics.register_valid_metric(metric);
206 self
207 }
208
209 pub fn grads_accumulation(mut self, accumulation: usize) -> Self {
220 self.grad_accumulation = Some(accumulation);
221 self
222 }
223
224 pub fn metric_train_numeric<Me>(mut self, metric: Me) -> Self
226 where
227 Me: Metric + Numeric + 'static,
228 <TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
229 {
230 self.summary_metrics.insert(metric.name().to_string());
231 self.metrics.register_train_metric_numeric(metric);
232 self
233 }
234
235 pub fn metric_valid_numeric<Me: Metric + Numeric + 'static>(mut self, metric: Me) -> Self
237 where
238 <InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<Me::Input>,
239 {
240 self.summary_metrics.insert(metric.name().to_string());
241 self.metrics.register_valid_metric_numeric(metric);
242 self
243 }
244
245 pub fn num_epochs(mut self, num_epochs: usize) -> Self {
247 self.num_epochs = num_epochs;
248 self
249 }
250
251 pub fn checkpoint(mut self, checkpoint: usize) -> Self {
253 self.checkpoint = Some(checkpoint);
254 self
255 }
256
257 pub fn interrupter(&self) -> Interrupter {
259 self.interrupter.clone()
260 }
261
262 pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
264 self.interrupter = interrupter;
265 self
266 }
267
268 pub fn early_stopping<Strategy>(mut self, strategy: Strategy) -> Self
271 where
272 Strategy: EarlyStoppingStrategy + Clone + Send + Sync + 'static,
273 {
274 self.early_stopping = Some(Box::new(strategy));
275 self
276 }
277
278 pub fn with_application_logger(
282 mut self,
283 logger: Option<Box<dyn ApplicationLoggerInstaller>>,
284 ) -> Self {
285 self.tracing_logger = logger;
286 self
287 }
288
289 pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
292 where
293 FR: FileRecorder<<LC as LearningComponentsTypes>::Backend> + 'static,
294 FR: FileRecorder<
295 <<LC as LearningComponentsTypes>::Backend as AutodiffBackend>::InnerBackend,
296 > + 'static,
297 {
298 let checkpoint_dir = self.directory.join("checkpoint");
299 let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model");
300 let checkpointer_optimizer =
301 FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim");
302 let checkpointer_scheduler: FileCheckpointer<FR> =
303 FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler");
304
305 self.checkpointers = Some((
306 AsyncCheckpointer::new(checkpointer_model),
307 AsyncCheckpointer::new(checkpointer_optimizer),
308 AsyncCheckpointer::new(checkpointer_scheduler),
309 ));
310
311 self
312 }
313
314 pub fn summary(mut self) -> Self {
318 self.summary = true;
319 self
320 }
321}
322
323impl<LC> SupervisedTraining<LC>
324where
325 LC: LearningComponentsTypes + Send + 'static,
326{
327 pub fn launch(mut self, learner: Learner<LC>) -> LearningResult<InferenceModel<LC>> {
329 if self.tracing_logger.is_some()
330 && let Err(e) = self.tracing_logger.as_ref().unwrap().install()
331 {
332 log::warn!("Failed to install the experiment logger: {e}");
333 }
334 let renderer = self
335 .renderer
336 .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
337
338 if !self.event_store.has_loggers() {
339 self.event_store
340 .register_logger(FileMetricLogger::new(self.directory.clone()));
341 }
342
343 let event_store = Arc::new(EventStoreClient::new(self.event_store));
344 let event_processor = AsyncProcessorTraining::new(FullEventProcessorTraining::new(
345 self.metrics,
346 renderer,
347 event_store.clone(),
348 ));
349
350 let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
351 LearningCheckpointer::new(
352 model.with_interrupter(self.interrupter.clone()),
353 optim.with_interrupter(self.interrupter.clone()),
354 scheduler.with_interrupter(self.interrupter.clone()),
355 self.checkpointer_strategy,
356 )
357 });
358
359 let summary = if self.summary {
360 Some(LearnerSummaryConfig {
361 directory: self.directory,
362 metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
363 })
364 } else {
365 None
366 };
367
368 let components = TrainingComponents {
369 checkpoint: self.checkpoint,
370 checkpointer,
371 interrupter: self.interrupter,
372 early_stopping: self.early_stopping,
373 event_processor,
374 event_store,
375 num_epochs: self.num_epochs,
376 grad_accumulation: self.grad_accumulation,
377 summary,
378 };
379
380 let training_strategy = self.training_strategy.unwrap_or(TrainingStrategy::Default(
382 ExecutionStrategy::SingleDevice(learner.model.devices()[0].clone()),
383 ));
384
385 match training_strategy {
386 TrainingStrategy::Custom(learning_paradigm) => learning_paradigm.train(
387 learner,
388 self.dataloader_train,
389 self.dataloader_valid,
390 components,
391 ),
392 TrainingStrategy::Default(strategy) => match strategy {
393 ExecutionStrategy::SingleDevice(device) => {
394 let single_device: SingleDeviceTrainingStrategy<LC> =
395 SingleDeviceTrainingStrategy::new(device);
396 single_device.train(
397 learner,
398 self.dataloader_train,
399 self.dataloader_valid,
400 components,
401 )
402 }
403 ExecutionStrategy::MultiDevice(devices, multi_device_optim) => {
404 let strategy: Box<dyn SupervisedLearningStrategy<LC>> = match devices.len() == 1
405 {
406 true => Box::new(SingleDeviceTrainingStrategy::new(devices[0].clone())),
407 false => Box::new(MultiDeviceLearningStrategy::new(
408 devices,
409 multi_device_optim,
410 )),
411 };
412 strategy.train(
413 learner,
414 self.dataloader_train,
415 self.dataloader_valid,
416 components,
417 )
418 }
419 #[cfg(feature = "ddp")]
420 ExecutionStrategy::DistributedDataParallel { devices, runtime } => {
421 use crate::ddp::DdpTrainingStrategy;
422
423 let ddp = DdpTrainingStrategy::new(devices, runtime);
424 ddp.train(
425 learner,
426 self.dataloader_train,
427 self.dataloader_valid,
428 components,
429 )
430 }
431 },
432 }
433 }
434}
435
436pub trait MetricRegistration<LC: LearningComponentsTypes>: Sized {
438 fn register(self, builder: SupervisedTraining<LC>) -> SupervisedTraining<LC>;
440}
441
442pub trait TextMetricRegistration<LC: LearningComponentsTypes>: Sized {
444 fn register(self, builder: SupervisedTraining<LC>) -> SupervisedTraining<LC>;
446}
447
448macro_rules! gen_tuple {
449 ($($M:ident),*) => {
450 impl<$($M,)* LC: LearningComponentsTypes> TextMetricRegistration<LC> for ($($M,)*)
451 where
452 $(<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
453 $(<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
454 $($M: Metric + 'static,)*
455 {
456 #[allow(non_snake_case)]
457 fn register(
458 self,
459 builder: SupervisedTraining<LC>,
460 ) -> SupervisedTraining<LC> {
461 let ($($M,)*) = self;
462 $(let builder = builder.metric_train($M.clone());)*
463 $(let builder = builder.metric_valid($M);)*
464 builder
465 }
466 }
467
468 impl<$($M,)* LC: LearningComponentsTypes> MetricRegistration<LC> for ($($M,)*)
469 where
470 $(<TrainingModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
471 $(<InferenceModelOutput<LC> as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
472 $($M: Metric + Numeric + 'static,)*
473 {
474 #[allow(non_snake_case)]
475 fn register(
476 self,
477 builder: SupervisedTraining<LC>,
478 ) -> SupervisedTraining<LC> {
479 let ($($M,)*) = self;
480 $(let builder = builder.metric_train_numeric($M.clone());)*
481 $(let builder = builder.metric_valid_numeric($M);)*
482 builder
483 }
484 }
485 };
486}
487
488gen_tuple!(M1);
489gen_tuple!(M1, M2);
490gen_tuple!(M1, M2, M3);
491gen_tuple!(M1, M2, M3, M4);
492gen_tuple!(M1, M2, M3, M4, M5);
493gen_tuple!(M1, M2, M3, M4, M5, M6);