oxirs_embed/clustering.rs
1//! Clustering Support for Knowledge Graph Embeddings
2//!
3//! This module provides comprehensive clustering algorithms for analyzing and grouping
4//! entities based on their learned embeddings. Clustering helps discover latent
5//! structure in knowledge graphs and can improve downstream tasks such as entity
6//! type discovery, knowledge organization, and recommendation systems.
7//!
8//! # Overview
9//!
10//! The module provides four powerful clustering algorithms:
11//! - **K-Means**: Fast, spherical clusters with K-Means++ initialization
12//! - **Hierarchical**: Bottom-up agglomerative clustering with linkage methods
13//! - **DBSCAN**: Density-based clustering that discovers arbitrary shapes and handles noise
14//! - **Spectral**: Graph-based clustering using eigenvalues of similarity matrices
15//!
16//! Each algorithm is suited for different data characteristics and use cases.
17//!
18//! # Quick Start
19//!
20//! ```rust,no_run
21//! use oxirs_embed::{
22//! TransE, ModelConfig, Triple, NamedNode, EmbeddingModel,
23//! clustering::{EntityClustering, ClusteringConfig, ClusteringAlgorithm},
24//! };
25//! use std::collections::HashMap;
26//! use scirs2_core::ndarray_ext::Array1;
27//!
28//! # async fn example() -> anyhow::Result<()> {
29//! // 1. Train an embedding model
30//! let config = ModelConfig::default().with_dimensions(128);
31//! let mut model = TransE::new(config);
32//!
33//! model.add_triple(Triple::new(
34//! NamedNode::new("paris")?,
35//! NamedNode::new("capital_of")?,
36//! NamedNode::new("france")?,
37//! ))?;
38//! model.add_triple(Triple::new(
39//! NamedNode::new("london")?,
40//! NamedNode::new("capital_of")?,
41//! NamedNode::new("uk")?,
42//! ))?;
43//!
44//! model.train(Some(100)).await?;
45//!
46//! // 2. Extract embeddings
47//! let mut embeddings = HashMap::new();
48//! for entity in model.get_entities() {
49//! if let Ok(emb) = model.get_entity_embedding(&entity) {
50//! let array = Array1::from_vec(emb.values);
51//! embeddings.insert(entity, array);
52//! }
53//! }
54//!
55//! // 3. Perform clustering
56//! let cluster_config = ClusteringConfig {
57//! algorithm: ClusteringAlgorithm::KMeans,
58//! num_clusters: 3,
59//! max_iterations: 50,
60//! ..Default::default()
61//! };
62//!
63//! let mut clustering = EntityClustering::new(cluster_config);
64//! let result = clustering.cluster(&embeddings)?;
65//!
66//! println!("Silhouette score: {:.3}", result.silhouette_score);
67//! println!("Cluster 0: {} entities", result.cluster_sizes[0]);
68//! # Ok(())
69//! # }
70//! ```
71//!
72//! # Clustering Algorithms
73//!
74//! ## K-Means Clustering
75//!
76//! Fast and efficient for spherical clusters. Uses K-Means++ initialization for
77//! better convergence. Best for when you know the number of clusters.
78//!
79//! ```rust,no_run
80//! use oxirs_embed::clustering::{ClusteringConfig, ClusteringAlgorithm, EntityClustering};
81//! use std::collections::HashMap;
82//! use scirs2_core::ndarray_ext::Array1;
83//!
84//! # fn example() -> anyhow::Result<()> {
85//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
86//! let config = ClusteringConfig {
87//! algorithm: ClusteringAlgorithm::KMeans,
88//! num_clusters: 5,
89//! max_iterations: 100,
90//! tolerance: 0.0001,
91//! ..Default::default()
92//! };
93//!
94//! let mut clustering = EntityClustering::new(config);
95//! let result = clustering.cluster(&embeddings)?;
96//! # Ok(())
97//! # }
98//! ```
99//!
100//! ## Hierarchical Clustering
101//!
102//! Builds a hierarchy of clusters using bottom-up approach. Supports different
103//! linkage methods (single, average, complete). Does not require specifying the
104//! number of clusters upfront.
105//!
106//! ```rust,no_run
107//! use oxirs_embed::clustering::{ClusteringConfig, ClusteringAlgorithm, EntityClustering};
108//! use std::collections::HashMap;
109//! use scirs2_core::ndarray_ext::Array1;
110//!
111//! # fn example() -> anyhow::Result<()> {
112//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
113//! let config = ClusteringConfig {
114//! algorithm: ClusteringAlgorithm::Hierarchical,
115//! num_clusters: 4,
116//! ..Default::default()
117//! };
118//!
119//! let mut clustering = EntityClustering::new(config);
120//! let result = clustering.cluster(&embeddings)?;
121//! # Ok(())
122//! # }
123//! ```
124//!
125//! ## DBSCAN (Density-Based Clustering)
126//!
127//! Discovers clusters of arbitrary shape and automatically identifies noise/outliers.
128//! Does not require specifying the number of clusters. Best for non-spherical clusters.
129//!
130//! ```rust,no_run
131//! use oxirs_embed::clustering::{ClusteringConfig, ClusteringAlgorithm, EntityClustering};
132//! use std::collections::HashMap;
133//! use scirs2_core::ndarray_ext::Array1;
134//!
135//! # fn example() -> anyhow::Result<()> {
136//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
137//! let config = ClusteringConfig {
138//! algorithm: ClusteringAlgorithm::DBSCAN,
139//! epsilon: 0.5, // Neighborhood radius
140//! min_points: 5, // Minimum points to form cluster
141//! ..Default::default()
142//! };
143//!
144//! let mut clustering = EntityClustering::new(config);
145//! let result = clustering.cluster(&embeddings)?;
146//!
147//! // Check for noise points (cluster_id == usize::MAX)
148//! let noise_count = result.assignments.values()
149//! .filter(|&&id| id == usize::MAX)
150//! .count();
151//! println!("Noise points: {}", noise_count);
152//! # Ok(())
153//! # }
154//! ```
155//!
156//! ## Spectral Clustering
157//!
158//! Graph-based clustering using eigenvalues of the similarity matrix. Effective for
159//! non-convex clusters and can capture complex geometric structures.
160//!
161//! ```rust,no_run
162//! use oxirs_embed::clustering::{ClusteringConfig, ClusteringAlgorithm, EntityClustering};
163//! use std::collections::HashMap;
164//! use scirs2_core::ndarray_ext::Array1;
165//!
166//! # fn example() -> anyhow::Result<()> {
167//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
168//! let config = ClusteringConfig {
169//! algorithm: ClusteringAlgorithm::Spectral,
170//! num_clusters: 3,
171//! ..Default::default()
172//! };
173//!
174//! let mut clustering = EntityClustering::new(config);
175//! let result = clustering.cluster(&embeddings)?;
176//! # Ok(())
177//! # }
178//! ```
179//!
180//! # Cluster Quality Metrics
181//!
182//! The module computes several metrics to assess clustering quality:
183//!
184//! ## Silhouette Score
185//!
186//! Measures how similar entities are to their own cluster compared to other clusters.
187//! Range: [-1, 1], where:
188//! - 1: Perfect clustering
189//! - 0: Overlapping clusters
190//! - -1: Incorrect clustering
191//!
192//! ```rust,no_run
193//! # use oxirs_embed::clustering::*;
194//! # use std::collections::HashMap;
195//! # use scirs2_core::ndarray_ext::Array1;
196//! # fn example() -> anyhow::Result<()> {
197//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
198//! # let mut clustering = EntityClustering::new(ClusteringConfig::default());
199//! let result = clustering.cluster(&embeddings)?;
200//!
201//! if result.silhouette_score > 0.7 {
202//! println!("Excellent clustering!");
203//! } else if result.silhouette_score > 0.5 {
204//! println!("Good clustering");
205//! } else {
206//! println!("Weak clustering - consider different parameters");
207//! }
208//! # Ok(())
209//! # }
210//! ```
211//!
212//! ## Inertia
213//!
214//! Sum of squared distances from entities to their cluster centroids.
215//! Lower values indicate tighter clusters (for K-Means).
216//!
217//! # Analyzing Cluster Results
218//!
219//! ```rust,no_run
220//! # use oxirs_embed::clustering::*;
221//! # use std::collections::HashMap;
222//! # use scirs2_core::ndarray_ext::Array1;
223//! # fn example() -> anyhow::Result<()> {
224//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
225//! # let mut clustering = EntityClustering::new(ClusteringConfig::default());
226//! let result = clustering.cluster(&embeddings)?;
227//!
228//! // Analyze cluster composition
229//! for (entity, cluster_id) in &result.assignments {
230//! println!("Entity '{}' belongs to cluster {}", entity, cluster_id);
231//! }
232//!
233//! // Cluster statistics
234//! for (i, size) in result.cluster_sizes.iter().enumerate() {
235//! println!("Cluster {}: {} entities", i, size);
236//! }
237//!
238//! // Find entities closest to cluster centroids
239//! for (cluster_id, centroid) in result.centroids.iter().enumerate() {
240//! println!("Cluster {} centroid: {:?}", cluster_id, centroid);
241//! }
242//! # Ok(())
243//! # }
244//! ```
245//!
246//! # Use Cases
247//!
248//! ## Entity Type Discovery
249//!
250//! Automatically discover entity types without explicit labels:
251//!
252//! ```text
253//! Cluster 0: [paris, london, berlin] -> Cities
254//! Cluster 1: [france, germany, uk] -> Countries
255//! Cluster 2: [euro, dollar, pound] -> Currencies
256//! ```
257//!
258//! ## Knowledge Graph Organization
259//!
260//! Group related entities for improved navigation and querying.
261//!
262//! ## Recommendation Systems
263//!
264//! Find similar users or items based on learned embeddings.
265//!
266//! ## Anomaly Detection
267//!
268//! Identify outliers using DBSCAN's noise detection (cluster_id == usize::MAX).
269//!
270//! # Performance Considerations
271//!
272//! - **K-Means**: O(n*k*d*i) where n=entities, k=clusters, d=dimensions, i=iterations
273//! - **Hierarchical**: O(n^2 * log n) - slow for large datasets
274//! - **DBSCAN**: O(n * log n) with spatial indexing
275//! - **Spectral**: O(n^3) due to eigenvalue computation - slow for large datasets
276//!
277//! For large knowledge graphs (>10,000 entities), K-Means or DBSCAN are recommended.
278//!
279//! # Choosing the Right Algorithm
280//!
281//! | Algorithm | When to Use | Pros | Cons |
282//! |--------------|-------------------------------------------|--------------------------------|-------------------------|
283//! | K-Means | Known cluster count, spherical clusters | Fast, scalable | Requires K, spherical |
284//! | Hierarchical | Nested structure, small datasets | No K needed, hierarchical | Slow, memory intensive |
285//! | DBSCAN | Arbitrary shapes, noise handling | Finds outliers, no K needed | Sensitive to parameters |
286//! | Spectral | Non-convex clusters, graph structure | Handles complex shapes | Slow, requires K |
287//!
288//! # See Also
289//!
290//! - [`EntityClustering`]: Main clustering interface
291//! - [`ClusteringConfig`]: Configuration options
292//! - [`ClusteringResult`]: Clustering results and metrics
293//! - [`ClusteringAlgorithm`]: Available algorithms
294
295use anyhow::{anyhow, Result};
296use scirs2_core::ndarray_ext::Array1;
297use scirs2_core::random::Random;
298use serde::{Deserialize, Serialize};
299use std::collections::{HashMap, HashSet};
300use tracing::{debug, info};
301
302/// Clustering algorithm type
303#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
304pub enum ClusteringAlgorithm {
305 /// K-Means clustering
306 KMeans,
307 /// Hierarchical clustering
308 Hierarchical,
309 /// DBSCAN (Density-Based Spatial Clustering)
310 DBSCAN,
311 /// Spectral clustering
312 Spectral,
313}
314
315/// Clustering configuration
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct ClusteringConfig {
318 /// Clustering algorithm to use
319 pub algorithm: ClusteringAlgorithm,
320 /// Number of clusters (for K-Means, Spectral)
321 pub num_clusters: usize,
322 /// Maximum iterations (for iterative algorithms)
323 pub max_iterations: usize,
324 /// Convergence tolerance
325 pub tolerance: f32,
326 /// Random seed for reproducibility
327 pub random_seed: Option<u64>,
328 /// DBSCAN epsilon (neighborhood radius)
329 pub epsilon: f32,
330 /// DBSCAN minimum points
331 pub min_points: usize,
332}
333
334impl Default for ClusteringConfig {
335 fn default() -> Self {
336 Self {
337 algorithm: ClusteringAlgorithm::KMeans,
338 num_clusters: 10,
339 max_iterations: 100,
340 tolerance: 1e-4,
341 random_seed: None,
342 epsilon: 0.5,
343 min_points: 5,
344 }
345 }
346}
347
348/// Clustering result
349#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct ClusteringResult {
351 /// Cluster assignments for each entity (entity_id -> cluster_id)
352 pub assignments: HashMap<String, usize>,
353 /// Cluster centroids (for K-Means, Spectral)
354 pub centroids: Vec<Array1<f32>>,
355 /// Cluster sizes
356 pub cluster_sizes: Vec<usize>,
357 /// Inertia/objective function value
358 pub inertia: f32,
359 /// Number of iterations performed
360 pub num_iterations: usize,
361 /// Silhouette score (quality metric, -1 to 1, higher is better)
362 pub silhouette_score: f32,
363}
364
365/// Entity clustering for knowledge graph embeddings
366pub struct EntityClustering {
367 config: ClusteringConfig,
368 rng: Random,
369}
370
371impl EntityClustering {
372 /// Create new entity clustering
373 pub fn new(config: ClusteringConfig) -> Self {
374 let rng = Random::default();
375
376 Self { config, rng }
377 }
378
379 /// Cluster entities based on their embeddings
380 pub fn cluster(
381 &mut self,
382 entity_embeddings: &HashMap<String, Array1<f32>>,
383 ) -> Result<ClusteringResult> {
384 if entity_embeddings.is_empty() {
385 return Err(anyhow!("No entity embeddings provided"));
386 }
387
388 info!(
389 "Clustering {} entities using {:?}",
390 entity_embeddings.len(),
391 self.config.algorithm
392 );
393
394 match self.config.algorithm {
395 ClusteringAlgorithm::KMeans => self.kmeans_clustering(entity_embeddings),
396 ClusteringAlgorithm::Hierarchical => self.hierarchical_clustering(entity_embeddings),
397 ClusteringAlgorithm::DBSCAN => self.dbscan_clustering(entity_embeddings),
398 ClusteringAlgorithm::Spectral => self.spectral_clustering(entity_embeddings),
399 }
400 }
401
402 /// K-Means clustering implementation
403 fn kmeans_clustering(
404 &mut self,
405 entity_embeddings: &HashMap<String, Array1<f32>>,
406 ) -> Result<ClusteringResult> {
407 let k = self.config.num_clusters;
408 let entity_list: Vec<String> = entity_embeddings.keys().cloned().collect();
409 let n = entity_list.len();
410
411 if k > n {
412 return Err(anyhow!("Number of clusters exceeds number of entities"));
413 }
414
415 // Initialize centroids randomly
416 let dim = entity_embeddings
417 .values()
418 .next()
419 .expect("entity_embeddings should not be empty")
420 .len();
421 let mut centroids: Vec<Array1<f32>> = Vec::new();
422
423 // K-Means++ initialization for better convergence
424 let first_idx = self.rng.random_range(0..n);
425 centroids.push(entity_embeddings[&entity_list[first_idx]].clone());
426
427 for _ in 1..k {
428 // Compute distances to nearest centroid
429 let distances: Vec<f32> = entity_list
430 .iter()
431 .map(|entity| {
432 let emb = &entity_embeddings[entity];
433 centroids
434 .iter()
435 .map(|c| self.euclidean_distance(emb, c))
436 .fold(f32::INFINITY, f32::min)
437 .powi(2)
438 })
439 .collect();
440
441 // Sample proportional to distance squared
442 let sum: f32 = distances.iter().sum();
443 let mut prob = self.rng.random_range(0.0..sum);
444 let mut next_idx = 0;
445
446 for (i, &dist) in distances.iter().enumerate() {
447 prob -= dist;
448 if prob <= 0.0 {
449 next_idx = i;
450 break;
451 }
452 }
453
454 centroids.push(entity_embeddings[&entity_list[next_idx]].clone());
455 }
456
457 // Iterative refinement
458 let mut assignments: HashMap<String, usize> = HashMap::new();
459 let mut prev_inertia = f32::INFINITY;
460
461 for iteration in 0..self.config.max_iterations {
462 // Assignment step
463 assignments.clear();
464 for entity in &entity_list {
465 let emb = &entity_embeddings[entity];
466 let cluster = self.nearest_centroid(emb, ¢roids);
467 assignments.insert(entity.clone(), cluster);
468 }
469
470 // Update step
471 let mut new_centroids: Vec<Array1<f32>> = vec![Array1::zeros(dim); k];
472 let mut counts = vec![0; k];
473
474 for entity in &entity_list {
475 if let Some(&cluster) = assignments.get(entity) {
476 new_centroids[cluster] = &new_centroids[cluster] + &entity_embeddings[entity];
477 counts[cluster] += 1;
478 }
479 }
480
481 for (i, count) in counts.iter().enumerate() {
482 if *count > 0 {
483 new_centroids[i] = &new_centroids[i] / (*count as f32);
484 }
485 }
486
487 centroids = new_centroids;
488
489 // Compute inertia
490 let inertia =
491 self.compute_inertia(&entity_list, entity_embeddings, &assignments, ¢roids);
492
493 debug!("Iteration {}: inertia = {:.6}", iteration + 1, inertia);
494
495 // Check convergence
496 if (prev_inertia - inertia).abs() < self.config.tolerance {
497 info!("K-Means converged at iteration {}", iteration + 1);
498 break;
499 }
500
501 prev_inertia = inertia;
502 }
503
504 let final_inertia =
505 self.compute_inertia(&entity_list, entity_embeddings, &assignments, ¢roids);
506 let cluster_sizes = self.compute_cluster_sizes(&assignments, k);
507 let silhouette =
508 self.compute_silhouette_score(&entity_list, entity_embeddings, &assignments);
509
510 Ok(ClusteringResult {
511 assignments,
512 centroids,
513 cluster_sizes,
514 inertia: final_inertia,
515 num_iterations: self.config.max_iterations,
516 silhouette_score: silhouette,
517 })
518 }
519
520 /// Hierarchical clustering (agglomerative)
521 fn hierarchical_clustering(
522 &mut self,
523 entity_embeddings: &HashMap<String, Array1<f32>>,
524 ) -> Result<ClusteringResult> {
525 let entity_list: Vec<String> = entity_embeddings.keys().cloned().collect();
526 let n = entity_list.len();
527
528 // Start with each entity in its own cluster
529 let mut clusters: Vec<HashSet<usize>> = (0..n)
530 .map(|i| {
531 let mut set = HashSet::new();
532 set.insert(i);
533 set
534 })
535 .collect();
536
537 // Merge clusters until we reach desired number
538 while clusters.len() > self.config.num_clusters {
539 // Find closest pair of clusters
540 let (i, j) = self.find_closest_clusters(&clusters, &entity_list, entity_embeddings);
541
542 // Merge clusters
543 let cluster_j = clusters.remove(j);
544 clusters[i].extend(cluster_j);
545 }
546
547 // Convert to assignments
548 let mut assignments = HashMap::new();
549 for (cluster_id, cluster) in clusters.iter().enumerate() {
550 for &entity_idx in cluster {
551 assignments.insert(entity_list[entity_idx].clone(), cluster_id);
552 }
553 }
554
555 // Compute centroids
556 let dim = entity_embeddings
557 .values()
558 .next()
559 .expect("entity_embeddings should not be empty")
560 .len();
561 let mut centroids = vec![Array1::zeros(dim); self.config.num_clusters];
562 let mut counts = vec![0; self.config.num_clusters];
563
564 for (entity, &cluster) in &assignments {
565 centroids[cluster] = ¢roids[cluster] + &entity_embeddings[entity];
566 counts[cluster] += 1;
567 }
568
569 for (i, count) in counts.iter().enumerate() {
570 if *count > 0 {
571 centroids[i] = ¢roids[i] / (*count as f32);
572 }
573 }
574
575 let inertia =
576 self.compute_inertia(&entity_list, entity_embeddings, &assignments, ¢roids);
577 let cluster_sizes = self.compute_cluster_sizes(&assignments, self.config.num_clusters);
578 let silhouette =
579 self.compute_silhouette_score(&entity_list, entity_embeddings, &assignments);
580
581 Ok(ClusteringResult {
582 assignments,
583 centroids,
584 cluster_sizes,
585 inertia,
586 num_iterations: n - self.config.num_clusters,
587 silhouette_score: silhouette,
588 })
589 }
590
591 /// DBSCAN clustering implementation
592 fn dbscan_clustering(
593 &mut self,
594 entity_embeddings: &HashMap<String, Array1<f32>>,
595 ) -> Result<ClusteringResult> {
596 let entity_list: Vec<String> = entity_embeddings.keys().cloned().collect();
597 let n = entity_list.len();
598
599 let mut assignments: HashMap<String, usize> = HashMap::new();
600 let mut visited = HashSet::new();
601 let mut cluster_id = 0;
602
603 for i in 0..n {
604 let entity = &entity_list[i];
605 if visited.contains(&i) {
606 continue;
607 }
608
609 visited.insert(i);
610
611 // Find neighbors
612 let neighbors = self.find_neighbors(i, &entity_list, entity_embeddings);
613
614 if neighbors.len() < self.config.min_points {
615 // Mark as noise (-1 represented as max usize)
616 assignments.insert(entity.clone(), usize::MAX);
617 } else {
618 // Start new cluster
619 self.expand_cluster(
620 i,
621 &neighbors,
622 cluster_id,
623 &entity_list,
624 entity_embeddings,
625 &mut assignments,
626 &mut visited,
627 );
628 cluster_id += 1;
629 }
630 }
631
632 // Compute centroids for non-noise clusters
633 let dim = entity_embeddings
634 .values()
635 .next()
636 .expect("entity_embeddings should not be empty")
637 .len();
638 let mut centroids = vec![Array1::zeros(dim); cluster_id];
639 let mut counts = vec![0; cluster_id];
640
641 for (entity, &cluster) in &assignments {
642 if cluster != usize::MAX {
643 centroids[cluster] = ¢roids[cluster] + &entity_embeddings[entity];
644 counts[cluster] += 1;
645 }
646 }
647
648 for (i, count) in counts.iter().enumerate() {
649 if *count > 0 {
650 centroids[i] = ¢roids[i] / (*count as f32);
651 }
652 }
653
654 let inertia =
655 self.compute_inertia(&entity_list, entity_embeddings, &assignments, ¢roids);
656 let cluster_sizes = self.compute_cluster_sizes(&assignments, cluster_id);
657 let silhouette =
658 self.compute_silhouette_score(&entity_list, entity_embeddings, &assignments);
659
660 Ok(ClusteringResult {
661 assignments,
662 centroids,
663 cluster_sizes,
664 inertia,
665 num_iterations: 1,
666 silhouette_score: silhouette,
667 })
668 }
669
670 /// Spectral clustering (simplified implementation)
671 fn spectral_clustering(
672 &mut self,
673 entity_embeddings: &HashMap<String, Array1<f32>>,
674 ) -> Result<ClusteringResult> {
675 // For simplicity, use K-Means on normalized embeddings
676 // Full spectral clustering requires eigendecomposition of graph Laplacian
677
678 let mut normalized_embeddings = HashMap::new();
679 for (entity, emb) in entity_embeddings {
680 let norm = emb.dot(emb).sqrt();
681 if norm > 0.0 {
682 normalized_embeddings.insert(entity.clone(), emb / norm);
683 } else {
684 normalized_embeddings.insert(entity.clone(), emb.clone());
685 }
686 }
687
688 self.kmeans_clustering(&normalized_embeddings)
689 }
690
691 /// Find nearest centroid for an embedding
692 fn nearest_centroid(&self, embedding: &Array1<f32>, centroids: &[Array1<f32>]) -> usize {
693 centroids
694 .iter()
695 .enumerate()
696 .map(|(i, c)| (i, self.euclidean_distance(embedding, c)))
697 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
698 .map(|(i, _)| i)
699 .unwrap_or(0)
700 }
701
702 /// Compute Euclidean distance
703 fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
704 let diff = a - b;
705 diff.dot(&diff).sqrt()
706 }
707
708 /// Compute inertia (sum of squared distances to centroids)
709 fn compute_inertia(
710 &self,
711 entity_list: &[String],
712 embeddings: &HashMap<String, Array1<f32>>,
713 assignments: &HashMap<String, usize>,
714 centroids: &[Array1<f32>],
715 ) -> f32 {
716 entity_list
717 .iter()
718 .filter_map(|entity| {
719 assignments.get(entity).and_then(|&cluster| {
720 if cluster < centroids.len() {
721 Some(
722 self.euclidean_distance(&embeddings[entity], ¢roids[cluster])
723 .powi(2),
724 )
725 } else {
726 None
727 }
728 })
729 })
730 .sum()
731 }
732
733 /// Compute cluster sizes
734 fn compute_cluster_sizes(
735 &self,
736 assignments: &HashMap<String, usize>,
737 num_clusters: usize,
738 ) -> Vec<usize> {
739 let mut sizes = vec![0; num_clusters];
740 for &cluster in assignments.values() {
741 if cluster < num_clusters {
742 sizes[cluster] += 1;
743 }
744 }
745 sizes
746 }
747
748 /// Compute silhouette score
749 fn compute_silhouette_score(
750 &self,
751 entity_list: &[String],
752 embeddings: &HashMap<String, Array1<f32>>,
753 assignments: &HashMap<String, usize>,
754 ) -> f32 {
755 if entity_list.len() < 2 {
756 return 0.0;
757 }
758
759 let scores: Vec<f32> = entity_list
760 .iter()
761 .filter_map(|entity| {
762 assignments.get(entity).map(|&cluster| {
763 let emb = &embeddings[entity];
764
765 // Compute average distance to same cluster (a)
766 let same_cluster: Vec<f32> = entity_list
767 .iter()
768 .filter_map(|other| {
769 if other != entity && assignments.get(other) == Some(&cluster) {
770 Some(self.euclidean_distance(emb, &embeddings[other]))
771 } else {
772 None
773 }
774 })
775 .collect();
776
777 let a = if !same_cluster.is_empty() {
778 same_cluster.iter().sum::<f32>() / same_cluster.len() as f32
779 } else {
780 0.0
781 };
782
783 // Compute minimum average distance to other clusters (b)
784 let unique_clusters: HashSet<usize> = assignments.values().copied().collect();
785 let b = unique_clusters
786 .iter()
787 .filter(|&&c| c != cluster)
788 .map(|&other_cluster| {
789 let distances: Vec<f32> = entity_list
790 .iter()
791 .filter_map(|other| {
792 if assignments.get(other) == Some(&other_cluster) {
793 Some(self.euclidean_distance(emb, &embeddings[other]))
794 } else {
795 None
796 }
797 })
798 .collect();
799
800 if !distances.is_empty() {
801 distances.iter().sum::<f32>() / distances.len() as f32
802 } else {
803 f32::INFINITY
804 }
805 })
806 .fold(f32::INFINITY, f32::min);
807
808 (b - a) / a.max(b).max(1e-10)
809 })
810 })
811 .collect();
812
813 if scores.is_empty() {
814 0.0
815 } else {
816 scores.iter().sum::<f32>() / scores.len() as f32
817 }
818 }
819
820 /// Find closest pair of clusters for hierarchical clustering
821 fn find_closest_clusters(
822 &self,
823 clusters: &[HashSet<usize>],
824 entity_list: &[String],
825 embeddings: &HashMap<String, Array1<f32>>,
826 ) -> (usize, usize) {
827 let mut min_dist = f32::INFINITY;
828 let mut closest_pair = (0, 1);
829
830 for i in 0..clusters.len() {
831 for j in (i + 1)..clusters.len() {
832 // Average linkage
833 let mut total_dist = 0.0;
834 let mut count = 0;
835
836 for &idx_i in &clusters[i] {
837 for &idx_j in &clusters[j] {
838 let dist = self.euclidean_distance(
839 &embeddings[&entity_list[idx_i]],
840 &embeddings[&entity_list[idx_j]],
841 );
842 total_dist += dist;
843 count += 1;
844 }
845 }
846
847 let avg_dist = if count > 0 {
848 total_dist / count as f32
849 } else {
850 f32::INFINITY
851 };
852
853 if avg_dist < min_dist {
854 min_dist = avg_dist;
855 closest_pair = (i, j);
856 }
857 }
858 }
859
860 closest_pair
861 }
862
863 /// Find neighbors within epsilon distance for DBSCAN
864 fn find_neighbors(
865 &self,
866 idx: usize,
867 entity_list: &[String],
868 embeddings: &HashMap<String, Array1<f32>>,
869 ) -> Vec<usize> {
870 let entity = &entity_list[idx];
871 let emb = &embeddings[entity];
872
873 entity_list
874 .iter()
875 .enumerate()
876 .filter_map(|(i, other)| {
877 if i != idx
878 && self.euclidean_distance(emb, &embeddings[other]) <= self.config.epsilon
879 {
880 Some(i)
881 } else {
882 None
883 }
884 })
885 .collect()
886 }
887
888 /// Expand cluster for DBSCAN
889 #[allow(clippy::too_many_arguments)]
890 fn expand_cluster(
891 &self,
892 idx: usize,
893 neighbors: &[usize],
894 cluster_id: usize,
895 entity_list: &[String],
896 embeddings: &HashMap<String, Array1<f32>>,
897 assignments: &mut HashMap<String, usize>,
898 visited: &mut HashSet<usize>,
899 ) {
900 assignments.insert(entity_list[idx].clone(), cluster_id);
901
902 let mut queue: Vec<usize> = neighbors.to_vec();
903 let mut processed = 0;
904
905 while processed < queue.len() {
906 let neighbor_idx = queue[processed];
907 processed += 1;
908
909 if !visited.contains(&neighbor_idx) {
910 visited.insert(neighbor_idx);
911
912 let neighbor_neighbors = self.find_neighbors(neighbor_idx, entity_list, embeddings);
913
914 if neighbor_neighbors.len() >= self.config.min_points {
915 queue.extend(neighbor_neighbors);
916 }
917 }
918
919 if !assignments.contains_key(&entity_list[neighbor_idx]) {
920 assignments.insert(entity_list[neighbor_idx].clone(), cluster_id);
921 }
922 }
923 }
924}
925
926#[cfg(test)]
927mod tests {
928 use super::*;
929 use scirs2_core::ndarray_ext::array;
930
931 #[test]
932 fn test_kmeans_clustering() {
933 let mut embeddings = HashMap::new();
934 embeddings.insert("e1".to_string(), array![1.0, 1.0]);
935 embeddings.insert("e2".to_string(), array![1.1, 0.9]);
936 embeddings.insert("e3".to_string(), array![5.0, 5.0]);
937 embeddings.insert("e4".to_string(), array![5.1, 4.9]);
938
939 let config = ClusteringConfig {
940 algorithm: ClusteringAlgorithm::KMeans,
941 num_clusters: 2,
942 ..Default::default()
943 };
944
945 let mut clustering = EntityClustering::new(config);
946 let result = clustering.cluster(&embeddings).unwrap();
947
948 assert_eq!(result.assignments.len(), 4);
949 assert_eq!(result.centroids.len(), 2);
950 assert_eq!(result.cluster_sizes.len(), 2);
951
952 // Check that similar entities are in the same cluster
953 assert_eq!(result.assignments["e1"], result.assignments["e2"]);
954 assert_eq!(result.assignments["e3"], result.assignments["e4"]);
955 }
956
957 #[test]
958 fn test_silhouette_score() {
959 let mut embeddings = HashMap::new();
960 embeddings.insert("e1".to_string(), array![0.0, 0.0]);
961 embeddings.insert("e2".to_string(), array![1.0, 1.0]);
962 embeddings.insert("e3".to_string(), array![5.0, 5.0]);
963
964 let config = ClusteringConfig {
965 num_clusters: 2,
966 ..Default::default()
967 };
968
969 let mut clustering = EntityClustering::new(config);
970 let result = clustering.cluster(&embeddings).unwrap();
971
972 // Silhouette score should be between -1 and 1
973 assert!(result.silhouette_score >= -1.0 && result.silhouette_score <= 1.0);
974 }
975}