use super::Loss;
use feos_core::{FeosError, Residual};
use ndarray::Array1;
use std::fmt;
pub trait DataSet<E: Residual>: Send + Sync {
fn target(&self) -> &Array1<f64>;
fn target_str(&self) -> &str;
fn input_str(&self) -> Vec<&str>;
fn predict(&self, eos: &Arc<E>) -> Result<Array1<f64>, FeosError>;
fn cost(&self, eos: &Arc<E>, loss: Loss) -> Result<Array1<f64>, FeosError> {
let mut cost = self.relative_difference(eos)?;
loss.apply(&mut cost);
let datapoints = cost.len();
Ok(cost / datapoints as f64)
}
fn datapoints(&self) -> usize {
self.target().len()
}
fn relative_difference(&self, eos: &Arc<E>) -> Result<Array1<f64>, FeosError> {
let prediction = &self.predict(eos)?;
let target = self.target();
Ok((prediction - target) / target)
}
fn mean_absolute_relative_difference(&self, eos: &Arc<E>) -> Result<f64, FeosError> {
Ok(self
.relative_difference(eos)?
.into_iter()
.filter(|&x| x.is_finite())
.enumerate()
.fold(0.0, |mean, (i, x)| mean + (x.abs() - mean) / (i + 1) as f64))
}
}
impl<E: Residual> fmt::Display for dyn DataSet<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"DataSet(target: {}, input: {}, datapoints: {}",
self.target_str(),
self.input_str().join(", "),
self.datapoints()
)
}
}