use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::parallel_ops::*;
use std::fmt::Debug;
use crate::error::{ClusteringError, Result};
use crate::hierarchy::{coords_to_condensed_index, LinkageMethod};
#[derive(Debug)]
pub(crate) struct ParallelCluster {
pub size: usize,
pub members: Vec<usize>,
}
pub(crate) fn parallel_hierarchical_clustering<
F: Float + FromPrimitive + Debug + PartialOrd + Send + Sync + std::iter::Sum,
>(
distances: &Array1<F>,
n_samples: usize,
method: LinkageMethod,
) -> Result<Array2<F>> {
let mut clusters: Vec<ParallelCluster> = (0..n_samples)
.map(|i| ParallelCluster {
size: 1,
members: vec![i],
})
.collect();
let mut linkage_matrix = Array2::zeros((n_samples - 1, 4));
let mut activeclusters: Vec<usize> = (0..n_samples).collect();
let mut centroids: Option<Array2<F>> = None;
if matches!(method, LinkageMethod::Centroid | LinkageMethod::Median) {
centroids = Some(Array2::from_elem(
(2 * n_samples - 1, distances.len()),
F::zero(),
));
}
for i in 0..(n_samples - 1) {
let (cluster1_idx, cluster2_idx, min_dist) = parallel_find_closestclusters(
&activeclusters,
&clusters,
distances,
method,
centroids.as_ref(),
n_samples,
)?;
let cluster1 = activeclusters[cluster1_idx];
let cluster2 = activeclusters[cluster2_idx];
let (cluster1, cluster2) = if cluster1 < cluster2 {
(cluster1, cluster2)
} else {
(cluster2, cluster1)
};
let new_cluster_id = n_samples + i;
let mut new_members = clusters[cluster1].members.clone();
new_members.extend(clusters[cluster2].members.clone());
let new_cluster = ParallelCluster {
size: clusters[cluster1].size + clusters[cluster2].size,
members: new_members,
};
if let Some(ref mut cents) = centroids {
update_centroid(cents, method, n_samples, new_cluster_id);
}
clusters.push(new_cluster);
activeclusters.remove(cluster1_idx.max(cluster2_idx));
activeclusters.remove(cluster1_idx.min(cluster2_idx));
activeclusters.push(new_cluster_id);
linkage_matrix[[i, 0]] = F::from_usize(cluster1).expect("Operation failed");
linkage_matrix[[i, 1]] = F::from_usize(cluster2).expect("Operation failed");
linkage_matrix[[i, 2]] = min_dist;
linkage_matrix[[i, 3]] =
F::from_usize(clusters[new_cluster_id].size).expect("Operation failed");
}
Ok(linkage_matrix)
}
#[allow(dead_code)]
fn parallel_find_closestclusters<
F: Float + FromPrimitive + Debug + PartialOrd + Send + Sync + std::iter::Sum,
>(
activeclusters: &[usize],
clusters: &[ParallelCluster],
distances: &Array1<F>,
method: LinkageMethod,
centroids: Option<&Array2<F>>,
n_samples: usize,
) -> Result<(usize, usize, F)> {
let mut pairs = Vec::new();
for i in 0..activeclusters.len() {
for j in (i + 1)..activeclusters.len() {
pairs.push((i, j));
}
}
if pairs.is_empty() {
return Err(ClusteringError::ComputationError(
"No cluster pairs to process".into(),
));
}
let results: Result<Vec<(usize, usize, F)>> = pairs
.par_iter()
.map(|&(i, j)| {
let cluster_i = activeclusters[i];
let cluster_j = activeclusters[j];
let dist = match method {
LinkageMethod::Single => parallel_single_linkage(
&clusters[cluster_i],
&clusters[cluster_j],
distances,
n_samples,
),
LinkageMethod::Complete => parallel_complete_linkage(
&clusters[cluster_i],
&clusters[cluster_j],
distances,
n_samples,
),
LinkageMethod::Average => parallel_average_linkage(
&clusters[cluster_i],
&clusters[cluster_j],
distances,
n_samples,
),
LinkageMethod::Ward => parallel_ward_linkage(
&clusters[cluster_i],
&clusters[cluster_j],
distances,
n_samples,
),
LinkageMethod::Centroid => Ok(centroid_linkage(
cluster_i,
cluster_i,
cluster_j,
centroids.expect("Operation failed"),
)),
LinkageMethod::Median => Ok(median_linkage(
cluster_i,
cluster_i,
cluster_j,
centroids.expect("Operation failed"),
)),
LinkageMethod::Weighted => parallel_weighted_linkage(
&clusters[cluster_i],
&clusters[cluster_j],
distances,
n_samples,
),
};
dist.map(|d| (i, j, d))
})
.collect();
let min_result = results?
.into_iter()
.min_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal))
.ok_or_else(|| {
ClusteringError::ComputationError("Could not find minimum distance".into())
})?;
Ok((min_result.0, min_result.1, min_result.2))
}
pub(crate) fn parallel_single_linkage<F: Float + PartialOrd + Send + Sync>(
cluster1: &ParallelCluster,
cluster2: &ParallelCluster,
distances: &Array1<F>,
n_samples: usize,
) -> Result<F> {
let results: Result<Vec<F>> = cluster1
.members
.par_iter()
.map(|&i| {
let mut min_dist = F::infinity();
for &j in &cluster2.members {
let (min_idx, max_idx) = if i < j { (i, j) } else { (j, i) };
let idx = coords_to_condensed_index(n_samples, min_idx, max_idx)?;
let dist = distances[idx];
min_dist = min_dist.min(dist);
}
Ok(min_dist)
})
.collect();
let min_distances = results?;
Ok(min_distances
.into_iter()
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(F::infinity()))
}
pub(crate) fn parallel_complete_linkage<F: Float + PartialOrd + Send + Sync>(
cluster1: &ParallelCluster,
cluster2: &ParallelCluster,
distances: &Array1<F>,
n_samples: usize,
) -> Result<F> {
let results: Result<Vec<F>> = cluster1
.members
.par_iter()
.map(|&i| {
let mut max_dist = F::neg_infinity();
for &j in &cluster2.members {
let (min_idx, max_idx) = if i < j { (i, j) } else { (j, i) };
let idx = coords_to_condensed_index(n_samples, min_idx, max_idx)?;
let dist = distances[idx];
max_dist = max_dist.max(dist);
}
Ok(max_dist)
})
.collect();
let max_distances = results?;
Ok(max_distances
.into_iter()
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(F::neg_infinity()))
}
pub(crate) fn parallel_average_linkage<F: Float + FromPrimitive + Send + Sync>(
cluster1: &ParallelCluster,
cluster2: &ParallelCluster,
distances: &Array1<F>,
n_samples: usize,
) -> Result<F> {
let results: Result<Vec<(F, usize)>> = cluster1
.members
.par_iter()
.map(|&i| {
let mut sum = F::zero();
let mut count = 0;
for &j in &cluster2.members {
let (min_idx, max_idx) = if i < j { (i, j) } else { (j, i) };
let idx = coords_to_condensed_index(n_samples, min_idx, max_idx)?;
sum = sum + distances[idx];
count += 1;
}
Ok((sum, count))
})
.collect();
let sum_counts = results?;
let (total_sum, total_count) = sum_counts
.into_iter()
.fold((F::zero(), 0), |(acc_sum, acc_count), (sum, count)| {
(acc_sum + sum, acc_count + count)
});
if total_count == 0 {
Ok(F::infinity())
} else {
Ok(total_sum / F::from_usize(total_count).expect("Operation failed"))
}
}
pub(crate) fn parallel_ward_linkage<F: Float + FromPrimitive + Send + Sync + std::iter::Sum>(
cluster1: &ParallelCluster,
cluster2: &ParallelCluster,
distances: &Array1<F>,
n_samples: usize,
) -> Result<F> {
let size1 = F::from_usize(cluster1.size).expect("Operation failed");
let size2 = F::from_usize(cluster2.size).expect("Operation failed");
let results: Result<Vec<F>> = cluster1
.members
.par_iter()
.map(|&i| {
let mut sum_squared = F::zero();
for &j in &cluster2.members {
let (min_idx, max_idx) = if i < j { (i, j) } else { (j, i) };
let idx = coords_to_condensed_index(n_samples, min_idx, max_idx)?;
let dist = distances[idx];
sum_squared = sum_squared + dist * dist;
}
Ok(sum_squared)
})
.collect();
let squared_distances = results?;
let sum_squared_dist: F = squared_distances.into_iter().sum();
let avg_dist_sq = sum_squared_dist / (size1 * size2);
let factor = (size1 * size2) / (size1 + size2);
Ok((factor * avg_dist_sq).sqrt())
}
pub(crate) fn parallel_weighted_linkage<F: Float + FromPrimitive + Send + Sync>(
cluster1: &ParallelCluster,
cluster2: &ParallelCluster,
distances: &Array1<F>,
n_samples: usize,
) -> Result<F> {
parallel_average_linkage(cluster1, cluster2, distances, n_samples)
}
#[allow(dead_code)]
fn centroid_linkage<F: Float>(
_cluster_i: usize,
i: usize,
_cluster_j: usize,
_centroids: &Array2<F>,
) -> F {
F::zero()
}
#[allow(dead_code)]
fn median_linkage<F: Float>(
_cluster_i: usize,
i: usize,
_cluster_j: usize,
_centroids: &Array2<F>,
) -> F {
F::zero()
}
#[allow(dead_code)]
fn update_centroid<F: Float>(
_centroids: &mut Array2<F>,
_method: LinkageMethod,
n_samples: usize,
_cluster_id: usize,
) {
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_parallel_single_linkage() {
let cluster1 = ParallelCluster {
size: 2,
members: vec![0, 1],
};
let cluster2 = ParallelCluster {
size: 2,
members: vec![2, 3],
};
let distances = Array1::from(vec![1.0, 2.0, 3.0, 1.5, 2.5, 1.8]);
let result = parallel_single_linkage(&cluster1, &cluster2, &distances, 4);
assert!(result.expect("Operation failed") >= 0.0);
}
#[test]
fn test_parallel_complete_linkage() {
let cluster1 = ParallelCluster {
size: 2,
members: vec![0, 1],
};
let cluster2 = ParallelCluster {
size: 2,
members: vec![2, 3],
};
let distances = Array1::from(vec![1.0, 2.0, 3.0, 1.5, 2.5, 1.8]);
let result = parallel_complete_linkage(&cluster1, &cluster2, &distances, 4);
assert!(result.expect("Operation failed") >= 0.0);
}
#[test]
fn test_parallel_average_linkage() {
let cluster1 = ParallelCluster {
size: 2,
members: vec![0, 1],
};
let cluster2 = ParallelCluster {
size: 2,
members: vec![2, 3],
};
let distances = Array1::from(vec![1.0, 2.0, 3.0, 1.5, 2.5, 1.8]);
let result = parallel_average_linkage(&cluster1, &cluster2, &distances, 4);
assert!(result.expect("Operation failed") >= 0.0);
}
}