use scirs2_core::ndarray::{Array1, Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::collections::BinaryHeap;
use std::fmt::Debug;
use crate::error::{ClusteringError, Result};
use crate::hierarchy::Metric;
#[derive(Debug, Clone)]
struct ClusterPair<F: Float> {
distance: F,
cluster1: usize,
cluster2: usize,
#[allow(dead_code)]
timestamp: usize, }
impl<F: Float> PartialEq for ClusterPair<F> {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl<F: Float> Eq for ClusterPair<F> {}
impl<F: Float> Ord for ClusterPair<F> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other
.distance
.partial_cmp(&self.distance)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
impl<F: Float> PartialOrd for ClusterPair<F> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone)]
struct WardCluster<F: Float> {
size: usize,
sum_coords: Array1<F>,
sum_squared: F,
active: bool,
#[allow(dead_code)]
timestamp: usize,
}
impl<F: Float + FromPrimitive + ScalarOperand + 'static> WardCluster<F> {
fn new(point: &Array1<F>, timestamp: usize) -> Self {
let sum_squared = point.dot(point);
Self {
size: 1,
sum_coords: point.clone(),
sum_squared,
active: true,
timestamp,
}
}
fn merge(&self, other: &Self, timestamp: usize) -> Self {
Self {
size: self.size + other.size,
sum_coords: &self.sum_coords + &other.sum_coords,
sum_squared: self.sum_squared + other.sum_squared,
active: true,
timestamp,
}
}
fn centroid(&self) -> Array1<F> {
&self.sum_coords / F::from(self.size).expect("Failed to convert to float")
}
fn ward_distance(&self, other: &Self) -> F {
if !self.active || !other.active {
return F::infinity();
}
let n1 = F::from(self.size).expect("Failed to convert to float");
let n2 = F::from(other.size).expect("Failed to convert to float");
let n_total = n1 + n2;
let centroid1 = self.centroid();
let centroid2 = other.centroid();
let diff = ¢roid1 - ¢roid2;
let dist_sq = diff.dot(&diff);
let ward_dist = (n1 * n2 / n_total) * dist_sq;
ward_dist.sqrt()
}
}
#[allow(dead_code)]
pub fn optimized_ward_linkage<F>(
data: ArrayView2<F>,
_metric: Metric, ) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + PartialOrd + Send + Sync + ScalarOperand + 'static,
{
let n_samples = data.shape()[0];
let _n_features = data.shape()[1];
if n_samples < 2 {
return Err(ClusteringError::InvalidInput(
"Need at least 2 samples for hierarchical clustering".into(),
));
}
let mut clusters: Vec<WardCluster<F>> = Vec::with_capacity(2 * n_samples - 1);
for i in 0..n_samples {
let point = data.row(i).to_owned();
clusters.push(WardCluster::new(&point, 0)); }
let mut timestamp = 1;
let mut heap: BinaryHeap<ClusterPair<F>> = BinaryHeap::new();
for i in 0..n_samples {
for j in (i + 1)..n_samples {
let distance = clusters[i].ward_distance(&clusters[j]);
if distance.is_finite() {
heap.push(ClusterPair {
distance,
cluster1: i,
cluster2: j,
timestamp: 0, });
}
}
}
if heap.is_empty() {
return Err(ClusteringError::ComputationError(
"No valid initial cluster pairs found - all distances are infinite".into(),
));
}
let mut linkage_matrix = Array2::zeros((n_samples - 1, 4));
let mut _next_cluster_id = n_samples;
for merge_step in 0..(n_samples - 1) {
let (cluster1_id, cluster2_id, min_distance) = loop {
if let Some(pair) = heap.pop() {
if pair.cluster1 < clusters.len()
&& pair.cluster2 < clusters.len()
&& clusters[pair.cluster1].active
&& clusters[pair.cluster2].active
{
break (pair.cluster1, pair.cluster2, pair.distance);
}
} else {
return Err(ClusteringError::ComputationError(format!(
"No valid cluster pairs found in priority queue at merge step {}",
merge_step
)));
}
};
let cluster1 = &clusters[cluster1_id];
let cluster2 = &clusters[cluster2_id];
linkage_matrix[[merge_step, 0]] = F::from(cluster1_id).expect("Failed to convert to float");
linkage_matrix[[merge_step, 1]] = F::from(cluster2_id).expect("Failed to convert to float");
linkage_matrix[[merge_step, 2]] = min_distance;
linkage_matrix[[merge_step, 3]] =
F::from(cluster1.size + cluster2.size).expect("Failed to convert to float");
let merged_cluster = cluster1.merge(cluster2, timestamp);
clusters[cluster1_id].active = false;
clusters[cluster2_id].active = false;
clusters.push(merged_cluster);
for i in 0..clusters.len() - 1 {
if clusters[i].active {
let distance = clusters[i].ward_distance(&clusters[clusters.len() - 1]);
heap.push(ClusterPair {
distance,
cluster1: i,
cluster2: clusters.len() - 1,
timestamp,
});
}
}
timestamp += 1;
}
Ok(linkage_matrix)
}
#[allow(dead_code)]
pub fn lance_williams_ward_update<F: Float + FromPrimitive>(
dist_ik: F,
dist_jk: F,
dist_ij: F,
size_i: usize,
size_j: usize,
size_k: usize,
) -> F {
let ni = F::from(size_i).expect("Failed to convert to float");
let nj = F::from(size_j).expect("Failed to convert to float");
let nk = F::from(size_k).expect("Failed to convert to float");
let nij = ni + nj;
let alpha_i = (ni + nk) / (nij + nk);
let alpha_j = (nj + nk) / (nij + nk);
let beta = -nk / (nij + nk);
let new_dist_sq =
alpha_i * dist_ik * dist_ik + alpha_j * dist_jk * dist_jk + beta * dist_ij * dist_ij;
new_dist_sq.max(F::zero()).sqrt()
}
#[allow(dead_code)]
pub fn memory_efficient_ward_linkage<F>(
data: ArrayView2<F>,
max_memory_mb: usize,
) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + PartialOrd + Send + Sync + ScalarOperand + 'static,
{
let n_samples = data.shape()[0];
let _n_features = data.shape()[1];
let distance_matrix_size = n_samples * (n_samples - 1) / 2;
let memory_per_float = std::mem::size_of::<F>();
let estimated_memory_mb = (distance_matrix_size * memory_per_float).div_ceil(1024 * 1024);
if estimated_memory_mb > max_memory_mb {
return Err(ClusteringError::InvalidInput(format!(
"Dataset requires approximately {} MB but limit is {} MB. \
Consider using a different clustering algorithm for large datasets.",
estimated_memory_mb, max_memory_mb
)));
}
optimized_ward_linkage(data, Metric::Euclidean)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_optimized_ward_simple() {
let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
.expect("Operation failed");
let linkage_matrix =
optimized_ward_linkage(data.view(), Metric::Euclidean).expect("Operation failed");
assert_eq!(linkage_matrix.shape(), &[3, 4]);
for i in 0..2 {
assert!(linkage_matrix[[i, 2]] >= 0.0);
if i > 0 {
assert!(linkage_matrix[[i, 2]] >= linkage_matrix[[i - 1, 2]]);
}
}
for i in 0..3 {
assert!(linkage_matrix[[i, 3]] >= 2.0); }
}
#[test]
fn test_ward_cluster_creation() {
let point = Array1::from_vec(vec![1.0, 2.0]);
let cluster = WardCluster::new(&point, 0);
assert_eq!(cluster.size, 1);
assert_eq!(cluster.sum_coords, point);
assert_eq!(cluster.sum_squared, 5.0); assert!(cluster.active);
}
#[test]
fn test_ward_cluster_merge() {
let point1 = Array1::from_vec(vec![1.0, 2.0]);
let point2 = Array1::from_vec(vec![3.0, 4.0]);
let cluster1 = WardCluster::new(&point1, 0);
let cluster2 = WardCluster::new(&point2, 1);
let merged = cluster1.merge(&cluster2, 2);
assert_eq!(merged.size, 2);
assert_eq!(merged.sum_coords, Array1::from_vec(vec![4.0, 6.0]));
assert_eq!(merged.sum_squared, 30.0); assert!(merged.active);
assert_eq!(merged.timestamp, 2);
}
#[test]
fn test_lance_williams_update() {
let dist_ik = 2.0;
let dist_jk = 3.0;
let dist_ij = 1.0;
let updated_dist = lance_williams_ward_update(
dist_ik, dist_jk, dist_ij, 2, 3, 4, );
assert!(updated_dist > 0.0);
assert!(updated_dist.is_finite());
}
#[test]
fn test_memory_efficient_ward() {
let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
.expect("Operation failed");
let result = memory_efficient_ward_linkage(data.view(), 100);
assert!(result.is_ok());
let result = memory_efficient_ward_linkage(data.view(), 0);
assert!(result.is_err());
}
#[test]
fn test_ward_distance_calculation() {
let point1 = Array1::from_vec(vec![0.0, 0.0]);
let point2 = Array1::from_vec(vec![1.0, 1.0]);
let cluster1 = WardCluster::new(&point1, 0);
let cluster2 = WardCluster::new(&point2, 1);
let distance = cluster1.ward_distance(&cluster2);
assert!(distance > 0.0);
assert!(distance.is_finite());
}
#[test]
fn test_optimized_ward_identical_points() {
let data = Array2::from_shape_vec((3, 2), vec![1.0, 1.0, 1.0, 1.0, 2.0, 2.0])
.expect("Operation failed");
let result = optimized_ward_linkage(data.view(), Metric::Euclidean);
assert!(result.is_ok());
let linkage_matrix = result.expect("Operation failed");
assert_eq!(linkage_matrix.shape(), &[2, 4]);
assert_eq!(linkage_matrix[[0, 2]], 0.0);
}
}