1#![forbid(unsafe_code)]
2#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum LossError {
16 EmptyInput,
17 MismatchedLengths,
18 NonFiniteInput,
19}
20
21pub fn absolute_error(actual: f64, predicted: f64) -> f64 {
22 (actual - predicted).abs()
23}
24
25pub fn squared_error(actual: f64, predicted: f64) -> f64 {
26 let difference = actual - predicted;
27 difference * difference
28}
29
30pub fn mean_absolute_error(actual: &[f64], predicted: &[f64]) -> Result<f64, LossError> {
31 validate_inputs(actual, predicted)?;
32
33 Ok(actual
34 .iter()
35 .zip(predicted.iter())
36 .map(|(actual_value, predicted_value)| absolute_error(*actual_value, *predicted_value))
37 .sum::<f64>()
38 / actual.len() as f64)
39}
40
41pub fn mean_squared_error(actual: &[f64], predicted: &[f64]) -> Result<f64, LossError> {
42 validate_inputs(actual, predicted)?;
43
44 Ok(actual
45 .iter()
46 .zip(predicted.iter())
47 .map(|(actual_value, predicted_value)| squared_error(*actual_value, *predicted_value))
48 .sum::<f64>()
49 / actual.len() as f64)
50}
51
52pub fn root_mean_squared_error(actual: &[f64], predicted: &[f64]) -> Result<f64, LossError> {
53 Ok(mean_squared_error(actual, predicted)?.sqrt())
54}
55
56fn validate_inputs(actual: &[f64], predicted: &[f64]) -> Result<(), LossError> {
57 if actual.is_empty() || predicted.is_empty() {
58 return Err(LossError::EmptyInput);
59 }
60
61 if actual.len() != predicted.len() {
62 return Err(LossError::MismatchedLengths);
63 }
64
65 if actual
66 .iter()
67 .chain(predicted.iter())
68 .any(|value| !value.is_finite())
69 {
70 return Err(LossError::NonFiniteInput);
71 }
72
73 Ok(())
74}
75
76#[cfg(test)]
77mod tests {
78 use super::{
79 LossError, absolute_error, mean_absolute_error, mean_squared_error,
80 root_mean_squared_error, squared_error,
81 };
82
83 fn approx_eq(left: f64, right: f64) {
84 assert!((left - right).abs() < 1.0e-10, "left={left}, right={right}");
85 }
86
87 #[test]
88 fn computes_basic_error_terms() {
89 assert_eq!(absolute_error(4.0, 3.0), 1.0);
90 assert_eq!(squared_error(4.0, 3.0), 1.0);
91 }
92
93 #[test]
94 fn computes_common_loss_functions() {
95 let actual = [1.0, 2.0, 3.0];
96 let predicted = [1.5, 2.5, 2.0];
97
98 approx_eq(mean_absolute_error(&actual, &predicted).unwrap(), 2.0 / 3.0);
99 approx_eq(mean_squared_error(&actual, &predicted).unwrap(), 0.5);
100 approx_eq(
101 root_mean_squared_error(&actual, &predicted).unwrap(),
102 0.5_f64.sqrt(),
103 );
104 }
105
106 #[test]
107 fn handles_single_value_inputs() {
108 approx_eq(mean_absolute_error(&[3.0], &[2.0]).unwrap(), 1.0);
109 approx_eq(mean_squared_error(&[3.0], &[2.0]).unwrap(), 1.0);
110 approx_eq(root_mean_squared_error(&[3.0], &[2.0]).unwrap(), 1.0);
111 }
112
113 #[test]
114 fn rejects_invalid_loss_inputs() {
115 assert_eq!(mean_absolute_error(&[], &[]), Err(LossError::EmptyInput));
116 assert_eq!(
117 mean_squared_error(&[1.0], &[1.0, 2.0]),
118 Err(LossError::MismatchedLengths)
119 );
120 assert_eq!(
121 root_mean_squared_error(&[1.0, f64::NAN], &[1.0, 2.0]),
122 Err(LossError::NonFiniteInput)
123 );
124 }
125}