use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
use scirs2_core::numeric::{Float, FromPrimitive, Zero};
use scirs2_core::parallel_ops::*;
use scirs2_core::simd_ops::{AutoOptimizer, PlatformCapabilities, SimdUnifiedOps};
use std::fmt::Debug;
use crate::error::{ClusteringError, Result};
use statrs::statistics::Statistics;
#[derive(Debug, Clone)]
pub struct SimdOptimizationConfig {
pub simd_threshold: usize,
pub enable_parallel: bool,
pub parallel_chunk_size: usize,
pub cache_friendly: bool,
pub force_simd: bool,
}
impl Default for SimdOptimizationConfig {
fn default() -> Self {
Self {
simd_threshold: 64,
enable_parallel: true,
parallel_chunk_size: 1024,
cache_friendly: true,
force_simd: false,
}
}
}
#[allow(dead_code)]
pub fn euclidean_distance_simd<F>(
x: ArrayView1<F>,
y: ArrayView1<F>,
config: Option<&SimdOptimizationConfig>,
) -> Result<F>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
if x.len() != y.len() {
return Err(ClusteringError::InvalidInput(format!(
"Vectors must have the same length: got {} and {}",
x.len(),
y.len()
)));
}
let default_config = SimdOptimizationConfig::default();
let config = config.unwrap_or(&default_config);
let caps = PlatformCapabilities::detect();
let optimizer = AutoOptimizer::new();
if (caps.simd_available && (optimizer.should_use_simd(x.len()) || config.force_simd))
|| x.len() >= config.simd_threshold
{
let diff = F::simd_sub(&x, &y);
Ok(F::simd_norm(&diff.view()))
} else {
let mut sum = F::zero();
for i in 0..x.len() {
let diff = x[i] - y[i];
sum = sum + diff * diff;
}
Ok(sum.sqrt())
}
}
#[allow(dead_code)]
pub fn whiten_simd<F>(obs: &Array2<F>, config: Option<&SimdOptimizationConfig>) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
{
let default_config = SimdOptimizationConfig::default();
let config = config.unwrap_or(&default_config);
let n_samples = obs.shape()[0];
let n_features = obs.shape()[1];
if n_samples == 0 || n_features == 0 {
return Err(ClusteringError::InvalidInput(
"Input data cannot be empty".to_string(),
));
}
let caps = PlatformCapabilities::detect();
let optimizer = AutoOptimizer::new();
let use_simd = caps.simd_available
&& (optimizer.should_use_simd(n_samples * n_features) || config.force_simd);
if use_simd && config.enable_parallel && n_features > config.parallel_chunk_size {
whiten_simd_parallel(obs, config)
} else if use_simd {
whiten_simd_sequential(obs)
} else {
whiten_scalar_fallback(obs)
}
}
#[allow(dead_code)]
fn whiten_simd_sequential<F>(obs: &Array2<F>) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
let n_samples = obs.shape()[0];
let n_features = obs.shape()[1];
let n_samples_f = F::from(n_samples).expect("Failed to convert to float");
let mut means = Array1::<F>::zeros(n_features);
for j in 0..n_features {
let column = obs.column(j);
means[j] = F::simd_sum(&column) / n_samples_f;
}
let mut stds = Array1::<F>::zeros(n_features);
for j in 0..n_features {
let column = obs.column(j);
let mean_array = Array1::from_elem(n_samples, means[j]);
let diff = F::simd_sub(&column, &mean_array.view());
let squared_diff = F::simd_mul(&diff.view(), &diff.view());
let variance = F::simd_sum(&squared_diff.view())
/ F::from(n_samples - 1).expect("Failed to convert to float");
stds[j] = variance.sqrt();
if stds[j] < F::from(1e-10).expect("Failed to convert constant to float") {
stds[j] = F::one();
}
}
let mut whitened = Array2::<F>::zeros((n_samples, n_features));
for j in 0..n_features {
let column = obs.column(j);
let mean_array = Array1::from_elem(n_samples, means[j]);
let std_array = Array1::from_elem(n_samples, stds[j]);
let centered = F::simd_sub(&column, &mean_array.view());
let normalized = F::simd_div(¢ered.view(), &std_array.view());
for i in 0..n_samples {
whitened[[i, j]] = normalized[i];
}
}
Ok(whitened)
}
#[allow(dead_code)]
fn whiten_simd_parallel<F>(obs: &Array2<F>, config: &SimdOptimizationConfig) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
{
let n_samples = obs.shape()[0];
let n_features = obs.shape()[1];
let n_samples_f = F::from(n_samples).expect("Failed to convert to float");
let means: Array1<F> = if is_parallel_enabled() {
(0..n_features)
.into_par_iter()
.map(|j| {
let column = obs.column(j);
F::simd_sum(&column) / n_samples_f
})
.collect::<Vec<_>>()
.into()
} else {
let mut means = Array1::<F>::zeros(n_features);
for j in 0..n_features {
let column = obs.column(j);
means[j] = F::simd_sum(&column) / n_samples_f;
}
means
};
let stds: Array1<F> = if is_parallel_enabled() {
(0..n_features)
.into_par_iter()
.map(|j| {
let column = obs.column(j);
let mean_array = Array1::from_elem(n_samples, means[j]);
let diff = F::simd_sub(&column, &mean_array.view());
let squared_diff = F::simd_mul(&diff.view(), &diff.view());
let variance = F::simd_sum(&squared_diff.view())
/ F::from(n_samples - 1).expect("Failed to convert to float");
let std = variance.sqrt();
if std < F::from(1e-10).expect("Failed to convert constant to float") {
F::one()
} else {
std
}
})
.collect::<Vec<_>>()
.into()
} else {
whiten_simd_sequential(obs)?
.into_shape_with_order((n_samples, n_features))
.expect("Operation failed");
return whiten_simd_sequential(obs);
};
let mut whitened = Array2::<F>::zeros((n_samples, n_features));
if is_parallel_enabled() {
let chunk_size = config.parallel_chunk_size;
let normalized_columns: Vec<Array1<F>> = (0..n_features)
.into_par_iter()
.map(|j| {
let column = obs.column(j);
let mean_array = Array1::from_elem(n_samples, means[j]);
let std_array = Array1::from_elem(n_samples, stds[j]);
let centered = F::simd_sub(&column, &mean_array.view());
F::simd_div(¢ered.view(), &std_array.view())
})
.collect();
for (j, normalized_column) in normalized_columns.iter().enumerate() {
for i in 0..n_samples {
whitened[[i, j]] = normalized_column[i];
}
}
} else {
for j in 0..n_features {
let column = obs.column(j);
let mean_array = Array1::from_elem(n_samples, means[j]);
let std_array = Array1::from_elem(n_samples, stds[j]);
let centered = F::simd_sub(&column, &mean_array.view());
let normalized = F::simd_div(¢ered.view(), &std_array.view());
for i in 0..n_samples {
whitened[[i, j]] = normalized[i];
}
}
}
Ok(whitened)
}
#[allow(dead_code)]
fn whiten_scalar_fallback<F>(obs: &Array2<F>) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug,
{
let n_samples = obs.shape()[0];
let n_features = obs.shape()[1];
let mut means = Array1::<F>::zeros(n_features);
for j in 0..n_features {
let mut sum = F::zero();
for i in 0..n_samples {
sum = sum + obs[[i, j]];
}
means[j] = sum / F::from(n_samples).expect("Failed to convert to float");
}
let mut stds = Array1::<F>::zeros(n_features);
for j in 0..n_features {
let mut sum = F::zero();
for i in 0..n_samples {
let diff = obs[[i, j]] - means[j];
sum = sum + diff * diff;
}
stds[j] = (sum / F::from(n_samples - 1).expect("Failed to convert to float")).sqrt();
if stds[j] < F::from(1e-10).expect("Failed to convert constant to float") {
stds[j] = F::one();
}
}
let mut whitened = Array2::<F>::zeros((n_samples, n_features));
for i in 0..n_samples {
for j in 0..n_features {
whitened[[i, j]] = (obs[[i, j]] - means[j]) / stds[j];
}
}
Ok(whitened)
}
#[allow(dead_code)]
pub fn vq_simd<F>(
data: ArrayView2<F>,
centroids: ArrayView2<F>,
config: Option<&SimdOptimizationConfig>,
) -> Result<(Array1<usize>, Array1<F>)>
where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
{
if data.shape()[1] != centroids.shape()[1] {
return Err(ClusteringError::InvalidInput(format!(
"Data and centroids must have the same number of features: {} vs {}",
data.shape()[1],
centroids.shape()[1]
)));
}
let default_config = SimdOptimizationConfig::default();
let config = config.unwrap_or(&default_config);
let n_samples = data.shape()[0];
let n_centroids = centroids.shape()[0];
if config.enable_parallel && is_parallel_enabled() && n_samples > config.parallel_chunk_size {
vq_simd_parallel(data, centroids, config)
} else {
vq_simd_sequential(data, centroids, config)
}
}
#[allow(dead_code)]
fn vq_simd_sequential<F>(
data: ArrayView2<F>,
centroids: ArrayView2<F>,
config: &SimdOptimizationConfig,
) -> Result<(Array1<usize>, Array1<F>)>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
let n_centroids = centroids.shape()[0];
let mut labels = Array1::zeros(n_samples);
let mut distances = Array1::zeros(n_samples);
let caps = PlatformCapabilities::detect();
let use_simd = caps.simd_available || config.force_simd;
for i in 0..n_samples {
let point = data.slice(s![i, ..]);
let mut min_dist = F::infinity();
let mut closest_centroid = 0;
for j in 0..n_centroids {
let centroid = centroids.slice(s![j, ..]);
let dist = if use_simd {
let diff = F::simd_sub(&point, ¢roid);
F::simd_norm(&diff.view())
} else {
let mut sum = F::zero();
for k in 0..point.len() {
let diff = point[k] - centroid[k];
sum = sum + diff * diff;
}
sum.sqrt()
};
if dist < min_dist {
min_dist = dist;
closest_centroid = j;
}
}
labels[i] = closest_centroid;
distances[i] = min_dist;
}
Ok((labels, distances))
}
#[allow(dead_code)]
fn vq_simd_parallel<F>(
data: ArrayView2<F>,
centroids: ArrayView2<F>,
config: &SimdOptimizationConfig,
) -> Result<(Array1<usize>, Array1<F>)>
where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
let n_centroids = centroids.shape()[0];
let caps = PlatformCapabilities::detect();
let use_simd = caps.simd_available || config.force_simd;
let results: Vec<(usize, F)> = (0..n_samples)
.into_par_iter()
.map(|i| {
let point = data.slice(s![i, ..]);
let mut min_dist = F::infinity();
let mut closest_centroid = 0;
for j in 0..n_centroids {
let centroid = centroids.slice(s![j, ..]);
let dist = if use_simd {
let diff = F::simd_sub(&point, ¢roid);
F::simd_norm(&diff.view())
} else {
let mut sum = F::zero();
for k in 0..point.len() {
let diff = point[k] - centroid[k];
sum = sum + diff * diff;
}
sum.sqrt()
};
if dist < min_dist {
min_dist = dist;
closest_centroid = j;
}
}
(closest_centroid, min_dist)
})
.collect();
let mut labels = Array1::zeros(n_samples);
let mut distances = Array1::zeros(n_samples);
for (i, (label, distance)) in results.into_iter().enumerate() {
labels[i] = label;
distances[i] = distance;
}
Ok((labels, distances))
}
#[allow(dead_code)]
pub fn compute_centroids_simd<F>(
data: ArrayView2<F>,
labels: &Array1<usize>,
k: usize,
config: Option<&SimdOptimizationConfig>,
) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps + std::iter::Sum,
{
let default_config = SimdOptimizationConfig::default();
let config = config.unwrap_or(&default_config);
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
if labels.len() != n_samples {
return Err(ClusteringError::InvalidInput(
"Labels array length must match number of data points".to_string(),
));
}
let caps = PlatformCapabilities::detect();
let use_simd = caps.simd_available || config.force_simd;
if config.enable_parallel && is_parallel_enabled() && k > 4 {
compute_centroids_simd_parallel(data, labels, k, use_simd)
} else {
compute_centroids_simd_sequential(data, labels, k, use_simd)
}
}
#[allow(dead_code)]
fn compute_centroids_simd_sequential<F>(
data: ArrayView2<F>,
labels: &Array1<usize>,
k: usize,
use_simd: bool,
) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps + std::iter::Sum,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
let mut centroids = Array2::zeros((k, n_features));
let mut counts = Array1::<usize>::zeros(k);
for i in 0..n_samples {
let cluster = labels[i];
if cluster >= k {
return Err(ClusteringError::InvalidInput(format!(
"Label {} exceeds number of clusters {}",
cluster, k
)));
}
counts[cluster] += 1;
if use_simd {
let point = data.slice(s![i, ..]);
let centroid_row = centroids.slice_mut(s![cluster, ..]);
let updated_centroid = F::simd_add(¢roid_row.view(), &point);
for j in 0..n_features {
centroids[[cluster, j]] = updated_centroid[j];
}
} else {
for j in 0..n_features {
centroids[[cluster, j]] = centroids[[cluster, j]] + data[[i, j]];
}
}
}
for i in 0..k {
if counts[i] == 0 {
if n_samples > 0 {
let random_idx = i % n_samples; for j in 0..n_features {
centroids[[i, j]] = data[[random_idx, j]];
}
}
} else {
let count_f = F::from(counts[i]).expect("Failed to convert to float");
if use_simd {
let centroid_row = centroids.slice(s![i, ..]);
let count_array = Array1::from_elem(n_features, count_f);
let normalized = F::simd_div(¢roid_row, &count_array.view());
for j in 0..n_features {
centroids[[i, j]] = normalized[j];
}
} else {
for j in 0..n_features {
centroids[[i, j]] = centroids[[i, j]] / count_f;
}
}
}
}
Ok(centroids)
}
#[allow(dead_code)]
fn compute_centroids_simd_parallel<F>(
data: ArrayView2<F>,
labels: &Array1<usize>,
k: usize,
use_simd: bool,
) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps + std::iter::Sum,
{
let n_features = data.shape()[1];
let centroids: Vec<Array1<F>> = (0..k)
.into_par_iter()
.map(|cluster_id| {
let mut sum = Array1::zeros(n_features);
let mut count = 0;
for i in 0..data.shape()[0] {
if labels[i] == cluster_id {
count += 1;
let point = data.slice(s![i, ..]);
if use_simd {
let updated_sum = F::simd_add(&sum.view(), &point);
for j in 0..n_features {
sum[j] = updated_sum[j];
}
} else {
for j in 0..n_features {
sum[j] = sum[j] + point[j];
}
}
}
}
if count == 0 {
if data.shape()[0] > 0 {
let random_idx = cluster_id % data.shape()[0];
data.slice(s![random_idx, ..]).to_owned()
} else {
sum
}
} else {
let count_f = F::from(count).expect("Failed to convert to float");
if use_simd {
let count_array = Array1::from_elem(n_features, count_f);
let normalized = F::simd_div(&sum.view(), &count_array.view());
normalized
} else {
sum.mapv(|x| x / count_f)
}
}
})
.collect();
let mut result = Array2::zeros((k, n_features));
for (i, centroid) in centroids.into_iter().enumerate() {
for j in 0..n_features {
result[[i, j]] = centroid[j];
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn calculate_distortion_simd<F>(
data: ArrayView2<F>,
centroids: ArrayView2<F>,
labels: &Array1<usize>,
config: Option<&SimdOptimizationConfig>,
) -> Result<F>
where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps + std::iter::Sum,
{
let default_config = SimdOptimizationConfig::default();
let config = config.unwrap_or(&default_config);
let n_samples = data.shape()[0];
if labels.len() != n_samples {
return Err(ClusteringError::InvalidInput(
"Labels array length must match number of data points".to_string(),
));
}
let caps = PlatformCapabilities::detect();
let use_simd = caps.simd_available || config.force_simd;
if config.enable_parallel && is_parallel_enabled() && n_samples > config.parallel_chunk_size {
calculate_distortion_simd_parallel(data, centroids, labels, use_simd)
} else {
calculate_distortion_simd_sequential(data, centroids, labels, use_simd)
}
}
#[allow(dead_code)]
fn calculate_distortion_simd_sequential<F>(
data: ArrayView2<F>,
centroids: ArrayView2<F>,
labels: &Array1<usize>,
use_simd: bool,
) -> Result<F>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
let mut total_distortion = F::zero();
for i in 0..n_samples {
let cluster = labels[i];
if cluster >= centroids.shape()[0] {
return Err(ClusteringError::InvalidInput(format!(
"Label {} exceeds number of centroids {}",
cluster,
centroids.shape()[0]
)));
}
let point = data.slice(s![i, ..]);
let centroid = centroids.slice(s![cluster, ..]);
let squared_distance = if use_simd {
let diff = F::simd_sub(&point, ¢roid);
let squared_diff = F::simd_mul(&diff.view(), &diff.view());
F::simd_sum(&squared_diff.view())
} else {
let mut sum = F::zero();
for j in 0..point.len() {
let diff = point[j] - centroid[j];
sum = sum + diff * diff;
}
sum
};
total_distortion = total_distortion + squared_distance;
}
Ok(total_distortion)
}
#[allow(dead_code)]
fn calculate_distortion_simd_parallel<F>(
data: ArrayView2<F>,
centroids: ArrayView2<F>,
labels: &Array1<usize>,
use_simd: bool,
) -> Result<F>
where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps + std::iter::Sum,
{
let n_samples = data.shape()[0];
for &label in labels.iter() {
if label >= centroids.shape()[0] {
return Err(ClusteringError::InvalidInput(format!(
"Label {} exceeds number of centroids {}",
label,
centroids.shape()[0]
)));
}
}
let squared_distances: Vec<F> = (0..n_samples)
.into_par_iter()
.map(|i| {
let cluster = labels[i];
let point = data.slice(s![i, ..]);
let centroid = centroids.slice(s![cluster, ..]);
if use_simd {
let diff = F::simd_sub(&point, ¢roid);
let squared_diff = F::simd_mul(&diff.view(), &diff.view());
F::simd_sum(&squared_diff.view())
} else {
let mut sum = F::zero();
for j in 0..point.len() {
let diff = point[j] - centroid[j];
sum = sum + diff * diff;
}
sum
}
})
.collect();
Ok(squared_distances.into_iter().sum())
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array2;
#[test]
fn test_euclidean_distance_simd() {
let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let y = Array1::from_vec(vec![4.0, 5.0, 6.0]);
let distance = euclidean_distance_simd(x.view(), y.view(), None).expect("Operation failed");
let expected = ((4.0 - 1.0).powi(2) + (5.0 - 2.0).powi(2) + (6.0 - 3.0).powi(2)).sqrt();
assert_abs_diff_eq!(distance, expected, epsilon = 1e-10);
}
#[test]
fn test_whiten_simd() {
let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 1.5, 2.5, 0.5, 1.5])
.expect("Operation failed");
let config = SimdOptimizationConfig {
enable_parallel: false,
force_simd: false,
..Default::default()
};
let whitened = whiten_simd(&data, Some(&config)).expect("Operation failed");
let col_means: Vec<f64> = (0..2).map(|j| whitened.column(j).mean()).collect();
for mean in col_means {
assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-8);
}
}
#[test]
#[ignore = "timeout"]
fn test_vq_simd() {
let data = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0])
.expect("Operation failed");
let centroids =
Array2::from_shape_vec((2, 2), vec![0.25, 0.25, 0.75, 0.75]).expect("Operation failed");
let config = SimdOptimizationConfig {
enable_parallel: false,
force_simd: false,
..Default::default()
};
let (labels, distances) =
vq_simd(data.view(), centroids.view(), Some(&config)).expect("Operation failed");
assert_eq!(labels.len(), 3);
assert_eq!(distances.len(), 3);
for &distance in distances.iter() {
assert!(distance >= 0.0);
}
}
#[test]
#[ignore = "timeout"]
fn test_compute_centroids_simd() {
let data = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0])
.expect("Operation failed");
let labels = Array1::from_vec(vec![0, 0, 1]);
let config = SimdOptimizationConfig {
enable_parallel: false,
force_simd: false,
..Default::default()
};
let centroids = compute_centroids_simd(data.view(), &labels, 2, Some(&config))
.expect("Operation failed");
assert_eq!(centroids.shape(), &[2, 2]);
assert_abs_diff_eq!(centroids[[0, 0]], 0.5, epsilon = 1e-8);
assert_abs_diff_eq!(centroids[[0, 1]], 0.0, epsilon = 1e-8);
assert_abs_diff_eq!(centroids[[1, 0]], 0.0, epsilon = 1e-8);
assert_abs_diff_eq!(centroids[[1, 1]], 1.0, epsilon = 1e-8);
}
#[test]
#[ignore = "timeout"]
fn test_calculate_distortion_simd() {
let data = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0])
.expect("Operation failed");
let centroids =
Array2::from_shape_vec((2, 2), vec![0.5, 0.0, 0.0, 1.0]).expect("Operation failed");
let labels = Array1::from_vec(vec![0, 0, 1]);
let config = SimdOptimizationConfig {
enable_parallel: false,
force_simd: false,
..Default::default()
};
let distortion =
calculate_distortion_simd(data.view(), centroids.view(), &labels, Some(&config))
.expect("Operation failed");
let expected = 0.5 * 0.5 + 0.5 * 0.5 + 0.0;
assert_abs_diff_eq!(distortion, expected, epsilon = 1e-8);
}
}