use scirs2_core::ndarray::{Array1, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use crate::error::{ClusteringError, Result};
use crate::hierarchy::disjoint_set::DisjointSet;
#[allow(dead_code)]
pub fn extract_clusters_multi_criteria<F: Float + FromPrimitive + Debug + PartialOrd>(
linkage_matrix: ArrayView2<F>,
max_clusters: Option<usize>,
distance_threshold: Option<F>,
inconsistency_threshold: Option<F>,
) -> Result<Array1<usize>> {
let n_observations = linkage_matrix.shape()[0] + 1;
let mut disjoint_set = DisjointSet::new();
for i in 0..n_observations {
disjoint_set.make_set(i);
}
for merge_idx in 0..linkage_matrix.shape()[0] {
let cluster1 = linkage_matrix[[merge_idx, 0]]
.to_usize()
.expect("Operation failed");
let cluster2 = linkage_matrix[[merge_idx, 1]]
.to_usize()
.expect("Operation failed");
let distance = linkage_matrix[[merge_idx, 2]];
if let Some(dist_thresh) = distance_threshold {
if distance > dist_thresh {
break; }
}
if let Some(inconsist_thresh) = inconsistency_threshold {
let inconsistency = calculate_merge_inconsistency(linkage_matrix, merge_idx)?;
if inconsistency > inconsist_thresh {
break; }
}
if let Some(max_clust) = max_clusters {
let current_clusters = disjoint_set.num_sets();
if current_clusters <= max_clust {
break; }
}
if cluster1 < n_observations && cluster2 < n_observations {
disjoint_set.union(cluster1, cluster2);
} else {
}
}
let mut cluster_id_map = HashMap::new();
let mut next_cluster_id = 0;
let mut assignments = Array1::zeros(n_observations);
for i in 0..n_observations {
if let Some(root) = disjoint_set.find(&i) {
let cluster_id = cluster_id_map.entry(root).or_insert_with(|| {
let id = next_cluster_id;
next_cluster_id += 1;
id
});
assignments[i] = *cluster_id;
}
}
Ok(assignments)
}
#[allow(dead_code)]
fn calculate_merge_inconsistency<F: Float + FromPrimitive + Debug + PartialOrd>(
linkage_matrix: ArrayView2<F>,
merge_idx: usize,
) -> Result<F> {
if merge_idx >= linkage_matrix.shape()[0] {
return Err(ClusteringError::InvalidInput(
"Merge index out of bounds".to_string(),
));
}
let current_distance = linkage_matrix[[merge_idx, 2]];
let mut distances = Vec::new();
let start_idx = merge_idx.saturating_sub(5);
for i in start_idx..merge_idx {
distances.push(linkage_matrix[[i, 2]]);
}
if distances.is_empty() {
return Ok(F::zero()); }
let mean = distances.iter().fold(F::zero(), |acc, &x| acc + x)
/ F::from_usize(distances.len()).expect("Operation failed");
let variance = distances
.iter()
.fold(F::zero(), |acc, &x| acc + (x - mean) * (x - mean))
/ F::from_usize(distances.len()).expect("Operation failed");
let std_dev = variance.sqrt();
if std_dev < F::from_f64(1e-10).expect("Operation failed") {
return Ok(F::zero()); }
Ok((current_distance - mean) / std_dev)
}
#[allow(dead_code)]
pub fn estimate_optimal_clusters<F: Float + FromPrimitive + Debug + PartialOrd>(
linkage_matrix: ArrayView2<F>,
data: Option<ArrayView2<F>>,
max_clusters: Option<usize>,
) -> Result<usize> {
let n_observations = linkage_matrix.shape()[0] + 1;
let max_k = max_clusters.unwrap_or(n_observations.min(20));
if max_k < 2 {
return Ok(1);
}
let elbow_k = estimate_clusters_elbow_method(linkage_matrix, max_k)?;
let gap_k = estimate_clusters_distance_gap(linkage_matrix, max_k)?;
let silhouette_k = if let Some(data_view) = data {
Some(estimate_clusters_silhouette(
linkage_matrix,
data_view,
max_k,
)?)
} else {
None
};
let mut candidates = vec![elbow_k, gap_k];
if let Some(sil_k) = silhouette_k {
candidates.push(sil_k);
}
let mut counts = HashMap::new();
for &k in &candidates {
*counts.entry(k).or_insert(0) += 1;
}
let optimal_k = counts
.into_iter()
.max_by_key(|(_, count)| *count)
.map(|(k_, _)| k_)
.unwrap_or(2);
Ok(optimal_k.max(1).min(n_observations))
}
#[allow(dead_code)]
fn estimate_clusters_elbow_method<F: Float + FromPrimitive + Debug + PartialOrd>(
linkage_matrix: ArrayView2<F>,
max_clusters: usize,
) -> Result<usize> {
let n_observations = linkage_matrix.shape()[0] + 1;
let mut wcss_values = Vec::new();
for k in 1..=max_clusters.min(n_observations) {
let _clusters = extract_clusters_by_count(linkage_matrix, k)?;
let wcss = calculate_wcss_approximation(linkage_matrix, &_clusters, k);
wcss_values.push(wcss);
}
let elbow_k = find_elbow_point(&wcss_values);
Ok(elbow_k + 1) }
#[allow(dead_code)]
fn estimate_clusters_distance_gap<F: Float + FromPrimitive + Debug + PartialOrd>(
linkage_matrix: ArrayView2<F>,
max_clusters: usize,
) -> Result<usize> {
let n_merges = linkage_matrix.shape()[0];
let mut gaps = Vec::new();
for i in 1..n_merges {
let current_dist = linkage_matrix[[i, 2]];
let prev_dist = linkage_matrix[[i - 1, 2]];
gaps.push(current_dist - prev_dist);
}
let max_gap_idx = gaps
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.unwrap_or(0);
let optimal_k = (n_merges - max_gap_idx).min(max_clusters);
Ok(optimal_k.max(1))
}
#[allow(dead_code)]
fn estimate_clusters_silhouette<F: Float + FromPrimitive + Debug + PartialOrd>(
linkage_matrix: ArrayView2<F>,
data: ArrayView2<F>,
max_clusters: usize,
) -> Result<usize> {
let n_observations = data.shape()[0];
let mut best_silhouette = F::neg_infinity();
let mut best_k = 2;
for k in 2..=max_clusters.min(n_observations) {
let _clusters = extract_clusters_by_count(linkage_matrix, k)?;
let silhouette_score = calculate_silhouette_score(data, &_clusters);
if silhouette_score > best_silhouette {
best_silhouette = silhouette_score;
best_k = k;
}
}
Ok(best_k)
}
#[allow(dead_code)]
fn extract_clusters_by_count<F: Float + FromPrimitive + Debug + PartialOrd>(
linkage_matrix: ArrayView2<F>,
n_clusters: usize,
) -> Result<Array1<usize>> {
let n_observations = linkage_matrix.shape()[0] + 1;
if n_clusters > n_observations {
return Err(ClusteringError::InvalidInput(format!(
"Number of _clusters {} cannot exceed number of observations {}",
n_clusters, n_observations
)));
}
if n_clusters == n_observations {
return Ok(Array1::from_iter(0..n_observations));
}
let mut disjoint_set = DisjointSet::new();
for i in 0..n_observations {
disjoint_set.make_set(i);
}
let n_merges_to_perform = n_observations - n_clusters;
for merge_idx in 0..n_merges_to_perform {
if merge_idx >= linkage_matrix.shape()[0] {
break;
}
let cluster1 = linkage_matrix[[merge_idx, 0]]
.to_usize()
.expect("Operation failed");
let cluster2 = linkage_matrix[[merge_idx, 1]]
.to_usize()
.expect("Operation failed");
let obs1 = if cluster1 < n_observations {
cluster1
} else {
continue;
};
let obs2 = if cluster2 < n_observations {
cluster2
} else {
continue;
};
disjoint_set.union(obs1, obs2);
}
let mut cluster_id_map = HashMap::new();
let mut next_cluster_id = 0;
let mut assignments = Array1::zeros(n_observations);
for i in 0..n_observations {
if let Some(root) = disjoint_set.find(&i) {
let cluster_id = cluster_id_map.entry(root).or_insert_with(|| {
let id = next_cluster_id;
next_cluster_id += 1;
id
});
assignments[i] = *cluster_id;
}
}
Ok(assignments)
}
#[allow(dead_code)]
fn calculate_wcss_approximation<F: Float + FromPrimitive + Debug + PartialOrd>(
linkage_matrix: ArrayView2<F>,
_clusters: &Array1<usize>,
_k: usize,
) -> F {
let n_merges = linkage_matrix.shape()[0];
let mut total_wcss = F::zero();
for i in 0..n_merges {
let distance = linkage_matrix[[i, 2]];
let cluster_size = linkage_matrix[[i, 3]];
total_wcss = total_wcss + distance * cluster_size;
}
total_wcss
}
#[allow(dead_code)]
fn find_elbow_point<F: Float + FromPrimitive + Debug + PartialOrd>(values: &[F]) -> usize {
if values.len() < 3 {
return 0;
}
let mut second_derivatives = Vec::new();
for i in 1..(values.len() - 1) {
let second_deriv =
values[i + 1] - F::from_f64(2.0).expect("Operation failed") * values[i] + values[i - 1];
second_derivatives.push(second_deriv.abs());
}
second_derivatives
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx_, _)| idx_ + 1) .unwrap_or(0)
}
#[allow(dead_code)]
fn calculate_silhouette_score<F: Float + FromPrimitive + Debug + PartialOrd>(
data: ArrayView2<F>,
clusters: &Array1<usize>,
) -> F {
let n_observations = data.shape()[0];
if n_observations != clusters.len() {
return F::neg_infinity(); }
let unique_clusters: HashSet<_> = clusters.iter().copied().collect();
if unique_clusters.len() < 2 {
return F::zero(); }
let mut total_silhouette = F::zero();
for i in 0..n_observations {
let cluster_i = clusters[i];
let mut intra_cluster_distances = Vec::new();
for j in 0..n_observations {
if i != j && clusters[j] == cluster_i {
let distance = euclidean_distance(data.row(i), data.row(j));
intra_cluster_distances.push(distance);
}
}
let a = if intra_cluster_distances.is_empty() {
F::zero()
} else {
intra_cluster_distances
.iter()
.fold(F::zero(), |acc, &x| acc + x)
/ F::from_usize(intra_cluster_distances.len()).expect("Operation failed")
};
let mut min_inter_cluster_distance = F::infinity();
for &other_cluster in &unique_clusters {
if other_cluster != cluster_i {
let mut inter_cluster_distances = Vec::new();
for j in 0..n_observations {
if clusters[j] == other_cluster {
let distance = euclidean_distance(data.row(i), data.row(j));
inter_cluster_distances.push(distance);
}
}
if !inter_cluster_distances.is_empty() {
let avg_distance = inter_cluster_distances
.iter()
.fold(F::zero(), |acc, &x| acc + x)
/ F::from_usize(inter_cluster_distances.len()).expect("Operation failed");
if avg_distance < min_inter_cluster_distance {
min_inter_cluster_distance = avg_distance;
}
}
}
}
let b = min_inter_cluster_distance;
let silhouette = if a == b {
F::zero()
} else {
(b - a) / a.max(b)
};
total_silhouette = total_silhouette + silhouette;
}
total_silhouette / F::from_usize(n_observations).expect("Operation failed")
}
#[allow(dead_code)]
fn euclidean_distance<F: Float + FromPrimitive>(point1: ArrayView1<F>, point2: ArrayView1<F>) -> F {
let mut sum = F::zero();
for (a, b) in point1.iter().zip(point2.iter()) {
let diff = *a - *b;
sum = sum + diff * diff;
}
sum.sqrt()
}
#[allow(dead_code)]
pub fn prune_clusters<F: Float + FromPrimitive + Debug + PartialOrd>(
clusters: &Array1<usize>,
data: ArrayView2<F>,
min_cluster_size: usize,
max_merge_distance: F,
) -> Result<Array1<usize>> {
let n_observations = clusters.len();
if n_observations != data.shape()[0] {
return Err(ClusteringError::InvalidInput(
"Cluster assignments and data dimensions don't match".to_string(),
));
}
let mut cluster_sizes = HashMap::new();
for &cluster_id in clusters.iter() {
*cluster_sizes.entry(cluster_id).or_insert(0) += 1;
}
let small_clusters: Vec<_> = cluster_sizes
.iter()
.filter_map(|(&cluster_id, &_size)| {
if _size < min_cluster_size {
Some(cluster_id)
} else {
None
}
})
.collect();
if small_clusters.is_empty() {
return Ok(clusters.clone()); }
let mut new_clusters = clusters.clone();
for &small_cluster_id in &small_clusters {
let small_cluster_points: Vec<_> = (0..n_observations)
.filter(|&i| clusters[i] == small_cluster_id)
.collect();
if small_cluster_points.is_empty() {
continue;
}
let mut min_distance = F::infinity();
let mut nearest_large_cluster = None;
for (&large_cluster_id, &_size) in &cluster_sizes {
if _size >= min_cluster_size && large_cluster_id != small_cluster_id {
let large_cluster_points: Vec<_> = (0..n_observations)
.filter(|&i| clusters[i] == large_cluster_id)
.collect();
for &small_point in &small_cluster_points {
for &large_point in &large_cluster_points {
let distance =
euclidean_distance(data.row(small_point), data.row(large_point));
if distance < min_distance {
min_distance = distance;
nearest_large_cluster = Some(large_cluster_id);
}
}
}
}
}
if let Some(target_cluster) = nearest_large_cluster {
if min_distance <= max_merge_distance {
for &point_idx in &small_cluster_points {
new_clusters[point_idx] = target_cluster;
}
}
}
}
let mut cluster_map = HashMap::new();
let mut next_id = 0;
for cluster_id in new_clusters.iter_mut() {
let new_id = cluster_map.entry(*cluster_id).or_insert_with(|| {
let id = next_id;
next_id += 1;
id
});
*cluster_id = *new_id;
}
Ok(new_clusters)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_extract_clusters_by_count() {
let linkage = Array2::from_shape_vec(
(3, 4),
vec![
0.0, 1.0, 0.5, 2.0, 2.0, 3.0, 0.8, 2.0, 4.0, 5.0, 1.2, 4.0, ],
)
.expect("Operation failed");
let clusters = extract_clusters_by_count(linkage.view(), 2).expect("Operation failed");
assert_eq!(clusters.len(), 4);
assert_eq!(clusters[0], clusters[1]);
assert_eq!(clusters[2], clusters[3]);
assert_ne!(clusters[0], clusters[2]);
}
#[test]
fn test_estimate_optimal_clusters() {
let data = Array2::from_shape_vec(
(6, 2),
vec![
0.0, 0.0, 1.0, 0.0, 0.5, 0.5, 10.0, 10.0, 11.0, 10.0, 10.5, 10.5, ],
)
.expect("Operation failed");
let linkage = Array2::from_shape_vec(
(5, 4),
vec![
0.0, 1.0, 1.0, 2.0, 1.0, 2.0, 0.7, 2.0, 3.0, 4.0, 1.0, 2.0, 4.0, 5.0, 0.7, 2.0,
6.0, 7.0, 8.0, 6.0,
],
)
.expect("Operation failed");
let optimal_k = estimate_optimal_clusters(linkage.view(), Some(data.view()), Some(4))
.expect("Operation failed");
assert!(optimal_k >= 1);
assert!(optimal_k <= 4);
}
#[test]
fn test_prune_clusters() {
let clusters = Array1::from_vec(vec![0, 0, 0, 1, 2, 2]); let data = Array2::from_shape_vec(
(6, 2),
vec![
0.0, 0.0, 0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 10.0, 10.0, 10.1, 10.1,
],
)
.expect("Operation failed");
let pruned = prune_clusters(&clusters, data.view(), 2, 1.0).expect("Operation failed");
assert_eq!(pruned.len(), 6);
let unique_clusters: std::collections::HashSet<_> = pruned.iter().copied().collect();
assert!(unique_clusters.len() < 3); }
#[test]
fn test_calculate_silhouette_score() {
let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 10.0, 10.0, 11.0, 10.0])
.expect("Operation failed");
let clusters = Array1::from_vec(vec![0, 0, 1, 1]);
let score = calculate_silhouette_score(data.view(), &clusters);
assert!(score > 0.0);
}
}