use scirs2_core::ndarray::ArrayView1;
use scirs2_core::numeric::{Float, FromPrimitive};
use std::collections::HashMap;
use std::fmt::Debug;
use crate::error::{ClusteringError, Result};
pub fn rand_score<F>(labels_true: ArrayView1<i32>, labels_pred: ArrayView1<i32>) -> Result<F>
where
F: Float + FromPrimitive + Debug,
{
if labels_true.len() != labels_pred.len() {
return Err(ClusteringError::InvalidInput(
"True and predicted labels must have the same length".to_string(),
));
}
let n = labels_true.len();
if n <= 1 {
return Ok(F::one());
}
let mut agreements = 0;
for i in 0..n {
for j in (i + 1)..n {
let same_true = labels_true[i] == labels_true[j];
let same_pred = labels_pred[i] == labels_pred[j];
if same_true == same_pred {
agreements += 1;
}
}
}
let total_pairs = n * (n - 1) / 2;
if total_pairs == 0 {
Ok(F::one())
} else {
Ok(F::from(agreements as f64 / total_pairs as f64).expect("Failed to convert to float"))
}
}
pub fn adjusted_rand_score<F>(
labels_true: ArrayView1<i32>,
labels_pred: ArrayView1<i32>,
) -> Result<F>
where
F: Float + FromPrimitive + Debug,
{
if labels_true.len() != labels_pred.len() {
return Err(ClusteringError::InvalidInput(
"True and predicted labels must have the same length".to_string(),
));
}
let n_samples = labels_true.len();
if n_samples <= 1 {
return Ok(F::one());
}
let contingency = build_contingency_table(labels_true, labels_pred);
let mut row_sums = HashMap::new();
let mut col_sums = HashMap::new();
for (&(i, j), &count) in &contingency {
*row_sums.entry(i).or_insert(0) += count;
*col_sums.entry(j).or_insert(0) += count;
}
let mut sum_comb_c = F::zero(); for &n_ij in contingency.values() {
if n_ij >= 2 {
sum_comb_c = sum_comb_c + F::from(combination(n_ij, 2)).expect("Operation failed");
}
}
let mut sum_comb_a = F::zero(); for &n_i in row_sums.values() {
if n_i >= 2 {
sum_comb_a = sum_comb_a + F::from(combination(n_i, 2)).expect("Operation failed");
}
}
let mut sum_comb_b = F::zero(); for &n_j in col_sums.values() {
if n_j >= 2 {
sum_comb_b = sum_comb_b + F::from(combination(n_j, 2)).expect("Operation failed");
}
}
let comb_n = F::from(combination(n_samples, 2)).expect("Operation failed");
let expected_index = (sum_comb_a * sum_comb_b) / comb_n;
let max_index =
(sum_comb_a + sum_comb_b) / F::from(2).expect("Failed to convert constant to float");
let denominator = max_index - expected_index;
if denominator.abs() < F::from(1e-10).expect("Failed to convert constant to float") {
Ok(F::zero())
} else {
Ok((sum_comb_c - expected_index) / denominator)
}
}
pub fn fowlkes_mallows_score<F>(
labels_true: ArrayView1<i32>,
labels_pred: ArrayView1<i32>,
) -> Result<F>
where
F: Float + FromPrimitive + Debug,
{
if labels_true.len() != labels_pred.len() {
return Err(ClusteringError::InvalidInput(
"True and predicted labels must have the same length".to_string(),
));
}
let n = labels_true.len();
if n <= 1 {
return Ok(F::one());
}
let mut tp = 0; let mut fp = 0; let mut fn_count = 0;
for i in 0..n {
for j in (i + 1)..n {
let same_true = labels_true[i] == labels_true[j];
let same_pred = labels_pred[i] == labels_pred[j];
match (same_true, same_pred) {
(true, true) => tp += 1,
(false, true) => fp += 1,
(true, false) => fn_count += 1,
(false, false) => {} }
}
}
if tp == 0 {
return Ok(F::zero());
}
let precision = tp as f64 / (tp + fp) as f64;
let recall = tp as f64 / (tp + fn_count) as f64;
if precision == 0.0 || recall == 0.0 {
Ok(F::zero())
} else {
Ok(F::from((precision * recall).sqrt()).expect("Operation failed"))
}
}
pub fn pair_confusion_matrix(
labels_true: ArrayView1<i32>,
labels_pred: ArrayView1<i32>,
) -> Result<(usize, usize, usize, usize)> {
if labels_true.len() != labels_pred.len() {
return Err(ClusteringError::InvalidInput(
"True and predicted labels must have the same length".to_string(),
));
}
let n = labels_true.len();
let mut tp = 0; let mut tn = 0; let mut fp = 0; let mut fn_count = 0;
for i in 0..n {
for j in (i + 1)..n {
let same_true = labels_true[i] == labels_true[j];
let same_pred = labels_pred[i] == labels_pred[j];
match (same_true, same_pred) {
(true, true) => tp += 1,
(false, false) => tn += 1,
(false, true) => fp += 1,
(true, false) => fn_count += 1,
}
}
}
Ok((tp, tn, fp, fn_count))
}
pub fn pair_precision_recall<F>(
labels_true: ArrayView1<i32>,
labels_pred: ArrayView1<i32>,
) -> Result<(F, F)>
where
F: Float + FromPrimitive + Debug,
{
let (tp, _tn, fp, fn_count) = pair_confusion_matrix(labels_true, labels_pred)?;
let precision = if tp + fp > 0 {
F::from(tp as f64 / (tp + fp) as f64).expect("Operation failed")
} else {
F::zero()
};
let recall = if tp + fn_count > 0 {
F::from(tp as f64 / (tp + fn_count) as f64).expect("Operation failed")
} else {
F::zero()
};
Ok((precision, recall))
}
fn build_contingency_table(
labels_true: ArrayView1<i32>,
labels_pred: ArrayView1<i32>,
) -> HashMap<(i32, i32), usize> {
let mut contingency = HashMap::new();
for (&true_label, &pred_label) in labels_true.iter().zip(labels_pred.iter()) {
*contingency.entry((true_label, pred_label)).or_insert(0) += 1;
}
contingency
}
fn combination(n: usize, k: usize) -> usize {
if k > n || k == 0 {
return if k == 0 { 1 } else { 0 };
}
if k == 1 {
return n;
}
if k == 2 {
return n * (n - 1) / 2;
}
let k = k.min(n - k); let mut result = 1;
for i in 0..k {
result = result * (n - i) / (i + 1);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_rand_score_identical() {
let labels = Array1::from_vec(vec![0, 0, 1, 1, 2, 2]);
let rand: f64 = rand_score(labels.view(), labels.view()).expect("Operation failed");
assert!((rand - 1.0).abs() < 1e-10);
}
#[test]
fn test_adjusted_rand_score_identical() {
let labels = Array1::from_vec(vec![0, 0, 1, 1, 2, 2]);
let ari: f64 = adjusted_rand_score(labels.view(), labels.view()).expect("Operation failed");
assert!((ari - 1.0).abs() < 1e-10);
}
#[test]
fn test_adjusted_rand_score_random() {
let true_labels = Array1::from_vec(vec![0, 0, 1, 1, 2, 2]);
let pred_labels = Array1::from_vec(vec![0, 1, 2, 0, 1, 2]);
let ari: f64 =
adjusted_rand_score(true_labels.view(), pred_labels.view()).expect("Operation failed");
assert!(ari >= -1.0 && ari <= 1.0);
}
#[test]
fn test_fowlkes_mallows_score() {
let true_labels = Array1::from_vec(vec![0, 0, 1, 1]);
let pred_labels = Array1::from_vec(vec![0, 1, 0, 1]);
let fm: f64 = fowlkes_mallows_score(true_labels.view(), pred_labels.view())
.expect("Operation failed");
assert!(fm >= 0.0 && fm <= 1.0);
}
#[test]
fn test_pair_confusion_matrix() {
let true_labels = Array1::from_vec(vec![0, 0, 1, 1]);
let pred_labels = Array1::from_vec(vec![0, 1, 0, 1]);
let (tp, tn, fp, fn_count) = pair_confusion_matrix(true_labels.view(), pred_labels.view())
.expect("Operation failed");
assert_eq!(tp + tn + fp + fn_count, 6);
}
#[test]
fn test_pair_precision_recall() {
let true_labels = Array1::from_vec(vec![0, 0, 1, 1]);
let pred_labels = Array1::from_vec(vec![0, 0, 1, 1]);
let (precision, recall): (f64, f64) =
pair_precision_recall(true_labels.view(), pred_labels.view())
.expect("Operation failed");
assert!((precision - 1.0).abs() < 1e-10);
assert!((recall - 1.0).abs() < 1e-10);
}
#[test]
fn test_combination() {
assert_eq!(combination(4, 2), 6);
assert_eq!(combination(5, 3), 10);
assert_eq!(combination(6, 0), 1);
assert_eq!(combination(6, 6), 1);
assert_eq!(combination(3, 5), 0); }
}