Skip to main content

burn_train/logger/
metric.rs

1use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
2use crate::metric::{
3    MetricDefinition, MetricEntry, MetricId, NumericEntry,
4    store::{EpochSummary, MetricsUpdate, Split},
5};
6use std::{
7    collections::HashMap,
8    fs,
9    path::{Path, PathBuf},
10    sync::Arc,
11};
12
13const EPOCH_PREFIX: &str = "epoch-";
14
15/// Metric logger.
16pub trait MetricLogger: Send {
17    /// Logs an item.
18    ///
19    /// # Arguments
20    ///
21    /// * `update` - Update information for all registered metrics.
22    /// * `epoch` - Current epoch.
23    /// * `split` - Current dataset split.
24    /// * `iteration` - Current iteration.
25    /// * `tag` - Optional, additional tag for the split.
26    fn log(&mut self, update: MetricsUpdate, epoch: usize, split: Split, tag: Option<Arc<String>>);
27
28    /// Read the logs for an epoch.
29    fn read_numeric(
30        &mut self,
31        name: &str,
32        epoch: usize,
33        split: Split,
34    ) -> Result<Vec<NumericEntry>, String>;
35
36    /// Logs the metric definition information (name, description, unit, etc.)
37    fn log_metric_definition(&mut self, definition: MetricDefinition);
38
39    /// Logs summary at the end of the epoch.
40    fn log_epoch_summary(&mut self, summary: EpochSummary);
41}
42
43/// The file metric logger.
44pub struct FileMetricLogger {
45    loggers: HashMap<String, AsyncLogger<String>>,
46    directory: PathBuf,
47    metric_definitions: HashMap<MetricId, MetricDefinition>,
48    is_eval: bool,
49    last_epoch: Option<usize>,
50}
51
52impl FileMetricLogger {
53    /// Create a new file metric logger.
54    ///
55    /// # Arguments
56    ///
57    /// * `directory` - The directory.
58    ///
59    /// # Returns
60    ///
61    /// The file metric logger.
62    pub fn new(directory: impl AsRef<Path>) -> Self {
63        Self {
64            loggers: HashMap::new(),
65            directory: directory.as_ref().to_path_buf(),
66            metric_definitions: HashMap::default(),
67            is_eval: false,
68            last_epoch: None,
69        }
70    }
71
72    /// Create a new file metric logger.
73    ///
74    /// # Arguments
75    ///
76    /// * `directory` - The directory.
77    ///
78    /// # Returns
79    ///
80    /// The file metric logger.
81    pub fn new_eval(directory: impl AsRef<Path>) -> Self {
82        Self {
83            loggers: HashMap::new(),
84            directory: directory.as_ref().to_path_buf(),
85            metric_definitions: HashMap::default(),
86            is_eval: true,
87            last_epoch: None,
88        }
89    }
90
91    pub(crate) fn split_exists(&self, split: Split) -> bool {
92        let split_path = self.directory.join(split.to_string());
93        split_path.exists() && split_path.is_dir()
94    }
95
96    /// Number of epochs recorded.
97    pub(crate) fn epochs(&self) -> usize {
98        if self.is_eval {
99            log::warn!("Number of epochs not available when testing.");
100            return 0;
101        }
102
103        let mut max_epoch = 0;
104
105        // with split
106        for path in fs::read_dir(&self.directory).unwrap() {
107            let path = path.unwrap();
108
109            if fs::metadata(path.path()).unwrap().is_dir() {
110                for split_path in fs::read_dir(path.path()).unwrap() {
111                    let split_path = split_path.unwrap();
112
113                    if fs::metadata(split_path.path()).unwrap().is_dir() {
114                        let dir_name = split_path.file_name().into_string().unwrap();
115
116                        if !dir_name.starts_with(EPOCH_PREFIX) {
117                            continue;
118                        }
119
120                        let epoch = dir_name.replace(EPOCH_PREFIX, "").parse::<usize>().ok();
121
122                        if let Some(epoch) = epoch
123                            && epoch > max_epoch
124                        {
125                            max_epoch = epoch;
126                        }
127                    }
128                }
129            }
130        }
131
132        max_epoch
133    }
134
135    fn train_directory(&self, tag: Option<&String>, epoch: usize, split: Split) -> PathBuf {
136        let name = format!("{EPOCH_PREFIX}{epoch}");
137
138        match tag {
139            Some(tag) => self.directory.join(split.to_string()).join(tag).join(name),
140            None => self.directory.join(split.to_string()).join(name),
141        }
142    }
143
144    fn eval_directory(&self, tag: Option<&String>, split: Split) -> PathBuf {
145        match tag {
146            Some(tag) => self.directory.join(split.to_string()).join(tag),
147            None => self.directory.clone(),
148        }
149    }
150
151    fn file_path(
152        &self,
153        tag: Option<&String>,
154        name: &str,
155        epoch: Option<usize>,
156        split: Split,
157    ) -> PathBuf {
158        let directory = match epoch {
159            Some(epoch) => self.train_directory(tag, epoch, split),
160            None => self.eval_directory(tag, split),
161        };
162        let name = name.replace(' ', "_");
163        let name = format!("{name}.log");
164        directory.join(name)
165    }
166
167    fn create_directory(&self, tag: Option<&String>, epoch: Option<usize>, split: Split) {
168        let directory = match epoch {
169            Some(epoch) => self.train_directory(tag, epoch, split),
170            None => self.eval_directory(tag, split),
171        };
172        std::fs::create_dir_all(directory).ok();
173    }
174}
175
176impl FileMetricLogger {
177    fn log_item(
178        &mut self,
179        tag: Option<&String>,
180        item: &MetricEntry,
181        epoch: Option<usize>,
182        split: Split,
183    ) {
184        let name = &self.metric_definitions.get(&item.metric_id).unwrap().name;
185        let key = logger_key(name, split);
186        let value = &item.serialized_entry.serialized;
187
188        let logger = match self.loggers.get_mut(&key) {
189            Some(val) => val,
190            None => {
191                self.create_directory(tag, epoch, split);
192
193                let file_path = self.file_path(tag, name, epoch, split);
194                let logger = FileLogger::new(file_path);
195                let logger = AsyncLogger::new(logger);
196
197                self.loggers.insert(key.clone(), logger);
198                self.loggers
199                    .get_mut(&key)
200                    .expect("Can get the previously saved logger.")
201            }
202        };
203
204        logger.log(value.clone());
205    }
206}
207
208impl MetricLogger for FileMetricLogger {
209    fn log(&mut self, update: MetricsUpdate, epoch: usize, split: Split, tag: Option<Arc<String>>) {
210        if !self.is_eval && self.last_epoch != Some(epoch) {
211            self.loggers.clear();
212            self.last_epoch = Some(epoch);
213        }
214
215        let entries: Vec<_> = update
216            .entries
217            .iter()
218            .chain(
219                update
220                    .entries_numeric
221                    .iter()
222                    .map(|numeric_update| &numeric_update.entry),
223            )
224            .cloned()
225            .collect();
226
227        for item in entries.iter() {
228            match tag {
229                Some(ref tag) => {
230                    let tag = tag.trim().replace(' ', "-").to_lowercase();
231                    self.log_item(Some(&tag), item, Some(epoch), split);
232                }
233                None => self.log_item(None, item, Some(epoch), split),
234            }
235        }
236    }
237
238    fn read_numeric(
239        &mut self,
240        name: &str,
241        epoch: usize,
242        split: Split,
243    ) -> Result<Vec<NumericEntry>, String> {
244        if let Some(value) = self.loggers.get(name) {
245            value.sync()
246        }
247
248        let file_path = self.file_path(None, name, Some(epoch), split);
249
250        let mut errors = false;
251
252        let data = std::fs::read_to_string(file_path)
253            .unwrap_or_default()
254            .split('\n')
255            .filter_map(|value| {
256                if value.is_empty() {
257                    None
258                } else {
259                    match NumericEntry::deserialize(value) {
260                        Ok(value) => Some(value),
261                        Err(err) => {
262                            log::error!("{err}");
263                            errors = true;
264                            None
265                        }
266                    }
267                }
268            })
269            .collect();
270
271        if errors {
272            Err("Parsing numeric entry errors".to_string())
273        } else {
274            Ok(data)
275        }
276    }
277
278    fn log_metric_definition(&mut self, definition: MetricDefinition) {
279        self.metric_definitions
280            .insert(definition.metric_id.clone(), definition);
281    }
282
283    fn log_epoch_summary(&mut self, _summary: EpochSummary) {
284        if !self.is_eval {
285            self.loggers.clear();
286        }
287    }
288}
289
290fn logger_key(name: &str, split: Split) -> String {
291    format!("{name}_{split}")
292}
293
294/// In memory metric logger, useful when testing and debugging.
295#[derive(Default)]
296pub struct InMemoryMetricLogger {
297    values: HashMap<String, Vec<InMemoryLogger>>,
298    last_epoch: Option<usize>,
299    metric_definitions: HashMap<MetricId, MetricDefinition>,
300}
301
302impl InMemoryMetricLogger {
303    /// Create a new in-memory metric logger.
304    pub fn new() -> Self {
305        Self::default()
306    }
307}
308
309impl MetricLogger for InMemoryMetricLogger {
310    fn log(
311        &mut self,
312        update: MetricsUpdate,
313        epoch: usize,
314        split: Split,
315        _tag: Option<Arc<String>>,
316    ) {
317        if self.last_epoch != Some(epoch) {
318            self.values
319                .values_mut()
320                .for_each(|loggers| loggers.push(InMemoryLogger::default()));
321            self.last_epoch = Some(epoch);
322        }
323
324        let entries: Vec<_> = update
325            .entries
326            .iter()
327            .chain(
328                update
329                    .entries_numeric
330                    .iter()
331                    .map(|numeric_update| &numeric_update.entry),
332            )
333            .cloned()
334            .collect();
335
336        for item in entries.iter() {
337            let name = &self.metric_definitions.get(&item.metric_id).unwrap().name;
338            let key = logger_key(name, split);
339
340            if !self.values.contains_key(&key) {
341                self.values
342                    .insert(key.to_string(), vec![InMemoryLogger::default()]);
343            }
344
345            let values = self.values.get_mut(&key).unwrap();
346
347            values
348                .last_mut()
349                .unwrap()
350                .log(item.serialized_entry.serialized.clone());
351        }
352    }
353
354    fn read_numeric(
355        &mut self,
356        name: &str,
357        epoch: usize,
358        split: Split,
359    ) -> Result<Vec<NumericEntry>, String> {
360        let key = logger_key(name, split);
361        let values = match self.values.get(&key) {
362            Some(values) => values,
363            None => return Ok(Vec::new()),
364        };
365
366        match values.get(epoch - 1) {
367            Some(logger) => Ok(logger
368                .values
369                .iter()
370                .filter_map(|value| NumericEntry::deserialize(value).ok())
371                .collect()),
372            None => Ok(Vec::new()),
373        }
374    }
375
376    fn log_metric_definition(&mut self, definition: MetricDefinition) {
377        self.metric_definitions
378            .insert(definition.metric_id.clone(), definition);
379    }
380
381    fn log_epoch_summary(&mut self, _summary: EpochSummary) {}
382}