infomeasure 0.1.0-alpha.1

Information theory measures and entropy calculations for Rust
Documentation
// SPDX-FileCopyrightText: 2025-2026 Carlson Büth <code@cbueth.de>
//
// SPDX-License-Identifier: MIT OR Apache-2.0

use kiddo::{ImmutableKdTree, SquaredEuclidean};
use ndarray::{Array2, ArrayView2};
use rand::prelude::*;
use rand_distr::Normal;
use std::num::NonZeroUsize;

/// Add Gaussian noise to a 2D array.
pub(crate) fn add_noise(mut data: Array2<f64>, noise_level: f64) -> Array2<f64> {
    if noise_level <= 0.0 {
        return data;
    }
    let mut rng = thread_rng();
    let normal = Normal::new(0.0, noise_level).unwrap();
    for x in data.iter_mut() {
        *x += normal.sample(&mut rng);
    }
    data
}

/// Compute the volume of the unit m-ball in R^m.
/// c_m = pi^{m/2} / Gamma(m/2 + 1)
/// We use a simple implementation via statrs gamma function.
pub(crate) fn unit_ball_volume(m: usize) -> f64 {
    use statrs::function::gamma::gamma;
    let m_f = m as f64;
    let numerator = std::f64::consts::PI.powf(m_f / 2.0);
    let denom = gamma(m_f / 2.0 + 1.0);
    numerator / denom
}

/// Convert an ArrayView2<f64> with exactly K columns into Vec<[f64; K]> points.
fn to_points<const K: usize>(data: ArrayView2<'_, f64>) -> Vec<[f64; K]> {
    assert!(data.ncols() == K, "data.ncols() must equal K");
    let n = data.nrows();
    let mut points: Vec<[f64; K]> = Vec::with_capacity(n);
    if let Some(slice) = data.as_slice() {
        for chunk in slice.chunks_exact(K) {
            let mut p = [0.0; K];
            p.copy_from_slice(&chunk[..K]);
            points.push(p);
        }
    } else {
        for r in 0..n {
            let mut p = [0.0; K];
            for c in 0..K {
                p[c] = data[(r, c)];
            }
            points.push(p);
        }
    }
    points
}

/// Compute kNN radii (Euclidean distances to the k-th nearest neighbor).
///
/// If `at` is None, it computes distances within the same dataset `data` (excluding self).
/// If `at` is Some(target), it computes distances from points in `target` to their k-th neighbors in `data`.
///
/// - data is shape (N, K)
/// - target is shape (M, K)
/// - k >= 1
pub(crate) fn knn_radii_at<const K: usize>(
    data: ArrayView2<'_, f64>,
    k: usize,
    at: Option<ArrayView2<'_, f64>>,
) -> Vec<f64> {
    assert!(k >= 1, "k must be >= 1");
    assert!(data.ncols() == K, "data.ncols() must equal K");
    let n = data.nrows();
    if n == 0 {
        return Vec::new();
    }

    let points = to_points::<K>(data);
    let tree: ImmutableKdTree<f64, K> = ImmutableKdTree::new_from_slice(&points);

    if let Some(target_data) = at {
        assert!(target_data.ncols() == K, "target_data.ncols() must equal K");
        let m = target_data.nrows();
        let target_points = to_points::<K>(target_data);
        let mut radii = Vec::with_capacity(m);
        for p in target_points.iter() {
            // No need to exclude self if we are querying another dataset
            let mut neigh = tree.nearest_n::<SquaredEuclidean>(p, NonZeroUsize::new(k).unwrap());
            let kth = neigh.remove(k - 1);
            let (dist2, _idx): (f64, u64) = kth.into();
            radii.push(dist2.sqrt());
        }
        radii
    } else {
        assert!(
            k < n,
            "k must be <= N-1 when querying within the same dataset"
        );
        // Query k+1 neighbors (including self), take index k (0-based) and sqrt distance
        let mut radii = Vec::with_capacity(n);
        for p in points.iter() {
            let mut neigh =
                tree.nearest_n::<SquaredEuclidean>(p, NonZeroUsize::new(k + 1).unwrap());
            let kth = neigh.remove(k);
            let (dist2, _idx): (f64, u64) = kth.into();
            radii.push(dist2.sqrt());
        }
        radii
    }
}

/// Compute kNN radii (Euclidean distances to the k-th nearest neighbor), excluding self.
pub fn knn_radii<const K: usize>(data: ArrayView2<'_, f64>, k: usize) -> Vec<f64> {
    knn_radii_at::<K>(data, k, None)
}

/// Compute common components used by exponential-family kNN estimators.
/// Mirrors Python calculate_common_entropy_components.
pub(crate) fn calculate_common_entropy_components_at<const K: usize>(
    data: ArrayView2<'_, f64>,
    k: usize,
    at: Option<ArrayView2<'_, f64>>,
) -> (f64, Vec<f64>, usize, usize) {
    let v_m = unit_ball_volume(K);
    let rho_k = knn_radii_at::<K>(data, k, at);
    let n = rho_k.len(); // N if at is None, M if at is Some(target)
    (v_m, rho_k, n, K)
}

#[cfg(test)]
mod tests {
    use super::*;
    use approx::assert_abs_diff_eq;
    use ndarray::{Array2, array};

    /// Compute common components used by exponential-family kNN estimators for self-evaluation.
    pub(crate) fn calculate_common_entropy_components<const K: usize>(
        data: ArrayView2<'_, f64>,
        k: usize,
    ) -> (f64, Vec<f64>, usize, usize) {
        calculate_common_entropy_components_at::<K>(data, k, None)
    }

    #[test]
    fn test_add_noise_basic() {
        let data: Array2<f64> = array![[1.0, 2.0], [3.0, 4.0]];
        let noisy = add_noise(data.clone(), 0.1);

        // Should be different (unless very unlikely same noise)
        assert!(noisy != data);

        // With zero noise should be identical
        let no_noise = add_noise(data.clone(), 0.0);
        assert_eq!(no_noise, data);
    }

    #[test]
    fn unit_ball_volume_known_values() {
        // m = 1 -> volume = 2 (length of [-1, 1])
        assert_abs_diff_eq!(unit_ball_volume(1), 2.0, epsilon = 1e-12);
        // m = 2 -> area = pi
        assert_abs_diff_eq!(unit_ball_volume(2), std::f64::consts::PI, epsilon = 1e-12);
        // m = 3 -> volume = 4/3 * pi
        assert_abs_diff_eq!(
            unit_ball_volume(3),
            4.0 * std::f64::consts::PI / 3.0,
            epsilon = 1e-6
        );
    }

    #[test]
    fn knn_radii_1d_simple_cases() {
        // [[1.0], [2.0]] with k=1 -> [1.0, 1.0]
        let d2: Array2<f64> = array![[1.0], [2.0]];
        let r = knn_radii::<1>(d2.view(), 1);
        assert_abs_diff_eq!(r[0], 1.0, epsilon = 1e-12);
        assert_abs_diff_eq!(r[1], 1.0, epsilon = 1e-12);

        // [[1.0],[2.0],[3.0]] with k=1 -> [1.0, 1.0, 1.0]
        let d3: Array2<f64> = array![[1.0], [2.0], [3.0]];
        let r = knn_radii::<1>(d3.view(), 1);
        assert_eq!(r.len(), 3);
        for v in r {
            assert_abs_diff_eq!(v, 1.0, epsilon = 1e-12);
        }

        // k=2 -> [2.0, 1.0, 2.0]
        let d3: Array2<f64> = array![[1.0], [2.0], [3.0]];
        let r = knn_radii::<1>(d3.view(), 2);
        assert_abs_diff_eq!(r[0], 2.0, epsilon = 1e-12);
        assert_abs_diff_eq!(r[1], 1.0, epsilon = 1e-12);
        assert_abs_diff_eq!(r[2], 2.0, epsilon = 1e-12);
    }

    #[test]
    #[should_panic]
    fn knn_radii_panics_when_k_too_large() {
        let d: Array2<f64> = array![[1.0], [2.0], [3.0]];
        // Here N=3, but k must be <= N-1 when querying within same dataset
        let _ = knn_radii::<1>(d.view(), 3);
    }

    #[test]
    fn calculate_common_entropy_components_matches_python_cases() {
        // Case: ([[1.0],[2.0]], k=1) -> V_m=2.0, rho_k=[1,1], N=2, m=1
        let d: Array2<f64> = array![[1.0], [2.0]];
        let (v_m, rho_k, n, m) = calculate_common_entropy_components::<1>(d.view(), 1);
        assert_abs_diff_eq!(v_m, 2.0, epsilon = 1e-12);
        assert_eq!(n, 2);
        assert_eq!(m, 1);
        assert_eq!(rho_k.len(), 2);
        assert_abs_diff_eq!(rho_k[0], 1.0, epsilon = 1e-12);
        assert_abs_diff_eq!(rho_k[1], 1.0, epsilon = 1e-12);

        // Case: ([[1.0],[2.0],[3.0]], k=1) -> V_m=2.0, rho_k=[1,1,1], N=3, m=1
        let d: Array2<f64> = array![[1.0], [2.0], [3.0]];
        let (v_m, rho_k, n, m) = calculate_common_entropy_components::<1>(d.view(), 1);
        assert_abs_diff_eq!(v_m, 2.0, epsilon = 1e-12);
        assert_eq!(n, 3);
        assert_eq!(m, 1);
        assert_eq!(rho_k.len(), 3);
        for v in rho_k {
            assert_abs_diff_eq!(v, 1.0, epsilon = 1e-12);
        }

        // Case: ([[1.0],[2.0],[3.0]], k=2) -> rho_k=[2,1,2]
        let d: Array2<f64> = array![[1.0], [2.0], [3.0]];
        let (_v_m, rho_k, _n, _m) = calculate_common_entropy_components::<1>(d.view(), 2);
        assert_abs_diff_eq!(rho_k[0], 2.0, epsilon = 1e-12);
        assert_abs_diff_eq!(rho_k[1], 1.0, epsilon = 1e-12);
        assert_abs_diff_eq!(rho_k[2], 2.0, epsilon = 1e-12);

        // Case: ([[1.0,2.0],[2.0,3.0]], k=1) -> V_m=pi, rho_k=[sqrt(2), sqrt(2)], N=2, m=2
        let d: Array2<f64> = array![[1.0, 2.0], [2.0, 3.0]];
        let (v_m, rho_k, n, m) = calculate_common_entropy_components::<2>(d.view(), 1);
        assert_abs_diff_eq!(v_m, std::f64::consts::PI, epsilon = 1e-12);
        assert_eq!(n, 2);
        assert_eq!(m, 2);
        assert_eq!(rho_k.len(), 2);
        let s2 = 2.0_f64.sqrt();
        assert_abs_diff_eq!(rho_k[0], s2, epsilon = 1e-12);
        assert_abs_diff_eq!(rho_k[1], s2, epsilon = 1e-12);

        // Case: ([[1.0,2.0,3.0],[2.0,3.0,4.0]], k=1) -> V_m≈4.188790, rho_k=[sqrt(3), sqrt(3)], N=2, m=3
        let d: Array2<f64> = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
        let (v_m, rho_k, n, m) = calculate_common_entropy_components::<3>(d.view(), 1);
        assert_abs_diff_eq!(v_m, 4.0 * std::f64::consts::PI / 3.0, epsilon = 1e-6);
        assert_eq!(n, 2);
        assert_eq!(m, 3);
        assert_eq!(rho_k.len(), 2);
        let s3 = 3.0_f64.sqrt();
        assert_abs_diff_eq!(rho_k[0], s3, epsilon = 1e-12);
        assert_abs_diff_eq!(rho_k[1], s3, epsilon = 1e-12);
    }

    #[test]
    fn calculate_common_entropy_components_panics_when_k_too_large() {
        // ([[0.0]], 1) -> k too large because N-1 = 0
        let d: Array2<f64> = array![[0.0]];
        assert!(
            std::panic::catch_unwind(|| {
                let _ = calculate_common_entropy_components::<1>(d.view(), 1);
            })
            .is_err()
        );

        // ([[1.0],[2.0],[3.0]], 3) and 4
        let d: Array2<f64> = array![[1.0], [2.0], [3.0]];
        assert!(
            std::panic::catch_unwind(|| {
                let _ = calculate_common_entropy_components::<1>(d.view(), 3);
            })
            .is_err()
        );
        assert!(
            std::panic::catch_unwind(|| {
                let _ = calculate_common_entropy_components::<1>(d.view(), 4);
            })
            .is_err()
        );

        // ([[1.0,2.0,3.0],[2.0,3.0,4.0]], 3)
        let d: Array2<f64> = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
        assert!(
            std::panic::catch_unwind(|| {
                let _ = calculate_common_entropy_components::<3>(d.view(), 3);
            })
            .is_err()
        );
    }
}