burn-train 0.21.0-pre.4

Training crate for the Burn framework
Documentation
use super::{EventProcessorTraining, ItemLazy, LearnerEvent, MetricsTraining};
use crate::{
    metric::store::{EpochSummary, EventStoreClient, Split},
    renderer::cli::CliMetricsRenderer,
};
use std::sync::Arc;

/// An [event processor](EventProcessor) that handles:
///   - Computing and storing metrics in an [event store](crate::metric::store::EventStore).
#[allow(dead_code)]
#[derive(new)]
pub(crate) struct MinimalEventProcessor<T: ItemLazy, V: ItemLazy> {
    metrics: MetricsTraining<T, V>,
    store: Arc<EventStoreClient>,
}

impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining<LearnerEvent<T>, LearnerEvent<V>>
    for MinimalEventProcessor<T, V>
{
    fn process_train(&mut self, event: LearnerEvent<T>) {
        match event {
            LearnerEvent::Start => {
                let definitions = self.metrics.metric_definitions();
                self.store
                    .add_event_train(crate::metric::store::Event::MetricsInit(definitions));
            }

            LearnerEvent::ProcessedItem(item) => {
                let item = item.sync();
                let metadata = (&item).into();

                let update = self.metrics.update_train(&item, &metadata);

                self.store
                    .add_event_train(crate::metric::store::Event::MetricsUpdate(update));
            }
            LearnerEvent::EndEpoch(epoch) => {
                self.metrics.end_epoch_train();
                self.store
                    .add_event_train(crate::metric::store::Event::EndEpoch(EpochSummary::new(
                        epoch,
                        Split::Train,
                    )));
            }
            LearnerEvent::End(_summary) => {} // no-op for now
        }
    }

    fn process_valid(&mut self, event: LearnerEvent<V>) {
        match event {
            LearnerEvent::Start => {} // no-op for now
            LearnerEvent::ProcessedItem(item) => {
                let item = item.sync();
                let metadata = (&item).into();

                let update = self.metrics.update_valid(&item, &metadata);

                self.store
                    .add_event_valid(crate::metric::store::Event::MetricsUpdate(update));
            }
            LearnerEvent::EndEpoch(epoch) => {
                self.metrics.end_epoch_valid();
                self.store
                    .add_event_valid(crate::metric::store::Event::EndEpoch(EpochSummary::new(
                        epoch,
                        Split::Valid,
                    )));
            }
            LearnerEvent::End(_) => {} // no-op for now
        }
    }
    fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
        // TODO: Check for another default.
        Box::new(CliMetricsRenderer::new())
    }
}