burn_train/logger/
metric.rs

1use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
2use crate::metric::{MetricEntry, NumericEntry};
3use std::{
4    collections::HashMap,
5    fs,
6    path::{Path, PathBuf},
7};
8
9const EPOCH_PREFIX: &str = "epoch-";
10
11/// Metric logger.
12pub trait MetricLogger: Send {
13    /// Logs an item.
14    ///
15    /// # Arguments
16    ///
17    /// * `item` - The item.
18    fn log(&mut self, item: &MetricEntry);
19
20    /// Logs an epoch.
21    ///
22    /// # Arguments
23    ///
24    /// * `epoch` - The epoch.
25    fn end_epoch(&mut self, epoch: usize);
26
27    /// Read the logs for an epoch.
28    fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String>;
29}
30
31/// The file metric logger.
32pub struct FileMetricLogger {
33    loggers: HashMap<String, AsyncLogger<String>>,
34    directory: PathBuf,
35    epoch: Option<usize>,
36}
37
38impl FileMetricLogger {
39    /// Create a new file metric logger.
40    ///
41    /// # Arguments
42    ///
43    /// * `directory` - The directory.
44    ///
45    /// # Returns
46    ///
47    /// The file metric logger.
48    pub fn new_train(directory: impl AsRef<Path>) -> Self {
49        Self {
50            loggers: HashMap::new(),
51            directory: directory.as_ref().to_path_buf(),
52            epoch: Some(1),
53        }
54    }
55
56    /// Create a new file metric logger.
57    ///
58    /// # Arguments
59    ///
60    /// * `directory` - The directory.
61    ///
62    /// # Returns
63    ///
64    /// The file metric logger.
65    pub fn new_eval(directory: impl AsRef<Path>) -> Self {
66        Self {
67            loggers: HashMap::new(),
68            directory: directory.as_ref().to_path_buf(),
69            epoch: None,
70        }
71    }
72
73    /// Number of epochs recorded.
74    pub(crate) fn epochs(&self) -> usize {
75        if self.epoch.is_none() {
76            log::warn!("Number of epochs not available when testing.");
77            return 0;
78        }
79
80        let mut max_epoch = 0;
81
82        for path in fs::read_dir(&self.directory).unwrap() {
83            let path = path.unwrap();
84
85            if fs::metadata(path.path()).unwrap().is_dir() {
86                let dir_name = path.file_name().into_string().unwrap();
87
88                if !dir_name.starts_with(EPOCH_PREFIX) {
89                    continue;
90                }
91
92                let epoch = dir_name.replace(EPOCH_PREFIX, "").parse::<usize>().ok();
93
94                if let Some(epoch) = epoch
95                    && epoch > max_epoch
96                {
97                    max_epoch = epoch;
98                }
99            }
100        }
101
102        max_epoch
103    }
104
105    fn train_directory(&self, tags: Option<&String>, epoch: usize) -> PathBuf {
106        let name = format!("{EPOCH_PREFIX}{epoch}");
107
108        match tags {
109            Some(tags) => self.directory.join(tags).join(name),
110            None => self.directory.join(name),
111        }
112    }
113
114    fn eval_directory(&self, tags: Option<&String>) -> PathBuf {
115        match tags {
116            Some(tags) => self.directory.join(tags),
117            None => self.directory.clone(),
118        }
119    }
120
121    fn file_path(&self, tags: Option<&String>, name: &str, epoch: Option<usize>) -> PathBuf {
122        let directory = match epoch {
123            Some(epoch) => self.train_directory(tags, epoch),
124            None => self.eval_directory(tags),
125        };
126        let name = name.replace(' ', "_");
127        let name = format!("{name}.log");
128        directory.join(name)
129    }
130
131    fn create_directory(&self, tags: Option<&String>, epoch: Option<usize>) {
132        let directory = match epoch {
133            Some(epoch) => self.train_directory(tags, epoch),
134            None => self.eval_directory(tags),
135        };
136        std::fs::create_dir_all(directory).ok();
137    }
138}
139
140impl FileMetricLogger {
141    fn log_item(&mut self, tags: Option<&String>, item: &MetricEntry) {
142        let key = &item.name;
143        let value = &item.serialize;
144
145        let logger = match self.loggers.get_mut(key.as_ref()) {
146            Some(val) => val,
147            None => {
148                self.create_directory(tags, self.epoch);
149
150                let file_path = self.file_path(tags, key, self.epoch);
151                let logger = FileLogger::new(file_path);
152                let logger = AsyncLogger::new(logger);
153
154                self.loggers.insert(key.to_string(), logger);
155                self.loggers
156                    .get_mut(key.as_ref())
157                    .expect("Can get the previously saved logger.")
158            }
159        };
160
161        logger.log(value.clone());
162    }
163
164    fn log_tags(&mut self, item: &MetricEntry) {
165        let mut tags = String::new();
166        item.tags.iter().for_each(|tag| tags += tag.as_str());
167        let tags = tags.replace(" ", "-").trim().to_lowercase();
168        self.log_item(Some(&tags), item);
169    }
170}
171
172impl MetricLogger for FileMetricLogger {
173    fn log(&mut self, item: &MetricEntry) {
174        match item.tags.is_empty() {
175            true => self.log_item(None, item),
176            false => self.log_tags(item),
177        }
178    }
179
180    fn end_epoch(&mut self, epoch: usize) {
181        self.loggers.clear();
182        if self.epoch.is_none() {
183            panic!("Only evaluation logger supported.");
184        }
185        self.epoch = Some(epoch + 1);
186    }
187
188    fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String> {
189        if let Some(value) = self.loggers.get(name) {
190            value.sync()
191        }
192
193        let file_path = self.file_path(None, name, Some(epoch));
194
195        let mut errors = false;
196
197        let data = std::fs::read_to_string(file_path)
198            .unwrap_or_default()
199            .split('\n')
200            .filter_map(|value| {
201                if value.is_empty() {
202                    None
203                } else {
204                    match NumericEntry::deserialize(value) {
205                        Ok(value) => Some(value),
206                        Err(err) => {
207                            log::error!("{err}");
208                            errors = true;
209                            None
210                        }
211                    }
212                }
213            })
214            .collect();
215
216        if errors {
217            Err("Parsing numeric entry errors".to_string())
218        } else {
219            Ok(data)
220        }
221    }
222}
223
224/// In memory metric logger, useful when testing and debugging.
225#[derive(Default)]
226pub struct InMemoryMetricLogger {
227    values: HashMap<String, Vec<InMemoryLogger>>,
228}
229
230impl InMemoryMetricLogger {
231    /// Create a new in-memory metric logger.
232    pub fn new() -> Self {
233        Self::default()
234    }
235}
236impl MetricLogger for InMemoryMetricLogger {
237    fn log(&mut self, item: &MetricEntry) {
238        if !self.values.contains_key(item.name.as_ref()) {
239            self.values
240                .insert(item.name.to_string(), vec![InMemoryLogger::default()]);
241        }
242
243        let values = self.values.get_mut(item.name.as_ref()).unwrap();
244
245        values.last_mut().unwrap().log(item.serialize.clone());
246    }
247
248    fn end_epoch(&mut self, _epoch: usize) {
249        for (_, values) in self.values.iter_mut() {
250            values.push(InMemoryLogger::default());
251        }
252    }
253
254    fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String> {
255        let values = match self.values.get(name) {
256            Some(values) => values,
257            None => return Ok(Vec::new()),
258        };
259
260        match values.get(epoch - 1) {
261            Some(logger) => Ok(logger
262                .values
263                .iter()
264                .filter_map(|value| NumericEntry::deserialize(value).ok())
265                .collect()),
266            None => Ok(Vec::new()),
267        }
268    }
269}