use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::random::{Rng, RngExt};
use std::fmt::Debug;
use super::simd_optimizations::{
calculate_distortion_simd, compute_centroids_simd, vq_simd, SimdOptimizationConfig,
};
use super::{kmeans_init, KMeansInit, KMeansOptions};
use crate::error::{ClusteringError, Result};
use scirs2_core::parallel_ops::*;
use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
use scirs2_core::validation::clustering::*;
#[allow(dead_code)]
pub fn kmeans_simd<F>(
data: ArrayView2<F>,
k: usize,
options: Option<KMeansOptions<F>>,
simd_config: Option<SimdOptimizationConfig>,
) -> Result<(Array2<F>, Array1<usize>, F)>
where
F: Float
+ FromPrimitive
+ Debug
+ Send
+ Sync
+ SimdUnifiedOps
+ std::iter::Sum
+ std::fmt::Display,
{
if k == 0 {
return Err(ClusteringError::InvalidInput(
"Number of clusters must be greater than 0".to_string(),
));
}
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
if n_samples == 0 {
return Err(ClusteringError::InvalidInput(
"Input data is empty".to_string(),
));
}
if k > n_samples {
return Err(ClusteringError::InvalidInput(format!(
"Number of clusters ({}) cannot be greater than number of data points ({})",
k, n_samples
)));
}
validate_clustering_data(&data, "SIMD K-means", true, Some(k))
.map_err(|e| ClusteringError::InvalidInput(format!("SIMD K-means: {}", e)))?;
let opts = options.unwrap_or_default();
let simd_config = simd_config.unwrap_or_default();
let caps = PlatformCapabilities::detect();
let use_simd = caps.simd_available || simd_config.force_simd;
if !use_simd {
eprintln!(
"Warning: SIMD not available or not beneficial, falling back to standard K-means"
);
}
let mut bestcentroids = None;
let mut best_labels = None;
let mut best_inertia = F::infinity();
let n_init = if opts.init_method == KMeansInit::KMeansParallel {
1
} else {
opts.n_init
};
for init_idx in 0..n_init {
let centroids = kmeans_init(data, k, Some(opts.init_method), opts.random_seed)?;
let (finalcentroids, labels, inertia) =
simd_kmeans_single(data, centroids.view(), &opts, &simd_config)?;
if inertia < best_inertia {
bestcentroids = Some(finalcentroids);
best_labels = Some(labels);
best_inertia = inertia;
}
if n_init == 1 {
break;
}
}
Ok((
bestcentroids.expect("Operation failed"),
best_labels.expect("Operation failed"),
best_inertia,
))
}
#[allow(dead_code)]
fn simd_kmeans_single<F>(
data: ArrayView2<F>,
initcentroids: ArrayView2<F>,
opts: &KMeansOptions<F>,
simd_config: &SimdOptimizationConfig,
) -> Result<(Array2<F>, Array1<usize>, F)>
where
F: Float
+ FromPrimitive
+ Debug
+ Send
+ Sync
+ SimdUnifiedOps
+ std::iter::Sum
+ std::fmt::Display,
{
let n_samples = data.shape()[0];
let _n_features = data.shape()[1];
let k = initcentroids.shape()[0];
let mut centroids = initcentroids.to_owned();
let mut labels = Array1::zeros(n_samples);
let mut prev_inertia = F::infinity();
let mut _converged = false;
for iteration in 0..opts.max_iter {
let (new_labels, _distances) = vq_simd(data, centroids.view(), Some(simd_config))?;
labels = new_labels;
let newcentroids = compute_centroids_simd(data, &labels, k, Some(simd_config))?;
let centroid_shift =
compute_centroid_shift_simd(centroids.view(), newcentroids.view(), simd_config)?;
centroids = newcentroids;
let inertia =
calculate_distortion_simd(data, centroids.view(), &labels, Some(simd_config))?;
if centroid_shift <= opts.tol {
_converged = true;
break;
}
if iteration > 0 && (inertia >= prev_inertia || (prev_inertia - inertia) < opts.tol) {
_converged = true;
break;
}
prev_inertia = inertia;
}
let final_inertia =
calculate_distortion_simd(data, centroids.view(), &labels, Some(simd_config))?;
Ok((centroids, labels, final_inertia))
}
#[allow(dead_code)]
fn compute_centroid_shift_simd<F>(
oldcentroids: ArrayView2<F>,
newcentroids: ArrayView2<F>,
simd_config: &SimdOptimizationConfig,
) -> Result<F>
where
F: Float
+ FromPrimitive
+ Debug
+ Send
+ Sync
+ SimdUnifiedOps
+ std::iter::Sum
+ std::fmt::Display,
{
let k = oldcentroids.shape()[0];
let n_features = oldcentroids.shape()[1];
if newcentroids.shape() != oldcentroids.shape() {
return Err(ClusteringError::ComputationError(
"Centroid arrays must have the same shape".to_string(),
));
}
let caps = PlatformCapabilities::detect();
let use_simd = caps.simd_available || simd_config.force_simd;
if simd_config.enable_parallel && is_parallel_enabled() && k > 4 {
let shifts: Vec<F> = (0..k)
.into_par_iter()
.map(|i| {
let old_centroid = oldcentroids.slice(s![i, ..]);
let new_centroid = newcentroids.slice(s![i, ..]);
if use_simd {
let diff = F::simd_sub(&new_centroid, &old_centroid);
F::simd_norm(&diff.view())
} else {
let mut sum = F::zero();
for j in 0..n_features {
let diff = new_centroid[j] - old_centroid[j];
sum = sum + diff * diff;
}
sum.sqrt()
}
})
.collect();
Ok(shifts.into_iter().sum())
} else {
let mut total_shift = F::zero();
for i in 0..k {
let old_centroid = oldcentroids.slice(s![i, ..]);
let new_centroid = newcentroids.slice(s![i, ..]);
let shift = if use_simd {
let diff = F::simd_sub(&new_centroid, &old_centroid);
F::simd_norm(&diff.view())
} else {
let mut sum = F::zero();
for j in 0..n_features {
let diff = new_centroid[j] - old_centroid[j];
sum = sum + diff * diff;
}
sum.sqrt()
};
total_shift = total_shift + shift;
}
Ok(total_shift)
}
}
#[allow(dead_code)]
pub fn mini_batch_kmeans_simd<F>(
data: ArrayView2<F>,
k: usize,
batch_size: usize,
options: Option<KMeansOptions<F>>,
simd_config: Option<SimdOptimizationConfig>,
) -> Result<(Array2<F>, Array1<usize>, F)>
where
F: Float
+ FromPrimitive
+ Debug
+ Send
+ Sync
+ SimdUnifiedOps
+ std::iter::Sum
+ std::fmt::Display,
{
let opts = options.unwrap_or_default();
let simd_config = simd_config.unwrap_or_default();
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
if batch_size >= n_samples {
return kmeans_simd(data, k, Some(opts), Some(simd_config));
}
let mut centroids = kmeans_init(data, k, Some(opts.init_method), opts.random_seed)?;
let mut counts = Array1::<usize>::zeros(k);
let mut rng = scirs2_core::random::rng();
for iteration in 0..opts.max_iter {
let mut batch_indices = Vec::with_capacity(batch_size);
for _ in 0..batch_size {
batch_indices.push(rng.random_range(0..n_samples));
}
let mut batch_data = Array2::<F>::zeros((batch_size, n_features));
for (i, &idx) in batch_indices.iter().enumerate() {
for j in 0..n_features {
batch_data[[i, j]] = data[[idx, j]];
}
}
let (batch_labels, _) = vq_simd(batch_data.view(), centroids.view(), Some(&simd_config))?;
for i in 0..batch_size {
let cluster = batch_labels[i];
counts[cluster] += 1;
let eta = F::one() / F::from(counts[cluster]).expect("Failed to convert to float");
let point = batch_data.slice(s![i, ..]);
let centroid = centroids.slice_mut(s![cluster, ..]);
let caps = PlatformCapabilities::detect();
if caps.simd_available || simd_config.force_simd {
let diff = F::simd_sub(&point, ¢roid.view());
let eta_array = Array1::from_elem(n_features, eta);
let weighted_diff = F::simd_mul(&diff.view(), &eta_array.view());
let updated_centroid = F::simd_add(¢roid.view(), &weighted_diff.view());
for j in 0..n_features {
centroids[[cluster, j]] = updated_centroid[j];
}
} else {
for j in 0..n_features {
centroids[[cluster, j]] =
centroids[[cluster, j]] + eta * (point[j] - centroids[[cluster, j]]);
}
}
}
if iteration % 10 == 0 && iteration > 0 {
let (labels, _) = vq_simd(data, centroids.view(), Some(&simd_config))?;
let inertia =
calculate_distortion_simd(data, centroids.view(), &labels, Some(&simd_config))?;
if iteration > 20 {
break;
}
}
}
let (final_labels, _) = vq_simd(data, centroids.view(), Some(&simd_config))?;
let final_inertia =
calculate_distortion_simd(data, centroids.view(), &final_labels, Some(&simd_config))?;
Ok((centroids, final_labels, final_inertia))
}
#[allow(dead_code)]
pub fn kmeans_plus_plus_simd<F>(
data: ArrayView2<F>,
k: usize,
simd_config: Option<&SimdOptimizationConfig>,
random_seed: Option<u64>,
) -> Result<Array2<F>>
where
F: Float
+ FromPrimitive
+ Debug
+ Send
+ Sync
+ SimdUnifiedOps
+ std::iter::Sum
+ std::fmt::Display,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
let default_config = SimdOptimizationConfig::default();
let simd_config = simd_config.unwrap_or(&default_config);
if k == 0 || k > n_samples {
return Err(ClusteringError::InvalidInput(format!(
"Number of clusters ({}) must be between 1 and number of samples ({})",
k, n_samples
)));
}
let mut rng = scirs2_core::random::rng();
let mut centroids = Array2::zeros((k, n_features));
let first_idx = rng.random_range(0..n_samples);
for j in 0..n_features {
centroids[[0, j]] = data[[first_idx, j]];
}
if k == 1 {
return Ok(centroids);
}
let caps = PlatformCapabilities::detect();
let use_simd = caps.simd_available || simd_config.force_simd;
for i in 1..k {
let mut min_distances = Array1::from_elem(n_samples, F::infinity());
if simd_config.enable_parallel
&& is_parallel_enabled()
&& n_samples > simd_config.parallel_chunk_size
{
let distances: Vec<F> = (0..n_samples)
.into_par_iter()
.map(|sample_idx| {
let sample = data.slice(s![sample_idx, ..]);
let mut min_dist = F::infinity();
for centroid_idx in 0..i {
let centroid = centroids.slice(s![centroid_idx, ..]);
let dist = if use_simd {
let diff = F::simd_sub(&sample, ¢roid);
F::simd_norm(&diff.view())
} else {
let mut sum = F::zero();
for k in 0..n_features {
let diff = sample[k] - centroid[k];
sum = sum + diff * diff;
}
sum.sqrt()
};
if dist < min_dist {
min_dist = dist;
}
}
min_dist
})
.collect();
for (idx, dist) in distances.into_iter().enumerate() {
min_distances[idx] = dist;
}
} else {
for sample_idx in 0..n_samples {
let sample = data.slice(s![sample_idx, ..]);
for centroid_idx in 0..i {
let centroid = centroids.slice(s![centroid_idx, ..]);
let dist = if use_simd {
let diff = F::simd_sub(&sample, ¢roid);
F::simd_norm(&diff.view())
} else {
let mut sum = F::zero();
for k in 0..n_features {
let diff = sample[k] - centroid[k];
sum = sum + diff * diff;
}
sum.sqrt()
};
if dist < min_distances[sample_idx] {
min_distances[sample_idx] = dist;
}
}
}
}
let mut weights = min_distances.mapv(|d| d * d);
let sum_weights = weights.sum();
if sum_weights > F::zero() {
weights.mapv_inplace(|w| w / sum_weights);
} else {
weights.fill(F::from(1.0 / n_samples as f64).expect("Failed to convert to float"));
}
let mut cum_weights = weights.clone();
for j in 1..n_samples {
cum_weights[j] = cum_weights[j] + cum_weights[j - 1];
}
let r = rng.random_range(0.0..1.0);
let r_f = F::from(r).expect("Failed to convert to float");
let mut selected_idx = n_samples - 1;
for j in 0..n_samples {
if cum_weights[j] >= r_f {
selected_idx = j;
break;
}
}
for j in 0..n_features {
centroids[[i, j]] = data[[selected_idx, j]];
}
}
Ok(centroids)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array2;
#[test]
fn test_kmeans_simd() {
let data = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 1.2, 1.8, 3.7, 4.2, 3.9, 3.9])
.expect("Operation failed");
let options = KMeansOptions {
max_iter: 10,
n_init: 1,
tol: 0.1,
..Default::default()
};
let (centroids, labels, inertia) =
kmeans_simd(data.view(), 2, Some(options), None).expect("Operation failed");
assert_eq!(centroids.shape(), &[2, 2]);
assert_eq!(labels.len(), 4);
assert!(inertia >= 0.0);
for &label in labels.iter() {
assert!(label < 2);
}
}
#[test]
fn test_mini_batch_kmeans_simd() {
let data = Array2::from_shape_vec((8, 2), (0..16).map(|x| x as f64).collect())
.expect("Operation failed");
let options = KMeansOptions {
max_iter: 5,
n_init: 1,
tol: 0.1,
..Default::default()
};
let (centroids, labels, inertia) =
mini_batch_kmeans_simd(data.view(), 2, 3, Some(options), None)
.expect("Operation failed");
assert_eq!(centroids.shape(), &[2, 2]);
assert_eq!(labels.len(), 8);
assert!(inertia >= 0.0);
for &label in labels.iter() {
assert!(label < 2);
}
}
#[test]
fn test_kmeans_plus_plus_simd() {
let data = Array2::from_shape_vec((6, 2), (0..12).map(|x| x as f64).collect())
.expect("Operation failed");
let centroids =
kmeans_plus_plus_simd(data.view(), 2, None, Some(42)).expect("Operation failed");
assert_eq!(centroids.shape(), &[2, 2]);
for i in 0..2 {
for j in (i + 1)..2 {
let dist = ((centroids[[i, 0]] - centroids[[j, 0]]).powi(2)
+ (centroids[[i, 1]] - centroids[[j, 1]]).powi(2))
.sqrt();
assert!(dist > 0.0, "Centroids should be different");
}
}
}
#[test]
fn test_compute_centroid_shift_simd() {
let oldcentroids =
Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).expect("Operation failed");
let newcentroids =
Array2::from_shape_vec((2, 2), vec![0.1, 0.1, 1.1, 1.1]).expect("Operation failed");
let config = SimdOptimizationConfig::default();
let shift = compute_centroid_shift_simd(oldcentroids.view(), newcentroids.view(), &config)
.expect("Operation failed");
let expected = 2.0 * (0.1f64.powi(2) + 0.1f64.powi(2)).sqrt();
assert_abs_diff_eq!(shift, expected, epsilon = 1e-10);
}
}