use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use rayon::prelude::*;
use super::checkpoint::{ExtractionPhase, TopicExtractionCheckpoint};
use super::config::{ClusteringConfig, LinkageMethod};
use super::dendrogram::Dendrogram;
use super::{Result, TopicError};
pub struct AtomicDistanceMatrix {
distances: Vec<AtomicU64>,
n: usize,
}
impl AtomicDistanceMatrix {
pub fn new(n: usize) -> Self {
let size = n * (n - 1) / 2;
let distances = (0..size).map(|_| AtomicU64::new(0)).collect();
Self { distances, n }
}
pub fn from_checkpoint(distances: &[f32], n: usize) -> Self {
let atomic_distances: Vec<AtomicU64> = distances
.iter()
.map(|&d| AtomicU64::new((d as f64).to_bits()))
.collect();
Self {
distances: atomic_distances,
n,
}
}
#[inline]
pub const fn condensed_index(i: usize, j: usize, n: usize) -> usize {
debug_assert!(i < j);
n * i - i * (i + 1) / 2 + j - i - 1
}
#[inline]
pub fn get(&self, i: usize, j: usize) -> f64 {
if i == j {
return 0.0;
}
let (i, j) = if i < j { (i, j) } else { (j, i) };
let idx = Self::condensed_index(i, j, self.n);
f64::from_bits(self.distances[idx].load(Ordering::Relaxed))
}
#[inline]
pub fn set(&self, i: usize, j: usize, dist: f64) {
if i == j {
return;
}
let (i, j) = if i < j { (i, j) } else { (j, i) };
let idx = Self::condensed_index(i, j, self.n);
self.distances[idx].store(dist.to_bits(), Ordering::Relaxed);
}
#[inline]
pub fn n(&self) -> usize {
self.n
}
#[inline]
pub fn len(&self) -> usize {
self.distances.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.distances.is_empty()
}
pub fn to_vec(&self) -> Vec<f32> {
self.distances
.iter()
.map(|d| f64::from_bits(d.load(Ordering::Relaxed)) as f32)
.collect()
}
}
#[inline]
pub fn cosine_distance(a: &[f32], b: &[f32]) -> f64 {
debug_assert_eq!(a.len(), b.len());
let mut dot = 0.0f64;
let mut norm_a = 0.0f64;
let mut norm_b = 0.0f64;
for (x, y) in a.iter().zip(b.iter()) {
let x = *x as f64;
let y = *y as f64;
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10);
(1.0 - dot / denom).max(0.0)
}
pub fn compute_distance_matrix_parallel(
embeddings: &[Vec<f32>],
progress: Option<&AtomicUsize>,
) -> AtomicDistanceMatrix {
let n = embeddings.len();
let matrix = AtomicDistanceMatrix::new(n);
let total_pairs = n * (n - 1) / 2;
(0..n).into_par_iter().for_each(|i| {
for j in (i + 1)..n {
let dist = cosine_distance(&embeddings[i], &embeddings[j]);
matrix.set(i, j, dist);
}
if let Some(prog) = progress {
prog.fetch_add(n - i - 1, Ordering::Relaxed);
}
});
std::sync::atomic::fence(Ordering::Release);
if let Some(prog) = progress {
prog.store(total_pairs, Ordering::Release);
}
matrix
}
#[derive(Clone)]
pub struct ClusterState {
pub assignments: Vec<u32>,
pub sizes: Vec<usize>,
pub num_active: usize,
pub next_cluster_id: u32,
}
impl ClusterState {
pub fn new(n: usize) -> Self {
Self {
assignments: (0..n as u32).collect(),
sizes: vec![1; n],
num_active: n,
next_cluster_id: n as u32,
}
}
pub fn merge(&mut self, i: u32, j: u32) -> u32 {
let new_id = self.next_cluster_id;
self.next_cluster_id += 1;
let size_i = self.sizes[i as usize];
let size_j = self.sizes[j as usize];
while self.sizes.len() <= new_id as usize {
self.sizes.push(0);
}
self.sizes[new_id as usize] = size_i + size_j;
self.sizes[i as usize] = 0;
self.sizes[j as usize] = 0;
for assignment in &mut self.assignments {
if *assignment == i || *assignment == j {
*assignment = new_id;
}
}
self.num_active -= 1;
new_id
}
#[inline]
pub fn is_active(&self, cluster: u32) -> bool {
(cluster as usize) < self.sizes.len() && self.sizes[cluster as usize] > 0
}
pub fn active_clusters(&self) -> Vec<u32> {
self.sizes
.iter()
.enumerate()
.filter(|(_, &size)| size > 0)
.map(|(i, _)| i as u32)
.collect()
}
}
pub struct ActiveDistanceMatrix {
distances: std::collections::HashMap<(u32, u32), f64>,
min_dist: f64,
min_pair: Option<(u32, u32)>,
}
impl ActiveDistanceMatrix {
pub fn from_initial(matrix: &AtomicDistanceMatrix) -> Self {
let n = matrix.n();
let mut distances = std::collections::HashMap::with_capacity(n * (n - 1) / 2);
let mut min_dist = f64::MAX;
let mut min_pair = None;
for i in 0..n {
for j in (i + 1)..n {
let dist = matrix.get(i, j);
let key = (i as u32, j as u32);
distances.insert(key, dist);
if dist < min_dist {
min_dist = dist;
min_pair = Some(key);
}
}
}
Self {
distances,
min_dist,
min_pair,
}
}
pub fn get(&self, i: u32, j: u32) -> Option<f64> {
let key = if i < j { (i, j) } else { (j, i) };
self.distances.get(&key).copied()
}
pub fn set(&mut self, i: u32, j: u32, dist: f64) {
let key = if i < j { (i, j) } else { (j, i) };
self.distances.insert(key, dist);
}
pub fn remove_cluster(&mut self, cluster: u32) {
self.distances
.retain(|&(i, j), _| i != cluster && j != cluster);
}
pub fn find_minimum(&mut self) -> Option<(u32, u32, f64)> {
if self.min_pair.is_none() {
self.min_dist = f64::MAX;
for (&(i, j), &dist) in &self.distances {
if dist < self.min_dist {
self.min_dist = dist;
self.min_pair = Some((i, j));
}
}
}
self.min_pair.map(|(i, j)| (i, j, self.min_dist))
}
pub fn invalidate_minimum(&mut self) {
self.min_pair = None;
}
}
#[inline]
pub fn linkage_distance(
method: LinkageMethod,
dist_ik: f64,
dist_jk: f64,
size_i: usize,
size_j: usize,
size_k: usize,
dist_ij: f64,
) -> f64 {
match method {
LinkageMethod::Single => dist_ik.min(dist_jk),
LinkageMethod::Complete => dist_ik.max(dist_jk),
LinkageMethod::Average => {
let n_i = size_i as f64;
let n_j = size_j as f64;
(n_i * dist_ik + n_j * dist_jk) / (n_i + n_j)
}
LinkageMethod::Ward => {
let n_i = size_i as f64;
let n_j = size_j as f64;
let n_k = size_k as f64;
let n_total = n_i + n_j + n_k;
((n_i + n_k) * dist_ik + (n_j + n_k) * dist_jk - n_k * dist_ij) / n_total
}
}
}
#[derive(Clone, Debug)]
pub struct ClusteringResult {
pub linkage: Vec<(u32, u32, f32, u32)>,
pub dendrogram: Dendrogram,
pub assignments: Vec<u32>,
pub num_points: usize,
}
pub struct HierarchicalClustering {
config: ClusteringConfig,
}
impl HierarchicalClustering {
pub fn new(config: ClusteringConfig) -> Self {
Self { config }
}
pub fn cluster(&self, embeddings: &[Vec<f32>]) -> Result<ClusteringResult> {
let n = embeddings.len();
if n < 2 {
return Err(TopicError::ClusteringError(
"Need at least 2 points for clustering".to_string(),
));
}
let progress = if self.config.verbose {
Some(AtomicUsize::new(0))
} else {
None
};
let dist_matrix = compute_distance_matrix_parallel(embeddings, progress.as_ref());
self.cluster_from_distances(&dist_matrix)
}
pub fn cluster_from_distances(
&self,
dist_matrix: &AtomicDistanceMatrix,
) -> Result<ClusteringResult> {
let n = dist_matrix.n();
if n < 2 {
return Err(TopicError::ClusteringError(
"Need at least 2 points for clustering".to_string(),
));
}
let mut state = ClusterState::new(n);
let mut active_distances = ActiveDistanceMatrix::from_initial(dist_matrix);
let mut linkage: Vec<(u32, u32, f32, u32)> = Vec::with_capacity(n - 1);
for _ in 0..(n - 1) {
let Some((i, j, dist)) = active_distances.find_minimum() else {
break;
};
let size_i = state.sizes[i as usize];
let size_j = state.sizes[j as usize];
linkage.push((i, j, dist as f32, (size_i + size_j) as u32));
let matrix_n = dist_matrix.n();
let mut new_distances: Vec<(u32, f64)> = Vec::new();
let other_clusters: Vec<u32> = state
.active_clusters()
.into_iter()
.filter(|&k| k != i && k != j)
.collect();
for k in &other_clusters {
let k = *k;
let dist_ik = active_distances
.get(i, k)
.or_else(|| {
if (i as usize) < matrix_n && (k as usize) < matrix_n {
Some(dist_matrix.get(i as usize, k as usize))
} else {
None
}
})
.unwrap_or(f64::MAX);
let dist_jk = active_distances
.get(j, k)
.or_else(|| {
if (j as usize) < matrix_n && (k as usize) < matrix_n {
Some(dist_matrix.get(j as usize, k as usize))
} else {
None
}
})
.unwrap_or(f64::MAX);
let size_k = state.sizes[k as usize];
let new_dist = linkage_distance(
self.config.linkage,
dist_ik,
dist_jk,
size_i,
size_j,
size_k,
dist,
);
new_distances.push((k, new_dist));
}
let new_cluster = state.merge(i, j);
active_distances.remove_cluster(i);
active_distances.remove_cluster(j);
active_distances.invalidate_minimum();
for (k, new_dist) in new_distances {
active_distances.set(new_cluster, k, new_dist);
}
}
let dendrogram = Dendrogram::from_linkage(&linkage, n);
let assignments = if let Some(k) = self.config.num_clusters {
dendrogram.cut_to_k_clusters(k)
} else if let Some(threshold) = self.config.distance_threshold {
dendrogram.cut_at_distance(threshold)
} else {
(0..n as u32).collect()
};
Ok(ClusteringResult {
linkage,
dendrogram,
assignments,
num_points: n,
})
}
pub fn cluster_from_checkpoint(
&self,
embeddings: &[Vec<f32>],
checkpoint: &TopicExtractionCheckpoint,
) -> Result<ClusteringResult> {
let n = embeddings.len();
if n != checkpoint.num_documents {
return Err(TopicError::ClusteringError(format!(
"Document count mismatch: expected {}, got {}",
checkpoint.num_documents, n
)));
}
match checkpoint.phase {
ExtractionPhase::DistanceMatrix => {
self.cluster(embeddings)
}
ExtractionPhase::Clustering => {
if let Some(ref distances) = checkpoint.distance_matrix {
let dist_matrix = AtomicDistanceMatrix::from_checkpoint(distances, n);
self.resume_from_linkage(&dist_matrix, &checkpoint.linkage_matrix, n)
} else {
self.cluster(embeddings)
}
}
_ => {
let dendrogram = Dendrogram::from_linkage(&checkpoint.linkage_matrix, n);
let assignments = checkpoint.cluster_assignments.clone();
Ok(ClusteringResult {
linkage: checkpoint.linkage_matrix.clone(),
dendrogram,
assignments,
num_points: n,
})
}
}
}
fn resume_from_linkage(
&self,
dist_matrix: &AtomicDistanceMatrix,
partial_linkage: &[(u32, u32, f32, u32)],
n: usize,
) -> Result<ClusteringResult> {
let mut state = ClusterState::new(n);
let mut linkage = partial_linkage.to_vec();
for &(i, j, _, _) in partial_linkage {
state.merge(i, j);
}
let mut active_distances = ActiveDistanceMatrix::from_initial(dist_matrix);
for &(i, j, _, _) in partial_linkage {
active_distances.remove_cluster(i);
active_distances.remove_cluster(j);
}
active_distances.invalidate_minimum();
while state.num_active > 1 {
let Some((i, j, dist)) = active_distances.find_minimum() else {
break;
};
let size_i = state.sizes[i as usize];
let size_j = state.sizes[j as usize];
linkage.push((i, j, dist as f32, (size_i + size_j) as u32));
let matrix_n = dist_matrix.n();
let mut new_distances: Vec<(u32, f64)> = Vec::new();
let other_clusters: Vec<u32> = state
.active_clusters()
.into_iter()
.filter(|&k| k != i && k != j)
.collect();
for k in &other_clusters {
let k = *k;
let dist_ik = active_distances
.get(i, k)
.or_else(|| {
if (i as usize) < matrix_n && (k as usize) < matrix_n {
Some(dist_matrix.get(i as usize, k as usize))
} else {
None
}
})
.unwrap_or(f64::MAX);
let dist_jk = active_distances
.get(j, k)
.or_else(|| {
if (j as usize) < matrix_n && (k as usize) < matrix_n {
Some(dist_matrix.get(j as usize, k as usize))
} else {
None
}
})
.unwrap_or(f64::MAX);
let size_k = state.sizes[k as usize];
let new_dist = linkage_distance(
self.config.linkage,
dist_ik,
dist_jk,
size_i,
size_j,
size_k,
dist,
);
new_distances.push((k, new_dist));
}
let new_cluster = state.merge(i, j);
active_distances.remove_cluster(i);
active_distances.remove_cluster(j);
active_distances.invalidate_minimum();
for (k, new_dist) in new_distances {
active_distances.set(new_cluster, k, new_dist);
}
}
let dendrogram = Dendrogram::from_linkage(&linkage, n);
let assignments = if let Some(k) = self.config.num_clusters {
dendrogram.cut_to_k_clusters(k)
} else if let Some(threshold) = self.config.distance_threshold {
dendrogram.cut_at_distance(threshold)
} else {
(0..n as u32).collect()
};
Ok(ClusteringResult {
linkage,
dendrogram,
assignments,
num_points: n,
})
}
pub fn config(&self) -> &ClusteringConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_distance() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let dist = cosine_distance(&a, &b);
assert!((dist - 1.0).abs() < 1e-6);
let dist = cosine_distance(&a, &a);
assert!(dist.abs() < 1e-6);
let c = vec![-1.0, 0.0, 0.0];
let dist = cosine_distance(&a, &c);
assert!((dist - 2.0).abs() < 1e-6);
}
#[test]
fn test_atomic_distance_matrix() {
let matrix = AtomicDistanceMatrix::new(4);
matrix.set(0, 1, 0.5);
matrix.set(0, 2, 0.8);
matrix.set(1, 2, 0.3);
assert!((matrix.get(0, 1) - 0.5).abs() < 1e-6);
assert!((matrix.get(1, 0) - 0.5).abs() < 1e-6); assert!((matrix.get(1, 2) - 0.3).abs() < 1e-6);
assert!(matrix.get(0, 0).abs() < 1e-6); }
#[test]
fn test_condensed_index() {
assert_eq!(AtomicDistanceMatrix::condensed_index(0, 1, 4), 0);
assert_eq!(AtomicDistanceMatrix::condensed_index(0, 2, 4), 1);
assert_eq!(AtomicDistanceMatrix::condensed_index(0, 3, 4), 2);
assert_eq!(AtomicDistanceMatrix::condensed_index(1, 2, 4), 3);
assert_eq!(AtomicDistanceMatrix::condensed_index(1, 3, 4), 4);
assert_eq!(AtomicDistanceMatrix::condensed_index(2, 3, 4), 5);
}
#[test]
fn test_compute_distance_matrix_parallel() {
let embeddings = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
vec![1.0, 1.0, 0.0],
];
let matrix = compute_distance_matrix_parallel(&embeddings, None);
assert_eq!(matrix.n(), 4);
assert_eq!(matrix.len(), 6);
assert!((matrix.get(0, 1) - 1.0).abs() < 1e-6); assert!(matrix.get(0, 3) < 1.0); }
#[test]
fn test_cluster_state() {
let mut state = ClusterState::new(4);
assert_eq!(state.num_active, 4);
assert!(state.is_active(0));
assert!(state.is_active(3));
let new_id = state.merge(0, 1);
assert_eq!(new_id, 4);
assert_eq!(state.num_active, 3);
assert!(!state.is_active(0));
assert!(!state.is_active(1));
assert!(state.is_active(4));
assert_eq!(state.assignments[0], 4);
assert_eq!(state.assignments[1], 4);
assert_eq!(state.assignments[2], 2);
assert_eq!(state.assignments[3], 3);
}
#[test]
fn test_linkage_methods() {
let dist_ik = 1.0;
let dist_jk = 2.0;
let size_i = 2;
let size_j = 3;
let size_k = 4;
let dist_ij = 0.5;
let single = linkage_distance(
LinkageMethod::Single,
dist_ik,
dist_jk,
size_i,
size_j,
size_k,
dist_ij,
);
assert!((single - 1.0).abs() < 1e-6);
let complete = linkage_distance(
LinkageMethod::Complete,
dist_ik,
dist_jk,
size_i,
size_j,
size_k,
dist_ij,
);
assert!((complete - 2.0).abs() < 1e-6);
let average = linkage_distance(
LinkageMethod::Average,
dist_ik,
dist_jk,
size_i,
size_j,
size_k,
dist_ij,
);
let expected_avg = (2.0 * 1.0 + 3.0 * 2.0) / 5.0;
assert!((average - expected_avg).abs() < 1e-6);
}
#[test]
fn test_hierarchical_clustering() {
let embeddings = vec![
vec![1.0, 0.0],
vec![1.1, 0.0],
vec![0.0, 1.0],
vec![0.0, 1.1],
];
let config = ClusteringConfig {
num_clusters: Some(2),
linkage: LinkageMethod::Single,
..Default::default()
};
let clustering = HierarchicalClustering::new(config);
let result = clustering.cluster(&embeddings).expect("clustering failed");
assert_eq!(result.linkage.len(), 3);
assert_eq!(result.num_points, 4);
assert_eq!(result.assignments[0], result.assignments[1]);
assert_eq!(result.assignments[2], result.assignments[3]);
assert_ne!(result.assignments[0], result.assignments[2]);
}
#[test]
fn test_clustering_single_linkage() {
let embeddings = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![2.0, 0.0]];
let config = ClusteringConfig {
linkage: LinkageMethod::Single,
..Default::default()
};
let clustering = HierarchicalClustering::new(config);
let result = clustering.cluster(&embeddings).expect("clustering failed");
assert_eq!(result.linkage.len(), 2);
let (c1, c2, _, _) = result.linkage[0];
assert!(
(c1 == 0 && c2 == 1) || (c1 == 1 && c2 == 2),
"First merge should be adjacent: ({}, {})",
c1,
c2
);
}
#[test]
fn test_distance_matrix_to_vec() {
let matrix = AtomicDistanceMatrix::new(3);
matrix.set(0, 1, 0.5);
matrix.set(0, 2, 0.8);
matrix.set(1, 2, 0.3);
let vec = matrix.to_vec();
assert_eq!(vec.len(), 3);
let matrix2 = AtomicDistanceMatrix::from_checkpoint(&vec, 3);
assert!((matrix2.get(0, 1) - 0.5).abs() < 1e-6);
assert!((matrix2.get(0, 2) - 0.8).abs() < 1e-6);
assert!((matrix2.get(1, 2) - 0.3).abs() < 1e-6);
}
#[test]
fn test_active_distance_matrix() {
let initial = AtomicDistanceMatrix::new(4);
initial.set(0, 1, 0.5);
initial.set(0, 2, 0.8);
initial.set(0, 3, 1.0);
initial.set(1, 2, 0.3);
initial.set(1, 3, 0.9);
initial.set(2, 3, 0.7);
let mut active = ActiveDistanceMatrix::from_initial(&initial);
let min = active.find_minimum().expect("should have minimum");
assert_eq!(min.0, 1);
assert_eq!(min.1, 2);
assert!((min.2 - 0.3).abs() < 1e-6);
active.remove_cluster(1);
active.invalidate_minimum();
let min = active.find_minimum().expect("should have minimum");
assert_ne!(min.0, 1);
assert_ne!(min.1, 1);
}
}