use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
use std::collections::{HashMap, HashSet};
use crate::error::{ClusteringError, Result};
#[derive(Debug, Clone)]
pub struct Merge {
pub cluster_a: usize,
pub cluster_b: usize,
pub distance: f64,
pub size: usize,
}
#[derive(Debug, Clone)]
pub struct Dendrogram {
pub merges: Vec<Merge>,
pub n_samples: usize,
}
impl Dendrogram {
pub fn cut_at_n_clusters(&self, n_clusters: usize) -> Result<Vec<usize>> {
let n = self.n_samples;
if n_clusters == 0 || n_clusters > n {
return Err(ClusteringError::InvalidInput(format!(
"n_clusters must be in 1..={n}, got {n_clusters}"
)));
}
let n_merges_to_do = n.saturating_sub(n_clusters);
let mut parent: Vec<usize> = (0..(2 * n - 1)).collect();
for merge in self.merges.iter().take(n_merges_to_do) {
let a_root = find_root(&parent, merge.cluster_a);
let b_root = find_root(&parent, merge.cluster_b);
if a_root != b_root {
parent[b_root] = a_root;
}
}
let roots: Vec<usize> = (0..n).map(|i| find_root(&parent, i)).collect();
let mut id_map: HashMap<usize, usize> = HashMap::new();
let mut next_label = 0usize;
let labels: Vec<usize> = roots
.iter()
.map(|root| {
let entry = id_map.entry(*root).or_insert_with(|| {
let l = next_label;
next_label += 1;
l
});
*entry
})
.collect();
Ok(labels)
}
pub fn cut_at_distance(&self, max_distance: f64) -> Result<Vec<usize>> {
let n = self.n_samples;
let mut parent: Vec<usize> = (0..(2 * n - 1)).collect();
for merge in &self.merges {
if merge.distance > max_distance {
break;
}
let a_root = find_root(&parent, merge.cluster_a);
let b_root = find_root(&parent, merge.cluster_b);
if a_root != b_root {
parent[b_root] = a_root;
}
}
let roots: Vec<usize> = (0..n).map(|i| find_root(&parent, i)).collect();
let mut id_map: HashMap<usize, usize> = HashMap::new();
let mut next_label = 0usize;
let labels: Vec<usize> = roots
.iter()
.map(|&root| {
let entry = id_map.entry(root).or_insert_with(|| {
let l = next_label;
next_label += 1;
l
});
*entry
})
.collect();
Ok(labels)
}
}
fn find_root(parent: &[usize], mut x: usize) -> usize {
while parent[x] != x {
x = parent[x];
}
x
}
#[derive(Debug, Clone)]
pub struct HierarchicalResult {
pub labels: Array1<usize>,
pub n_clusters: usize,
pub dendrogram: Dendrogram,
pub inertia: f64,
}
pub struct Ward;
impl Ward {
pub fn fit(x: ArrayView2<f64>, n_clusters: usize) -> Result<HierarchicalResult> {
let (n_samples, n_features) = (x.shape()[0], x.shape()[1]);
if n_samples == 0 {
return Err(ClusteringError::InvalidInput("Empty input data".into()));
}
if n_clusters == 0 || n_clusters > n_samples {
return Err(ClusteringError::InvalidInput(format!(
"n_clusters must be in 1..={n_samples}, got {n_clusters}"
)));
}
let capacity = 2 * n_samples - 1;
let mut all_centroids: Vec<Vec<f64>> = Vec::with_capacity(capacity);
let mut all_sizes: Vec<f64> = Vec::with_capacity(capacity);
for i in 0..n_samples {
all_centroids.push(x.row(i).to_vec());
all_sizes.push(1.0);
}
let mut ward_dist: Vec<Vec<f64>> = vec![vec![0.0; n_samples]; n_samples];
for i in 0..n_samples {
for j in (i + 1)..n_samples {
let d = ward_dist_between(&all_centroids[i], &all_centroids[j], 1.0, 1.0);
ward_dist[i][j] = d;
ward_dist[j][i] = d;
}
}
let mut merges: Vec<Merge> = Vec::with_capacity(n_samples - 1);
let mut active_ids: Vec<usize> = (0..n_samples).collect();
let mut next_node = n_samples;
for _ in 0..(n_samples - 1) {
let n_active = active_ids.len();
let mut min_dist = f64::INFINITY;
let mut best_ai = 0usize;
let mut best_aj = 1usize;
for a in 0..n_active {
for b in (a + 1)..n_active {
let ia = active_ids[a];
let ib = active_ids[b];
let d = ward_dist[ia][ib];
if d < min_dist {
min_dist = d;
best_ai = a;
best_aj = b;
}
}
}
let ia = active_ids[best_ai];
let ib = active_ids[best_aj];
let sa = all_sizes[ia];
let sb = all_sizes[ib];
let new_size = sa + sb;
let new_centroid: Vec<f64> = (0..n_features)
.map(|k| (all_centroids[ia][k] * sa + all_centroids[ib][k] * sb) / new_size)
.collect();
merges.push(Merge {
cluster_a: ia,
cluster_b: ib,
distance: min_dist,
size: new_size as usize,
});
let new_id = next_node;
next_node += 1;
all_centroids.push(new_centroid.clone());
all_sizes.push(new_size);
let current_len = ward_dist.len();
for row in ward_dist.iter_mut() {
row.push(0.0);
}
ward_dist.push(vec![0.0; current_len + 1]);
let remaining: Vec<usize> = active_ids
.iter()
.enumerate()
.filter(|&(idx, _)| idx != best_ai && idx != best_aj)
.map(|(_, &id)| id)
.collect();
for &ik in &remaining {
let sk = all_sizes[ik];
let d = ward_dist_between(&new_centroid, &all_centroids[ik], new_size, sk);
ward_dist[new_id][ik] = d;
ward_dist[ik][new_id] = d;
}
let remove_high = best_ai.max(best_aj);
let remove_low = best_ai.min(best_aj);
active_ids.remove(remove_high);
active_ids.remove(remove_low);
active_ids.push(new_id);
}
let dendrogram = Dendrogram {
merges,
n_samples,
};
let label_vec = dendrogram.cut_at_n_clusters(n_clusters)?;
let labels = Array1::from_vec(label_vec.clone());
let inertia = compute_inertia(x, &label_vec, n_clusters);
Ok(HierarchicalResult {
labels,
n_clusters,
dendrogram,
inertia,
})
}
}
fn ward_dist_between(ca: &[f64], cb: &[f64], sa: f64, sb: f64) -> f64 {
let merged = sa + sb;
if merged == 0.0 {
return 0.0;
}
let factor = (sa * sb) / merged;
ca.iter()
.zip(cb.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f64>()
* factor
}
fn compute_inertia(x: ArrayView2<f64>, labels: &[usize], n_clusters: usize) -> f64 {
let n_samples = x.shape()[0];
let n_features = x.shape()[1];
let mut cluster_sums: Vec<Vec<f64>> = vec![vec![0.0; n_features]; n_clusters];
let mut cluster_counts: Vec<usize> = vec![0; n_clusters];
for i in 0..n_samples {
let c = labels[i];
if c < n_clusters {
cluster_counts[c] += 1;
for j in 0..n_features {
cluster_sums[c][j] += x[[i, j]];
}
}
}
let centroids: Vec<Vec<f64>> = (0..n_clusters)
.map(|c| {
if cluster_counts[c] == 0 {
vec![0.0; n_features]
} else {
cluster_sums[c]
.iter()
.map(|&v| v / cluster_counts[c] as f64)
.collect()
}
})
.collect();
let mut inertia = 0.0;
for i in 0..n_samples {
let c = labels[i];
if c < n_clusters {
inertia += centroids[c]
.iter()
.zip(x.row(i).iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f64>();
}
}
inertia
}
#[derive(Debug, Clone)]
pub struct DivisiveResult {
pub labels: Array1<usize>,
pub n_clusters: usize,
pub split_history: Vec<(usize, usize)>,
pub inertia: f64,
}
pub struct Divisive;
impl Divisive {
pub fn fit(x: ArrayView2<f64>, n_clusters: usize) -> Result<DivisiveResult> {
let n_samples = x.shape()[0];
if n_samples == 0 {
return Err(ClusteringError::InvalidInput("Empty input data".into()));
}
if n_clusters == 0 || n_clusters > n_samples {
return Err(ClusteringError::InvalidInput(format!(
"n_clusters must be in 1..={n_samples}, got {n_clusters}"
)));
}
let mut clusters: Vec<Vec<usize>> = vec![(0..n_samples).collect()];
let mut split_history: Vec<(usize, usize)> = Vec::new();
let dist = precompute_distances(x);
while clusters.len() < n_clusters {
let split_idx = find_cluster_to_split(&clusters, &dist);
let old_cluster = clusters.remove(split_idx);
if old_cluster.len() == 1 {
clusters.push(old_cluster);
break;
}
let (group_a, group_b) = diana_split(&old_cluster, &dist)?;
let a_size = group_a.len();
let b_size = group_b.len();
split_history.push((a_size, b_size));
clusters.push(group_a);
clusters.push(group_b);
}
let mut labels = vec![0usize; n_samples];
for (cluster_id, cluster) in clusters.iter().enumerate() {
for &idx in cluster {
labels[idx] = cluster_id;
}
}
let actual_n = clusters.len();
let inertia = compute_inertia(x, &labels, actual_n);
Ok(DivisiveResult {
labels: Array1::from_vec(labels),
n_clusters: actual_n,
split_history,
inertia,
})
}
}
fn precompute_distances(x: ArrayView2<f64>) -> Vec<Vec<f64>> {
let n = x.shape()[0];
let mut dist = vec![vec![0.0f64; n]; n];
for i in 0..n {
for j in (i + 1)..n {
let d: f64 = x.row(i)
.iter()
.zip(x.row(j).iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f64>()
.sqrt();
dist[i][j] = d;
dist[j][i] = d;
}
}
dist
}
fn find_cluster_to_split(clusters: &[Vec<usize>], dist: &[Vec<f64>]) -> usize {
let mut max_diam = -1.0f64;
let mut best = 0usize;
for (idx, cluster) in clusters.iter().enumerate() {
let diam = average_diameter(cluster, dist);
if diam > max_diam {
max_diam = diam;
best = idx;
}
}
best
}
fn average_diameter(cluster: &[usize], dist: &[Vec<f64>]) -> f64 {
let n = cluster.len();
if n < 2 {
return 0.0;
}
let mut total = 0.0;
let mut count = 0u64;
for i in 0..n {
for j in (i + 1)..n {
total += dist[cluster[i]][cluster[j]];
count += 1;
}
}
if count == 0 {
0.0
} else {
total / count as f64
}
}
fn diana_split(cluster: &[usize], dist: &[Vec<f64>]) -> Result<(Vec<usize>, Vec<usize>)> {
if cluster.len() < 2 {
return Err(ClusteringError::InvalidInput(
"Cannot split a cluster with fewer than 2 elements".into(),
));
}
let avg_diss: Vec<f64> = cluster
.iter()
.map(|&i| {
let sum: f64 = cluster.iter().filter(|&&j| j != i).map(|&j| dist[i][j]).sum();
if cluster.len() <= 1 {
0.0
} else {
sum / (cluster.len() - 1) as f64
}
})
.collect();
let splinter_local_idx = avg_diss
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
let mut main_party: HashSet<usize> = cluster.iter().cloned().collect();
let mut splinter_group: Vec<usize> = vec![cluster[splinter_local_idx]];
main_party.remove(&cluster[splinter_local_idx]);
loop {
let main_vec: Vec<usize> = main_party.iter().cloned().collect();
let sg_len = splinter_group.len() as f64;
let main_len = main_vec.len() as f64;
if main_len == 0.0 {
break;
}
let mut to_move: Vec<usize> = Vec::new();
for &obj in &main_vec {
let d_sg = splinter_group.iter().map(|&s| dist[obj][s]).sum::<f64>() / sg_len;
let other_main_len = (main_len - 1.0).max(1.0);
let d_main = main_vec
.iter()
.filter(|&&o| o != obj)
.map(|&o| dist[obj][o])
.sum::<f64>()
/ other_main_len;
if d_sg < d_main {
to_move.push(obj);
}
}
if to_move.is_empty() {
break;
}
for obj in to_move {
main_party.remove(&obj);
splinter_group.push(obj);
}
}
if splinter_group.is_empty() || main_party.is_empty() {
let half = cluster.len() / 2;
return Ok((cluster[..half].to_vec(), cluster[half..].to_vec()));
}
let main_vec: Vec<usize> = main_party.into_iter().collect();
Ok((main_vec, splinter_group))
}
#[derive(Debug, Clone)]
pub struct ConsensusMatrix {
pub matrix: Array2<f64>,
pub cooccurrence: Array2<f64>,
pub selection: Array2<f64>,
pub n_samples: usize,
pub n_resamples: usize,
}
impl ConsensusMatrix {
pub fn get(&self, i: usize, j: usize) -> f64 {
self.matrix[[i, j]]
}
pub fn extract_clusters(&self, n_clusters: usize) -> Result<Vec<usize>> {
let n = self.n_samples;
let mut dist_data = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
dist_data[[i, j]] = 1.0 - self.matrix[[i, j]].clamp(0.0, 1.0);
}
}
let result = Ward::fit(dist_data.view(), n_clusters)?;
Ok(result.labels.to_vec())
}
}
pub trait BaseClusterer: Send + Sync {
fn fit_predict(&self, x: ArrayView2<f64>) -> Result<Vec<usize>>;
}
pub struct ConsensusClustering {
pub subsample_fraction: f64,
pub seed: Option<u64>,
}
impl Default for ConsensusClustering {
fn default() -> Self {
Self {
subsample_fraction: 0.8,
seed: None,
}
}
}
impl ConsensusClustering {
pub fn new(subsample_fraction: f64, seed: Option<u64>) -> Self {
Self {
subsample_fraction: subsample_fraction.clamp(0.1, 1.0),
seed,
}
}
pub fn fit(
&self,
x: ArrayView2<f64>,
base_clusterer: &dyn BaseClusterer,
n_resamples: usize,
) -> Result<ConsensusMatrix> {
let n_samples = x.shape()[0];
if n_samples == 0 {
return Err(ClusteringError::InvalidInput("Empty input data".into()));
}
if n_resamples == 0 {
return Err(ClusteringError::InvalidInput(
"n_resamples must be at least 1".into(),
));
}
let mut cooccurrence = Array2::<f64>::zeros((n_samples, n_samples));
let mut selection = Array2::<f64>::zeros((n_samples, n_samples));
let subsample_size =
((n_samples as f64 * self.subsample_fraction).ceil() as usize).max(2);
let mut rng_state = self.seed.unwrap_or(42u64);
for _ in 0..n_resamples {
let indices = lcg_sample_without_replacement(&mut rng_state, n_samples, subsample_size);
let sub_data = build_submatrix(x, &indices);
let sub_labels = base_clusterer.fit_predict(sub_data.view())?;
for (a, &ia) in indices.iter().enumerate() {
for (b, &ib) in indices.iter().enumerate() {
selection[[ia, ib]] += 1.0;
if sub_labels[a] == sub_labels[b] {
cooccurrence[[ia, ib]] += 1.0;
}
}
}
}
let mut matrix = Array2::<f64>::zeros((n_samples, n_samples));
for i in 0..n_samples {
for j in 0..n_samples {
let sel = selection[[i, j]];
matrix[[i, j]] = if sel > 0.0 {
cooccurrence[[i, j]] / sel
} else {
0.0
};
}
}
Ok(ConsensusMatrix {
matrix,
cooccurrence,
selection,
n_samples,
n_resamples,
})
}
}
fn lcg_sample_without_replacement(state: &mut u64, n: usize, k: usize) -> Vec<usize> {
let mut indices: Vec<usize> = (0..n).collect();
let k = k.min(n);
for i in 0..k {
*state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let j = i + (*state as usize % (n - i));
indices.swap(i, j);
}
indices[..k].to_vec()
}
fn build_submatrix(x: ArrayView2<f64>, indices: &[usize]) -> Array2<f64> {
let n_features = x.shape()[1];
let k = indices.len();
let mut sub = Array2::<f64>::zeros((k, n_features));
for (row, &idx) in indices.iter().enumerate() {
sub.row_mut(row).assign(&x.row(idx));
}
sub
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn two_cluster_data() -> Array2<f64> {
Array2::from_shape_vec(
(10, 2),
vec![
0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.1, 0.2, 0.3, 0.1,
5.0, 5.1, 5.2, 5.0, 5.1, 5.2, 5.0, 5.1, 5.2, 5.0,
],
)
.expect("valid shape")
}
fn three_cluster_data() -> Array2<f64> {
Array2::from_shape_vec(
(12, 2),
vec![
0.0, 0.0, 0.1, 0.1, 0.2, 0.0,
5.0, 0.0, 5.1, 0.1, 5.2, 0.0,
0.0, 5.0, 0.1, 5.1, 0.2, 5.0,
5.0, 5.0, 5.1, 5.1, 5.2, 5.0,
],
)
.expect("valid shape")
}
#[test]
fn test_ward_basic_two_clusters() {
let data = two_cluster_data();
let result = Ward::fit(data.view(), 2).expect("ward fit");
assert_eq!(result.labels.len(), 10);
assert_eq!(result.n_clusters, 2);
let unique: std::collections::HashSet<usize> = result.labels.iter().cloned().collect();
assert_eq!(unique.len(), 2);
}
#[test]
fn test_ward_single_cluster() {
let data = two_cluster_data();
let result = Ward::fit(data.view(), 1).expect("ward fit 1 cluster");
assert!(result.labels.iter().all(|&l| l == 0));
}
#[test]
fn test_ward_returns_correct_number_of_clusters() {
let data = three_cluster_data();
let result = Ward::fit(data.view(), 3).expect("ward fit 3 clusters");
let unique: std::collections::HashSet<usize> = result.labels.iter().cloned().collect();
assert_eq!(unique.len(), 3);
}
#[test]
fn test_dendrogram_cut_at_n_clusters() {
let data = two_cluster_data();
let result = Ward::fit(data.view(), 1).expect("ward fit");
let labels2 = result.dendrogram.cut_at_n_clusters(2).expect("cut 2");
assert_eq!(labels2.len(), 10);
let unique: std::collections::HashSet<usize> = labels2.iter().cloned().collect();
assert_eq!(unique.len(), 2);
}
#[test]
fn test_dendrogram_cut_at_distance() {
let data = two_cluster_data();
let result = Ward::fit(data.view(), 1).expect("ward fit");
let labels = result.dendrogram.cut_at_distance(1e9).expect("cut dist");
let unique: std::collections::HashSet<usize> = labels.iter().cloned().collect();
assert_eq!(unique.len(), 1);
}
#[test]
fn test_divisive_basic() {
let data = two_cluster_data();
let result = Divisive::fit(data.view(), 2).expect("divisive fit");
assert_eq!(result.labels.len(), 10);
let unique: std::collections::HashSet<usize> = result.labels.iter().cloned().collect();
assert_eq!(unique.len(), 2);
}
#[test]
fn test_divisive_single_cluster() {
let data = two_cluster_data();
let result = Divisive::fit(data.view(), 1).expect("divisive 1 cluster");
assert!(result.labels.iter().all(|&l| l == 0));
}
#[test]
fn test_consensus_clustering_basic() {
struct SimpleKMeans {
k: usize,
}
impl BaseClusterer for SimpleKMeans {
fn fit_predict(&self, x: ArrayView2<f64>) -> Result<Vec<usize>> {
let n = x.shape()[0];
if n == 0 {
return Ok(vec![]);
}
let mut vals: Vec<(f64, usize)> = (0..n).map(|i| (x[[i, 0]], i)).collect();
vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let mut labels = vec![0usize; n];
let half = n / self.k.max(1);
for (rank, (_, orig)) in vals.iter().enumerate() {
labels[*orig] = (rank / half.max(1)).min(self.k - 1);
}
Ok(labels)
}
}
let data = two_cluster_data();
let clusterer = SimpleKMeans { k: 2 };
let cc = ConsensusClustering::new(0.8, Some(7));
let result = cc.fit(data.view(), &clusterer, 10).expect("consensus fit");
assert_eq!(result.n_samples, 10);
assert_eq!(result.n_resamples, 10);
assert_eq!(result.matrix.shape(), [10, 10]);
for i in 0..10 {
assert!(
(result.matrix[[i, i]] - 1.0).abs() < 1e-9,
"diagonal must be 1"
);
}
}
#[test]
fn test_consensus_extract_clusters() {
struct TrivialClusterer;
impl BaseClusterer for TrivialClusterer {
fn fit_predict(&self, x: ArrayView2<f64>) -> Result<Vec<usize>> {
let n = x.shape()[0];
Ok((0..n).map(|i| i % 2).collect())
}
}
let data = two_cluster_data();
let clusterer = TrivialClusterer;
let cc = ConsensusClustering::new(1.0, Some(3));
let result = cc.fit(data.view(), &clusterer, 5).expect("consensus fit");
let labels = result.extract_clusters(2).expect("extract clusters");
assert_eq!(labels.len(), 10);
}
#[test]
fn test_ward_invalid_n_clusters() {
let data = two_cluster_data();
assert!(Ward::fit(data.view(), 0).is_err());
assert!(Ward::fit(data.view(), 100).is_err());
}
}