burn_train/renderer/
base.rs1use std::sync::Arc;
2
3use crate::{
4 LearnerSummary,
5 metric::{MetricDefinition, MetricEntry, NumericEntry},
6};
7use burn_core::data::dataloader::Progress;
8
9pub trait MetricsRendererTraining: Send + Sync {
11 fn update_train(&mut self, state: MetricState);
17
18 fn update_valid(&mut self, state: MetricState);
24
25 fn render_train(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>);
31
32 fn render_valid(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>);
38
39 fn on_train_end(
46 &mut self,
47 summary: Option<LearnerSummary>,
48 ) -> Result<(), Box<dyn core::error::Error>> {
49 default_summary_action(summary);
50 Ok(())
51 }
52}
53
54pub trait MetricsRenderer: MetricsRendererEvaluation + MetricsRendererTraining {
56 fn manual_close(&mut self);
58 fn register_metric(&mut self, definition: MetricDefinition);
60}
61
62#[derive(Clone)]
63pub struct EvaluationName {
67 pub(crate) name: Arc<String>,
68}
69
70impl EvaluationName {
71 pub fn new<S: core::fmt::Display>(s: S) -> Self {
73 Self {
74 name: Arc::new(format!("{s}")),
75 }
76 }
77
78 pub fn as_str(&self) -> &str {
80 &self.name
81 }
82}
83
84impl core::fmt::Display for EvaluationName {
85 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
86 f.write_str(&self.name)
87 }
88}
89
90pub trait MetricsRendererEvaluation: Send + Sync {
92 fn update_test(&mut self, name: EvaluationName, state: MetricState);
98 fn render_test(&mut self, item: EvaluationProgress, progress_indicators: Vec<ProgressType>);
104
105 fn on_test_end(
112 &mut self,
113 summary: Option<LearnerSummary>,
114 ) -> Result<(), Box<dyn core::error::Error>> {
115 default_summary_action(summary);
116 Ok(())
117 }
118}
119
120#[derive(Debug)]
122pub enum MetricState {
123 Generic(MetricEntry),
125 Numeric(MetricEntry, NumericEntry),
127}
128
129#[derive(Debug)]
131pub struct TrainingProgress {
132 pub progress: Option<Progress>,
134
135 pub global_progress: Progress,
137
138 pub iteration: Option<usize>,
140}
141
142#[derive(Debug)]
144pub struct EvaluationProgress {
145 pub progress: Progress,
147
148 pub iteration: Option<usize>,
150}
151
152impl From<&EvaluationProgress> for TrainingProgress {
153 fn from(value: &EvaluationProgress) -> Self {
154 TrainingProgress {
155 progress: None,
156 global_progress: value.progress.clone(),
157 iteration: value.iteration,
158 }
159 }
160}
161
162impl TrainingProgress {
163 pub fn none() -> Self {
165 Self {
166 progress: None,
167 global_progress: Progress {
168 items_processed: 0,
169 items_total: 0,
170 },
171 iteration: None,
172 }
173 }
174}
175
176pub enum ProgressType {
178 Detailed {
180 tag: String,
182 progress: Progress,
184 },
185 Value {
187 tag: String,
189 value: usize,
191 },
192}
193
194fn default_summary_action(summary: Option<LearnerSummary>) {
195 if let Some(summary) = summary {
196 println!("{summary}");
197 }
198}