burn_train/logger/
metric.rs1use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
2use crate::metric::{MetricEntry, NumericEntry};
3use std::{
4 collections::HashMap,
5 fs,
6 path::{Path, PathBuf},
7};
8
9const EPOCH_PREFIX: &str = "epoch-";
10
11pub trait MetricLogger: Send {
13 fn log(&mut self, item: &MetricEntry);
19
20 fn end_epoch(&mut self, epoch: usize);
26
27 fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String>;
29}
30
31pub struct FileMetricLogger {
33 loggers: HashMap<String, AsyncLogger<String>>,
34 directory: PathBuf,
35 epoch: usize,
36}
37
38impl FileMetricLogger {
39 pub fn new(directory: impl AsRef<Path>) -> Self {
49 Self {
50 loggers: HashMap::new(),
51 directory: directory.as_ref().to_path_buf(),
52 epoch: 1,
53 }
54 }
55
56 pub(crate) fn epochs(&self) -> usize {
58 let mut max_epoch = 0;
59
60 for path in fs::read_dir(&self.directory).unwrap() {
61 let path = path.unwrap();
62
63 if fs::metadata(path.path()).unwrap().is_dir() {
64 let dir_name = path.file_name().into_string().unwrap();
65
66 if !dir_name.starts_with(EPOCH_PREFIX) {
67 continue;
68 }
69
70 let epoch = dir_name.replace(EPOCH_PREFIX, "").parse::<usize>().ok();
71
72 if let Some(epoch) = epoch {
73 if epoch > max_epoch {
74 max_epoch = epoch;
75 }
76 }
77 }
78 }
79
80 max_epoch
81 }
82
83 fn epoch_directory(&self, epoch: usize) -> PathBuf {
84 let name = format!("{}{}", EPOCH_PREFIX, epoch);
85 self.directory.join(name)
86 }
87
88 fn file_path(&self, name: &str, epoch: usize) -> PathBuf {
89 let directory = self.epoch_directory(epoch);
90 let name = name.replace(' ', "_");
91 let name = format!("{name}.log");
92 directory.join(name)
93 }
94
95 fn create_directory(&self, epoch: usize) {
96 let directory = self.epoch_directory(epoch);
97 std::fs::create_dir_all(directory).ok();
98 }
99}
100
101impl MetricLogger for FileMetricLogger {
102 fn log(&mut self, item: &MetricEntry) {
103 let key = &item.name;
104 let value = &item.serialize;
105
106 let logger = match self.loggers.get_mut(key) {
107 Some(val) => val,
108 None => {
109 self.create_directory(self.epoch);
110
111 let file_path = self.file_path(key, self.epoch);
112 let logger = FileLogger::new(file_path);
113 let logger = AsyncLogger::new(logger);
114
115 self.loggers.insert(key.clone(), logger);
116 self.loggers
117 .get_mut(key)
118 .expect("Can get the previously saved logger.")
119 }
120 };
121
122 logger.log(value.clone());
123 }
124
125 fn end_epoch(&mut self, epoch: usize) {
126 self.loggers.clear();
127 self.epoch = epoch + 1;
128 }
129
130 fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String> {
131 if let Some(value) = self.loggers.get(name) {
132 value.sync()
133 }
134
135 let file_path = self.file_path(name, epoch);
136
137 let mut errors = false;
138
139 let data = std::fs::read_to_string(file_path)
140 .unwrap_or_default()
141 .split('\n')
142 .filter_map(|value| {
143 if value.is_empty() {
144 None
145 } else {
146 match NumericEntry::deserialize(value) {
147 Ok(value) => Some(value),
148 Err(err) => {
149 log::error!("{err}");
150 errors = true;
151 None
152 }
153 }
154 }
155 })
156 .collect();
157
158 if errors {
159 Err("Parsing numeric entry errors".to_string())
160 } else {
161 Ok(data)
162 }
163 }
164}
165
166#[derive(Default)]
168pub struct InMemoryMetricLogger {
169 values: HashMap<String, Vec<InMemoryLogger>>,
170}
171
172impl InMemoryMetricLogger {
173 pub fn new() -> Self {
175 Self::default()
176 }
177}
178impl MetricLogger for InMemoryMetricLogger {
179 fn log(&mut self, item: &MetricEntry) {
180 if !self.values.contains_key(&item.name) {
181 self.values
182 .insert(item.name.clone(), vec![InMemoryLogger::default()]);
183 }
184
185 let values = self.values.get_mut(&item.name).unwrap();
186
187 values.last_mut().unwrap().log(item.serialize.clone());
188 }
189
190 fn end_epoch(&mut self, _epoch: usize) {
191 for (_, values) in self.values.iter_mut() {
192 values.push(InMemoryLogger::default());
193 }
194 }
195
196 fn read_numeric(&mut self, name: &str, epoch: usize) -> Result<Vec<NumericEntry>, String> {
197 let values = match self.values.get(name) {
198 Some(values) => values,
199 None => return Ok(Vec::new()),
200 };
201
202 match values.get(epoch - 1) {
203 Some(logger) => Ok(logger
204 .values
205 .iter()
206 .filter_map(|value| NumericEntry::deserialize(value).ok())
207 .collect()),
208 None => Ok(Vec::new()),
209 }
210 }
211}