burn_train/metric/processor/
metrics.rs1use std::collections::HashMap;
2
3use super::{ItemLazy, LearnerItem};
4use crate::{
5 metric::{
6 Adaptor, Metric, MetricDefinition, MetricEntry, MetricId, MetricMetadata, Numeric,
7 store::{MetricsUpdate, NumericMetricUpdate},
8 },
9 renderer::{EvaluationProgress, TrainingProgress},
10};
11
12pub(crate) struct MetricsTraining<T: ItemLazy, V: ItemLazy> {
13 train: Vec<Box<dyn MetricUpdater<T::ItemSync>>>,
14 valid: Vec<Box<dyn MetricUpdater<V::ItemSync>>>,
15 train_numeric: Vec<Box<dyn NumericMetricUpdater<T::ItemSync>>>,
16 valid_numeric: Vec<Box<dyn NumericMetricUpdater<V::ItemSync>>>,
17 metric_definitions: HashMap<MetricId, MetricDefinition>,
18}
19
20pub(crate) struct MetricsEvaluation<T: ItemLazy> {
21 test: Vec<Box<dyn MetricUpdater<T::ItemSync>>>,
22 test_numeric: Vec<Box<dyn NumericMetricUpdater<T::ItemSync>>>,
23 metric_definitions: HashMap<MetricId, MetricDefinition>,
24}
25
26impl<T: ItemLazy> Default for MetricsEvaluation<T> {
27 fn default() -> Self {
28 Self {
29 test: Default::default(),
30 test_numeric: Default::default(),
31 metric_definitions: HashMap::default(),
32 }
33 }
34}
35
36impl<T: ItemLazy, V: ItemLazy> Default for MetricsTraining<T, V> {
37 fn default() -> Self {
38 Self {
39 train: Vec::default(),
40 valid: Vec::default(),
41 train_numeric: Vec::default(),
42 valid_numeric: Vec::default(),
43 metric_definitions: HashMap::default(),
44 }
45 }
46}
47
48impl<T: ItemLazy> MetricsEvaluation<T> {
49 pub(crate) fn register_test_metric<Me: Metric + 'static>(&mut self, metric: Me)
51 where
52 T::ItemSync: Adaptor<Me::Input> + 'static,
53 {
54 let metric = MetricWrapper::new(metric);
55 self.register_definition(&metric);
56 self.test.push(Box::new(metric))
57 }
58
59 pub(crate) fn register_test_metric_numeric<Me: Metric + Numeric + 'static>(
61 &mut self,
62 metric: Me,
63 ) where
64 T::ItemSync: Adaptor<Me::Input> + 'static,
65 {
66 let metric = MetricWrapper::new(metric);
67 self.register_definition(&metric);
68 self.test_numeric.push(Box::new(metric))
69 }
70
71 fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {
72 self.metric_definitions.insert(
73 metric.id.clone(),
74 MetricDefinition::new(metric.id.clone(), &metric.metric),
75 );
76 }
77
78 pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {
80 self.metric_definitions.values().cloned().collect()
81 }
82
83 pub(crate) fn update_test(
85 &mut self,
86 item: &LearnerItem<T::ItemSync>,
87 metadata: &MetricMetadata,
88 ) -> MetricsUpdate {
89 let mut entries = Vec::with_capacity(self.test.len());
90 let mut entries_numeric = Vec::with_capacity(self.test_numeric.len());
91
92 for metric in self.test.iter_mut() {
93 let state = metric.update(item, metadata);
94 entries.push(state);
95 }
96
97 for metric in self.test_numeric.iter_mut() {
98 let numeric_update = metric.update(item, metadata);
99 entries_numeric.push(numeric_update);
100 }
101
102 MetricsUpdate::new(entries, entries_numeric)
103 }
104}
105
106impl<T: ItemLazy, V: ItemLazy> MetricsTraining<T, V> {
107 pub(crate) fn register_train_metric<Me: Metric + 'static>(&mut self, metric: Me)
109 where
110 T::ItemSync: Adaptor<Me::Input> + 'static,
111 {
112 let metric = MetricWrapper::new(metric);
113 self.register_definition(&metric);
114 self.train.push(Box::new(metric))
115 }
116
117 pub(crate) fn register_valid_metric<Me: Metric + 'static>(&mut self, metric: Me)
119 where
120 V::ItemSync: Adaptor<Me::Input> + 'static,
121 {
122 let metric = MetricWrapper::new(metric);
123 self.register_definition(&metric);
124 self.valid.push(Box::new(metric))
125 }
126
127 pub(crate) fn register_train_metric_numeric<Me: Metric + Numeric + 'static>(
129 &mut self,
130 metric: Me,
131 ) where
132 T::ItemSync: Adaptor<Me::Input> + 'static,
133 {
134 let metric = MetricWrapper::new(metric);
135 self.register_definition(&metric);
136 self.train_numeric.push(Box::new(metric))
137 }
138
139 pub(crate) fn register_valid_metric_numeric<Me>(&mut self, metric: Me)
141 where
142 V::ItemSync: Adaptor<Me::Input> + 'static,
143 Me: Metric + Numeric + 'static,
144 {
145 let metric = MetricWrapper::new(metric);
146 self.register_definition(&metric);
147 self.valid_numeric.push(Box::new(metric))
148 }
149
150 fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {
151 self.metric_definitions.insert(
152 metric.id.clone(),
153 MetricDefinition::new(metric.id.clone(), &metric.metric),
154 );
155 }
156
157 pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {
159 self.metric_definitions.values().cloned().collect()
160 }
161
162 pub(crate) fn update_train(
164 &mut self,
165 item: &LearnerItem<T::ItemSync>,
166 metadata: &MetricMetadata,
167 ) -> MetricsUpdate {
168 let mut entries = Vec::with_capacity(self.train.len());
169 let mut entries_numeric = Vec::with_capacity(self.train_numeric.len());
170
171 for metric in self.train.iter_mut() {
172 let state = metric.update(item, metadata);
173 entries.push(state);
174 }
175
176 for metric in self.train_numeric.iter_mut() {
177 let numeric_update = metric.update(item, metadata);
178 entries_numeric.push(numeric_update);
179 }
180
181 MetricsUpdate::new(entries, entries_numeric)
182 }
183
184 pub(crate) fn update_valid(
186 &mut self,
187 item: &LearnerItem<V::ItemSync>,
188 metadata: &MetricMetadata,
189 ) -> MetricsUpdate {
190 let mut entries = Vec::with_capacity(self.valid.len());
191 let mut entries_numeric = Vec::with_capacity(self.valid_numeric.len());
192
193 for metric in self.valid.iter_mut() {
194 let state = metric.update(item, metadata);
195 entries.push(state);
196 }
197
198 for metric in self.valid_numeric.iter_mut() {
199 let numeric_update = metric.update(item, metadata);
200 entries_numeric.push(numeric_update);
201 }
202
203 MetricsUpdate::new(entries, entries_numeric)
204 }
205
206 pub(crate) fn end_epoch_train(&mut self) {
208 for metric in self.train.iter_mut() {
209 metric.clear();
210 }
211 for metric in self.train_numeric.iter_mut() {
212 metric.clear();
213 }
214 }
215
216 pub(crate) fn end_epoch_valid(&mut self) {
218 for metric in self.valid.iter_mut() {
219 metric.clear();
220 }
221 for metric in self.valid_numeric.iter_mut() {
222 metric.clear();
223 }
224 }
225}
226
227impl<T> From<&LearnerItem<T>> for TrainingProgress {
228 fn from(item: &LearnerItem<T>) -> Self {
229 Self {
230 progress: item.progress.clone(),
231 epoch: item.epoch,
232 epoch_total: item.epoch_total,
233 iteration: item.iteration,
234 }
235 }
236}
237
238impl<T> From<&LearnerItem<T>> for EvaluationProgress {
239 fn from(item: &LearnerItem<T>) -> Self {
240 Self {
241 progress: item.progress.clone(),
242 iteration: item.iteration,
243 }
244 }
245}
246
247impl<T> From<&LearnerItem<T>> for MetricMetadata {
248 fn from(item: &LearnerItem<T>) -> Self {
249 Self {
250 progress: item.progress.clone(),
251 epoch: item.epoch,
252 epoch_total: item.epoch_total,
253 iteration: item.iteration,
254 lr: item.lr,
255 }
256 }
257}
258
259trait NumericMetricUpdater<T>: Send + Sync {
260 fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> NumericMetricUpdate;
261 fn clear(&mut self);
262}
263
264trait MetricUpdater<T>: Send + Sync {
265 fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry;
266 fn clear(&mut self);
267}
268
269struct MetricWrapper<M> {
270 id: MetricId,
271 metric: M,
272}
273
274impl<M: Metric> MetricWrapper<M> {
275 pub fn new(metric: M) -> Self {
276 Self {
277 id: MetricId::new(metric.name()),
278 metric,
279 }
280 }
281}
282
283impl<T, M> NumericMetricUpdater<T> for MetricWrapper<M>
284where
285 T: 'static,
286 M: Metric + Numeric + 'static,
287 T: Adaptor<M::Input>,
288{
289 fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> NumericMetricUpdate {
290 let serialized_entry = self.metric.update(&item.item.adapt(), metadata);
291 let update = MetricEntry::new(self.id.clone(), serialized_entry);
292 let numeric = self.metric.value();
293 let running = self.metric.running_value();
294
295 NumericMetricUpdate {
296 entry: update,
297 numeric_entry: numeric,
298 running_entry: running,
299 }
300 }
301
302 fn clear(&mut self) {
303 self.metric.clear()
304 }
305}
306
307impl<T, M> MetricUpdater<T> for MetricWrapper<M>
308where
309 T: 'static,
310 M: Metric + 'static,
311 T: Adaptor<M::Input>,
312{
313 fn update(&mut self, item: &LearnerItem<T>, metadata: &MetricMetadata) -> MetricEntry {
314 let serialized_entry = self.metric.update(&item.item.adapt(), metadata);
315 MetricEntry::new(self.id.clone(), serialized_entry)
316 }
317
318 fn clear(&mut self) {
319 self.metric.clear()
320 }
321}