use kiddo::{ImmutableKdTree, SquaredEuclidean};
use ndarray::{Array2, ArrayView2};
use rand::prelude::*;
use rand_distr::Normal;
use std::num::NonZeroUsize;
pub(crate) fn add_noise(mut data: Array2<f64>, noise_level: f64) -> Array2<f64> {
if noise_level <= 0.0 {
return data;
}
let mut rng = thread_rng();
let normal = Normal::new(0.0, noise_level).unwrap();
for x in data.iter_mut() {
*x += normal.sample(&mut rng);
}
data
}
pub(crate) fn unit_ball_volume(m: usize) -> f64 {
use statrs::function::gamma::gamma;
let m_f = m as f64;
let numerator = std::f64::consts::PI.powf(m_f / 2.0);
let denom = gamma(m_f / 2.0 + 1.0);
numerator / denom
}
fn to_points<const K: usize>(data: ArrayView2<'_, f64>) -> Vec<[f64; K]> {
assert!(data.ncols() == K, "data.ncols() must equal K");
let n = data.nrows();
let mut points: Vec<[f64; K]> = Vec::with_capacity(n);
if let Some(slice) = data.as_slice() {
for chunk in slice.chunks_exact(K) {
let mut p = [0.0; K];
p.copy_from_slice(&chunk[..K]);
points.push(p);
}
} else {
for r in 0..n {
let mut p = [0.0; K];
for c in 0..K {
p[c] = data[(r, c)];
}
points.push(p);
}
}
points
}
pub(crate) fn knn_radii_at<const K: usize>(
data: ArrayView2<'_, f64>,
k: usize,
at: Option<ArrayView2<'_, f64>>,
) -> Vec<f64> {
assert!(k >= 1, "k must be >= 1");
assert!(data.ncols() == K, "data.ncols() must equal K");
let n = data.nrows();
if n == 0 {
return Vec::new();
}
let points = to_points::<K>(data);
let tree: ImmutableKdTree<f64, K> = ImmutableKdTree::new_from_slice(&points);
if let Some(target_data) = at {
assert!(target_data.ncols() == K, "target_data.ncols() must equal K");
let m = target_data.nrows();
let target_points = to_points::<K>(target_data);
let mut radii = Vec::with_capacity(m);
for p in target_points.iter() {
let mut neigh = tree.nearest_n::<SquaredEuclidean>(p, NonZeroUsize::new(k).unwrap());
let kth = neigh.remove(k - 1);
let (dist2, _idx): (f64, u64) = kth.into();
radii.push(dist2.sqrt());
}
radii
} else {
assert!(
k < n,
"k must be <= N-1 when querying within the same dataset"
);
let mut radii = Vec::with_capacity(n);
for p in points.iter() {
let mut neigh =
tree.nearest_n::<SquaredEuclidean>(p, NonZeroUsize::new(k + 1).unwrap());
let kth = neigh.remove(k);
let (dist2, _idx): (f64, u64) = kth.into();
radii.push(dist2.sqrt());
}
radii
}
}
pub fn knn_radii<const K: usize>(data: ArrayView2<'_, f64>, k: usize) -> Vec<f64> {
knn_radii_at::<K>(data, k, None)
}
pub(crate) fn calculate_common_entropy_components_at<const K: usize>(
data: ArrayView2<'_, f64>,
k: usize,
at: Option<ArrayView2<'_, f64>>,
) -> (f64, Vec<f64>, usize, usize) {
let v_m = unit_ball_volume(K);
let rho_k = knn_radii_at::<K>(data, k, at);
let n = rho_k.len(); (v_m, rho_k, n, K)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::{Array2, array};
pub(crate) fn calculate_common_entropy_components<const K: usize>(
data: ArrayView2<'_, f64>,
k: usize,
) -> (f64, Vec<f64>, usize, usize) {
calculate_common_entropy_components_at::<K>(data, k, None)
}
#[test]
fn test_add_noise_basic() {
let data: Array2<f64> = array![[1.0, 2.0], [3.0, 4.0]];
let noisy = add_noise(data.clone(), 0.1);
assert!(noisy != data);
let no_noise = add_noise(data.clone(), 0.0);
assert_eq!(no_noise, data);
}
#[test]
fn unit_ball_volume_known_values() {
assert_abs_diff_eq!(unit_ball_volume(1), 2.0, epsilon = 1e-12);
assert_abs_diff_eq!(unit_ball_volume(2), std::f64::consts::PI, epsilon = 1e-12);
assert_abs_diff_eq!(
unit_ball_volume(3),
4.0 * std::f64::consts::PI / 3.0,
epsilon = 1e-6
);
}
#[test]
fn knn_radii_1d_simple_cases() {
let d2: Array2<f64> = array![[1.0], [2.0]];
let r = knn_radii::<1>(d2.view(), 1);
assert_abs_diff_eq!(r[0], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(r[1], 1.0, epsilon = 1e-12);
let d3: Array2<f64> = array![[1.0], [2.0], [3.0]];
let r = knn_radii::<1>(d3.view(), 1);
assert_eq!(r.len(), 3);
for v in r {
assert_abs_diff_eq!(v, 1.0, epsilon = 1e-12);
}
let d3: Array2<f64> = array![[1.0], [2.0], [3.0]];
let r = knn_radii::<1>(d3.view(), 2);
assert_abs_diff_eq!(r[0], 2.0, epsilon = 1e-12);
assert_abs_diff_eq!(r[1], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(r[2], 2.0, epsilon = 1e-12);
}
#[test]
#[should_panic]
fn knn_radii_panics_when_k_too_large() {
let d: Array2<f64> = array![[1.0], [2.0], [3.0]];
let _ = knn_radii::<1>(d.view(), 3);
}
#[test]
fn calculate_common_entropy_components_matches_python_cases() {
let d: Array2<f64> = array![[1.0], [2.0]];
let (v_m, rho_k, n, m) = calculate_common_entropy_components::<1>(d.view(), 1);
assert_abs_diff_eq!(v_m, 2.0, epsilon = 1e-12);
assert_eq!(n, 2);
assert_eq!(m, 1);
assert_eq!(rho_k.len(), 2);
assert_abs_diff_eq!(rho_k[0], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(rho_k[1], 1.0, epsilon = 1e-12);
let d: Array2<f64> = array![[1.0], [2.0], [3.0]];
let (v_m, rho_k, n, m) = calculate_common_entropy_components::<1>(d.view(), 1);
assert_abs_diff_eq!(v_m, 2.0, epsilon = 1e-12);
assert_eq!(n, 3);
assert_eq!(m, 1);
assert_eq!(rho_k.len(), 3);
for v in rho_k {
assert_abs_diff_eq!(v, 1.0, epsilon = 1e-12);
}
let d: Array2<f64> = array![[1.0], [2.0], [3.0]];
let (_v_m, rho_k, _n, _m) = calculate_common_entropy_components::<1>(d.view(), 2);
assert_abs_diff_eq!(rho_k[0], 2.0, epsilon = 1e-12);
assert_abs_diff_eq!(rho_k[1], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(rho_k[2], 2.0, epsilon = 1e-12);
let d: Array2<f64> = array![[1.0, 2.0], [2.0, 3.0]];
let (v_m, rho_k, n, m) = calculate_common_entropy_components::<2>(d.view(), 1);
assert_abs_diff_eq!(v_m, std::f64::consts::PI, epsilon = 1e-12);
assert_eq!(n, 2);
assert_eq!(m, 2);
assert_eq!(rho_k.len(), 2);
let s2 = 2.0_f64.sqrt();
assert_abs_diff_eq!(rho_k[0], s2, epsilon = 1e-12);
assert_abs_diff_eq!(rho_k[1], s2, epsilon = 1e-12);
let d: Array2<f64> = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
let (v_m, rho_k, n, m) = calculate_common_entropy_components::<3>(d.view(), 1);
assert_abs_diff_eq!(v_m, 4.0 * std::f64::consts::PI / 3.0, epsilon = 1e-6);
assert_eq!(n, 2);
assert_eq!(m, 3);
assert_eq!(rho_k.len(), 2);
let s3 = 3.0_f64.sqrt();
assert_abs_diff_eq!(rho_k[0], s3, epsilon = 1e-12);
assert_abs_diff_eq!(rho_k[1], s3, epsilon = 1e-12);
}
#[test]
fn calculate_common_entropy_components_panics_when_k_too_large() {
let d: Array2<f64> = array![[0.0]];
assert!(
std::panic::catch_unwind(|| {
let _ = calculate_common_entropy_components::<1>(d.view(), 1);
})
.is_err()
);
let d: Array2<f64> = array![[1.0], [2.0], [3.0]];
assert!(
std::panic::catch_unwind(|| {
let _ = calculate_common_entropy_components::<1>(d.view(), 3);
})
.is_err()
);
assert!(
std::panic::catch_unwind(|| {
let _ = calculate_common_entropy_components::<1>(d.view(), 4);
})
.is_err()
);
let d: Array2<f64> = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
assert!(
std::panic::catch_unwind(|| {
let _ = calculate_common_entropy_components::<3>(d.view(), 3);
})
.is_err()
);
}
}