burn_train/logger/
metric.rs

1use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
2use crate::metric::{MetricEntry, NumericEntry};
3use std::{
4    collections::HashMap,
5    fs,
6    path::{Path, PathBuf},
7};
8
9const EPOCH_PREFIX: &str = "epoch-";
10
11/// Metric logger.
12pub trait MetricLogger: Send {
13    /// Logs an item.
14    ///
15    /// # Arguments
16    ///
17    /// * `item` - The item.
18    fn log(&mut self, item: &MetricEntry);
19
20    /// Logs an epoch.
21    ///
22    /// # Arguments
23    ///
24    /// * `epoch` - The epoch.
25    fn end_epoch(&mut self, epoch: usize);
26
27    /// Read the logs for an epoch.
28    fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String>;
29}
30
31/// The file metric logger.
32pub struct FileMetricLogger {
33    loggers: HashMap<String, AsyncLogger<String>>,
34    directory: PathBuf,
35    epoch: usize,
36}
37
38impl FileMetricLogger {
39    /// Create a new file metric logger.
40    ///
41    /// # Arguments
42    ///
43    /// * `directory` - The directory.
44    ///
45    /// # Returns
46    ///
47    /// The file metric logger.
48    pub fn new(directory: impl AsRef<Path>) -> Self {
49        Self {
50            loggers: HashMap::new(),
51            directory: directory.as_ref().to_path_buf(),
52            epoch: 1,
53        }
54    }
55
56    /// Number of epochs recorded.
57    pub(crate) fn epochs(&self) -> usize {
58        let mut max_epoch = 0;
59
60        for path in fs::read_dir(&self.directory).unwrap() {
61            let path = path.unwrap();
62
63            if fs::metadata(path.path()).unwrap().is_dir() {
64                let dir_name = path.file_name().into_string().unwrap();
65
66                if !dir_name.starts_with(EPOCH_PREFIX) {
67                    continue;
68                }
69
70                let epoch = dir_name.replace(EPOCH_PREFIX, "").parse::<usize>().ok();
71
72                if let Some(epoch) = epoch {
73                    if epoch > max_epoch {
74                        max_epoch = epoch;
75                    }
76                }
77            }
78        }
79
80        max_epoch
81    }
82
83    fn epoch_directory(&self, epoch: usize) -> PathBuf {
84        let name = format!("{}{}", EPOCH_PREFIX, epoch);
85        self.directory.join(name)
86    }
87
88    fn file_path(&self, name: &str, epoch: usize) -> PathBuf {
89        let directory = self.epoch_directory(epoch);
90        let name = name.replace(' ', "_");
91        let name = format!("{name}.log");
92        directory.join(name)
93    }
94
95    fn create_directory(&self, epoch: usize) {
96        let directory = self.epoch_directory(epoch);
97        std::fs::create_dir_all(directory).ok();
98    }
99}
100
101impl MetricLogger for FileMetricLogger {
102    fn log(&mut self, item: &MetricEntry) {
103        let key = &item.name;
104        let value = &item.serialize;
105
106        let logger = match self.loggers.get_mut(key) {
107            Some(val) => val,
108            None => {
109                self.create_directory(self.epoch);
110
111                let file_path = self.file_path(key, self.epoch);
112                let logger = FileLogger::new(file_path);
113                let logger = AsyncLogger::new(logger);
114
115                self.loggers.insert(key.clone(), logger);
116                self.loggers
117                    .get_mut(key)
118                    .expect("Can get the previously saved logger.")
119            }
120        };
121
122        logger.log(value.clone());
123    }
124
125    fn end_epoch(&mut self, epoch: usize) {
126        self.loggers.clear();
127        self.epoch = epoch + 1;
128    }
129
130    fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String> {
131        if let Some(value) = self.loggers.get(name) {
132            value.sync()
133        }
134
135        let file_path = self.file_path(name, epoch);
136
137        let mut errors = false;
138
139        let data = std::fs::read_to_string(file_path)
140            .unwrap_or_default()
141            .split('\n')
142            .filter_map(|value| {
143                if value.is_empty() {
144                    None
145                } else {
146                    match NumericEntry::deserialize(value) {
147                        Ok(value) => Some(value),
148                        Err(err) => {
149                            log::error!("{err}");
150                            errors = true;
151                            None
152                        }
153                    }
154                }
155            })
156            .collect();
157
158        if errors {
159            Err("Parsing numeric entry errors".to_string())
160        } else {
161            Ok(data)
162        }
163    }
164}
165
166/// In memory metric logger, useful when testing and debugging.
167#[derive(Default)]
168pub struct InMemoryMetricLogger {
169    values: HashMap<String, Vec<InMemoryLogger>>,
170}
171
172impl InMemoryMetricLogger {
173    /// Create a new in-memory metric logger.
174    pub fn new() -> Self {
175        Self::default()
176    }
177}
178impl MetricLogger for InMemoryMetricLogger {
179    fn log(&mut self, item: &MetricEntry) {
180        if !self.values.contains_key(&item.name) {
181            self.values
182                .insert(item.name.clone(), vec![InMemoryLogger::default()]);
183        }
184
185        let values = self.values.get_mut(&item.name).unwrap();
186
187        values.last_mut().unwrap().log(item.serialize.clone());
188    }
189
190    fn end_epoch(&mut self, _epoch: usize) {
191        for (_, values) in self.values.iter_mut() {
192            values.push(InMemoryLogger::default());
193        }
194    }
195
196    fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String> {
197        let values = match self.values.get(name) {
198            Some(values) => values,
199            None => return Ok(Vec::new()),
200        };
201
202        match values.get(epoch - 1) {
203            Some(logger) => Ok(logger
204                .values
205                .iter()
206                .filter_map(|value| NumericEntry::deserialize(value).ok())
207                .collect()),
208            None => Ok(Vec::new()),
209        }
210    }
211}