1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
use crate as burn;

use crate::nn::loss::reduction::Reduction;

use crate::module::Module;
use crate::tensor::{backend::Backend, Tensor};

/// Calculate the mean squared error loss from the input logits and the targets.
#[derive(Module, Clone, Debug)]
pub struct MseLoss;

impl Default for MseLoss {
    fn default() -> Self {
        Self::new()
    }
}

impl MseLoss {
    /// Create the criterion.
    pub fn new() -> Self {
        Self
    }

    /// Compute the criterion on the input tensor.
    ///
    /// # Shapes
    ///
    /// - logits: [batch_size, num_targets]
    /// - targets: [batch_size, num_targets]
    pub fn forward<const D: usize, B: Backend>(
        &self,
        logits: Tensor<B, D>,
        targets: Tensor<B, D>,
        reduction: Reduction,
    ) -> Tensor<B, 1> {
        let tensor = self.forward_no_reduction(logits, targets);
        match reduction {
            Reduction::Mean | Reduction::Auto => tensor.mean(),
            Reduction::Sum => tensor.sum(),
        }
    }

    /// Compute the criterion on the input tensor without reducing.
    pub fn forward_no_reduction<const D: usize, B: Backend>(
        &self,
        logits: Tensor<B, D>,
        targets: Tensor<B, D>,
    ) -> Tensor<B, D> {
        logits.sub(targets).powf_scalar(2.0)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::tensor::TensorData;
    use crate::TestBackend;

    #[test]
    fn test_mse_loss() {
        let device = Default::default();
        let logits = Tensor::<TestBackend, 2>::from_data(
            TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
            &device,
        );

        let targets = Tensor::<TestBackend, 2>::from_data(
            TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
            &device,
        );

        let mse = MseLoss::new();
        let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone());
        let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto);
        let loss_sum = mse.forward(logits, targets, Reduction::Sum);

        let expected = TensorData::from([[1.0, 1.0], [0.0, 4.0]]);
        loss_no_reduction.into_data().assert_eq(&expected, false);

        let expected = TensorData::from([1.5]);
        loss.into_data().assert_eq(&expected, false);

        let expected = TensorData::from([6.0]);
        loss_sum.into_data().assert_eq(&expected, false);
    }

    #[test]
    fn display() {
        let loss = MseLoss::new();
        assert_eq!(alloc::format!("{}", loss), "MseLoss");
    }
}