Skip to main content

oxirs_embed/
clustering.rs

1//! Clustering Support for Knowledge Graph Embeddings
2//!
3//! This module provides comprehensive clustering algorithms for analyzing and grouping
4//! entities based on their learned embeddings. Clustering helps discover latent
5//! structure in knowledge graphs and can improve downstream tasks such as entity
6//! type discovery, knowledge organization, and recommendation systems.
7//!
8//! # Overview
9//!
10//! The module provides four powerful clustering algorithms:
11//! - **K-Means**: Fast, spherical clusters with K-Means++ initialization
12//! - **Hierarchical**: Bottom-up agglomerative clustering with linkage methods
13//! - **DBSCAN**: Density-based clustering that discovers arbitrary shapes and handles noise
14//! - **Spectral**: Graph-based clustering using eigenvalues of similarity matrices
15//!
16//! Each algorithm is suited for different data characteristics and use cases.
17//!
18//! # Quick Start
19//!
20//! ```rust,no_run
21//! use oxirs_embed::{
22//!     TransE, ModelConfig, Triple, NamedNode, EmbeddingModel,
23//!     clustering::{EntityClustering, ClusteringConfig, ClusteringAlgorithm},
24//! };
25//! use std::collections::HashMap;
26//! use scirs2_core::ndarray_ext::Array1;
27//!
28//! # async fn example() -> anyhow::Result<()> {
29//! // 1. Train an embedding model
30//! let config = ModelConfig::default().with_dimensions(128);
31//! let mut model = TransE::new(config);
32//!
33//! model.add_triple(Triple::new(
34//!     NamedNode::new("paris")?,
35//!     NamedNode::new("capital_of")?,
36//!     NamedNode::new("france")?,
37//! ))?;
38//! model.add_triple(Triple::new(
39//!     NamedNode::new("london")?,
40//!     NamedNode::new("capital_of")?,
41//!     NamedNode::new("uk")?,
42//! ))?;
43//!
44//! model.train(Some(100)).await?;
45//!
46//! // 2. Extract embeddings
47//! let mut embeddings = HashMap::new();
48//! for entity in model.get_entities() {
49//!     if let Ok(emb) = model.get_entity_embedding(&entity) {
50//!         let array = Array1::from_vec(emb.values);
51//!         embeddings.insert(entity, array);
52//!     }
53//! }
54//!
55//! // 3. Perform clustering
56//! let cluster_config = ClusteringConfig {
57//!     algorithm: ClusteringAlgorithm::KMeans,
58//!     num_clusters: 3,
59//!     max_iterations: 50,
60//!     ..Default::default()
61//! };
62//!
63//! let mut clustering = EntityClustering::new(cluster_config);
64//! let result = clustering.cluster(&embeddings)?;
65//!
66//! println!("Silhouette score: {:.3}", result.silhouette_score);
67//! println!("Cluster 0: {} entities", result.cluster_sizes[0]);
68//! # Ok(())
69//! # }
70//! ```
71//!
72//! # Clustering Algorithms
73//!
74//! ## K-Means Clustering
75//!
76//! Fast and efficient for spherical clusters. Uses K-Means++ initialization for
77//! better convergence. Best for when you know the number of clusters.
78//!
79//! ```rust,no_run
80//! use oxirs_embed::clustering::{ClusteringConfig, ClusteringAlgorithm, EntityClustering};
81//! use std::collections::HashMap;
82//! use scirs2_core::ndarray_ext::Array1;
83//!
84//! # fn example() -> anyhow::Result<()> {
85//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
86//! let config = ClusteringConfig {
87//!     algorithm: ClusteringAlgorithm::KMeans,
88//!     num_clusters: 5,
89//!     max_iterations: 100,
90//!     tolerance: 0.0001,
91//!     ..Default::default()
92//! };
93//!
94//! let mut clustering = EntityClustering::new(config);
95//! let result = clustering.cluster(&embeddings)?;
96//! # Ok(())
97//! # }
98//! ```
99//!
100//! ## Hierarchical Clustering
101//!
102//! Builds a hierarchy of clusters using bottom-up approach. Supports different
103//! linkage methods (single, average, complete). Does not require specifying the
104//! number of clusters upfront.
105//!
106//! ```rust,no_run
107//! use oxirs_embed::clustering::{ClusteringConfig, ClusteringAlgorithm, EntityClustering};
108//! use std::collections::HashMap;
109//! use scirs2_core::ndarray_ext::Array1;
110//!
111//! # fn example() -> anyhow::Result<()> {
112//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
113//! let config = ClusteringConfig {
114//!     algorithm: ClusteringAlgorithm::Hierarchical,
115//!     num_clusters: 4,
116//!     ..Default::default()
117//! };
118//!
119//! let mut clustering = EntityClustering::new(config);
120//! let result = clustering.cluster(&embeddings)?;
121//! # Ok(())
122//! # }
123//! ```
124//!
125//! ## DBSCAN (Density-Based Clustering)
126//!
127//! Discovers clusters of arbitrary shape and automatically identifies noise/outliers.
128//! Does not require specifying the number of clusters. Best for non-spherical clusters.
129//!
130//! ```rust,no_run
131//! use oxirs_embed::clustering::{ClusteringConfig, ClusteringAlgorithm, EntityClustering};
132//! use std::collections::HashMap;
133//! use scirs2_core::ndarray_ext::Array1;
134//!
135//! # fn example() -> anyhow::Result<()> {
136//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
137//! let config = ClusteringConfig {
138//!     algorithm: ClusteringAlgorithm::DBSCAN,
139//!     epsilon: 0.5,        // Neighborhood radius
140//!     min_points: 5,       // Minimum points to form cluster
141//!     ..Default::default()
142//! };
143//!
144//! let mut clustering = EntityClustering::new(config);
145//! let result = clustering.cluster(&embeddings)?;
146//!
147//! // Check for noise points (cluster_id == usize::MAX)
148//! let noise_count = result.assignments.values()
149//!     .filter(|&&id| id == usize::MAX)
150//!     .count();
151//! println!("Noise points: {}", noise_count);
152//! # Ok(())
153//! # }
154//! ```
155//!
156//! ## Spectral Clustering
157//!
158//! Graph-based clustering using eigenvalues of the similarity matrix. Effective for
159//! non-convex clusters and can capture complex geometric structures.
160//!
161//! ```rust,no_run
162//! use oxirs_embed::clustering::{ClusteringConfig, ClusteringAlgorithm, EntityClustering};
163//! use std::collections::HashMap;
164//! use scirs2_core::ndarray_ext::Array1;
165//!
166//! # fn example() -> anyhow::Result<()> {
167//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
168//! let config = ClusteringConfig {
169//!     algorithm: ClusteringAlgorithm::Spectral,
170//!     num_clusters: 3,
171//!     ..Default::default()
172//! };
173//!
174//! let mut clustering = EntityClustering::new(config);
175//! let result = clustering.cluster(&embeddings)?;
176//! # Ok(())
177//! # }
178//! ```
179//!
180//! # Cluster Quality Metrics
181//!
182//! The module computes several metrics to assess clustering quality:
183//!
184//! ## Silhouette Score
185//!
186//! Measures how similar entities are to their own cluster compared to other clusters.
187//! Range: [-1, 1], where:
188//! - 1: Perfect clustering
189//! - 0: Overlapping clusters
190//! - -1: Incorrect clustering
191//!
192//! ```rust,no_run
193//! # use oxirs_embed::clustering::*;
194//! # use std::collections::HashMap;
195//! # use scirs2_core::ndarray_ext::Array1;
196//! # fn example() -> anyhow::Result<()> {
197//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
198//! # let mut clustering = EntityClustering::new(ClusteringConfig::default());
199//! let result = clustering.cluster(&embeddings)?;
200//!
201//! if result.silhouette_score > 0.7 {
202//!     println!("Excellent clustering!");
203//! } else if result.silhouette_score > 0.5 {
204//!     println!("Good clustering");
205//! } else {
206//!     println!("Weak clustering - consider different parameters");
207//! }
208//! # Ok(())
209//! # }
210//! ```
211//!
212//! ## Inertia
213//!
214//! Sum of squared distances from entities to their cluster centroids.
215//! Lower values indicate tighter clusters (for K-Means).
216//!
217//! # Analyzing Cluster Results
218//!
219//! ```rust,no_run
220//! # use oxirs_embed::clustering::*;
221//! # use std::collections::HashMap;
222//! # use scirs2_core::ndarray_ext::Array1;
223//! # fn example() -> anyhow::Result<()> {
224//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
225//! # let mut clustering = EntityClustering::new(ClusteringConfig::default());
226//! let result = clustering.cluster(&embeddings)?;
227//!
228//! // Analyze cluster composition
229//! for (entity, cluster_id) in &result.assignments {
230//!     println!("Entity '{}' belongs to cluster {}", entity, cluster_id);
231//! }
232//!
233//! // Cluster statistics
234//! for (i, size) in result.cluster_sizes.iter().enumerate() {
235//!     println!("Cluster {}: {} entities", i, size);
236//! }
237//!
238//! // Find entities closest to cluster centroids
239//! for (cluster_id, centroid) in result.centroids.iter().enumerate() {
240//!     println!("Cluster {} centroid: {:?}", cluster_id, centroid);
241//! }
242//! # Ok(())
243//! # }
244//! ```
245//!
246//! # Use Cases
247//!
248//! ## Entity Type Discovery
249//!
250//! Automatically discover entity types without explicit labels:
251//!
252//! ```text
253//! Cluster 0: [paris, london, berlin]  -> Cities
254//! Cluster 1: [france, germany, uk]    -> Countries
255//! Cluster 2: [euro, dollar, pound]    -> Currencies
256//! ```
257//!
258//! ## Knowledge Graph Organization
259//!
260//! Group related entities for improved navigation and querying.
261//!
262//! ## Recommendation Systems
263//!
264//! Find similar users or items based on learned embeddings.
265//!
266//! ## Anomaly Detection
267//!
268//! Identify outliers using DBSCAN's noise detection (cluster_id == usize::MAX).
269//!
270//! # Performance Considerations
271//!
272//! - **K-Means**: O(n*k*d*i) where n=entities, k=clusters, d=dimensions, i=iterations
273//! - **Hierarchical**: O(n^2 * log n) - slow for large datasets
274//! - **DBSCAN**: O(n * log n) with spatial indexing
275//! - **Spectral**: O(n^3) due to eigenvalue computation - slow for large datasets
276//!
277//! For large knowledge graphs (>10,000 entities), K-Means or DBSCAN are recommended.
278//!
279//! # Choosing the Right Algorithm
280//!
281//! | Algorithm    | When to Use                                | Pros                           | Cons                    |
282//! |--------------|-------------------------------------------|--------------------------------|-------------------------|
283//! | K-Means      | Known cluster count, spherical clusters   | Fast, scalable                 | Requires K, spherical   |
284//! | Hierarchical | Nested structure, small datasets          | No K needed, hierarchical      | Slow, memory intensive  |
285//! | DBSCAN       | Arbitrary shapes, noise handling          | Finds outliers, no K needed    | Sensitive to parameters |
286//! | Spectral     | Non-convex clusters, graph structure      | Handles complex shapes         | Slow, requires K        |
287//!
288//! # See Also
289//!
290//! - [`EntityClustering`]: Main clustering interface
291//! - [`ClusteringConfig`]: Configuration options
292//! - [`ClusteringResult`]: Clustering results and metrics
293//! - [`ClusteringAlgorithm`]: Available algorithms
294
295use anyhow::{anyhow, Result};
296use scirs2_core::ndarray_ext::Array1;
297use scirs2_core::random::Random;
298use serde::{Deserialize, Serialize};
299use std::collections::{HashMap, HashSet};
300use tracing::{debug, info};
301
302/// Clustering algorithm type
303#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
304pub enum ClusteringAlgorithm {
305    /// K-Means clustering
306    KMeans,
307    /// Hierarchical clustering
308    Hierarchical,
309    /// DBSCAN (Density-Based Spatial Clustering)
310    DBSCAN,
311    /// Spectral clustering
312    Spectral,
313}
314
315/// Clustering configuration
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct ClusteringConfig {
318    /// Clustering algorithm to use
319    pub algorithm: ClusteringAlgorithm,
320    /// Number of clusters (for K-Means, Spectral)
321    pub num_clusters: usize,
322    /// Maximum iterations (for iterative algorithms)
323    pub max_iterations: usize,
324    /// Convergence tolerance
325    pub tolerance: f32,
326    /// Random seed for reproducibility
327    pub random_seed: Option<u64>,
328    /// DBSCAN epsilon (neighborhood radius)
329    pub epsilon: f32,
330    /// DBSCAN minimum points
331    pub min_points: usize,
332}
333
334impl Default for ClusteringConfig {
335    fn default() -> Self {
336        Self {
337            algorithm: ClusteringAlgorithm::KMeans,
338            num_clusters: 10,
339            max_iterations: 100,
340            tolerance: 1e-4,
341            random_seed: None,
342            epsilon: 0.5,
343            min_points: 5,
344        }
345    }
346}
347
348/// Clustering result
349#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct ClusteringResult {
351    /// Cluster assignments for each entity (entity_id -> cluster_id)
352    pub assignments: HashMap<String, usize>,
353    /// Cluster centroids (for K-Means, Spectral)
354    pub centroids: Vec<Array1<f32>>,
355    /// Cluster sizes
356    pub cluster_sizes: Vec<usize>,
357    /// Inertia/objective function value
358    pub inertia: f32,
359    /// Number of iterations performed
360    pub num_iterations: usize,
361    /// Silhouette score (quality metric, -1 to 1, higher is better)
362    pub silhouette_score: f32,
363}
364
365/// Entity clustering for knowledge graph embeddings
366pub struct EntityClustering {
367    config: ClusteringConfig,
368    rng: Random,
369}
370
371impl EntityClustering {
372    /// Create new entity clustering
373    pub fn new(config: ClusteringConfig) -> Self {
374        let rng = Random::default();
375
376        Self { config, rng }
377    }
378
379    /// Cluster entities based on their embeddings
380    pub fn cluster(
381        &mut self,
382        entity_embeddings: &HashMap<String, Array1<f32>>,
383    ) -> Result<ClusteringResult> {
384        if entity_embeddings.is_empty() {
385            return Err(anyhow!("No entity embeddings provided"));
386        }
387
388        info!(
389            "Clustering {} entities using {:?}",
390            entity_embeddings.len(),
391            self.config.algorithm
392        );
393
394        match self.config.algorithm {
395            ClusteringAlgorithm::KMeans => self.kmeans_clustering(entity_embeddings),
396            ClusteringAlgorithm::Hierarchical => self.hierarchical_clustering(entity_embeddings),
397            ClusteringAlgorithm::DBSCAN => self.dbscan_clustering(entity_embeddings),
398            ClusteringAlgorithm::Spectral => self.spectral_clustering(entity_embeddings),
399        }
400    }
401
402    /// K-Means clustering implementation
403    fn kmeans_clustering(
404        &mut self,
405        entity_embeddings: &HashMap<String, Array1<f32>>,
406    ) -> Result<ClusteringResult> {
407        let k = self.config.num_clusters;
408        let entity_list: Vec<String> = entity_embeddings.keys().cloned().collect();
409        let n = entity_list.len();
410
411        if k > n {
412            return Err(anyhow!("Number of clusters exceeds number of entities"));
413        }
414
415        // Initialize centroids randomly
416        let dim = entity_embeddings
417            .values()
418            .next()
419            .expect("entity_embeddings should not be empty")
420            .len();
421        let mut centroids: Vec<Array1<f32>> = Vec::new();
422
423        // K-Means++ initialization for better convergence
424        let first_idx = self.rng.random_range(0..n);
425        centroids.push(entity_embeddings[&entity_list[first_idx]].clone());
426
427        for _ in 1..k {
428            // Compute distances to nearest centroid
429            let distances: Vec<f32> = entity_list
430                .iter()
431                .map(|entity| {
432                    let emb = &entity_embeddings[entity];
433                    centroids
434                        .iter()
435                        .map(|c| self.euclidean_distance(emb, c))
436                        .fold(f32::INFINITY, f32::min)
437                        .powi(2)
438                })
439                .collect();
440
441            // Sample proportional to distance squared
442            let sum: f32 = distances.iter().sum();
443            let mut prob = self.rng.random_range(0.0..sum);
444            let mut next_idx = 0;
445
446            for (i, &dist) in distances.iter().enumerate() {
447                prob -= dist;
448                if prob <= 0.0 {
449                    next_idx = i;
450                    break;
451                }
452            }
453
454            centroids.push(entity_embeddings[&entity_list[next_idx]].clone());
455        }
456
457        // Iterative refinement
458        let mut assignments: HashMap<String, usize> = HashMap::new();
459        let mut prev_inertia = f32::INFINITY;
460
461        for iteration in 0..self.config.max_iterations {
462            // Assignment step
463            assignments.clear();
464            for entity in &entity_list {
465                let emb = &entity_embeddings[entity];
466                let cluster = self.nearest_centroid(emb, &centroids);
467                assignments.insert(entity.clone(), cluster);
468            }
469
470            // Update step
471            let mut new_centroids: Vec<Array1<f32>> = vec![Array1::zeros(dim); k];
472            let mut counts = vec![0; k];
473
474            for entity in &entity_list {
475                if let Some(&cluster) = assignments.get(entity) {
476                    new_centroids[cluster] = &new_centroids[cluster] + &entity_embeddings[entity];
477                    counts[cluster] += 1;
478                }
479            }
480
481            for (i, count) in counts.iter().enumerate() {
482                if *count > 0 {
483                    new_centroids[i] = &new_centroids[i] / (*count as f32);
484                }
485            }
486
487            centroids = new_centroids;
488
489            // Compute inertia
490            let inertia =
491                self.compute_inertia(&entity_list, entity_embeddings, &assignments, &centroids);
492
493            debug!("Iteration {}: inertia = {:.6}", iteration + 1, inertia);
494
495            // Check convergence
496            if (prev_inertia - inertia).abs() < self.config.tolerance {
497                info!("K-Means converged at iteration {}", iteration + 1);
498                break;
499            }
500
501            prev_inertia = inertia;
502        }
503
504        let final_inertia =
505            self.compute_inertia(&entity_list, entity_embeddings, &assignments, &centroids);
506        let cluster_sizes = self.compute_cluster_sizes(&assignments, k);
507        let silhouette =
508            self.compute_silhouette_score(&entity_list, entity_embeddings, &assignments);
509
510        Ok(ClusteringResult {
511            assignments,
512            centroids,
513            cluster_sizes,
514            inertia: final_inertia,
515            num_iterations: self.config.max_iterations,
516            silhouette_score: silhouette,
517        })
518    }
519
520    /// Hierarchical clustering (agglomerative)
521    fn hierarchical_clustering(
522        &mut self,
523        entity_embeddings: &HashMap<String, Array1<f32>>,
524    ) -> Result<ClusteringResult> {
525        let entity_list: Vec<String> = entity_embeddings.keys().cloned().collect();
526        let n = entity_list.len();
527
528        // Start with each entity in its own cluster
529        let mut clusters: Vec<HashSet<usize>> = (0..n)
530            .map(|i| {
531                let mut set = HashSet::new();
532                set.insert(i);
533                set
534            })
535            .collect();
536
537        // Merge clusters until we reach desired number
538        while clusters.len() > self.config.num_clusters {
539            // Find closest pair of clusters
540            let (i, j) = self.find_closest_clusters(&clusters, &entity_list, entity_embeddings);
541
542            // Merge clusters
543            let cluster_j = clusters.remove(j);
544            clusters[i].extend(cluster_j);
545        }
546
547        // Convert to assignments
548        let mut assignments = HashMap::new();
549        for (cluster_id, cluster) in clusters.iter().enumerate() {
550            for &entity_idx in cluster {
551                assignments.insert(entity_list[entity_idx].clone(), cluster_id);
552            }
553        }
554
555        // Compute centroids
556        let dim = entity_embeddings
557            .values()
558            .next()
559            .expect("entity_embeddings should not be empty")
560            .len();
561        let mut centroids = vec![Array1::zeros(dim); self.config.num_clusters];
562        let mut counts = vec![0; self.config.num_clusters];
563
564        for (entity, &cluster) in &assignments {
565            centroids[cluster] = &centroids[cluster] + &entity_embeddings[entity];
566            counts[cluster] += 1;
567        }
568
569        for (i, count) in counts.iter().enumerate() {
570            if *count > 0 {
571                centroids[i] = &centroids[i] / (*count as f32);
572            }
573        }
574
575        let inertia =
576            self.compute_inertia(&entity_list, entity_embeddings, &assignments, &centroids);
577        let cluster_sizes = self.compute_cluster_sizes(&assignments, self.config.num_clusters);
578        let silhouette =
579            self.compute_silhouette_score(&entity_list, entity_embeddings, &assignments);
580
581        Ok(ClusteringResult {
582            assignments,
583            centroids,
584            cluster_sizes,
585            inertia,
586            num_iterations: n - self.config.num_clusters,
587            silhouette_score: silhouette,
588        })
589    }
590
591    /// DBSCAN clustering implementation
592    fn dbscan_clustering(
593        &mut self,
594        entity_embeddings: &HashMap<String, Array1<f32>>,
595    ) -> Result<ClusteringResult> {
596        let entity_list: Vec<String> = entity_embeddings.keys().cloned().collect();
597        let n = entity_list.len();
598
599        let mut assignments: HashMap<String, usize> = HashMap::new();
600        let mut visited = HashSet::new();
601        let mut cluster_id = 0;
602
603        for i in 0..n {
604            let entity = &entity_list[i];
605            if visited.contains(&i) {
606                continue;
607            }
608
609            visited.insert(i);
610
611            // Find neighbors
612            let neighbors = self.find_neighbors(i, &entity_list, entity_embeddings);
613
614            if neighbors.len() < self.config.min_points {
615                // Mark as noise (-1 represented as max usize)
616                assignments.insert(entity.clone(), usize::MAX);
617            } else {
618                // Start new cluster
619                self.expand_cluster(
620                    i,
621                    &neighbors,
622                    cluster_id,
623                    &entity_list,
624                    entity_embeddings,
625                    &mut assignments,
626                    &mut visited,
627                );
628                cluster_id += 1;
629            }
630        }
631
632        // Compute centroids for non-noise clusters
633        let dim = entity_embeddings
634            .values()
635            .next()
636            .expect("entity_embeddings should not be empty")
637            .len();
638        let mut centroids = vec![Array1::zeros(dim); cluster_id];
639        let mut counts = vec![0; cluster_id];
640
641        for (entity, &cluster) in &assignments {
642            if cluster != usize::MAX {
643                centroids[cluster] = &centroids[cluster] + &entity_embeddings[entity];
644                counts[cluster] += 1;
645            }
646        }
647
648        for (i, count) in counts.iter().enumerate() {
649            if *count > 0 {
650                centroids[i] = &centroids[i] / (*count as f32);
651            }
652        }
653
654        let inertia =
655            self.compute_inertia(&entity_list, entity_embeddings, &assignments, &centroids);
656        let cluster_sizes = self.compute_cluster_sizes(&assignments, cluster_id);
657        let silhouette =
658            self.compute_silhouette_score(&entity_list, entity_embeddings, &assignments);
659
660        Ok(ClusteringResult {
661            assignments,
662            centroids,
663            cluster_sizes,
664            inertia,
665            num_iterations: 1,
666            silhouette_score: silhouette,
667        })
668    }
669
670    /// Spectral clustering (simplified implementation)
671    fn spectral_clustering(
672        &mut self,
673        entity_embeddings: &HashMap<String, Array1<f32>>,
674    ) -> Result<ClusteringResult> {
675        // For simplicity, use K-Means on normalized embeddings
676        // Full spectral clustering requires eigendecomposition of graph Laplacian
677
678        let mut normalized_embeddings = HashMap::new();
679        for (entity, emb) in entity_embeddings {
680            let norm = emb.dot(emb).sqrt();
681            if norm > 0.0 {
682                normalized_embeddings.insert(entity.clone(), emb / norm);
683            } else {
684                normalized_embeddings.insert(entity.clone(), emb.clone());
685            }
686        }
687
688        self.kmeans_clustering(&normalized_embeddings)
689    }
690
691    /// Find nearest centroid for an embedding
692    fn nearest_centroid(&self, embedding: &Array1<f32>, centroids: &[Array1<f32>]) -> usize {
693        centroids
694            .iter()
695            .enumerate()
696            .map(|(i, c)| (i, self.euclidean_distance(embedding, c)))
697            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
698            .map(|(i, _)| i)
699            .unwrap_or(0)
700    }
701
702    /// Compute Euclidean distance
703    fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
704        let diff = a - b;
705        diff.dot(&diff).sqrt()
706    }
707
708    /// Compute inertia (sum of squared distances to centroids)
709    fn compute_inertia(
710        &self,
711        entity_list: &[String],
712        embeddings: &HashMap<String, Array1<f32>>,
713        assignments: &HashMap<String, usize>,
714        centroids: &[Array1<f32>],
715    ) -> f32 {
716        entity_list
717            .iter()
718            .filter_map(|entity| {
719                assignments.get(entity).and_then(|&cluster| {
720                    if cluster < centroids.len() {
721                        Some(
722                            self.euclidean_distance(&embeddings[entity], &centroids[cluster])
723                                .powi(2),
724                        )
725                    } else {
726                        None
727                    }
728                })
729            })
730            .sum()
731    }
732
733    /// Compute cluster sizes
734    fn compute_cluster_sizes(
735        &self,
736        assignments: &HashMap<String, usize>,
737        num_clusters: usize,
738    ) -> Vec<usize> {
739        let mut sizes = vec![0; num_clusters];
740        for &cluster in assignments.values() {
741            if cluster < num_clusters {
742                sizes[cluster] += 1;
743            }
744        }
745        sizes
746    }
747
748    /// Compute silhouette score
749    fn compute_silhouette_score(
750        &self,
751        entity_list: &[String],
752        embeddings: &HashMap<String, Array1<f32>>,
753        assignments: &HashMap<String, usize>,
754    ) -> f32 {
755        if entity_list.len() < 2 {
756            return 0.0;
757        }
758
759        let scores: Vec<f32> = entity_list
760            .iter()
761            .filter_map(|entity| {
762                assignments.get(entity).map(|&cluster| {
763                    let emb = &embeddings[entity];
764
765                    // Compute average distance to same cluster (a)
766                    let same_cluster: Vec<f32> = entity_list
767                        .iter()
768                        .filter_map(|other| {
769                            if other != entity && assignments.get(other) == Some(&cluster) {
770                                Some(self.euclidean_distance(emb, &embeddings[other]))
771                            } else {
772                                None
773                            }
774                        })
775                        .collect();
776
777                    let a = if !same_cluster.is_empty() {
778                        same_cluster.iter().sum::<f32>() / same_cluster.len() as f32
779                    } else {
780                        0.0
781                    };
782
783                    // Compute minimum average distance to other clusters (b)
784                    let unique_clusters: HashSet<usize> = assignments.values().copied().collect();
785                    let b = unique_clusters
786                        .iter()
787                        .filter(|&&c| c != cluster)
788                        .map(|&other_cluster| {
789                            let distances: Vec<f32> = entity_list
790                                .iter()
791                                .filter_map(|other| {
792                                    if assignments.get(other) == Some(&other_cluster) {
793                                        Some(self.euclidean_distance(emb, &embeddings[other]))
794                                    } else {
795                                        None
796                                    }
797                                })
798                                .collect();
799
800                            if !distances.is_empty() {
801                                distances.iter().sum::<f32>() / distances.len() as f32
802                            } else {
803                                f32::INFINITY
804                            }
805                        })
806                        .fold(f32::INFINITY, f32::min);
807
808                    (b - a) / a.max(b).max(1e-10)
809                })
810            })
811            .collect();
812
813        if scores.is_empty() {
814            0.0
815        } else {
816            scores.iter().sum::<f32>() / scores.len() as f32
817        }
818    }
819
820    /// Find closest pair of clusters for hierarchical clustering
821    fn find_closest_clusters(
822        &self,
823        clusters: &[HashSet<usize>],
824        entity_list: &[String],
825        embeddings: &HashMap<String, Array1<f32>>,
826    ) -> (usize, usize) {
827        let mut min_dist = f32::INFINITY;
828        let mut closest_pair = (0, 1);
829
830        for i in 0..clusters.len() {
831            for j in (i + 1)..clusters.len() {
832                // Average linkage
833                let mut total_dist = 0.0;
834                let mut count = 0;
835
836                for &idx_i in &clusters[i] {
837                    for &idx_j in &clusters[j] {
838                        let dist = self.euclidean_distance(
839                            &embeddings[&entity_list[idx_i]],
840                            &embeddings[&entity_list[idx_j]],
841                        );
842                        total_dist += dist;
843                        count += 1;
844                    }
845                }
846
847                let avg_dist = if count > 0 {
848                    total_dist / count as f32
849                } else {
850                    f32::INFINITY
851                };
852
853                if avg_dist < min_dist {
854                    min_dist = avg_dist;
855                    closest_pair = (i, j);
856                }
857            }
858        }
859
860        closest_pair
861    }
862
863    /// Find neighbors within epsilon distance for DBSCAN
864    fn find_neighbors(
865        &self,
866        idx: usize,
867        entity_list: &[String],
868        embeddings: &HashMap<String, Array1<f32>>,
869    ) -> Vec<usize> {
870        let entity = &entity_list[idx];
871        let emb = &embeddings[entity];
872
873        entity_list
874            .iter()
875            .enumerate()
876            .filter_map(|(i, other)| {
877                if i != idx
878                    && self.euclidean_distance(emb, &embeddings[other]) <= self.config.epsilon
879                {
880                    Some(i)
881                } else {
882                    None
883                }
884            })
885            .collect()
886    }
887
888    /// Expand cluster for DBSCAN
889    #[allow(clippy::too_many_arguments)]
890    fn expand_cluster(
891        &self,
892        idx: usize,
893        neighbors: &[usize],
894        cluster_id: usize,
895        entity_list: &[String],
896        embeddings: &HashMap<String, Array1<f32>>,
897        assignments: &mut HashMap<String, usize>,
898        visited: &mut HashSet<usize>,
899    ) {
900        assignments.insert(entity_list[idx].clone(), cluster_id);
901
902        let mut queue: Vec<usize> = neighbors.to_vec();
903        let mut processed = 0;
904
905        while processed < queue.len() {
906            let neighbor_idx = queue[processed];
907            processed += 1;
908
909            if !visited.contains(&neighbor_idx) {
910                visited.insert(neighbor_idx);
911
912                let neighbor_neighbors = self.find_neighbors(neighbor_idx, entity_list, embeddings);
913
914                if neighbor_neighbors.len() >= self.config.min_points {
915                    queue.extend(neighbor_neighbors);
916                }
917            }
918
919            if !assignments.contains_key(&entity_list[neighbor_idx]) {
920                assignments.insert(entity_list[neighbor_idx].clone(), cluster_id);
921            }
922        }
923    }
924}
925
926#[cfg(test)]
927mod tests {
928    use super::*;
929    use scirs2_core::ndarray_ext::array;
930
931    #[test]
932    fn test_kmeans_clustering() {
933        let mut embeddings = HashMap::new();
934        embeddings.insert("e1".to_string(), array![1.0, 1.0]);
935        embeddings.insert("e2".to_string(), array![1.1, 0.9]);
936        embeddings.insert("e3".to_string(), array![5.0, 5.0]);
937        embeddings.insert("e4".to_string(), array![5.1, 4.9]);
938
939        let config = ClusteringConfig {
940            algorithm: ClusteringAlgorithm::KMeans,
941            num_clusters: 2,
942            ..Default::default()
943        };
944
945        let mut clustering = EntityClustering::new(config);
946        let result = clustering.cluster(&embeddings).unwrap();
947
948        assert_eq!(result.assignments.len(), 4);
949        assert_eq!(result.centroids.len(), 2);
950        assert_eq!(result.cluster_sizes.len(), 2);
951
952        // Check that similar entities are in the same cluster
953        assert_eq!(result.assignments["e1"], result.assignments["e2"]);
954        assert_eq!(result.assignments["e3"], result.assignments["e4"]);
955    }
956
957    #[test]
958    fn test_silhouette_score() {
959        let mut embeddings = HashMap::new();
960        embeddings.insert("e1".to_string(), array![0.0, 0.0]);
961        embeddings.insert("e2".to_string(), array![1.0, 1.0]);
962        embeddings.insert("e3".to_string(), array![5.0, 5.0]);
963
964        let config = ClusteringConfig {
965            num_clusters: 2,
966            ..Default::default()
967        };
968
969        let mut clustering = EntityClustering::new(config);
970        let result = clustering.cluster(&embeddings).unwrap();
971
972        // Silhouette score should be between -1 and 1
973        assert!(result.silhouette_score >= -1.0 && result.silhouette_score <= 1.0);
974    }
975}