#![allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
use std::collections::HashMap;
use crate::core::error::{IgraphError, IgraphResult};
use super::reindex_membership::reindex_membership;
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum CommunityComparison {
VariationOfInformation,
NormalizedMutualInformation,
SplitJoin,
Rand,
AdjustedRand,
}
pub fn compare_communities(
comm1: &[u32],
comm2: &[u32],
method: CommunityComparison,
) -> IgraphResult<f64> {
if comm1.len() != comm2.len() {
return Err(IgraphError::InvalidArgument(format!(
"community membership vectors have different lengths: {} and {}",
comm1.len(),
comm2.len(),
)));
}
let n = comm1.len();
if n == 0 {
return match method {
CommunityComparison::NormalizedMutualInformation => Ok(1.0),
CommunityComparison::VariationOfInformation | CommunityComparison::SplitJoin => Ok(0.0),
CommunityComparison::Rand | CommunityComparison::AdjustedRand => {
Err(IgraphError::InvalidArgument(format!(
"Rand indices not defined for zero or one vertices. \
Found membership vector of size {n}.",
)))
}
};
}
let c1 = reindex_membership(comm1)?;
let c2 = reindex_membership(comm2)?;
match method {
CommunityComparison::VariationOfInformation => {
let (h1, h2, mi) = entropy_and_mutual_information(&c1.membership, &c2.membership, n);
Ok(h1 + h2 - 2.0 * mi)
}
CommunityComparison::NormalizedMutualInformation => {
let (h1, h2, mi) = entropy_and_mutual_information(&c1.membership, &c2.membership, n);
if h1 == 0.0 && h2 == 0.0 {
Ok(1.0)
} else {
Ok(2.0 * mi / (h1 + h2))
}
}
CommunityComparison::SplitJoin => {
let (d12, d21) = split_join_distances(&c1.membership, &c2.membership, n);
Ok((d12 + d21) as f64)
}
CommunityComparison::Rand | CommunityComparison::AdjustedRand => {
if n < 2 {
return Err(IgraphError::InvalidArgument(format!(
"Rand indices not defined for zero or one vertices. \
Found membership vector of size {n}.",
)));
}
Ok(rand_index(
&c1.membership,
&c2.membership,
n,
matches!(method, CommunityComparison::AdjustedRand),
))
}
}
}
fn entropy_and_mutual_information(v1: &[u32], v2: &[u32], n: usize) -> (f64, f64, f64) {
let k1 = max_plus_one(v1);
let k2 = max_plus_one(v2);
let n_f = n as f64;
let mut p1: Vec<f64> = vec![0.0; k1];
let mut p2: Vec<f64> = vec![0.0; k2];
for &c in v1 {
p1[c as usize] += 1.0;
}
for &c in v2 {
p2[c as usize] += 1.0;
}
let mut h1 = 0.0;
for x in &mut p1 {
*x /= n_f;
h1 -= *x * x.ln();
}
let mut h2 = 0.0;
for x in &mut p2 {
*x /= n_f;
h2 -= *x * x.ln();
}
let log_p1: Vec<f64> = p1.iter().map(|&p| p.ln()).collect();
let log_p2: Vec<f64> = p2.iter().map(|&p| p.ln()).collect();
let mut counts: HashMap<(u32, u32), u32> = HashMap::new();
for i in 0..n {
*counts.entry((v1[i], v2[i])).or_insert(0) += 1;
}
let mut mut_inf = 0.0;
for (&(r, c), &cnt) in &counts {
let p = f64::from(cnt) / n_f;
mut_inf += p * (p.ln() - log_p1[r as usize] - log_p2[c as usize]);
}
(h1, h2, mut_inf)
}
pub(crate) fn split_join_distances(v1: &[u32], v2: &[u32], n: usize) -> (u64, u64) {
let k1 = max_plus_one(v1);
let k2 = max_plus_one(v2);
let mut counts: HashMap<(u32, u32), u32> = HashMap::new();
for i in 0..n {
*counts.entry((v1[i], v2[i])).or_insert(0) += 1;
}
let mut row_max: Vec<u32> = vec![0; k1];
let mut col_max: Vec<u32> = vec![0; k2];
for (&(r, c), &cnt) in &counts {
let r_slot = &mut row_max[r as usize];
if cnt > *r_slot {
*r_slot = cnt;
}
let c_slot = &mut col_max[c as usize];
if cnt > *c_slot {
*c_slot = cnt;
}
}
let sum_row: u64 = row_max.iter().map(|&x| u64::from(x)).sum();
let sum_col: u64 = col_max.iter().map(|&x| u64::from(x)).sum();
let n_u64 = n as u64;
(n_u64 - sum_row, n_u64 - sum_col)
}
fn rand_index(v1: &[u32], v2: &[u32], n: usize, adjust: bool) -> f64 {
let k1 = max_plus_one(v1);
let k2 = max_plus_one(v2);
let n_f = n as f64;
let mut counts: HashMap<(u32, u32), u32> = HashMap::new();
for i in 0..n {
*counts.entry((v1[i], v2[i])).or_insert(0) += 1;
}
let mut row_sums: Vec<f64> = vec![0.0; k1];
let mut col_sums: Vec<f64> = vec![0.0; k2];
for (&(r, c), &cnt) in &counts {
row_sums[r as usize] += f64::from(cnt);
col_sums[c as usize] += f64::from(cnt);
}
let mut joint = 0.0;
for &cnt in counts.values() {
let v = f64::from(cnt);
joint += (v / n_f) * (v - 1.0) / (n_f - 1.0);
}
let mut frac_in_1 = 0.0;
for &v in &row_sums {
frac_in_1 += (v / n_f) * (v - 1.0) / (n_f - 1.0);
}
let mut frac_in_2 = 0.0;
for &v in &col_sums {
frac_in_2 += (v / n_f) * (v - 1.0) / (n_f - 1.0);
}
let rand = 1.0 + 2.0 * joint - frac_in_1 - frac_in_2;
if adjust {
let expected = frac_in_1 * frac_in_2 + (1.0 - frac_in_1) * (1.0 - frac_in_2);
let denom = 1.0 - expected;
if denom == 0.0 {
1.0
} else {
(rand - expected) / denom
}
} else {
rand
}
}
fn max_plus_one(v: &[u32]) -> usize {
let m = v.iter().copied().max().unwrap_or(0);
(m as usize) + 1
}
#[cfg(test)]
mod tests {
use super::*;
fn close(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
#[test]
fn err_on_length_mismatch() {
let err = compare_communities(&[0, 1], &[0], CommunityComparison::VariationOfInformation)
.unwrap_err();
match err {
IgraphError::InvalidArgument(_) => (),
other => panic!("expected InvalidArgument, got {other:?}"),
}
}
#[test]
fn empty_input_returns_method_defaults() {
for (m, expected) in [
(CommunityComparison::VariationOfInformation, 0.0),
(CommunityComparison::NormalizedMutualInformation, 1.0),
(CommunityComparison::SplitJoin, 0.0),
] {
let q = compare_communities(&[], &[], m).unwrap();
assert!(close(q, expected, 1e-12), "method {m:?} got {q}");
}
for m in [CommunityComparison::Rand, CommunityComparison::AdjustedRand] {
assert!(compare_communities(&[], &[], m).is_err());
}
}
#[test]
fn identical_partitions_have_nmi_1_and_vi_0() {
let v = [0, 0, 1, 1, 2, 2];
assert!(close(
compare_communities(&v, &v, CommunityComparison::NormalizedMutualInformation).unwrap(),
1.0,
1e-12,
));
assert!(close(
compare_communities(&v, &v, CommunityComparison::VariationOfInformation).unwrap(),
0.0,
1e-12,
));
assert!(close(
compare_communities(&v, &v, CommunityComparison::Rand).unwrap(),
1.0,
1e-12,
));
assert!(close(
compare_communities(&v, &v, CommunityComparison::AdjustedRand).unwrap(),
1.0,
1e-12,
));
assert!(close(
compare_communities(&v, &v, CommunityComparison::SplitJoin).unwrap(),
0.0,
1e-12,
));
}
#[test]
fn relabel_invariance() {
let a = [0, 0, 1, 1, 2, 2];
let b = [7, 7, 3, 3, 9, 9];
for m in [
CommunityComparison::VariationOfInformation,
CommunityComparison::NormalizedMutualInformation,
CommunityComparison::SplitJoin,
CommunityComparison::Rand,
CommunityComparison::AdjustedRand,
] {
let q1 = compare_communities(&a, &a, m).unwrap();
let q2 = compare_communities(&a, &b, m).unwrap();
assert!(close(q1, q2, 1e-12), "method {m:?}: {q1} vs {q2}");
}
}
#[test]
fn singletons_vs_singletons() {
let v: Vec<u32> = (0..6).collect();
assert!(close(
compare_communities(&v, &v, CommunityComparison::NormalizedMutualInformation).unwrap(),
1.0,
1e-12,
));
let w: Vec<u32> = (0..6).rev().collect();
assert!(close(
compare_communities(&v, &w, CommunityComparison::Rand).unwrap(),
1.0,
1e-12,
));
}
#[test]
fn one_cluster_each_side_is_nmi_one_per_spec() {
let v = [0u32; 5];
let w = [9u32; 5];
assert!(close(
compare_communities(&v, &w, CommunityComparison::NormalizedMutualInformation).unwrap(),
1.0,
1e-12,
));
assert!(close(
compare_communities(&v, &w, CommunityComparison::VariationOfInformation).unwrap(),
0.0,
1e-12,
));
assert!(close(
compare_communities(&v, &w, CommunityComparison::SplitJoin).unwrap(),
0.0,
1e-12,
));
assert!(close(
compare_communities(&v, &w, CommunityComparison::Rand).unwrap(),
1.0,
1e-12,
));
}
#[test]
fn full_disagreement_two_clusters() {
let a = [0u32, 0, 1, 1];
let b = [0u32, 1, 0, 1];
let nmi =
compare_communities(&a, &b, CommunityComparison::NormalizedMutualInformation).unwrap();
assert!(close(nmi, 0.0, 1e-12), "NMI = {nmi}");
let vi = compare_communities(&a, &b, CommunityComparison::VariationOfInformation).unwrap();
assert!(close(vi, 2.0 * 2f64.ln(), 1e-12), "VI = {vi}");
let sj = compare_communities(&a, &b, CommunityComparison::SplitJoin).unwrap();
assert!(close(sj, 4.0, 1e-12), "SJ = {sj}");
let rand = compare_communities(&a, &b, CommunityComparison::Rand).unwrap();
assert!(close(rand, 1.0 / 3.0, 1e-12), "Rand = {rand}");
let ar = compare_communities(&a, &b, CommunityComparison::AdjustedRand).unwrap();
assert!(close(ar, -0.5, 1e-12), "AR = {ar}");
}
#[test]
fn split_join_is_zero_for_subpartition() {
let a = [0u32, 0, 0, 1, 1, 1];
let b = [5u32, 5, 6, 7, 7, 8];
let r1 = reindex_membership(&a).unwrap();
let r2 = reindex_membership(&b).unwrap();
let (d12, d21) = split_join_distances(&r1.membership, &r2.membership, a.len());
assert_eq!(d12, 2);
assert_eq!(d21, 0);
let sj = compare_communities(&a, &b, CommunityComparison::SplitJoin).unwrap();
assert!(close(sj, 2.0, 1e-12));
}
#[test]
fn nmi_is_symmetric() {
let a = [0u32, 0, 1, 1, 2, 2, 0, 1];
let b = [3u32, 4, 4, 3, 3, 4, 4, 3];
let n_ab =
compare_communities(&a, &b, CommunityComparison::NormalizedMutualInformation).unwrap();
let n_ba =
compare_communities(&b, &a, CommunityComparison::NormalizedMutualInformation).unwrap();
assert!(close(n_ab, n_ba, 1e-12));
}
#[test]
fn rand_requires_at_least_two_vertices() {
let v = [0u32];
assert!(compare_communities(&v, &v, CommunityComparison::Rand).is_err());
assert!(compare_communities(&v, &v, CommunityComparison::AdjustedRand).is_err());
}
#[test]
fn variation_of_information_zero_iff_same_partition() {
let a = [0u32, 0, 1, 1];
let b = [1u32, 1, 0, 0]; let vi = compare_communities(&a, &b, CommunityComparison::VariationOfInformation).unwrap();
assert!(close(vi, 0.0, 1e-12));
}
#[cfg(all(test, feature = "proptest-harness"))]
mod prop {
use super::*;
use proptest::prelude::*;
prop_compose! {
fn arb_pair()(
n in 2usize..=24,
k1 in 1u32..=5,
k2 in 1u32..=5,
seed in any::<u64>(),
) -> (Vec<u32>, Vec<u32>) {
let mut rng: u64 = seed.wrapping_add(0xDEAD_BEEF_C0FF_EE00);
let mut step = || -> u32 {
rng = rng.wrapping_mul(0x9E37_79B9_7F4A_7C15).wrapping_add(1);
(rng >> 32) as u32
};
let v1: Vec<u32> = (0..n).map(|_| step() % k1).collect();
let v2: Vec<u32> = (0..n).map(|_| step() % k2).collect();
(v1, v2)
}
}
proptest! {
#![proptest_config(ProptestConfig { cases: 60, ..ProptestConfig::default() })]
#[test]
fn nmi_in_unit_interval((a, b) in arb_pair()) {
let q = compare_communities(
&a, &b, CommunityComparison::NormalizedMutualInformation,
).unwrap();
prop_assert!((-1e-9..=1.0 + 1e-9).contains(&q), "NMI out of [0,1]: {}", q);
}
#[test]
fn vi_non_negative((a, b) in arb_pair()) {
let q = compare_communities(
&a, &b, CommunityComparison::VariationOfInformation,
).unwrap();
prop_assert!(q >= -1e-9, "VI < 0: {}", q);
}
#[test]
fn rand_in_unit_interval((a, b) in arb_pair()) {
let q = compare_communities(
&a, &b, CommunityComparison::Rand,
).unwrap();
prop_assert!((-1e-9..=1.0 + 1e-9).contains(&q), "Rand out of [0,1]: {}", q);
}
#[test]
fn adjusted_rand_capped_at_one((a, b) in arb_pair()) {
let q = compare_communities(
&a, &b, CommunityComparison::AdjustedRand,
).unwrap();
prop_assert!(q <= 1.0 + 1e-9, "AR > 1: {}", q);
}
#[test]
fn measures_are_relabel_invariant((a, b) in arb_pair()) {
let bump = |v: &[u32], offset: u32| -> Vec<u32> {
v.iter().map(|&x| x.wrapping_add(offset).wrapping_mul(7)).collect()
};
let a2 = bump(&a, 100);
let b2 = bump(&b, 50);
for m in [
CommunityComparison::VariationOfInformation,
CommunityComparison::NormalizedMutualInformation,
CommunityComparison::SplitJoin,
CommunityComparison::Rand,
CommunityComparison::AdjustedRand,
] {
let q1 = compare_communities(&a, &b, m).unwrap();
let q2 = compare_communities(&a2, &b2, m).unwrap();
prop_assert!((q1 - q2).abs() < 1e-9, "method {:?}: {} vs {}", m, q1, q2);
}
}
#[test]
fn nmi_symmetric((a, b) in arb_pair()) {
let ab = compare_communities(
&a, &b, CommunityComparison::NormalizedMutualInformation,
).unwrap();
let ba = compare_communities(
&b, &a, CommunityComparison::NormalizedMutualInformation,
).unwrap();
prop_assert!((ab - ba).abs() < 1e-9);
}
#[test]
fn identical_partition_is_extremal((a, _b) in arb_pair()) {
for (m, expected) in [
(CommunityComparison::VariationOfInformation, 0.0_f64),
(CommunityComparison::NormalizedMutualInformation, 1.0),
(CommunityComparison::SplitJoin, 0.0),
(CommunityComparison::Rand, 1.0),
(CommunityComparison::AdjustedRand, 1.0),
] {
let q = compare_communities(&a, &a, m).unwrap();
prop_assert!((q - expected).abs() < 1e-9, "method {:?}: {} vs {}", m, q, expected);
}
}
}
}
}