Skip to main content

burn_nn/loss/
mse.rs

1use burn_core as burn;
2
3use crate::loss::reduction::Reduction;
4
5use burn::module::Module;
6use burn::tensor::{Tensor, backend::Backend};
7
8/// Calculate the mean squared error loss from the input logits and the targets.
9#[derive(Module, Clone, Debug)]
10pub struct MseLoss;
11
12impl Default for MseLoss {
13    fn default() -> Self {
14        Self::new()
15    }
16}
17
18impl MseLoss {
19    /// Create the criterion.
20    pub fn new() -> Self {
21        Self
22    }
23
24    /// Compute the criterion on the input tensor.
25    ///
26    /// # Shapes
27    ///
28    /// - logits: [batch_size, num_targets]
29    /// - targets: [batch_size, num_targets]
30    pub fn forward<const D: usize, B: Backend>(
31        &self,
32        logits: Tensor<B, D>,
33        targets: Tensor<B, D>,
34        reduction: Reduction,
35    ) -> Tensor<B, 1> {
36        let tensor = self.forward_no_reduction(logits, targets);
37        match reduction {
38            Reduction::Mean | Reduction::Auto => tensor.mean(),
39            Reduction::Sum => tensor.sum(),
40            other => panic!("{other:?} reduction is not supported"),
41        }
42    }
43
44    /// Compute the criterion on the input tensor without reducing.
45    pub fn forward_no_reduction<const D: usize, B: Backend>(
46        &self,
47        logits: Tensor<B, D>,
48        targets: Tensor<B, D>,
49    ) -> Tensor<B, D> {
50        logits.sub(targets).square()
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57    use crate::TestBackend;
58    use burn::tensor::TensorData;
59
60    #[test]
61    fn test_mse_loss() {
62        let device = Default::default();
63        let logits = Tensor::<TestBackend, 2>::from_data(
64            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
65            &device,
66        );
67
68        let targets = Tensor::<TestBackend, 2>::from_data(
69            TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
70            &device,
71        );
72
73        let mse = MseLoss::new();
74        let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone());
75        let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto);
76        let loss_sum = mse.forward(logits, targets, Reduction::Sum);
77
78        let expected = TensorData::from([[1.0, 1.0], [0.0, 4.0]]);
79        loss_no_reduction.into_data().assert_eq(&expected, false);
80
81        let expected = TensorData::from([1.5]);
82        loss.into_data().assert_eq(&expected, false);
83
84        let expected = TensorData::from([6.0]);
85        loss_sum.into_data().assert_eq(&expected, false);
86    }
87
88    #[test]
89    fn display() {
90        let loss = MseLoss::new();
91        assert_eq!(alloc::format!("{loss}"), "MseLoss");
92    }
93}