use std::sync::Arc;
use crate::{
LearnerSummary,
metric::{MetricDefinition, MetricEntry, NumericEntry},
};
use burn_core::data::dataloader::Progress;
pub trait MetricsRendererTraining: Send + Sync {
fn update_train(&mut self, state: MetricState);
fn update_valid(&mut self, state: MetricState);
fn render_train(&mut self, item: TrainingProgress);
fn render_valid(&mut self, item: TrainingProgress);
fn on_train_end(
&mut self,
summary: Option<LearnerSummary>,
) -> Result<(), Box<dyn core::error::Error>> {
if let Some(summary) = summary {
println!("{summary}");
}
Ok(())
}
}
pub trait MetricsRenderer: MetricsRendererEvaluation + MetricsRendererTraining {
fn manual_close(&mut self);
fn register_metric(&mut self, _definition: MetricDefinition);
}
#[derive(Clone)]
pub struct EvaluationName {
pub(crate) name: Arc<String>,
}
impl EvaluationName {
pub fn new<S: core::fmt::Display>(s: S) -> Self {
Self {
name: Arc::new(format!("{s}")),
}
}
}
pub trait MetricsRendererEvaluation: Send + Sync {
fn update_test(&mut self, name: EvaluationName, state: MetricState);
fn render_test(&mut self, item: EvaluationProgress);
fn on_test_end(&mut self) -> Result<(), Box<dyn core::error::Error>> {
Ok(())
}
}
#[derive(Debug)]
pub enum MetricState {
Generic(MetricEntry),
Numeric(MetricEntry, NumericEntry),
}
#[derive(Debug)]
pub struct TrainingProgress {
pub progress: Progress,
pub epoch: usize,
pub epoch_total: usize,
pub iteration: usize,
}
#[derive(Debug)]
pub struct EvaluationProgress {
pub progress: Progress,
pub iteration: usize,
}
impl TrainingProgress {
pub fn none() -> Self {
Self {
progress: Progress {
items_processed: 0,
items_total: 0,
},
epoch: 0,
epoch_total: 0,
iteration: 0,
}
}
}