1use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
2use crate::metric::{
3 MetricDefinition, MetricEntry, MetricId, NumericEntry,
4 store::{EpochSummary, MetricsUpdate, Split},
5};
6use std::{
7 collections::HashMap,
8 fs,
9 path::{Path, PathBuf},
10 sync::Arc,
11};
12
13const EPOCH_PREFIX: &str = "epoch-";
14
15pub trait MetricLogger: Send {
17 fn log(&mut self, update: MetricsUpdate, epoch: usize, split: Split, tag: Option<Arc<String>>);
27
28 fn read_numeric(
30 &mut self,
31 name: &str,
32 epoch: usize,
33 split: Split,
34 ) -> Result<Vec<NumericEntry>, String>;
35
36 fn log_metric_definition(&mut self, definition: MetricDefinition);
38
39 fn log_epoch_summary(&mut self, summary: EpochSummary);
41}
42
43pub struct FileMetricLogger {
45 loggers: HashMap<String, AsyncLogger<String>>,
46 directory: PathBuf,
47 metric_definitions: HashMap<MetricId, MetricDefinition>,
48 is_eval: bool,
49 last_epoch: Option<usize>,
50}
51
52impl FileMetricLogger {
53 pub fn new(directory: impl AsRef<Path>) -> Self {
63 Self {
64 loggers: HashMap::new(),
65 directory: directory.as_ref().to_path_buf(),
66 metric_definitions: HashMap::default(),
67 is_eval: false,
68 last_epoch: None,
69 }
70 }
71
72 pub fn new_eval(directory: impl AsRef<Path>) -> Self {
82 Self {
83 loggers: HashMap::new(),
84 directory: directory.as_ref().to_path_buf(),
85 metric_definitions: HashMap::default(),
86 is_eval: true,
87 last_epoch: None,
88 }
89 }
90
91 pub(crate) fn split_exists(&self, split: Split) -> bool {
92 let split_path = self.directory.join(split.to_string());
93 split_path.exists() && split_path.is_dir()
94 }
95
96 pub(crate) fn epochs(&self) -> usize {
98 if self.is_eval {
99 log::warn!("Number of epochs not available when testing.");
100 return 0;
101 }
102
103 let mut max_epoch = 0;
104
105 for path in fs::read_dir(&self.directory).unwrap() {
107 let path = path.unwrap();
108
109 if fs::metadata(path.path()).unwrap().is_dir() {
110 for split_path in fs::read_dir(path.path()).unwrap() {
111 let split_path = split_path.unwrap();
112
113 if fs::metadata(split_path.path()).unwrap().is_dir() {
114 let dir_name = split_path.file_name().into_string().unwrap();
115
116 if !dir_name.starts_with(EPOCH_PREFIX) {
117 continue;
118 }
119
120 let epoch = dir_name.replace(EPOCH_PREFIX, "").parse::<usize>().ok();
121
122 if let Some(epoch) = epoch
123 && epoch > max_epoch
124 {
125 max_epoch = epoch;
126 }
127 }
128 }
129 }
130 }
131
132 max_epoch
133 }
134
135 fn train_directory(&self, tag: Option<&String>, epoch: usize, split: Split) -> PathBuf {
136 let name = format!("{EPOCH_PREFIX}{epoch}");
137
138 match tag {
139 Some(tag) => self.directory.join(split.to_string()).join(tag).join(name),
140 None => self.directory.join(split.to_string()).join(name),
141 }
142 }
143
144 fn eval_directory(&self, tag: Option<&String>, split: Split) -> PathBuf {
145 match tag {
146 Some(tag) => self.directory.join(split.to_string()).join(tag),
147 None => self.directory.clone(),
148 }
149 }
150
151 fn file_path(
152 &self,
153 tag: Option<&String>,
154 name: &str,
155 epoch: Option<usize>,
156 split: Split,
157 ) -> PathBuf {
158 let directory = match epoch {
159 Some(epoch) => self.train_directory(tag, epoch, split),
160 None => self.eval_directory(tag, split),
161 };
162 let name = name.replace(' ', "_");
163 let name = format!("{name}.log");
164 directory.join(name)
165 }
166
167 fn create_directory(&self, tag: Option<&String>, epoch: Option<usize>, split: Split) {
168 let directory = match epoch {
169 Some(epoch) => self.train_directory(tag, epoch, split),
170 None => self.eval_directory(tag, split),
171 };
172 std::fs::create_dir_all(directory).ok();
173 }
174}
175
176impl FileMetricLogger {
177 fn log_item(
178 &mut self,
179 tag: Option<&String>,
180 item: &MetricEntry,
181 epoch: Option<usize>,
182 split: Split,
183 ) {
184 let name = &self.metric_definitions.get(&item.metric_id).unwrap().name;
185 let key = logger_key(name, split);
186 let value = &item.serialized_entry.serialized;
187
188 let logger = match self.loggers.get_mut(&key) {
189 Some(val) => val,
190 None => {
191 self.create_directory(tag, epoch, split);
192
193 let file_path = self.file_path(tag, name, epoch, split);
194 let logger = FileLogger::new(file_path);
195 let logger = AsyncLogger::new(logger);
196
197 self.loggers.insert(key.clone(), logger);
198 self.loggers
199 .get_mut(&key)
200 .expect("Can get the previously saved logger.")
201 }
202 };
203
204 logger.log(value.clone());
205 }
206}
207
208impl MetricLogger for FileMetricLogger {
209 fn log(&mut self, update: MetricsUpdate, epoch: usize, split: Split, tag: Option<Arc<String>>) {
210 if !self.is_eval && self.last_epoch != Some(epoch) {
211 self.loggers.clear();
212 self.last_epoch = Some(epoch);
213 }
214
215 let entries: Vec<_> = update
216 .entries
217 .iter()
218 .chain(
219 update
220 .entries_numeric
221 .iter()
222 .map(|numeric_update| &numeric_update.entry),
223 )
224 .cloned()
225 .collect();
226
227 for item in entries.iter() {
228 match tag {
229 Some(ref tag) => {
230 let tag = tag.trim().replace(' ', "-").to_lowercase();
231 self.log_item(Some(&tag), item, Some(epoch), split);
232 }
233 None => self.log_item(None, item, Some(epoch), split),
234 }
235 }
236 }
237
238 fn read_numeric(
239 &mut self,
240 name: &str,
241 epoch: usize,
242 split: Split,
243 ) -> Result<Vec<NumericEntry>, String> {
244 if let Some(value) = self.loggers.get(name) {
245 value.sync()
246 }
247
248 let file_path = self.file_path(None, name, Some(epoch), split);
249
250 let mut errors = false;
251
252 let data = std::fs::read_to_string(file_path)
253 .unwrap_or_default()
254 .split('\n')
255 .filter_map(|value| {
256 if value.is_empty() {
257 None
258 } else {
259 match NumericEntry::deserialize(value) {
260 Ok(value) => Some(value),
261 Err(err) => {
262 log::error!("{err}");
263 errors = true;
264 None
265 }
266 }
267 }
268 })
269 .collect();
270
271 if errors {
272 Err("Parsing numeric entry errors".to_string())
273 } else {
274 Ok(data)
275 }
276 }
277
278 fn log_metric_definition(&mut self, definition: MetricDefinition) {
279 self.metric_definitions
280 .insert(definition.metric_id.clone(), definition);
281 }
282
283 fn log_epoch_summary(&mut self, _summary: EpochSummary) {
284 if !self.is_eval {
285 self.loggers.clear();
286 }
287 }
288}
289
290fn logger_key(name: &str, split: Split) -> String {
291 format!("{name}_{split}")
292}
293
294#[derive(Default)]
296pub struct InMemoryMetricLogger {
297 values: HashMap<String, Vec<InMemoryLogger>>,
298 last_epoch: Option<usize>,
299 metric_definitions: HashMap<MetricId, MetricDefinition>,
300}
301
302impl InMemoryMetricLogger {
303 pub fn new() -> Self {
305 Self::default()
306 }
307}
308
309impl MetricLogger for InMemoryMetricLogger {
310 fn log(
311 &mut self,
312 update: MetricsUpdate,
313 epoch: usize,
314 split: Split,
315 _tag: Option<Arc<String>>,
316 ) {
317 if self.last_epoch != Some(epoch) {
318 self.values
319 .values_mut()
320 .for_each(|loggers| loggers.push(InMemoryLogger::default()));
321 self.last_epoch = Some(epoch);
322 }
323
324 let entries: Vec<_> = update
325 .entries
326 .iter()
327 .chain(
328 update
329 .entries_numeric
330 .iter()
331 .map(|numeric_update| &numeric_update.entry),
332 )
333 .cloned()
334 .collect();
335
336 for item in entries.iter() {
337 let name = &self.metric_definitions.get(&item.metric_id).unwrap().name;
338 let key = logger_key(name, split);
339
340 if !self.values.contains_key(&key) {
341 self.values
342 .insert(key.to_string(), vec![InMemoryLogger::default()]);
343 }
344
345 let values = self.values.get_mut(&key).unwrap();
346
347 values
348 .last_mut()
349 .unwrap()
350 .log(item.serialized_entry.serialized.clone());
351 }
352 }
353
354 fn read_numeric(
355 &mut self,
356 name: &str,
357 epoch: usize,
358 split: Split,
359 ) -> Result<Vec<NumericEntry>, String> {
360 let key = logger_key(name, split);
361 let values = match self.values.get(&key) {
362 Some(values) => values,
363 None => return Ok(Vec::new()),
364 };
365
366 match values.get(epoch - 1) {
367 Some(logger) => Ok(logger
368 .values
369 .iter()
370 .filter_map(|value| NumericEntry::deserialize(value).ok())
371 .collect()),
372 None => Ok(Vec::new()),
373 }
374 }
375
376 fn log_metric_definition(&mut self, definition: MetricDefinition) {
377 self.metric_definitions
378 .insert(definition.metric_id.clone(), definition);
379 }
380
381 fn log_epoch_summary(&mut self, _summary: EpochSummary) {}
382}