ferrite/network/loss/
mean.rs

1use super::loss::*;
2use crate::tensor::*;
3
4pub struct MSELoss {
5  is_mean_reduction: bool
6}
7
8impl MSELoss {
9  pub fn new(reduction: &str) -> Self {
10    let is_mean_reduction = match reduction {
11      "mean" => true,
12      "sum" => false,
13      _ => panic!("Reduction must be either 'mean' or 'sum'"),
14    };
15
16    Self{ is_mean_reduction }
17  }
18}
19
20impl LossTrait for MSELoss {
21  fn loss(&self, x: &Tensor, y: &Tensor) -> Tensor {
22    let z_1 = x.sub_tensor(y);
23    let z_2 = z_1.pow_f32(2.); 
24
25    if self.is_mean_reduction {
26      z_2.mean()
27    } else {
28      z_2.sum()
29    }
30  }
31}
32
33
34pub struct MAELoss {
35  is_mean_reduction: bool
36}
37
38impl MAELoss {
39  pub fn new(reduction: &str) -> Self {
40    let is_mean_reduction = match reduction {
41      "mean" => true,
42      "sum" => false,
43      _ => panic!("Reduction must be either 'mean' or 'sum'"),
44    };
45
46    Self{ is_mean_reduction }
47  }
48}
49
50impl LossTrait for MAELoss {
51  fn loss(&self, x: &Tensor, y: &Tensor) -> Tensor {
52    let z_1 = x.sub_tensor(y);
53    let z_2 = z_1.abs(); 
54
55    if self.is_mean_reduction {
56      z_2.mean()
57    } else {
58      z_2.sum()
59    }
60  }
61}