burn_train/metric/processor/
base.rs

1use burn_core::data::dataloader::Progress;
2use burn_optim::LearningRate;
3
4use crate::{
5    LearnerSummary,
6    renderer::{EvaluationName, MetricsRenderer},
7};
8
9/// Event happening during the training/validation process.
10pub enum LearnerEvent<T> {
11    /// Signal the start of the process (e.g., training start)
12    Start,
13    /// Signal that an item have been processed.
14    ProcessedItem(LearnerItem<T>),
15    /// Signal the end of an epoch.
16    EndEpoch(usize),
17    /// Signal the end of the process (e.g., training end).
18    End(Option<LearnerSummary>),
19}
20
21/// Event happening during the evaluation process.
22pub enum EvaluatorEvent<T> {
23    /// Signal the start of the process (e.g., training start)
24    Start,
25    /// Signal that an item have been processed.
26    ProcessedItem(EvaluationName, LearnerItem<T>),
27    /// Signal the end of the process (e.g., training end).
28    End,
29}
30
31/// Items that are lazy are not ready to be processed by metrics.
32///
33/// We want to sync them on a different thread to avoid blocking training.
34pub trait ItemLazy: Send {
35    /// Item that is properly synced and ready to be processed by metrics.
36    type ItemSync: Send;
37
38    /// Sync the item.
39    fn sync(self) -> Self::ItemSync;
40}
41
42/// Process events happening during training and validation.
43pub trait EventProcessorTraining: Send {
44    /// The training item.
45    type ItemTrain: ItemLazy;
46    /// The validation item.
47    type ItemValid: ItemLazy;
48
49    /// Collect a training event.
50    fn process_train(&mut self, event: LearnerEvent<Self::ItemTrain>);
51    /// Collect a validation event.
52    fn process_valid(&mut self, event: LearnerEvent<Self::ItemValid>);
53    /// Returns the renderer used for training.
54    fn renderer(self) -> Box<dyn MetricsRenderer>;
55}
56
57/// Process events happening during evaluation.
58pub trait EventProcessorEvaluation: Send {
59    /// The test item.
60    type ItemTest: ItemLazy;
61
62    /// Collect a test event.
63    fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>);
64
65    /// Returns the renderer used for evaluation.
66    fn renderer(self) -> Box<dyn MetricsRenderer>;
67}
68
69/// A learner item.
70#[derive(new)]
71pub struct LearnerItem<T> {
72    /// The item.
73    pub item: T,
74
75    /// The progress.
76    pub progress: Progress,
77
78    /// The epoch.
79    pub epoch: usize,
80
81    /// The total number of epochs.
82    pub epoch_total: usize,
83
84    /// The iteration.
85    pub iteration: usize,
86
87    /// The learning rate.
88    pub lr: Option<LearningRate>,
89}
90
91impl<T: ItemLazy> ItemLazy for LearnerItem<T> {
92    type ItemSync = LearnerItem<T::ItemSync>;
93
94    fn sync(self) -> Self::ItemSync {
95        LearnerItem {
96            item: self.item.sync(),
97            progress: self.progress,
98            epoch: self.epoch,
99            epoch_total: self.epoch_total,
100            iteration: self.iteration,
101            lr: self.lr,
102        }
103    }
104}
105
106impl ItemLazy for () {
107    type ItemSync = ();
108
109    fn sync(self) -> Self::ItemSync {}
110}