burn_train/metric/processor/
metrics.rs1use std::collections::HashMap;
2
3use super::{ItemLazy, TrainingItem};
4use crate::{
5 EvaluationItem,
6 metric::{
7 Adaptor, Metric, MetricDefinition, MetricEntry, MetricId, MetricMetadata, Numeric,
8 store::{MetricsUpdate, NumericMetricUpdate},
9 },
10 renderer::{EvaluationProgress, TrainingProgress},
11};
12
13pub(crate) struct MetricsTraining<T: ItemLazy, V: ItemLazy> {
14 train: Vec<Box<dyn MetricUpdater<T::ItemSync>>>,
15 valid: Vec<Box<dyn MetricUpdater<V::ItemSync>>>,
16 train_numeric: Vec<Box<dyn NumericMetricUpdater<T::ItemSync>>>,
17 valid_numeric: Vec<Box<dyn NumericMetricUpdater<V::ItemSync>>>,
18 metric_definitions: HashMap<MetricId, MetricDefinition>,
19}
20
21pub(crate) struct MetricsEvaluation<T: ItemLazy> {
22 test: Vec<Box<dyn MetricUpdater<T::ItemSync>>>,
23 test_numeric: Vec<Box<dyn NumericMetricUpdater<T::ItemSync>>>,
24 metric_definitions: HashMap<MetricId, MetricDefinition>,
25}
26
27impl<T: ItemLazy> Default for MetricsEvaluation<T> {
28 fn default() -> Self {
29 Self {
30 test: Default::default(),
31 test_numeric: Default::default(),
32 metric_definitions: HashMap::default(),
33 }
34 }
35}
36
37impl<T: ItemLazy, V: ItemLazy> Default for MetricsTraining<T, V> {
38 fn default() -> Self {
39 Self {
40 train: Vec::default(),
41 valid: Vec::default(),
42 train_numeric: Vec::default(),
43 valid_numeric: Vec::default(),
44 metric_definitions: HashMap::default(),
45 }
46 }
47}
48
49impl<T: ItemLazy> MetricsEvaluation<T> {
50 pub(crate) fn register_test_metric<Me: Metric + 'static>(&mut self, metric: Me)
52 where
53 T::ItemSync: Adaptor<Me::Input> + 'static,
54 {
55 let metric = MetricWrapper::new(metric);
56 self.register_definition(&metric);
57 self.test.push(Box::new(metric))
58 }
59
60 pub(crate) fn register_test_metric_numeric<Me: Metric + Numeric + 'static>(
62 &mut self,
63 metric: Me,
64 ) where
65 T::ItemSync: Adaptor<Me::Input> + 'static,
66 {
67 let metric = MetricWrapper::new(metric);
68 self.register_definition(&metric);
69 self.test_numeric.push(Box::new(metric))
70 }
71
72 fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {
73 self.metric_definitions.insert(
74 metric.id.clone(),
75 MetricDefinition::new(metric.id.clone(), &metric.metric),
76 );
77 }
78
79 pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {
81 self.metric_definitions.values().cloned().collect()
82 }
83
84 pub(crate) fn update_test(
86 &mut self,
87 item: &EvaluationItem<T::ItemSync>,
88 metadata: &MetricMetadata,
89 ) -> MetricsUpdate {
90 let mut entries = Vec::with_capacity(self.test.len());
91 let mut entries_numeric = Vec::with_capacity(self.test_numeric.len());
92
93 for metric in self.test.iter_mut() {
94 let state = metric.update(&item.item, metadata);
95 entries.push(state);
96 }
97
98 for metric in self.test_numeric.iter_mut() {
99 let numeric_update = metric.update(&item.item, metadata);
100 entries_numeric.push(numeric_update);
101 }
102
103 MetricsUpdate::new(entries, entries_numeric)
104 }
105}
106
107impl<T: ItemLazy, V: ItemLazy> MetricsTraining<T, V> {
108 pub(crate) fn register_train_metric<Me: Metric + 'static>(&mut self, metric: Me)
110 where
111 T::ItemSync: Adaptor<Me::Input> + 'static,
112 {
113 let metric = MetricWrapper::new(metric);
114 self.register_definition(&metric);
115 self.train.push(Box::new(metric))
116 }
117
118 pub(crate) fn register_valid_metric<Me: Metric + 'static>(&mut self, metric: Me)
120 where
121 V::ItemSync: Adaptor<Me::Input> + 'static,
122 {
123 let metric = MetricWrapper::new(metric);
124 self.register_definition(&metric);
125 self.valid.push(Box::new(metric))
126 }
127
128 pub(crate) fn register_train_metric_numeric<Me: Metric + Numeric + 'static>(
130 &mut self,
131 metric: Me,
132 ) where
133 T::ItemSync: Adaptor<Me::Input> + 'static,
134 {
135 let metric = MetricWrapper::new(metric);
136 self.register_definition(&metric);
137 self.train_numeric.push(Box::new(metric))
138 }
139
140 pub(crate) fn register_valid_metric_numeric<Me>(&mut self, metric: Me)
142 where
143 V::ItemSync: Adaptor<Me::Input> + 'static,
144 Me: Metric + Numeric + 'static,
145 {
146 let metric = MetricWrapper::new(metric);
147 self.register_definition(&metric);
148 self.valid_numeric.push(Box::new(metric))
149 }
150
151 fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {
152 self.metric_definitions.insert(
153 metric.id.clone(),
154 MetricDefinition::new(metric.id.clone(), &metric.metric),
155 );
156 }
157
158 pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {
160 self.metric_definitions.values().cloned().collect()
161 }
162
163 pub(crate) fn update_train(
165 &mut self,
166 item: &TrainingItem<T::ItemSync>,
167 metadata: &MetricMetadata,
168 ) -> MetricsUpdate {
169 let mut entries = Vec::with_capacity(self.train.len());
170 let mut entries_numeric = Vec::with_capacity(self.train_numeric.len());
171
172 for metric in self.train.iter_mut() {
173 let state = metric.update(&item.item, metadata);
174 entries.push(state);
175 }
176
177 for metric in self.train_numeric.iter_mut() {
178 let numeric_update = metric.update(&item.item, metadata);
179 entries_numeric.push(numeric_update);
180 }
181
182 MetricsUpdate::new(entries, entries_numeric)
183 }
184
185 pub(crate) fn update_valid(
187 &mut self,
188 item: &TrainingItem<V::ItemSync>,
189 metadata: &MetricMetadata,
190 ) -> MetricsUpdate {
191 let mut entries = Vec::with_capacity(self.valid.len());
192 let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len());
193
194 for metric in self.valid.iter_mut() {
195 let state = metric.update(&item.item, metadata);
196 entries.push(state);
197 }
198
199 for metric in self.valid_numeric.iter_mut() {
200 let numeric_update = metric.update(&item.item, metadata);
201 entries_numeric.push(numeric_update);
202 }
203
204 MetricsUpdate::new(entries, entries_numeric)
205 }
206
207 pub(crate) fn end_epoch_train(&mut self) {
209 for metric in self.train.iter_mut() {
210 metric.clear();
211 }
212 for metric in self.train_numeric.iter_mut() {
213 metric.clear();
214 }
215 }
216
217 pub(crate) fn end_epoch_valid(&mut self) {
219 for metric in self.valid.iter_mut() {
220 metric.clear();
221 }
222 for metric in self.valid_numeric.iter_mut() {
223 metric.clear();
224 }
225 }
226}
227
228impl<T> From<&TrainingItem<T>> for TrainingProgress {
229 fn from(item: &TrainingItem<T>) -> Self {
230 Self {
231 progress: Some(item.progress.clone()),
232 global_progress: item.global_progress.clone(),
233 iteration: item.iteration,
234 }
235 }
236}
237
238impl<T> From<&EvaluationItem<T>> for TrainingProgress {
239 fn from(item: &EvaluationItem<T>) -> Self {
240 Self {
241 progress: None,
242 global_progress: item.progress.clone(),
243 iteration: item.iteration,
244 }
245 }
246}
247
248impl<T> From<&EvaluationItem<T>> for EvaluationProgress {
249 fn from(item: &EvaluationItem<T>) -> Self {
250 Self {
251 progress: item.progress.clone(),
252 iteration: item.iteration,
253 }
254 }
255}
256
257impl<T> From<&TrainingItem<T>> for MetricMetadata {
258 fn from(item: &TrainingItem<T>) -> Self {
259 Self {
260 progress: item.progress.clone(),
261 global_progress: item.global_progress.clone(),
262 iteration: item.iteration,
263 lr: item.lr,
264 }
265 }
266}
267
268impl<T> From<&EvaluationItem<T>> for MetricMetadata {
269 fn from(item: &EvaluationItem<T>) -> Self {
270 Self {
271 progress: item.progress.clone(),
272 global_progress: item.progress.clone(),
273 iteration: item.iteration,
274 lr: None,
275 }
276 }
277}
278
279pub(crate) trait NumericMetricUpdater<T>: Send + Sync {
280 fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate;
281 fn clear(&mut self);
282}
283
284pub(crate) trait MetricUpdater<T>: Send + Sync {
285 fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry;
286 fn clear(&mut self);
287}
288
289pub(crate) struct MetricWrapper<M> {
290 pub id: MetricId,
291 pub metric: M,
292}
293
294impl<M: Metric> MetricWrapper<M> {
295 pub fn new(metric: M) -> Self {
296 Self {
297 id: MetricId::new(metric.name()),
298 metric,
299 }
300 }
301}
302
303impl<T, M> NumericMetricUpdater<T> for MetricWrapper<M>
304where
305 T: 'static,
306 M: Metric + Numeric + 'static,
307 T: Adaptor<M::Input>,
308{
309 fn update(&mut self, item: &T, metadata: &MetricMetadata) -> NumericMetricUpdate {
310 let serialized_entry = self.metric.update(&item.adapt(), metadata);
311 let update = MetricEntry::new(self.id.clone(), serialized_entry);
312 let numeric = self.metric.value();
313 let running = self.metric.running_value();
314
315 NumericMetricUpdate {
316 entry: update,
317 numeric_entry: numeric,
318 running_entry: running,
319 }
320 }
321
322 fn clear(&mut self) {
323 self.metric.clear()
324 }
325}
326
327impl<T, M> MetricUpdater<T> for MetricWrapper<M>
328where
329 T: 'static,
330 M: Metric + 'static,
331 T: Adaptor<M::Input>,
332{
333 fn update(&mut self, item: &T, metadata: &MetricMetadata) -> MetricEntry {
334 let serialized_entry = self.metric.update(&item.adapt(), metadata);
335 MetricEntry::new(self.id.clone(), serialized_entry)
336 }
337
338 fn clear(&mut self) {
339 self.metric.clear()
340 }
341}