1use crate as burn;
2
3use crate::nn::loss::reduction::Reduction;
4
5use crate::module::Module;
6use crate::tensor::{Tensor, backend::Backend};
7
8#[derive(Module, Clone, Debug)]
10pub struct MseLoss;
11
12impl Default for MseLoss {
13 fn default() -> Self {
14 Self::new()
15 }
16}
17
18impl MseLoss {
19 pub fn new() -> Self {
21 Self
22 }
23
24 pub fn forward<const D: usize, B: Backend>(
31 &self,
32 logits: Tensor<B, D>,
33 targets: Tensor<B, D>,
34 reduction: Reduction,
35 ) -> Tensor<B, 1> {
36 let tensor = self.forward_no_reduction(logits, targets);
37 match reduction {
38 Reduction::Mean | Reduction::Auto => tensor.mean(),
39 Reduction::Sum => tensor.sum(),
40 }
41 }
42
43 pub fn forward_no_reduction<const D: usize, B: Backend>(
45 &self,
46 logits: Tensor<B, D>,
47 targets: Tensor<B, D>,
48 ) -> Tensor<B, D> {
49 logits.sub(targets).powi_scalar(2)
50 }
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56 use crate::TestBackend;
57 use crate::tensor::TensorData;
58
59 #[test]
60 fn test_mse_loss() {
61 let device = Default::default();
62 let logits = Tensor::<TestBackend, 2>::from_data(
63 TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
64 &device,
65 );
66
67 let targets = Tensor::<TestBackend, 2>::from_data(
68 TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
69 &device,
70 );
71
72 let mse = MseLoss::new();
73 let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone());
74 let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto);
75 let loss_sum = mse.forward(logits, targets, Reduction::Sum);
76
77 let expected = TensorData::from([[1.0, 1.0], [0.0, 4.0]]);
78 loss_no_reduction.into_data().assert_eq(&expected, false);
79
80 let expected = TensorData::from([1.5]);
81 loss.into_data().assert_eq(&expected, false);
82
83 let expected = TensorData::from([6.0]);
84 loss_sum.into_data().assert_eq(&expected, false);
85 }
86
87 #[test]
88 fn display() {
89 let loss = MseLoss::new();
90 assert_eq!(alloc::format!("{loss}"), "MseLoss");
91 }
92}