use scirs2_core::ndarray::{ArrayBase, Data, Ix2};
use scirs2_core::numeric::{Float, NumCast};
use super::{calculate_distance, group_by_labels, pairwise_distances};
use crate::error::{MetricsError, Result};
#[allow(dead_code)]
pub fn dunn_index_enhanced<F, S1, S2, D>(
x: &ArrayBase<S1, Ix2>,
labels: &ArrayBase<S2, D>,
) -> Result<F>
where
F: Float + NumCast + std::fmt::Debug,
S1: Data<Elem = F>,
S2: Data<Elem = usize>,
D: scirs2_core::ndarray::Dimension,
{
let n_samples = x.shape()[0];
if labels.len() != n_samples {
return Err(MetricsError::InvalidInput(format!(
"Number of samples in x ({}) does not match number of labels ({})",
n_samples,
labels.len()
)));
}
let clusters = group_by_labels(x, labels)?;
let n_clusters = clusters.len();
if n_clusters <= 1 {
return Err(MetricsError::InvalidInput(
"Dunn index is only defined for more than one cluster".to_string(),
));
}
let distances = pairwise_distances::<F, S1>(x, "euclidean")?;
let mut min_inter_distance = F::infinity();
for (i, indices_i) in clusters.iter() {
for (j, indices_j) in clusters.iter() {
if i != j {
for &idx_i in indices_i {
for &idx_j in indices_j {
min_inter_distance = min_inter_distance.min(distances[[idx_i, idx_j]]);
}
}
}
}
}
let mut max_intra_distance = F::zero();
for (_, indices) in clusters.iter() {
for (idx, &i) in indices.iter().enumerate() {
for &j in indices.iter().skip(idx + 1) {
max_intra_distance = max_intra_distance.max(distances[[i, j]]);
}
}
}
if max_intra_distance == F::zero() {
Ok(F::infinity())
} else {
Ok(min_inter_distance / max_intra_distance)
}
}
#[allow(dead_code)]
pub fn elbow_method<F, S>(
x: &ArrayBase<S, Ix2>,
k_range: std::ops::RangeInclusive<usize>,
kmeans_fn: impl Fn(&ArrayBase<S, Ix2>, usize) -> F,
) -> Result<Vec<F>>
where
F: Float + NumCast + std::fmt::Debug,
S: Data<Elem = F>,
{
let start = *k_range.start();
let end = *k_range.end();
if start < 1 {
return Err(MetricsError::InvalidInput(
"k_range must start at 1 or greater".to_string(),
));
}
if end < start {
return Err(MetricsError::InvalidInput(
"k_range end must be greater than or equal to start".to_string(),
));
}
let (n_samples, n_features) = x.dim();
if n_samples == 0 || n_features == 0 {
return Err(MetricsError::InvalidInput(
"Input data is empty".to_string(),
));
}
let mut inertias = Vec::with_capacity(end - start + 1);
for k in k_range {
let inertia = kmeans_fn(x, k);
inertias.push(inertia);
}
Ok(inertias)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::clustering::dunn_index;
use scirs2_core::ndarray::{array, Array2};
#[test]
fn test_dunn_index() {
let x = Array2::from_shape_vec(
(6, 2),
vec![1.0, 2.0, 1.5, 1.8, 1.2, 2.2, 5.0, 6.0, 5.2, 5.8, 5.5, 6.2],
)
.expect("Operation failed");
let labels = array![0, 0, 0, 1, 1, 1];
let dunn = dunn_index(&x, &labels).expect("Operation failed");
assert!(dunn > 0.5);
let x_overlap = Array2::from_shape_vec(
(6, 2),
vec![
1.0, 2.0, 1.5, 1.8, 3.2, 3.2, 3.0, 3.0, 5.2, 5.8, 5.5, 6.2,
],
)
.expect("Operation failed");
let labels_overlap = array![0, 0, 0, 1, 1, 1];
let dunn_overlap = dunn_index(&x_overlap, &labels_overlap).expect("Operation failed");
assert!(dunn_overlap < dunn);
}
#[test]
fn test_elbow_method() {
let x = Array2::<f64>::zeros((10, 2));
let kmeans_mock = |_: &Array2<f64>, k: usize| {
let base = 100.0;
match k {
1 => base,
2 => base / 2.0,
3 => base / 3.0, 4 => base / 3.2, 5 => base / 3.4,
_ => base / (3.5 + (k as f64 - 5.0) * 0.1),
}
};
let inertias = elbow_method(&x, 1..=6, kmeans_mock).expect("Operation failed");
assert_eq!(inertias.len(), 6);
for i in 1..inertias.len() {
assert!(inertias[i] < inertias[i - 1]);
}
let drop_1_to_2 = inertias[0] - inertias[1];
let drop_2_to_3 = inertias[1] - inertias[2];
let drop_3_to_4 = inertias[2] - inertias[3];
assert!(drop_1_to_2 > drop_2_to_3);
assert!(drop_2_to_3 > drop_3_to_4);
}
}