use crate::error::{ClusterError, ClusterResult};
use crate::traits::{
ClusteringAlgorithm, ClusteringResult, Fit, FitPredict, HierarchicalClustering,
};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use torsh_tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum Linkage {
Ward,
Complete,
Average,
Single,
}
#[derive(Debug, Clone)]
pub struct HierarchicalResult {
pub labels: Tensor,
pub n_clusters: usize,
pub linkage_matrix: Option<Tensor>,
}
impl ClusteringResult for HierarchicalResult {
fn labels(&self) -> &Tensor {
&self.labels
}
fn n_clusters(&self) -> usize {
self.n_clusters
}
}
#[derive(Debug, Clone)]
pub struct AgglomerativeClustering {
n_clusters: usize,
linkage: Linkage,
fitted: bool,
}
impl AgglomerativeClustering {
pub fn new(n_clusters: usize) -> Self {
Self {
n_clusters,
linkage: Linkage::Ward,
fitted: false,
}
}
pub fn linkage(mut self, linkage: Linkage) -> Self {
self.linkage = linkage;
self
}
}
impl ClusteringAlgorithm for AgglomerativeClustering {
fn name(&self) -> &str {
"Agglomerative Clustering"
}
fn get_params(&self) -> HashMap<String, String> {
let mut params = HashMap::new();
params.insert("n_clusters".to_string(), self.n_clusters.to_string());
params.insert("linkage".to_string(), format!("{:?}", self.linkage));
params
}
fn set_params(&mut self, _params: HashMap<String, String>) -> ClusterResult<()> {
Ok(())
}
fn is_fitted(&self) -> bool {
self.fitted
}
}
impl Fit for AgglomerativeClustering {
type Result = HierarchicalResult;
fn fit(&self, data: &Tensor) -> ClusterResult<Self::Result> {
self.validate_input(data)?;
let data_vec = data.to_vec().map_err(ClusterError::TensorError)?;
let shape = data.shape();
let data_shape = shape.dims();
if data_shape.len() != 2 {
return Err(ClusterError::InvalidInput(
"Data tensor must be 2-dimensional".to_string(),
));
}
let n_samples = data_shape[0];
let n_features = data_shape[1];
if self.n_clusters > n_samples {
return Err(ClusterError::InvalidInput(
"Number of clusters cannot exceed number of samples".to_string(),
));
}
let data_array =
Array2::from_shape_vec((n_samples, n_features), data_vec).map_err(|e| {
ClusterError::InvalidInput(format!("Failed to reshape data array: {}", e))
})?;
let labels = self.perform_agglomerative_clustering(&data_array)?;
let labels_vec: Vec<f32> = labels.iter().map(|&x| x as f32).collect();
let labels_tensor = Tensor::from_vec(labels_vec, &[n_samples])?;
Ok(HierarchicalResult {
labels: labels_tensor,
n_clusters: self.n_clusters,
linkage_matrix: None, })
}
}
impl AgglomerativeClustering {
fn perform_agglomerative_clustering(&self, data: &Array2<f32>) -> ClusterResult<Vec<usize>> {
let n_samples = data.nrows();
let mut clusters: Vec<Vec<usize>> = (0..n_samples).map(|i| vec![i]).collect();
let mut cluster_labels = (0..n_samples).collect::<Vec<usize>>();
let mut distance_matrix = self.compute_initial_distances(data)?;
let mut current_n_clusters = n_samples;
while current_n_clusters > self.n_clusters {
let (cluster1_idx, cluster2_idx) = self.find_closest_clusters(&distance_matrix)?;
let merged_cluster = self.merge_clusters(&clusters, cluster1_idx, cluster2_idx);
for &point_idx in &merged_cluster {
cluster_labels[point_idx] = cluster1_idx;
}
clusters[cluster1_idx] = merged_cluster;
clusters.remove(cluster2_idx);
self.update_distance_matrix(
data,
&mut distance_matrix,
&clusters,
cluster1_idx,
cluster2_idx,
)?;
current_n_clusters -= 1;
}
let mut final_labels = vec![0; n_samples];
for (cluster_id, cluster) in clusters.iter().enumerate() {
for &point_idx in cluster {
final_labels[point_idx] = cluster_id;
}
}
Ok(final_labels)
}
fn compute_initial_distances(&self, data: &Array2<f32>) -> ClusterResult<Vec<Vec<f64>>> {
let n_samples = data.nrows();
let mut distances = vec![vec![0.0; n_samples]; n_samples];
for i in 0..n_samples {
for j in (i + 1)..n_samples {
let dist = self.euclidean_distance(&data.row(i), &data.row(j));
distances[i][j] = dist;
distances[j][i] = dist;
}
}
Ok(distances)
}
fn euclidean_distance(&self, point1: &ArrayView1<f32>, point2: &ArrayView1<f32>) -> f64 {
let mut sum_sq = 0.0_f64;
for (&a, &b) in point1.iter().zip(point2.iter()) {
let diff = a as f64 - b as f64;
sum_sq += diff * diff;
}
sum_sq.sqrt()
}
fn find_closest_clusters(&self, distance_matrix: &[Vec<f64>]) -> ClusterResult<(usize, usize)> {
let n_clusters = distance_matrix.len();
let mut min_distance = f64::INFINITY;
let mut closest_pair = (0, 1);
#[allow(clippy::needless_range_loop)]
for i in 0..n_clusters {
for j in (i + 1)..n_clusters {
if distance_matrix[i][j] < min_distance {
min_distance = distance_matrix[i][j];
closest_pair = (i, j);
}
}
}
if min_distance == f64::INFINITY {
return Err(ClusterError::InvalidInput(
"Could not find closest clusters".to_string(),
));
}
Ok(closest_pair)
}
fn merge_clusters(&self, clusters: &[Vec<usize>], idx1: usize, idx2: usize) -> Vec<usize> {
let mut merged = clusters[idx1].clone();
merged.extend_from_slice(&clusters[idx2]);
merged
}
fn update_distance_matrix(
&self,
data: &Array2<f32>,
distance_matrix: &mut Vec<Vec<f64>>,
clusters: &[Vec<usize>],
merged_idx: usize,
removed_idx: usize,
) -> ClusterResult<()> {
let n_clusters = clusters.len();
for i in 0..n_clusters {
if i != merged_idx {
let new_distance =
self.compute_cluster_distance(data, &clusters[merged_idx], &clusters[i])?;
distance_matrix[merged_idx][i] = new_distance;
distance_matrix[i][merged_idx] = new_distance;
}
}
distance_matrix.remove(removed_idx);
for row in distance_matrix.iter_mut() {
row.remove(removed_idx);
}
Ok(())
}
fn compute_cluster_distance(
&self,
data: &Array2<f32>,
cluster1: &[usize],
cluster2: &[usize],
) -> ClusterResult<f64> {
match self.linkage {
Linkage::Single => {
let mut min_dist = f64::INFINITY;
for &i in cluster1 {
for &j in cluster2 {
let dist = self.euclidean_distance(&data.row(i), &data.row(j));
min_dist = min_dist.min(dist);
}
}
Ok(min_dist)
}
Linkage::Complete => {
let mut max_dist = 0.0_f64;
for &i in cluster1 {
for &j in cluster2 {
let dist = self.euclidean_distance(&data.row(i), &data.row(j));
max_dist = max_dist.max(dist);
}
}
Ok(max_dist)
}
Linkage::Average => {
let mut total_dist = 0.0;
let mut count = 0;
for &i in cluster1 {
for &j in cluster2 {
let dist = self.euclidean_distance(&data.row(i), &data.row(j));
total_dist += dist;
count += 1;
}
}
Ok(total_dist / count as f64)
}
Linkage::Ward => {
let centroid1 = self.compute_centroid(data, cluster1);
let centroid2 = self.compute_centroid(data, cluster2);
let centroid_dist = self.euclidean_distance_arrays(¢roid1, ¢roid2);
let n1 = cluster1.len() as f64;
let n2 = cluster2.len() as f64;
let weight = (n1 * n2) / (n1 + n2);
Ok(weight * centroid_dist * centroid_dist)
}
}
}
fn compute_centroid(&self, data: &Array2<f32>, cluster: &[usize]) -> Array1<f64> {
let n_features = data.ncols();
let mut centroid = Array1::zeros(n_features);
for &point_idx in cluster {
let point = data.row(point_idx);
for (i, &value) in point.iter().enumerate() {
centroid[i] += value as f64;
}
}
let cluster_size = cluster.len() as f64;
for value in centroid.iter_mut() {
*value /= cluster_size;
}
centroid
}
fn euclidean_distance_arrays(&self, a: &Array1<f64>, b: &Array1<f64>) -> f64 {
let mut sum_sq = 0.0;
for (&x, &y) in a.iter().zip(b.iter()) {
let diff = x - y;
sum_sq += diff * diff;
}
sum_sq.sqrt()
}
}
impl FitPredict for AgglomerativeClustering {
type Result = HierarchicalResult;
fn fit_predict(&self, data: &Tensor) -> ClusterResult<Self::Result> {
self.fit(data)
}
}
impl HierarchicalClustering for AgglomerativeClustering {
type Tree = Tensor;
fn extract_flat_clustering(&self, _n_clusters: usize) -> ClusterResult<Tensor> {
Err(ClusterError::NotImplemented(
"extract_flat_clustering not yet implemented".to_string(),
))
}
fn extract_clustering_by_distance(&self, _threshold: f64) -> ClusterResult<Tensor> {
Err(ClusterError::NotImplemented(
"extract_clustering_by_distance not yet implemented".to_string(),
))
}
}