1use burn_core as burn;
2
3use crate::loss::reduction::Reduction;
4
5use burn::module::Module;
6use burn::tensor::{Tensor, backend::Backend};
7
8#[derive(Module, Clone, Debug)]
10pub struct MseLoss;
11
12impl Default for MseLoss {
13 fn default() -> Self {
14 Self::new()
15 }
16}
17
18impl MseLoss {
19 pub fn new() -> Self {
21 Self
22 }
23
24 pub fn forward<const D: usize, B: Backend>(
31 &self,
32 logits: Tensor<B, D>,
33 targets: Tensor<B, D>,
34 reduction: Reduction,
35 ) -> Tensor<B, 1> {
36 let tensor = self.forward_no_reduction(logits, targets);
37 match reduction {
38 Reduction::Mean | Reduction::Auto => tensor.mean(),
39 Reduction::Sum => tensor.sum(),
40 other => panic!("{other:?} reduction is not supported"),
41 }
42 }
43
44 pub fn forward_no_reduction<const D: usize, B: Backend>(
46 &self,
47 logits: Tensor<B, D>,
48 targets: Tensor<B, D>,
49 ) -> Tensor<B, D> {
50 logits.sub(targets).square()
51 }
52}
53
54#[cfg(test)]
55mod tests {
56 use super::*;
57 use crate::TestBackend;
58 use burn::tensor::TensorData;
59
60 #[test]
61 fn test_mse_loss() {
62 let device = Default::default();
63 let logits = Tensor::<TestBackend, 2>::from_data(
64 TensorData::from([[1.0, 2.0], [3.0, 4.0]]),
65 &device,
66 );
67
68 let targets = Tensor::<TestBackend, 2>::from_data(
69 TensorData::from([[2.0, 1.0], [3.0, 2.0]]),
70 &device,
71 );
72
73 let mse = MseLoss::new();
74 let loss_no_reduction = mse.forward_no_reduction(logits.clone(), targets.clone());
75 let loss = mse.forward(logits.clone(), targets.clone(), Reduction::Auto);
76 let loss_sum = mse.forward(logits, targets, Reduction::Sum);
77
78 let expected = TensorData::from([[1.0, 1.0], [0.0, 4.0]]);
79 loss_no_reduction.into_data().assert_eq(&expected, false);
80
81 let expected = TensorData::from([1.5]);
82 loss.into_data().assert_eq(&expected, false);
83
84 let expected = TensorData::from([6.0]);
85 loss_sum.into_data().assert_eq(&expected, false);
86 }
87
88 #[test]
89 fn display() {
90 let loss = MseLoss::new();
91 assert_eq!(alloc::format!("{loss}"), "MseLoss");
92 }
93}