burn_core/nn/loss/
mse.rs

1use crate as burn;
2
3use crate::nn::loss::reduction::Reduction;
4
5use crate::module::Module;
6use crate::tensor::{Tensor, backend::Backend};
7
8/// Calculate the mean squared error loss from the input logits and the targets.
9#[derive(Module, Clone, Debug)]
10pub struct MseLoss;
11
12impl Default for MseLoss {
13    fn default() -> Self {
14        Self::new()
15    }
16}
17
18impl MseLoss {
19    /// Create the criterion.
20    pub fn new() -> Self {
21        Self
22    }
23
24    /// Compute the criterion on the input tensor.
25    ///
26    /// # Shapes
27    ///
28    /// - logits: [batch_size, num_targets]
29    /// - targets: [batch_size, num_targets]
30    pub fn forward<const D: usize, B: Backend>(
31        &self,
32        logits: Tensor<B, D>,
33        targets: Tensor<B, D>,
34        reduction: Reduction,
35    ) -> Tensor<B, 1> {
36        let tensor = self.forward_no_reduction(logits, targets);
37        match reduction {
38            Reduction::Mean | Reduction::Auto => tensor.mean(),
39            Reduction::Sum => tensor.sum(),
40        }
41    }
42
43    /// Compute the criterion on the input tensor without reducing.
44    pub fn forward_no_reduction<const D: usize, B: Backend>(
45        &self,
46        logits: Tensor<B, D>,
47        targets: Tensor<B, D>,
48    ) -> Tensor<B, D> {
49        logits.sub(targets).powi_scalar(2)
50    }
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56    use crate::TestBackend;
57    use crate::tensor::TensorData;
58
59    #[test]
60    fn test_mse_loss() {
61        let device = Default::default();
62        let logits = Tensor::<TestBackend, 2>::from_data(
63            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
64            &device,
65        );
66
67        let targets = Tensor::<TestBackend, 2>::from_data(
68            TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
69            &device,
70        );
71
72        let mse = MseLoss::new();
73        let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone());
74        let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto);
75        let loss_sum = mse.forward(logits, targets, Reduction::Sum);
76
77        let expected = TensorData::from([[1.0, 1.0], [0.0, 4.0]]);
78        loss_no_reduction.into_data().assert_eq(&expected, false);
79
80        let expected = TensorData::from([1.5]);
81        loss.into_data().assert_eq(&expected, false);
82
83        let expected = TensorData::from([6.0]);
84        loss_sum.into_data().assert_eq(&expected, false);
85    }
86
87    #[test]
88    fn display() {
89        let loss = MseLoss::new();
90        assert_eq!(alloc::format!("{loss}"), "MseLoss");
91    }
92}