burn_train/metric/processor/
metrics.rs

1use std::collections::HashMap;
2
3use super::{ItemLazy, LearnerItem};
4use crate::{
5    metric::{
6        Adaptor, Metric, MetricDefinition, MetricEntry, MetricId, MetricMetadata, Numeric,
7        store::{MetricsUpdate, NumericMetricUpdate},
8    },
9    renderer::{EvaluationProgress, TrainingProgress},
10};
11
12pub(crate) struct MetricsTraining<T: ItemLazy, V: ItemLazy> {
13    train: Vec<Box<dyn MetricUpdater<T::ItemSync>>>,
14    valid: Vec<Box<dyn MetricUpdater<V::ItemSync>>>,
15    train_numeric: Vec<Box<dyn NumericMetricUpdater<T::ItemSync>>>,
16    valid_numeric: Vec<Box<dyn NumericMetricUpdater<V::ItemSync>>>,
17    metric_definitions: HashMap<MetricId, MetricDefinition>,
18}
19
20pub(crate) struct MetricsEvaluation<T: ItemLazy> {
21    test: Vec<Box<dyn MetricUpdater<T::ItemSync>>>,
22    test_numeric: Vec<Box<dyn NumericMetricUpdater<T::ItemSync>>>,
23    metric_definitions: HashMap<MetricId, MetricDefinition>,
24}
25
26impl<T: ItemLazy> Default for MetricsEvaluation<T> {
27    fn default() -> Self {
28        Self {
29            test: Default::default(),
30            test_numeric: Default::default(),
31            metric_definitions: HashMap::default(),
32        }
33    }
34}
35
36impl<T: ItemLazy, V: ItemLazy> Default for MetricsTraining<T, V> {
37    fn default() -> Self {
38        Self {
39            train: Vec::default(),
40            valid: Vec::default(),
41            train_numeric: Vec::default(),
42            valid_numeric: Vec::default(),
43            metric_definitions: HashMap::default(),
44        }
45    }
46}
47
48impl<T: ItemLazy> MetricsEvaluation<T> {
49    /// Register a testing metric.
50    pub(crate) fn register_test_metric<Me: Metric + 'static>(&mut self, metric: Me)
51    where
52        T::ItemSync: Adaptor<Me::Input> + 'static,
53    {
54        let metric = MetricWrapper::new(metric);
55        self.register_definition(&metric);
56        self.test.push(Box::new(metric))
57    }
58
59    /// Register a numeric testing metric.
60    pub(crate) fn register_test_metric_numeric<Me: Metric + Numeric + 'static>(
61        &mut self,
62        metric: Me,
63    ) where
64        T::ItemSync: Adaptor<Me::Input> + 'static,
65    {
66        let metric = MetricWrapper::new(metric);
67        self.register_definition(&metric);
68        self.test_numeric.push(Box::new(metric))
69    }
70
71    fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {
72        self.metric_definitions.insert(
73            metric.id.clone(),
74            MetricDefinition::new(metric.id.clone(), &metric.metric),
75        );
76    }
77
78    /// Get metric definitions.
79    pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {
80        self.metric_definitions.values().cloned().collect()
81    }
82
83    /// Update the testing information from the testing item.
84    pub(crate) fn update_test(
85        &mut self,
86        item: &LearnerItem<T::ItemSync>,
87        metadata: &MetricMetadata,
88    ) -> MetricsUpdate {
89        let mut entries = Vec::with_capacity(self.test.len());
90        let mut entries_numeric = Vec::with_capacity(self.test_numeric.len());
91
92        for metric in self.test.iter_mut() {
93            let state = metric.update(item, metadata);
94            entries.push(state);
95        }
96
97        for metric in self.test_numeric.iter_mut() {
98            let numeric_update = metric.update(item, metadata);
99            entries_numeric.push(numeric_update);
100        }
101
102        MetricsUpdate::new(entries, entries_numeric)
103    }
104}
105
106impl<T: ItemLazy, V: ItemLazy> MetricsTraining<T, V> {
107    /// Register a training metric.
108    pub(crate) fn register_train_metric<Me: Metric + 'static>(&mut self, metric: Me)
109    where
110        T::ItemSync: Adaptor<Me::Input> + 'static,
111    {
112        let metric = MetricWrapper::new(metric);
113        self.register_definition(&metric);
114        self.train.push(Box::new(metric))
115    }
116
117    /// Register a validation metric.
118    pub(crate) fn register_valid_metric<Me: Metric + 'static>(&mut self, metric: Me)
119    where
120        V::ItemSync: Adaptor<Me::Input> + 'static,
121    {
122        let metric = MetricWrapper::new(metric);
123        self.register_definition(&metric);
124        self.valid.push(Box::new(metric))
125    }
126
127    /// Register a numeric training metric.
128    pub(crate) fn register_train_metric_numeric<Me: Metric + Numeric + 'static>(
129        &mut self,
130        metric: Me,
131    ) where
132        T::ItemSync: Adaptor<Me::Input> + 'static,
133    {
134        let metric = MetricWrapper::new(metric);
135        self.register_definition(&metric);
136        self.train_numeric.push(Box::new(metric))
137    }
138
139    /// Register a numeric validation metric.
140    pub(crate) fn register_valid_metric_numeric<Me>(&mut self, metric: Me)
141    where
142        V::ItemSync: Adaptor<Me::Input> + 'static,
143        Me: Metric + Numeric + 'static,
144    {
145        let metric = MetricWrapper::new(metric);
146        self.register_definition(&metric);
147        self.valid_numeric.push(Box::new(metric))
148    }
149
150    fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {
151        self.metric_definitions.insert(
152            metric.id.clone(),
153            MetricDefinition::new(metric.id.clone(), &metric.metric),
154        );
155    }
156
157    /// Get metric definitions for all splits
158    pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {
159        self.metric_definitions.values().cloned().collect()
160    }
161
162    /// Update the training information from the training item.
163    pub(crate) fn update_train(
164        &mut self,
165        item: &LearnerItem<T::ItemSync>,
166        metadata: &MetricMetadata,
167    ) -> MetricsUpdate {
168        let mut entries = Vec::with_capacity(self.train.len());
169        let mut entries_numeric = Vec::with_capacity(self.train_numeric.len());
170
171        for metric in self.train.iter_mut() {
172            let state = metric.update(item, metadata);
173            entries.push(state);
174        }
175
176        for metric in self.train_numeric.iter_mut() {
177            let numeric_update = metric.update(item, metadata);
178            entries_numeric.push(numeric_update);
179        }
180
181        MetricsUpdate::new(entries, entries_numeric)
182    }
183
184    /// Update the training information from the validation item.
185    pub(crate) fn update_valid(
186        &mut self,
187        item: &LearnerItem<V::ItemSync>,
188        metadata: &MetricMetadata,
189    ) -> MetricsUpdate {
190        let mut entries = Vec::with_capacity(self.valid.len());
191        let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len());
192
193        for metric in self.valid.iter_mut() {
194            let state = metric.update(item, metadata);
195            entries.push(state);
196        }
197
198        for metric in self.valid_numeric.iter_mut() {
199            let numeric_update = metric.update(item, metadata);
200            entries_numeric.push(numeric_update);
201        }
202
203        MetricsUpdate::new(entries, entries_numeric)
204    }
205
206    /// Signal the end of a training epoch.
207    pub(crate) fn end_epoch_train(&mut self) {
208        for metric in self.train.iter_mut() {
209            metric.clear();
210        }
211        for metric in self.train_numeric.iter_mut() {
212            metric.clear();
213        }
214    }
215
216    /// Signal the end of a validation epoch.
217    pub(crate) fn end_epoch_valid(&mut self) {
218        for metric in self.valid.iter_mut() {
219            metric.clear();
220        }
221        for metric in self.valid_numeric.iter_mut() {
222            metric.clear();
223        }
224    }
225}
226
227impl<T> From<&LearnerItem<T>> for TrainingProgress {
228    fn from(item: &LearnerItem<T>) -> Self {
229        Self {
230            progress: item.progress.clone(),
231            epoch: item.epoch,
232            epoch_total: item.epoch_total,
233            iteration: item.iteration,
234        }
235    }
236}
237
238impl<T> From<&LearnerItem<T>> for EvaluationProgress {
239    fn from(item: &LearnerItem<T>) -> Self {
240        Self {
241            progress: item.progress.clone(),
242            iteration: item.iteration,
243        }
244    }
245}
246
247impl<T> From<&LearnerItem<T>> for MetricMetadata {
248    fn from(item: &LearnerItem<T>) -> Self {
249        Self {
250            progress: item.progress.clone(),
251            epoch: item.epoch,
252            epoch_total: item.epoch_total,
253            iteration: item.iteration,
254            lr: item.lr,
255        }
256    }
257}
258
259trait NumericMetricUpdater<T>: Send + Sync {
260    fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> NumericMetricUpdate;
261    fn clear(&mut self);
262}
263
264trait MetricUpdater<T>: Send + Sync {
265    fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry;
266    fn clear(&mut self);
267}
268
269struct MetricWrapper<M> {
270    id: MetricId,
271    metric: M,
272}
273
274impl<M: Metric> MetricWrapper<M> {
275    pub fn new(metric: M) -> Self {
276        Self {
277            id: MetricId::new(metric.name()),
278            metric,
279        }
280    }
281}
282
283impl<T, M> NumericMetricUpdater<T> for MetricWrapper<M>
284where
285    T: 'static,
286    M: Metric + Numeric + 'static,
287    T: Adaptor<M::Input>,
288{
289    fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> NumericMetricUpdate {
290        let serialized_entry = self.metric.update(&item.item.adapt(), metadata);
291        let update = MetricEntry::new(self.id.clone(), serialized_entry);
292        let numeric = self.metric.value();
293        let running = self.metric.running_value();
294
295        NumericMetricUpdate {
296            entry: update,
297            numeric_entry: numeric,
298            running_entry: running,
299        }
300    }
301
302    fn clear(&mut self) {
303        self.metric.clear()
304    }
305}
306
307impl<T, M> MetricUpdater<T> for MetricWrapper<M>
308where
309    T: 'static,
310    M: Metric + 'static,
311    T: Adaptor<M::Input>,
312{
313    fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry {
314        let serialized_entry = self.metric.update(&item.item.adapt(), metadata);
315        MetricEntry::new(self.id.clone(), serialized_entry)
316    }
317
318    fn clear(&mut self) {
319        self.metric.clear()
320    }
321}