use crate::utils::{cosine_similarity, l2_normalize};
use std::collections::HashMap;
pub fn agglomerative_cluster(embeddings: &[Vec<f32>], threshold: f32) -> Vec<usize> {
ahc_impl(embeddings, threshold, 0).0
}
pub fn agglomerative_cluster_max_clusters(
embeddings: &[Vec<f32>],
threshold: f32,
max_clusters: usize,
) -> Vec<usize> {
ahc_impl(embeddings, threshold, max_clusters).0
}
pub fn agglomerative_cluster_auto(embeddings: &[Vec<f32>]) -> (Vec<usize>, f32) {
agglomerative_cluster_auto_max_clusters(embeddings, 0)
}
pub fn agglomerative_cluster_auto_max_clusters(
embeddings: &[Vec<f32>],
max_clusters: usize,
) -> (Vec<usize>, f32) {
let n = embeddings.len();
if n == 0 {
return (Vec::new(), 0.0);
}
let threshold = estimate_threshold_from_similarities(embeddings);
ahc_impl(embeddings, threshold, max_clusters)
}
#[allow(clippy::needless_range_loop)]
fn ahc_impl(embeddings: &[Vec<f32>], threshold: f32, max_clusters: usize) -> (Vec<usize>, f32) {
let n = embeddings.len();
if n == 0 {
return (Vec::new(), 0.0);
}
let dim = embeddings[0].len();
if !embeddings.iter().all(|e| e.len() == dim) {
return (vec![0; n], 0.0);
}
let mut labels: Vec<usize> = (0..n).collect();
let mut centroids: Vec<Vec<f32>> = embeddings.to_vec();
let mut cluster_sizes: Vec<usize> = vec![1; n];
let mut active: Vec<bool> = vec![true; n];
let neg_inf = f32::NEG_INFINITY;
let mut sim_matrix: Vec<Vec<f32>> = vec![vec![neg_inf; n]; n];
for i in 0..n {
sim_matrix[i][i] = 1.0;
for j in (i + 1)..n {
let sim = cosine_similarity(¢roids[i], ¢roids[j]);
sim_matrix[i][j] = sim;
sim_matrix[j][i] = sim;
}
}
loop {
let mut best_sim = neg_inf;
let mut best_i = 0;
let mut best_j = 0;
for i in 0..n {
if !active[i] {
continue;
}
for j in (i + 1)..n {
if !active[j] {
continue;
}
let sim = sim_matrix[i][j];
if sim > best_sim {
best_sim = sim;
best_i = i;
best_j = j;
}
}
}
let active_count = active.iter().filter(|&&a| a).count();
let above_ceiling = max_clusters > 0 && max_clusters < n && active_count > max_clusters;
if !above_ceiling && best_sim < threshold {
break;
}
if above_ceiling && best_sim == neg_inf {
break;
}
let total = cluster_sizes[best_i] + cluster_sizes[best_j];
let w_i = cluster_sizes[best_i] as f32 / total as f32;
let w_j = cluster_sizes[best_j] as f32 / total as f32;
let dim = centroids[best_i].len();
let mut new_centroid = vec![0.0f32; dim];
for k in 0..dim {
new_centroid[k] = centroids[best_i][k] * w_i + centroids[best_j][k] * w_j;
}
l2_normalize(&mut new_centroid);
centroids[best_i] = new_centroid;
cluster_sizes[best_i] = total;
active[best_j] = false;
for k in 0..n {
sim_matrix[best_j][k] = neg_inf;
sim_matrix[k][best_j] = neg_inf;
}
for k in 0..n {
if k == best_i || !active[k] {
continue;
}
let sim = cosine_similarity(¢roids[best_i], ¢roids[k]);
sim_matrix[best_i][k] = sim;
sim_matrix[k][best_i] = sim;
}
for label in &mut labels {
if *label == best_j {
*label = best_i;
}
}
}
let mut group: HashMap<usize, (usize, usize)> = HashMap::new(); for (idx, &label) in labels.iter().enumerate() {
let e = group.entry(label).or_insert((0, idx));
e.0 += 1;
if idx < e.1 {
e.1 = idx;
}
}
let mut order: Vec<(usize, usize, usize)> = group
.iter()
.map(|(&label, &(size, min_idx))| (size, min_idx, label))
.collect();
order.sort_by(|a, b| b.0.cmp(&a.0).then(a.1.cmp(&b.1)));
let mut canonical: HashMap<usize, usize> = HashMap::new();
for (new_id, &(_, _, label)) in order.iter().enumerate() {
canonical.insert(label, new_id);
}
for label in &mut labels {
*label = canonical[label];
}
(labels, threshold)
}
fn estimate_threshold_from_similarities(embeddings: &[Vec<f32>]) -> f32 {
let n = embeddings.len();
if n < 2 {
return 0.5;
}
let mut sims: Vec<f32> = Vec::with_capacity(n * (n - 1) / 2);
for i in 0..n {
for j in (i + 1)..n {
sims.push(cosine_similarity(&embeddings[i], &embeddings[j]));
}
}
sims.sort_by(|a, b| a.total_cmp(b));
if sims.is_empty() {
return 0.5;
}
let median_idx = sims.len() / 2;
let mut best_gap = 0.0f32;
let mut best_idx = 0usize;
for i in 0..median_idx.saturating_sub(1) {
let gap = sims[i + 1] - sims[i];
if gap > best_gap {
best_gap = gap;
best_idx = i;
}
}
if sims.len() <= 1 {
return 0.5;
}
let th = sims[best_idx + 1];
th.clamp(0.2, 0.7)
}
#[allow(clippy::unwrap_used)]
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agglomerative_cluster_basic() {
let embeddings = vec![
vec![1.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.1, 0.9, 0.0],
];
let labels = agglomerative_cluster(&embeddings, 0.5);
assert_eq!(labels.len(), 4);
assert_eq!(labels.iter().copied().max(), Some(1));
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[2], labels[3]);
assert_ne!(labels[0], labels[2]);
}
#[test]
fn test_agglomerative_cluster_empty() {
let labels = agglomerative_cluster(&[], 0.5);
assert!(labels.is_empty());
}
#[test]
fn test_agglomerative_cluster_single() {
let embeddings = vec![vec![1.0, 0.0, 0.0]];
let labels = agglomerative_cluster(&embeddings, 0.5);
assert_eq!(labels, vec![0]);
}
#[test]
fn test_agglomerative_cluster_auto_basic() {
let embeddings = vec![
vec![1.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.1, 0.9, 0.0],
];
let (labels, th) = agglomerative_cluster_auto(&embeddings);
assert_eq!(labels.len(), 4);
assert!((0.2..=0.7).contains(&th), "threshold {} out of range", th);
}
#[test]
fn test_agglomerative_cluster_auto_max_clusters_caps_count() {
let embeddings = vec![
vec![1.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.1, 0.9, 0.0],
];
let (labels, _th) = agglomerative_cluster_auto_max_clusters(&embeddings, 2);
let unique: std::collections::HashSet<usize> = labels.iter().copied().collect();
assert_eq!(
unique.len(),
2,
"max_clusters=2 must produce exactly 2 clusters"
);
}
#[test]
fn test_agglomerative_cluster_mismatched_dimensions() {
let embeddings = vec![vec![1.0, 0.0, 0.0], vec![0.9, 0.1]];
let labels = agglomerative_cluster(&embeddings, 0.5);
assert_eq!(labels, vec![0, 0]);
}
#[test]
fn cluster_ids_are_canonical_and_shuffle_invariant() {
let a = vec![1.0, 0.0, 0.0];
let a2 = vec![0.95, 0.05, 0.0];
let a3 = vec![0.9, 0.1, 0.0];
let b = vec![0.0, 1.0, 0.0];
let b2 = vec![0.05, 0.95, 0.0];
let base = vec![a.clone(), a2.clone(), a3.clone(), b.clone(), b2.clone()];
let l1 = agglomerative_cluster(&base, 0.5);
assert_eq!(l1, vec![0, 0, 0, 1, 1], "big cluster must be id 0");
let shuffled = vec![b2.clone(), a3.clone(), b.clone(), a.clone(), a2.clone()];
let l2 = agglomerative_cluster(&shuffled, 0.5);
assert_eq!(l2[1], 0, "a3 is in the big cluster -> id 0");
assert_eq!(l2[3], 0, "a is in the big cluster -> id 0");
assert_eq!(l2[4], 0, "a2 is in the big cluster -> id 0");
assert_eq!(l2[0], 1, "b2 is in the small cluster -> id 1");
assert_eq!(l2[2], 1, "b is in the small cluster -> id 1");
}
}