use burn_core::data::dataloader::Progress;
use burn_optim::LearningRate;
use crate::{
LearnerSummary,
renderer::{EvaluationName, MetricsRenderer},
};
pub enum LearnerEvent<T> {
Start,
ProcessedItem(TrainingItem<T>),
EndEpoch(usize),
End(Option<LearnerSummary>),
}
pub enum EvaluatorEvent<T> {
Start,
ProcessedItem(EvaluationName, EvaluationItem<T>),
End(Option<LearnerSummary>),
}
pub trait ItemLazy: Send {
type ItemSync: Send;
fn sync(self) -> Self::ItemSync;
}
pub trait EventProcessorTraining<TrainEvent, ValidEvent>: Send {
fn process_train(&mut self, event: TrainEvent);
fn process_valid(&mut self, event: ValidEvent);
fn renderer(self) -> Box<dyn MetricsRenderer>;
}
pub trait EventProcessorEvaluation: Send {
type ItemTest: ItemLazy;
fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>);
fn renderer(self) -> Box<dyn MetricsRenderer>;
}
#[derive(new)]
pub struct TrainingItem<T> {
pub item: T,
pub progress: Progress,
pub global_progress: Progress,
pub iteration: Option<usize>,
pub lr: Option<LearningRate>,
}
impl<T: ItemLazy> ItemLazy for TrainingItem<T> {
type ItemSync = TrainingItem<T::ItemSync>;
fn sync(self) -> Self::ItemSync {
TrainingItem {
item: self.item.sync(),
progress: self.progress,
global_progress: self.global_progress,
iteration: self.iteration,
lr: self.lr,
}
}
}
#[derive(new)]
pub struct EvaluationItem<T> {
pub item: T,
pub progress: Progress,
pub iteration: Option<usize>,
}
impl<T: ItemLazy> ItemLazy for EvaluationItem<T> {
type ItemSync = EvaluationItem<T::ItemSync>;
fn sync(self) -> Self::ItemSync {
EvaluationItem {
item: self.item.sync(),
progress: self.progress,
iteration: self.iteration,
}
}
}
impl ItemLazy for () {
type ItemSync = ();
fn sync(self) -> Self::ItemSync {}
}