use crate::error::{ClusterError, ClusterResult};
use std::collections::HashMap;
use torsh_tensor::Tensor;
use super::utils::compute_entropy;
pub fn normalized_mutual_info_score(
labels_true: &Tensor,
labels_pred: &Tensor,
) -> ClusterResult<f64> {
let true_vec = labels_true.to_vec().map_err(ClusterError::TensorError)?;
let pred_vec = labels_pred.to_vec().map_err(ClusterError::TensorError)?;
if true_vec.len() != pred_vec.len() {
return Err(ClusterError::InvalidInput(
"Labels must have the same length".to_string(),
));
}
let n = true_vec.len();
if n <= 1 {
return Ok(1.0); }
let true_labels: Vec<i32> = true_vec.iter().map(|&x| x as i32).collect();
let pred_labels: Vec<i32> = pred_vec.iter().map(|&x| x as i32).collect();
let mut contingency_table: HashMap<(i32, i32), usize> = HashMap::new();
let mut true_counts: HashMap<i32, usize> = HashMap::new();
let mut pred_counts: HashMap<i32, usize> = HashMap::new();
for (t, p) in true_labels.iter().zip(pred_labels.iter()) {
*contingency_table.entry((*t, *p)).or_insert(0) += 1;
*true_counts.entry(*t).or_insert(0) += 1;
*pred_counts.entry(*p).or_insert(0) += 1;
}
let mut mi = 0.0;
for (&(true_label, pred_label), &joint_count) in &contingency_table {
if joint_count > 0 {
let p_joint = joint_count as f64 / n as f64;
let p_true = true_counts[&true_label] as f64 / n as f64;
let p_pred = pred_counts[&pred_label] as f64 / n as f64;
mi += p_joint * (p_joint / (p_true * p_pred)).ln();
}
}
let entropy_true = compute_entropy(&true_counts, n);
let entropy_pred = compute_entropy(&pred_counts, n);
if entropy_true == 0.0 || entropy_pred == 0.0 {
if entropy_true == entropy_pred {
Ok(1.0) } else {
Ok(0.0) }
} else {
let nmi = mi / (entropy_true * entropy_pred).sqrt();
Ok(nmi.clamp(0.0, 1.0)) }
}
pub fn adjusted_mutual_info_score(
labels_true: &Tensor,
labels_pred: &Tensor,
) -> ClusterResult<f64> {
let true_vec = labels_true.to_vec().map_err(ClusterError::TensorError)?;
let pred_vec = labels_pred.to_vec().map_err(ClusterError::TensorError)?;
if true_vec.len() != pred_vec.len() {
return Err(ClusterError::InvalidInput(
"Labels must have the same length".to_string(),
));
}
let n = true_vec.len() as f64;
if n <= 1.0 {
return Ok(1.0); }
let true_labels: Vec<i32> = true_vec.iter().map(|&x| x as i32).collect();
let pred_labels: Vec<i32> = pred_vec.iter().map(|&x| x as i32).collect();
let mut contingency_table: HashMap<(i32, i32), usize> = HashMap::new();
let mut true_counts: HashMap<i32, usize> = HashMap::new();
let mut pred_counts: HashMap<i32, usize> = HashMap::new();
for (t, p) in true_labels.iter().zip(pred_labels.iter()) {
*contingency_table.entry((*t, *p)).or_insert(0) += 1;
*true_counts.entry(*t).or_insert(0) += 1;
*pred_counts.entry(*p).or_insert(0) += 1;
}
let mut mi = 0.0;
for (&(true_label, pred_label), &joint_count) in &contingency_table {
if joint_count > 0 {
let p_joint = joint_count as f64 / n;
let p_true = true_counts[&true_label] as f64 / n;
let p_pred = pred_counts[&pred_label] as f64 / n;
mi += p_joint * (p_joint / (p_true * p_pred)).ln();
}
}
let entropy_true = compute_entropy(&true_counts, n as usize);
let entropy_pred = compute_entropy(&pred_counts, n as usize);
let emi = compute_expected_mutual_info(&true_counts, &pred_counts, n as usize);
let normalizer = 0.5 * (entropy_true + entropy_pred);
if normalizer.abs() < f64::EPSILON {
Ok(0.0) } else {
let ami = (mi - emi) / (normalizer - emi);
Ok(ami.clamp(0.0, 1.0)) }
}
fn compute_expected_mutual_info(
true_counts: &HashMap<i32, usize>,
pred_counts: &HashMap<i32, usize>,
n: usize,
) -> f64 {
let mut emi = 0.0;
for &a_i in true_counts.values() {
for &b_j in pred_counts.values() {
let a_i = a_i as f64;
let b_j = b_j as f64;
let n = n as f64;
let expected_n_ij = (a_i * b_j) / n;
if expected_n_ij > 0.0 {
let p_true = a_i / n;
let p_pred = b_j / n;
let p_joint = expected_n_ij / n;
if p_joint > 0.0 {
emi += p_joint * (p_joint / (p_true * p_pred)).ln();
}
}
}
}
emi
}
pub fn homogeneity_score(labels_true: &Tensor, labels_pred: &Tensor) -> ClusterResult<f64> {
let true_vec = labels_true.to_vec().map_err(ClusterError::TensorError)?;
let pred_vec = labels_pred.to_vec().map_err(ClusterError::TensorError)?;
if true_vec.len() != pred_vec.len() {
return Err(ClusterError::InvalidInput(
"Labels must have the same length".to_string(),
));
}
let n = true_vec.len();
if n <= 1 {
return Ok(1.0); }
let true_labels: Vec<i32> = true_vec.iter().map(|&x| x as i32).collect();
let pred_labels: Vec<i32> = pred_vec.iter().map(|&x| x as i32).collect();
let mut contingency_table: HashMap<(i32, i32), usize> = HashMap::new();
let mut true_counts: HashMap<i32, usize> = HashMap::new();
let mut pred_counts: HashMap<i32, usize> = HashMap::new();
for (t, p) in true_labels.iter().zip(pred_labels.iter()) {
*contingency_table.entry((*t, *p)).or_insert(0) += 1;
*true_counts.entry(*t).or_insert(0) += 1;
*pred_counts.entry(*p).or_insert(0) += 1;
}
let entropy_true = compute_entropy(&true_counts, n);
if entropy_true == 0.0 {
return Ok(1.0); }
let mut conditional_entropy = 0.0;
for (&pred_label, &cluster_size) in &pred_counts {
if cluster_size > 0 {
let cluster_prob = cluster_size as f64 / n as f64;
let mut true_in_cluster: HashMap<i32, usize> = HashMap::new();
for (&(true_label, predicted_label), &count) in &contingency_table {
if predicted_label == pred_label {
*true_in_cluster.entry(true_label).or_insert(0) += count;
}
}
let cluster_entropy = compute_entropy(&true_in_cluster, cluster_size);
conditional_entropy += cluster_prob * cluster_entropy;
}
}
let homogeneity = 1.0 - (conditional_entropy / entropy_true);
Ok(homogeneity.clamp(0.0, 1.0))
}
pub fn completeness_score(labels_true: &Tensor, labels_pred: &Tensor) -> ClusterResult<f64> {
let true_vec = labels_true.to_vec().map_err(ClusterError::TensorError)?;
let pred_vec = labels_pred.to_vec().map_err(ClusterError::TensorError)?;
if true_vec.len() != pred_vec.len() {
return Err(ClusterError::InvalidInput(
"Labels must have the same length".to_string(),
));
}
let n = true_vec.len();
if n <= 1 {
return Ok(1.0); }
let true_labels: Vec<i32> = true_vec.iter().map(|&x| x as i32).collect();
let pred_labels: Vec<i32> = pred_vec.iter().map(|&x| x as i32).collect();
let mut contingency_table: HashMap<(i32, i32), usize> = HashMap::new();
let mut true_counts: HashMap<i32, usize> = HashMap::new();
let mut pred_counts: HashMap<i32, usize> = HashMap::new();
for (t, p) in true_labels.iter().zip(pred_labels.iter()) {
*contingency_table.entry((*t, *p)).or_insert(0) += 1;
*true_counts.entry(*t).or_insert(0) += 1;
*pred_counts.entry(*p).or_insert(0) += 1;
}
let entropy_pred = compute_entropy(&pred_counts, n);
if entropy_pred == 0.0 {
return Ok(1.0); }
let mut conditional_entropy = 0.0;
for (&true_label, &class_size) in &true_counts {
if class_size > 0 {
let class_prob = class_size as f64 / n as f64;
let mut pred_in_class: HashMap<i32, usize> = HashMap::new();
for (&(true_lbl, predicted_label), &count) in &contingency_table {
if true_lbl == true_label {
*pred_in_class.entry(predicted_label).or_insert(0) += count;
}
}
let class_entropy = compute_entropy(&pred_in_class, class_size);
conditional_entropy += class_prob * class_entropy;
}
}
let completeness = 1.0 - (conditional_entropy / entropy_pred);
Ok(completeness.clamp(0.0, 1.0))
}
pub fn v_measure_score(labels_true: &Tensor, labels_pred: &Tensor) -> ClusterResult<f64> {
let homogeneity = homogeneity_score(labels_true, labels_pred)?;
let completeness = completeness_score(labels_true, labels_pred)?;
if homogeneity + completeness == 0.0 {
Ok(0.0)
} else {
Ok(2.0 * homogeneity * completeness / (homogeneity + completeness))
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_nmi_perfect_match() -> Result<(), Box<dyn std::error::Error>> {
let labels_true = Tensor::from_vec(vec![0.0, 0.0, 1.0, 1.0], &[4])?;
let labels_pred = Tensor::from_vec(vec![0.0, 0.0, 1.0, 1.0], &[4])?;
let nmi = normalized_mutual_info_score(&labels_true, &labels_pred)?;
assert_relative_eq!(nmi, 1.0, epsilon = 1e-6);
Ok(())
}
#[test]
fn test_v_measure_perfect_match() -> Result<(), Box<dyn std::error::Error>> {
let labels_true = Tensor::from_vec(vec![0.0, 0.0, 1.0, 1.0], &[4])?;
let labels_pred = Tensor::from_vec(vec![0.0, 0.0, 1.0, 1.0], &[4])?;
let v_measure = v_measure_score(&labels_true, &labels_pred)?;
assert_relative_eq!(v_measure, 1.0, epsilon = 1e-6);
Ok(())
}
}