use crate::{flatten_matrix, ToMatrix};
#[derive(Clone)]
pub enum Loss {
MSE,
Custom(fn(&[f64], &[f64]) -> f64),
}
impl Loss {
pub fn calculate<T: ToMatrix>(&mut self, predictions: &T, targets: &T) -> f64 {
let flattened_predictions = flatten_matrix(&predictions.to_matrix());
let flattened_targets = flatten_matrix(&targets.to_matrix());
if flattened_predictions.len() != flattened_targets.len() {
panic!("Predictions and targets length mismatch.");
}
match self {
Loss::MSE => mse(&flattened_predictions, &flattened_targets),
Loss::Custom(function) => function(&flattened_predictions, &flattened_targets),
}
}
}
fn mse(predictions: &[f64], targets: &[f64]) -> f64 {
predictions
.iter()
.zip(targets.iter())
.fold(0.0, |acc, (prediction, target)| {
acc + (prediction - target).powi(2)
})
/ predictions.len() as f64
}