fn compute_mean_metric<F>(y_true: &[f64], y_pred: &[f64], f: F) -> f64
where
F: Fn(f64, f64) -> f64,
{
if y_true.is_empty() {
return 0.0;
}
let n = y_true.len();
y_true
.iter()
.zip(y_pred.iter())
.map(|(&yt, &yp)| f(yt, yp))
.sum::<f64>() / n as f64
}
pub fn compute_mse(y_true: &[f64], y_pred: &[f64]) -> f64 {
compute_mean_metric(y_true, y_pred, |yt, yp| (yt - yp).powi(2))
}
pub fn compute_rmse(y_true: &[f64], y_pred: &[f64]) -> f64 {
compute_mse(y_true, y_pred).sqrt()
}
pub fn compute_mae(y_true: &[f64], y_pred: &[f64]) -> f64 {
compute_mean_metric(y_true, y_pred, |yt, yp| (yt - yp).abs())
}
pub fn compute_r_squared(y_true: &[f64], y_pred: &[f64]) -> f64 {
if y_true.len() <= 1 {
return 0.0;
}
let n = y_true.len();
let mean_true: f64 = y_true.iter().sum::<f64>() / n as f64;
let ss_res: f64 = y_true
.iter()
.zip(y_pred.iter())
.map(|(yt, yp)| (yt - yp).powi(2))
.sum();
let ss_tot: f64 = y_true.iter().map(|yt| (yt - mean_true).powi(2)).sum();
if ss_tot == 0.0 {
return if ss_res == 0.0 { 1.0 } else { 0.0 };
}
1.0 - (ss_res / ss_tot)
}
pub fn mean_std(values: &[f64]) -> (f64, f64) {
if values.is_empty() {
return (0.0, 0.0);
}
let n = values.len();
let mean = values.iter().sum::<f64>() / n as f64;
if n == 1 {
return (mean, 0.0);
}
let variance =
values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / (n - 1) as f64;
let std_dev = variance.sqrt();
(mean, std_dev)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_mse_basic() {
let y_true = vec![3.0, -0.5, 2.0, 7.0];
let y_pred = vec![2.5, 0.0, 2.0, 8.0];
let mse = compute_mse(&y_true, &y_pred);
assert!((mse - 0.375).abs() < 1e-10);
}
#[test]
fn test_compute_mse_perfect_prediction() {
let y_true = vec![1.0, 2.0, 3.0];
let y_pred = vec![1.0, 2.0, 3.0];
assert_eq!(compute_mse(&y_true, &y_pred), 0.0);
}
#[test]
fn test_compute_mse_empty() {
assert_eq!(compute_mse(&[], &[]), 0.0);
}
#[test]
fn test_compute_rmse_basic() {
let y_true = vec![3.0, -0.5, 2.0, 7.0];
let y_pred = vec![2.5, 0.0, 2.0, 8.0];
let rmse = compute_rmse(&y_true, &y_pred);
assert!((rmse - 0.61237).abs() < 1e-5);
}
#[test]
fn test_compute_mae_basic() {
let y_true = vec![3.0, -0.5, 2.0, 7.0];
let y_pred = vec![2.5, 0.0, 2.0, 8.0];
let mae = compute_mae(&y_true, &y_pred);
assert_eq!(mae, 0.5);
}
#[test]
fn test_compute_mae_empty() {
assert_eq!(compute_mae(&[], &[]), 0.0);
}
#[test]
fn test_compute_r_squared_perfect_fit() {
let y_true = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y_pred = vec![1.0, 2.0, 3.0, 4.0, 5.0];
assert_eq!(compute_r_squared(&y_true, &y_pred), 1.0);
}
#[test]
fn test_compute_r_squared_good_fit() {
let y_true = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y_pred = vec![1.1, 1.9, 3.1, 3.9, 5.1];
let r2 = compute_r_squared(&y_true, &y_pred);
assert!(r2 > 0.99);
}
#[test]
fn test_compute_r_squared_negative() {
let y_true = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y_pred = vec![10.0, 20.0, 30.0, 40.0, 50.0];
let r2 = compute_r_squared(&y_true, &y_pred);
assert!(r2 < 0.0);
}
#[test]
fn test_compute_r_squared_constant_values() {
let y_true = vec![5.0, 5.0, 5.0, 5.0];
let y_pred = vec![5.0, 5.0, 5.0, 5.0];
assert_eq!(compute_r_squared(&y_true, &y_pred), 1.0);
}
#[test]
fn test_compute_r_squared_empty() {
assert_eq!(compute_r_squared(&[], &[]), 0.0);
}
#[test]
fn test_compute_r_squared_single_element() {
assert_eq!(compute_r_squared(&[1.0], &[1.0]), 0.0);
}
#[test]
fn test_mean_std_basic() {
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let (mean, std) = mean_std(&values);
assert_eq!(mean, 3.0);
assert!((std - 1.5811).abs() < 1e-4);
}
#[test]
fn test_mean_std_constant_values() {
let values = vec![5.0, 5.0, 5.0, 5.0, 5.0];
let (mean, std) = mean_std(&values);
assert_eq!(mean, 5.0);
assert_eq!(std, 0.0);
}
#[test]
fn test_mean_std_empty() {
let (mean, std) = mean_std(&[]);
assert_eq!(mean, 0.0);
assert_eq!(std, 0.0);
}
#[test]
fn test_mean_std_single_element() {
let (mean, std) = mean_std(&[42.0]);
assert_eq!(mean, 42.0);
assert_eq!(std, 0.0);
}
#[test]
fn test_metrics_consistency() {
let y_true = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y_pred = vec![1.2, 1.8, 3.1, 4.2, 4.8];
let mse = compute_mse(&y_true, &y_pred);
let rmse = compute_rmse(&y_true, &y_pred);
assert!((rmse - mse.sqrt()).abs() < 1e-10);
}
#[test]
fn test_different_length_arrays() {
let y_true = vec![1.0, 2.0, 3.0];
let y_pred = vec![1.0, 2.0];
let mse = compute_mse(&y_true, &y_pred);
assert_eq!(mse, 0.0);
}
}