use burn_core::data::dataloader::Progress;
use burn_optim::LearningRate;
use crate::{
LearnerSummary,
renderer::{EvaluationName, MetricsRenderer},
};
pub enum LearnerEvent<T> {
Start,
ProcessedItem(LearnerItem<T>),
EndEpoch(usize),
End(Option<LearnerSummary>),
}
pub enum EvaluatorEvent<T> {
Start,
ProcessedItem(EvaluationName, LearnerItem<T>),
End,
}
pub trait ItemLazy: Send {
type ItemSync: Send;
fn sync(self) -> Self::ItemSync;
}
pub trait EventProcessorTraining: Send {
type ItemTrain: ItemLazy;
type ItemValid: ItemLazy;
fn process_train(&mut self, event: LearnerEvent<Self::ItemTrain>);
fn process_valid(&mut self, event: LearnerEvent<Self::ItemValid>);
fn renderer(self) -> Box<dyn MetricsRenderer>;
}
pub trait EventProcessorEvaluation: Send {
type ItemTest: ItemLazy;
fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>);
fn renderer(self) -> Box<dyn MetricsRenderer>;
}
#[derive(new)]
pub struct LearnerItem<T> {
pub item: T,
pub progress: Progress,
pub epoch: usize,
pub epoch_total: usize,
pub iteration: usize,
pub lr: Option<LearningRate>,
}
impl<T: ItemLazy> ItemLazy for LearnerItem<T> {
type ItemSync = LearnerItem<T::ItemSync>;
fn sync(self) -> Self::ItemSync {
LearnerItem {
item: self.item.sync(),
progress: self.progress,
epoch: self.epoch,
epoch_total: self.epoch_total,
iteration: self.iteration,
lr: self.lr,
}
}
}
impl ItemLazy for () {
type ItemSync = ();
fn sync(self) -> Self::ItemSync {}
}