use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2, Axis};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::parallel_ops::*;
use std::fmt::Debug;
use std::sync::Mutex;
use super::{euclidean_distance, kmeans_init, KMeansInit};
use crate::error::{ClusteringError, Result};
#[derive(Debug, Clone)]
pub struct ParallelKMeansOptions<F: Float> {
pub max_iter: usize,
pub tol: F,
pub random_seed: Option<u64>,
pub n_init: usize,
pub init_method: KMeansInit,
pub n_threads: Option<usize>,
}
impl<F: Float + FromPrimitive> Default for ParallelKMeansOptions<F> {
fn default() -> Self {
Self {
max_iter: 300,
tol: F::from(1e-4).expect("Failed to convert constant to float"),
random_seed: None,
n_init: 10,
init_method: KMeansInit::KMeansPlusPlus,
n_threads: None,
}
}
}
#[allow(dead_code)]
pub fn parallel_kmeans<F>(
data: ArrayView2<F>,
k: usize,
options: Option<ParallelKMeansOptions<F>>,
) -> Result<(Array2<F>, Array1<usize>)>
where
F: Float + FromPrimitive + Debug + std::iter::Sum + Send + Sync,
{
if k == 0 {
return Err(ClusteringError::InvalidInput(
"Number of clusters must be greater than 0".to_string(),
));
}
let n_samples = data.shape()[0];
if n_samples == 0 {
return Err(ClusteringError::InvalidInput(
"Input data is empty".to_string(),
));
}
if k > n_samples {
return Err(ClusteringError::InvalidInput(format!(
"Number of clusters ({}) cannot be greater than number of data points ({})",
k, n_samples
)));
}
let opts = options.unwrap_or_default();
if let Some(_n_threads) = opts.n_threads {
}
let mut bestcentroids = None;
let mut best_labels = None;
let mut best_inertia = F::infinity();
for _ in 0..opts.n_init {
let centroids = kmeans_init(data, k, Some(opts.init_method), opts.random_seed)?;
let (centroids, labels, inertia) = parallel_kmeans_single(data, centroids.view(), &opts)?;
if inertia < best_inertia {
bestcentroids = Some(centroids);
best_labels = Some(labels);
best_inertia = inertia;
}
}
Ok((
bestcentroids.expect("Operation failed"),
best_labels.expect("Operation failed"),
))
}
#[allow(dead_code)]
fn parallel_kmeans_single<F>(
data: ArrayView2<F>,
initcentroids: ArrayView2<F>,
opts: &ParallelKMeansOptions<F>,
) -> Result<(Array2<F>, Array1<usize>, F)>
where
F: Float + FromPrimitive + Debug + std::iter::Sum + Send + Sync,
{
let n_samples = data.shape()[0];
let _n_features = data.shape()[1];
let k = initcentroids.shape()[0];
let mut centroids = initcentroids.to_owned();
let mut labels = Array1::zeros(n_samples);
let mut prev_inertia = F::infinity();
for _iter in 0..opts.max_iter {
let (new_labels, distances) = parallel_assign_labels(data, centroids.view())?;
labels = new_labels;
let newcentroids = parallel_updatecentroids(data, &labels, k)?;
let cluster_counts = count_clusters(&labels, k);
let mut finalcentroids = newcentroids;
for (i, &count) in cluster_counts.iter().enumerate() {
if count == 0 {
let (far_idx, _) = distances
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("Operation failed"))
.expect("Operation failed");
finalcentroids
.slice_mut(s![i, ..])
.assign(&data.slice(s![far_idx, ..]));
}
}
let inertia = parallel_compute_inertia(data, &labels, finalcentroids.view())?;
if (prev_inertia - inertia).abs() <= opts.tol {
return Ok((finalcentroids, labels, inertia));
}
centroids = finalcentroids;
prev_inertia = inertia;
}
let final_inertia = parallel_compute_inertia(data, &labels, centroids.view())?;
Ok((centroids, labels, final_inertia))
}
#[allow(dead_code)]
fn parallel_assign_labels<F>(
data: ArrayView2<F>,
centroids: ArrayView2<F>,
) -> Result<(Array1<usize>, Array1<F>)>
where
F: Float + FromPrimitive + Debug + Send + Sync,
{
let n_samples = data.shape()[0];
let k = centroids.shape()[0];
let results: Vec<(usize, F)> = (0..n_samples)
.into_par_iter()
.map(|i| {
let sample = data.slice(s![i, ..]);
let mut min_dist = F::infinity();
let mut best_label = 0;
for j in 0..k {
let centroid = centroids.slice(s![j, ..]);
let dist = euclidean_distance(sample, centroid);
if dist < min_dist {
min_dist = dist;
best_label = j;
}
}
(best_label, min_dist)
})
.collect();
let mut labels = Array1::zeros(n_samples);
let mut distances = Array1::zeros(n_samples);
for (i, (label, dist)) in results.into_iter().enumerate() {
labels[i] = label;
distances[i] = dist;
}
Ok((labels, distances))
}
#[allow(dead_code)]
fn parallel_updatecentroids<F>(
data: ArrayView2<F>,
labels: &Array1<usize>,
k: usize,
) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + Send + Sync + std::iter::Sum,
{
let n_features = data.shape()[1];
let sums: Vec<Mutex<Array1<F>>> = (0..k)
.map(|_| Mutex::new(Array1::zeros(n_features)))
.collect();
let counts: Vec<Mutex<usize>> = (0..k).map(|_| Mutex::new(0)).collect();
data.axis_iter(Axis(0))
.zip(labels.iter())
.par_bridge()
.for_each(|(sample, &label)| {
let mut sum = sums[label].lock().expect("Operation failed");
for i in 0..n_features {
sum[i] = sum[i] + sample[i];
}
let mut count = counts[label].lock().expect("Operation failed");
*count += 1;
});
let mut newcentroids = Array2::zeros((k, n_features));
for i in 0..k {
let sum = sums[i].lock().expect("Operation failed");
let count = *counts[i].lock().expect("Operation failed");
if count > 0 {
for j in 0..n_features {
newcentroids[[i, j]] = sum[j] / F::from(count).expect("Failed to convert to float");
}
}
}
Ok(newcentroids)
}
#[allow(dead_code)]
fn count_clusters(labels: &Array1<usize>, k: usize) -> Vec<usize> {
let mut counts = vec![0; k];
for &label in labels.iter() {
counts[label] += 1;
}
counts
}
#[allow(dead_code)]
fn parallel_compute_inertia<F>(
data: ArrayView2<F>,
labels: &Array1<usize>,
centroids: ArrayView2<F>,
) -> Result<F>
where
F: Float + FromPrimitive + Debug + Send + Sync + std::iter::Sum,
{
let inertia: F = data
.axis_iter(Axis(0))
.zip(labels.iter())
.par_bridge()
.map(|(sample, &label)| {
let centroid = centroids.slice(s![label, ..]);
let dist = euclidean_distance(sample.view(), centroid);
dist * dist
})
.sum();
Ok(inertia)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_parallel_kmeans_simple() {
let data = Array2::from_shape_vec(
(6, 2),
vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
)
.expect("Operation failed");
let options = ParallelKMeansOptions {
n_init: 1,
random_seed: Some(42),
..Default::default()
};
let (centroids, labels) =
parallel_kmeans(data.view(), 2, Some(options)).expect("Operation failed");
assert_eq!(centroids.shape(), &[2, 2]);
assert_eq!(labels.len(), 6);
let unique_labels: Vec<_> = labels
.iter()
.copied()
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
assert_eq!(unique_labels.len(), 2);
}
#[test]
fn test_parallel_kmeans_large_dataset() {
let n_samples = 1000;
let n_features = 10;
let mut data_vec = Vec::with_capacity(n_samples * n_features);
for i in 0..n_samples {
for j in 0..n_features {
let cluster = i / (n_samples / 3);
let value = (cluster * 10) as f64 + (j as f64 + i as f64 * 0.01);
data_vec.push(value);
}
}
let data =
Array2::from_shape_vec((n_samples, n_features), data_vec).expect("Operation failed");
let options = ParallelKMeansOptions {
n_init: 3,
max_iter: 50,
random_seed: Some(42),
..Default::default()
};
let start_time = std::time::Instant::now();
let (centroids, labels) =
parallel_kmeans(data.view(), 3, Some(options)).expect("Operation failed");
let duration = start_time.elapsed();
println!("Parallel K-means took: {duration:?}");
assert_eq!(centroids.shape(), &[3, n_features]);
assert_eq!(labels.len(), n_samples);
for &label in labels.iter() {
assert!(label < 3);
}
}
}