radiate_extensions/problems/
error_functions.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
use num_traits::cast::FromPrimitive;
use num_traits::float::Float;
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, Sub, SubAssign};

use super::DataSet;

pub enum ErrorFunction {
    MSE,
    MAE,
    CrossEntropy,
    Diff,
}

impl ErrorFunction {
    pub fn calculate<T, F>(&self, samples: &DataSet<T>, eval_func: &mut F) -> T
    where
        T: Clone
            + PartialEq
            + Default
            + Add<Output = T>
            + Div<Output = T>
            + Sub<Output = T>
            + Mul<Output = T>
            + Div<Output = T>
            + AddAssign
            + SubAssign
            + DivAssign
            + Float
            + FromPrimitive
            + DivAssign,
        F: FnMut(&Vec<T>) -> Vec<T>,
    {
        match self {
            ErrorFunction::MSE => {
                let mut sum = T::default();
                for sample in samples.get_samples().iter() {
                    let output = eval_func(&sample.1);

                    for (i, val) in output.iter().enumerate() {
                        let diff = sample.2[i] - *val;
                        sum += diff * diff;
                    }
                }

                sum / T::from_usize(samples.get_samples().len()).unwrap()
            }
            ErrorFunction::MAE => {
                let mut sum = T::default();
                for sample in samples.get_samples().iter() {
                    let output = eval_func(&sample.1);

                    for i in 0..sample.2.len() {
                        let diff = sample.2[i] - output[i];
                        sum += diff;
                    }
                }

                sum /= T::from_usize(samples.get_samples().len()).unwrap();
                sum
            }
            ErrorFunction::CrossEntropy => {
                let mut sum = T::default();
                for sample in samples.get_samples().iter() {
                    let output = eval_func(&sample.1);

                    for i in 0..sample.2.len() {
                        sum += sample.2[i] * output[i].ln();
                    }
                }

                sum
            }
            ErrorFunction::Diff => {
                let mut sum = T::default();
                for sample in samples.get_samples().iter() {
                    let output = eval_func(&sample.1);

                    for i in 0..sample.2.len() {
                        sum += (sample.2[i] - output[i]).abs();
                    }
                }

                sum
            }
        }
    }
}