1use crate::error::EvalError;
2
3fn check_lengths(predictions: &[f32], targets: &[f32]) -> Result<(), EvalError> {
4 if predictions.len() != targets.len() {
5 return Err(EvalError::CountLengthMismatch {
6 ground_truth: targets.len(),
7 predictions: predictions.len(),
8 });
9 }
10 Ok(())
11}
12
13pub fn r2_score(predictions: &[f32], targets: &[f32]) -> Result<f32, EvalError> {
17 check_lengths(predictions, targets)?;
18 if targets.is_empty() {
19 return Ok(0.0);
20 }
21 let mean = targets.iter().sum::<f32>() / targets.len() as f32;
22 let ss_tot: f32 = targets.iter().map(|t| (t - mean).powi(2)).sum();
23 if ss_tot == 0.0 {
24 let ss_res: f32 = predictions
26 .iter()
27 .zip(targets)
28 .map(|(p, t)| (p - t).powi(2))
29 .sum();
30 return Ok(if ss_res == 0.0 { 1.0 } else { 0.0 });
31 }
32 let ss_res: f32 = predictions
33 .iter()
34 .zip(targets)
35 .map(|(p, t)| (p - t).powi(2))
36 .sum();
37 Ok(1.0 - ss_res / ss_tot)
38}
39
40pub fn mae(predictions: &[f32], targets: &[f32]) -> Result<f32, EvalError> {
44 check_lengths(predictions, targets)?;
45 if targets.is_empty() {
46 return Ok(0.0);
47 }
48 let sum: f32 = predictions
49 .iter()
50 .zip(targets)
51 .map(|(p, t)| (p - t).abs())
52 .sum();
53 Ok(sum / targets.len() as f32)
54}
55
56pub fn rmse(predictions: &[f32], targets: &[f32]) -> Result<f32, EvalError> {
60 check_lengths(predictions, targets)?;
61 if targets.is_empty() {
62 return Ok(0.0);
63 }
64 let sum: f32 = predictions
65 .iter()
66 .zip(targets)
67 .map(|(p, t)| (p - t).powi(2))
68 .sum();
69 Ok((sum / targets.len() as f32).sqrt())
70}
71
72pub fn mape(predictions: &[f32], targets: &[f32]) -> Result<f32, EvalError> {
76 check_lengths(predictions, targets)?;
77 if targets.is_empty() {
78 return Ok(0.0);
79 }
80 let mut sum = 0.0f32;
81 let mut count = 0usize;
82 for (p, t) in predictions.iter().zip(targets) {
83 if *t == 0.0 {
84 continue;
85 }
86 sum += ((p - t) / t).abs();
87 count += 1;
88 }
89 if count == 0 {
90 return Ok(0.0);
91 }
92 Ok(sum / count as f32)
93}