use crate::error::{ClusterError, ClusterResult};
use std::collections::{HashMap, HashSet};
use torsh_tensor::Tensor;
use super::utils::combinations;
pub fn adjusted_rand_score(labels_true: &Tensor, labels_pred: &Tensor) -> ClusterResult<f64> {
let true_vec = labels_true.to_vec().map_err(ClusterError::TensorError)?;
let pred_vec = labels_pred.to_vec().map_err(ClusterError::TensorError)?;
if true_vec.len() != pred_vec.len() {
return Err(ClusterError::InvalidInput(
"Labels must have the same length".to_string(),
));
}
let n = true_vec.len();
if n <= 1 {
return Ok(1.0); }
let true_labels: Vec<i32> = true_vec.iter().map(|&x| x as i32).collect();
let pred_labels: Vec<i32> = pred_vec.iter().map(|&x| x as i32).collect();
let _true_unique: HashSet<i32> = true_labels.iter().cloned().collect();
let _pred_unique: HashSet<i32> = pred_labels.iter().cloned().collect();
let mut contingency_table: HashMap<(i32, i32), usize> = HashMap::new();
for (t, p) in true_labels.iter().zip(pred_labels.iter()) {
*contingency_table.entry((*t, *p)).or_insert(0) += 1;
}
let mut sum_comb_c = 0_f64; let mut a_sum = HashMap::new(); let mut b_sum = HashMap::new();
for (&(true_label, pred_label), &count) in &contingency_table {
if count >= 2 {
sum_comb_c += combinations(count as u64, 2) as f64;
}
*a_sum.entry(true_label).or_insert(0) += count;
*b_sum.entry(pred_label).or_insert(0) += count;
}
let sum_comb_a: f64 = a_sum
.values()
.map(|&count| {
if count >= 2 {
combinations(count as u64, 2) as f64
} else {
0.0
}
})
.sum();
let sum_comb_b: f64 = b_sum
.values()
.map(|&count| {
if count >= 2 {
combinations(count as u64, 2) as f64
} else {
0.0
}
})
.sum();
let total_pairs = if n >= 2 {
combinations(n as u64, 2) as f64
} else {
1.0
};
let expected_value = (sum_comb_a * sum_comb_b) / total_pairs;
let numerator = sum_comb_c - expected_value;
let denominator = 0.5 * (sum_comb_a + sum_comb_b) - expected_value;
if denominator.abs() < f64::EPSILON {
Ok(0.0) } else {
Ok(numerator / denominator)
}
}
pub fn fowlkes_mallows_score(labels_true: &Tensor, labels_pred: &Tensor) -> ClusterResult<f64> {
let true_vec = labels_true.to_vec().map_err(ClusterError::TensorError)?;
let pred_vec = labels_pred.to_vec().map_err(ClusterError::TensorError)?;
if true_vec.len() != pred_vec.len() {
return Err(ClusterError::InvalidInput(
"Labels must have the same length".to_string(),
));
}
let n = true_vec.len();
if n <= 1 {
return Ok(1.0); }
let true_labels: Vec<i32> = true_vec.iter().map(|&x| x as i32).collect();
let pred_labels: Vec<i32> = pred_vec.iter().map(|&x| x as i32).collect();
let mut tp = 0_u64; let mut fp = 0_u64; let mut fn_count = 0_u64;
for i in 0..n {
for j in (i + 1)..n {
let same_true = true_labels[i] == true_labels[j];
let same_pred = pred_labels[i] == pred_labels[j];
match (same_true, same_pred) {
(true, true) => tp += 1, (false, true) => fp += 1, (true, false) => fn_count += 1, (false, false) => {} }
}
}
let precision = if tp + fp == 0 {
1.0 } else {
tp as f64 / (tp + fp) as f64
};
let recall = if tp + fn_count == 0 {
1.0 } else {
tp as f64 / (tp + fn_count) as f64
};
let fm_score = (precision * recall).sqrt();
Ok(fm_score.clamp(0.0, 1.0))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_ari_perfect_match() -> Result<(), Box<dyn std::error::Error>> {
let labels_true = Tensor::from_vec(vec![0.0, 0.0, 1.0, 1.0], &[4])?;
let labels_pred = Tensor::from_vec(vec![0.0, 0.0, 1.0, 1.0], &[4])?;
let ari = adjusted_rand_score(&labels_true, &labels_pred)?;
assert_relative_eq!(ari, 1.0, epsilon = 1e-6);
Ok(())
}
#[test]
fn test_fm_perfect_match() -> Result<(), Box<dyn std::error::Error>> {
let labels_true = Tensor::from_vec(vec![0.0, 0.0, 1.0, 1.0], &[4])?;
let labels_pred = Tensor::from_vec(vec![0.0, 0.0, 1.0, 1.0], &[4])?;
let fm = fowlkes_mallows_score(&labels_true, &labels_pred)?;
assert_relative_eq!(fm, 1.0, epsilon = 1e-6);
Ok(())
}
#[test]
fn test_ari_random_assignment() -> Result<(), Box<dyn std::error::Error>> {
let labels_true = Tensor::from_vec(vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0], &[6])?;
let labels_pred = Tensor::from_vec(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0], &[6])?;
let ari = adjusted_rand_score(&labels_true, &labels_pred)?;
assert!(
ari < 0.5,
"ARI should be low for random assignment: got {}",
ari
);
Ok(())
}
}