#![cfg(feature = "machine_learning")]
use ndarray::{Array1, Array2};
use ndarray_rand::rand::rngs::StdRng;
use ndarray_rand::rand::{Rng, SeedableRng};
use rustyml::machine_learning::kmeans::KMeans;
fn create_test_data() -> Array2<f64> {
let mut data = Array2::zeros((20, 2));
for i in 0..10 {
let mut rng = StdRng::seed_from_u64(i as u64);
data[[i, 0]] = rng.random_range(-30.0..-20.0);
data[[i, 1]] = rng.random_range(-30.0..-20.0);
}
for i in 10..20 {
let mut rng = StdRng::seed_from_u64(i as u64);
data[[i, 0]] = rng.random_range(20.0..30.0);
data[[i, 1]] = rng.random_range(20.0..30.0);
}
data
}
#[test]
fn test_new_and_default() {
let kmeans = KMeans::new(3, 100, 0.0001, Some(42)).unwrap();
assert!(matches!(kmeans.get_centroids(), None));
assert!(matches!(kmeans.get_labels(), None));
assert!(matches!(kmeans.get_inertia(), None));
assert!(matches!(kmeans.get_actual_iterations(), None));
let default_kmeans = KMeans::default();
assert!(matches!(default_kmeans.get_centroids(), None));
assert!(matches!(default_kmeans.get_labels(), None));
assert!(matches!(default_kmeans.get_inertia(), None));
assert!(matches!(default_kmeans.get_actual_iterations(), None));
}
#[test]
fn test_fit() {
let mut kmeans = KMeans::new(2, 100, 0.0001, Some(42)).unwrap();
let data = create_test_data();
kmeans.fit(&data.view()).unwrap();
let centroids = match kmeans.get_centroids() {
Some(centroids) => centroids,
None => panic!("Centroids should be available after fitting"),
};
assert!(matches!(kmeans.get_centroids(), Some(_)));
assert_eq!(centroids.shape(), &[2, 2]);
assert!(matches!(kmeans.get_inertia(), Some(_)));
assert!(matches!(kmeans.get_actual_iterations(), Some(_)));
}
#[test]
fn test_predict() {
let mut kmeans = KMeans::new(2, 1000, 1e-7, Some(42)).unwrap();
let data = create_test_data();
kmeans.fit(&data.view()).unwrap();
let predictions = kmeans.predict(&data.view()).unwrap();
assert_eq!(predictions.len(), 20);
let first_label = predictions[0];
let expected_first_half = Array1::from_elem(10, first_label);
let expected_second_half = Array1::from_elem(10, 1 - first_label);
let mut correct_count = 0;
for i in 0..10 {
if predictions[i] == expected_first_half[i] {
correct_count += 1;
}
}
assert!(correct_count >= 8);
correct_count = 0;
for i in 10..20 {
if predictions[i] == expected_second_half[i - 10] {
correct_count += 1;
}
}
assert!(correct_count >= 8)
}
#[test]
fn test_fit_predict() {
let mut kmeans = KMeans::new(2, 100, 0.0001, Some(42)).unwrap();
let data = create_test_data();
let predictions = kmeans.fit_predict(&data.view()).unwrap();
assert_eq!(predictions.len(), 20);
assert!(matches!(kmeans.get_centroids(), Some(_)));
assert!(matches!(kmeans.get_labels(), Some(_)));
assert!(matches!(kmeans.get_inertia(), Some(_)));
assert!(matches!(kmeans.get_actual_iterations(), Some(_)));
let labels = match kmeans.get_labels() {
Some(labels) => labels,
None => panic!("Labels should be available after fitting"),
};
assert_eq!(predictions, labels);
}
#[test]
fn test_getters() {
let mut kmeans = KMeans::new(2, 100, 0.0001, Some(42)).unwrap();
let data = create_test_data();
assert!(matches!(kmeans.get_centroids(), None));
assert!(matches!(kmeans.get_labels(), None));
assert!(matches!(kmeans.get_inertia(), None));
assert!(matches!(kmeans.get_actual_iterations(), None));
kmeans.fit(&data.view()).unwrap();
assert!(matches!(kmeans.get_centroids(), Some(_)));
assert!(matches!(kmeans.get_labels(), Some(_)));
assert!(matches!(kmeans.get_inertia(), Some(_)));
assert!(matches!(kmeans.get_actual_iterations(), Some(_)));
}
#[test]
fn test_different_cluster_counts() {
let data = create_test_data();
let mut kmeans_k1 = KMeans::new(1, 100, 0.0001, Some(42)).unwrap();
kmeans_k1.fit(&data.view()).unwrap();
let centroids = match kmeans_k1.get_centroids() {
Some(centroids) => centroids,
None => panic!("Centroids should be available after fitting"),
};
assert_eq!(centroids.shape(), &[1, 2]);
let mut kmeans_k3 = KMeans::new(3, 100, 0.0001, Some(42)).unwrap();
kmeans_k3.fit(&data.view()).unwrap();
let centroids = match kmeans_k3.get_centroids() {
Some(centroids) => centroids,
None => panic!("Centroids should be available after fitting"),
};
assert_eq!(centroids.shape(), &[3, 2]);
}