1use std::collections::HashSet;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use super::Learner;
6use crate::checkpoint::{
7 AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
8 KeepLastNCheckpoints, MetricCheckpointingStrategy,
9};
10use crate::components::LearnerComponentsMarker;
11use crate::learner::base::TrainingInterrupter;
12use crate::learner::EarlyStoppingStrategy;
13use crate::logger::{FileMetricLogger, MetricLogger};
14use crate::metric::processor::{AsyncProcessor, FullEventProcessor, ItemLazy, Metrics};
15use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
16use crate::metric::{Adaptor, LossMetric, Metric};
17use crate::renderer::{default_renderer, MetricsRenderer};
18use crate::{
19 ApplicationLoggerInstaller, FileApplicationLoggerInstaller, LearnerCheckpointer,
20 LearnerSummaryConfig,
21};
22use burn_core::lr_scheduler::LrScheduler;
23use burn_core::module::AutodiffModule;
24use burn_core::optim::Optimizer;
25use burn_core::record::FileRecorder;
26use burn_core::tensor::backend::AutodiffBackend;
27
28pub struct LearnerBuilder<B, T, V, M, O, S>
30where
31 T: ItemLazy + 'static,
32 V: ItemLazy + 'static,
33 B: AutodiffBackend,
34 M: AutodiffModule<B>,
35 O: Optimizer<M, B>,
36 S: LrScheduler,
37{
38 #[allow(clippy::type_complexity)]
42 checkpointers: Option<(
43 AsyncCheckpointer<M::Record, B>,
44 AsyncCheckpointer<O::Record, B>,
45 AsyncCheckpointer<S::Record<B>, B>,
46 )>,
47 num_epochs: usize,
48 checkpoint: Option<usize>,
49 directory: PathBuf,
50 grad_accumulation: Option<usize>,
51 devices: Vec<B::Device>,
52 renderer: Option<Box<dyn MetricsRenderer + 'static>>,
53 metrics: Metrics<T, V>,
54 event_store: LogEventStore,
55 interrupter: TrainingInterrupter,
56 tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
57 num_loggers: usize,
58 checkpointer_strategy: Box<dyn CheckpointingStrategy>,
59 early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
60 summary_metrics: HashSet<String>,
61 summary: bool,
62}
63
64impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
65where
66 B: AutodiffBackend,
67 T: ItemLazy + 'static,
68 V: ItemLazy + 'static,
69 M: AutodiffModule<B> + core::fmt::Display + 'static,
70 O: Optimizer<M, B>,
71 S: LrScheduler,
72{
73 pub fn new(directory: impl AsRef<Path>) -> Self {
79 let directory = directory.as_ref().to_path_buf();
80 let experiment_log_file = directory.join("experiment.log");
81 Self {
82 num_epochs: 1,
83 checkpoint: None,
84 checkpointers: None,
85 directory,
86 grad_accumulation: None,
87 devices: vec![B::Device::default()],
88 metrics: Metrics::default(),
89 event_store: LogEventStore::default(),
90 renderer: None,
91 interrupter: TrainingInterrupter::new(),
92 tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
93 experiment_log_file,
94 ))),
95 num_loggers: 0,
96 checkpointer_strategy: Box::new(
97 ComposedCheckpointingStrategy::builder()
98 .add(KeepLastNCheckpoints::new(2))
99 .add(MetricCheckpointingStrategy::new::<LossMetric<B>>(
100 Aggregate::Mean,
101 Direction::Lowest,
102 Split::Valid,
103 ))
104 .build(),
105 ),
106 early_stopping: None,
107 summary_metrics: HashSet::new(),
108 summary: false,
109 }
110 }
111
112 pub fn metric_loggers<MT, MV>(mut self, logger_train: MT, logger_valid: MV) -> Self
119 where
120 MT: MetricLogger + 'static,
121 MV: MetricLogger + 'static,
122 {
123 self.event_store.register_logger_train(logger_train);
124 self.event_store.register_logger_valid(logger_valid);
125 self.num_loggers += 1;
126 self
127 }
128
129 pub fn with_checkpointing_strategy<CS>(mut self, strategy: CS) -> Self
131 where
132 CS: CheckpointingStrategy + 'static,
133 {
134 self.checkpointer_strategy = Box::new(strategy);
135 self
136 }
137
138 pub fn renderer<MR>(mut self, renderer: MR) -> Self
144 where
145 MR: MetricsRenderer + 'static,
146 {
147 self.renderer = Some(Box::new(renderer));
148 self
149 }
150
151 pub fn metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
153 where
154 T::ItemSync: Adaptor<Me::Input>,
155 {
156 self.metrics.register_train_metric(metric);
157 self
158 }
159
160 pub fn metric_valid<Me: Metric + 'static>(mut self, metric: Me) -> Self
162 where
163 V::ItemSync: Adaptor<Me::Input>,
164 {
165 self.metrics.register_valid_metric(metric);
166 self
167 }
168
169 pub fn grads_accumulation(mut self, accumulation: usize) -> Self {
180 self.grad_accumulation = Some(accumulation);
181 self
182 }
183
184 pub fn metric_train_numeric<Me>(mut self, metric: Me) -> Self
186 where
187 Me: Metric + crate::metric::Numeric + 'static,
188 T::ItemSync: Adaptor<Me::Input>,
189 {
190 self.summary_metrics.insert(Me::NAME.to_string());
191 self.metrics.register_train_metric_numeric(metric);
192 self
193 }
194
195 pub fn metric_valid_numeric<Me: Metric + crate::metric::Numeric + 'static>(
197 mut self,
198 metric: Me,
199 ) -> Self
200 where
201 V::ItemSync: Adaptor<Me::Input>,
202 {
203 self.summary_metrics.insert(Me::NAME.to_string());
204 self.metrics.register_valid_metric_numeric(metric);
205 self
206 }
207
208 pub fn num_epochs(mut self, num_epochs: usize) -> Self {
210 self.num_epochs = num_epochs;
211 self
212 }
213
214 pub fn devices(mut self, devices: Vec<B::Device>) -> Self {
216 self.devices = devices;
217 self
218 }
219
220 pub fn checkpoint(mut self, checkpoint: usize) -> Self {
222 self.checkpoint = Some(checkpoint);
223 self
224 }
225
226 pub fn interrupter(&self) -> TrainingInterrupter {
228 self.interrupter.clone()
229 }
230
231 pub fn early_stopping<Strategy>(mut self, strategy: Strategy) -> Self
234 where
235 Strategy: EarlyStoppingStrategy + 'static,
236 {
237 self.early_stopping = Some(Box::new(strategy));
238 self
239 }
240
241 pub fn with_application_logger(
245 mut self,
246 logger: Option<Box<dyn ApplicationLoggerInstaller>>,
247 ) -> Self {
248 self.tracing_logger = logger;
249 self
250 }
251
252 pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
255 where
256 FR: FileRecorder<B> + 'static,
257 FR: FileRecorder<B::InnerBackend> + 'static,
258 O::Record: 'static,
259 M::Record: 'static,
260 S::Record<B>: 'static,
261 {
262 let checkpoint_dir = self.directory.join("checkpoint");
263 let checkpointer_model = FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "model");
264 let checkpointer_optimizer =
265 FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "optim");
266 let checkpointer_scheduler: FileCheckpointer<FR> =
267 FileCheckpointer::new(recorder, &checkpoint_dir, "scheduler");
268
269 self.checkpointers = Some((
270 AsyncCheckpointer::new(checkpointer_model),
271 AsyncCheckpointer::new(checkpointer_optimizer),
272 AsyncCheckpointer::new(checkpointer_scheduler),
273 ));
274
275 self
276 }
277
278 pub fn summary(mut self) -> Self {
282 self.summary = true;
283 self
284 }
285
286 #[allow(clippy::type_complexity)] pub fn build(
292 mut self,
293 model: M,
294 optim: O,
295 lr_scheduler: S,
296 ) -> Learner<
297 LearnerComponentsMarker<
298 B,
299 S,
300 M,
301 O,
302 AsyncCheckpointer<M::Record, B>,
303 AsyncCheckpointer<O::Record, B>,
304 AsyncCheckpointer<S::Record<B>, B>,
305 AsyncProcessor<FullEventProcessor<T, V>>,
306 Box<dyn CheckpointingStrategy>,
307 >,
308 >
309 where
310 M::Record: 'static,
311 O::Record: 'static,
312 S::Record<B>: 'static,
313 {
314 if self.tracing_logger.is_some() {
315 if let Err(e) = self.tracing_logger.as_ref().unwrap().install() {
316 log::warn!("Failed to install the experiment logger: {}", e);
317 }
318 }
319 let renderer = self
320 .renderer
321 .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
322
323 if self.num_loggers == 0 {
324 self.event_store
325 .register_logger_train(FileMetricLogger::new(self.directory.join("train")));
326 self.event_store
327 .register_logger_valid(FileMetricLogger::new(self.directory.join("valid")));
328 }
329
330 let event_store = Arc::new(EventStoreClient::new(self.event_store));
331 let event_processor = AsyncProcessor::new(FullEventProcessor::new(
332 self.metrics,
333 renderer,
334 event_store.clone(),
335 ));
336
337 let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
338 LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy)
339 });
340
341 let summary = if self.summary {
342 Some(LearnerSummaryConfig {
343 directory: self.directory,
344 metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
345 })
346 } else {
347 None
348 };
349
350 Learner {
351 model,
352 optim,
353 lr_scheduler,
354 checkpointer,
355 num_epochs: self.num_epochs,
356 event_processor,
357 event_store,
358 checkpoint: self.checkpoint,
359 grad_accumulation: self.grad_accumulation,
360 devices: self.devices,
361 interrupter: self.interrupter,
362 early_stopping: self.early_stopping,
363 summary,
364 }
365 }
366}