use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::parallel_ops::*;
use scirs2_core::simd_ops::{AutoOptimizer, PlatformCapabilities, SimdUnifiedOps};
use std::fmt::Debug;
use crate::error::{ClusteringError, Result};
#[derive(Debug, Clone)]
pub struct SimdClusterConfig {
pub simd_threshold: usize,
pub enable_parallel: bool,
pub parallel_chunk_size: usize,
pub block_size: usize,
pub force_simd: bool,
}
impl Default for SimdClusterConfig {
fn default() -> Self {
Self {
simd_threshold: 32,
enable_parallel: true,
parallel_chunk_size: 512,
block_size: 128,
force_simd: false,
}
}
}
pub(super) fn should_use_simd(n_elements: usize, config: &SimdClusterConfig) -> bool {
let caps = PlatformCapabilities::detect();
let optimizer = AutoOptimizer::new();
config.force_simd
|| (caps.simd_available
&& (optimizer.should_use_simd(n_elements) || n_elements >= config.simd_threshold))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimdDistanceMetric {
Euclidean,
SquaredEuclidean,
Manhattan,
}
pub fn simd_euclidean_distance<F>(
a: ArrayView1<F>,
b: ArrayView1<F>,
config: Option<&SimdClusterConfig>,
) -> Result<F>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
if a.len() != b.len() {
return Err(ClusteringError::InvalidInput(format!(
"Vectors must have the same length: got {} and {}",
a.len(),
b.len()
)));
}
if a.is_empty() {
return Ok(F::zero());
}
let default_config = SimdClusterConfig::default();
let cfg = config.unwrap_or(&default_config);
if should_use_simd(a.len(), cfg) {
let diff = F::simd_sub(&a, &b);
Ok(F::simd_norm(&diff.view()))
} else {
let mut sum = F::zero();
for i in 0..a.len() {
let d = a[i] - b[i];
sum = sum + d * d;
}
Ok(sum.sqrt())
}
}
pub fn simd_squared_euclidean_distance<F>(
a: ArrayView1<F>,
b: ArrayView1<F>,
config: Option<&SimdClusterConfig>,
) -> Result<F>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
if a.len() != b.len() {
return Err(ClusteringError::InvalidInput(format!(
"Vectors must have the same length: got {} and {}",
a.len(),
b.len()
)));
}
if a.is_empty() {
return Ok(F::zero());
}
let default_config = SimdClusterConfig::default();
let cfg = config.unwrap_or(&default_config);
if should_use_simd(a.len(), cfg) {
let diff = F::simd_sub(&a, &b);
let sq = F::simd_mul(&diff.view(), &diff.view());
Ok(F::simd_sum(&sq.view()))
} else {
let mut sum = F::zero();
for i in 0..a.len() {
let d = a[i] - b[i];
sum = sum + d * d;
}
Ok(sum)
}
}
pub fn simd_manhattan_distance<F>(
a: ArrayView1<F>,
b: ArrayView1<F>,
config: Option<&SimdClusterConfig>,
) -> Result<F>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
if a.len() != b.len() {
return Err(ClusteringError::InvalidInput(format!(
"Vectors must have the same length: got {} and {}",
a.len(),
b.len()
)));
}
if a.is_empty() {
return Ok(F::zero());
}
let default_config = SimdClusterConfig::default();
let cfg = config.unwrap_or(&default_config);
if should_use_simd(a.len(), cfg) {
let diff = F::simd_sub(&a, &b);
let abs_diff = F::simd_abs(&diff.view());
Ok(F::simd_sum(&abs_diff.view()))
} else {
let mut sum = F::zero();
for i in 0..a.len() {
sum = sum + (a[i] - b[i]).abs();
}
Ok(sum)
}
}
pub fn simd_distance<F>(
a: ArrayView1<F>,
b: ArrayView1<F>,
metric: SimdDistanceMetric,
config: Option<&SimdClusterConfig>,
) -> Result<F>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
match metric {
SimdDistanceMetric::Euclidean => simd_euclidean_distance(a, b, config),
SimdDistanceMetric::SquaredEuclidean => simd_squared_euclidean_distance(a, b, config),
SimdDistanceMetric::Manhattan => simd_manhattan_distance(a, b, config),
}
}
pub fn simd_pairwise_distance_matrix<F>(
data: ArrayView2<F>,
metric: SimdDistanceMetric,
config: Option<&SimdClusterConfig>,
) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
if n_samples == 0 {
return Err(ClusteringError::InvalidInput(
"Data must have at least one sample".to_string(),
));
}
let default_config = SimdClusterConfig::default();
let cfg = config.unwrap_or(&default_config);
let mut distances = Array2::zeros((n_samples, n_samples));
let use_parallel =
cfg.enable_parallel && is_parallel_enabled() && n_samples > cfg.parallel_chunk_size;
if use_parallel {
simd_pairwise_matrix_parallel(data, metric, cfg, &mut distances);
} else {
simd_pairwise_matrix_sequential(data, metric, cfg, &mut distances)?;
}
Ok(distances)
}
fn simd_pairwise_matrix_sequential<F>(
data: ArrayView2<F>,
metric: SimdDistanceMetric,
config: &SimdClusterConfig,
distances: &mut Array2<F>,
) -> Result<()>
where
F: Float + FromPrimitive + Debug + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
let block_size = config.block_size;
for block_i in (0..n_samples).step_by(block_size) {
let end_i = (block_i + block_size).min(n_samples);
for block_j in (block_i..n_samples).step_by(block_size) {
let end_j = (block_j + block_size).min(n_samples);
for i in block_i..end_i {
let start_j = if block_i == block_j { i + 1 } else { block_j };
for j in start_j..end_j {
let dist = simd_distance(data.row(i), data.row(j), metric, Some(config))?;
distances[[i, j]] = dist;
distances[[j, i]] = dist;
}
}
}
}
Ok(())
}
fn simd_pairwise_matrix_parallel<F>(
data: ArrayView2<F>,
metric: SimdDistanceMetric,
config: &SimdClusterConfig,
distances: &mut Array2<F>,
) where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
let row_results: Vec<Vec<(usize, F)>> = (0..n_samples)
.into_par_iter()
.map(|i| {
let mut row_dists = Vec::with_capacity(n_samples - i - 1);
for j in (i + 1)..n_samples {
let dist = simd_distance(data.row(i), data.row(j), metric, Some(config))
.unwrap_or_else(|_| F::zero());
row_dists.push((j, dist));
}
row_dists
})
.collect();
for (i, row_dists) in row_results.into_iter().enumerate() {
for (j, dist) in row_dists {
distances[[i, j]] = dist;
distances[[j, i]] = dist;
}
}
}
pub fn simd_pairwise_condensed_distances<F>(
data: ArrayView2<F>,
metric: SimdDistanceMetric,
config: Option<&SimdClusterConfig>,
) -> Result<Array1<F>>
where
F: Float + FromPrimitive + Debug + Send + Sync + SimdUnifiedOps,
{
let n_samples = data.shape()[0];
if n_samples == 0 {
return Err(ClusteringError::InvalidInput(
"Data must have at least one sample".to_string(),
));
}
let n_distances = n_samples * (n_samples - 1) / 2;
let mut distances = Array1::zeros(n_distances);
let default_config = SimdClusterConfig::default();
let cfg = config.unwrap_or(&default_config);
let mut idx = 0;
for i in 0..n_samples {
for j in (i + 1)..n_samples {
let dist = simd_distance(data.row(i), data.row(j), metric, Some(cfg))?;
distances[idx] = dist;
idx += 1;
}
}
Ok(distances)
}