Skip to main content

burn_train/renderer/
base.rs

1use std::sync::Arc;
2
3use crate::{
4    LearnerSummary,
5    metric::{MetricDefinition, MetricEntry, NumericEntry},
6};
7use burn_core::data::dataloader::Progress;
8
9/// Trait for rendering metrics.
10pub trait MetricsRendererTraining: Send + Sync {
11    /// Updates the training metric state.
12    ///
13    /// # Arguments
14    ///
15    /// * `state` - The metric state.
16    fn update_train(&mut self, state: MetricState);
17
18    /// Updates the validation metric state.
19    ///
20    /// # Arguments
21    ///
22    /// * `state` - The metric state.
23    fn update_valid(&mut self, state: MetricState);
24
25    /// Renders the training progress.
26    ///
27    /// # Arguments
28    ///
29    /// * `item` - The training progress.
30    fn render_train(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>);
31
32    /// Renders the validation progress.
33    ///
34    /// # Arguments
35    ///
36    /// * `item` - The validation progress.
37    fn render_valid(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>);
38
39    /// Callback method invoked when training ends, whether it
40    /// completed successfully or was interrupted.
41    ///
42    /// # Returns
43    ///
44    /// A result indicating whether the end-of-training actions were successful.
45    fn on_train_end(
46        &mut self,
47        summary: Option<LearnerSummary>,
48    ) -> Result<(), Box<dyn core::error::Error>> {
49        default_summary_action(summary);
50        Ok(())
51    }
52}
53
54/// A renderer that can be used for both training and evaluation.
55pub trait MetricsRenderer: MetricsRendererEvaluation + MetricsRendererTraining {
56    /// Keep the renderer from automatically closing, requiring manual action to close it.
57    fn manual_close(&mut self);
58    /// Register a new metric.
59    fn register_metric(&mut self, definition: MetricDefinition);
60}
61
62#[derive(Clone)]
63/// The name of an evaluation.
64///
65/// This is going to group metrics together for easier analysis.
66pub struct EvaluationName {
67    pub(crate) name: Arc<String>,
68}
69
70impl EvaluationName {
71    /// Creates a new evaluation name.
72    pub fn new<S: core::fmt::Display>(s: S) -> Self {
73        Self {
74            name: Arc::new(format!("{s}")),
75        }
76    }
77
78    /// Returns the evaluation name.
79    pub fn as_str(&self) -> &str {
80        &self.name
81    }
82}
83
84impl core::fmt::Display for EvaluationName {
85    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
86        f.write_str(&self.name)
87    }
88}
89
90/// Trait for rendering metrics.
91pub trait MetricsRendererEvaluation: Send + Sync {
92    /// Updates the testing metric state.
93    ///
94    /// # Arguments
95    ///
96    /// * `state` - The metric state.
97    fn update_test(&mut self, name: EvaluationName, state: MetricState);
98    /// Renders the testing progress.
99    ///
100    /// # Arguments
101    ///
102    /// * `item` - The training progress.
103    fn render_test(&mut self, item: EvaluationProgress, progress_indicators: Vec<ProgressType>);
104
105    /// Callback method invoked when testing ends, whether it
106    /// completed successfully or was interrupted.
107    ///
108    /// # Returns
109    ///
110    /// A result indicating whether the end-of-testing actions were successful.
111    fn on_test_end(
112        &mut self,
113        summary: Option<LearnerSummary>,
114    ) -> Result<(), Box<dyn core::error::Error>> {
115        default_summary_action(summary);
116        Ok(())
117    }
118}
119
120/// The state of a metric.
121#[derive(Debug)]
122pub enum MetricState {
123    /// A generic metric.
124    Generic(MetricEntry),
125    /// A numeric metric.
126    Numeric(MetricEntry, NumericEntry),
127}
128
129/// Training progress.
130#[derive(Debug)]
131pub struct TrainingProgress {
132    /// The progress.
133    pub progress: Option<Progress>,
134
135    /// The progress of the whole training.
136    pub global_progress: Progress,
137
138    /// The iteration, if it differs from the items processed.
139    pub iteration: Option<usize>,
140}
141
142/// Evaluation progress.
143#[derive(Debug)]
144pub struct EvaluationProgress {
145    /// The progress.
146    pub progress: Progress,
147
148    /// The iteration, if it is different from the processed items.
149    pub iteration: Option<usize>,
150}
151
152impl From<&EvaluationProgress> for TrainingProgress {
153    fn from(value: &EvaluationProgress) -> Self {
154        TrainingProgress {
155            progress: None,
156            global_progress: value.progress.clone(),
157            iteration: value.iteration,
158        }
159    }
160}
161
162impl TrainingProgress {
163    /// Creates a new empty training progress.
164    pub fn none() -> Self {
165        Self {
166            progress: None,
167            global_progress: Progress {
168                items_processed: 0,
169                items_total: 0,
170            },
171            iteration: None,
172        }
173    }
174}
175
176/// Type of progress indicators.
177pub enum ProgressType {
178    /// Detailed progress.
179    Detailed {
180        /// The tag.
181        tag: String,
182        /// The progress.
183        progress: Progress,
184    },
185    /// Simple value.
186    Value {
187        /// The tag.
188        tag: String,
189        /// The value.
190        value: usize,
191    },
192}
193
194fn default_summary_action(summary: Option<LearnerSummary>) {
195    if let Some(summary) = summary {
196        println!("{summary}");
197    }
198}