eval_metrics/
regression.rs

1//!
2//! Provides support for regression metrics
3//!
4
5use crate::util;
6use crate::numeric::Scalar;
7use crate::error::EvalError;
8
9///
10/// Computes the mean squared error between scores and labels
11///
12/// # Arguments
13///
14/// * `scores` - score vector
15/// * `labels` - label vector
16///
17/// # Examples
18///
19/// ```
20/// # use eval_metrics::error::EvalError;
21/// # fn main() -> Result<(), EvalError> {
22/// use eval_metrics::regression::mse;
23/// let scores = vec![2.3, 5.1, -3.2, 7.1, -4.4];
24/// let labels = vec![1.7, 4.3, -4.1, 6.5, -3.2];
25/// let metric = mse(&scores, &labels)?;
26/// # Ok(())}
27/// ```
28///
29pub fn mse<T: Scalar>(scores: &Vec<T>, labels: &Vec<T>) -> Result<T, EvalError> {
30    util::validate_input_dims(scores, labels).and_then(|()| {
31        Ok(scores.iter().zip(labels.iter()).fold(T::zero(), |sum, (&a, &b)| {
32            let diff = a - b;
33            sum + (diff * diff)
34        }) / T::from_usize(scores.len()))
35    }).and_then(util::check_finite)
36}
37
38///
39/// Computes the root mean squared error between scores and labels
40///
41/// # Arguments
42///
43/// * `scores` - score vector
44/// * `labels` - label vector
45///
46/// # Examples
47///
48/// ```
49/// # use eval_metrics::error::EvalError;
50/// # fn main() -> Result<(), EvalError> {
51/// use eval_metrics::regression::rmse;
52/// let scores = vec![2.3, 5.1, -3.2, 7.1, -4.4];
53/// let labels = vec![1.7, 4.3, -4.1, 6.5, -3.2];
54/// let metric = rmse(&scores, &labels)?;
55/// # Ok(())}
56/// ```
57///
58pub fn rmse<T: Scalar>(scores: &Vec<T>, labels: &Vec<T>) -> Result<T, EvalError> {
59    mse(scores, labels).map(|m| m.sqrt())
60}
61
62///
63/// Computes the mean absolute error between scores and labels
64///
65/// # Arguments
66///
67/// * `scores` - score vector
68/// * `labels` - label vector
69///
70/// # Examples
71///
72/// ```
73/// # use eval_metrics::error::EvalError;
74/// # fn main() -> Result<(), EvalError> {
75/// use eval_metrics::regression::mae;
76/// let scores = vec![2.3, 5.1, -3.2, 7.1, -4.4];
77/// let labels = vec![1.7, 4.3, -4.1, 6.5, -3.2];
78/// let metric = mae(&scores, &labels)?;
79/// # Ok(())}
80/// ```
81///
82pub fn mae<T: Scalar>(scores: &Vec<T>, labels: &Vec<T>) -> Result<T, EvalError> {
83    util::validate_input_dims(scores, labels).and_then(|()| {
84        Ok(scores.iter().zip(labels.iter()).fold(T::zero(), |sum, (&a, &b)| {
85            sum + (a - b).abs()
86        }) / T::from_usize(scores.len()))
87    }).and_then(util::check_finite)
88}
89
90///
91/// Computes the coefficient of determination between scores and labels
92///
93/// # Arguments
94///
95/// * `scores` - score vector
96/// * `labels` - label vector
97///
98/// # Examples
99///
100/// ```
101/// # use eval_metrics::error::EvalError;
102/// # fn main() -> Result<(), EvalError> {
103/// use eval_metrics::regression::rsq;
104/// let scores = vec![2.3, 5.1, -3.2, 7.1, -4.4];
105/// let labels = vec![1.7, 4.3, -4.1, 6.5, -3.2];
106/// let metric = rsq(&scores, &labels)?;
107/// # Ok(())}
108/// ```
109///
110pub fn rsq<T: Scalar>(scores: &Vec<T>, labels: &Vec<T>) -> Result<T, EvalError> {
111    util::validate_input_dims(scores, labels).and_then(|()| {
112        let length = scores.len();
113        let label_sum = labels.iter().fold(T::zero(), |s, &v| {s + v});
114        let label_mean =  label_sum / T::from_usize(length);
115        let den = labels.iter().fold(T::zero(), |sse, &label| {
116            sse + (label - label_mean) * (label - label_mean)
117        }) / T::from_usize(length);
118        if den == T::zero() {
119            Err(EvalError::constant_input_data())
120        } else {
121            mse(scores, labels).map(|m| T::one() - (m / den))
122        }
123    })
124}
125
126///
127/// Computes the linear correlation between scores and labels
128///
129/// # Arguments
130///
131/// * `scores` - score vector
132/// * `labels` - label vector
133///
134/// # Examples
135///
136/// ```
137/// # use eval_metrics::error::EvalError;
138/// # fn main() -> Result<(), EvalError> {
139/// use eval_metrics::regression::corr;
140/// let scores = vec![2.3, 5.1, -3.2, 7.1, -4.4];
141/// let labels = vec![1.7, 4.3, -4.1, 6.5, -3.2];
142/// let metric = corr(&scores, &labels)?;
143/// # Ok(())}
144/// ```
145///
146pub fn corr<T: Scalar>(scores: &Vec<T>, labels: &Vec<T>) -> Result<T, EvalError> {
147    util::validate_input_dims(scores, labels).and_then(|()| {
148        let length = scores.len();
149        let x_mean = scores.iter().fold(T::zero(), |sum, &v| {sum + v}) / T::from_usize(length);
150        let y_mean = labels.iter().fold(T::zero(), |sum, &v| {sum + v}) / T::from_usize(length);
151        let mut sxx = T::zero();
152        let mut syy = T::zero();
153        let mut sxy = T::zero();
154
155        scores.iter().zip(labels.iter()).for_each(|(&x, &y)| {
156            let x_diff = x - x_mean;
157            let y_diff = y - y_mean;
158            sxx += x_diff * x_diff;
159            syy += y_diff * y_diff;
160            sxy += x_diff * y_diff;
161        });
162
163        match (sxx * syy).sqrt() {
164            den if den == T::zero() => Err(EvalError::constant_input_data()),
165            den => util::check_finite(sxy / den)
166        }
167    })
168}
169
170#[cfg(test)]
171mod tests {
172
173    use assert_approx_eq::assert_approx_eq;
174    use super::*;
175
176    fn data() -> (Vec<f64>, Vec<f64>) {
177        let scores= vec![0.5, 0.2, 0.7, 0.4, 0.1, 0.3, 0.8, 0.9];
178        let labels= vec![0.3, 0.1, 0.5, 0.6, 0.2, 0.5, 0.7, 0.6];
179        (scores, labels)
180    }
181
182    #[test]
183    fn test_mse() {
184        let (scores, labels) = data();
185        assert_approx_eq!(mse(&scores, &labels).unwrap(), 0.035)
186    }
187
188    #[test]
189    fn test_mse_empty() {
190        assert!(mse(&Vec::<f64>::new(), &Vec::<f64>::new()).is_err())
191    }
192
193    #[test]
194    fn test_mse_unequal_length() {
195        assert!(mse(&vec![0.1, 0.2], &vec![0.3, 0.5, 0.8]).is_err())
196    }
197
198    #[test]
199    fn test_mse_constant() {
200        assert_approx_eq!(mse(&vec![1.0; 10], &vec![1.0; 10]).unwrap(), 0.0)
201    }
202
203    #[test]
204    fn test_mse_nan() {
205        assert!(mse(&vec![0.2, 0.5, 0.4], &vec![0.1, 0.4, f64::NAN]).is_err())
206    }
207
208    #[test]
209    fn test_rmse() {
210        let (scores, labels) = data();
211        assert_approx_eq!(rmse(&scores, &labels).unwrap(), 0.035.sqrt())
212    }
213
214    #[test]
215    fn test_rmse_empty() {
216        assert!(rmse(&Vec::<f64>::new(), &Vec::<f64>::new()).is_err())
217    }
218
219    #[test]
220    fn test_rmse_unequal_length() {
221        assert!(rmse(&vec![0.1, 0.2], &vec![0.3, 0.5, 0.8]).is_err())
222    }
223
224    #[test]
225    fn test_rmse_constant() {
226        assert_approx_eq!(rmse(&vec![1.0; 10], &vec![1.0; 10]).unwrap(), 0.0)
227    }
228
229    #[test]
230    fn test_rmse_nan() {
231        assert!(rmse(&vec![0.2, 0.5, 0.4], &vec![0.1, 0.4, f64::NAN]).is_err())
232    }
233
234    #[test]
235    fn test_mae() {
236        let (scores, labels) = data();
237        assert_approx_eq!(mae(&scores, &labels).unwrap(), 0.175)
238    }
239
240    #[test]
241    fn test_mae_empty() {
242        assert!(mae(&Vec::<f64>::new(), &Vec::<f64>::new()).is_err())
243    }
244
245    #[test]
246    fn test_mae_unequal_length() {
247        assert!(mae(&vec![0.1, 0.2], &vec![0.3, 0.5, 0.8]).is_err())
248    }
249
250    #[test]
251    fn test_mae_constant() {
252        assert_approx_eq!(mae(&vec![1.0; 10], &vec![1.0; 10]).unwrap(), 0.0)
253    }
254
255    #[test]
256    fn test_mae_nan() {
257        assert!(mae(&vec![0.2, 0.5, 0.4], &vec![0.1, 0.4, f64::NAN]).is_err())
258    }
259
260    #[test]
261    fn test_rsq() {
262        let (scores, labels) = data();
263        assert_approx_eq!(rsq(&scores, &labels).unwrap(), 0.12156862745098007)
264    }
265
266    #[test]
267    fn test_rsq_empty() {
268        assert!(rsq(&Vec::<f64>::new(), &Vec::<f64>::new()).is_err())
269    }
270
271    #[test]
272    fn test_rsq_unequal_length() {
273        assert!(rsq(&vec![0.1, 0.2], &vec![0.3, 0.5, 0.8]).is_err())
274    }
275
276    #[test]
277    fn test_rsq_constant() {
278        assert!(rsq(&vec![1.0; 10], &vec![1.0; 10]).is_err())
279    }
280
281    #[test]
282    fn test_rsq_nan() {
283        assert!(rsq(&vec![0.2, 0.5, 0.4], &vec![0.1, 0.4, f64::NAN]).is_err())
284    }
285
286    #[test]
287    fn test_corr() {
288        let (scores, labels) = data();
289        assert_approx_eq!(corr(&scores, &labels).unwrap(), 0.7473417080949364)
290    }
291
292    #[test]
293    fn test_corr_empty() {
294        assert!(corr(&Vec::<f64>::new(), &Vec::<f64>::new()).is_err())
295    }
296
297    #[test]
298    fn test_corr_unequal_length() {
299        assert!(corr(&vec![0.1, 0.2], &vec![0.3, 0.5, 0.8]).is_err())
300    }
301
302    #[test]
303    fn test_corr_constant() {
304        assert!(corr(&vec![1.0; 10], &vec![1.0; 10]).is_err())
305    }
306
307    #[test]
308    fn test_corr_nan() {
309        assert!(corr(&vec![0.2, 0.5, 0.4], &vec![0.1, 0.4, f64::NAN]).is_err())
310    }
311}