klaster 0.2.0

Machine learning library providing modern clusterning algorithms for the Rust programming language
Documentation
// Copyright (C) 2025 Piotr Jabłoński
// Extended copyright information can be found in the LICENSE file.

use crate::metric::ClusteringMetricInput;
use burn::prelude::*;
use burn::train::metric::state::{FormatOptions, NumericMetricState};
use burn::train::metric::{Metric, MetricEntry, MetricMetadata, Numeric};
use std::collections::HashMap;

/// Normalized Mutual Information (NMI) metric.
#[derive(Default)]
pub(crate) struct NMIMetric<B: Backend> {
    state: NumericMetricState,
    _b: B,
}

impl<B: Backend> NMIMetric<B> {
    pub(crate) fn new() -> Self {
        Self::default()
    }
}

/// Compute Normalized Mutual Information (NMI).
///
/// # Returns
/// NMI in \[0, 1\].
pub fn nmi_score<T>(y_pred: &[T], y_true: &[T]) -> f64
where
    T: std::cmp::Eq + std::hash::Hash + Copy,
{
    assert_eq!(y_pred.len(), y_true.len());
    let n = y_true.len() as f64;

    let mut true_counts = HashMap::new();
    let mut pred_counts = HashMap::new();
    let mut joint_counts = HashMap::new();

    for (&true_label, &pred_label) in y_true.iter().zip(y_pred) {
        *true_counts.entry(true_label).or_insert(0.0) += 1.0;
        *pred_counts.entry(pred_label).or_insert(0.0) += 1.0;
        *joint_counts.entry((true_label, pred_label)).or_insert(0.0) += 1.0;
    }

    let h_true = true_counts
        .values()
        .map(|&count| {
            let p = count / n;
            -p * p.log2()
        })
        .sum::<f64>();

    let h_pred = pred_counts
        .values()
        .map(|&count| {
            let p = count / n;
            -p * p.log2()
        })
        .sum::<f64>();

    let mi = joint_counts
        .iter()
        .map(|(&(y, c), &joint_count)| {
            let p_joint = joint_count / n;
            let p_y = true_counts[&y] / n;
            let p_c = pred_counts[&c] / n;
            p_joint * (p_joint / (p_y * p_c)).log2()
        })
        .sum::<f64>();

    if h_true == 0.0 || h_pred == 0.0 {
        0.0
    } else {
        mi / ((h_true * h_pred).sqrt())
    }
}

impl<B: Backend> Metric for NMIMetric<B> {
    type Input = ClusteringMetricInput<B>;

    fn name(&self) -> String {
        "NMI".to_string()
    }

    fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry {
        self.state.update(
            nmi_score(&input.y_true(), &input.y_pred()),
            input.batch_size(),
            FormatOptions::new(self.name()).precision(2),
        )
    }

    fn clear(&mut self) {
        self.state.reset();
    }
}

impl<B: Backend> Numeric for NMIMetric<B> {
    fn value(&self) -> f64 {
        self.state.value()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn perfect_labelings() {
        assert_eq!(nmi_score(&[0, 0, 1, 1], &[0, 0, 1, 1]), 1.0);

        assert_eq!(nmi_score(&[0, 0, 1, 1], &[1, 1, 0, 0]), 1.0);
    }

    #[test]
    fn totally_incomplete() {
        assert_eq!(nmi_score(&[0, 0, 0, 0], &[0, 1, 2, 3]), 0.0);
    }
}