burn_train/renderer/
base.rs

1use std::sync::Arc;
2
3use crate::{
4    LearnerSummary,
5    metric::{MetricEntry, NumericEntry},
6};
7use burn_core::data::dataloader::Progress;
8
9/// Trait for rendering metrics.
10pub trait MetricsRendererTraining: Send + Sync {
11    /// Updates the training metric state.
12    ///
13    /// # Arguments
14    ///
15    /// * `state` - The metric state.
16    fn update_train(&mut self, state: MetricState);
17
18    /// Updates the validation metric state.
19    ///
20    /// # Arguments
21    ///
22    /// * `state` - The metric state.
23    fn update_valid(&mut self, state: MetricState);
24
25    /// Renders the training progress.
26    ///
27    /// # Arguments
28    ///
29    /// * `item` - The training progress.
30    fn render_train(&mut self, item: TrainingProgress);
31
32    /// Renders the validation progress.
33    ///
34    /// # Arguments
35    ///
36    /// * `item` - The validation progress.
37    fn render_valid(&mut self, item: TrainingProgress);
38
39    /// Callback method invoked when training ends, whether it
40    /// completed successfully or was interrupted.
41    ///
42    /// # Returns
43    ///
44    /// A result indicating whether the end-of-training actions were successful.
45    fn on_train_end(
46        &mut self,
47        summary: Option<LearnerSummary>,
48    ) -> Result<(), Box<dyn core::error::Error>> {
49        if let Some(summary) = summary {
50            println!("{summary}");
51        }
52        Ok(())
53    }
54}
55
56/// A renderer that can be used for both training and evaluation.
57pub trait MetricsRenderer: MetricsRendererEvaluation + MetricsRendererTraining {
58    /// Keep the renderer from automatically closing, requiring manual action to close it.
59    fn manual_close(&mut self);
60}
61
62#[derive(Clone)]
63/// The name of an evaluation.
64///
65/// This is going to group metrics together for easier analysis.
66pub struct EvaluationName {
67    pub(crate) name: Arc<String>,
68}
69
70impl EvaluationName {
71    /// Creates a new metric name.
72    pub fn new<S: core::fmt::Display>(s: S) -> Self {
73        Self {
74            name: Arc::new(format!("{s}")),
75        }
76    }
77}
78
79/// Trait for rendering metrics.
80pub trait MetricsRendererEvaluation: Send + Sync {
81    /// Updates the testing metric state.
82    ///
83    /// # Arguments
84    ///
85    /// * `state` - The metric state.
86    fn update_test(&mut self, name: EvaluationName, state: MetricState);
87    /// Renders the testing progress.
88    ///
89    /// # Arguments
90    ///
91    /// * `item` - The training progress.
92    fn render_test(&mut self, item: EvaluationProgress);
93
94    /// Callback method invoked when testing ends, whether it
95    /// completed successfully or was interrupted.
96    ///
97    /// # Returns
98    ///
99    /// A result indicating whether the end-of-testing actions were successful.
100    fn on_test_end(&mut self) -> Result<(), Box<dyn core::error::Error>> {
101        Ok(())
102    }
103}
104
105/// The state of a metric.
106#[derive(Debug)]
107pub enum MetricState {
108    /// A generic metric.
109    Generic(MetricEntry),
110    /// A numeric metric.
111    Numeric(MetricEntry, NumericEntry),
112}
113
114/// Training progress.
115#[derive(Debug)]
116pub struct TrainingProgress {
117    /// The progress.
118    pub progress: Progress,
119
120    /// The epoch.
121    pub epoch: usize,
122
123    /// The total number of epochs.
124    pub epoch_total: usize,
125
126    /// The iteration.
127    pub iteration: usize,
128}
129
130/// Evaluation progress.
131#[derive(Debug)]
132pub struct EvaluationProgress {
133    /// The progress.
134    pub progress: Progress,
135
136    /// The iteration.
137    pub iteration: usize,
138}
139
140impl TrainingProgress {
141    /// Creates a new empty training progress.
142    pub fn none() -> Self {
143        Self {
144            progress: Progress {
145                items_processed: 0,
146                items_total: 0,
147            },
148            epoch: 0,
149            epoch_total: 0,
150            iteration: 0,
151        }
152    }
153}