use super::LossFunction;
#[derive(Debug, Clone, Copy, Default)]
pub struct MseLoss;
impl MseLoss {
pub fn new() -> Self {
Self
}
}
impl LossFunction for MseLoss {
#[inline]
fn loss(&self, target: f32, prediction: f32) -> f32 {
let residual = target - prediction;
0.5 * residual * residual
}
#[inline]
fn gradient(&self, target: f32, prediction: f32) -> f32 {
prediction - target
}
#[inline]
fn hessian(&self, _target: f32, _prediction: f32) -> f32 {
1.0
}
#[inline]
fn gradient_hessian(&self, target: f32, prediction: f32) -> (f32, f32) {
(prediction - target, 1.0)
}
fn name(&self) -> &'static str {
"mse"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mse_loss() {
let loss = MseLoss::new();
assert_eq!(loss.loss(10.0, 10.0), 0.0);
assert_eq!(loss.loss(10.0, 12.0), 2.0); assert_eq!(loss.loss(12.0, 10.0), 2.0);
}
#[test]
fn test_mse_gradient() {
let loss = MseLoss::new();
assert_eq!(loss.gradient(10.0, 12.0), 2.0);
assert_eq!(loss.gradient(10.0, 8.0), -2.0);
}
#[test]
fn test_mse_hessian() {
let loss = MseLoss::new();
assert_eq!(loss.hessian(10.0, 12.0), 1.0);
assert_eq!(loss.hessian(0.0, 100.0), 1.0);
}
#[test]
fn test_initial_prediction() {
let loss = MseLoss::new();
let targets = vec![10.0, 20.0, 30.0];
assert_eq!(loss.initial_prediction(&targets), 20.0);
}
}