use super::Loss;
use candle_core::{Result, Tensor};
pub struct MSE;
impl Loss for MSE {
fn compute(&self, y_pred: &Tensor, y_true: &Tensor) -> Result<(f32, Tensor)> {
let diff = y_pred.sub(y_true)?;
let squared_diff = diff.sqr()?;
let loss_tensor = squared_diff.mean_all()?;
let loss_value = loss_tensor.to_scalar::<f32>()?;
let n = y_pred.dim(0)? as f64;
let d_loss_d_y_pred = (diff * (2.0 / n))?;
Ok((loss_value, d_loss_d_y_pred))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Device;
#[test]
fn test_mse_compute() -> Result<()> {
let device = Device::Cpu;
let mse = MSE;
let y_pred = Tensor::new(&[[1.0f32], [2.0], [3.0]], &device.as_candle().unwrap())?;
let y_true = Tensor::new(&[[1.0f32], [1.0], [5.0]], &device.as_candle().unwrap())?;
let (loss_val, grad) = mse.compute(&y_pred, &y_true)?;
assert!((loss_val - 1.6666666).abs() < 1e-5);
let grad_vec = grad.to_vec2::<f32>()?;
assert!((grad_vec[0][0] - 0.0).abs() < 1e-5);
assert!((grad_vec[1][0] - 0.6666666).abs() < 1e-5);
assert!((grad_vec[2][0] + 1.3333333).abs() < 1e-5);
Ok(())
}
}