use std::sync::Arc;
use burn_core::data::dataloader::Progress;
use burn_optim::LearningRate;
pub struct MetricMetadata {
pub progress: Progress,
pub epoch: usize,
pub epoch_total: usize,
pub iteration: usize,
pub lr: Option<LearningRate>,
}
impl MetricMetadata {
#[cfg(test)]
pub fn fake() -> Self {
Self {
progress: Progress {
items_processed: 1,
items_total: 1,
},
epoch: 0,
epoch_total: 1,
iteration: 0,
lr: None,
}
}
}
pub trait Metric: Send + Sync + Clone {
type Input;
fn name(&self) -> MetricName;
fn update(&mut self, item: &Self::Input, metadata: &MetricMetadata) -> MetricEntry;
fn clear(&mut self);
}
pub type MetricName = Arc<String>;
pub trait Adaptor<T> {
fn adapt(&self) -> T;
}
impl<T> Adaptor<()> for T {
fn adapt(&self) {}
}
pub trait Numeric {
fn value(&self) -> NumericEntry;
}
#[derive(Debug, Clone)]
pub struct MetricEntry {
pub name: Arc<String>,
pub formatted: String,
pub serialize: String,
pub tags: Vec<Arc<String>>,
}
impl MetricEntry {
pub fn new(name: Arc<String>, formatted: String, serialize: String) -> Self {
Self {
name,
formatted,
serialize,
tags: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub enum NumericEntry {
Value(f64),
Aggregated {
sum: f64,
count: usize,
current: f64,
},
}
impl NumericEntry {
pub fn current(&self) -> f64 {
match self {
NumericEntry::Value(val) => *val,
NumericEntry::Aggregated { current, .. } => *current,
}
}
}
impl NumericEntry {
pub fn serialize(&self) -> String {
match self {
Self::Value(v) => v.to_string(),
Self::Aggregated { sum, count, .. } => format!("{sum},{count}"),
}
}
pub fn deserialize(entry: &str) -> Result<Self, String> {
let values = entry.split(',').collect::<Vec<_>>();
let num_values = values.len();
if num_values == 1 {
match values[0].parse::<f64>() {
Ok(value) => Ok(NumericEntry::Value(value)),
Err(err) => Err(err.to_string()),
}
} else if num_values == 2 {
let (value, numel) = (values[0], values[1]);
match value.parse::<f64>() {
Ok(value) => match numel.parse::<usize>() {
Ok(numel) => Ok(NumericEntry::Aggregated {
sum: value,
count: numel,
current: value,
}),
Err(err) => Err(err.to_string()),
},
Err(err) => Err(err.to_string()),
}
} else {
Err("Invalid number of values for numeric entry".to_string())
}
}
}
pub fn format_float(float: f64, precision: usize) -> String {
let scientific_notation_threshold = 0.1_f64.powf(precision as f64 - 1.0);
match scientific_notation_threshold >= float {
true => format!("{float:.precision$e}"),
false => format!("{float:.precision$}"),
}
}