burn_train/renderer/base.rs
1use burn_core::data::dataloader::Progress;
2
3use crate::metric::MetricEntry;
4
5/// Trait for rendering metrics.
6pub trait MetricsRenderer: Send + Sync {
7 /// Updates the training metric state.
8 ///
9 /// # Arguments
10 ///
11 /// * `state` - The metric state.
12 fn update_train(&mut self, state: MetricState);
13
14 /// Updates the validation metric state.
15 ///
16 /// # Arguments
17 ///
18 /// * `state` - The metric state.
19 fn update_valid(&mut self, state: MetricState);
20
21 /// Renders the training progress.
22 ///
23 /// # Arguments
24 ///
25 /// * `item` - The training progress.
26 fn render_train(&mut self, item: TrainingProgress);
27
28 /// Renders the validation progress.
29 ///
30 /// # Arguments
31 ///
32 /// * `item` - The validation progress.
33 fn render_valid(&mut self, item: TrainingProgress);
34
35 /// Callback method invoked when training ends, whether it
36 /// completed successfully or was interrupted.
37 ///
38 /// # Returns
39 ///
40 /// A result indicating whether the end-of-training actions were successful.
41 fn on_train_end(&mut self) -> Result<(), Box<dyn core::error::Error>> {
42 Ok(())
43 }
44}
45
46/// The state of a metric.
47#[derive(Debug)]
48pub enum MetricState {
49 /// A generic metric.
50 Generic(MetricEntry),
51
52 /// A numeric metric.
53 Numeric(MetricEntry, f64),
54}
55
56/// Training progress.
57#[derive(Debug)]
58pub struct TrainingProgress {
59 /// The progress.
60 pub progress: Progress,
61
62 /// The epoch.
63 pub epoch: usize,
64
65 /// The total number of epochs.
66 pub epoch_total: usize,
67
68 /// The iteration.
69 pub iteration: usize,
70}
71
72impl TrainingProgress {
73 /// Creates a new empty training progress.
74 pub fn none() -> Self {
75 Self {
76 progress: Progress {
77 items_processed: 0,
78 items_total: 0,
79 },
80 epoch: 0,
81 epoch_total: 0,
82 iteration: 0,
83 }
84 }
85}