burn-train 0.21.0-pre.3

Training crate for the Burn framework
Documentation
use std::collections::HashMap;

use super::{Aggregate, Direction, Event, EventStore, Split, aggregate::NumericMetricsAggregate};
use crate::logger::MetricLogger;

#[derive(Default)]
pub(crate) struct LogEventStore {
    loggers: Vec<Box<dyn MetricLogger>>,
    aggregate: NumericMetricsAggregate,
    epochs: HashMap<Split, usize>,
}

impl EventStore for LogEventStore {
    fn add_event(&mut self, event: Event, split: Split) {
        let epoch = *self.epochs.entry(split.clone()).or_insert(1);

        match event {
            Event::MetricsInit(definitions) => {
                definitions.iter().for_each(|def| {
                    self.loggers
                        .iter_mut()
                        .for_each(|logger| logger.log_metric_definition(def.clone()));
                });
            }
            Event::MetricsUpdate(update) => {
                self.loggers
                    .iter_mut()
                    .for_each(|logger| logger.log(update.clone(), epoch, &split));
            }
            Event::EndEpoch(summary) => {
                self.epochs.insert(split, summary.epoch_number + 1);
                self.loggers
                    .iter_mut()
                    .for_each(|logger| logger.log_epoch_summary(summary.clone()));
            }
        }
    }

    fn find_epoch(
        &mut self,
        name: &str,
        aggregate: Aggregate,
        direction: Direction,
        split: &Split,
    ) -> Option<usize> {
        self.aggregate
            .find_epoch(name, split, aggregate, direction, &mut self.loggers)
    }

    fn find_metric(
        &mut self,
        name: &str,
        epoch: usize,
        aggregate: Aggregate,
        split: &Split,
    ) -> Option<f64> {
        self.aggregate
            .aggregate(name, epoch, split, aggregate, &mut self.loggers)
    }
}

impl LogEventStore {
    /// Register a logger for metrics.
    pub(crate) fn register_logger<ML: MetricLogger + 'static>(&mut self, logger: ML) {
        self.loggers.push(Box::new(logger));
    }

    /// Returns whether any loggers are registered.
    pub(crate) fn has_loggers(&self) -> bool {
        !self.loggers.is_empty()
    }
}