use scirs2_core::ndarray::{ArrayBase, Data, Dimension};
use std::collections::HashMap;
use std::fmt::Debug;
use crate::error::{MetricsError, Result};
#[allow(dead_code)]
pub fn homogeneity_completeness_v_measure<T, U, S1, S2, D1, D2>(
labels_true: &ArrayBase<S1, D1>,
labels_pred: &ArrayBase<S2, D2>,
beta: f64,
) -> Result<(f64, f64, f64)>
where
T: Clone + std::hash::Hash + Eq + Debug,
U: Clone + std::hash::Hash + Eq + Debug,
S1: Data<Elem = T>,
S2: Data<Elem = U>,
D1: Dimension,
D2: Dimension,
{
if labels_true.len() != labels_pred.len() {
return Err(MetricsError::InvalidInput(format!(
"labels_true and labels_pred have different lengths: {} vs {}",
labels_true.len(),
labels_pred.len()
)));
}
let n_samples = labels_true.len();
if n_samples == 0 {
return Err(MetricsError::InvalidInput(
"Empty arrays provided".to_string(),
));
}
if beta < 0.0 {
return Err(MetricsError::InvalidInput(
"beta must be non-negative".to_string(),
));
}
let mut contingency: HashMap<(String, String), usize> = HashMap::new();
for (lt, lp) in labels_true.iter().zip(labels_pred.iter()) {
let key = (format!("{lt:?}"), format!("{lp:?}"));
*contingency.entry(key).or_insert(0) += 1;
}
let mut true_counts: HashMap<String, usize> = HashMap::new();
for lt in labels_true.iter() {
let key = format!("{lt:?}");
*true_counts.entry(key).or_insert(0) += 1;
}
let mut pred_counts: HashMap<String, usize> = HashMap::new();
for lp in labels_pred.iter() {
let key = format!("{lp:?}");
*pred_counts.entry(key).or_insert(0) += 1;
}
let mut h_true = 0.0;
for (_, &count) in true_counts.iter() {
let pk = count as f64 / n_samples as f64;
h_true -= pk * pk.ln();
}
let mut h_pred = 0.0;
for (_, &count) in pred_counts.iter() {
let pk = count as f64 / n_samples as f64;
h_pred -= pk * pk.ln();
}
let mut h_true_given_pred = 0.0;
let n_samples_f64 = n_samples as f64;
for label_pred in pred_counts.keys() {
let mut cluster_true_counts: HashMap<String, usize> = HashMap::new();
let mut pred_size = 0;
for ((label_true, lp), &count) in contingency.iter() {
if *lp == *label_pred {
*cluster_true_counts.entry(label_true.clone()).or_insert(0) += count;
pred_size += count;
}
}
for &count in cluster_true_counts.values() {
if count > 0 {
let pk = count as f64 / pred_size as f64;
h_true_given_pred -= (count as f64 / n_samples_f64) * pk.ln();
}
}
}
let mut h_pred_given_true = 0.0;
for label_true in true_counts.keys() {
let mut cluster_pred_counts: HashMap<String, usize> = HashMap::new();
let mut true_size = 0;
for ((lt, label_pred), &count) in contingency.iter() {
if *lt == *label_true {
*cluster_pred_counts.entry(label_pred.clone()).or_insert(0) += count;
true_size += count;
}
}
for &count in cluster_pred_counts.values() {
if count > 0 {
let pk = count as f64 / true_size as f64;
h_pred_given_true -= (count as f64 / n_samples_f64) * pk.ln();
}
}
}
let homogeneity = if h_true == 0.0 {
1.0
} else {
1.0 - h_true_given_pred / h_true
};
let completeness = if h_pred == 0.0 {
1.0
} else {
1.0 - h_pred_given_true / h_pred
};
let v_measure = if homogeneity + completeness == 0.0 {
0.0
} else {
(1.0 + beta) * homogeneity * completeness / (beta * homogeneity + completeness)
};
Ok((
homogeneity.clamp(0.0, 1.0),
completeness.clamp(0.0, 1.0),
v_measure.clamp(0.0, 1.0),
))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_homogeneity_completeness_v_measure() {
let labels_true = array![0, 0, 1, 1, 2, 2];
let labels_pred = array![1, 1, 0, 0, 2, 2];
let (h, c, v) = homogeneity_completeness_v_measure(&labels_true, &labels_pred, 1.0)
.expect("Operation failed");
assert!((h - 1.0).abs() < 1e-10);
assert!((c - 1.0).abs() < 1e-10);
assert!((v - 1.0).abs() < 1e-10);
let labels_true = array![0, 0, 1, 1, 2, 2];
let labels_pred = array![0, 0, 0, 1, 1, 1];
let (h, c, v) = homogeneity_completeness_v_measure(&labels_true, &labels_pred, 1.0)
.expect("Operation failed");
assert!(h > 0.0 && h < 1.0);
assert!(c > 0.0 && c < 1.0);
assert!(v > 0.0 && v < 1.0);
}
}