use std::sync::Arc;
use crate::{
LearnerSummary,
metric::{MetricDefinition, MetricEntry, NumericEntry},
};
use burn_core::data::dataloader::Progress;
pub trait MetricsRendererTraining: Send + Sync {
fn update_train(&mut self, state: MetricState);
fn update_valid(&mut self, state: MetricState);
fn render_train(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>);
fn render_valid(&mut self, item: TrainingProgress, progress_indicators: Vec<ProgressType>);
fn on_train_end(
&mut self,
summary: Option<LearnerSummary>,
) -> Result<(), Box<dyn core::error::Error>> {
default_summary_action(summary);
Ok(())
}
}
pub trait MetricsRenderer: MetricsRendererEvaluation + MetricsRendererTraining {
fn manual_close(&mut self);
fn register_metric(&mut self, definition: MetricDefinition);
}
#[derive(Clone)]
pub struct EvaluationName {
pub(crate) name: Arc<String>,
}
impl EvaluationName {
pub fn new<S: core::fmt::Display>(s: S) -> Self {
Self {
name: Arc::new(format!("{s}")),
}
}
pub fn as_str(&self) -> &str {
&self.name
}
}
impl core::fmt::Display for EvaluationName {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(&self.name)
}
}
pub trait MetricsRendererEvaluation: Send + Sync {
fn update_test(&mut self, name: EvaluationName, state: MetricState);
fn render_test(&mut self, item: EvaluationProgress, progress_indicators: Vec<ProgressType>);
fn on_test_end(
&mut self,
summary: Option<LearnerSummary>,
) -> Result<(), Box<dyn core::error::Error>> {
default_summary_action(summary);
Ok(())
}
}
#[derive(Debug)]
pub enum MetricState {
Generic(MetricEntry),
Numeric(MetricEntry, NumericEntry),
}
#[derive(Debug)]
pub struct TrainingProgress {
pub progress: Option<Progress>,
pub global_progress: Progress,
pub iteration: Option<usize>,
}
#[derive(Debug)]
pub struct EvaluationProgress {
pub progress: Progress,
pub iteration: Option<usize>,
}
impl From<&EvaluationProgress> for TrainingProgress {
fn from(value: &EvaluationProgress) -> Self {
TrainingProgress {
progress: None,
global_progress: value.progress.clone(),
iteration: value.iteration,
}
}
}
impl TrainingProgress {
pub fn none() -> Self {
Self {
progress: None,
global_progress: Progress {
items_processed: 0,
items_total: 0,
},
iteration: None,
}
}
}
pub enum ProgressType {
Detailed {
tag: String,
progress: Progress,
},
Value {
tag: String,
value: usize,
},
}
fn default_summary_action(summary: Option<LearnerSummary>) {
if let Some(summary) = summary {
println!("{summary}");
}
}