Skip to main content

yscv_eval/
regression.rs

1use crate::error::EvalError;
2
3fn check_lengths(predictions: &[f32], targets: &[f32]) -> Result<(), EvalError> {
4    if predictions.len() != targets.len() {
5        return Err(EvalError::CountLengthMismatch {
6            ground_truth: targets.len(),
7            predictions: predictions.len(),
8        });
9    }
10    Ok(())
11}
12
13/// Coefficient of determination: 1 - SS_res / SS_tot.
14///
15/// Returns `0.0` for empty inputs (no data to explain).
16pub fn r2_score(predictions: &[f32], targets: &[f32]) -> Result<f32, EvalError> {
17    check_lengths(predictions, targets)?;
18    if targets.is_empty() {
19        return Ok(0.0);
20    }
21    let mean = targets.iter().sum::<f32>() / targets.len() as f32;
22    let ss_tot: f32 = targets.iter().map(|t| (t - mean).powi(2)).sum();
23    if ss_tot == 0.0 {
24        // All targets identical. If predictions match, perfect; otherwise undefined — return 0.
25        let ss_res: f32 = predictions
26            .iter()
27            .zip(targets)
28            .map(|(p, t)| (p - t).powi(2))
29            .sum();
30        return Ok(if ss_res == 0.0 { 1.0 } else { 0.0 });
31    }
32    let ss_res: f32 = predictions
33        .iter()
34        .zip(targets)
35        .map(|(p, t)| (p - t).powi(2))
36        .sum();
37    Ok(1.0 - ss_res / ss_tot)
38}
39
40/// Mean absolute error.
41///
42/// Returns `0.0` for empty inputs.
43pub fn mae(predictions: &[f32], targets: &[f32]) -> Result<f32, EvalError> {
44    check_lengths(predictions, targets)?;
45    if targets.is_empty() {
46        return Ok(0.0);
47    }
48    let sum: f32 = predictions
49        .iter()
50        .zip(targets)
51        .map(|(p, t)| (p - t).abs())
52        .sum();
53    Ok(sum / targets.len() as f32)
54}
55
56/// Root mean squared error.
57///
58/// Returns `0.0` for empty inputs.
59pub fn rmse(predictions: &[f32], targets: &[f32]) -> Result<f32, EvalError> {
60    check_lengths(predictions, targets)?;
61    if targets.is_empty() {
62        return Ok(0.0);
63    }
64    let sum: f32 = predictions
65        .iter()
66        .zip(targets)
67        .map(|(p, t)| (p - t).powi(2))
68        .sum();
69    Ok((sum / targets.len() as f32).sqrt())
70}
71
72/// Mean absolute percentage error, skipping pairs where the target is zero.
73///
74/// Returns `0.0` for empty inputs or when all targets are zero.
75pub fn mape(predictions: &[f32], targets: &[f32]) -> Result<f32, EvalError> {
76    check_lengths(predictions, targets)?;
77    if targets.is_empty() {
78        return Ok(0.0);
79    }
80    let mut sum = 0.0f32;
81    let mut count = 0usize;
82    for (p, t) in predictions.iter().zip(targets) {
83        if *t == 0.0 {
84            continue;
85        }
86        sum += ((p - t) / t).abs();
87        count += 1;
88    }
89    if count == 0 {
90        return Ok(0.0);
91    }
92    Ok(sum / count as f32)
93}