use crate::utils::{cosine_similarity, l2_normalize};
use std::collections::HashMap;
pub fn agglomerative_cluster(embeddings: &[Vec<f32>], threshold: f32) -> Vec<usize> {
ahc_impl(embeddings, threshold).0
}
pub fn agglomerative_cluster_auto(embeddings: &[Vec<f32>]) -> (Vec<usize>, f32) {
let n = embeddings.len();
if n == 0 {
return (Vec::new(), 0.0);
}
let threshold = estimate_threshold_from_similarities(embeddings);
ahc_impl(embeddings, threshold)
}
fn ahc_impl(embeddings: &[Vec<f32>], threshold: f32) -> (Vec<usize>, f32) {
let n = embeddings.len();
if n == 0 {
return (Vec::new(), 0.0);
}
let mut labels: Vec<usize> = (0..n).collect();
let mut centroids: Vec<Vec<f32>> = embeddings.to_vec();
let mut cluster_sizes: Vec<usize> = vec![1; n];
let mut active: Vec<bool> = vec![true; n];
let mut merge_history: Vec<f32> = Vec::with_capacity(n.saturating_sub(1));
loop {
let mut best_sim = f32::NEG_INFINITY;
let mut best_i = 0;
let mut best_j = 0;
let active_indices: Vec<usize> = (0..n).filter(|&i| active[i]).collect();
for (ii, &i) in active_indices.iter().enumerate() {
for &j in &active_indices[ii + 1..] {
let sim = cosine_similarity(¢roids[i], ¢roids[j]);
if sim > best_sim {
best_sim = sim;
best_i = i;
best_j = j;
}
}
}
if best_sim < threshold {
break;
}
let total = cluster_sizes[best_i] + cluster_sizes[best_j];
let w_i = cluster_sizes[best_i] as f32 / total as f32;
let w_j = cluster_sizes[best_j] as f32 / total as f32;
let mut new_centroid = vec![0.0f32; centroids[best_i].len()];
for (k, v) in new_centroid.iter_mut().enumerate() {
*v = centroids[best_i][k] * w_i + centroids[best_j][k] * w_j;
}
l2_normalize(&mut new_centroid);
centroids[best_i] = new_centroid;
cluster_sizes[best_i] = total;
active[best_j] = false;
for label in &mut labels {
if *label == best_j {
*label = best_i;
}
}
merge_history.push(best_sim);
}
let mut label_map = HashMap::new();
let mut next_label = 0usize;
for label in &mut labels {
let entry = label_map.entry(*label).or_insert_with(|| {
let l = next_label;
next_label += 1;
l
});
*label = *entry;
}
(labels, threshold)
}
fn estimate_threshold_from_similarities(embeddings: &[Vec<f32>]) -> f32 {
let n = embeddings.len();
if n < 2 {
return 0.5;
}
let mut sims: Vec<f32> = Vec::with_capacity(n * (n - 1) / 2);
for i in 0..n {
for j in (i + 1)..n {
sims.push(cosine_similarity(&embeddings[i], &embeddings[j]));
}
}
sims.sort_by(|a, b| a.total_cmp(b));
if sims.is_empty() {
return 0.5;
}
let median_idx = sims.len() / 2;
let mut best_gap = 0.0f32;
let mut best_idx = 0usize;
for i in 0..median_idx.saturating_sub(1) {
let gap = sims[i + 1] - sims[i];
if gap > best_gap {
best_gap = gap;
best_idx = i;
}
}
if sims.len() <= 1 {
return 0.5;
}
let th = sims[best_idx + 1];
th.clamp(0.2, 0.7)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agglomerative_cluster_basic() {
let embeddings = vec![
vec![1.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.1, 0.9, 0.0],
];
let labels = agglomerative_cluster(&embeddings, 0.5);
assert_eq!(labels.len(), 4);
assert_eq!(labels.iter().copied().max(), Some(1));
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[2], labels[3]);
assert_ne!(labels[0], labels[2]);
}
#[test]
fn test_agglomerative_cluster_empty() {
let labels = agglomerative_cluster(&[], 0.5);
assert!(labels.is_empty());
}
#[test]
fn test_agglomerative_cluster_single() {
let embeddings = vec![vec![1.0, 0.0, 0.0]];
let labels = agglomerative_cluster(&embeddings, 0.5);
assert_eq!(labels, vec![0]);
}
#[test]
fn test_agglomerative_cluster_auto_basic() {
let embeddings = vec![
vec![1.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.1, 0.9, 0.0],
];
let (labels, th) = agglomerative_cluster_auto(&embeddings);
assert_eq!(labels.len(), 4);
assert!((0.2..=0.7).contains(&th), "threshold {} out of range", th);
}
}