use super::distance::{
should_use_simd, simd_distance, simd_pairwise_condensed_distances,
simd_squared_euclidean_distance, SimdClusterConfig, SimdDistanceMetric,
};
use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::parallel_ops::*;
use scirs2_core::simd_ops::{AutoOptimizer, PlatformCapabilities, SimdUnifiedOps};
use std::fmt::Debug;
use crate::error::{ClusteringError, Result};
pub fn simd_assign_clusters<F>(
data: ArrayView2<F>,
centroids: ArrayView2<F>,
metric: SimdDistanceMetric,
config: Option<&SimdClusterConfig>,
) -> Result<(Array1<usize>, Array1<F>)>
where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
let n_centroids = centroids.shape()[0];
if centroids.shape()[1] != n_features {
return Err(ClusteringError::InvalidInput(format!(
"Data has {} features but centroids have {} features",
n_features,
centroids.shape()[1]
)));
}
let default_config = SimdClusterConfig::default();
let cfg = config.unwrap_or(&default_config);
let use_parallel =
cfg.enable_parallel && is_parallel_enabled() && n_samples > cfg.parallel_chunk_size;
if use_parallel {
simd_assign_clusters_parallel(data, centroids, metric, cfg)
} else {
simd_assign_clusters_sequential(data, centroids, metric, cfg)
}
}
fn simd_assign_clusters_sequential<F>(
data: ArrayView2<F>,
centroids: ArrayView2<F>,
metric: SimdDistanceMetric,
config: &SimdClusterConfig,
) -> Result<(Array1<usize>, Array1<F>)>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
let n_centroids = centroids.shape()[0];
let mut labels = Array1::zeros(n_samples);
let mut distances = Array1::from_elem(n_samples, F::infinity());
for i in 0..n_samples {
let point = data.row(i);
for j in 0..n_centroids {
let centroid = centroids.row(j);
let dist = simd_distance(point, centroid, metric, Some(config))?;
if dist < distances[i] {
distances[i] = dist;
labels[i] = j;
}
}
}
Ok((labels, distances))
}
fn simd_assign_clusters_parallel<F>(
data: ArrayView2<F>,
centroids: ArrayView2<F>,
metric: SimdDistanceMetric,
config: &SimdClusterConfig,
) -> Result<(Array1<usize>, Array1<F>)>
where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
let n_centroids = centroids.shape()[0];
let results: Vec<(usize, F)> = (0..n_samples)
.into_par_iter()
.map(|i| {
let point = data.row(i);
let mut best_label = 0;
let mut best_dist = F::infinity();
for j in 0..n_centroids {
let centroid = centroids.row(j);
let dist = simd_distance(point, centroid, metric, Some(config))
.unwrap_or_else(|_| F::infinity());
if dist < best_dist {
best_dist = dist;
best_label = j;
}
}
(best_label, best_dist)
})
.collect();
let mut labels = Array1::zeros(n_samples);
let mut distances = Array1::zeros(n_samples);
for (i, (label, dist)) in results.into_iter().enumerate() {
labels[i] = label;
distances[i] = dist;
}
Ok((labels, distances))
}
pub fn simd_centroid_update<F>(
data: ArrayView2<F>,
labels: &Array1<usize>,
k: usize,
config: Option<&SimdClusterConfig>,
) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
if labels.len() != n_samples {
return Err(ClusteringError::InvalidInput(format!(
"Labels length ({}) must match number of samples ({})",
labels.len(),
n_samples
)));
}
let default_config = SimdClusterConfig::default();
let cfg = config.unwrap_or(&default_config);
let use_simd = should_use_simd(n_features, cfg);
let use_parallel = cfg.enable_parallel && is_parallel_enabled() && k > 4;
if use_parallel {
simd_centroid_update_parallel(data, labels, k, n_features, use_simd)
} else {
simd_centroid_update_sequential(data, labels, k, n_features, use_simd)
}
}
fn simd_centroid_update_sequential<F>(
data: ArrayView2<F>,
labels: &Array1<usize>,
k: usize,
n_features: usize,
use_simd: bool,
) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
let mut centroids = Array2::zeros((k, n_features));
let mut counts = vec![0usize; k];
for i in 0..n_samples {
let cluster = labels[i];
if cluster >= k {
continue; }
counts[cluster] += 1;
if use_simd {
let current = centroids.slice(s![cluster, ..]).to_owned();
let point = data.row(i);
let sum = F::simd_add(¤t.view(), &point);
centroids.slice_mut(s![cluster, ..]).assign(&sum);
} else {
for j in 0..n_features {
centroids[[cluster, j]] = centroids[[cluster, j]] + data[[i, j]];
}
}
}
for cluster in 0..k {
if counts[cluster] > 0 {
let count_f = F::from_usize(counts[cluster]).unwrap_or_else(|| F::one());
if use_simd {
let centroid_row = centroids.slice(s![cluster, ..]).to_owned();
let count_arr = Array1::from_elem(n_features, count_f);
let mean = F::simd_div(¢roid_row.view(), &count_arr.view());
centroids.slice_mut(s![cluster, ..]).assign(&mean);
} else {
for j in 0..n_features {
centroids[[cluster, j]] = centroids[[cluster, j]] / count_f;
}
}
}
}
Ok(centroids)
}
fn simd_centroid_update_parallel<F>(
data: ArrayView2<F>,
labels: &Array1<usize>,
k: usize,
n_features: usize,
use_simd: bool,
) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
{
let mut cluster_indices: Vec<Vec<usize>> = vec![Vec::new(); k];
for (i, &label) in labels.iter().enumerate() {
if label < k {
cluster_indices[label].push(i);
}
}
let centroid_rows: Vec<Array1<F>> = (0..k)
.into_par_iter()
.map(|cluster| {
let indices = &cluster_indices[cluster];
if indices.is_empty() {
return Array1::zeros(n_features);
}
let mut sum = Array1::zeros(n_features);
for &idx in indices {
if use_simd {
let point = data.row(idx);
let new_sum = F::simd_add(&sum.view(), &point);
sum = new_sum;
} else {
for j in 0..n_features {
sum[j] = sum[j] + data[[idx, j]];
}
}
}
let count_f = F::from_usize(indices.len()).unwrap_or_else(|| F::one());
if use_simd {
let count_arr = Array1::from_elem(n_features, count_f);
F::simd_div(&sum.view(), &count_arr.view())
} else {
sum.mapv(|v| v / count_f)
}
})
.collect();
let mut centroids = Array2::zeros((k, n_features));
for (cluster, row) in centroid_rows.into_iter().enumerate() {
centroids.slice_mut(s![cluster, ..]).assign(&row);
}
Ok(centroids)
}
pub fn simd_logsumexp<F>(values: ArrayView1<F>, config: Option<&SimdClusterConfig>) -> Result<F>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
if values.is_empty() {
return Err(ClusteringError::InvalidInput(
"Cannot compute logsumexp of empty array".to_string(),
));
}
let default_config = SimdClusterConfig::default();
let cfg = config.unwrap_or(&default_config);
let use_simd = should_use_simd(values.len(), cfg);
if use_simd {
let max_val = F::simd_max_element(&values);
if max_val.is_infinite() && max_val.is_sign_negative() {
return Ok(F::neg_infinity());
}
let max_arr = Array1::from_elem(values.len(), max_val);
let shifted = F::simd_sub(&values, &max_arr.view());
let exp_vals = F::simd_exp(&shifted.view());
let sum_exp = F::simd_sum(&exp_vals.view());
Ok(max_val + sum_exp.ln())
} else {
let mut max_val = F::neg_infinity();
for &v in values.iter() {
if v > max_val {
max_val = v;
}
}
if max_val.is_infinite() && max_val.is_sign_negative() {
return Ok(F::neg_infinity());
}
let mut sum = F::zero();
for &v in values.iter() {
sum = sum + (v - max_val).exp();
}
Ok(max_val + sum.ln())
}
}
pub fn simd_logsumexp_rows<F>(
values: ArrayView2<F>,
config: Option<&SimdClusterConfig>,
) -> Result<Array1<F>>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
let n_rows = values.shape()[0];
let mut result = Array1::zeros(n_rows);
for i in 0..n_rows {
result[i] = simd_logsumexp(values.row(i), config)?;
}
Ok(result)
}
pub fn simd_gmm_log_responsibilities<F>(
log_prob: ArrayView2<F>,
log_weights: ArrayView1<F>,
config: Option<&SimdClusterConfig>,
) -> Result<(Array2<F>, F)>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
let n_samples = log_prob.shape()[0];
let n_components = log_prob.shape()[1];
if log_weights.len() != n_components {
return Err(ClusteringError::InvalidInput(format!(
"log_weights length ({}) must match number of components ({})",
log_weights.len(),
n_components
)));
}
let default_config = SimdClusterConfig::default();
let cfg = config.unwrap_or(&default_config);
let use_simd = should_use_simd(n_components, cfg);
let mut log_resp = log_prob.to_owned();
for i in 0..n_samples {
if use_simd {
let row = log_resp.slice(s![i, ..]).to_owned();
let weighted = F::simd_add(&row.view(), &log_weights);
log_resp.slice_mut(s![i, ..]).assign(&weighted);
} else {
for k in 0..n_components {
log_resp[[i, k]] = log_resp[[i, k]] + log_weights[k];
}
}
}
let log_norm = simd_logsumexp_rows(log_resp.view(), Some(cfg))?;
let mut resp = Array2::zeros((n_samples, n_components));
for i in 0..n_samples {
if use_simd {
let row = log_resp.slice(s![i, ..]).to_owned();
let norm_arr = Array1::from_elem(n_components, log_norm[i]);
let shifted = F::simd_sub(&row.view(), &norm_arr.view());
let exp_vals = F::simd_exp(&shifted.view());
resp.slice_mut(s![i, ..]).assign(&exp_vals);
} else {
for k in 0..n_components {
resp[[i, k]] = (log_resp[[i, k]] - log_norm[i]).exp();
}
}
}
let lower_bound = if use_simd {
F::simd_sum(&log_norm.view()) / F::from_usize(n_samples).unwrap_or_else(|| F::one())
} else {
let mut sum = F::zero();
for &v in log_norm.iter() {
sum = sum + v;
}
sum / F::from_usize(n_samples).unwrap_or_else(|| F::one())
};
Ok((resp, lower_bound))
}
pub fn simd_gmm_weighted_mean<F>(
data: ArrayView2<F>,
responsibilities: ArrayView1<F>,
config: Option<&SimdClusterConfig>,
) -> Result<Array1<F>>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
if responsibilities.len() != n_samples {
return Err(ClusteringError::InvalidInput(format!(
"Responsibilities length ({}) must match number of samples ({})",
responsibilities.len(),
n_samples
)));
}
let default_config = SimdClusterConfig::default();
let cfg = config.unwrap_or(&default_config);
let use_simd = should_use_simd(n_features, cfg);
let total_resp = if use_simd {
F::simd_sum(&responsibilities)
} else {
let mut s = F::zero();
for &r in responsibilities.iter() {
s = s + r;
}
s
};
if total_resp <= F::zero() {
return Err(ClusteringError::ComputationError(
"Total responsibility is zero or negative; cannot compute weighted mean".to_string(),
));
}
let mut weighted_sum = Array1::zeros(n_features);
for i in 0..n_samples {
let r = responsibilities[i];
if use_simd {
let point = data.row(i);
let r_arr = Array1::from_elem(n_features, r);
let weighted_point = F::simd_mul(&point, &r_arr.view());
let new_sum = F::simd_add(&weighted_sum.view(), &weighted_point.view());
weighted_sum = new_sum;
} else {
for j in 0..n_features {
weighted_sum[j] = weighted_sum[j] + data[[i, j]] * r;
}
}
}
if use_simd {
let total_arr = Array1::from_elem(n_features, total_resp);
Ok(F::simd_div(&weighted_sum.view(), &total_arr.view()))
} else {
Ok(weighted_sum.mapv(|v| v / total_resp))
}
}
pub fn simd_epsilon_neighborhood<F>(
data: ArrayView2<F>,
query_idx: usize,
eps: F,
metric: SimdDistanceMetric,
config: Option<&SimdClusterConfig>,
) -> Result<Vec<usize>>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
if query_idx >= n_samples {
return Err(ClusteringError::InvalidInput(format!(
"Query index {} is out of bounds (data has {} samples)",
query_idx, n_samples
)));
}
let default_config = SimdClusterConfig::default();
let cfg = config.unwrap_or(&default_config);
let query = data.row(query_idx);
let mut neighbors = Vec::new();
for j in 0..n_samples {
if j == query_idx {
continue;
}
let dist = simd_distance(query, data.row(j), metric, Some(cfg))?;
if dist <= eps {
neighbors.push(j);
}
}
Ok(neighbors)
}
pub fn simd_batch_epsilon_neighborhood<F>(
data: ArrayView2<F>,
eps: F,
metric: SimdDistanceMetric,
config: Option<&SimdClusterConfig>,
) -> Result<Vec<Vec<usize>>>
where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
if n_samples == 0 {
return Err(ClusteringError::InvalidInput(
"Data must have at least one sample".to_string(),
));
}
let default_config = SimdClusterConfig::default();
let cfg = config.unwrap_or(&default_config);
let use_parallel =
cfg.enable_parallel && is_parallel_enabled() && n_samples > cfg.parallel_chunk_size;
if use_parallel {
simd_batch_neighborhood_parallel(data, eps, metric, cfg)
} else {
simd_batch_neighborhood_sequential(data, eps, metric, cfg)
}
}
fn simd_batch_neighborhood_sequential<F>(
data: ArrayView2<F>,
eps: F,
metric: SimdDistanceMetric,
config: &SimdClusterConfig,
) -> Result<Vec<Vec<usize>>>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
let mut scratch: Vec<Vec<usize>> = (0..n_samples).map(|_| Vec::new()).collect();
for i in 0..n_samples {
for j in (i + 1)..n_samples {
let dist = simd_distance(data.row(i), data.row(j), metric, Some(config))?;
if dist <= eps {
scratch[i].push(j);
scratch[j].push(i);
}
}
}
let neighborhoods = scratch;
Ok(neighborhoods)
}
fn simd_batch_neighborhood_parallel<F>(
data: ArrayView2<F>,
eps: F,
metric: SimdDistanceMetric,
config: &SimdClusterConfig,
) -> Result<Vec<Vec<usize>>>
where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
let neighborhoods: Vec<Vec<usize>> = (0..n_samples)
.into_par_iter()
.map(|i| {
let query = data.row(i);
let mut neighbors = Vec::new();
for j in 0..n_samples {
if j == i {
continue;
}
let dist = simd_distance(query, data.row(j), metric, Some(config))
.unwrap_or_else(|_| F::infinity());
if dist <= eps {
neighbors.push(j);
}
}
neighbors
})
.collect();
Ok(neighborhoods)
}
pub fn simd_compute_distortion<F>(
data: ArrayView2<F>,
centroids: ArrayView2<F>,
labels: &Array1<usize>,
config: Option<&SimdClusterConfig>,
) -> Result<F>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
if labels.len() != n_samples {
return Err(ClusteringError::InvalidInput(format!(
"Labels length ({}) must match number of samples ({})",
labels.len(),
n_samples
)));
}
let default_config = SimdClusterConfig::default();
let cfg = config.unwrap_or(&default_config);
let mut total = F::zero();
for i in 0..n_samples {
let cluster = labels[i];
if cluster < centroids.shape()[0] {
let sq_dist =
simd_squared_euclidean_distance(data.row(i), centroids.row(cluster), Some(cfg))?;
total = total + sq_dist;
}
}
Ok(total)
}
pub fn simd_linkage_distances<F>(
data: ArrayView2<F>,
metric: SimdDistanceMetric,
config: Option<&SimdClusterConfig>,
) -> Result<Array1<F>>
where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
{
simd_pairwise_condensed_distances(data, metric, config)
}