use crate::metric::ClusteringMetricInput;
use burn::prelude::*;
use burn::train::metric::state::{FormatOptions, NumericMetricState};
use burn::train::metric::{Metric, MetricEntry, MetricMetadata, Numeric};
use std::collections::HashMap;
#[derive(Default)]
pub(crate) struct NMIMetric<B: Backend> {
state: NumericMetricState,
_b: B,
}
impl<B: Backend> NMIMetric<B> {
pub(crate) fn new() -> Self {
Self::default()
}
}
pub fn nmi_score<T>(y_pred: &[T], y_true: &[T]) -> f64
where
T: std::cmp::Eq + std::hash::Hash + Copy,
{
assert_eq!(y_pred.len(), y_true.len());
let n = y_true.len() as f64;
let mut true_counts = HashMap::new();
let mut pred_counts = HashMap::new();
let mut joint_counts = HashMap::new();
for (&true_label, &pred_label) in y_true.iter().zip(y_pred) {
*true_counts.entry(true_label).or_insert(0.0) += 1.0;
*pred_counts.entry(pred_label).or_insert(0.0) += 1.0;
*joint_counts.entry((true_label, pred_label)).or_insert(0.0) += 1.0;
}
let h_true = true_counts
.values()
.map(|&count| {
let p = count / n;
-p * p.log2()
})
.sum::<f64>();
let h_pred = pred_counts
.values()
.map(|&count| {
let p = count / n;
-p * p.log2()
})
.sum::<f64>();
let mi = joint_counts
.iter()
.map(|(&(y, c), &joint_count)| {
let p_joint = joint_count / n;
let p_y = true_counts[&y] / n;
let p_c = pred_counts[&c] / n;
p_joint * (p_joint / (p_y * p_c)).log2()
})
.sum::<f64>();
if h_true == 0.0 || h_pred == 0.0 {
0.0
} else {
mi / ((h_true * h_pred).sqrt())
}
}
impl<B: Backend> Metric for NMIMetric<B> {
type Input = ClusteringMetricInput<B>;
fn name(&self) -> String {
"NMI".to_string()
}
fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry {
self.state.update(
nmi_score(&input.y_true(), &input.y_pred()),
input.batch_size(),
FormatOptions::new(self.name()).precision(2),
)
}
fn clear(&mut self) {
self.state.reset();
}
}
impl<B: Backend> Numeric for NMIMetric<B> {
fn value(&self) -> f64 {
self.state.value()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn perfect_labelings() {
assert_eq!(nmi_score(&[0, 0, 1, 1], &[0, 0, 1, 1]), 1.0);
assert_eq!(nmi_score(&[0, 0, 1, 1], &[1, 1, 0, 0]), 1.0);
}
#[test]
fn totally_incomplete() {
assert_eq!(nmi_score(&[0, 0, 0, 0], &[0, 1, 2, 3]), 0.0);
}
}