Skip to main content

entrenar/train/loss/mse/
mse_loss.rs

1//! Mean Squared Error Loss
2//!
3//! Forward scalar computation delegates to [`aprender::loss::mse_loss`].
4//! Gradient computation (backward) is entrenar's autograd concern.
5
6use aprender::primitives::Vector;
7
8use crate::autograd::BackwardOp;
9use crate::Tensor;
10use ndarray::Array1;
11use std::rc::Rc;
12
13use crate::train::loss::LossFn;
14
15/// Mean Squared Error Loss
16///
17/// L = mean((predictions - targets)^2)
18///
19/// Forward scalar delegates to [`aprender::loss::mse_loss`].
20/// Backward gradient is computed by entrenar's autograd.
21///
22/// # Example
23///
24/// ```
25/// use entrenar::train::{MSELoss, LossFn};
26/// use entrenar::Tensor;
27///
28/// let loss_fn = MSELoss;
29/// let pred = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
30/// let target = Tensor::from_vec(vec![1.5, 2.5, 3.5], false);
31///
32/// let loss = loss_fn.forward(&pred, &target);
33/// assert!(loss.data()[0] > 0.0);
34/// ```
35pub struct MSELoss;
36
37impl LossFn for MSELoss {
38    fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
39        assert_eq!(
40            predictions.len(),
41            targets.len(),
42            "Predictions and targets must have same length"
43        );
44
45        // Delegate forward scalar to aprender
46        let pred_vec =
47            Vector::from_slice(predictions.data().as_slice().expect("contiguous tensor data"));
48        let tgt_vec =
49            Vector::from_slice(targets.data().as_slice().expect("contiguous tensor data"));
50        let mse = aprender::loss::mse_loss(&pred_vec, &tgt_vec);
51
52        let mut loss = Tensor::from_vec(vec![mse], true);
53
54        // Gradient computation is entrenar's autograd concern
55        // d(MSE)/d(pred) = 2 * (pred - target) / n
56        let diff = predictions.data() - targets.data();
57        let n = predictions.len() as f32;
58        let grad = &diff * (2.0 / n);
59
60        struct MSEBackward {
61            pred_grad_cell: Rc<std::cell::RefCell<Option<Array1<f32>>>>,
62            grad: Array1<f32>,
63        }
64
65        impl BackwardOp for MSEBackward {
66            fn backward(&self) {
67                let mut pred_grad = self.pred_grad_cell.borrow_mut();
68                if let Some(existing) = pred_grad.as_mut() {
69                    *existing = &*existing + &self.grad;
70                } else {
71                    *pred_grad = Some(self.grad.clone());
72                }
73            }
74        }
75
76        if predictions.requires_grad() {
77            loss.set_backward_op(Rc::new(MSEBackward {
78                pred_grad_cell: predictions.grad_cell(),
79                grad,
80            }));
81        }
82
83        loss
84    }
85
86    fn name(&self) -> &'static str {
87        "MSE"
88    }
89}