ducky_learn/
cost.rs

1use ndarray::Array1;
2
3// TODO: Documentation
4pub fn mean_squared_error(
5    observed_array: Array1<f64>,
6    predicted_array: Array1<f64>,
7) -> Array1<f64> {
8    (&observed_array - &predicted_array).mapv(|value| value.powi(2))
9}
10
11// TODO: Documentation
12pub fn deriv_mean_squared_error(
13    observed_array: Array1<f64>,
14    predicted_array: Array1<f64>,
15) -> Array1<f64> {
16    -2f64 * (&observed_array - &predicted_array)
17}
18
19#[cfg(test)]
20mod cost_tests {
21    use super::*;
22    use ndarray::arr1;
23
24    #[test]
25    fn mse_1() {
26        let observed = arr1(&[0.88651179, 0.59085182, 0.78865531]);
27        let predicted = arr1(&[0.37609094, 0.04389782, 0.27988027]);
28
29        assert_eq!(
30            mean_squared_error(observed, predicted),
31            arr1(&[
32                0.26052944411472256,
33                0.29915867811600005,
34                0.25885204132700157
35            ])
36        );
37    }
38
39    #[test]
40    fn mse_2() {
41        let observed = arr1(&[32.321, -0.32, -1.232]);
42        let predicted = arr1(&[0.69953402, 0.07279993, 0.25552055]);
43
44        assert_eq!(
45            mean_squared_error(observed, predicted),
46            arr1(&[999.9171107242971, 0.15429178500800492, 2.2127173866723022])
47        );
48    }
49
50    #[test]
51    fn mse_3() {
52        let observed = arr1(&[]);
53        let predicted = arr1(&[]);
54
55        assert_eq!(mean_squared_error(observed, predicted), arr1(&[]));
56    }
57
58    #[test]
59    fn deriv_mse_1() {
60        let observed = arr1(&[-0.52198585, -2.27179003, -0.14017833]);
61        let predicted = arr1(&[0.81674329, -1.07071564, 2.20337672]);
62
63        assert_eq!(
64            deriv_mean_squared_error(observed, predicted),
65            arr1(&[2.6774582799999997, 2.40214878, 4.6871101])
66        );
67    }
68
69    #[test]
70    fn deriv_mse_2() {
71        let observed = arr1(&[-0.76362711, -1.83292557, -0.16423367]);
72        let predicted = arr1(&[-1.3829452, 0.2221366, -0.27885796]);
73
74        assert_eq!(
75            deriv_mean_squared_error(observed, predicted),
76            arr1(&[-1.2386361799999999, 4.11012434, -0.22924858000000004])
77        );
78    }
79
80    #[test]
81    fn deriv_mse_3() {
82        let observed = arr1(&[]);
83        let predicted = arr1(&[]);
84
85        assert_eq!(deriv_mean_squared_error(observed, predicted), arr1(&[]));
86    }
87}