use crate::neural_network::Tensor;
use crate::neural_network::neural_network_trait::LossFunction;
pub struct MeanSquaredError;
impl MeanSquaredError {
pub fn new() -> Self {
Self {}
}
}
impl LossFunction for MeanSquaredError {
fn compute_loss(&self, y_true: &Tensor, y_pred: &Tensor) -> f32 {
let squared_diff = (y_pred - y_true).mapv(|x| x * x);
let n = squared_diff.len() as f32;
squared_diff.sum() / n
}
fn compute_grad(&self, y_true: &Tensor, y_pred: &Tensor) -> Tensor {
let diff = y_pred - y_true;
let n = diff.len() as f32;
let mut result = diff.clone();
result.par_mapv_inplace(|x| 2.0 * x / n);
result
}
}