burn_train/renderer/
base.rs

1use burn_core::data::dataloader::Progress;
2
3use crate::metric::MetricEntry;
4
5/// Trait for rendering metrics.
6pub trait MetricsRenderer: Send + Sync {
7    /// Updates the training metric state.
8    ///
9    /// # Arguments
10    ///
11    /// * `state` - The metric state.
12    fn update_train(&mut self, state: MetricState);
13
14    /// Updates the validation metric state.
15    ///
16    /// # Arguments
17    ///
18    /// * `state` - The metric state.
19    fn update_valid(&mut self, state: MetricState);
20
21    /// Renders the training progress.
22    ///
23    /// # Arguments
24    ///
25    /// * `item` - The training progress.
26    fn render_train(&mut self, item: TrainingProgress);
27
28    /// Renders the validation progress.
29    ///
30    /// # Arguments
31    ///
32    /// * `item` - The validation progress.
33    fn render_valid(&mut self, item: TrainingProgress);
34
35    /// Callback method invoked when training ends, whether it
36    /// completed successfully or was interrupted.
37    ///
38    /// # Returns
39    ///
40    /// A result indicating whether the end-of-training actions were successful.
41    fn on_train_end(&mut self) -> Result<(), Box<dyn core::error::Error>> {
42        Ok(())
43    }
44}
45
46/// The state of a metric.
47#[derive(Debug)]
48pub enum MetricState {
49    /// A generic metric.
50    Generic(MetricEntry),
51
52    /// A numeric metric.
53    Numeric(MetricEntry, f64),
54}
55
56/// Training progress.
57#[derive(Debug)]
58pub struct TrainingProgress {
59    /// The progress.
60    pub progress: Progress,
61
62    /// The epoch.
63    pub epoch: usize,
64
65    /// The total number of epochs.
66    pub epoch_total: usize,
67
68    /// The iteration.
69    pub iteration: usize,
70}
71
72impl TrainingProgress {
73    /// Creates a new empty training progress.
74    pub fn none() -> Self {
75        Self {
76            progress: Progress {
77                items_processed: 0,
78                items_total: 0,
79            },
80            epoch: 0,
81            epoch_total: 0,
82            iteration: 0,
83        }
84    }
85}