burn_train/metric/processor/
base.rs

1use burn_core::data::dataloader::Progress;
2use burn_optim::LearningRate;
3
4use crate::{
5    LearnerSummary,
6    renderer::{EvaluationName, MetricsRenderer},
7};
8
9/// Event happening during the training/validation process.
10pub enum LearnerEvent<T> {
11    /// Signal the start of the process (e.g., training start)
12    Start,
13    /// Signal that an item have been processed.
14    ProcessedItem(LearnerItem<T>),
15    /// Signal the end of an epoch.
16    EndEpoch(usize),
17    /// Signal the end of the process (e.g., training end).
18    End(Option<LearnerSummary>),
19}
20
21/// Event happening during the evaluation process.
22pub enum EvaluatorEvent<T> {
23    /// Signal that an item have been processed.
24    ProcessedItem(EvaluationName, LearnerItem<T>),
25    /// Signal the end of the process (e.g., training end).
26    End,
27}
28
29/// Items that are lazy are not ready to be processed by metrics.
30///
31/// We want to sync them on a different thread to avoid blocking training.
32pub trait ItemLazy: Send {
33    /// Item that is properly synced and ready to be processed by metrics.
34    type ItemSync: Send;
35
36    /// Sync the item.
37    fn sync(self) -> Self::ItemSync;
38}
39
40/// Process events happening during training and validation.
41pub trait EventProcessorTraining: Send {
42    /// The training item.
43    type ItemTrain: ItemLazy;
44    /// The validation item.
45    type ItemValid: ItemLazy;
46
47    /// Collect a training event.
48    fn process_train(&mut self, event: LearnerEvent<Self::ItemTrain>);
49    /// Collect a validation event.
50    fn process_valid(&mut self, event: LearnerEvent<Self::ItemValid>);
51    /// Returns the renderer used for training.
52    fn renderer(self) -> Box<dyn MetricsRenderer>;
53}
54
55/// Process events happening during evaluation.
56pub trait EventProcessorEvaluation: Send {
57    /// The test item.
58    type ItemTest: ItemLazy;
59
60    /// Collect a test event.
61    fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>);
62
63    /// Returns the renderer used for evaluation.
64    fn renderer(self) -> Box<dyn MetricsRenderer>;
65}
66
67/// A learner item.
68#[derive(new)]
69pub struct LearnerItem<T> {
70    /// The item.
71    pub item: T,
72
73    /// The progress.
74    pub progress: Progress,
75
76    /// The epoch.
77    pub epoch: usize,
78
79    /// The total number of epochs.
80    pub epoch_total: usize,
81
82    /// The iteration.
83    pub iteration: usize,
84
85    /// The learning rate.
86    pub lr: Option<LearningRate>,
87}
88
89impl<T: ItemLazy> ItemLazy for LearnerItem<T> {
90    type ItemSync = LearnerItem<T::ItemSync>;
91
92    fn sync(self) -> Self::ItemSync {
93        LearnerItem {
94            item: self.item.sync(),
95            progress: self.progress,
96            epoch: self.epoch,
97            epoch_total: self.epoch_total,
98            iteration: self.iteration,
99            lr: self.lr,
100        }
101    }
102}
103
104impl ItemLazy for () {
105    type ItemSync = ();
106
107    fn sync(self) -> Self::ItemSync {}
108}