burn_train/renderer/base.rs
1use std::sync::Arc;
2
3use crate::{
4 LearnerSummary,
5 metric::{MetricDefinition, MetricEntry, NumericEntry},
6};
7use burn_core::data::dataloader::Progress;
8
9/// Trait for rendering metrics.
10pub trait MetricsRendererTraining: Send + Sync {
11 /// Updates the training metric state.
12 ///
13 /// # Arguments
14 ///
15 /// * `state` - The metric state.
16 fn update_train(&mut self, state: MetricState);
17
18 /// Updates the validation metric state.
19 ///
20 /// # Arguments
21 ///
22 /// * `state` - The metric state.
23 fn update_valid(&mut self, state: MetricState);
24
25 /// Renders the training progress.
26 ///
27 /// # Arguments
28 ///
29 /// * `item` - The training progress.
30 fn render_train(&mut self, item: TrainingProgress);
31
32 /// Renders the validation progress.
33 ///
34 /// # Arguments
35 ///
36 /// * `item` - The validation progress.
37 fn render_valid(&mut self, item: TrainingProgress);
38
39 /// Callback method invoked when training ends, whether it
40 /// completed successfully or was interrupted.
41 ///
42 /// # Returns
43 ///
44 /// A result indicating whether the end-of-training actions were successful.
45 fn on_train_end(
46 &mut self,
47 summary: Option<LearnerSummary>,
48 ) -> Result<(), Box<dyn core::error::Error>> {
49 if let Some(summary) = summary {
50 println!("{summary}");
51 }
52 Ok(())
53 }
54}
55
56/// A renderer that can be used for both training and evaluation.
57pub trait MetricsRenderer: MetricsRendererEvaluation + MetricsRendererTraining {
58 /// Keep the renderer from automatically closing, requiring manual action to close it.
59 fn manual_close(&mut self);
60 /// Register a new metric.
61 fn register_metric(&mut self, _definition: MetricDefinition);
62}
63
64#[derive(Clone)]
65/// The name of an evaluation.
66///
67/// This is going to group metrics together for easier analysis.
68pub struct EvaluationName {
69 pub(crate) name: Arc<String>,
70}
71
72impl EvaluationName {
73 /// Creates a new metric name.
74 pub fn new<S: core::fmt::Display>(s: S) -> Self {
75 Self {
76 name: Arc::new(format!("{s}")),
77 }
78 }
79}
80
81/// Trait for rendering metrics.
82pub trait MetricsRendererEvaluation: Send + Sync {
83 /// Updates the testing metric state.
84 ///
85 /// # Arguments
86 ///
87 /// * `state` - The metric state.
88 fn update_test(&mut self, name: EvaluationName, state: MetricState);
89 /// Renders the testing progress.
90 ///
91 /// # Arguments
92 ///
93 /// * `item` - The training progress.
94 fn render_test(&mut self, item: EvaluationProgress);
95
96 /// Callback method invoked when testing ends, whether it
97 /// completed successfully or was interrupted.
98 ///
99 /// # Returns
100 ///
101 /// A result indicating whether the end-of-testing actions were successful.
102 fn on_test_end(&mut self) -> Result<(), Box<dyn core::error::Error>> {
103 Ok(())
104 }
105}
106
107/// The state of a metric.
108#[derive(Debug)]
109pub enum MetricState {
110 /// A generic metric.
111 Generic(MetricEntry),
112 /// A numeric metric.
113 Numeric(MetricEntry, NumericEntry),
114}
115
116/// Training progress.
117#[derive(Debug)]
118pub struct TrainingProgress {
119 /// The progress.
120 pub progress: Progress,
121
122 /// The epoch.
123 pub epoch: usize,
124
125 /// The total number of epochs.
126 pub epoch_total: usize,
127
128 /// The iteration.
129 pub iteration: usize,
130}
131
132/// Evaluation progress.
133#[derive(Debug)]
134pub struct EvaluationProgress {
135 /// The progress.
136 pub progress: Progress,
137
138 /// The iteration.
139 pub iteration: usize,
140}
141
142impl TrainingProgress {
143 /// Creates a new empty training progress.
144 pub fn none() -> Self {
145 Self {
146 progress: Progress {
147 items_processed: 0,
148 items_total: 0,
149 },
150 epoch: 0,
151 epoch_total: 0,
152 iteration: 0,
153 }
154 }
155}