use std::sync::Arc;
use burn_core::data::dataloader::Progress;
use burn_optim::LearningRate;
pub struct MetricMetadata {
pub progress: Progress,
pub epoch: usize,
pub epoch_total: usize,
pub iteration: usize,
pub lr: Option<LearningRate>,
}
impl MetricMetadata {
#[cfg(test)]
pub fn fake() -> Self {
Self {
progress: Progress {
items_processed: 1,
items_total: 1,
},
epoch: 0,
epoch_total: 1,
iteration: 0,
lr: None,
}
}
}
#[derive(Debug, Clone, new, PartialEq, Eq, Hash)]
pub struct MetricId {
id: Arc<String>,
}
#[derive(Clone, Debug)]
pub enum MetricAttributes {
Numeric(NumericAttributes),
None,
}
#[derive(Clone, Debug)]
pub struct MetricDefinition {
pub metric_id: MetricId,
pub name: String,
pub description: Option<String>,
pub attributes: MetricAttributes,
}
impl MetricDefinition {
pub fn new<Me: Metric>(metric_id: MetricId, metric: &Me) -> Self {
Self {
metric_id,
name: metric.name().to_string(),
description: metric.description(),
attributes: metric.attributes(),
}
}
}
pub trait Metric: Send + Sync + Clone {
type Input;
fn name(&self) -> MetricName;
fn description(&self) -> Option<String> {
None
}
fn attributes(&self) -> MetricAttributes {
MetricAttributes::None
}
fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> SerializedEntry;
fn clear(&mut self);
}
pub type MetricName = Arc<String>;
pub trait Adaptor<T> {
fn adapt(&self) -> T;
}
impl<T> Adaptor<()> for T {
fn adapt(&self) {}
}
#[derive(Clone, Debug)]
pub struct NumericAttributes {
pub unit: Option<String>,
pub higher_is_better: bool,
}
impl From<NumericAttributes> for MetricAttributes {
fn from(attr: NumericAttributes) -> Self {
MetricAttributes::Numeric(attr)
}
}
impl Default for NumericAttributes {
fn default() -> Self {
Self {
unit: None,
higher_is_better: true,
}
}
}
pub trait Numeric {
fn value(&self) -> NumericEntry;
fn running_value(&self) -> NumericEntry;
}
#[derive(Debug, Clone, new)]
pub struct SerializedEntry {
pub formatted: String,
pub serialized: String,
}
#[derive(Debug, Clone)]
pub struct MetricEntry {
pub metric_id: MetricId,
pub serialized_entry: SerializedEntry,
}
impl MetricEntry {
pub fn new(metric_id: MetricId, serialized_entry: SerializedEntry) -> Self {
Self {
metric_id,
serialized_entry,
}
}
}
#[derive(Debug, Clone)]
pub enum NumericEntry {
Value(f64),
Aggregated {
aggregated_value: f64,
count: usize,
},
}
impl NumericEntry {
pub fn current(&self) -> f64 {
match self {
NumericEntry::Value(val) => *val,
NumericEntry::Aggregated {
aggregated_value, ..
} => *aggregated_value,
}
}
pub fn serialize(&self) -> String {
match self {
Self::Value(v) => v.to_string(),
Self::Aggregated {
aggregated_value,
count,
} => format!("{aggregated_value},{count}"),
}
}
pub fn deserialize(entry: &str) -> Result<Self, String> {
let values = entry.split(',').collect::<Vec<_>>();
let num_values = values.len();
if num_values == 1 {
match values[0].parse::<f64>() {
Ok(value) => Ok(NumericEntry::Value(value)),
Err(err) => Err(err.to_string()),
}
} else if num_values == 2 {
let (value, numel) = (values[0], values[1]);
match value.parse::<f64>() {
Ok(value) => match numel.parse::<usize>() {
Ok(numel) => Ok(NumericEntry::Aggregated {
aggregated_value: value,
count: numel,
}),
Err(err) => Err(err.to_string()),
},
Err(err) => Err(err.to_string()),
}
} else {
Err("Invalid number of values for numeric entry".to_string())
}
}
pub fn better_than(&self, other: &NumericEntry, higher_is_better: bool) -> bool {
(self.current() > other.current()) == higher_is_better
}
}
pub fn format_float(float: f64, precision: usize) -> String {
let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
match scientific_notation_threshold >= float {
true => format!("{float:.precision$e}"),
false => format!("{float:.precision$}"),
}
}