use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MetricValue {
Numerical(f64),
Categorical(String),
}
pub type ErrorMetric<T = f64> = dyn Fn(&[T], &[T]) -> MetricValue + Send + Sync;
pub fn rel_err(actual: &[f64], expected: &[f64]) -> MetricValue {
if actual.is_empty() || expected.is_empty() || actual.len() != expected.len() {
return MetricValue::Numerical(f64::NAN);
}
if actual.len() == 1 {
let a = actual[0];
let e = expected[0];
let abs_err = (a - e).abs();
if abs_err < f64::EPSILON {
return MetricValue::Numerical(0.0);
}
let denom = a.abs().max(e.abs());
let err = if denom == 0.0 { 0.0 } else { abs_err / denom };
return MetricValue::Numerical(err);
}
let mut sq_sum = 0.0;
let mut expected_sq_sum = 0.0;
for (a, e) in actual.iter().zip(expected.iter()) {
let diff = a - e;
sq_sum += diff * diff;
expected_sq_sum += e * e;
}
let diff_norm = sq_sum.sqrt();
let expected_norm = expected_sq_sum.sqrt();
if diff_norm < f64::EPSILON {
return MetricValue::Numerical(0.0);
}
let err = if expected_norm < 1e-12 {
diff_norm
} else {
diff_norm / expected_norm
};
MetricValue::Numerical(err)
}
pub fn rel_err_eps(actual: &[f64], expected: &[f64]) -> MetricValue {
let err = rel_err(actual, expected);
if let MetricValue::Numerical(err) = err {
MetricValue::Numerical(err / f64::EPSILON)
} else {
err
}
}
pub fn abs_err(actual: &[f64], expected: &[f64]) -> MetricValue {
if actual.is_empty() || expected.is_empty() || actual.len() != expected.len() {
return MetricValue::Numerical(f64::NAN);
}
if actual.len() == 1 {
let a = actual[0];
let e = expected[0];
let abs_err = (a - e).abs();
if abs_err < f64::EPSILON {
return MetricValue::Numerical(0.0);
}
return MetricValue::Numerical(abs_err);
}
let mut sq_sum = 0.0;
for (a, e) in actual.iter().zip(expected.iter()) {
let diff = a - e;
sq_sum += diff * diff;
}
let diff_norm = sq_sum.sqrt();
if diff_norm < f64::EPSILON {
return MetricValue::Numerical(0.0);
}
MetricValue::Numerical(diff_norm)
}