use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use crate::error::{ClusteringError, Result};
pub fn dunn_index<F>(data: ArrayView2<F>, labels: ArrayView1<i32>) -> Result<F>
where
F: Float + FromPrimitive + Debug + PartialOrd + 'static,
{
if data.shape()[0] != labels.shape()[0] {
return Err(ClusteringError::InvalidInput(
"Data and labels must have the same number of samples".to_string(),
));
}
let n_samples = data.shape()[0];
if n_samples < 2 {
return Ok(F::zero());
}
let mut unique_labels = Vec::new();
for &label in labels.iter() {
if label >= 0 && !unique_labels.contains(&label) {
unique_labels.push(label);
}
}
if unique_labels.len() < 2 {
return Ok(F::zero());
}
let mut min_inter_cluster = F::infinity();
for i in 0..n_samples {
for j in (i + 1)..n_samples {
if labels[i] >= 0 && labels[j] >= 0 && labels[i] != labels[j] {
let distance = euclidean_distance(data.row(i), data.row(j));
if distance < min_inter_cluster {
min_inter_cluster = distance;
}
}
}
}
let mut max_intra_cluster = F::zero();
for &cluster_label in &unique_labels {
let cluster_indices: Vec<usize> = labels
.iter()
.enumerate()
.filter_map(|(i, &label)| {
if label == cluster_label {
Some(i)
} else {
None
}
})
.collect();
for i in 0..cluster_indices.len() {
for j in (i + 1)..cluster_indices.len() {
let distance =
euclidean_distance(data.row(cluster_indices[i]), data.row(cluster_indices[j]));
if distance > max_intra_cluster {
max_intra_cluster = distance;
}
}
}
}
if max_intra_cluster == F::zero() {
return Ok(F::infinity());
}
Ok(min_inter_cluster / max_intra_cluster)
}
pub fn bic_score<F>(data: ArrayView2<F>, labels: ArrayView1<i32>) -> Result<F>
where
F: Float + FromPrimitive + Debug + PartialOrd + ScalarOperand + 'static,
{
if data.shape()[0] != labels.shape()[0] {
return Err(ClusteringError::InvalidInput(
"Data and labels must have the same number of samples".to_string(),
));
}
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
if n_samples < 2 {
return Ok(F::zero());
}
let mut unique_labels = Vec::new();
for &label in labels.iter() {
if label >= 0 && !unique_labels.contains(&label) {
unique_labels.push(label);
}
}
let k = unique_labels.len();
if k < 1 {
return Ok(F::infinity());
}
let mut wcss = F::zero();
for &cluster_label in &unique_labels {
let cluster_indices: Vec<usize> = labels
.iter()
.enumerate()
.filter_map(|(i, &label)| {
if label == cluster_label {
Some(i)
} else {
None
}
})
.collect();
if cluster_indices.is_empty() {
continue;
}
let mut centroid = Array1::zeros(n_features);
for &idx in &cluster_indices {
centroid = ¢roid + &data.row(idx);
}
centroid = centroid / F::from_usize(cluster_indices.len()).expect("Operation failed");
for &idx in &cluster_indices {
let diff = &data.row(idx) - ¢roid;
wcss = wcss + diff.dot(&diff);
}
}
let n_params = k * (n_features + 1) - 1; let log_likelihood = -F::from_usize(n_samples).expect("Operation failed") * wcss.ln()
/ F::from_f64(2.0).expect("Operation failed");
let penalty = F::from_usize(n_params).expect("Operation failed")
* F::from_usize(n_samples).expect("Operation failed").ln()
/ F::from_f64(2.0).expect("Operation failed");
Ok(-log_likelihood + penalty)
}
pub fn aic_score<F>(data: ArrayView2<F>, labels: ArrayView1<i32>) -> Result<F>
where
F: Float + FromPrimitive + Debug + PartialOrd + ScalarOperand + 'static,
{
if data.shape()[0] != labels.shape()[0] {
return Err(ClusteringError::InvalidInput(
"Data and labels must have the same number of samples".to_string(),
));
}
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
if n_samples < 2 {
return Ok(F::zero());
}
let mut unique_labels = Vec::new();
for &label in labels.iter() {
if label >= 0 && !unique_labels.contains(&label) {
unique_labels.push(label);
}
}
let k = unique_labels.len();
if k < 1 {
return Ok(F::infinity());
}
let mut wcss = F::zero();
for &cluster_label in &unique_labels {
let cluster_indices: Vec<usize> = labels
.iter()
.enumerate()
.filter_map(|(i, &label)| {
if label == cluster_label {
Some(i)
} else {
None
}
})
.collect();
if cluster_indices.is_empty() {
continue;
}
let mut centroid = Array1::zeros(n_features);
for &idx in &cluster_indices {
centroid = ¢roid + &data.row(idx);
}
centroid = centroid / F::from_usize(cluster_indices.len()).expect("Operation failed");
for &idx in &cluster_indices {
let diff = &data.row(idx) - ¢roid;
wcss = wcss + diff.dot(&diff);
}
}
let n_params = k * (n_features + 1) - 1; let log_likelihood = -F::from_usize(n_samples).expect("Operation failed") * wcss.ln()
/ F::from_f64(2.0).expect("Operation failed");
Ok(
-F::from_f64(2.0).expect("Operation failed") * log_likelihood
+ F::from_f64(2.0).expect("Operation failed")
* F::from_usize(n_params).expect("Operation failed"),
)
}
fn euclidean_distance<F>(point1: ArrayView1<F>, point2: ArrayView1<F>) -> F
where
F: Float + FromPrimitive + Debug + 'static,
{
let diff = &point1 - &point2;
diff.dot(&diff).sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_dunn_index() {
let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 1.0, 10.0, 10.0, 11.0, 11.0])
.expect("Operation failed");
let labels = Array1::from_vec(vec![0, 0, 1, 1]);
let dunn = dunn_index(data.view(), labels.view()).expect("Operation failed");
assert!(dunn > 0.0);
}
#[test]
fn test_bic_score() {
let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 1.0, 10.0, 10.0, 11.0, 11.0])
.expect("Operation failed");
let labels = Array1::from_vec(vec![0, 0, 1, 1]);
let bic = bic_score(data.view(), labels.view()).expect("Operation failed");
assert!(bic.is_finite());
}
#[test]
fn test_aic_score() {
let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 1.0, 10.0, 10.0, 11.0, 11.0])
.expect("Operation failed");
let labels = Array1::from_vec(vec![0, 0, 1, 1]);
let aic = aic_score(data.view(), labels.view()).expect("Operation failed");
assert!(aic.is_finite());
}
#[test]
fn test_euclidean_distance() {
let point1 = Array1::from_vec(vec![0.0, 0.0]);
let point2 = Array1::from_vec(vec![3.0, 4.0]);
let distance = euclidean_distance(point1.view(), point2.view());
assert!((distance - 5.0).abs() < 1e-10);
}
}