use crate::{common::matrix::DMat, error::NetworkError, regression::RegressionEvaluator, MetricEvaluator, Metrics};
use serde::{Deserialize, Serialize};
use typetag;
use super::{LossFunction, LossFunctionClone};
#[derive(Serialize, Deserialize, Clone)]
struct MeanSquaredErrorLoss;
pub struct MeanSquared;
impl MeanSquared {
fn new() -> Self {
Self {}
}
pub fn build(self) -> Result<Box<dyn LossFunction>, NetworkError> {
Ok(Box::new(MeanSquaredErrorLoss {}))
}
}
impl Default for MeanSquared {
fn default() -> Self {
Self::new()
}
}
impl LossFunctionClone for MeanSquaredErrorLoss {
fn clone_box(&self) -> Box<dyn LossFunction> {
Box::new(self.clone())
}
}
#[typetag::serde]
impl LossFunction for MeanSquaredErrorLoss {
fn forward(&self, predicted: &DMat, target: &DMat) -> f32 {
let mut loss = 0.0;
let rows = predicted.rows();
for i in 0..rows {
let diff = predicted.at(i, 0) - target.at(i, 0);
loss += diff * diff;
}
loss / rows as f32
}
fn backward(&self, predicted: &DMat, target: &DMat) -> DMat {
let (rows, cols) = (predicted.rows(), predicted.cols());
let mut gradient = DMat::zeros(rows, cols);
gradient.apply_with_indices(|i, j, v| {
let diff = predicted.at(i, j) - target.at(i, j);
*v = 2.0 * diff / rows as f32;
});
gradient
}
fn calculate_metrics(&self, targets: &DMat, predictions: &DMat) -> Metrics {
RegressionEvaluator.evaluate(targets, predictions)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{common::matrix::DMat, util};
#[test]
fn test_forward() {
let loss = MeanSquared.build().unwrap();
let predicted = DMat::new(2, 1, &[0.9, 0.2]);
let target = DMat::new(2, 1, &[1.0, 0.0]);
let result = loss.forward(&predicted, &target);
assert!((result - 0.025).abs() < 1e-6);
}
#[test]
fn test_backward() {
let loss = MeanSquared.build().unwrap();
let predicted = DMat::new(2, 1, &[0.9, 0.2]);
let target = DMat::new(2, 1, &[1.0, 0.0]);
let gradient = loss.backward(&predicted, &target);
let expected_gradient = DMat::new(2, 1, &[-0.1, 0.2]);
assert!(util::equal_approx(&gradient, &expected_gradient, 1e-6));
}
}