Skip to main content

burn_train/logger/
metric.rs

1use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
2use crate::metric::{
3    MetricDefinition, MetricEntry, MetricId, NumericEntry,
4    store::{EpochSummary, MetricsUpdate, Split},
5};
6use std::{
7    collections::HashMap,
8    fs,
9    path::{Path, PathBuf},
10};
11
12const EPOCH_PREFIX: &str = "epoch-";
13
14/// Metric logger.
15pub trait MetricLogger: Send {
16    /// Logs an item.
17    ///
18    /// # Arguments
19    ///
20    /// * `update` - Update information for all registered metrics.
21    /// * `epoch` - Current epoch.
22    /// * `split` - Current dataset split.
23    fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split);
24
25    /// Read the logs for an epoch.
26    fn read_numeric(
27        &mut self,
28        name: &str,
29        epoch: usize,
30        split: &Split,
31    ) -> Result<Vec<NumericEntry>, String>;
32
33    /// Logs the metric definition information (name, description, unit, etc.)
34    fn log_metric_definition(&mut self, definition: MetricDefinition);
35
36    /// Logs summary at the end of the epoch.
37    fn log_epoch_summary(&mut self, summary: EpochSummary);
38}
39
40/// The file metric logger.
41pub struct FileMetricLogger {
42    loggers: HashMap<String, AsyncLogger<String>>,
43    directory: PathBuf,
44    metric_definitions: HashMap<MetricId, MetricDefinition>,
45    is_eval: bool,
46    last_epoch: Option<usize>,
47}
48
49impl FileMetricLogger {
50    /// Create a new file metric logger.
51    ///
52    /// # Arguments
53    ///
54    /// * `directory` - The directory.
55    ///
56    /// # Returns
57    ///
58    /// The file metric logger.
59    pub fn new(directory: impl AsRef<Path>) -> Self {
60        Self {
61            loggers: HashMap::new(),
62            directory: directory.as_ref().to_path_buf(),
63            metric_definitions: HashMap::default(),
64            is_eval: false,
65            last_epoch: None,
66        }
67    }
68
69    /// Create a new file metric logger.
70    ///
71    /// # Arguments
72    ///
73    /// * `directory` - The directory.
74    ///
75    /// # Returns
76    ///
77    /// The file metric logger.
78    pub fn new_eval(directory: impl AsRef<Path>) -> Self {
79        Self {
80            loggers: HashMap::new(),
81            directory: directory.as_ref().to_path_buf(),
82            metric_definitions: HashMap::default(),
83            is_eval: true,
84            last_epoch: None,
85        }
86    }
87
88    pub(crate) fn split_exists(&self, split: &Split) -> bool {
89        self.split_dir(split).is_some()
90    }
91
92    pub(crate) fn split_dir(&self, split: &Split) -> Option<PathBuf> {
93        let split_path = match split {
94            Split::Test(Some(tag)) => self.directory.join(split.to_string()).join(tag.as_str()),
95            other => self.directory.join(other.to_string()),
96        };
97        (split_path.exists() && split_path.is_dir()).then_some(split_path)
98    }
99
100    pub(crate) fn is_epoch_dir<P: AsRef<str>>(dirname: P) -> bool {
101        dirname.as_ref().starts_with(EPOCH_PREFIX)
102    }
103
104    /// Number of epochs recorded.
105    pub(crate) fn epochs(&self) -> usize {
106        if self.is_eval {
107            log::warn!("Number of epochs not available when testing.");
108            return 0;
109        }
110
111        let mut max_epoch = 0;
112
113        // with split
114        for path in fs::read_dir(&self.directory).unwrap() {
115            let path = path.unwrap();
116
117            if fs::metadata(path.path()).unwrap().is_dir() {
118                for split_path in fs::read_dir(path.path()).unwrap() {
119                    let split_path = split_path.unwrap();
120
121                    if fs::metadata(split_path.path()).unwrap().is_dir() {
122                        let dir_name = split_path.file_name().into_string().unwrap();
123
124                        if !dir_name.starts_with(EPOCH_PREFIX) {
125                            continue;
126                        }
127
128                        let epoch = dir_name.replace(EPOCH_PREFIX, "").parse::<usize>().ok();
129
130                        if let Some(epoch) = epoch
131                            && epoch > max_epoch
132                        {
133                            max_epoch = epoch;
134                        }
135                    }
136                }
137            }
138        }
139
140        max_epoch
141    }
142
143    fn train_directory(&self, epoch: usize, split: &Split) -> PathBuf {
144        let name = format!("{EPOCH_PREFIX}{epoch}");
145
146        match split {
147            Split::Train | Split::Valid | Split::Test(None) => {
148                self.directory.join(split.to_string()).join(name)
149            }
150            Split::Test(Some(tag)) => {
151                let tag = format_tag(tag);
152                self.directory.join(split.to_string()).join(tag).join(name)
153            }
154        }
155    }
156
157    fn eval_directory(&self, split: &Split) -> PathBuf {
158        match split {
159            Split::Train | Split::Valid | Split::Test(None) => self.directory.clone(),
160            Split::Test(Some(tag)) => self.directory.join(split.to_string()).join(format_tag(tag)),
161        }
162    }
163
164    fn file_path(&self, name: &str, epoch: Option<usize>, split: &Split) -> PathBuf {
165        let directory = match epoch {
166            Some(epoch) => self.train_directory(epoch, split),
167            None => self.eval_directory(split),
168        };
169        let name = name.replace(' ', "_");
170        let name = format!("{name}.log");
171        directory.join(name)
172    }
173
174    fn create_directory(&self, epoch: Option<usize>, split: &Split) {
175        let directory = match epoch {
176            Some(epoch) => self.train_directory(epoch, split),
177            None => self.eval_directory(split),
178        };
179        std::fs::create_dir_all(directory).ok();
180    }
181}
182
183impl FileMetricLogger {
184    fn log_item(&mut self, item: &MetricEntry, epoch: Option<usize>, split: &Split) {
185        let name = &self.metric_definitions.get(&item.metric_id).unwrap().name;
186        let key = logger_key(name, split);
187        let value = &item.serialized_entry.serialized;
188
189        let logger = match self.loggers.get_mut(&key) {
190            Some(val) => val,
191            None => {
192                self.create_directory(epoch, split);
193
194                let file_path = self.file_path(name, epoch, split);
195                let logger = FileLogger::new(file_path);
196                let logger = AsyncLogger::new(logger);
197
198                self.loggers.insert(key.clone(), logger);
199                self.loggers
200                    .get_mut(&key)
201                    .expect("Can get the previously saved logger.")
202            }
203        };
204
205        logger.log(value.clone());
206    }
207}
208
209fn format_tag(tag: &str) -> String {
210    tag.trim().replace(' ', "-").to_lowercase()
211}
212
213impl MetricLogger for FileMetricLogger {
214    fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split) {
215        if !self.is_eval && self.last_epoch != Some(epoch) {
216            self.loggers.clear();
217            self.last_epoch = Some(epoch);
218        }
219
220        let entries: Vec<_> = update
221            .entries
222            .iter()
223            .chain(
224                update
225                    .entries_numeric
226                    .iter()
227                    .map(|numeric_update| &numeric_update.entry),
228            )
229            .cloned()
230            .collect();
231
232        for item in entries.iter() {
233            self.log_item(item, Some(epoch), split);
234        }
235    }
236
237    fn read_numeric(
238        &mut self,
239        name: &str,
240        epoch: usize,
241        split: &Split,
242    ) -> Result<Vec<NumericEntry>, String> {
243        if let Some(value) = self.loggers.get(name) {
244            value.sync()
245        }
246
247        let file_path = self.file_path(name, Some(epoch), split);
248
249        let mut errors = false;
250
251        let data = std::fs::read_to_string(file_path)
252            .unwrap_or_default()
253            .split('\n')
254            .filter_map(|value| {
255                if value.is_empty() {
256                    None
257                } else {
258                    match NumericEntry::deserialize(value) {
259                        Ok(value) => Some(value),
260                        Err(err) => {
261                            log::error!("{err}");
262                            errors = true;
263                            None
264                        }
265                    }
266                }
267            })
268            .collect();
269
270        if errors {
271            Err("Parsing numeric entry errors".to_string())
272        } else {
273            Ok(data)
274        }
275    }
276
277    fn log_metric_definition(&mut self, definition: MetricDefinition) {
278        self.metric_definitions
279            .insert(definition.metric_id.clone(), definition);
280    }
281
282    fn log_epoch_summary(&mut self, _summary: EpochSummary) {
283        if !self.is_eval {
284            self.loggers.clear();
285        }
286    }
287}
288
289fn logger_key(name: &str, split: &Split) -> String {
290    format!("{name}_{split}")
291}
292
293/// In memory metric logger, useful when testing and debugging.
294#[derive(Default)]
295pub struct InMemoryMetricLogger {
296    values: HashMap<String, Vec<InMemoryLogger>>,
297    last_epoch: Option<usize>,
298    metric_definitions: HashMap<MetricId, MetricDefinition>,
299}
300
301impl InMemoryMetricLogger {
302    /// Create a new in-memory metric logger.
303    pub fn new() -> Self {
304        Self::default()
305    }
306}
307
308impl MetricLogger for InMemoryMetricLogger {
309    fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split) {
310        if self.last_epoch != Some(epoch) {
311            self.values
312                .values_mut()
313                .for_each(|loggers| loggers.push(InMemoryLogger::default()));
314            self.last_epoch = Some(epoch);
315        }
316
317        let entries: Vec<_> = update
318            .entries
319            .iter()
320            .chain(
321                update
322                    .entries_numeric
323                    .iter()
324                    .map(|numeric_update| &numeric_update.entry),
325            )
326            .cloned()
327            .collect();
328
329        for item in entries.iter() {
330            let name = &self.metric_definitions.get(&item.metric_id).unwrap().name;
331            let key = logger_key(name, split);
332
333            if !self.values.contains_key(&key) {
334                self.values
335                    .insert(key.to_string(), vec![InMemoryLogger::default()]);
336            }
337
338            let values = self.values.get_mut(&key).unwrap();
339
340            values
341                .last_mut()
342                .unwrap()
343                .log(item.serialized_entry.serialized.clone());
344        }
345    }
346
347    fn read_numeric(
348        &mut self,
349        name: &str,
350        epoch: usize,
351        split: &Split,
352    ) -> Result<Vec<NumericEntry>, String> {
353        let key = logger_key(name, split);
354        let values = match self.values.get(&key) {
355            Some(values) => values,
356            None => return Ok(Vec::new()),
357        };
358
359        match values.get(epoch - 1) {
360            Some(logger) => Ok(logger
361                .values
362                .iter()
363                .filter_map(|value| NumericEntry::deserialize(value).ok())
364                .collect()),
365            None => Ok(Vec::new()),
366        }
367    }
368
369    fn log_metric_definition(&mut self, definition: MetricDefinition) {
370        self.metric_definitions
371            .insert(definition.metric_id.clone(), definition);
372    }
373
374    fn log_epoch_summary(&mut self, _summary: EpochSummary) {}
375}