1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
use burn_core::data::dataloader::Progress;
use crate::metric::MetricEntry;
/// Trait for rendering metrics.
pub trait MetricsRenderer: Send + Sync {
/// Updates the training metric state.
///
/// # Arguments
///
/// * `state` - The metric state.
fn update_train(&mut self, state: MetricState);
/// Updates the validation metric state.
///
/// # Arguments
///
/// * `state` - The metric state.
fn update_valid(&mut self, state: MetricState);
/// Renders the training progress.
///
/// # Arguments
///
/// * `item` - The training progress.
fn render_train(&mut self, item: TrainingProgress);
/// Renders the validation progress.
///
/// # Arguments
///
/// * `item` - The validation progress.
fn render_valid(&mut self, item: TrainingProgress);
}
/// The state of a metric.
#[derive(Debug)]
pub enum MetricState {
/// A generic metric.
Generic(MetricEntry),
/// A numeric metric.
Numeric(MetricEntry, f64),
}
/// Training progress.
#[derive(Debug)]
pub struct TrainingProgress {
/// The progress.
pub progress: Progress,
/// The epoch.
pub epoch: usize,
/// The total number of epochs.
pub epoch_total: usize,
/// The iteration.
pub iteration: usize,
}
impl TrainingProgress {
/// Creates a new empty training progress.
pub fn none() -> Self {
Self {
progress: Progress {
items_processed: 0,
items_total: 0,
},
epoch: 0,
epoch_total: 0,
iteration: 0,
}
}
}