ferrite/network/loss/
mean.rs1use super::loss::*;
2use crate::tensor::*;
3
4pub struct MSELoss {
5 is_mean_reduction: bool
6}
7
8impl MSELoss {
9 pub fn new(reduction: &str) -> Self {
10 let is_mean_reduction = match reduction {
11 "mean" => true,
12 "sum" => false,
13 _ => panic!("Reduction must be either 'mean' or 'sum'"),
14 };
15
16 Self{ is_mean_reduction }
17 }
18}
19
20impl LossTrait for MSELoss {
21 fn loss(&self, x: &Tensor, y: &Tensor) -> Tensor {
22 let z_1 = x.sub_tensor(y);
23 let z_2 = z_1.pow_f32(2.);
24
25 if self.is_mean_reduction {
26 z_2.mean()
27 } else {
28 z_2.sum()
29 }
30 }
31}
32
33
34pub struct MAELoss {
35 is_mean_reduction: bool
36}
37
38impl MAELoss {
39 pub fn new(reduction: &str) -> Self {
40 let is_mean_reduction = match reduction {
41 "mean" => true,
42 "sum" => false,
43 _ => panic!("Reduction must be either 'mean' or 'sum'"),
44 };
45
46 Self{ is_mean_reduction }
47 }
48}
49
50impl LossTrait for MAELoss {
51 fn loss(&self, x: &Tensor, y: &Tensor) -> Tensor {
52 let z_1 = x.sub_tensor(y);
53 let z_2 = z_1.abs();
54
55 if self.is_mean_reduction {
56 z_2.mean()
57 } else {
58 z_2.sum()
59 }
60 }
61}