radiate_gp/regression/
loss.rs

1use super::DataSet;
2use crate::EvalMut;
3
4#[derive(Debug, Clone, Copy)]
5pub enum Loss {
6    MSE,
7    MAE,
8    CrossEntropy,
9    Diff,
10}
11
12impl Loss {
13    #[inline]
14    pub fn calc(&self, data_set: &DataSet, eval: &mut impl EvalMut<[f32], Vec<f32>>) -> f32 {
15        let out_len = data_set.shape().2;
16        let mut buffer = vec![0.0; out_len];
17
18        self.calculate(
19            data_set,
20            |x, y| {
21                let v = eval.eval_mut(x);
22                y.copy_from_slice(&v);
23            },
24            &mut buffer[..out_len],
25        )
26    }
27
28    #[inline]
29    pub fn calculate<F>(&self, data_set: &DataSet, mut eval_into_buf: F, buffer: &mut [f32]) -> f32
30    where
31        F: FnMut(&[f32], &mut [f32]),
32    {
33        let n = data_set.len() as f32;
34
35        match self {
36            Loss::MSE => {
37                let mut sum = 0.0;
38                for sample in data_set.iter() {
39                    eval_into_buf(sample.input(), buffer);
40                    let target = sample.output();
41                    for i in 0..target.len() {
42                        let d = target[i] - buffer[i];
43                        sum += d * d;
44                    }
45                }
46                sum / n
47            }
48            Loss::MAE => {
49                let mut sum = 0.0;
50                for sample in data_set.iter() {
51                    eval_into_buf(sample.input(), buffer);
52                    let target = sample.output();
53                    for i in 0..target.len() {
54                        let d = target[i] - buffer[i];
55                        sum += d.abs();
56                    }
57                }
58                sum / n
59            }
60            Loss::CrossEntropy => {
61                const EPS: f32 = 1e-7;
62                let mut sum = 0.0;
63                for sample in data_set.iter() {
64                    eval_into_buf(sample.input(), buffer);
65                    let target = sample.output();
66                    for i in 0..target.len() {
67                        let p = target[i];
68                        let q = buffer[i].clamp(EPS, 1.0);
69                        sum += -p * q.ln();
70                    }
71                }
72                sum / n
73            }
74            Loss::Diff => {
75                let mut sum = 0.0;
76                for sample in data_set.iter() {
77                    eval_into_buf(sample.input(), buffer);
78                    let target = sample.output();
79                    for i in 0..target.len() {
80                        sum += (target[i] - buffer[i]).abs();
81                    }
82                }
83                sum / n
84            }
85        }
86    }
87}