Skip to main content

burn_train/metric/processor/
base.rs

1use burn_core::data::dataloader::Progress;
2use burn_optim::LearningRate;
3
4use crate::{
5    LearnerSummary,
6    renderer::{EvaluationName, MetricsRenderer},
7};
8
9/// Event happening during the training/validation process.
10pub enum LearnerEvent<T> {
11    /// Signal the start of the process (e.g., training start)
12    Start,
13    /// Signal that an item have been processed.
14    ProcessedItem(TrainingItem<T>),
15    /// Signal the end of an epoch.
16    EndEpoch(usize),
17    /// Signal the end of the process (e.g., training end).
18    End(Option<LearnerSummary>),
19}
20
21/// Event happening during the evaluation process.
22pub enum EvaluatorEvent<T> {
23    /// Signal the start of the process (e.g., evaluation start)
24    Start,
25    /// Signal that an item have been processed.
26    ProcessedItem(EvaluationName, EvaluationItem<T>),
27    /// Signal the end of the process (e.g., evaluation end).
28    End(Option<LearnerSummary>),
29}
30
31/// Items that are lazy are not ready to be processed by metrics.
32///
33/// We want to sync them on a different thread to avoid blocking training.
34pub trait ItemLazy: Send {
35    /// Item that is properly synced and ready to be processed by metrics.
36    type ItemSync: Send;
37
38    /// Sync the item.
39    fn sync(self) -> Self::ItemSync;
40}
41
42/// Process events happening during training and validation.
43pub trait EventProcessorTraining<TrainEvent, ValidEvent>: Send {
44    /// Collect a training event.
45    fn process_train(&mut self, event: TrainEvent);
46    /// Collect a validation event.
47    fn process_valid(&mut self, event: ValidEvent);
48    /// Returns the renderer used for training.
49    fn renderer(self) -> Box<dyn MetricsRenderer>;
50}
51
52/// Process events happening during evaluation.
53pub trait EventProcessorEvaluation: Send {
54    /// The test item.
55    type ItemTest: ItemLazy;
56
57    /// Collect a test event.
58    fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>);
59
60    /// Returns the renderer used for evaluation.
61    fn renderer(self) -> Box<dyn MetricsRenderer>;
62}
63
64/// A learner item.
65#[derive(new)]
66pub struct TrainingItem<T> {
67    /// The item.
68    pub item: T,
69
70    /// The progress.
71    pub progress: Progress,
72
73    /// The global progress of the training (e.g. epochs).
74    pub global_progress: Progress,
75
76    /// The iteration, if it it different from the items processed.
77    pub iteration: Option<usize>,
78
79    /// The learning rate.
80    pub lr: Option<LearningRate>,
81}
82
83impl<T: ItemLazy> ItemLazy for TrainingItem<T> {
84    type ItemSync = TrainingItem<T::ItemSync>;
85
86    fn sync(self) -> Self::ItemSync {
87        TrainingItem {
88            item: self.item.sync(),
89            progress: self.progress,
90            global_progress: self.global_progress,
91            iteration: self.iteration,
92            lr: self.lr,
93        }
94    }
95}
96
97/// An evaluation item.
98#[derive(new)]
99pub struct EvaluationItem<T> {
100    /// The item.
101    pub item: T,
102
103    /// The progress.
104    pub progress: Progress,
105
106    /// The iteration, if it it different from the items processed.
107    pub iteration: Option<usize>,
108}
109
110impl<T: ItemLazy> ItemLazy for EvaluationItem<T> {
111    type ItemSync = EvaluationItem<T::ItemSync>;
112
113    fn sync(self) -> Self::ItemSync {
114        EvaluationItem {
115            item: self.item.sync(),
116            progress: self.progress,
117            iteration: self.iteration,
118        }
119    }
120}
121
122impl ItemLazy for () {
123    type ItemSync = ();
124
125    fn sync(self) -> Self::ItemSync {}
126}