burn_train/metric/processor/
base.rs1use burn_core::data::dataloader::Progress;
2use burn_optim::LearningRate;
3
4use crate::{
5 LearnerSummary,
6 renderer::{EvaluationName, MetricsRenderer},
7};
8
9pub enum LearnerEvent<T> {
11 Start,
13 ProcessedItem(TrainingItem<T>),
15 EndEpoch(usize),
17 End(Option<LearnerSummary>),
19}
20
21pub enum EvaluatorEvent<T> {
23 Start,
25 ProcessedItem(EvaluationName, EvaluationItem<T>),
27 End(Option<LearnerSummary>),
29}
30
31pub trait ItemLazy: Send {
35 type ItemSync: Send;
37
38 fn sync(self) -> Self::ItemSync;
40}
41
42pub trait EventProcessorTraining<TrainEvent, ValidEvent>: Send {
44 fn process_train(&mut self, event: TrainEvent);
46 fn process_valid(&mut self, event: ValidEvent);
48 fn renderer(self) -> Box<dyn MetricsRenderer>;
50}
51
52pub trait EventProcessorEvaluation: Send {
54 type ItemTest: ItemLazy;
56
57 fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>);
59
60 fn renderer(self) -> Box<dyn MetricsRenderer>;
62}
63
64#[derive(new)]
66pub struct TrainingItem<T> {
67 pub item: T,
69
70 pub progress: Progress,
72
73 pub global_progress: Progress,
75
76 pub iteration: Option<usize>,
78
79 pub lr: Option<LearningRate>,
81}
82
83impl<T: ItemLazy> ItemLazy for TrainingItem<T> {
84 type ItemSync = TrainingItem<T::ItemSync>;
85
86 fn sync(self) -> Self::ItemSync {
87 TrainingItem {
88 item: self.item.sync(),
89 progress: self.progress,
90 global_progress: self.global_progress,
91 iteration: self.iteration,
92 lr: self.lr,
93 }
94 }
95}
96
97#[derive(new)]
99pub struct EvaluationItem<T> {
100 pub item: T,
102
103 pub progress: Progress,
105
106 pub iteration: Option<usize>,
108}
109
110impl<T: ItemLazy> ItemLazy for EvaluationItem<T> {
111 type ItemSync = EvaluationItem<T::ItemSync>;
112
113 fn sync(self) -> Self::ItemSync {
114 EvaluationItem {
115 item: self.item.sync(),
116 progress: self.progress,
117 iteration: self.iteration,
118 }
119 }
120}
121
122impl ItemLazy for () {
123 type ItemSync = ();
124
125 fn sync(self) -> Self::ItemSync {}
126}