#![cfg(feature = "machine_learning")]
use ndarray::{Array2, arr2};
use ndarray_rand::rand::rngs::StdRng;
use ndarray_rand::rand::{Rng, SeedableRng};
use rustyml::machine_learning::meanshift::{MeanShift, estimate_bandwidth};
fn create_test_data() -> Array2<f64> {
let mut rng = StdRng::seed_from_u64(42);
let n_samples = 300;
let mut data = Array2::zeros((n_samples, 2));
for i in 0..100 {
data[[i, 0]] = rng.random_range(-20.0..=-10.0);
data[[i, 1]] = rng.random_range(-20.0..=-10.0);
}
for i in 100..200 {
data[[i, 0]] = rng.random_range(0.0..=10.0);
data[[i, 1]] = rng.random_range(0.0..=10.0);
}
for i in 200..300 {
data[[i, 0]] = rng.random_range(20.0..=30.0);
data[[i, 1]] = rng.random_range(20.0..=30.0);
}
data
}
#[test]
fn test_meanshift_default() {
let ms = MeanShift::default();
assert_eq!(ms.get_bandwidth(), 1.0);
assert_eq!(ms.get_max_iterations(), 300);
assert_eq!(ms.get_tolerance(), 1e-3);
assert_eq!(ms.get_bin_seeding(), false);
assert_eq!(ms.get_cluster_all(), true);
}
#[test]
fn test_meanshift_new() {
let ms = MeanShift::new(2.0, Some(200), Some(1e-4), Some(true), Some(true)).unwrap();
assert_eq!(ms.get_bandwidth(), 2.0);
assert_eq!(ms.get_max_iterations(), 200);
assert_eq!(ms.get_tolerance(), 1e-4);
assert_eq!(ms.get_bin_seeding(), true);
assert_eq!(ms.get_cluster_all(), true);
}
#[test]
fn test_meanshift_getters_before_fit() {
let ms = MeanShift::default();
assert!(ms.get_cluster_centers().is_none());
assert!(ms.get_labels().is_none());
assert!(ms.get_actual_iterations().is_none());
assert!(ms.get_n_samples_per_center().is_none());
}
#[test]
fn test_meanshift_fit() {
let data = create_test_data();
let mut ms = MeanShift::new(2.0, None, None, None, Some(true)).unwrap();
ms.fit(&data.view()).unwrap();
assert!(ms.get_cluster_centers().is_some());
assert!(ms.get_labels().is_some());
assert!(ms.get_actual_iterations().is_some());
assert!(ms.get_n_samples_per_center().is_some());
let centers = match ms.get_cluster_centers() {
Some(centers) => centers,
None => panic!("Cluster centers should be available after fitting"),
};
assert_eq!(centers.dim().1, 2);
let labels = match ms.get_labels() {
Some(labels) => labels,
None => panic!("Labels should be available after fitting"),
};
assert_eq!(labels.len(), data.dim().0); }
#[test]
fn test_meanshift_predict() {
let data = create_test_data();
let mut ms = MeanShift::new(2.0, None, None, None, Some(true)).unwrap();
ms.fit(&data.view()).unwrap();
let test_points = arr2(&[
[-15.0, -15.0], [5.0, 5.0], [25.0, 25.0], ]);
let predictions = ms.predict(test_points.view()).unwrap();
assert_eq!(predictions.len(), 3);
assert_ne!(predictions[0], predictions[1]);
assert_ne!(predictions[1], predictions[2]);
assert_ne!(predictions[0], predictions[2]);
}
#[test]
fn test_meanshift_fit_predict() {
let data = create_test_data();
let mut ms = MeanShift::new(2.0, None, None, None, Some(true)).unwrap();
let labels = ms.fit_predict(&data.view()).unwrap();
assert_eq!(labels.len(), data.dim().0);
let mut unique_labels = std::collections::HashSet::new();
for label in labels.iter() {
unique_labels.insert(*label);
}
assert!(unique_labels.len() >= 3);
}
#[test]
fn test_bin_seeding() {
let data = create_test_data();
let mut ms1 = MeanShift::new(2.0, None, None, Some(false), None).unwrap();
let mut ms2 = MeanShift::new(2.0, None, None, Some(true), None).unwrap();
ms1.fit(&data.view()).unwrap();
ms2.fit(&data.view()).unwrap();
assert!(ms1.get_cluster_centers().is_some());
assert!(ms2.get_cluster_centers().is_some());
}
#[test]
fn test_estimate_bandwidth() {
let data = create_test_data();
let bw1 = estimate_bandwidth(&data.view(), None, None, None).unwrap();
assert!(bw1 > 0.0);
let bw2 = estimate_bandwidth(&data.view(), Some(0.3), None, None).unwrap();
assert!(bw2 > 0.0);
let bw3 = estimate_bandwidth(&data.view(), None, Some(50), None).unwrap();
assert!(bw3 > 0.0);
let bw4 = estimate_bandwidth(&data.view(), None, None, Some(42)).unwrap();
assert!(bw4 > 0.0);
let bw5 = estimate_bandwidth(&data.view(), None, None, Some(42)).unwrap();
assert_eq!(bw4, bw5);
}
#[test]
fn test_cluster_all_parameter() {
let data = create_test_data();
let mut ms1 = MeanShift::new(1.0, None, None, None, Some(false)).unwrap();
ms1.fit(&data.view()).unwrap();
let labels1 = match ms1.get_labels() {
Some(labels) => labels,
None => panic!("Labels should be available after fitting"),
};
let mut ms2 = MeanShift::new(1.0, None, None, None, Some(true)).unwrap();
ms2.fit(&data.view()).unwrap();
let labels2 = match ms2.get_labels() {
Some(labels) => labels,
None => panic!("Labels should be available after fitting"),
};
assert_eq!(labels1.len(), labels2.len());
}
#[test]
fn test_fit_with_max_iterations() {
let data = create_test_data();
let mut ms = MeanShift::new(2.0, Some(1), None, None, None).unwrap();
ms.fit(&data.view()).unwrap();
assert_eq!(ms.get_actual_iterations().unwrap(), 1);
}