use std::collections::HashMap;
pub fn nmi(pred: &[usize], truth: &[usize]) -> f64 {
if pred.len() != truth.len() || pred.is_empty() {
return 0.0;
}
let (joint, _n) = build_contingency_table(pred, truth);
let n_f = pred.len() as f64;
let mut p_pred = HashMap::new();
let mut p_truth = HashMap::new();
for &p in pred {
*p_pred.entry(p).or_insert(0) += 1;
}
for &t in truth {
*p_truth.entry(t).or_insert(0) += 1;
}
let h_pred: f64 = p_pred
.values()
.map(|&c| {
let p = c as f64 / n_f;
if p > 0.0 {
-p * p.ln()
} else {
0.0
}
})
.sum();
let h_truth: f64 = p_truth
.values()
.map(|&c| {
let p = c as f64 / n_f;
if p > 0.0 {
-p * p.ln()
} else {
0.0
}
})
.sum();
let mut mi = 0.0;
for (&(p, t), &count) in &joint {
if count > 0 {
let p_joint = count as f64 / n_f;
let p_p = *p_pred.get(&p).unwrap_or(&0) as f64 / n_f;
let p_t = *p_truth.get(&t).unwrap_or(&0) as f64 / n_f;
if p_p > 0.0 && p_t > 0.0 {
mi += p_joint * (p_joint / (p_p * p_t)).ln();
}
}
}
let denom = h_pred + h_truth;
if denom > 0.0 {
2.0 * mi / denom
} else {
1.0 }
}
pub fn ari(pred: &[usize], truth: &[usize]) -> f64 {
if pred.len() != truth.len() || pred.is_empty() {
return 0.0;
}
let (joint, n) = build_contingency_table(pred, truth);
let mut row_sums = HashMap::new();
let mut col_sums = HashMap::new();
for (&(p, t), &count) in &joint {
*row_sums.entry(p).or_insert(0usize) += count;
*col_sums.entry(t).or_insert(0usize) += count;
}
let mut sum_comb_ij: f64 = 0.0;
for &count in joint.values() {
sum_comb_ij += comb2(count) as f64;
}
let sum_comb_a: f64 = row_sums.values().map(|&a| comb2(a) as f64).sum();
let sum_comb_b: f64 = col_sums.values().map(|&b| comb2(b) as f64).sum();
let comb_n = comb2(n) as f64;
let expected = sum_comb_a * sum_comb_b / comb_n;
let max_index = (sum_comb_a + sum_comb_b) / 2.0;
let denom = max_index - expected;
if denom.abs() < 1e-10 {
return 1.0; }
(sum_comb_ij - expected) / denom
}
pub fn purity(pred: &[usize], truth: &[usize]) -> f64 {
if pred.len() != truth.len() || pred.is_empty() {
return 0.0;
}
let n = pred.len();
let (joint, _) = build_contingency_table(pred, truth);
let mut cluster_maxes: HashMap<usize, usize> = HashMap::new();
for (&(p, _), &count) in &joint {
let current_max = cluster_maxes.entry(p).or_insert(0);
*current_max = (*current_max).max(count);
}
let correct: usize = cluster_maxes.values().sum();
correct as f64 / n as f64
}
pub fn homogeneity(pred: &[usize], truth: &[usize]) -> f64 {
if pred.len() != truth.len() || pred.is_empty() {
return 0.0;
}
let (h_c, h_c_given_k) = conditional_entropies(pred, truth);
if h_c < 1e-10 {
return 1.0; }
1.0 - h_c_given_k / h_c
}
pub fn completeness(pred: &[usize], truth: &[usize]) -> f64 {
if pred.len() != truth.len() || pred.is_empty() {
return 0.0;
}
let (h_k, h_k_given_c) = conditional_entropies(truth, pred);
if h_k < 1e-10 {
return 1.0; }
1.0 - h_k_given_c / h_k
}
pub fn v_measure(pred: &[usize], truth: &[usize]) -> f64 {
let h = homogeneity(pred, truth);
let c = completeness(pred, truth);
if h + c < 1e-10 {
return 0.0;
}
2.0 * h * c / (h + c)
}
pub fn fowlkes_mallows(pred: &[usize], truth: &[usize]) -> f64 {
if pred.len() != truth.len() || pred.len() < 2 {
return 0.0;
}
let n = pred.len();
let mut tp = 0usize; let mut fp = 0usize; let mut fn_ = 0usize;
for i in 0..n {
for j in (i + 1)..n {
let same_pred = pred[i] == pred[j];
let same_truth = truth[i] == truth[j];
match (same_pred, same_truth) {
(true, true) => tp += 1,
(true, false) => fp += 1,
(false, true) => fn_ += 1,
(false, false) => {}
}
}
}
let precision = if tp + fp > 0 {
tp as f64 / (tp + fp) as f64
} else {
0.0
};
let recall = if tp + fn_ > 0 {
tp as f64 / (tp + fn_) as f64
} else {
0.0
};
(precision * recall).sqrt()
}
fn build_contingency_table(
pred: &[usize],
truth: &[usize],
) -> (HashMap<(usize, usize), usize>, usize) {
let mut table = HashMap::new();
for (&p, &t) in pred.iter().zip(truth.iter()) {
*table.entry((p, t)).or_insert(0) += 1;
}
(table, pred.len())
}
fn comb2(n: usize) -> usize {
if n < 2 {
0
} else {
n * (n - 1) / 2
}
}
fn conditional_entropies(a: &[usize], b: &[usize]) -> (f64, f64) {
let n = a.len() as f64;
let mut count_a = HashMap::new();
for &v in a {
*count_a.entry(v).or_insert(0usize) += 1;
}
let h_a: f64 = count_a
.values()
.map(|&c| {
let p = c as f64 / n;
if p > 0.0 {
-p * p.ln()
} else {
0.0
}
})
.sum();
let mut count_b = HashMap::new();
let mut joint = HashMap::new();
for (&va, &vb) in a.iter().zip(b.iter()) {
*count_b.entry(vb).or_insert(0usize) += 1;
*joint.entry((va, vb)).or_insert(0usize) += 1;
}
let mut h_a_given_b = 0.0;
for (&vb, &nb) in &count_b {
let p_b = nb as f64 / n;
let mut h_a_in_b = 0.0;
for &va in count_a.keys() {
let n_ab = *joint.get(&(va, vb)).unwrap_or(&0);
if n_ab > 0 && nb > 0 {
let p_a_given_b = n_ab as f64 / nb as f64;
h_a_in_b -= p_a_given_b * p_a_given_b.ln();
}
}
h_a_given_b += p_b * h_a_in_b;
}
(h_a, h_a_given_b)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nmi_perfect() {
let pred = [0, 0, 1, 1, 2, 2];
let truth = [0, 0, 1, 1, 2, 2];
assert!((nmi(&pred, &truth) - 1.0).abs() < 0.01);
}
#[test]
fn test_nmi_permuted() {
let pred = [1, 1, 0, 0, 2, 2];
let truth = [0, 0, 1, 1, 2, 2];
assert!((nmi(&pred, &truth) - 1.0).abs() < 0.01);
}
#[test]
fn test_ari_perfect() {
let pred = [0, 0, 1, 1];
let truth = [0, 0, 1, 1];
assert!((ari(&pred, &truth) - 1.0).abs() < 0.01);
}
#[test]
fn test_purity_perfect() {
let pred = [0, 0, 1, 1];
let truth = [0, 0, 1, 1];
assert!((purity(&pred, &truth) - 1.0).abs() < 0.01);
}
#[test]
fn test_purity_overclustering() {
let pred = [0, 1, 2, 3];
let truth = [0, 0, 1, 1];
assert!((purity(&pred, &truth) - 1.0).abs() < 0.01);
}
#[test]
fn test_homogeneity_completeness() {
let pred = [0, 0, 1, 1];
let truth = [0, 0, 1, 1];
assert!((homogeneity(&pred, &truth) - 1.0).abs() < 0.01);
assert!((completeness(&pred, &truth) - 1.0).abs() < 0.01);
assert!((v_measure(&pred, &truth) - 1.0).abs() < 0.01);
}
#[test]
fn test_fowlkes_mallows_perfect() {
let pred = [0, 0, 1, 1];
let truth = [0, 0, 1, 1];
assert!((fowlkes_mallows(&pred, &truth) - 1.0).abs() < 0.01);
}
}