#[cfg(test)]
mod test {
use crate::metrics::silhouette_score;
use crate::vq::{kmeans2, whiten, MinitMethod, MissingMethod};
use scirs2_core::ndarray::{array, Array2};
#[test]
fn test_whiten() {
let data: Array2<f64> =
Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 1.5, 2.5, 0.5, 1.5, 2.0, 3.0])
.expect("Operation failed");
let whitened = whiten(&data).expect("Operation failed");
let n_features = whitened.shape()[1];
for j in 0..n_features {
let column = whitened.column(j);
let mean: f64 = column.mean().expect("Operation failed");
let var: f64 = column.var(1.0);
assert!((mean.abs()) < 1e-6, "Mean should be close to 0");
assert!((var - 1.0).abs() < 1e-6, "Variance should be close to 1");
}
}
#[test]
fn test_kmeans2_all_init_methods() {
let data = Array2::from_shape_vec((20, 2), (0..40).map(|i| i as f64 / 10.0).collect())
.expect("Operation failed");
let init_methods = vec![
MinitMethod::Random,
MinitMethod::Points,
MinitMethod::PlusPlus,
];
for method in init_methods {
let (centroids, labels) = kmeans2(
data.view(),
3,
Some(10),
Some(1e-4),
Some(method),
Some(MissingMethod::Warn),
Some(true),
Some(42),
)
.expect("Operation failed");
assert_eq!(centroids.shape()[0], 3);
assert_eq!(centroids.shape()[1], 2);
assert_eq!(labels.len(), 20);
}
}
#[test]
fn test_kmeans2_empty_cluster_handling() {
let data = Array2::from_shape_vec(
(6, 2),
vec![
0.0, 0.0, 0.1, 0.1, 0.2, 0.2, 10.0, 10.0, 10.1, 10.1, 10.2, 10.2,
],
)
.expect("Operation failed");
let result1 = kmeans2(
data.view(),
3,
Some(5),
Some(1e-4),
Some(MinitMethod::Random),
Some(MissingMethod::Warn),
Some(true),
Some(123),
);
assert!(result1.is_ok());
let result2 = kmeans2(
data.view(),
4, Some(5),
Some(1e-4),
Some(MinitMethod::Random),
Some(MissingMethod::Raise),
Some(true),
Some(456),
);
match result2 {
Ok(_) => println!("Succeeded without empty clusters"),
Err(e) => println!("Failed as expected: {e}"),
}
}
#[test]
fn test_silhouette_score() {
let data = array![
[1.0, 1.0],
[1.5, 1.5],
[1.2, 1.3],
[10.0, 10.0],
[10.5, 10.5],
[10.2, 10.3],
];
let labels = array![0, 0, 0, 1, 1, 1];
let score = silhouette_score(data.view(), labels.view()).expect("Operation failed");
assert!(
score > 0.8,
"Silhouette score should be high for well-separated clusters"
);
}
#[test]
fn test_silhouette_with_noise() {
let data = array![
[1.0, 1.0],
[1.5, 1.5],
[10.0, 10.0],
[10.5, 10.5],
[50.0, 50.0], ];
let labels = array![0, 0, 1, 1, -1];
let score = silhouette_score(data.view(), labels.view()).expect("Operation failed");
assert!(score > 0.0);
}
}