burn_train/metric/processor/
base.rs1use burn_core::data::dataloader::Progress;
2use burn_optim::LearningRate;
3
4use crate::{
5 LearnerSummary,
6 renderer::{EvaluationName, MetricsRenderer},
7};
8
9pub enum LearnerEvent<T> {
11 Start,
13 ProcessedItem(LearnerItem<T>),
15 EndEpoch(usize),
17 End(Option<LearnerSummary>),
19}
20
21pub enum EvaluatorEvent<T> {
23 ProcessedItem(EvaluationName, LearnerItem<T>),
25 End,
27}
28
29pub trait ItemLazy: Send {
33 type ItemSync: Send;
35
36 fn sync(self) -> Self::ItemSync;
38}
39
40pub trait EventProcessorTraining: Send {
42 type ItemTrain: ItemLazy;
44 type ItemValid: ItemLazy;
46
47 fn process_train(&mut self, event: LearnerEvent<Self::ItemTrain>);
49 fn process_valid(&mut self, event: LearnerEvent<Self::ItemValid>);
51 fn renderer(self) -> Box<dyn MetricsRenderer>;
53}
54
55pub trait EventProcessorEvaluation: Send {
57 type ItemTest: ItemLazy;
59
60 fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>);
62
63 fn renderer(self) -> Box<dyn MetricsRenderer>;
65}
66
67#[derive(new)]
69pub struct LearnerItem<T> {
70 pub item: T,
72
73 pub progress: Progress,
75
76 pub epoch: usize,
78
79 pub epoch_total: usize,
81
82 pub iteration: usize,
84
85 pub lr: Option<LearningRate>,
87}
88
89impl<T: ItemLazy> ItemLazy for LearnerItem<T> {
90 type ItemSync = LearnerItem<T::ItemSync>;
91
92 fn sync(self) -> Self::ItemSync {
93 LearnerItem {
94 item: self.item.sync(),
95 progress: self.progress,
96 epoch: self.epoch,
97 epoch_total: self.epoch_total,
98 iteration: self.iteration,
99 lr: self.lr,
100 }
101 }
102}
103
104impl ItemLazy for () {
105 type ItemSync = ();
106
107 fn sync(self) -> Self::ItemSync {}
108}