use super::DataSet;
const ZERO: f32 = 0.0;
#[derive(Debug, Clone, Copy)]
pub enum Loss {
MSE,
MAE,
CrossEntropy,
Diff,
}
impl Loss {
pub fn calculate<F>(&self, samples: &DataSet, eval_func: &mut F) -> f32
where
F: FnMut(&Vec<f32>) -> Vec<f32>,
{
let len = samples.len() as f32;
match self {
Loss::MSE => {
let sum = samples
.iter()
.map(|sample| {
let output = eval_func(sample.input());
sample
.output()
.iter()
.zip(output.iter())
.map(|(y_true, y_pred)| {
let diff = y_true - y_pred;
diff * diff
})
.sum::<f32>()
})
.sum::<f32>();
sum / len
}
Loss::MAE => {
let mut sum = ZERO;
for sample in samples.iter() {
let output = eval_func(sample.input());
for i in 0..sample.output().len() {
let diff = sample.output()[i] - output[i];
sum += diff;
}
}
sum /= samples.iter().len() as f32;
sum
}
Loss::CrossEntropy => {
let mut sum = ZERO;
for sample in samples.iter() {
let output = eval_func(sample.input());
for i in 0..sample.output().len() {
sum += sample.output()[i] * output[i].ln();
}
}
sum
}
Loss::Diff => {
let mut sum = ZERO;
for sample in samples.iter() {
let output = eval_func(sample.input());
for i in 0..sample.output().len() {
sum += (sample.output()[i] - output[i]).abs();
}
}
sum
}
}
}
}