use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use crate::error::{ClusteringError, Result};
use crate::hierarchy::{coords_to_condensed_index, LinkageMethod};
#[derive(Debug)]
pub(crate) struct Cluster {
pub size: usize,
pub members: Vec<usize>,
}
pub(crate) fn hierarchical_clustering<F: Float + FromPrimitive + Debug + PartialOrd>(
distances: &Array1<F>,
n_samples: usize,
method: LinkageMethod,
) -> Result<Array2<F>> {
let mut clusters: Vec<Cluster> = (0..n_samples)
.map(|i| Cluster {
size: 1,
members: vec![i],
})
.collect();
let mut linkage_matrix = Array2::zeros((n_samples - 1, 4));
let mut activeclusters: Vec<usize> = (0..n_samples).collect();
let mut centroids: Option<Array2<F>> = None;
if matches!(method, LinkageMethod::Centroid | LinkageMethod::Median) {
centroids = Some(Array2::from_elem(
(2 * n_samples - 1, distances.len()),
F::zero(),
));
}
for i in 0..(n_samples - 1) {
let (cluster1_idx, cluster2_idx, min_dist) = find_closestclusters(
&activeclusters,
&clusters,
distances,
method,
centroids.as_ref(),
n_samples,
)?;
let cluster1 = activeclusters[cluster1_idx];
let cluster2 = activeclusters[cluster2_idx];
let (cluster1, cluster2) = if cluster1 < cluster2 {
(cluster1, cluster2)
} else {
(cluster2, cluster1)
};
let new_cluster_id = n_samples + i;
let mut new_members = clusters[cluster1].members.clone();
new_members.extend(clusters[cluster2].members.clone());
let new_cluster = Cluster {
size: clusters[cluster1].size + clusters[cluster2].size,
members: new_members,
};
if let Some(ref mut cents) = centroids {
update_centroid(cents, method, n_samples, new_cluster_id);
}
clusters.push(new_cluster);
activeclusters.remove(cluster1_idx.max(cluster2_idx));
activeclusters.remove(cluster1_idx.min(cluster2_idx));
activeclusters.push(new_cluster_id);
linkage_matrix[[i, 0]] = F::from_usize(cluster1).expect("Operation failed");
linkage_matrix[[i, 1]] = F::from_usize(cluster2).expect("Operation failed");
linkage_matrix[[i, 2]] = min_dist;
linkage_matrix[[i, 3]] =
F::from_usize(clusters[new_cluster_id].size).expect("Operation failed");
}
Ok(linkage_matrix)
}
#[allow(dead_code)]
fn find_closestclusters<F: Float + FromPrimitive + Debug + PartialOrd>(
activeclusters: &[usize],
clusters: &[Cluster],
distances: &Array1<F>,
method: LinkageMethod,
centroids: Option<&Array2<F>>,
n_samples: usize,
) -> Result<(usize, usize, F)> {
let mut min_dist = F::infinity();
let mut min_i = 0;
let mut min_j = 0;
for (i, &cluster_i) in activeclusters.iter().enumerate() {
for (j, &cluster_j) in activeclusters.iter().enumerate() {
if i >= j {
continue; }
let dist = match method {
LinkageMethod::Single => single_linkage(
&clusters[cluster_i],
&clusters[cluster_j],
distances,
n_samples,
)?,
LinkageMethod::Complete => complete_linkage(
&clusters[cluster_i],
&clusters[cluster_j],
distances,
n_samples,
)?,
LinkageMethod::Average => average_linkage(
&clusters[cluster_i],
&clusters[cluster_j],
distances,
n_samples,
)?,
LinkageMethod::Ward => ward_linkage(
&clusters[cluster_i],
&clusters[cluster_j],
distances,
n_samples,
)?,
LinkageMethod::Centroid => {
centroid_linkage(cluster_i, cluster_j, centroids.expect("Operation failed"))
}
LinkageMethod::Median => {
median_linkage(cluster_i, cluster_j, centroids.expect("Operation failed"))
}
LinkageMethod::Weighted => weighted_linkage(
&clusters[cluster_i],
&clusters[cluster_j],
distances,
n_samples,
)?,
};
if dist < min_dist {
min_dist = dist;
min_i = i;
min_j = j;
}
}
}
if min_dist == F::infinity() {
return Err(ClusteringError::ComputationError(
"Could not find minimum distance between clusters".into(),
));
}
Ok((min_i, min_j, min_dist))
}
pub(crate) fn single_linkage<F: Float + PartialOrd>(
cluster1: &Cluster,
cluster2: &Cluster,
distances: &Array1<F>,
n_samples: usize,
) -> Result<F> {
let mut min_dist = F::infinity();
for &i in &cluster1.members {
for &j in &cluster2.members {
let (min_idx, max_idx) = if i < j { (i, j) } else { (j, i) };
let idx = coords_to_condensed_index(n_samples, min_idx, max_idx)?;
let dist = distances[idx];
if dist < min_dist {
min_dist = dist;
}
}
}
Ok(min_dist)
}
pub(crate) fn complete_linkage<F: Float + PartialOrd>(
cluster1: &Cluster,
cluster2: &Cluster,
distances: &Array1<F>,
n_samples: usize,
) -> Result<F> {
let mut max_dist = F::neg_infinity();
for &i in &cluster1.members {
for &j in &cluster2.members {
let (min_idx, max_idx) = if i < j { (i, j) } else { (j, i) };
let idx = coords_to_condensed_index(n_samples, min_idx, max_idx)?;
let dist = distances[idx];
if dist > max_dist {
max_dist = dist;
}
}
}
Ok(max_dist)
}
pub(crate) fn average_linkage<F: Float + FromPrimitive>(
cluster1: &Cluster,
cluster2: &Cluster,
distances: &Array1<F>,
n_samples: usize,
) -> Result<F> {
let mut sum_dist = F::zero();
let mut count = 0;
for &i in &cluster1.members {
for &j in &cluster2.members {
let (min_idx, max_idx) = if i < j { (i, j) } else { (j, i) };
let idx = coords_to_condensed_index(n_samples, min_idx, max_idx)?;
sum_dist = sum_dist + distances[idx];
count += 1;
}
}
Ok(sum_dist / F::from_usize(count).expect("Operation failed"))
}
pub(crate) fn ward_linkage<F: Float + FromPrimitive>(
cluster1: &Cluster,
cluster2: &Cluster,
distances: &Array1<F>,
n_samples: usize,
) -> Result<F> {
let size1 = F::from_usize(cluster1.size).expect("Operation failed");
let size2 = F::from_usize(cluster2.size).expect("Operation failed");
let mut sum_dist = F::zero();
for &i in &cluster1.members {
for &j in &cluster2.members {
let (min_idx, max_idx) = if i < j { (i, j) } else { (j, i) };
let idx = coords_to_condensed_index(n_samples, min_idx, max_idx)?;
let dist = distances[idx];
sum_dist = sum_dist + dist * dist;
}
}
let avg_dist_sq = sum_dist / (size1 * size2);
let factor = (size1 * size2) / (size1 + size2);
Ok((factor * avg_dist_sq).sqrt())
}
pub(crate) fn centroid_linkage<F: Float>(
cluster1: usize,
cluster2: usize,
centroids: &Array2<F>,
) -> F {
centroids[[cluster1, cluster2]]
}
pub(crate) fn median_linkage<F: Float>(
cluster1: usize,
cluster2: usize,
centroids: &Array2<F>,
) -> F {
centroids[[cluster1, cluster2]]
}
pub(crate) fn weighted_linkage<F: Float + FromPrimitive>(
cluster1: &Cluster,
cluster2: &Cluster,
distances: &Array1<F>,
n_samples: usize,
) -> Result<F> {
average_linkage(cluster1, cluster2, distances, n_samples)
}
pub(crate) fn update_centroid<F: Float + FromPrimitive>(
centroids: &mut Array2<F>,
_method: LinkageMethod,
n_samples: usize,
new_cluster_id: usize,
) {
centroids[[new_cluster_id, 0]] = F::one();
}