use super::{AsyncLogger, FileLogger, InMemoryLogger, Logger};
use crate::metric::{
MetricDefinition, MetricEntry, MetricId, NumericEntry,
store::{EpochSummary, MetricsUpdate, Split},
};
use std::{
collections::HashMap,
fs,
path::{Path, PathBuf},
};
const EPOCH_PREFIX: &str = "epoch-";
pub trait MetricLogger: Send {
fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split);
fn read_numeric(
&mut self,
name: &str,
epoch: usize,
split: &Split,
) -> Result<Vec<NumericEntry>, String>;
fn log_metric_definition(&mut self, definition: MetricDefinition);
fn log_epoch_summary(&mut self, summary: EpochSummary);
}
pub struct FileMetricLogger {
loggers: HashMap<String, AsyncLogger<String>>,
directory: PathBuf,
metric_definitions: HashMap<MetricId, MetricDefinition>,
is_eval: bool,
last_epoch: Option<usize>,
}
impl FileMetricLogger {
pub fn new(directory: impl AsRef<Path>) -> Self {
Self {
loggers: HashMap::new(),
directory: directory.as_ref().to_path_buf(),
metric_definitions: HashMap::default(),
is_eval: false,
last_epoch: None,
}
}
pub fn new_eval(directory: impl AsRef<Path>) -> Self {
Self {
loggers: HashMap::new(),
directory: directory.as_ref().to_path_buf(),
metric_definitions: HashMap::default(),
is_eval: true,
last_epoch: None,
}
}
pub(crate) fn split_exists(&self, split: &Split) -> bool {
self.split_dir(split).is_some()
}
pub(crate) fn split_dir(&self, split: &Split) -> Option<PathBuf> {
let split_path = match split {
Split::Test(Some(tag)) => self.directory.join(split.to_string()).join(tag.as_str()),
other => self.directory.join(other.to_string()),
};
(split_path.exists() && split_path.is_dir()).then_some(split_path)
}
pub(crate) fn is_epoch_dir<P: AsRef<str>>(dirname: P) -> bool {
dirname.as_ref().starts_with(EPOCH_PREFIX)
}
pub(crate) fn epochs(&self) -> usize {
if self.is_eval {
log::warn!("Number of epochs not available when testing.");
return 0;
}
let mut max_epoch = 0;
for path in fs::read_dir(&self.directory).unwrap() {
let path = path.unwrap();
if fs::metadata(path.path()).unwrap().is_dir() {
for split_path in fs::read_dir(path.path()).unwrap() {
let split_path = split_path.unwrap();
if fs::metadata(split_path.path()).unwrap().is_dir() {
let dir_name = split_path.file_name().into_string().unwrap();
if !dir_name.starts_with(EPOCH_PREFIX) {
continue;
}
let epoch = dir_name.replace(EPOCH_PREFIX, "").parse::<usize>().ok();
if let Some(epoch) = epoch
&& epoch > max_epoch
{
max_epoch = epoch;
}
}
}
}
}
max_epoch
}
fn train_directory(&self, epoch: usize, split: &Split) -> PathBuf {
let name = format!("{EPOCH_PREFIX}{epoch}");
match split {
Split::Train | Split::Valid | Split::Test(None) => {
self.directory.join(split.to_string()).join(name)
}
Split::Test(Some(tag)) => {
let tag = format_tag(tag);
self.directory.join(split.to_string()).join(tag).join(name)
}
}
}
fn eval_directory(&self, split: &Split) -> PathBuf {
match split {
Split::Train | Split::Valid | Split::Test(None) => self.directory.clone(),
Split::Test(Some(tag)) => self.directory.join(split.to_string()).join(format_tag(tag)),
}
}
fn file_path(&self, name: &str, epoch: Option<usize>, split: &Split) -> PathBuf {
let directory = match epoch {
Some(epoch) => self.train_directory(epoch, split),
None => self.eval_directory(split),
};
let name = name.replace(' ', "_");
let name = format!("{name}.log");
directory.join(name)
}
fn create_directory(&self, epoch: Option<usize>, split: &Split) {
let directory = match epoch {
Some(epoch) => self.train_directory(epoch, split),
None => self.eval_directory(split),
};
std::fs::create_dir_all(directory).ok();
}
}
impl FileMetricLogger {
fn log_item(&mut self, item: &MetricEntry, epoch: Option<usize>, split: &Split) {
let name = &self.metric_definitions.get(&item.metric_id).unwrap().name;
let key = logger_key(name, split);
let value = &item.serialized_entry.serialized;
let logger = match self.loggers.get_mut(&key) {
Some(val) => val,
None => {
self.create_directory(epoch, split);
let file_path = self.file_path(name, epoch, split);
let logger = FileLogger::new(file_path);
let logger = AsyncLogger::new(logger);
self.loggers.insert(key.clone(), logger);
self.loggers
.get_mut(&key)
.expect("Can get the previously saved logger.")
}
};
logger.log(value.clone());
}
}
fn format_tag(tag: &str) -> String {
tag.trim().replace(' ', "-").to_lowercase()
}
impl MetricLogger for FileMetricLogger {
fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split) {
if !self.is_eval && self.last_epoch != Some(epoch) {
self.loggers.clear();
self.last_epoch = Some(epoch);
}
let entries: Vec<_> = update
.entries
.iter()
.chain(
update
.entries_numeric
.iter()
.map(|numeric_update| &numeric_update.entry),
)
.cloned()
.collect();
for item in entries.iter() {
self.log_item(item, Some(epoch), split);
}
}
fn read_numeric(
&mut self,
name: &str,
epoch: usize,
split: &Split,
) -> Result<Vec<NumericEntry>, String> {
if let Some(value) = self.loggers.get(name) {
value.sync()
}
let file_path = self.file_path(name, Some(epoch), split);
let mut errors = false;
let data = std::fs::read_to_string(file_path)
.unwrap_or_default()
.split('\n')
.filter_map(|value| {
if value.is_empty() {
None
} else {
match NumericEntry::deserialize(value) {
Ok(value) => Some(value),
Err(err) => {
log::error!("{err}");
errors = true;
None
}
}
}
})
.collect();
if errors {
Err("Parsing numeric entry errors".to_string())
} else {
Ok(data)
}
}
fn log_metric_definition(&mut self, definition: MetricDefinition) {
self.metric_definitions
.insert(definition.metric_id.clone(), definition);
}
fn log_epoch_summary(&mut self, _summary: EpochSummary) {
if !self.is_eval {
self.loggers.clear();
}
}
}
fn logger_key(name: &str, split: &Split) -> String {
format!("{name}_{split}")
}
#[derive(Default)]
pub struct InMemoryMetricLogger {
values: HashMap<String, Vec<InMemoryLogger>>,
last_epoch: Option<usize>,
metric_definitions: HashMap<MetricId, MetricDefinition>,
}
impl InMemoryMetricLogger {
pub fn new() -> Self {
Self::default()
}
}
impl MetricLogger for InMemoryMetricLogger {
fn log(&mut self, update: MetricsUpdate, epoch: usize, split: &Split) {
if self.last_epoch != Some(epoch) {
self.values
.values_mut()
.for_each(|loggers| loggers.push(InMemoryLogger::default()));
self.last_epoch = Some(epoch);
}
let entries: Vec<_> = update
.entries
.iter()
.chain(
update
.entries_numeric
.iter()
.map(|numeric_update| &numeric_update.entry),
)
.cloned()
.collect();
for item in entries.iter() {
let name = &self.metric_definitions.get(&item.metric_id).unwrap().name;
let key = logger_key(name, split);
if !self.values.contains_key(&key) {
self.values
.insert(key.to_string(), vec![InMemoryLogger::default()]);
}
let values = self.values.get_mut(&key).unwrap();
values
.last_mut()
.unwrap()
.log(item.serialized_entry.serialized.clone());
}
}
fn read_numeric(
&mut self,
name: &str,
epoch: usize,
split: &Split,
) -> Result<Vec<NumericEntry>, String> {
let key = logger_key(name, split);
let values = match self.values.get(&key) {
Some(values) => values,
None => return Ok(Vec::new()),
};
match values.get(epoch - 1) {
Some(logger) => Ok(logger
.values
.iter()
.filter_map(|value| NumericEntry::deserialize(value).ok())
.collect()),
None => Ok(Vec::new()),
}
}
fn log_metric_definition(&mut self, definition: MetricDefinition) {
self.metric_definitions
.insert(definition.metric_id.clone(), definition);
}
fn log_epoch_summary(&mut self, _summary: EpochSummary) {}
}