burn_train/metric/processor/
base.rs1use burn_core::data::dataloader::Progress;
2use burn_optim::LearningRate;
3
4use crate::{
5 LearnerSummary,
6 renderer::{EvaluationName, MetricsRenderer},
7};
8
9pub enum LearnerEvent<T> {
11 Start,
13 ProcessedItem(LearnerItem<T>),
15 EndEpoch(usize),
17 End(Option<LearnerSummary>),
19}
20
21pub enum EvaluatorEvent<T> {
23 Start,
25 ProcessedItem(EvaluationName, LearnerItem<T>),
27 End,
29}
30
31pub trait ItemLazy: Send {
35 type ItemSync: Send;
37
38 fn sync(self) -> Self::ItemSync;
40}
41
42pub trait EventProcessorTraining: Send {
44 type ItemTrain: ItemLazy;
46 type ItemValid: ItemLazy;
48
49 fn process_train(&mut self, event: LearnerEvent<Self::ItemTrain>);
51 fn process_valid(&mut self, event: LearnerEvent<Self::ItemValid>);
53 fn renderer(self) -> Box<dyn MetricsRenderer>;
55}
56
57pub trait EventProcessorEvaluation: Send {
59 type ItemTest: ItemLazy;
61
62 fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>);
64
65 fn renderer(self) -> Box<dyn MetricsRenderer>;
67}
68
69#[derive(new)]
71pub struct LearnerItem<T> {
72 pub item: T,
74
75 pub progress: Progress,
77
78 pub epoch: usize,
80
81 pub epoch_total: usize,
83
84 pub iteration: usize,
86
87 pub lr: Option<LearningRate>,
89}
90
91impl<T: ItemLazy> ItemLazy for LearnerItem<T> {
92 type ItemSync = LearnerItem<T::ItemSync>;
93
94 fn sync(self) -> Self::ItemSync {
95 LearnerItem {
96 item: self.item.sync(),
97 progress: self.progress,
98 epoch: self.epoch,
99 epoch_total: self.epoch_total,
100 iteration: self.iteration,
101 lr: self.lr,
102 }
103 }
104}
105
106impl ItemLazy for () {
107 type ItemSync = ();
108
109 fn sync(self) -> Self::ItemSync {}
110}