burn_train/renderer/
base.rs

1use std::sync::Arc;
2
3use crate::{
4    LearnerSummary,
5    metric::{MetricDefinition, MetricEntry, NumericEntry},
6};
7use burn_core::data::dataloader::Progress;
8
9/// Trait for rendering metrics.
10pub trait MetricsRendererTraining: Send + Sync {
11    /// Updates the training metric state.
12    ///
13    /// # Arguments
14    ///
15    /// * `state` - The metric state.
16    fn update_train(&mut self, state: MetricState);
17
18    /// Updates the validation metric state.
19    ///
20    /// # Arguments
21    ///
22    /// * `state` - The metric state.
23    fn update_valid(&mut self, state: MetricState);
24
25    /// Renders the training progress.
26    ///
27    /// # Arguments
28    ///
29    /// * `item` - The training progress.
30    fn render_train(&mut self, item: TrainingProgress);
31
32    /// Renders the validation progress.
33    ///
34    /// # Arguments
35    ///
36    /// * `item` - The validation progress.
37    fn render_valid(&mut self, item: TrainingProgress);
38
39    /// Callback method invoked when training ends, whether it
40    /// completed successfully or was interrupted.
41    ///
42    /// # Returns
43    ///
44    /// A result indicating whether the end-of-training actions were successful.
45    fn on_train_end(
46        &mut self,
47        summary: Option<LearnerSummary>,
48    ) -> Result<(), Box<dyn core::error::Error>> {
49        if let Some(summary) = summary {
50            println!("{summary}");
51        }
52        Ok(())
53    }
54}
55
56/// A renderer that can be used for both training and evaluation.
57pub trait MetricsRenderer: MetricsRendererEvaluation + MetricsRendererTraining {
58    /// Keep the renderer from automatically closing, requiring manual action to close it.
59    fn manual_close(&mut self);
60    /// Register a new metric.
61    fn register_metric(&mut self, _definition: MetricDefinition);
62}
63
64#[derive(Clone)]
65/// The name of an evaluation.
66///
67/// This is going to group metrics together for easier analysis.
68pub struct EvaluationName {
69    pub(crate) name: Arc<String>,
70}
71
72impl EvaluationName {
73    /// Creates a new metric name.
74    pub fn new<S: core::fmt::Display>(s: S) -> Self {
75        Self {
76            name: Arc::new(format!("{s}")),
77        }
78    }
79}
80
81/// Trait for rendering metrics.
82pub trait MetricsRendererEvaluation: Send + Sync {
83    /// Updates the testing metric state.
84    ///
85    /// # Arguments
86    ///
87    /// * `state` - The metric state.
88    fn update_test(&mut self, name: EvaluationName, state: MetricState);
89    /// Renders the testing progress.
90    ///
91    /// # Arguments
92    ///
93    /// * `item` - The training progress.
94    fn render_test(&mut self, item: EvaluationProgress);
95
96    /// Callback method invoked when testing ends, whether it
97    /// completed successfully or was interrupted.
98    ///
99    /// # Returns
100    ///
101    /// A result indicating whether the end-of-testing actions were successful.
102    fn on_test_end(&mut self) -> Result<(), Box<dyn core::error::Error>> {
103        Ok(())
104    }
105}
106
107/// The state of a metric.
108#[derive(Debug)]
109pub enum MetricState {
110    /// A generic metric.
111    Generic(MetricEntry),
112    /// A numeric metric.
113    Numeric(MetricEntry, NumericEntry),
114}
115
116/// Training progress.
117#[derive(Debug)]
118pub struct TrainingProgress {
119    /// The progress.
120    pub progress: Progress,
121
122    /// The epoch.
123    pub epoch: usize,
124
125    /// The total number of epochs.
126    pub epoch_total: usize,
127
128    /// The iteration.
129    pub iteration: usize,
130}
131
132/// Evaluation progress.
133#[derive(Debug)]
134pub struct EvaluationProgress {
135    /// The progress.
136    pub progress: Progress,
137
138    /// The iteration.
139    pub iteration: usize,
140}
141
142impl TrainingProgress {
143    /// Creates a new empty training progress.
144    pub fn none() -> Self {
145        Self {
146            progress: Progress {
147                items_processed: 0,
148                items_total: 0,
149            },
150            epoch: 0,
151            epoch_total: 0,
152            iteration: 0,
153        }
154    }
155}