use crate::{
logger::MetricLogger,
metric::{NumericEntry, store::Split},
};
use std::collections::HashMap;
use super::{Aggregate, Direction};
#[derive(Default, Debug)]
pub(crate) struct NumericMetricsAggregate {
value_for_each_epoch: HashMap<Key, f64>,
}
#[derive(new, Hash, PartialEq, Eq, Debug)]
struct Key {
name: String,
epoch: usize,
split: Split,
aggregate: Aggregate,
}
impl NumericMetricsAggregate {
pub(crate) fn aggregate(
&mut self,
name: &str,
epoch: usize,
split: &Split,
aggregate: Aggregate,
loggers: &mut [Box<dyn MetricLogger>],
) -> Option<f64> {
let key = Key::new(name.to_string(), epoch, split.clone(), aggregate);
if let Some(value) = self.value_for_each_epoch.get(&key) {
return Some(*value);
}
let points = || {
let mut errors = Vec::new();
for logger in loggers {
match logger.read_numeric(name, epoch, split) {
Ok(points) => return Ok(points),
Err(err) => errors.push(err),
};
}
Err(errors.join(" "))
};
let points = points().expect("Can read values");
if points.is_empty() {
return None;
}
let (sum, num_points) = points
.into_iter()
.map(|entry| match entry {
NumericEntry::Value(v) => (v, 1),
NumericEntry::Aggregated {
aggregated_value,
count,
} => (aggregated_value * count as f64, count),
})
.reduce(|(acc_v, acc_n), (v, n)| (acc_v + v, acc_n + n))
.unwrap();
let value = match aggregate {
Aggregate::Mean => sum / num_points as f64,
};
self.value_for_each_epoch.insert(key, value);
Some(value)
}
pub(crate) fn find_epoch(
&mut self,
name: &str,
split: &Split,
aggregate: Aggregate,
direction: Direction,
loggers: &mut [Box<dyn MetricLogger>],
) -> Option<usize> {
let mut data = Vec::new();
let mut current_epoch = 1;
while let Some(value) = self.aggregate(name, current_epoch, split, aggregate, loggers) {
data.push(value);
current_epoch += 1;
}
if data.is_empty() {
return None;
}
let mut current_value = match &direction {
Direction::Lowest => f64::MAX,
Direction::Highest => f64::MIN,
};
for (i, value) in data.into_iter().enumerate() {
match &direction {
Direction::Lowest => {
if value < current_value {
current_value = value;
current_epoch = i + 1;
}
}
Direction::Highest => {
if value > current_value {
current_value = value;
current_epoch = i + 1;
}
}
}
}
Some(current_epoch)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::{
logger::{FileMetricLogger, InMemoryMetricLogger},
metric::{MetricDefinition, MetricEntry, MetricId, SerializedEntry, store::MetricsUpdate},
};
use super::*;
struct TestLogger {
logger: FileMetricLogger,
epoch: usize,
}
const NAME: &str = "test-logger";
impl TestLogger {
fn new() -> Self {
Self {
logger: FileMetricLogger::new("/tmp"),
epoch: 1,
}
}
fn log(&mut self, num: f64) {
let entry = MetricEntry::new(
MetricId::new(Arc::new(NAME.into())),
SerializedEntry::new(num.to_string(), num.to_string()),
);
let entries = Vec::from([entry]);
let metrics_update = MetricsUpdate::new(entries, vec![]);
self.logger.log(metrics_update, self.epoch, &Split::Train);
}
fn log_definition(&mut self) {
let definition = MetricDefinition {
metric_id: MetricId::new(Arc::new(NAME.into())),
name: NAME.into(),
attributes: crate::metric::MetricAttributes::None,
description: None,
};
self.logger.log_metric_definition(definition);
}
fn new_epoch(&mut self) {
self.epoch += 1;
}
}
#[test]
fn should_find_epoch() {
let mut logger = TestLogger::new();
let mut aggregate = NumericMetricsAggregate::default();
logger.log_definition();
logger.log(500.); logger.log(1000.); logger.new_epoch();
logger.log(200.); logger.log(1000.); logger.new_epoch();
logger.log(10000.);
let value = aggregate
.find_epoch(
NAME,
&Split::Train,
Aggregate::Mean,
Direction::Lowest,
&mut [Box::new(logger.logger)],
)
.unwrap();
assert_eq!(value, 2);
}
#[test]
fn should_aggregate_numeric_entry() {
let mut logger = InMemoryMetricLogger::default();
let mut aggregate = NumericMetricsAggregate::default();
let metric_name = Arc::new("Loss".to_string());
let metric_id = MetricId::new(metric_name.clone());
let definition = MetricDefinition {
metric_id: metric_id.clone(),
name: metric_name.to_string(),
attributes: crate::metric::MetricAttributes::None,
description: None,
};
logger.log_metric_definition(definition);
let loss_1 = 0.5;
let loss_2 = 1.25; let entry = MetricEntry::new(
metric_id.clone(),
SerializedEntry::new(loss_1.to_string(), NumericEntry::Value(loss_1).serialize()),
);
let entries = Vec::from([entry]);
let metrics_update = MetricsUpdate::new(entries, vec![]);
logger.log(metrics_update, 1, &Split::Train);
let entry = MetricEntry::new(
metric_id.clone(),
SerializedEntry::new(
loss_2.to_string(),
NumericEntry::Aggregated {
aggregated_value: loss_2,
count: 2,
}
.serialize(),
),
);
let entries = Vec::from([entry]);
let metrics_update = MetricsUpdate::new(entries, vec![]);
logger.log(metrics_update, 1, &Split::Train);
let value = aggregate
.aggregate(
&metric_name,
1,
&Split::Train,
Aggregate::Mean,
&mut [Box::new(logger)],
)
.unwrap();
assert_eq!(value, 1.0);
}
}