burn_train/renderer/base.rs
1use std::sync::Arc;
2
3use crate::{
4 LearnerSummary,
5 metric::{MetricEntry, NumericEntry},
6};
7use burn_core::data::dataloader::Progress;
8
9/// Trait for rendering metrics.
10pub trait MetricsRendererTraining: Send + Sync {
11 /// Updates the training metric state.
12 ///
13 /// # Arguments
14 ///
15 /// * `state` - The metric state.
16 fn update_train(&mut self, state: MetricState);
17
18 /// Updates the validation metric state.
19 ///
20 /// # Arguments
21 ///
22 /// * `state` - The metric state.
23 fn update_valid(&mut self, state: MetricState);
24
25 /// Renders the training progress.
26 ///
27 /// # Arguments
28 ///
29 /// * `item` - The training progress.
30 fn render_train(&mut self, item: TrainingProgress);
31
32 /// Renders the validation progress.
33 ///
34 /// # Arguments
35 ///
36 /// * `item` - The validation progress.
37 fn render_valid(&mut self, item: TrainingProgress);
38
39 /// Callback method invoked when training ends, whether it
40 /// completed successfully or was interrupted.
41 ///
42 /// # Returns
43 ///
44 /// A result indicating whether the end-of-training actions were successful.
45 fn on_train_end(
46 &mut self,
47 summary: Option<LearnerSummary>,
48 ) -> Result<(), Box<dyn core::error::Error>> {
49 if let Some(summary) = summary {
50 println!("{summary}");
51 }
52 Ok(())
53 }
54}
55
56/// A renderer that can be used for both training and evaluation.
57pub trait MetricsRenderer: MetricsRendererEvaluation + MetricsRendererTraining {
58 /// Keep the renderer from automatically closing, requiring manual action to close it.
59 fn manual_close(&mut self);
60}
61
62#[derive(Clone)]
63/// The name of an evaluation.
64///
65/// This is going to group metrics together for easier analysis.
66pub struct EvaluationName {
67 pub(crate) name: Arc<String>,
68}
69
70impl EvaluationName {
71 /// Creates a new metric name.
72 pub fn new<S: core::fmt::Display>(s: S) -> Self {
73 Self {
74 name: Arc::new(format!("{s}")),
75 }
76 }
77}
78
79/// Trait for rendering metrics.
80pub trait MetricsRendererEvaluation: Send + Sync {
81 /// Updates the testing metric state.
82 ///
83 /// # Arguments
84 ///
85 /// * `state` - The metric state.
86 fn update_test(&mut self, name: EvaluationName, state: MetricState);
87 /// Renders the testing progress.
88 ///
89 /// # Arguments
90 ///
91 /// * `item` - The training progress.
92 fn render_test(&mut self, item: EvaluationProgress);
93
94 /// Callback method invoked when testing ends, whether it
95 /// completed successfully or was interrupted.
96 ///
97 /// # Returns
98 ///
99 /// A result indicating whether the end-of-testing actions were successful.
100 fn on_test_end(&mut self) -> Result<(), Box<dyn core::error::Error>> {
101 Ok(())
102 }
103}
104
105/// The state of a metric.
106#[derive(Debug)]
107pub enum MetricState {
108 /// A generic metric.
109 Generic(MetricEntry),
110 /// A numeric metric.
111 Numeric(MetricEntry, NumericEntry),
112}
113
114/// Training progress.
115#[derive(Debug)]
116pub struct TrainingProgress {
117 /// The progress.
118 pub progress: Progress,
119
120 /// The epoch.
121 pub epoch: usize,
122
123 /// The total number of epochs.
124 pub epoch_total: usize,
125
126 /// The iteration.
127 pub iteration: usize,
128}
129
130/// Evaluation progress.
131#[derive(Debug)]
132pub struct EvaluationProgress {
133 /// The progress.
134 pub progress: Progress,
135
136 /// The iteration.
137 pub iteration: usize,
138}
139
140impl TrainingProgress {
141 /// Creates a new empty training progress.
142 pub fn none() -> Self {
143 Self {
144 progress: Progress {
145 items_processed: 0,
146 items_total: 0,
147 },
148 epoch: 0,
149 epoch_total: 0,
150 iteration: 0,
151 }
152 }
153}