1use ndarray::Array1;
2
3pub fn mean_squared_error(
5 observed_array: Array1<f64>,
6 predicted_array: Array1<f64>,
7) -> Array1<f64> {
8 (&observed_array - &predicted_array).mapv(|value| value.powi(2))
9}
10
11pub fn deriv_mean_squared_error(
13 observed_array: Array1<f64>,
14 predicted_array: Array1<f64>,
15) -> Array1<f64> {
16 -2f64 * (&observed_array - &predicted_array)
17}
18
19#[cfg(test)]
20mod cost_tests {
21 use super::*;
22 use ndarray::arr1;
23
24 #[test]
25 fn mse_1() {
26 let observed = arr1(&[0.88651179, 0.59085182, 0.78865531]);
27 let predicted = arr1(&[0.37609094, 0.04389782, 0.27988027]);
28
29 assert_eq!(
30 mean_squared_error(observed, predicted),
31 arr1(&[
32 0.26052944411472256,
33 0.29915867811600005,
34 0.25885204132700157
35 ])
36 );
37 }
38
39 #[test]
40 fn mse_2() {
41 let observed = arr1(&[32.321, -0.32, -1.232]);
42 let predicted = arr1(&[0.69953402, 0.07279993, 0.25552055]);
43
44 assert_eq!(
45 mean_squared_error(observed, predicted),
46 arr1(&[999.9171107242971, 0.15429178500800492, 2.2127173866723022])
47 );
48 }
49
50 #[test]
51 fn mse_3() {
52 let observed = arr1(&[]);
53 let predicted = arr1(&[]);
54
55 assert_eq!(mean_squared_error(observed, predicted), arr1(&[]));
56 }
57
58 #[test]
59 fn deriv_mse_1() {
60 let observed = arr1(&[-0.52198585, -2.27179003, -0.14017833]);
61 let predicted = arr1(&[0.81674329, -1.07071564, 2.20337672]);
62
63 assert_eq!(
64 deriv_mean_squared_error(observed, predicted),
65 arr1(&[2.6774582799999997, 2.40214878, 4.6871101])
66 );
67 }
68
69 #[test]
70 fn deriv_mse_2() {
71 let observed = arr1(&[-0.76362711, -1.83292557, -0.16423367]);
72 let predicted = arr1(&[-1.3829452, 0.2221366, -0.27885796]);
73
74 assert_eq!(
75 deriv_mean_squared_error(observed, predicted),
76 arr1(&[-1.2386361799999999, 4.11012434, -0.22924858000000004])
77 );
78 }
79
80 #[test]
81 fn deriv_mse_3() {
82 let observed = arr1(&[]);
83 let predicted = arr1(&[]);
84
85 assert_eq!(deriv_mean_squared_error(observed, predicted), arr1(&[]));
86 }
87}