use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::random::{Rng, RngExt, SeedableRng};
use std::fmt::Debug;
use super::{euclidean_distance, vq};
use crate::error::{ClusteringError, Result};
#[derive(Debug, Clone)]
pub struct KMeansOptions<F: Float> {
pub max_iter: usize,
pub tol: F,
pub random_seed: Option<u64>,
pub n_init: usize,
pub init_method: KMeansInit,
}
impl<F: Float + FromPrimitive> Default for KMeansOptions<F> {
fn default() -> Self {
Self {
max_iter: 300,
tol: F::from(1e-4).expect("Failed to convert constant to float"),
random_seed: None,
n_init: 10,
init_method: KMeansInit::KMeansPlusPlus,
}
}
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn kmeans<F>(
obs: ArrayView2<F>,
k_or_guess: usize,
iter: Option<usize>,
thresh: Option<F>,
check_finite: Option<bool>,
seed: Option<u64>,
) -> Result<(Array2<F>, F)>
where
F: Float + FromPrimitive + Debug + std::iter::Sum + std::fmt::Display,
{
let k = k_or_guess; let max_iter = iter.unwrap_or(20);
let tol = thresh.unwrap_or(F::from(1e-5).expect("Failed to convert constant to float"));
let _check_finite_flag = check_finite.unwrap_or(true);
if obs.is_empty() {
return Err(ClusteringError::InvalidInput(
"Input data is empty".to_string(),
));
}
if k == 0 {
return Err(ClusteringError::InvalidInput(
"Number of clusters must be greater than 0".to_string(),
));
}
if k > obs.nrows() {
return Err(ClusteringError::InvalidInput(format!(
"Number of clusters ({}) cannot be greater than number of data points ({})",
k,
obs.nrows()
)));
}
let options = KMeansOptions {
max_iter,
tol,
random_seed: seed,
n_init: 1, init_method: KMeansInit::KMeansPlusPlus,
};
let (centroids, labels) = kmeans_with_options(obs, k, Some(options))?;
let distortion = calculate_distortion(obs, centroids.view(), &labels);
Ok((centroids, distortion))
}
#[allow(dead_code)]
pub fn kmeans_with_options<F>(
data: ArrayView2<F>,
k: usize,
options: Option<KMeansOptions<F>>,
) -> Result<(Array2<F>, Array1<usize>)>
where
F: Float + FromPrimitive + Debug + std::iter::Sum,
{
if k == 0 {
return Err(ClusteringError::InvalidInput(
"Number of clusters must be greater than 0".to_string(),
));
}
let n_samples = data.shape()[0];
if n_samples == 0 {
return Err(ClusteringError::InvalidInput(
"Input data is empty".to_string(),
));
}
if k > n_samples {
return Err(ClusteringError::InvalidInput(format!(
"Number of clusters ({}) cannot be greater than number of data points ({})",
k, n_samples
)));
}
let opts = options.unwrap_or_default();
let mut bestcentroids = None;
let mut best_labels = None;
let mut best_inertia = F::infinity();
let n_init = if opts.init_method == KMeansInit::KMeansParallel {
1
} else {
opts.n_init
};
for _ in 0..n_init {
let centroids = kmeans_init(data, k, Some(opts.init_method), opts.random_seed)?;
let (centroids, labels, inertia) = _kmeans_single(data, centroids.view(), &opts)?;
if inertia < best_inertia {
bestcentroids = Some(centroids);
best_labels = Some(labels);
best_inertia = inertia;
}
}
Ok((
bestcentroids.expect("Operation failed"),
best_labels.expect("Operation failed"),
))
}
#[allow(dead_code)]
fn calculate_distortion<F>(
data: ArrayView2<F>,
centroids: ArrayView2<F>,
labels: &Array1<usize>,
) -> F
where
F: Float + FromPrimitive + Debug + std::iter::Sum,
{
let n_samples = data.shape()[0];
let mut total_distortion = F::zero();
for i in 0..n_samples {
let cluster = labels[i];
let point = data.slice(s![i, ..]);
let centroid = centroids.slice(s![cluster, ..]);
let squared_distance = euclidean_distance(point, centroid).powi(2);
total_distortion = total_distortion + squared_distance;
}
total_distortion
}
#[allow(dead_code)]
fn _kmeans_single<F>(
data: ArrayView2<F>,
initcentroids: ArrayView2<F>,
opts: &KMeansOptions<F>,
) -> Result<(Array2<F>, Array1<usize>, F)>
where
F: Float + FromPrimitive + Debug + std::iter::Sum,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
let k = initcentroids.shape()[0];
let mut centroids = initcentroids.to_owned();
let mut labels = Array1::zeros(n_samples);
let mut prev_centroid_diff = F::infinity();
for _iter in 0..opts.max_iter {
let (new_labels, distances) = vq(data, centroids.view())?;
labels = new_labels;
let mut newcentroids = Array2::zeros((k, n_features));
let mut counts = Array1::zeros(k);
for i in 0..n_samples {
let cluster = labels[i];
let point = data.slice(s![i, ..]);
for j in 0..n_features {
newcentroids[[cluster, j]] = newcentroids[[cluster, j]] + point[j];
}
counts[cluster] += 1;
}
for i in 0..k {
if counts[i] == 0 {
let mut max_dist = F::zero();
let mut far_idx = 0;
for j in 0..n_samples {
let dist = distances[j];
if dist > max_dist {
max_dist = dist;
far_idx = j;
}
}
for j in 0..n_features {
newcentroids[[i, j]] = data[[far_idx, j]];
}
counts[i] = 1;
} else {
for j in 0..n_features {
newcentroids[[i, j]] = newcentroids[[i, j]]
/ F::from(counts[i]).expect("Failed to convert to float");
}
}
}
let mut centroid_diff = F::zero();
for i in 0..k {
let dist =
euclidean_distance(centroids.slice(s![i, ..]), newcentroids.slice(s![i, ..]));
centroid_diff = centroid_diff + dist;
}
centroids = newcentroids;
if centroid_diff <= opts.tol || centroid_diff >= prev_centroid_diff {
break;
}
prev_centroid_diff = centroid_diff;
}
let mut inertia = F::zero();
for i in 0..n_samples {
let cluster = labels[i];
let dist = euclidean_distance(data.slice(s![i, ..]), centroids.slice(s![cluster, ..]));
inertia = inertia + dist * dist;
}
Ok((centroids, labels, inertia))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum KMeansInit {
Random,
#[default]
KMeansPlusPlus,
KMeansParallel,
}
#[allow(dead_code)]
pub fn kmeans_init<F>(
data: ArrayView2<F>,
k: usize,
init_method: Option<KMeansInit>,
random_seed: Option<u64>,
) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + std::iter::Sum,
{
match init_method.unwrap_or_default() {
KMeansInit::Random => random_init(data, k, random_seed),
KMeansInit::KMeansPlusPlus => kmeans_plus_plus(data, k, random_seed),
KMeansInit::KMeansParallel => kmeans_parallel(data, k, random_seed),
}
}
#[allow(dead_code)]
pub fn random_init<F>(data: ArrayView2<F>, k: usize, random_seed: Option<u64>) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + std::iter::Sum,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
if k == 0 || k > n_samples {
return Err(ClusteringError::InvalidInput(format!(
"Number of clusters ({}) must be between 1 and number of samples ({})",
k, n_samples
)));
}
let mut rng = scirs2_core::random::rng();
let mut centroids = Array2::zeros((k, n_features));
let mut selected_indices = Vec::with_capacity(k);
while selected_indices.len() < k {
let idx = rng.random_range(0..n_samples);
if !selected_indices.contains(&idx) {
selected_indices.push(idx);
}
}
for (i, &idx) in selected_indices.iter().enumerate() {
for j in 0..n_features {
centroids[[i, j]] = data[[idx, j]];
}
}
Ok(centroids)
}
#[allow(dead_code)]
pub fn kmeans_plus_plus<F>(
data: ArrayView2<F>,
k: usize,
random_seed: Option<u64>,
) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + std::iter::Sum,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
if k == 0 || k > n_samples {
return Err(ClusteringError::InvalidInput(format!(
"Number of clusters ({}) must be between 1 and number of samples ({})",
k, n_samples
)));
}
let mut rng = match random_seed {
Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
None => scirs2_core::random::rngs::StdRng::seed_from_u64(
scirs2_core::random::rng().random::<u64>(),
),
};
let mut centroids = Array2::zeros((k, n_features));
let first_idx = rng.random_range(0..n_samples);
for j in 0..n_features {
centroids[[0, j]] = data[[first_idx, j]];
}
if k == 1 {
return Ok(centroids);
}
for i in 1..k {
let mut min_distances = Array1::from_elem(n_samples, F::infinity());
for sample_idx in 0..n_samples {
let sample = data.slice(s![sample_idx, ..]);
for centroid_idx in 0..i {
let centroid = centroids.slice(s![centroid_idx, ..]);
let dist = euclidean_distance(sample, centroid);
if dist < min_distances[sample_idx] {
min_distances[sample_idx] = dist;
}
}
}
let mut weights = min_distances.mapv(|d| d * d);
let sum_weights = weights.sum();
if sum_weights > F::zero() {
weights.mapv_inplace(|w| w / sum_weights);
} else {
weights.fill(F::from(1.0 / n_samples as f64).expect("Failed to convert to float"));
}
let mut cum_weights = weights.clone();
for j in 1..n_samples {
cum_weights[j] = cum_weights[j] + cum_weights[j - 1];
}
let rand_val = F::from(rng.random_range(0.0..1.0)).expect("Operation failed");
let mut next_idx = 0;
for j in 0..n_samples {
if rand_val <= cum_weights[j] {
next_idx = j;
break;
}
}
for j in 0..n_features {
centroids[[i, j]] = data[[next_idx, j]];
}
}
Ok(centroids)
}
#[allow(dead_code)]
pub fn kmeans_parallel<F>(
data: ArrayView2<F>,
k: usize,
random_seed: Option<u64>,
) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + std::iter::Sum,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
if k == 0 || k > n_samples {
return Err(ClusteringError::InvalidInput(format!(
"Number of clusters ({}) must be between 1 and number of samples ({})",
k, n_samples
)));
}
let mut rng = scirs2_core::random::rng();
let l = F::from(5.0).expect("Failed to convert constant to float"); let n_rounds = 8;
let mut centers = Vec::new();
let mut weights = Vec::new();
let first_idx = rng.random_range(0..n_samples);
let mut first_center = Vec::with_capacity(n_features);
for j in 0..n_features {
first_center.push(data[[first_idx, j]]);
}
centers.push(first_center);
weights.push(F::one());
for _ in 0..n_rounds {
let mut min_distances = Array1::from_elem(n_samples, F::infinity());
for sample_idx in 0..n_samples {
let sample = data.slice(s![sample_idx, ..]);
for center in centers.iter() {
let mut dist_sq = F::zero();
for j in 0..n_features {
let diff = sample[j] - center[j];
dist_sq = dist_sq + diff * diff;
}
let dist = dist_sq.sqrt();
if dist < min_distances[sample_idx] {
min_distances[sample_idx] = dist;
}
}
}
let potential: F = min_distances.iter().map(|&d| d * d).sum();
if potential <= F::epsilon() {
break; }
let expected_new_centers = l * F::from(k).expect("Failed to convert to float");
let oversampling = F::min(expected_new_centers / potential, F::one());
for sample_idx in 0..n_samples {
let probability = min_distances[sample_idx] * min_distances[sample_idx] * oversampling;
if F::from(rng.random_range(0.0..1.0)).expect("Operation failed") < probability {
let mut new_center = Vec::with_capacity(n_features);
for j in 0..n_features {
new_center.push(data[[sample_idx, j]]);
}
centers.push(new_center);
weights.push(F::one()); }
}
}
match centers.len().cmp(&k) {
std::cmp::Ordering::Greater => {
let n_centers = centers.len();
let mut centers_array = Array2::zeros((n_centers, n_features));
let mut weights_array = Array1::zeros(n_centers);
for i in 0..n_centers {
for j in 0..n_features {
centers_array[[i, j]] = centers[i][j];
}
weights_array[i] = weights[i];
}
let options = KMeansOptions {
max_iter: 100,
tol: F::from(1e-4).expect("Failed to convert constant to float"),
random_seed,
n_init: 1,
init_method: KMeansInit::KMeansPlusPlus,
};
let init_indices: Vec<usize> = (0..n_centers)
.filter(|_| rng.random_range(0.0..1.0) < 0.5) .take(k) .collect();
let actual_indices = if init_indices.len() < k {
(0..k.min(n_centers)).collect::<Vec<usize>>()
} else {
init_indices
};
let mut initcentroids = Array2::zeros((actual_indices.len(), n_features));
for (i, &idx) in actual_indices.iter().enumerate() {
for j in 0..n_features {
initcentroids[[i, j]] = centers_array[[idx, j]];
}
}
let (finalcentroids_, _) = _weighted_kmeans_single(
centers_array.view(),
weights_array.view(),
initcentroids.view(),
&options,
)?;
Ok(finalcentroids_)
}
std::cmp::Ordering::Less => {
let mut centroids = Array2::zeros((k, n_features));
for i in 0..centers.len() {
for j in 0..n_features {
centroids[[i, j]] = centers[i][j];
}
}
let mut selected_indices = Vec::with_capacity(k - centers.len());
while selected_indices.len() < k - centers.len() {
let idx = rng.random_range(0..n_samples);
if !selected_indices.contains(&idx) {
selected_indices.push(idx);
}
}
for (i, &idx) in selected_indices.iter().enumerate() {
for j in 0..n_features {
centroids[[centers.len() + i, j]] = data[[idx, j]];
}
}
Ok(centroids)
}
std::cmp::Ordering::Equal => {
let mut centroids = Array2::zeros((k, n_features));
for i in 0..k {
for j in 0..n_features {
centroids[[i, j]] = centers[i][j];
}
}
Ok(centroids)
}
}
}
#[allow(dead_code)]
fn _weighted_kmeans_single<F>(
data: ArrayView2<F>,
weights: ArrayView1<F>,
initcentroids: ArrayView2<F>,
opts: &KMeansOptions<F>,
) -> Result<(Array2<F>, Array1<usize>)>
where
F: Float + FromPrimitive + Debug + std::iter::Sum,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
let k = initcentroids.shape()[0];
let mut centroids = initcentroids.to_owned();
let mut labels = Array1::zeros(n_samples);
let mut prev_centroid_diff = F::infinity();
for _iter in 0..opts.max_iter {
let (new_labels_, _) = vq(data, centroids.view())?;
labels = new_labels_;
let mut newcentroids = Array2::zeros((k, n_features));
let mut total_weights = Array1::zeros(k);
for i in 0..n_samples {
let cluster = labels[i];
let point = data.slice(s![i, ..]);
let weight = weights[i];
for j in 0..n_features {
newcentroids[[cluster, j]] = newcentroids[[cluster, j]] + point[j] * weight;
}
total_weights[cluster] = total_weights[cluster] + weight;
}
for i in 0..k {
if total_weights[i] <= F::epsilon() {
let mut max_dist = F::zero();
let mut far_idx = 0;
for j in 0..n_samples {
let dist = euclidean_distance(
data.slice(s![j, ..]),
centroids.slice(s![labels[j], ..]),
);
if dist > max_dist {
max_dist = dist;
far_idx = j;
}
}
for j in 0..n_features {
newcentroids[[i, j]] = data[[far_idx, j]];
}
total_weights[i] = weights[far_idx];
} else {
for j in 0..n_features {
newcentroids[[i, j]] = newcentroids[[i, j]] / total_weights[i];
}
}
}
let mut centroid_diff = F::zero();
for i in 0..k {
let dist =
euclidean_distance(centroids.slice(s![i, ..]), newcentroids.slice(s![i, ..]));
centroid_diff = centroid_diff + dist;
}
centroids = newcentroids;
if centroid_diff <= opts.tol || centroid_diff >= prev_centroid_diff {
break;
}
prev_centroid_diff = centroid_diff;
}
Ok((centroids, labels))
}
#[allow(dead_code)]
pub fn kmeans_with_metric<F>(
data: ArrayView2<F>,
k: usize,
metric: Box<dyn crate::vq::VQDistanceMetric<F>>,
options: Option<KMeansOptions<F>>,
) -> Result<(Array2<F>, Array1<usize>)>
where
F: Float + FromPrimitive + Debug + std::iter::Sum + Send + Sync + 'static,
{
if k == 0 {
return Err(ClusteringError::InvalidInput(
"Number of clusters must be greater than 0".to_string(),
));
}
let n_samples = data.shape()[0];
if n_samples == 0 {
return Err(ClusteringError::InvalidInput(
"Input data is empty".to_string(),
));
}
if k > n_samples {
return Err(ClusteringError::InvalidInput(format!(
"Number of clusters ({}) cannot be greater than number of data points ({})",
k, n_samples
)));
}
let opts = options.unwrap_or_default();
let mut bestcentroids = None;
let mut best_labels = None;
let mut best_inertia = F::infinity();
let n_init = if opts.init_method == KMeansInit::KMeansParallel {
1
} else {
opts.n_init
};
for _ in 0..n_init {
let centroids = kmeans_init(data, k, Some(opts.init_method), opts.random_seed)?;
let (centroids, labels, inertia) =
_kmeans_single_with_metric(data, centroids.view(), metric.as_ref(), &opts)?;
if inertia < best_inertia {
bestcentroids = Some(centroids);
best_labels = Some(labels);
best_inertia = inertia;
}
}
Ok((
bestcentroids.expect("Operation failed"),
best_labels.expect("Operation failed"),
))
}
#[allow(dead_code)]
fn _kmeans_single_with_metric<F>(
data: ArrayView2<F>,
initcentroids: ArrayView2<F>,
metric: &dyn crate::vq::VQDistanceMetric<F>,
opts: &KMeansOptions<F>,
) -> Result<(Array2<F>, Array1<usize>, F)>
where
F: Float + FromPrimitive + Debug + std::iter::Sum + Send + Sync,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
let k = initcentroids.shape()[0];
let mut centroids = initcentroids.to_owned();
let mut labels = Array1::zeros(n_samples);
let mut prev_centroid_diff = F::infinity();
for _iter in 0..opts.max_iter {
let (new_labels, distances) = _vq_with_metric(data, centroids.view(), metric)?;
labels = new_labels;
let mut newcentroids = Array2::zeros((k, n_features));
let mut counts = Array1::zeros(k);
for i in 0..n_samples {
let cluster = labels[i];
let point = data.slice(s![i, ..]);
for j in 0..n_features {
newcentroids[[cluster, j]] = newcentroids[[cluster, j]] + point[j];
}
counts[cluster] += 1;
}
for i in 0..k {
if counts[i] == 0 {
let mut max_dist = F::zero();
let mut far_idx = 0;
for j in 0..n_samples {
let dist = distances[j];
if dist > max_dist {
max_dist = dist;
far_idx = j;
}
}
for j in 0..n_features {
newcentroids[[i, j]] = data[[far_idx, j]];
}
counts[i] = 1;
} else {
for j in 0..n_features {
newcentroids[[i, j]] = newcentroids[[i, j]]
/ F::from(counts[i]).expect("Failed to convert to float");
}
}
}
let mut centroid_diff = F::zero();
for i in 0..k {
let dist = metric.distance(centroids.slice(s![i, ..]), newcentroids.slice(s![i, ..]));
centroid_diff = centroid_diff + dist;
}
centroids = newcentroids;
if centroid_diff <= opts.tol || centroid_diff >= prev_centroid_diff {
break;
}
prev_centroid_diff = centroid_diff;
}
let mut inertia = F::zero();
for i in 0..n_samples {
let cluster = labels[i];
let dist = metric.distance(data.slice(s![i, ..]), centroids.slice(s![cluster, ..]));
inertia = inertia + dist * dist;
}
Ok((centroids, labels, inertia))
}
#[allow(dead_code)]
fn _vq_with_metric<F>(
data: ArrayView2<F>,
centroids: ArrayView2<F>,
metric: &dyn crate::vq::VQDistanceMetric<F>,
) -> Result<(Array1<usize>, Array1<F>)>
where
F: Float + FromPrimitive + Debug + Send + Sync,
{
let n_samples = data.shape()[0];
let ncentroids = centroids.shape()[0];
let mut labels = Array1::zeros(n_samples);
let mut distances = Array1::zeros(n_samples);
for i in 0..n_samples {
let point = data.slice(s![i, ..]);
let mut min_dist = F::infinity();
let mut closest_centroid = 0;
for j in 0..ncentroids {
let centroid = centroids.slice(s![j, ..]);
let dist = metric.distance(point, centroid);
if dist < min_dist {
min_dist = dist;
closest_centroid = j;
}
}
labels[i] = closest_centroid;
distances[i] = min_dist;
}
Ok((labels, distances))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_kmeans_random_init() {
let data = Array2::from_shape_vec(
(6, 2),
vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
)
.expect("Operation failed");
let options = KMeansOptions {
init_method: KMeansInit::Random,
..Default::default()
};
let result = kmeans_with_options(data.view(), 2, Some(options));
assert!(result.is_ok());
let (centroids, labels) = result.expect("Operation failed");
assert_eq!(centroids.shape(), &[2, 2]);
assert_eq!(labels.len(), 6);
let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
assert_eq!(unique_labels.len(), 2);
}
#[test]
fn test_kmeans_plusplus_init() {
let data = Array2::from_shape_vec(
(6, 2),
vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
)
.expect("Operation failed");
let options = KMeansOptions {
init_method: KMeansInit::KMeansPlusPlus,
..Default::default()
};
let result = kmeans_with_options(data.view(), 2, Some(options));
assert!(result.is_ok());
let (centroids, labels) = result.expect("Operation failed");
assert_eq!(centroids.shape(), &[2, 2]);
assert_eq!(labels.len(), 6);
let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
assert_eq!(unique_labels.len(), 2);
}
#[test]
fn test_kmeans_parallel_init() {
let data = Array2::from_shape_vec(
(20, 2),
vec![
1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 1.1, 2.2, 0.9, 1.7, 1.3, 2.1, 1.0, 1.9, 0.7, 2.0,
1.2, 2.3, 1.5, 1.8, 5.0, 6.0, 5.2, 5.8, 4.8, 6.2, 5.1, 5.9, 5.3, 6.1, 4.9, 5.7,
5.0, 6.3, 5.4, 5.6, 4.7, 5.9, 5.2, 6.2,
],
)
.expect("Operation failed");
let options = KMeansOptions {
init_method: KMeansInit::KMeansParallel,
..Default::default()
};
let result = kmeans_with_options(data.view(), 2, Some(options));
assert!(result.is_ok());
let (centroids, labels) = result.expect("Operation failed");
assert_eq!(centroids.shape(), &[2, 2]);
assert_eq!(labels.len(), 20);
let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
assert_eq!(unique_labels.len(), 2);
let first_cluster = labels[0];
for i in 0..10 {
assert_eq!(labels[i], first_cluster);
}
let second_cluster = labels[10];
assert_ne!(first_cluster, second_cluster);
for i in 10..20 {
assert_eq!(labels[i], second_cluster);
}
}
#[test]
fn test_scipy_compatible_kmeans() {
let data = Array2::from_shape_vec(
(6, 2),
vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
)
.expect("Operation failed");
let result = kmeans(
data.view(),
2, Some(20), Some(1e-5), Some(true), Some(42), );
assert!(result.is_ok());
let (centroids, distortion) = result.expect("Operation failed");
assert_eq!(centroids.shape(), &[2, 2]);
assert!(distortion > 0.0);
let result = kmeans(
data.view(),
2, None, None, None, None, );
assert!(result.is_ok());
let (centroids2, distortion2) = result.expect("Operation failed");
assert_eq!(centroids2.shape(), &[2, 2]);
assert!(distortion2 > 0.0);
}
#[test]
fn test_scipy_kmeans_check_finite() {
let data = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 1.5, 1.5, 8.0, 8.0, 8.5, 8.5])
.expect("Operation failed");
let result = kmeans(
data.view(),
2,
Some(10),
Some(1e-5),
Some(true), Some(42),
);
assert!(result.is_ok());
let result = kmeans(
data.view(),
2,
Some(10),
Some(1e-5),
Some(false), Some(42),
);
assert!(result.is_ok());
}
}