1use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
2use crate::metric::{
3 MetricDefinition, MetricEntry, MetricId, NumericEntry,
4 store::{EpochSummary, MetricsUpdate, Split},
5};
6use std::{
7 collections::HashMap,
8 fs,
9 path::{Path, PathBuf},
10};
11
12const EPOCH_PREFIX: &str = "epoch-";
13
14pub trait MetricLogger: Send {
16 fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split);
24
25 fn read_numeric(
27 &mut self,
28 name: &str,
29 epoch: usize,
30 split: &Split,
31 ) -> Result<Vec<NumericEntry>, String>;
32
33 fn log_metric_definition(&mut self, definition: MetricDefinition);
35
36 fn log_epoch_summary(&mut self, summary: EpochSummary);
38}
39
40pub struct FileMetricLogger {
42 loggers: HashMap<String, AsyncLogger<String>>,
43 directory: PathBuf,
44 metric_definitions: HashMap<MetricId, MetricDefinition>,
45 is_eval: bool,
46 last_epoch: Option<usize>,
47}
48
49impl FileMetricLogger {
50 pub fn new(directory: impl AsRef<Path>) -> Self {
60 Self {
61 loggers: HashMap::new(),
62 directory: directory.as_ref().to_path_buf(),
63 metric_definitions: HashMap::default(),
64 is_eval: false,
65 last_epoch: None,
66 }
67 }
68
69 pub fn new_eval(directory: impl AsRef<Path>) -> Self {
79 Self {
80 loggers: HashMap::new(),
81 directory: directory.as_ref().to_path_buf(),
82 metric_definitions: HashMap::default(),
83 is_eval: true,
84 last_epoch: None,
85 }
86 }
87
88 pub(crate) fn split_exists(&self, split: &Split) -> bool {
89 self.split_dir(split).is_some()
90 }
91
92 pub(crate) fn split_dir(&self, split: &Split) -> Option<PathBuf> {
93 let split_path = match split {
94 Split::Test(Some(tag)) => self.directory.join(split.to_string()).join(tag.as_str()),
95 other => self.directory.join(other.to_string()),
96 };
97 (split_path.exists() && split_path.is_dir()).then_some(split_path)
98 }
99
100 pub(crate) fn is_epoch_dir<P: AsRef<str>>(dirname: P) -> bool {
101 dirname.as_ref().starts_with(EPOCH_PREFIX)
102 }
103
104 pub(crate) fn epochs(&self) -> usize {
106 if self.is_eval {
107 log::warn!("Number of epochs not available when testing.");
108 return 0;
109 }
110
111 let mut max_epoch = 0;
112
113 for path in fs::read_dir(&self.directory).unwrap() {
115 let path = path.unwrap();
116
117 if fs::metadata(path.path()).unwrap().is_dir() {
118 for split_path in fs::read_dir(path.path()).unwrap() {
119 let split_path = split_path.unwrap();
120
121 if fs::metadata(split_path.path()).unwrap().is_dir() {
122 let dir_name = split_path.file_name().into_string().unwrap();
123
124 if !dir_name.starts_with(EPOCH_PREFIX) {
125 continue;
126 }
127
128 let epoch = dir_name.replace(EPOCH_PREFIX, "").parse::<usize>().ok();
129
130 if let Some(epoch) = epoch
131 && epoch > max_epoch
132 {
133 max_epoch = epoch;
134 }
135 }
136 }
137 }
138 }
139
140 max_epoch
141 }
142
143 fn train_directory(&self, epoch: usize, split: &Split) -> PathBuf {
144 let name = format!("{EPOCH_PREFIX}{epoch}");
145
146 match split {
147 Split::Train | Split::Valid | Split::Test(None) => {
148 self.directory.join(split.to_string()).join(name)
149 }
150 Split::Test(Some(tag)) => {
151 let tag = format_tag(tag);
152 self.directory.join(split.to_string()).join(tag).join(name)
153 }
154 }
155 }
156
157 fn eval_directory(&self, split: &Split) -> PathBuf {
158 match split {
159 Split::Train | Split::Valid | Split::Test(None) => self.directory.clone(),
160 Split::Test(Some(tag)) => self.directory.join(split.to_string()).join(format_tag(tag)),
161 }
162 }
163
164 fn file_path(&self, name: &str, epoch: Option<usize>, split: &Split) -> PathBuf {
165 let directory = match epoch {
166 Some(epoch) => self.train_directory(epoch, split),
167 None => self.eval_directory(split),
168 };
169 let name = name.replace(' ', "_");
170 let name = format!("{name}.log");
171 directory.join(name)
172 }
173
174 fn create_directory(&self, epoch: Option<usize>, split: &Split) {
175 let directory = match epoch {
176 Some(epoch) => self.train_directory(epoch, split),
177 None => self.eval_directory(split),
178 };
179 std::fs::create_dir_all(directory).ok();
180 }
181}
182
183impl FileMetricLogger {
184 fn log_item(&mut self, item: &MetricEntry, epoch: Option<usize>, split: &Split) {
185 let name = &self.metric_definitions.get(&item.metric_id).unwrap().name;
186 let key = logger_key(name, split);
187 let value = &item.serialized_entry.serialized;
188
189 let logger = match self.loggers.get_mut(&key) {
190 Some(val) => val,
191 None => {
192 self.create_directory(epoch, split);
193
194 let file_path = self.file_path(name, epoch, split);
195 let logger = FileLogger::new(file_path);
196 let logger = AsyncLogger::new(logger);
197
198 self.loggers.insert(key.clone(), logger);
199 self.loggers
200 .get_mut(&key)
201 .expect("Can get the previously saved logger.")
202 }
203 };
204
205 logger.log(value.clone());
206 }
207}
208
209fn format_tag(tag: &str) -> String {
210 tag.trim().replace(' ', "-").to_lowercase()
211}
212
213impl MetricLogger for FileMetricLogger {
214 fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split) {
215 if !self.is_eval && self.last_epoch != Some(epoch) {
216 self.loggers.clear();
217 self.last_epoch = Some(epoch);
218 }
219
220 let entries: Vec<_> = update
221 .entries
222 .iter()
223 .chain(
224 update
225 .entries_numeric
226 .iter()
227 .map(|numeric_update| &numeric_update.entry),
228 )
229 .cloned()
230 .collect();
231
232 for item in entries.iter() {
233 self.log_item(item, Some(epoch), split);
234 }
235 }
236
237 fn read_numeric(
238 &mut self,
239 name: &str,
240 epoch: usize,
241 split: &Split,
242 ) -> Result<Vec<NumericEntry>, String> {
243 if let Some(value) = self.loggers.get(name) {
244 value.sync()
245 }
246
247 let file_path = self.file_path(name, Some(epoch), split);
248
249 let mut errors = false;
250
251 let data = std::fs::read_to_string(file_path)
252 .unwrap_or_default()
253 .split('\n')
254 .filter_map(|value| {
255 if value.is_empty() {
256 None
257 } else {
258 match NumericEntry::deserialize(value) {
259 Ok(value) => Some(value),
260 Err(err) => {
261 log::error!("{err}");
262 errors = true;
263 None
264 }
265 }
266 }
267 })
268 .collect();
269
270 if errors {
271 Err("Parsing numeric entry errors".to_string())
272 } else {
273 Ok(data)
274 }
275 }
276
277 fn log_metric_definition(&mut self, definition: MetricDefinition) {
278 self.metric_definitions
279 .insert(definition.metric_id.clone(), definition);
280 }
281
282 fn log_epoch_summary(&mut self, _summary: EpochSummary) {
283 if !self.is_eval {
284 self.loggers.clear();
285 }
286 }
287}
288
289fn logger_key(name: &str, split: &Split) -> String {
290 format!("{name}_{split}")
291}
292
293#[derive(Default)]
295pub struct InMemoryMetricLogger {
296 values: HashMap<String, Vec<InMemoryLogger>>,
297 last_epoch: Option<usize>,
298 metric_definitions: HashMap<MetricId, MetricDefinition>,
299}
300
301impl InMemoryMetricLogger {
302 pub fn new() -> Self {
304 Self::default()
305 }
306}
307
308impl MetricLogger for InMemoryMetricLogger {
309 fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split) {
310 if self.last_epoch != Some(epoch) {
311 self.values
312 .values_mut()
313 .for_each(|loggers| loggers.push(InMemoryLogger::default()));
314 self.last_epoch = Some(epoch);
315 }
316
317 let entries: Vec<_> = update
318 .entries
319 .iter()
320 .chain(
321 update
322 .entries_numeric
323 .iter()
324 .map(|numeric_update| &numeric_update.entry),
325 )
326 .cloned()
327 .collect();
328
329 for item in entries.iter() {
330 let name = &self.metric_definitions.get(&item.metric_id).unwrap().name;
331 let key = logger_key(name, split);
332
333 if !self.values.contains_key(&key) {
334 self.values
335 .insert(key.to_string(), vec![InMemoryLogger::default()]);
336 }
337
338 let values = self.values.get_mut(&key).unwrap();
339
340 values
341 .last_mut()
342 .unwrap()
343 .log(item.serialized_entry.serialized.clone());
344 }
345 }
346
347 fn read_numeric(
348 &mut self,
349 name: &str,
350 epoch: usize,
351 split: &Split,
352 ) -> Result<Vec<NumericEntry>, String> {
353 let key = logger_key(name, split);
354 let values = match self.values.get(&key) {
355 Some(values) => values,
356 None => return Ok(Vec::new()),
357 };
358
359 match values.get(epoch - 1) {
360 Some(logger) => Ok(logger
361 .values
362 .iter()
363 .filter_map(|value| NumericEntry::deserialize(value).ok())
364 .collect()),
365 None => Ok(Vec::new()),
366 }
367 }
368
369 fn log_metric_definition(&mut self, definition: MetricDefinition) {
370 self.metric_definitions
371 .insert(definition.metric_id.clone(), definition);
372 }
373
374 fn log_epoch_summary(&mut self, _summary: EpochSummary) {}
375}