Skip to main content

burn_train/metric/processor/
metrics.rs

1use std::collections::HashMap;
2
3use super::{ItemLazy, TrainingItem};
4use crate::{
5    EvaluationItem,
6    metric::{
7        Adaptor, Metric, MetricDefinition, MetricEntry, MetricId, MetricMetadata, Numeric,
8        store::{MetricsUpdate, NumericMetricUpdate},
9    },
10    renderer::{EvaluationProgress, TrainingProgress},
11};
12
13pub(crate) struct MetricsTraining<T: ItemLazy, V: ItemLazy> {
14    train: Vec<Box<dyn MetricUpdater<T::ItemSync>>>,
15    valid: Vec<Box<dyn MetricUpdater<V::ItemSync>>>,
16    train_numeric: Vec<Box<dyn NumericMetricUpdater<T::ItemSync>>>,
17    valid_numeric: Vec<Box<dyn NumericMetricUpdater<V::ItemSync>>>,
18    metric_definitions: HashMap<MetricId, MetricDefinition>,
19}
20
21pub(crate) struct MetricsEvaluation<T: ItemLazy> {
22    test: Vec<Box<dyn MetricUpdater<T::ItemSync>>>,
23    test_numeric: Vec<Box<dyn NumericMetricUpdater<T::ItemSync>>>,
24    metric_definitions: HashMap<MetricId, MetricDefinition>,
25}
26
27impl<T: ItemLazy> Default for MetricsEvaluation<T> {
28    fn default() -> Self {
29        Self {
30            test: Default::default(),
31            test_numeric: Default::default(),
32            metric_definitions: HashMap::default(),
33        }
34    }
35}
36
37impl<T: ItemLazy, V: ItemLazy> Default for MetricsTraining<T, V> {
38    fn default() -> Self {
39        Self {
40            train: Vec::default(),
41            valid: Vec::default(),
42            train_numeric: Vec::default(),
43            valid_numeric: Vec::default(),
44            metric_definitions: HashMap::default(),
45        }
46    }
47}
48
49impl<T: ItemLazy> MetricsEvaluation<T> {
50    /// Register a testing metric.
51    pub(crate) fn register_test_metric<Me: Metric + 'static>(&mut self, metric: Me)
52    where
53        T::ItemSync: Adaptor<Me::Input> + 'static,
54    {
55        let metric = MetricWrapper::new(metric);
56        self.register_definition(&metric);
57        self.test.push(Box::new(metric))
58    }
59
60    /// Register a numeric testing metric.
61    pub(crate) fn register_test_metric_numeric<Me: Metric + Numeric + 'static>(
62        &mut self,
63        metric: Me,
64    ) where
65        T::ItemSync: Adaptor<Me::Input> + 'static,
66    {
67        let metric = MetricWrapper::new(metric);
68        self.register_definition(&metric);
69        self.test_numeric.push(Box::new(metric))
70    }
71
72    fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {
73        self.metric_definitions.insert(
74            metric.id.clone(),
75            MetricDefinition::new(metric.id.clone(), &metric.metric),
76        );
77    }
78
79    /// Get metric definitions.
80    pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {
81        self.metric_definitions.values().cloned().collect()
82    }
83
84    /// Update the testing information from the testing item.
85    pub(crate) fn update_test(
86        &mut self,
87        item: &EvaluationItem<T::ItemSync>,
88        metadata: &MetricMetadata,
89    ) -> MetricsUpdate {
90        let mut entries = Vec::with_capacity(self.test.len());
91        let mut entries_numeric = Vec::with_capacity(self.test_numeric.len());
92
93        for metric in self.test.iter_mut() {
94            let state = metric.update(&item.item, metadata);
95            entries.push(state);
96        }
97
98        for metric in self.test_numeric.iter_mut() {
99            let numeric_update = metric.update(&item.item, metadata);
100            entries_numeric.push(numeric_update);
101        }
102
103        MetricsUpdate::new(entries, entries_numeric)
104    }
105}
106
107impl<T: ItemLazy, V: ItemLazy> MetricsTraining<T, V> {
108    /// Register a training metric.
109    pub(crate) fn register_train_metric<Me: Metric + 'static>(&mut self, metric: Me)
110    where
111        T::ItemSync: Adaptor<Me::Input> + 'static,
112    {
113        let metric = MetricWrapper::new(metric);
114        self.register_definition(&metric);
115        self.train.push(Box::new(metric))
116    }
117
118    /// Register a validation metric.
119    pub(crate) fn register_valid_metric<Me: Metric + 'static>(&mut self, metric: Me)
120    where
121        V::ItemSync: Adaptor<Me::Input> + 'static,
122    {
123        let metric = MetricWrapper::new(metric);
124        self.register_definition(&metric);
125        self.valid.push(Box::new(metric))
126    }
127
128    /// Register a numeric training metric.
129    pub(crate) fn register_train_metric_numeric<Me: Metric + Numeric + 'static>(
130        &mut self,
131        metric: Me,
132    ) where
133        T::ItemSync: Adaptor<Me::Input> + 'static,
134    {
135        let metric = MetricWrapper::new(metric);
136        self.register_definition(&metric);
137        self.train_numeric.push(Box::new(metric))
138    }
139
140    /// Register a numeric validation metric.
141    pub(crate) fn register_valid_metric_numeric<Me>(&mut self, metric: Me)
142    where
143        V::ItemSync: Adaptor<Me::Input> + 'static,
144        Me: Metric + Numeric + 'static,
145    {
146        let metric = MetricWrapper::new(metric);
147        self.register_definition(&metric);
148        self.valid_numeric.push(Box::new(metric))
149    }
150
151    fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {
152        self.metric_definitions.insert(
153            metric.id.clone(),
154            MetricDefinition::new(metric.id.clone(), &metric.metric),
155        );
156    }
157
158    /// Get metric definitions for all splits
159    pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {
160        self.metric_definitions.values().cloned().collect()
161    }
162
163    /// Update the training information from the training item.
164    pub(crate) fn update_train(
165        &mut self,
166        item: &TrainingItem<T::ItemSync>,
167        metadata: &MetricMetadata,
168    ) -> MetricsUpdate {
169        let mut entries = Vec::with_capacity(self.train.len());
170        let mut entries_numeric = Vec::with_capacity(self.train_numeric.len());
171
172        for metric in self.train.iter_mut() {
173            let state = metric.update(&item.item, metadata);
174            entries.push(state);
175        }
176
177        for metric in self.train_numeric.iter_mut() {
178            let numeric_update = metric.update(&item.item, metadata);
179            entries_numeric.push(numeric_update);
180        }
181
182        MetricsUpdate::new(entries, entries_numeric)
183    }
184
185    /// Update the training information from the validation item.
186    pub(crate) fn update_valid(
187        &mut self,
188        item: &TrainingItem<V::ItemSync>,
189        metadata: &MetricMetadata,
190    ) -> MetricsUpdate {
191        let mut entries = Vec::with_capacity(self.valid.len());
192        let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len());
193
194        for metric in self.valid.iter_mut() {
195            let state = metric.update(&item.item, metadata);
196            entries.push(state);
197        }
198
199        for metric in self.valid_numeric.iter_mut() {
200            let numeric_update = metric.update(&item.item, metadata);
201            entries_numeric.push(numeric_update);
202        }
203
204        MetricsUpdate::new(entries, entries_numeric)
205    }
206
207    /// Signal the end of a training epoch.
208    pub(crate) fn end_epoch_train(&mut self) {
209        for metric in self.train.iter_mut() {
210            metric.clear();
211        }
212        for metric in self.train_numeric.iter_mut() {
213            metric.clear();
214        }
215    }
216
217    /// Signal the end of a validation epoch.
218    pub(crate) fn end_epoch_valid(&mut self) {
219        for metric in self.valid.iter_mut() {
220            metric.clear();
221        }
222        for metric in self.valid_numeric.iter_mut() {
223            metric.clear();
224        }
225    }
226}
227
228impl<T> From<&TrainingItem<T>> for TrainingProgress {
229    fn from(item: &TrainingItem<T>) -> Self {
230        Self {
231            progress: Some(item.progress.clone()),
232            global_progress: item.global_progress.clone(),
233            iteration: item.iteration,
234        }
235    }
236}
237
238impl<T> From<&EvaluationItem<T>> for TrainingProgress {
239    fn from(item: &EvaluationItem<T>) -> Self {
240        Self {
241            progress: None,
242            global_progress: item.progress.clone(),
243            iteration: item.iteration,
244        }
245    }
246}
247
248impl<T> From<&EvaluationItem<T>> for EvaluationProgress {
249    fn from(item: &EvaluationItem<T>) -> Self {
250        Self {
251            progress: item.progress.clone(),
252            iteration: item.iteration,
253        }
254    }
255}
256
257impl<T> From<&TrainingItem<T>> for MetricMetadata {
258    fn from(item: &TrainingItem<T>) -> Self {
259        Self {
260            progress: item.progress.clone(),
261            global_progress: item.global_progress.clone(),
262            iteration: item.iteration,
263            lr: item.lr,
264        }
265    }
266}
267
268impl<T> From<&EvaluationItem<T>> for MetricMetadata {
269    fn from(item: &EvaluationItem<T>) -> Self {
270        Self {
271            progress: item.progress.clone(),
272            global_progress: item.progress.clone(),
273            iteration: item.iteration,
274            lr: None,
275        }
276    }
277}
278
279pub(crate) trait NumericMetricUpdater<T>: Send + Sync {
280    fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate;
281    fn clear(&mut self);
282}
283
284pub(crate) trait MetricUpdater<T>: Send + Sync {
285    fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry;
286    fn clear(&mut self);
287}
288
289pub(crate) struct MetricWrapper<M> {
290    pub id: MetricId,
291    pub metric: M,
292}
293
294impl<M: Metric> MetricWrapper<M> {
295    pub fn new(metric: M) -> Self {
296        Self {
297            id: MetricId::new(metric.name()),
298            metric,
299        }
300    }
301}
302
303impl<T, M> NumericMetricUpdater<T> for MetricWrapper<M>
304where
305    T: 'static,
306    M: Metric + Numeric + 'static,
307    T: Adaptor<M::Input>,
308{
309    fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate {
310        let serialized_entry = self.metric.update(&item.adapt(), metadata);
311        let update = MetricEntry::new(self.id.clone(), serialized_entry);
312        let numeric = self.metric.value();
313        let running = self.metric.running_value();
314
315        NumericMetricUpdate {
316            entry: update,
317            numeric_entry: numeric,
318            running_entry: running,
319        }
320    }
321
322    fn clear(&mut self) {
323        self.metric.clear()
324    }
325}
326
327impl<T, M> MetricUpdater<T> for MetricWrapper<M>
328where
329    T: 'static,
330    M: Metric + 'static,
331    T: Adaptor<M::Input>,
332{
333    fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry {
334        let serialized_entry = self.metric.update(&item.adapt(), metadata);
335        MetricEntry::new(self.id.clone(), serialized_entry)
336    }
337
338    fn clear(&mut self) {
339        self.metric.clear()
340    }
341}