burn_train/metric/store/
base.rs1use std::sync::Arc;
2
3use crate::metric::{MetricEntry, NumericEntry};
4
5pub enum Event {
7 MetricsUpdate(MetricsUpdate),
9 EndEpoch(usize),
11}
12
13#[derive(new, Clone, Debug)]
15pub struct MetricsUpdate {
16 pub entries: Vec<MetricEntry>,
18 pub entries_numeric: Vec<(MetricEntry, NumericEntry)>,
20}
21
22impl MetricsUpdate {
23 pub fn tag(&mut self, tag: Arc<String>) {
25 self.entries.iter_mut().for_each(|entry| {
26 entry.tags.push(tag.clone());
27 });
28 self.entries_numeric.iter_mut().for_each(|(entry, _)| {
29 entry.tags.push(tag.clone());
30 });
31 }
32}
33
34pub trait EventStore: Send {
38 fn add_event(&mut self, event: Event, split: Split);
40
41 fn find_epoch(
43 &mut self,
44 name: &str,
45 aggregate: Aggregate,
46 direction: Direction,
47 split: Split,
48 ) -> Option<usize>;
49
50 fn find_metric(
52 &mut self,
53 name: &str,
54 epoch: usize,
55 aggregate: Aggregate,
56 split: Split,
57 ) -> Option<f64>;
58}
59
60#[derive(Copy, Clone, Hash, PartialEq, Eq, Debug)]
61pub enum Aggregate {
63 Mean,
65}
66
67#[derive(Copy, Clone)]
68pub enum Split {
70 Train,
72 Valid,
74 Test,
76}
77
78#[derive(Copy, Clone)]
79pub enum Direction {
81 Lowest,
83 Highest,
85}