oxirs_embed/clustering.rs
1//! Clustering Support for Knowledge Graph Embeddings
2//!
3//! This module provides comprehensive clustering algorithms for analyzing and grouping
4//! entities based on their learned embeddings. Clustering helps discover latent
5//! structure in knowledge graphs and can improve downstream tasks such as entity
6//! type discovery, knowledge organization, and recommendation systems.
7//!
8//! # Overview
9//!
10//! The module provides four powerful clustering algorithms:
11//! - **K-Means**: Fast, spherical clusters with K-Means++ initialization
12//! - **Hierarchical**: Bottom-up agglomerative clustering with linkage methods
13//! - **DBSCAN**: Density-based clustering that discovers arbitrary shapes and handles noise
14//! - **Spectral**: Graph-based clustering using eigenvalues of similarity matrices
15//!
16//! Each algorithm is suited for different data characteristics and use cases.
17//!
18//! # Quick Start
19//!
20//! ```rust,no_run
21//! use oxirs_embed::{
22//! TransE, ModelConfig, Triple, NamedNode, EmbeddingModel,
23//! clustering::{EntityClustering, ClusteringConfig, ClusteringAlgorithm},
24//! };
25//! use std::collections::HashMap;
26//! use scirs2_core::ndarray_ext::Array1;
27//!
28//! # async fn example() -> anyhow::Result<()> {
29//! // 1. Train an embedding model
30//! let config = ModelConfig::default().with_dimensions(128);
31//! let mut model = TransE::new(config);
32//!
33//! model.add_triple(Triple::new(
34//! NamedNode::new("paris")?,
35//! NamedNode::new("capital_of")?,
36//! NamedNode::new("france")?,
37//! ))?;
38//! model.add_triple(Triple::new(
39//! NamedNode::new("london")?,
40//! NamedNode::new("capital_of")?,
41//! NamedNode::new("uk")?,
42//! ))?;
43//!
44//! model.train(Some(100)).await?;
45//!
46//! // 2. Extract embeddings
47//! let mut embeddings = HashMap::new();
48//! for entity in model.get_entities() {
49//! if let Ok(emb) = model.get_entity_embedding(&entity) {
50//! let array = Array1::from_vec(emb.values);
51//! embeddings.insert(entity, array);
52//! }
53//! }
54//!
55//! // 3. Perform clustering
56//! let cluster_config = ClusteringConfig {
57//! algorithm: ClusteringAlgorithm::KMeans,
58//! num_clusters: 3,
59//! max_iterations: 50,
60//! ..Default::default()
61//! };
62//!
63//! let mut clustering = EntityClustering::new(cluster_config);
64//! let result = clustering.cluster(&embeddings)?;
65//!
66//! println!("Silhouette score: {:.3}", result.silhouette_score);
67//! println!("Cluster 0: {} entities", result.cluster_sizes[0]);
68//! # Ok(())
69//! # }
70//! ```
71//!
72//! # Clustering Algorithms
73//!
74//! ## K-Means Clustering
75//!
76//! Fast and efficient for spherical clusters. Uses K-Means++ initialization for
77//! better convergence. Best for when you know the number of clusters.
78//!
79//! ```rust,no_run
80//! use oxirs_embed::clustering::{ClusteringConfig, ClusteringAlgorithm, EntityClustering};
81//! use std::collections::HashMap;
82//! use scirs2_core::ndarray_ext::Array1;
83//!
84//! # fn example() -> anyhow::Result<()> {
85//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
86//! let config = ClusteringConfig {
87//! algorithm: ClusteringAlgorithm::KMeans,
88//! num_clusters: 5,
89//! max_iterations: 100,
90//! tolerance: 0.0001,
91//! ..Default::default()
92//! };
93//!
94//! let mut clustering = EntityClustering::new(config);
95//! let result = clustering.cluster(&embeddings)?;
96//! # Ok(())
97//! # }
98//! ```
99//!
100//! ## Hierarchical Clustering
101//!
102//! Builds a hierarchy of clusters using bottom-up approach. Supports different
103//! linkage methods (single, average, complete). Does not require specifying the
104//! number of clusters upfront.
105//!
106//! ```rust,no_run
107//! use oxirs_embed::clustering::{ClusteringConfig, ClusteringAlgorithm, EntityClustering};
108//! use std::collections::HashMap;
109//! use scirs2_core::ndarray_ext::Array1;
110//!
111//! # fn example() -> anyhow::Result<()> {
112//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
113//! let config = ClusteringConfig {
114//! algorithm: ClusteringAlgorithm::Hierarchical,
115//! num_clusters: 4,
116//! ..Default::default()
117//! };
118//!
119//! let mut clustering = EntityClustering::new(config);
120//! let result = clustering.cluster(&embeddings)?;
121//! # Ok(())
122//! # }
123//! ```
124//!
125//! ## DBSCAN (Density-Based Clustering)
126//!
127//! Discovers clusters of arbitrary shape and automatically identifies noise/outliers.
128//! Does not require specifying the number of clusters. Best for non-spherical clusters.
129//!
130//! ```rust,no_run
131//! use oxirs_embed::clustering::{ClusteringConfig, ClusteringAlgorithm, EntityClustering};
132//! use std::collections::HashMap;
133//! use scirs2_core::ndarray_ext::Array1;
134//!
135//! # fn example() -> anyhow::Result<()> {
136//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
137//! let config = ClusteringConfig {
138//! algorithm: ClusteringAlgorithm::DBSCAN,
139//! epsilon: 0.5, // Neighborhood radius
140//! min_points: 5, // Minimum points to form cluster
141//! ..Default::default()
142//! };
143//!
144//! let mut clustering = EntityClustering::new(config);
145//! let result = clustering.cluster(&embeddings)?;
146//!
147//! // Check for noise points (cluster_id == usize::MAX)
148//! let noise_count = result.assignments.values()
149//! .filter(|&&id| id == usize::MAX)
150//! .count();
151//! println!("Noise points: {}", noise_count);
152//! # Ok(())
153//! # }
154//! ```
155//!
156//! ## Spectral Clustering
157//!
158//! Graph-based clustering using eigenvalues of the similarity matrix. Effective for
159//! non-convex clusters and can capture complex geometric structures.
160//!
161//! ```rust,no_run
162//! use oxirs_embed::clustering::{ClusteringConfig, ClusteringAlgorithm, EntityClustering};
163//! use std::collections::HashMap;
164//! use scirs2_core::ndarray_ext::Array1;
165//!
166//! # fn example() -> anyhow::Result<()> {
167//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
168//! let config = ClusteringConfig {
169//! algorithm: ClusteringAlgorithm::Spectral,
170//! num_clusters: 3,
171//! ..Default::default()
172//! };
173//!
174//! let mut clustering = EntityClustering::new(config);
175//! let result = clustering.cluster(&embeddings)?;
176//! # Ok(())
177//! # }
178//! ```
179//!
180//! # Cluster Quality Metrics
181//!
182//! The module computes several metrics to assess clustering quality:
183//!
184//! ## Silhouette Score
185//!
186//! Measures how similar entities are to their own cluster compared to other clusters.
187//! Range: [-1, 1], where:
188//! - 1: Perfect clustering
189//! - 0: Overlapping clusters
190//! - -1: Incorrect clustering
191//!
192//! ```rust,no_run
193//! # use oxirs_embed::clustering::*;
194//! # use std::collections::HashMap;
195//! # use scirs2_core::ndarray_ext::Array1;
196//! # fn example() -> anyhow::Result<()> {
197//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
198//! # let mut clustering = EntityClustering::new(ClusteringConfig::default());
199//! let result = clustering.cluster(&embeddings)?;
200//!
201//! if result.silhouette_score > 0.7 {
202//! println!("Excellent clustering!");
203//! } else if result.silhouette_score > 0.5 {
204//! println!("Good clustering");
205//! } else {
206//! println!("Weak clustering - consider different parameters");
207//! }
208//! # Ok(())
209//! # }
210//! ```
211//!
212//! ## Inertia
213//!
214//! Sum of squared distances from entities to their cluster centroids.
215//! Lower values indicate tighter clusters (for K-Means).
216//!
217//! # Analyzing Cluster Results
218//!
219//! ```rust,no_run
220//! # use oxirs_embed::clustering::*;
221//! # use std::collections::HashMap;
222//! # use scirs2_core::ndarray_ext::Array1;
223//! # fn example() -> anyhow::Result<()> {
224//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
225//! # let mut clustering = EntityClustering::new(ClusteringConfig::default());
226//! let result = clustering.cluster(&embeddings)?;
227//!
228//! // Analyze cluster composition
229//! for (entity, cluster_id) in &result.assignments {
230//! println!("Entity '{}' belongs to cluster {}", entity, cluster_id);
231//! }
232//!
233//! // Cluster statistics
234//! for (i, size) in result.cluster_sizes.iter().enumerate() {
235//! println!("Cluster {}: {} entities", i, size);
236//! }
237//!
238//! // Find entities closest to cluster centroids
239//! for (cluster_id, centroid) in result.centroids.iter().enumerate() {
240//! println!("Cluster {} centroid: {:?}", cluster_id, centroid);
241//! }
242//! # Ok(())
243//! # }
244//! ```
245//!
246//! # Use Cases
247//!
248//! ## Entity Type Discovery
249//!
250//! Automatically discover entity types without explicit labels:
251//!
252//! ```text
253//! Cluster 0: [paris, london, berlin] -> Cities
254//! Cluster 1: [france, germany, uk] -> Countries
255//! Cluster 2: [euro, dollar, pound] -> Currencies
256//! ```
257//!
258//! ## Knowledge Graph Organization
259//!
260//! Group related entities for improved navigation and querying.
261//!
262//! ## Recommendation Systems
263//!
264//! Find similar users or items based on learned embeddings.
265//!
266//! ## Anomaly Detection
267//!
268//! Identify outliers using DBSCAN's noise detection (cluster_id == usize::MAX).
269//!
270//! # Performance Considerations
271//!
272//! - **K-Means**: O(n*k*d*i) where n=entities, k=clusters, d=dimensions, i=iterations
273//! - **Hierarchical**: O(n^2 * log n) - slow for large datasets
274//! - **DBSCAN**: O(n * log n) with spatial indexing
275//! - **Spectral**: O(n^3) due to eigenvalue computation - slow for large datasets
276//!
277//! For large knowledge graphs (>10,000 entities), K-Means or DBSCAN are recommended.
278//!
279//! # Choosing the Right Algorithm
280//!
281//! | Algorithm | When to Use | Pros | Cons |
282//! |--------------|-------------------------------------------|--------------------------------|-------------------------|
283//! | K-Means | Known cluster count, spherical clusters | Fast, scalable | Requires K, spherical |
284//! | Hierarchical | Nested structure, small datasets | No K needed, hierarchical | Slow, memory intensive |
285//! | DBSCAN | Arbitrary shapes, noise handling | Finds outliers, no K needed | Sensitive to parameters |
286//! | Spectral | Non-convex clusters, graph structure | Handles complex shapes | Slow, requires K |
287//!
288//! # See Also
289//!
290//! - [`EntityClustering`]: Main clustering interface
291//! - [`ClusteringConfig`]: Configuration options
292//! - [`ClusteringResult`]: Clustering results and metrics
293//! - [`ClusteringAlgorithm`]: Available algorithms
294
295use anyhow::{anyhow, Result};
296use scirs2_core::ndarray_ext::Array1;
297use scirs2_core::random::Random;
298use serde::{Deserialize, Serialize};
299use std::collections::{HashMap, HashSet};
300use tracing::{debug, info};
301
302/// Clustering algorithm type
303#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
304pub enum ClusteringAlgorithm {
305 /// K-Means clustering
306 KMeans,
307 /// Hierarchical clustering
308 Hierarchical,
309 /// DBSCAN (Density-Based Spatial Clustering)
310 DBSCAN,
311 /// Spectral clustering
312 Spectral,
313}
314
315/// Clustering configuration
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct ClusteringConfig {
318 /// Clustering algorithm to use
319 pub algorithm: ClusteringAlgorithm,
320 /// Number of clusters (for K-Means, Spectral)
321 pub num_clusters: usize,
322 /// Maximum iterations (for iterative algorithms)
323 pub max_iterations: usize,
324 /// Convergence tolerance
325 pub tolerance: f32,
326 /// Random seed for reproducibility
327 pub random_seed: Option<u64>,
328 /// DBSCAN epsilon (neighborhood radius)
329 pub epsilon: f32,
330 /// DBSCAN minimum points
331 pub min_points: usize,
332}
333
334impl Default for ClusteringConfig {
335 fn default() -> Self {
336 Self {
337 algorithm: ClusteringAlgorithm::KMeans,
338 num_clusters: 10,
339 max_iterations: 100,
340 tolerance: 1e-4,
341 random_seed: None,
342 epsilon: 0.5,
343 min_points: 5,
344 }
345 }
346}
347
348/// Clustering result
349#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct ClusteringResult {
351 /// Cluster assignments for each entity (entity_id -> cluster_id)
352 pub assignments: HashMap<String, usize>,
353 /// Cluster centroids (for K-Means, Spectral)
354 pub centroids: Vec<Array1<f32>>,
355 /// Cluster sizes
356 pub cluster_sizes: Vec<usize>,
357 /// Inertia/objective function value
358 pub inertia: f32,
359 /// Number of iterations performed
360 pub num_iterations: usize,
361 /// Silhouette score (quality metric, -1 to 1, higher is better)
362 pub silhouette_score: f32,
363}
364
365/// Entity clustering for knowledge graph embeddings
366pub struct EntityClustering {
367 config: ClusteringConfig,
368 rng: Random,
369}
370
371impl EntityClustering {
372 /// Create new entity clustering
373 pub fn new(config: ClusteringConfig) -> Self {
374 let rng = Random::default();
375
376 Self { config, rng }
377 }
378
379 /// Cluster entities based on their embeddings
380 pub fn cluster(
381 &mut self,
382 entity_embeddings: &HashMap<String, Array1<f32>>,
383 ) -> Result<ClusteringResult> {
384 if entity_embeddings.is_empty() {
385 return Err(anyhow!("No entity embeddings provided"));
386 }
387
388 info!(
389 "Clustering {} entities using {:?}",
390 entity_embeddings.len(),
391 self.config.algorithm
392 );
393
394 match self.config.algorithm {
395 ClusteringAlgorithm::KMeans => self.kmeans_clustering(entity_embeddings),
396 ClusteringAlgorithm::Hierarchical => self.hierarchical_clustering(entity_embeddings),
397 ClusteringAlgorithm::DBSCAN => self.dbscan_clustering(entity_embeddings),
398 ClusteringAlgorithm::Spectral => self.spectral_clustering(entity_embeddings),
399 }
400 }
401
402 /// K-Means clustering implementation
403 fn kmeans_clustering(
404 &mut self,
405 entity_embeddings: &HashMap<String, Array1<f32>>,
406 ) -> Result<ClusteringResult> {
407 let k = self.config.num_clusters;
408 let entity_list: Vec<String> = entity_embeddings.keys().cloned().collect();
409 let n = entity_list.len();
410
411 if k > n {
412 return Err(anyhow!("Number of clusters exceeds number of entities"));
413 }
414
415 // Initialize centroids randomly
416 let dim = entity_embeddings.values().next().unwrap().len();
417 let mut centroids: Vec<Array1<f32>> = Vec::new();
418
419 // K-Means++ initialization for better convergence
420 let first_idx = self.rng.random_range(0..n);
421 centroids.push(entity_embeddings[&entity_list[first_idx]].clone());
422
423 for _ in 1..k {
424 // Compute distances to nearest centroid
425 let distances: Vec<f32> = entity_list
426 .iter()
427 .map(|entity| {
428 let emb = &entity_embeddings[entity];
429 centroids
430 .iter()
431 .map(|c| self.euclidean_distance(emb, c))
432 .fold(f32::INFINITY, f32::min)
433 .powi(2)
434 })
435 .collect();
436
437 // Sample proportional to distance squared
438 let sum: f32 = distances.iter().sum();
439 let mut prob = self.rng.random_range(0.0..sum);
440 let mut next_idx = 0;
441
442 for (i, &dist) in distances.iter().enumerate() {
443 prob -= dist;
444 if prob <= 0.0 {
445 next_idx = i;
446 break;
447 }
448 }
449
450 centroids.push(entity_embeddings[&entity_list[next_idx]].clone());
451 }
452
453 // Iterative refinement
454 let mut assignments: HashMap<String, usize> = HashMap::new();
455 let mut prev_inertia = f32::INFINITY;
456
457 for iteration in 0..self.config.max_iterations {
458 // Assignment step
459 assignments.clear();
460 for entity in &entity_list {
461 let emb = &entity_embeddings[entity];
462 let cluster = self.nearest_centroid(emb, ¢roids);
463 assignments.insert(entity.clone(), cluster);
464 }
465
466 // Update step
467 let mut new_centroids: Vec<Array1<f32>> = vec![Array1::zeros(dim); k];
468 let mut counts = vec![0; k];
469
470 for entity in &entity_list {
471 if let Some(&cluster) = assignments.get(entity) {
472 new_centroids[cluster] = &new_centroids[cluster] + &entity_embeddings[entity];
473 counts[cluster] += 1;
474 }
475 }
476
477 for (i, count) in counts.iter().enumerate() {
478 if *count > 0 {
479 new_centroids[i] = &new_centroids[i] / (*count as f32);
480 }
481 }
482
483 centroids = new_centroids;
484
485 // Compute inertia
486 let inertia =
487 self.compute_inertia(&entity_list, entity_embeddings, &assignments, ¢roids);
488
489 debug!("Iteration {}: inertia = {:.6}", iteration + 1, inertia);
490
491 // Check convergence
492 if (prev_inertia - inertia).abs() < self.config.tolerance {
493 info!("K-Means converged at iteration {}", iteration + 1);
494 break;
495 }
496
497 prev_inertia = inertia;
498 }
499
500 let final_inertia =
501 self.compute_inertia(&entity_list, entity_embeddings, &assignments, ¢roids);
502 let cluster_sizes = self.compute_cluster_sizes(&assignments, k);
503 let silhouette =
504 self.compute_silhouette_score(&entity_list, entity_embeddings, &assignments);
505
506 Ok(ClusteringResult {
507 assignments,
508 centroids,
509 cluster_sizes,
510 inertia: final_inertia,
511 num_iterations: self.config.max_iterations,
512 silhouette_score: silhouette,
513 })
514 }
515
516 /// Hierarchical clustering (agglomerative)
517 fn hierarchical_clustering(
518 &mut self,
519 entity_embeddings: &HashMap<String, Array1<f32>>,
520 ) -> Result<ClusteringResult> {
521 let entity_list: Vec<String> = entity_embeddings.keys().cloned().collect();
522 let n = entity_list.len();
523
524 // Start with each entity in its own cluster
525 let mut clusters: Vec<HashSet<usize>> = (0..n)
526 .map(|i| {
527 let mut set = HashSet::new();
528 set.insert(i);
529 set
530 })
531 .collect();
532
533 // Merge clusters until we reach desired number
534 while clusters.len() > self.config.num_clusters {
535 // Find closest pair of clusters
536 let (i, j) = self.find_closest_clusters(&clusters, &entity_list, entity_embeddings);
537
538 // Merge clusters
539 let cluster_j = clusters.remove(j);
540 clusters[i].extend(cluster_j);
541 }
542
543 // Convert to assignments
544 let mut assignments = HashMap::new();
545 for (cluster_id, cluster) in clusters.iter().enumerate() {
546 for &entity_idx in cluster {
547 assignments.insert(entity_list[entity_idx].clone(), cluster_id);
548 }
549 }
550
551 // Compute centroids
552 let dim = entity_embeddings.values().next().unwrap().len();
553 let mut centroids = vec![Array1::zeros(dim); self.config.num_clusters];
554 let mut counts = vec![0; self.config.num_clusters];
555
556 for (entity, &cluster) in &assignments {
557 centroids[cluster] = ¢roids[cluster] + &entity_embeddings[entity];
558 counts[cluster] += 1;
559 }
560
561 for (i, count) in counts.iter().enumerate() {
562 if *count > 0 {
563 centroids[i] = ¢roids[i] / (*count as f32);
564 }
565 }
566
567 let inertia =
568 self.compute_inertia(&entity_list, entity_embeddings, &assignments, ¢roids);
569 let cluster_sizes = self.compute_cluster_sizes(&assignments, self.config.num_clusters);
570 let silhouette =
571 self.compute_silhouette_score(&entity_list, entity_embeddings, &assignments);
572
573 Ok(ClusteringResult {
574 assignments,
575 centroids,
576 cluster_sizes,
577 inertia,
578 num_iterations: n - self.config.num_clusters,
579 silhouette_score: silhouette,
580 })
581 }
582
583 /// DBSCAN clustering implementation
584 fn dbscan_clustering(
585 &mut self,
586 entity_embeddings: &HashMap<String, Array1<f32>>,
587 ) -> Result<ClusteringResult> {
588 let entity_list: Vec<String> = entity_embeddings.keys().cloned().collect();
589 let n = entity_list.len();
590
591 let mut assignments: HashMap<String, usize> = HashMap::new();
592 let mut visited = HashSet::new();
593 let mut cluster_id = 0;
594
595 for i in 0..n {
596 let entity = &entity_list[i];
597 if visited.contains(&i) {
598 continue;
599 }
600
601 visited.insert(i);
602
603 // Find neighbors
604 let neighbors = self.find_neighbors(i, &entity_list, entity_embeddings);
605
606 if neighbors.len() < self.config.min_points {
607 // Mark as noise (-1 represented as max usize)
608 assignments.insert(entity.clone(), usize::MAX);
609 } else {
610 // Start new cluster
611 self.expand_cluster(
612 i,
613 &neighbors,
614 cluster_id,
615 &entity_list,
616 entity_embeddings,
617 &mut assignments,
618 &mut visited,
619 );
620 cluster_id += 1;
621 }
622 }
623
624 // Compute centroids for non-noise clusters
625 let dim = entity_embeddings.values().next().unwrap().len();
626 let mut centroids = vec![Array1::zeros(dim); cluster_id];
627 let mut counts = vec![0; cluster_id];
628
629 for (entity, &cluster) in &assignments {
630 if cluster != usize::MAX {
631 centroids[cluster] = ¢roids[cluster] + &entity_embeddings[entity];
632 counts[cluster] += 1;
633 }
634 }
635
636 for (i, count) in counts.iter().enumerate() {
637 if *count > 0 {
638 centroids[i] = ¢roids[i] / (*count as f32);
639 }
640 }
641
642 let inertia =
643 self.compute_inertia(&entity_list, entity_embeddings, &assignments, ¢roids);
644 let cluster_sizes = self.compute_cluster_sizes(&assignments, cluster_id);
645 let silhouette =
646 self.compute_silhouette_score(&entity_list, entity_embeddings, &assignments);
647
648 Ok(ClusteringResult {
649 assignments,
650 centroids,
651 cluster_sizes,
652 inertia,
653 num_iterations: 1,
654 silhouette_score: silhouette,
655 })
656 }
657
658 /// Spectral clustering (simplified implementation)
659 fn spectral_clustering(
660 &mut self,
661 entity_embeddings: &HashMap<String, Array1<f32>>,
662 ) -> Result<ClusteringResult> {
663 // For simplicity, use K-Means on normalized embeddings
664 // Full spectral clustering requires eigendecomposition of graph Laplacian
665
666 let mut normalized_embeddings = HashMap::new();
667 for (entity, emb) in entity_embeddings {
668 let norm = emb.dot(emb).sqrt();
669 if norm > 0.0 {
670 normalized_embeddings.insert(entity.clone(), emb / norm);
671 } else {
672 normalized_embeddings.insert(entity.clone(), emb.clone());
673 }
674 }
675
676 self.kmeans_clustering(&normalized_embeddings)
677 }
678
679 /// Find nearest centroid for an embedding
680 fn nearest_centroid(&self, embedding: &Array1<f32>, centroids: &[Array1<f32>]) -> usize {
681 centroids
682 .iter()
683 .enumerate()
684 .map(|(i, c)| (i, self.euclidean_distance(embedding, c)))
685 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
686 .map(|(i, _)| i)
687 .unwrap_or(0)
688 }
689
690 /// Compute Euclidean distance
691 fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
692 let diff = a - b;
693 diff.dot(&diff).sqrt()
694 }
695
696 /// Compute inertia (sum of squared distances to centroids)
697 fn compute_inertia(
698 &self,
699 entity_list: &[String],
700 embeddings: &HashMap<String, Array1<f32>>,
701 assignments: &HashMap<String, usize>,
702 centroids: &[Array1<f32>],
703 ) -> f32 {
704 entity_list
705 .iter()
706 .filter_map(|entity| {
707 assignments.get(entity).and_then(|&cluster| {
708 if cluster < centroids.len() {
709 Some(
710 self.euclidean_distance(&embeddings[entity], ¢roids[cluster])
711 .powi(2),
712 )
713 } else {
714 None
715 }
716 })
717 })
718 .sum()
719 }
720
721 /// Compute cluster sizes
722 fn compute_cluster_sizes(
723 &self,
724 assignments: &HashMap<String, usize>,
725 num_clusters: usize,
726 ) -> Vec<usize> {
727 let mut sizes = vec![0; num_clusters];
728 for &cluster in assignments.values() {
729 if cluster < num_clusters {
730 sizes[cluster] += 1;
731 }
732 }
733 sizes
734 }
735
736 /// Compute silhouette score
737 fn compute_silhouette_score(
738 &self,
739 entity_list: &[String],
740 embeddings: &HashMap<String, Array1<f32>>,
741 assignments: &HashMap<String, usize>,
742 ) -> f32 {
743 if entity_list.len() < 2 {
744 return 0.0;
745 }
746
747 let scores: Vec<f32> = entity_list
748 .iter()
749 .filter_map(|entity| {
750 assignments.get(entity).map(|&cluster| {
751 let emb = &embeddings[entity];
752
753 // Compute average distance to same cluster (a)
754 let same_cluster: Vec<f32> = entity_list
755 .iter()
756 .filter_map(|other| {
757 if other != entity && assignments.get(other) == Some(&cluster) {
758 Some(self.euclidean_distance(emb, &embeddings[other]))
759 } else {
760 None
761 }
762 })
763 .collect();
764
765 let a = if !same_cluster.is_empty() {
766 same_cluster.iter().sum::<f32>() / same_cluster.len() as f32
767 } else {
768 0.0
769 };
770
771 // Compute minimum average distance to other clusters (b)
772 let unique_clusters: HashSet<usize> = assignments.values().copied().collect();
773 let b = unique_clusters
774 .iter()
775 .filter(|&&c| c != cluster)
776 .map(|&other_cluster| {
777 let distances: Vec<f32> = entity_list
778 .iter()
779 .filter_map(|other| {
780 if assignments.get(other) == Some(&other_cluster) {
781 Some(self.euclidean_distance(emb, &embeddings[other]))
782 } else {
783 None
784 }
785 })
786 .collect();
787
788 if !distances.is_empty() {
789 distances.iter().sum::<f32>() / distances.len() as f32
790 } else {
791 f32::INFINITY
792 }
793 })
794 .fold(f32::INFINITY, f32::min);
795
796 (b - a) / a.max(b).max(1e-10)
797 })
798 })
799 .collect();
800
801 if scores.is_empty() {
802 0.0
803 } else {
804 scores.iter().sum::<f32>() / scores.len() as f32
805 }
806 }
807
808 /// Find closest pair of clusters for hierarchical clustering
809 fn find_closest_clusters(
810 &self,
811 clusters: &[HashSet<usize>],
812 entity_list: &[String],
813 embeddings: &HashMap<String, Array1<f32>>,
814 ) -> (usize, usize) {
815 let mut min_dist = f32::INFINITY;
816 let mut closest_pair = (0, 1);
817
818 for i in 0..clusters.len() {
819 for j in (i + 1)..clusters.len() {
820 // Average linkage
821 let mut total_dist = 0.0;
822 let mut count = 0;
823
824 for &idx_i in &clusters[i] {
825 for &idx_j in &clusters[j] {
826 let dist = self.euclidean_distance(
827 &embeddings[&entity_list[idx_i]],
828 &embeddings[&entity_list[idx_j]],
829 );
830 total_dist += dist;
831 count += 1;
832 }
833 }
834
835 let avg_dist = if count > 0 {
836 total_dist / count as f32
837 } else {
838 f32::INFINITY
839 };
840
841 if avg_dist < min_dist {
842 min_dist = avg_dist;
843 closest_pair = (i, j);
844 }
845 }
846 }
847
848 closest_pair
849 }
850
851 /// Find neighbors within epsilon distance for DBSCAN
852 fn find_neighbors(
853 &self,
854 idx: usize,
855 entity_list: &[String],
856 embeddings: &HashMap<String, Array1<f32>>,
857 ) -> Vec<usize> {
858 let entity = &entity_list[idx];
859 let emb = &embeddings[entity];
860
861 entity_list
862 .iter()
863 .enumerate()
864 .filter_map(|(i, other)| {
865 if i != idx
866 && self.euclidean_distance(emb, &embeddings[other]) <= self.config.epsilon
867 {
868 Some(i)
869 } else {
870 None
871 }
872 })
873 .collect()
874 }
875
876 /// Expand cluster for DBSCAN
877 #[allow(clippy::too_many_arguments)]
878 fn expand_cluster(
879 &self,
880 idx: usize,
881 neighbors: &[usize],
882 cluster_id: usize,
883 entity_list: &[String],
884 embeddings: &HashMap<String, Array1<f32>>,
885 assignments: &mut HashMap<String, usize>,
886 visited: &mut HashSet<usize>,
887 ) {
888 assignments.insert(entity_list[idx].clone(), cluster_id);
889
890 let mut queue: Vec<usize> = neighbors.to_vec();
891 let mut processed = 0;
892
893 while processed < queue.len() {
894 let neighbor_idx = queue[processed];
895 processed += 1;
896
897 if !visited.contains(&neighbor_idx) {
898 visited.insert(neighbor_idx);
899
900 let neighbor_neighbors = self.find_neighbors(neighbor_idx, entity_list, embeddings);
901
902 if neighbor_neighbors.len() >= self.config.min_points {
903 queue.extend(neighbor_neighbors);
904 }
905 }
906
907 if !assignments.contains_key(&entity_list[neighbor_idx]) {
908 assignments.insert(entity_list[neighbor_idx].clone(), cluster_id);
909 }
910 }
911 }
912}
913
914#[cfg(test)]
915mod tests {
916 use super::*;
917 use scirs2_core::ndarray_ext::array;
918
919 #[test]
920 fn test_kmeans_clustering() {
921 let mut embeddings = HashMap::new();
922 embeddings.insert("e1".to_string(), array![1.0, 1.0]);
923 embeddings.insert("e2".to_string(), array![1.1, 0.9]);
924 embeddings.insert("e3".to_string(), array![5.0, 5.0]);
925 embeddings.insert("e4".to_string(), array![5.1, 4.9]);
926
927 let config = ClusteringConfig {
928 algorithm: ClusteringAlgorithm::KMeans,
929 num_clusters: 2,
930 ..Default::default()
931 };
932
933 let mut clustering = EntityClustering::new(config);
934 let result = clustering.cluster(&embeddings).unwrap();
935
936 assert_eq!(result.assignments.len(), 4);
937 assert_eq!(result.centroids.len(), 2);
938 assert_eq!(result.cluster_sizes.len(), 2);
939
940 // Check that similar entities are in the same cluster
941 assert_eq!(result.assignments["e1"], result.assignments["e2"]);
942 assert_eq!(result.assignments["e3"], result.assignments["e4"]);
943 }
944
945 #[test]
946 fn test_silhouette_score() {
947 let mut embeddings = HashMap::new();
948 embeddings.insert("e1".to_string(), array![0.0, 0.0]);
949 embeddings.insert("e2".to_string(), array![1.0, 1.0]);
950 embeddings.insert("e3".to_string(), array![5.0, 5.0]);
951
952 let config = ClusteringConfig {
953 num_clusters: 2,
954 ..Default::default()
955 };
956
957 let mut clustering = EntityClustering::new(config);
958 let result = clustering.cluster(&embeddings).unwrap();
959
960 // Silhouette score should be between -1 and 1
961 assert!(result.silhouette_score >= -1.0 && result.silhouette_score <= 1.0);
962 }
963}