#![cfg(feature = "metric")]
use approx::assert_abs_diff_eq;
use ndarray::{Array1, array};
use rustyml::metric::*;
#[test]
fn test_root_mean_squared_error() {
let predictions = array![1.0, 2.0, 3.0, 4.0, 5.0];
let targets = array![1.0, 2.0, 3.0, 4.0, 5.0];
let result = root_mean_squared_error(&predictions.view(), &targets.view());
assert_abs_diff_eq!(result, 0.0);
let predictions = array![2.0, 3.0, 4.0, 5.0, 6.0];
let targets = array![1.0, 2.0, 3.0, 4.0, 5.0];
let result = root_mean_squared_error(&predictions.view(), &targets.view());
assert_abs_diff_eq!(result, 1.0);
let predictions = array![1.5, 2.5, 3.5, 4.5, 5.5];
let targets = array![1.0, 2.0, 3.0, 4.0, 5.0];
let result = root_mean_squared_error(&predictions.view(), &targets.view());
assert_abs_diff_eq!(result, 0.5);
let predictions = array![-1.0, -2.0, -3.0];
let targets = array![1.0, 2.0, 3.0];
let result = root_mean_squared_error(&predictions.view(), &targets.view());
assert_abs_diff_eq!(result, 4.320493798938574);
}
#[test]
fn test_rmse_empty_arrays() {
let empty: Array1<f64> = array![];
let rmse = root_mean_squared_error(&empty.view(), &empty.view());
assert_float_eq(rmse, 0.0);
}
#[test]
#[should_panic]
fn test_rmse_mismatched_lengths() {
let predictions = array![1.0, 2.0, 3.0];
let targets = array![1.0, 2.0];
root_mean_squared_error(&predictions.view(), &targets.view());
}
#[test]
fn test_mean_absolute_error() {
let predictions = array![1.0, 2.0, 3.0, 4.0, 5.0];
let targets = array![1.0, 2.0, 3.0, 4.0, 5.0];
let result = mean_absolute_error(&predictions.view(), &targets.view());
assert_abs_diff_eq!(result, 0.0);
let predictions = array![2.0, 3.0, 4.0, 5.0, 6.0];
let targets = array![1.0, 2.0, 3.0, 4.0, 5.0];
let result = mean_absolute_error(&predictions.view(), &targets.view());
assert_abs_diff_eq!(result, 1.0);
let predictions = array![1.0, 3.0, 2.0];
let targets = array![2.0, 1.0, 3.0];
let result = mean_absolute_error(&predictions.view(), &targets.view());
assert_abs_diff_eq!(result, 1.3333333333333333, epsilon = 1e-10);
let predictions = array![-1.0, -2.0, -3.0];
let targets = array![1.0, 2.0, 3.0];
let result = mean_absolute_error(&predictions.view(), &targets.view());
assert_abs_diff_eq!(result, 4.0);
let predictions = array![1.5, 2.5, 3.5];
let targets = array![1.0, 2.0, 3.0];
let result = mean_absolute_error(&predictions.view(), &targets.view());
assert_abs_diff_eq!(result, 0.5);
}
#[test]
fn test_mae_empty_arrays() {
let empty: Array1<f64> = array![];
let mae = mean_absolute_error(&empty.view(), &empty.view());
assert_float_eq(mae, 0.0);
}
#[test]
#[should_panic]
fn test_mae_mismatched_lengths() {
let predictions = array![1.0, 2.0, 3.0];
let targets = array![1.0, 2.0];
mean_absolute_error(&predictions.view(), &targets.view());
}
#[test]
fn test_r2_score() {
let actual = array![3.0, 2.0, 5.0, 7.0, 9.0];
let predicted = array![2.8, 1.9, 5.2, 7.5, 8.9];
let r2 = r2_score(&predicted.view(), &actual.view());
assert!(r2 > 0.95);
let perfect_actual = array![1.0, 2.0, 3.0];
let perfect_predicted = array![1.0, 2.0, 3.0];
assert_abs_diff_eq!(
r2_score(&perfect_predicted.view(), &perfect_actual.view()),
1.0
);
let mean_actual = array![1.0, 2.0, 3.0, 4.0, 5.0]; let mean_predicted = array![3.0, 3.0, 3.0, 3.0, 3.0];
assert_abs_diff_eq!(r2_score(&mean_predicted.view(), &mean_actual.view()), 0.0);
let same_actual = array![7.0, 7.0, 7.0];
let same_predicted = array![6.0, 7.0, 8.0];
assert_abs_diff_eq!(r2_score(&same_predicted.view(), &same_actual.view()), 0.0);
}
fn assert_float_eq(a: f64, b: f64) {
assert!((a - b).abs() < f64::EPSILON, "Expected {}, got {}", b, a);
}
#[test]
fn test_confusion_matrix_new() {
let pred = array![1.0, 0.0, 1.0, 0.0];
let actual = array![1.0, 0.0, 1.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_eq!(cm.get_counts(), (2, 0, 2, 0));
let pred = array![1.0, 1.0, 0.0, 0.0];
let actual = array![1.0, 0.0, 1.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_eq!(cm.get_counts(), (1, 1, 1, 1));
let pred = array![0.6, 0.4, 0.7, 0.3];
let actual = array![0.9, 0.1, 0.2, 0.8];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_eq!(cm.get_counts(), (1, 1, 1, 1));
}
#[test]
#[should_panic]
fn test_confusion_matrix_new_error() {
let pred = array![1.0, 0.0, 1.0];
let actual = array![1.0, 0.0, 1.0, 0.0];
let _result = ConfusionMatrix::new(&pred.view(), &actual.view());
}
#[test]
fn test_get_counts() {
let pred = array![1.0, 1.0, 0.0, 0.0];
let actual = array![1.0, 0.0, 1.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
let counts = cm.get_counts();
assert_eq!(counts, (1, 1, 1, 1));
}
#[test]
fn test_confusion_matrix_accuracy() {
let pred = array![1.0, 0.0, 1.0, 0.0];
let actual = array![1.0, 0.0, 1.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_float_eq(cm.accuracy(), 1.0);
let pred = array![1.0, 1.0, 0.0, 0.0];
let actual = array![1.0, 0.0, 1.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_float_eq(cm.accuracy(), 0.5);
let pred = array![1.0, 1.0, 1.0, 1.0];
let actual = array![0.0, 0.0, 0.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_float_eq(cm.accuracy(), 0.0);
}
#[test]
fn test_error_rate() {
let pred = array![1.0, 0.0, 1.0, 0.0];
let actual = array![1.0, 0.0, 1.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_float_eq(cm.error_rate(), 0.0);
let pred = array![1.0, 1.0, 0.0, 0.0];
let actual = array![1.0, 0.0, 1.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_float_eq(cm.error_rate(), 0.5);
let pred = array![1.0, 1.0, 1.0, 1.0];
let actual = array![0.0, 0.0, 0.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_float_eq(cm.error_rate(), 1.0);
}
#[test]
fn test_precision() {
let pred = array![1.0, 1.0, 0.0, 0.0];
let actual = array![1.0, 1.0, 0.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_float_eq(cm.precision(), 1.0);
let pred = array![1.0, 1.0, 1.0, 0.0];
let actual = array![1.0, 0.0, 1.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_float_eq(cm.precision(), 2.0 / 3.0);
let pred = array![1.0, 1.0, 1.0];
let actual = array![0.0, 0.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_float_eq(cm.precision(), 0.0);
}
#[test]
fn test_recall() {
let pred = array![1.0, 1.0, 0.0, 0.0];
let actual = array![1.0, 1.0, 0.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_float_eq(cm.recall(), 1.0);
let pred = array![1.0, 0.0, 0.0, 0.0];
let actual = array![1.0, 1.0, 0.0, 0.0];
let _cm = ConfusionMatrix::new(&pred.view(), &actual.view());
let pred = array![0.0, 0.0, 0.0];
let actual = array![0.0, 0.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_float_eq(cm.recall(), 1.0); }
#[test]
fn test_specificity() {
let pred = array![1.0, 1.0, 0.0, 0.0];
let actual = array![1.0, 1.0, 0.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_float_eq(cm.specificity(), 1.0);
let pred = array![1.0, 1.0, 0.0, 1.0];
let actual = array![1.0, 1.0, 0.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_float_eq(cm.specificity(), 0.5);
let pred = array![1.0, 1.0, 1.0];
let actual = array![1.0, 1.0, 1.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_float_eq(cm.specificity(), 1.0); }
#[test]
fn test_f1_score() {
let pred = array![1.0, 1.0, 0.0, 0.0];
let actual = array![1.0, 1.0, 0.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_float_eq(cm.f1_score(), 1.0);
let pred = array![1.0, 1.0, 0.0, 0.0];
let actual = array![1.0, 0.0, 0.0, 1.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_float_eq(cm.f1_score(), 0.5);
let pred = array![1.0, 1.0, 1.0];
let actual = array![0.0, 0.0, 0.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
assert_float_eq(cm.f1_score(), 0.0);
}
#[test]
fn test_summary() {
let pred = array![1.0, 1.0, 0.0, 0.0];
let actual = array![1.0, 0.0, 0.0, 1.0];
let cm = ConfusionMatrix::new(&pred.view(), &actual.view());
let summary = cm.summary();
assert!(summary.contains("TP:"));
assert!(summary.contains("FP:"));
assert!(summary.contains("TN:"));
assert!(summary.contains("FN:"));
assert!(summary.contains("Accuracy:"));
assert!(summary.contains("Error Rate:"));
assert!(summary.contains("Precision:"));
assert!(summary.contains("Recall:"));
assert!(summary.contains("Specificity:"));
assert!(summary.contains("F1 Score:"));
}
#[test]
fn test_accuracy() {
let predicted = array![1.0, 0.0, 1.0, 0.0, 1.0];
let actual = array![1.0, 0.0, 1.0, 0.0, 1.0];
assert_eq!(accuracy(&predicted.view(), &actual.view()), 1.0);
let predicted = array![1.0, 0.0, 1.0, 1.0, 0.0];
let actual = array![1.0, 0.0, 0.0, 1.0, 1.0];
assert_eq!(accuracy(&predicted.view(), &actual.view()), 0.6);
let predicted = array![0.0, 1.0, 2.0, 3.0, 2.0, 1.0];
let actual = array![0.0, 1.0, 2.0, 2.0, 1.0, 1.0];
assert_eq!(accuracy(&predicted.view(), &actual.view()), 2.0 / 3.0);
let predicted = array![0.0000001, 1.0];
let actual = array![0.0, 1.0];
assert!(accuracy(&predicted.view(), &actual.view()) < 1.0);
let predicted = array![1.0, 1.0, 0.0, 0.0];
let actual = array![0.0, 0.0, 1.0, 1.0];
assert_eq!(accuracy(&predicted.view(), &actual.view()), 0.0);
}
#[test]
#[should_panic]
fn test_accuracy_with_different_length_arrays() {
let predicted = array![1.0, 0.0, 1.0];
let actual = array![1.0, 0.0];
accuracy(&predicted.view(), &actual.view()); }
#[test]
fn test_identical_clusterings() {
let labels = array![0, 1, 1, 0, 2, 2];
let nmi = normalized_mutual_info(&labels.view(), &labels.view());
let ami = adjusted_mutual_info(&labels.view(), &labels.view());
assert!(
(nmi - 1.0).abs() < 1e-6,
"NMI of identical clusterings should be 1.0, got {}",
nmi
);
assert!(
(ami - 1.0).abs() < 1e-6,
"AMI of identical clusterings should be 1.0, got {}",
ami
);
}
#[test]
fn test_different_clusterings() {
let labels_true = array![0, 0, 1, 1, 2, 2];
let labels_pred = array![0, 0, 1, 2, 1, 2];
let nmi = normalized_mutual_info(&labels_true.view(), &labels_pred.view());
let expected_nmi = 0.578;
assert!(
(nmi - expected_nmi).abs() < 0.05,
"Expected NMI approx {} but got {}",
expected_nmi,
nmi
);
let ami = adjusted_mutual_info(&labels_true.view(), &labels_pred.view());
assert!(
ami < 1.0,
"AMI should be less than 1.0 for different clusterings, got {}",
ami
);
assert!(ami >= 0.0, "AMI should be non-negative, got {}", ami);
}
#[test]
#[should_panic]
fn test_length_mismatch() {
let labels_true = array![0, 1, 1];
let labels_pred = array![0, 1];
normalized_mutual_info(&labels_true.view(), &labels_pred.view());
}
#[test]
fn test_constant_labels() {
let labels_true = array![0, 0, 0, 0];
let labels_pred = array![1, 1, 1, 1];
let nmi = normalized_mutual_info(&labels_true.view(), &labels_pred.view());
assert_eq!(nmi, 0.0, "NMI for constant labels should be 0.0");
let ami = adjusted_mutual_info(&labels_true.view(), &labels_pred.view());
assert_eq!(ami, 1.0, "AMI for constant labels should be 1.0");
}
#[test]
fn test_calculate_auc_valid() {
let scores = array![0.1, 0.4, 0.35, 0.8];
let labels = array![false, true, true, false];
let auc = calculate_auc(&scores.view(), &labels.view());
assert!((auc - 0.5).abs() < 1e-12);
}
#[test]
#[should_panic]
fn test_calculate_auc_empty_input() {
let scores: Array1<f64> = array![];
let labels: Array1<bool> = array![];
let _result = calculate_auc(&scores.view(), &labels.view());
}
#[test]
#[should_panic]
fn test_calculate_auc_length_mismatch() {
let scores = array![0.1, 0.2];
let labels = array![true];
let _result = calculate_auc(&scores.view(), &labels.view());
}
#[test]
#[should_panic]
fn test_calculate_auc_no_positive() {
let scores = array![0.2, 0.3, 0.4];
let labels = array![false, false, false];
let _result = calculate_auc(&scores.view(), &labels.view());
}
#[test]
#[should_panic]
fn test_calculate_auc_no_negative() {
let scores = array![0.2, 0.3, 0.4];
let labels = array![true, true, true];
let _result = calculate_auc(&scores.view(), &labels.view());
}
#[test]
fn test_calculate_auc_with_ties() {
let scores = array![0.5, 0.5, 0.5, 0.5];
let labels = array![true, false, true, false];
let auc = calculate_auc(&scores.view(), &labels.view());
assert!((auc - 0.5).abs() < 1e-12);
}