use crate::error::{MetricsError, Result};
#[derive(Debug, Clone)]
pub struct MetricBenchmark {
pub name: &'static str,
pub predictions: Vec<f64>,
pub targets: Vec<f64>,
pub expected_mse: f64,
pub expected_mae: f64,
pub expected_r2: f64,
}
#[derive(Debug, Clone)]
pub struct ClassificationBenchmark {
pub name: &'static str,
pub predictions: Vec<usize>,
pub targets: Vec<usize>,
pub expected_accuracy: f64,
pub expected_precision_macro: f64,
pub expected_recall_macro: f64,
}
#[derive(Debug, Clone)]
pub struct RankingBenchmark {
pub name: &'static str,
pub relevance_scores: Vec<f64>,
pub expected_ndcg: f64,
pub expected_map: f64,
}
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub benchmark_name: String,
pub passed: bool,
pub actual: f64,
pub expected: f64,
pub difference: f64,
}
pub fn standard_benchmarks() -> Vec<MetricBenchmark> {
vec![
MetricBenchmark {
name: "perfect_prediction",
predictions: vec![1.0, 2.0, 3.0, 4.0, 5.0],
targets: vec![1.0, 2.0, 3.0, 4.0, 5.0],
expected_mse: 0.0,
expected_mae: 0.0,
expected_r2: 1.0,
},
MetricBenchmark {
name: "constant_mean_prediction",
predictions: vec![3.0, 3.0, 3.0, 3.0, 3.0],
targets: vec![1.0, 2.0, 3.0, 4.0, 5.0],
expected_mse: 2.0,
expected_mae: 1.2,
expected_r2: 0.0,
},
MetricBenchmark {
name: "linear_offset_plus_one",
predictions: vec![2.0, 3.0, 4.0, 5.0, 6.0],
targets: vec![1.0, 2.0, 3.0, 4.0, 5.0],
expected_mse: 1.0,
expected_mae: 1.0,
expected_r2: 0.5,
},
MetricBenchmark {
name: "scaled_double",
predictions: vec![2.0, 4.0, 6.0, 8.0, 10.0],
targets: vec![1.0, 2.0, 3.0, 4.0, 5.0],
expected_mse: 11.0,
expected_mae: 3.0,
expected_r2: -4.5,
},
MetricBenchmark {
name: "small_symmetric_noise",
predictions: vec![10.1, 19.9, 30.1, 39.9, 50.1],
targets: vec![10.0, 20.0, 30.0, 40.0, 50.0],
expected_mse: 0.01,
expected_mae: 0.1,
expected_r2: 1.0 - 0.05 / 1000.0,
},
MetricBenchmark {
name: "inverse_prediction",
predictions: vec![5.0, 4.0, 3.0, 2.0, 1.0],
targets: vec![1.0, 2.0, 3.0, 4.0, 5.0],
expected_mse: 8.0,
expected_mae: 2.4,
expected_r2: -3.0,
},
MetricBenchmark {
name: "large_magnitude",
predictions: vec![1e6 + 0.001, 2e6 + 0.001, 3e6 + 0.001],
targets: vec![1e6, 2e6, 3e6],
expected_mse: 1e-6,
expected_mae: 0.001,
expected_r2: 1.0 - 3e-6 / 2e12,
},
MetricBenchmark {
name: "near_zero_values",
predictions: vec![0.002, 0.003, 0.004, 0.005, 0.006],
targets: vec![0.001, 0.002, 0.003, 0.004, 0.005],
expected_mse: 1e-6,
expected_mae: 0.001,
expected_r2: 0.5,
},
]
}
pub fn classification_benchmarks() -> Vec<ClassificationBenchmark> {
vec![
ClassificationBenchmark {
name: "perfect_classification",
predictions: vec![0, 1, 2, 0, 1, 2],
targets: vec![0, 1, 2, 0, 1, 2],
expected_accuracy: 1.0,
expected_precision_macro: 1.0,
expected_recall_macro: 1.0,
},
ClassificationBenchmark {
name: "all_wrong",
predictions: vec![1, 0, 0, 1, 0, 0],
targets: vec![0, 1, 2, 0, 1, 2],
expected_accuracy: 0.0,
expected_precision_macro: 0.0,
expected_recall_macro: 0.0,
},
ClassificationBenchmark {
name: "binary_balanced",
predictions: vec![0, 0, 1, 0, 1, 1],
targets: vec![0, 0, 0, 1, 1, 1],
expected_accuracy: 4.0 / 6.0,
expected_precision_macro: 2.0 / 3.0,
expected_recall_macro: 2.0 / 3.0,
},
ClassificationBenchmark {
name: "single_class",
predictions: vec![0, 0, 0, 0],
targets: vec![0, 0, 0, 0],
expected_accuracy: 1.0,
expected_precision_macro: 1.0,
expected_recall_macro: 1.0,
},
ClassificationBenchmark {
name: "multiclass_half_correct",
predictions: vec![0, 1, 2, 3, 1, 0, 3, 2],
targets: vec![0, 1, 2, 3, 0, 1, 2, 3],
expected_accuracy: 0.5,
expected_precision_macro: 0.5,
expected_recall_macro: 0.5,
},
]
}
pub fn ranking_benchmarks() -> Vec<RankingBenchmark> {
vec![
RankingBenchmark {
name: "perfect_descending",
relevance_scores: vec![3.0, 2.0, 1.0],
expected_ndcg: 1.0,
expected_map: 1.0,
},
RankingBenchmark {
name: "reverse_ranking",
relevance_scores: vec![1.0, 2.0, 3.0],
expected_ndcg: {
let dcg = 1.0 / 2.0_f64.log2() + 2.0 / 3.0_f64.log2() + 3.0 / 4.0_f64.log2();
let idcg = 3.0 / 2.0_f64.log2() + 2.0 / 3.0_f64.log2() + 1.0 / 4.0_f64.log2();
dcg / idcg
},
expected_map: 1.0,
},
RankingBenchmark {
name: "single_relevant_second",
relevance_scores: vec![0.0, 1.0, 0.0, 0.0],
expected_ndcg: {
let dcg = 1.0 / 3.0_f64.log2();
let idcg = 1.0 / 2.0_f64.log2();
dcg / idcg
},
expected_map: 0.5,
},
RankingBenchmark {
name: "no_relevant_items",
relevance_scores: vec![0.0, 0.0, 0.0],
expected_ndcg: 0.0,
expected_map: 0.0,
},
RankingBenchmark {
name: "binary_alternating",
relevance_scores: vec![1.0, 0.0, 1.0, 0.0, 1.0],
expected_ndcg: {
let dcg = 1.0 / 2.0_f64.log2() + 1.0 / 4.0_f64.log2() + 1.0 / 6.0_f64.log2();
let idcg = 1.0 / 2.0_f64.log2() + 1.0 / 3.0_f64.log2() + 1.0 / 4.0_f64.log2();
dcg / idcg
},
expected_map: (1.0 + 2.0 / 3.0 + 3.0 / 5.0) / 3.0,
},
]
}
pub fn validate_regression_metric<F: Fn(&[f64], &[f64]) -> f64>(
metric_fn: F,
benchmarks: &[MetricBenchmark],
expected_field: &str,
tolerance: f64,
) -> Result<Vec<(String, bool, f64)>> {
if !matches!(expected_field, "mse" | "mae" | "r2") {
return Err(MetricsError::InvalidArgument(format!(
"expected_field must be 'mse', 'mae', or 'r2', got '{}'",
expected_field
)));
}
let mut results = Vec::with_capacity(benchmarks.len());
for bench in benchmarks {
let actual = metric_fn(&bench.predictions, &bench.targets);
let expected = match expected_field {
"mse" => bench.expected_mse,
"mae" => bench.expected_mae,
"r2" => bench.expected_r2,
_ => unreachable!(), };
let diff = (actual - expected).abs();
let passed = diff <= tolerance;
results.push((bench.name.to_string(), passed, diff));
}
Ok(results)
}
pub fn validate_regression_metric_detailed<F: Fn(&[f64], &[f64]) -> f64>(
metric_fn: F,
benchmarks: &[MetricBenchmark],
expected_field: &str,
tolerance: f64,
) -> Result<Vec<ValidationResult>> {
if !matches!(expected_field, "mse" | "mae" | "r2") {
return Err(MetricsError::InvalidArgument(format!(
"expected_field must be 'mse', 'mae', or 'r2', got '{}'",
expected_field
)));
}
let mut results = Vec::with_capacity(benchmarks.len());
for bench in benchmarks {
let actual = metric_fn(&bench.predictions, &bench.targets);
let expected = match expected_field {
"mse" => bench.expected_mse,
"mae" => bench.expected_mae,
"r2" => bench.expected_r2,
_ => unreachable!(),
};
let difference = (actual - expected).abs();
results.push(ValidationResult {
benchmark_name: bench.name.to_string(),
passed: difference <= tolerance,
actual,
expected,
difference,
});
}
Ok(results)
}
pub fn validate_classification_metric<F: Fn(&[usize], &[usize]) -> f64>(
metric_fn: F,
benchmarks: &[ClassificationBenchmark],
expected_field: &str,
tolerance: f64,
) -> Result<Vec<(String, bool, f64)>> {
if !matches!(
expected_field,
"accuracy" | "precision_macro" | "recall_macro"
) {
return Err(MetricsError::InvalidArgument(format!(
"expected_field must be 'accuracy', 'precision_macro', or 'recall_macro', got '{}'",
expected_field
)));
}
let mut results = Vec::with_capacity(benchmarks.len());
for bench in benchmarks {
let actual = metric_fn(&bench.predictions, &bench.targets);
let expected = match expected_field {
"accuracy" => bench.expected_accuracy,
"precision_macro" => bench.expected_precision_macro,
"recall_macro" => bench.expected_recall_macro,
_ => unreachable!(),
};
let diff = (actual - expected).abs();
let passed = diff <= tolerance;
results.push((bench.name.to_string(), passed, diff));
}
Ok(results)
}
pub fn validate_ranking_metric<F: Fn(&[f64]) -> f64>(
metric_fn: F,
benchmarks: &[RankingBenchmark],
expected_field: &str,
tolerance: f64,
) -> Result<Vec<(String, bool, f64)>> {
if !matches!(expected_field, "ndcg" | "map") {
return Err(MetricsError::InvalidArgument(format!(
"expected_field must be 'ndcg' or 'map', got '{}'",
expected_field
)));
}
let mut results = Vec::with_capacity(benchmarks.len());
for bench in benchmarks {
let actual = metric_fn(&bench.relevance_scores);
let expected = match expected_field {
"ndcg" => bench.expected_ndcg,
"map" => bench.expected_map,
_ => unreachable!(),
};
let diff = (actual - expected).abs();
let passed = diff <= tolerance;
results.push((bench.name.to_string(), passed, diff));
}
Ok(results)
}
pub fn check_benchmark_consistency(benchmarks: &[MetricBenchmark]) -> Result<()> {
for bench in benchmarks {
if bench.predictions.len() != bench.targets.len() {
return Err(MetricsError::InvalidInput(format!(
"Benchmark '{}': predictions length ({}) != targets length ({})",
bench.name,
bench.predictions.len(),
bench.targets.len()
)));
}
if bench.predictions.is_empty() {
return Err(MetricsError::InvalidInput(format!(
"Benchmark '{}': empty predictions/targets",
bench.name
)));
}
if !bench.expected_mse.is_finite() {
return Err(MetricsError::InvalidInput(format!(
"Benchmark '{}': expected_mse is not finite",
bench.name
)));
}
if !bench.expected_mae.is_finite() {
return Err(MetricsError::InvalidInput(format!(
"Benchmark '{}': expected_mae is not finite",
bench.name
)));
}
if !bench.expected_r2.is_finite() {
return Err(MetricsError::InvalidInput(format!(
"Benchmark '{}': expected_r2 is not finite",
bench.name
)));
}
}
Ok(())
}
pub fn check_classification_benchmark_consistency(
benchmarks: &[ClassificationBenchmark],
) -> Result<()> {
for bench in benchmarks {
if bench.predictions.len() != bench.targets.len() {
return Err(MetricsError::InvalidInput(format!(
"Benchmark '{}': predictions length ({}) != targets length ({})",
bench.name,
bench.predictions.len(),
bench.targets.len()
)));
}
if bench.predictions.is_empty() {
return Err(MetricsError::InvalidInput(format!(
"Benchmark '{}': empty predictions/targets",
bench.name
)));
}
if !bench.expected_accuracy.is_finite()
|| !bench.expected_precision_macro.is_finite()
|| !bench.expected_recall_macro.is_finite()
{
return Err(MetricsError::InvalidInput(format!(
"Benchmark '{}': expected metric values must be finite",
bench.name
)));
}
}
Ok(())
}
pub fn check_ranking_benchmark_consistency(benchmarks: &[RankingBenchmark]) -> Result<()> {
for bench in benchmarks {
if bench.relevance_scores.is_empty() {
return Err(MetricsError::InvalidInput(format!(
"Benchmark '{}': empty relevance scores",
bench.name
)));
}
if !bench.expected_ndcg.is_finite() || !bench.expected_map.is_finite() {
return Err(MetricsError::InvalidInput(format!(
"Benchmark '{}': expected metric values must be finite",
bench.name
)));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_standard_benchmarks_not_empty() {
let benchmarks = standard_benchmarks();
assert!(!benchmarks.is_empty());
assert!(benchmarks.len() >= 8, "Expected at least 8 benchmarks");
}
#[test]
fn test_perfect_prediction_mse_zero() {
let benchmarks = standard_benchmarks();
let perfect = benchmarks
.iter()
.find(|b| b.name == "perfect_prediction")
.expect("perfect_prediction benchmark should exist");
assert!((perfect.expected_mse - 0.0).abs() < 1e-15);
assert!((perfect.expected_mae - 0.0).abs() < 1e-15);
assert!((perfect.expected_r2 - 1.0).abs() < 1e-15);
}
#[test]
fn test_constant_prediction_r2_zero() {
let benchmarks = standard_benchmarks();
let constant = benchmarks
.iter()
.find(|b| b.name == "constant_mean_prediction")
.expect("constant_mean_prediction benchmark should exist");
assert!(
(constant.expected_r2 - 0.0).abs() < 1e-15,
"R2 for constant mean prediction should be 0"
);
assert!(
(constant.expected_mse - 2.0).abs() < 1e-15,
"MSE should be 2.0"
);
}
#[test]
fn test_classification_benchmarks_not_empty() {
let benchmarks = classification_benchmarks();
assert!(!benchmarks.is_empty());
assert!(benchmarks.len() >= 4);
}
#[test]
fn test_validate_regression_metric_perfect() {
let benchmarks = standard_benchmarks();
let mse_fn = |preds: &[f64], targets: &[f64]| -> f64 {
let n = preds.len() as f64;
preds
.iter()
.zip(targets.iter())
.map(|(p, t)| (p - t).powi(2))
.sum::<f64>()
/ n
};
let results = validate_regression_metric(mse_fn, &benchmarks, "mse", 1e-8)
.expect("validation should succeed");
assert!(!results.is_empty());
let perfect_result = results
.iter()
.find(|(name, _, _)| name == "perfect_prediction")
.expect("should find perfect_prediction");
assert!(perfect_result.1, "Perfect prediction MSE should pass");
}
#[test]
fn test_validate_regression_metric_with_known_bad() {
let benchmarks = standard_benchmarks();
let bad_fn = |_preds: &[f64], _targets: &[f64]| -> f64 { 42.0 };
let results = validate_regression_metric(bad_fn, &benchmarks, "mse", 1e-10)
.expect("validation should succeed");
let failures: Vec<_> = results.iter().filter(|(_, passed, _)| !passed).collect();
assert!(
!failures.is_empty(),
"A bad metric should fail some benchmarks"
);
}
#[test]
fn test_validate_regression_metric_invalid_field() {
let benchmarks = standard_benchmarks();
let f = |_: &[f64], _: &[f64]| -> f64 { 0.0 };
let result = validate_regression_metric(f, &benchmarks, "invalid", 1e-10);
assert!(result.is_err());
}
#[test]
fn test_ranking_benchmarks_not_empty() {
let benchmarks = ranking_benchmarks();
assert!(!benchmarks.is_empty());
assert!(benchmarks.len() >= 4);
}
#[test]
fn test_benchmark_consistency() {
let reg = standard_benchmarks();
assert!(check_benchmark_consistency(®).is_ok());
let cls = classification_benchmarks();
assert!(check_classification_benchmark_consistency(&cls).is_ok());
let rank = ranking_benchmarks();
assert!(check_ranking_benchmark_consistency(&rank).is_ok());
}
#[test]
fn test_benchmark_consistency_catches_length_mismatch() {
let bad_bench = vec![MetricBenchmark {
name: "bad",
predictions: vec![1.0, 2.0],
targets: vec![1.0],
expected_mse: 0.0,
expected_mae: 0.0,
expected_r2: 0.0,
}];
assert!(check_benchmark_consistency(&bad_bench).is_err());
}
#[test]
fn test_validate_classification_metric() {
let benchmarks = classification_benchmarks();
let accuracy_fn = |preds: &[usize], targets: &[usize]| -> f64 {
let correct = preds
.iter()
.zip(targets.iter())
.filter(|(p, t)| p == t)
.count();
correct as f64 / preds.len() as f64
};
let results = validate_classification_metric(accuracy_fn, &benchmarks, "accuracy", 1e-10)
.expect("validation should succeed");
assert!(!results.is_empty());
let perfect = results
.iter()
.find(|(name, _, _)| name == "perfect_classification")
.expect("should find perfect_classification");
assert!(perfect.1, "Perfect classification accuracy should pass");
}
#[test]
fn test_validate_ranking_metric() {
let benchmarks = ranking_benchmarks();
let ndcg_fn = |scores: &[f64]| -> f64 {
if scores.is_empty() {
return 0.0;
}
let dcg: f64 = scores
.iter()
.enumerate()
.map(|(i, &rel)| rel / ((i + 2) as f64).log2())
.sum();
let mut sorted_scores = scores.to_vec();
sorted_scores.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let idcg: f64 = sorted_scores
.iter()
.enumerate()
.map(|(i, &rel)| rel / ((i + 2) as f64).log2())
.sum();
if idcg == 0.0 {
0.0
} else {
dcg / idcg
}
};
let results = validate_ranking_metric(ndcg_fn, &benchmarks, "ndcg", 1e-10)
.expect("validation should succeed");
assert!(!results.is_empty());
let perfect = results
.iter()
.find(|(name, _, _)| name == "perfect_descending")
.expect("should find perfect_descending");
assert!(perfect.1, "Perfect descending should have NDCG = 1.0");
}
#[test]
fn test_detailed_validation() {
let benchmarks = standard_benchmarks();
let mse_fn = |preds: &[f64], targets: &[f64]| -> f64 {
let n = preds.len() as f64;
preds
.iter()
.zip(targets.iter())
.map(|(p, t)| (p - t).powi(2))
.sum::<f64>()
/ n
};
let results = validate_regression_metric_detailed(mse_fn, &benchmarks, "mse", 1e-8)
.expect("should succeed");
assert!(!results.is_empty());
for result in &results {
assert!(result.difference.is_finite());
assert!(result.expected.is_finite());
assert!(result.actual.is_finite());
assert!(!result.benchmark_name.is_empty());
}
}
}