1use crate::util;
6use crate::numeric::Scalar;
7use crate::error::EvalError;
8
9pub fn mse<T: Scalar>(scores: &Vec<T>, labels: &Vec<T>) -> Result<T, EvalError> {
30 util::validate_input_dims(scores, labels).and_then(|()| {
31 Ok(scores.iter().zip(labels.iter()).fold(T::zero(), |sum, (&a, &b)| {
32 let diff = a - b;
33 sum + (diff * diff)
34 }) / T::from_usize(scores.len()))
35 }).and_then(util::check_finite)
36}
37
38pub fn rmse<T: Scalar>(scores: &Vec<T>, labels: &Vec<T>) -> Result<T, EvalError> {
59 mse(scores, labels).map(|m| m.sqrt())
60}
61
62pub fn mae<T: Scalar>(scores: &Vec<T>, labels: &Vec<T>) -> Result<T, EvalError> {
83 util::validate_input_dims(scores, labels).and_then(|()| {
84 Ok(scores.iter().zip(labels.iter()).fold(T::zero(), |sum, (&a, &b)| {
85 sum + (a - b).abs()
86 }) / T::from_usize(scores.len()))
87 }).and_then(util::check_finite)
88}
89
90pub fn rsq<T: Scalar>(scores: &Vec<T>, labels: &Vec<T>) -> Result<T, EvalError> {
111 util::validate_input_dims(scores, labels).and_then(|()| {
112 let length = scores.len();
113 let label_sum = labels.iter().fold(T::zero(), |s, &v| {s + v});
114 let label_mean = label_sum / T::from_usize(length);
115 let den = labels.iter().fold(T::zero(), |sse, &label| {
116 sse + (label - label_mean) * (label - label_mean)
117 }) / T::from_usize(length);
118 if den == T::zero() {
119 Err(EvalError::constant_input_data())
120 } else {
121 mse(scores, labels).map(|m| T::one() - (m / den))
122 }
123 })
124}
125
126pub fn corr<T: Scalar>(scores: &Vec<T>, labels: &Vec<T>) -> Result<T, EvalError> {
147 util::validate_input_dims(scores, labels).and_then(|()| {
148 let length = scores.len();
149 let x_mean = scores.iter().fold(T::zero(), |sum, &v| {sum + v}) / T::from_usize(length);
150 let y_mean = labels.iter().fold(T::zero(), |sum, &v| {sum + v}) / T::from_usize(length);
151 let mut sxx = T::zero();
152 let mut syy = T::zero();
153 let mut sxy = T::zero();
154
155 scores.iter().zip(labels.iter()).for_each(|(&x, &y)| {
156 let x_diff = x - x_mean;
157 let y_diff = y - y_mean;
158 sxx += x_diff * x_diff;
159 syy += y_diff * y_diff;
160 sxy += x_diff * y_diff;
161 });
162
163 match (sxx * syy).sqrt() {
164 den if den == T::zero() => Err(EvalError::constant_input_data()),
165 den => util::check_finite(sxy / den)
166 }
167 })
168}
169
170#[cfg(test)]
171mod tests {
172
173 use assert_approx_eq::assert_approx_eq;
174 use super::*;
175
176 fn data() -> (Vec<f64>, Vec<f64>) {
177 let scores= vec![0.5, 0.2, 0.7, 0.4, 0.1, 0.3, 0.8, 0.9];
178 let labels= vec![0.3, 0.1, 0.5, 0.6, 0.2, 0.5, 0.7, 0.6];
179 (scores, labels)
180 }
181
182 #[test]
183 fn test_mse() {
184 let (scores, labels) = data();
185 assert_approx_eq!(mse(&scores, &labels).unwrap(), 0.035)
186 }
187
188 #[test]
189 fn test_mse_empty() {
190 assert!(mse(&Vec::<f64>::new(), &Vec::<f64>::new()).is_err())
191 }
192
193 #[test]
194 fn test_mse_unequal_length() {
195 assert!(mse(&vec![0.1, 0.2], &vec![0.3, 0.5, 0.8]).is_err())
196 }
197
198 #[test]
199 fn test_mse_constant() {
200 assert_approx_eq!(mse(&vec![1.0; 10], &vec![1.0; 10]).unwrap(), 0.0)
201 }
202
203 #[test]
204 fn test_mse_nan() {
205 assert!(mse(&vec![0.2, 0.5, 0.4], &vec![0.1, 0.4, f64::NAN]).is_err())
206 }
207
208 #[test]
209 fn test_rmse() {
210 let (scores, labels) = data();
211 assert_approx_eq!(rmse(&scores, &labels).unwrap(), 0.035.sqrt())
212 }
213
214 #[test]
215 fn test_rmse_empty() {
216 assert!(rmse(&Vec::<f64>::new(), &Vec::<f64>::new()).is_err())
217 }
218
219 #[test]
220 fn test_rmse_unequal_length() {
221 assert!(rmse(&vec![0.1, 0.2], &vec![0.3, 0.5, 0.8]).is_err())
222 }
223
224 #[test]
225 fn test_rmse_constant() {
226 assert_approx_eq!(rmse(&vec![1.0; 10], &vec![1.0; 10]).unwrap(), 0.0)
227 }
228
229 #[test]
230 fn test_rmse_nan() {
231 assert!(rmse(&vec![0.2, 0.5, 0.4], &vec![0.1, 0.4, f64::NAN]).is_err())
232 }
233
234 #[test]
235 fn test_mae() {
236 let (scores, labels) = data();
237 assert_approx_eq!(mae(&scores, &labels).unwrap(), 0.175)
238 }
239
240 #[test]
241 fn test_mae_empty() {
242 assert!(mae(&Vec::<f64>::new(), &Vec::<f64>::new()).is_err())
243 }
244
245 #[test]
246 fn test_mae_unequal_length() {
247 assert!(mae(&vec![0.1, 0.2], &vec![0.3, 0.5, 0.8]).is_err())
248 }
249
250 #[test]
251 fn test_mae_constant() {
252 assert_approx_eq!(mae(&vec![1.0; 10], &vec![1.0; 10]).unwrap(), 0.0)
253 }
254
255 #[test]
256 fn test_mae_nan() {
257 assert!(mae(&vec![0.2, 0.5, 0.4], &vec![0.1, 0.4, f64::NAN]).is_err())
258 }
259
260 #[test]
261 fn test_rsq() {
262 let (scores, labels) = data();
263 assert_approx_eq!(rsq(&scores, &labels).unwrap(), 0.12156862745098007)
264 }
265
266 #[test]
267 fn test_rsq_empty() {
268 assert!(rsq(&Vec::<f64>::new(), &Vec::<f64>::new()).is_err())
269 }
270
271 #[test]
272 fn test_rsq_unequal_length() {
273 assert!(rsq(&vec![0.1, 0.2], &vec![0.3, 0.5, 0.8]).is_err())
274 }
275
276 #[test]
277 fn test_rsq_constant() {
278 assert!(rsq(&vec![1.0; 10], &vec![1.0; 10]).is_err())
279 }
280
281 #[test]
282 fn test_rsq_nan() {
283 assert!(rsq(&vec![0.2, 0.5, 0.4], &vec![0.1, 0.4, f64::NAN]).is_err())
284 }
285
286 #[test]
287 fn test_corr() {
288 let (scores, labels) = data();
289 assert_approx_eq!(corr(&scores, &labels).unwrap(), 0.7473417080949364)
290 }
291
292 #[test]
293 fn test_corr_empty() {
294 assert!(corr(&Vec::<f64>::new(), &Vec::<f64>::new()).is_err())
295 }
296
297 #[test]
298 fn test_corr_unequal_length() {
299 assert!(corr(&vec![0.1, 0.2], &vec![0.3, 0.5, 0.8]).is_err())
300 }
301
302 #[test]
303 fn test_corr_constant() {
304 assert!(corr(&vec![1.0; 10], &vec![1.0; 10]).is_err())
305 }
306
307 #[test]
308 fn test_corr_nan() {
309 assert!(corr(&vec![0.2, 0.5, 0.4], &vec![0.1, 0.4, f64::NAN]).is_err())
310 }
311}