entrenar/train/loss/mse/
mse_loss.rs1use aprender::primitives::Vector;
7
8use crate::autograd::BackwardOp;
9use crate::Tensor;
10use ndarray::Array1;
11use std::rc::Rc;
12
13use crate::train::loss::LossFn;
14
15pub struct MSELoss;
36
37impl LossFn for MSELoss {
38 fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
39 assert_eq!(
40 predictions.len(),
41 targets.len(),
42 "Predictions and targets must have same length"
43 );
44
45 let pred_vec =
47 Vector::from_slice(predictions.data().as_slice().expect("contiguous tensor data"));
48 let tgt_vec =
49 Vector::from_slice(targets.data().as_slice().expect("contiguous tensor data"));
50 let mse = aprender::loss::mse_loss(&pred_vec, &tgt_vec);
51
52 let mut loss = Tensor::from_vec(vec![mse], true);
53
54 let diff = predictions.data() - targets.data();
57 let n = predictions.len() as f32;
58 let grad = &diff * (2.0 / n);
59
60 struct MSEBackward {
61 pred_grad_cell: Rc<std::cell::RefCell<Option<Array1<f32>>>>,
62 grad: Array1<f32>,
63 }
64
65 impl BackwardOp for MSEBackward {
66 fn backward(&self) {
67 let mut pred_grad = self.pred_grad_cell.borrow_mut();
68 if let Some(existing) = pred_grad.as_mut() {
69 *existing = &*existing + &self.grad;
70 } else {
71 *pred_grad = Some(self.grad.clone());
72 }
73 }
74 }
75
76 if predictions.requires_grad() {
77 loss.set_backward_op(Rc::new(MSEBackward {
78 pred_grad_cell: predictions.grad_cell(),
79 grad,
80 }));
81 }
82
83 loss
84 }
85
86 fn name(&self) -> &'static str {
87 "MSE"
88 }
89}