oxirs_embed/
clustering.rs

1//! Clustering Support for Knowledge Graph Embeddings
2//!
3//! This module provides comprehensive clustering algorithms for analyzing and grouping
4//! entities based on their learned embeddings. Clustering helps discover latent
5//! structure in knowledge graphs and can improve downstream tasks such as entity
6//! type discovery, knowledge organization, and recommendation systems.
7//!
8//! # Overview
9//!
10//! The module provides four powerful clustering algorithms:
11//! - **K-Means**: Fast, spherical clusters with K-Means++ initialization
12//! - **Hierarchical**: Bottom-up agglomerative clustering with linkage methods
13//! - **DBSCAN**: Density-based clustering that discovers arbitrary shapes and handles noise
14//! - **Spectral**: Graph-based clustering using eigenvalues of similarity matrices
15//!
16//! Each algorithm is suited for different data characteristics and use cases.
17//!
18//! # Quick Start
19//!
20//! ```rust,no_run
21//! use oxirs_embed::{
22//!     TransE, ModelConfig, Triple, NamedNode, EmbeddingModel,
23//!     clustering::{EntityClustering, ClusteringConfig, ClusteringAlgorithm},
24//! };
25//! use std::collections::HashMap;
26//! use scirs2_core::ndarray_ext::Array1;
27//!
28//! # async fn example() -> anyhow::Result<()> {
29//! // 1. Train an embedding model
30//! let config = ModelConfig::default().with_dimensions(128);
31//! let mut model = TransE::new(config);
32//!
33//! model.add_triple(Triple::new(
34//!     NamedNode::new("paris")?,
35//!     NamedNode::new("capital_of")?,
36//!     NamedNode::new("france")?,
37//! ))?;
38//! model.add_triple(Triple::new(
39//!     NamedNode::new("london")?,
40//!     NamedNode::new("capital_of")?,
41//!     NamedNode::new("uk")?,
42//! ))?;
43//!
44//! model.train(Some(100)).await?;
45//!
46//! // 2. Extract embeddings
47//! let mut embeddings = HashMap::new();
48//! for entity in model.get_entities() {
49//!     if let Ok(emb) = model.get_entity_embedding(&entity) {
50//!         let array = Array1::from_vec(emb.values);
51//!         embeddings.insert(entity, array);
52//!     }
53//! }
54//!
55//! // 3. Perform clustering
56//! let cluster_config = ClusteringConfig {
57//!     algorithm: ClusteringAlgorithm::KMeans,
58//!     num_clusters: 3,
59//!     max_iterations: 50,
60//!     ..Default::default()
61//! };
62//!
63//! let mut clustering = EntityClustering::new(cluster_config);
64//! let result = clustering.cluster(&embeddings)?;
65//!
66//! println!("Silhouette score: {:.3}", result.silhouette_score);
67//! println!("Cluster 0: {} entities", result.cluster_sizes[0]);
68//! # Ok(())
69//! # }
70//! ```
71//!
72//! # Clustering Algorithms
73//!
74//! ## K-Means Clustering
75//!
76//! Fast and efficient for spherical clusters. Uses K-Means++ initialization for
77//! better convergence. Best for when you know the number of clusters.
78//!
79//! ```rust,no_run
80//! use oxirs_embed::clustering::{ClusteringConfig, ClusteringAlgorithm, EntityClustering};
81//! use std::collections::HashMap;
82//! use scirs2_core::ndarray_ext::Array1;
83//!
84//! # fn example() -> anyhow::Result<()> {
85//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
86//! let config = ClusteringConfig {
87//!     algorithm: ClusteringAlgorithm::KMeans,
88//!     num_clusters: 5,
89//!     max_iterations: 100,
90//!     tolerance: 0.0001,
91//!     ..Default::default()
92//! };
93//!
94//! let mut clustering = EntityClustering::new(config);
95//! let result = clustering.cluster(&embeddings)?;
96//! # Ok(())
97//! # }
98//! ```
99//!
100//! ## Hierarchical Clustering
101//!
102//! Builds a hierarchy of clusters using bottom-up approach. Supports different
103//! linkage methods (single, average, complete). Does not require specifying the
104//! number of clusters upfront.
105//!
106//! ```rust,no_run
107//! use oxirs_embed::clustering::{ClusteringConfig, ClusteringAlgorithm, EntityClustering};
108//! use std::collections::HashMap;
109//! use scirs2_core::ndarray_ext::Array1;
110//!
111//! # fn example() -> anyhow::Result<()> {
112//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
113//! let config = ClusteringConfig {
114//!     algorithm: ClusteringAlgorithm::Hierarchical,
115//!     num_clusters: 4,
116//!     ..Default::default()
117//! };
118//!
119//! let mut clustering = EntityClustering::new(config);
120//! let result = clustering.cluster(&embeddings)?;
121//! # Ok(())
122//! # }
123//! ```
124//!
125//! ## DBSCAN (Density-Based Clustering)
126//!
127//! Discovers clusters of arbitrary shape and automatically identifies noise/outliers.
128//! Does not require specifying the number of clusters. Best for non-spherical clusters.
129//!
130//! ```rust,no_run
131//! use oxirs_embed::clustering::{ClusteringConfig, ClusteringAlgorithm, EntityClustering};
132//! use std::collections::HashMap;
133//! use scirs2_core::ndarray_ext::Array1;
134//!
135//! # fn example() -> anyhow::Result<()> {
136//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
137//! let config = ClusteringConfig {
138//!     algorithm: ClusteringAlgorithm::DBSCAN,
139//!     epsilon: 0.5,        // Neighborhood radius
140//!     min_points: 5,       // Minimum points to form cluster
141//!     ..Default::default()
142//! };
143//!
144//! let mut clustering = EntityClustering::new(config);
145//! let result = clustering.cluster(&embeddings)?;
146//!
147//! // Check for noise points (cluster_id == usize::MAX)
148//! let noise_count = result.assignments.values()
149//!     .filter(|&&id| id == usize::MAX)
150//!     .count();
151//! println!("Noise points: {}", noise_count);
152//! # Ok(())
153//! # }
154//! ```
155//!
156//! ## Spectral Clustering
157//!
158//! Graph-based clustering using eigenvalues of the similarity matrix. Effective for
159//! non-convex clusters and can capture complex geometric structures.
160//!
161//! ```rust,no_run
162//! use oxirs_embed::clustering::{ClusteringConfig, ClusteringAlgorithm, EntityClustering};
163//! use std::collections::HashMap;
164//! use scirs2_core::ndarray_ext::Array1;
165//!
166//! # fn example() -> anyhow::Result<()> {
167//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
168//! let config = ClusteringConfig {
169//!     algorithm: ClusteringAlgorithm::Spectral,
170//!     num_clusters: 3,
171//!     ..Default::default()
172//! };
173//!
174//! let mut clustering = EntityClustering::new(config);
175//! let result = clustering.cluster(&embeddings)?;
176//! # Ok(())
177//! # }
178//! ```
179//!
180//! # Cluster Quality Metrics
181//!
182//! The module computes several metrics to assess clustering quality:
183//!
184//! ## Silhouette Score
185//!
186//! Measures how similar entities are to their own cluster compared to other clusters.
187//! Range: [-1, 1], where:
188//! - 1: Perfect clustering
189//! - 0: Overlapping clusters
190//! - -1: Incorrect clustering
191//!
192//! ```rust,no_run
193//! # use oxirs_embed::clustering::*;
194//! # use std::collections::HashMap;
195//! # use scirs2_core::ndarray_ext::Array1;
196//! # fn example() -> anyhow::Result<()> {
197//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
198//! # let mut clustering = EntityClustering::new(ClusteringConfig::default());
199//! let result = clustering.cluster(&embeddings)?;
200//!
201//! if result.silhouette_score > 0.7 {
202//!     println!("Excellent clustering!");
203//! } else if result.silhouette_score > 0.5 {
204//!     println!("Good clustering");
205//! } else {
206//!     println!("Weak clustering - consider different parameters");
207//! }
208//! # Ok(())
209//! # }
210//! ```
211//!
212//! ## Inertia
213//!
214//! Sum of squared distances from entities to their cluster centroids.
215//! Lower values indicate tighter clusters (for K-Means).
216//!
217//! # Analyzing Cluster Results
218//!
219//! ```rust,no_run
220//! # use oxirs_embed::clustering::*;
221//! # use std::collections::HashMap;
222//! # use scirs2_core::ndarray_ext::Array1;
223//! # fn example() -> anyhow::Result<()> {
224//! # let embeddings: HashMap<String, Array1<f32>> = HashMap::new();
225//! # let mut clustering = EntityClustering::new(ClusteringConfig::default());
226//! let result = clustering.cluster(&embeddings)?;
227//!
228//! // Analyze cluster composition
229//! for (entity, cluster_id) in &result.assignments {
230//!     println!("Entity '{}' belongs to cluster {}", entity, cluster_id);
231//! }
232//!
233//! // Cluster statistics
234//! for (i, size) in result.cluster_sizes.iter().enumerate() {
235//!     println!("Cluster {}: {} entities", i, size);
236//! }
237//!
238//! // Find entities closest to cluster centroids
239//! for (cluster_id, centroid) in result.centroids.iter().enumerate() {
240//!     println!("Cluster {} centroid: {:?}", cluster_id, centroid);
241//! }
242//! # Ok(())
243//! # }
244//! ```
245//!
246//! # Use Cases
247//!
248//! ## Entity Type Discovery
249//!
250//! Automatically discover entity types without explicit labels:
251//!
252//! ```text
253//! Cluster 0: [paris, london, berlin]  -> Cities
254//! Cluster 1: [france, germany, uk]    -> Countries
255//! Cluster 2: [euro, dollar, pound]    -> Currencies
256//! ```
257//!
258//! ## Knowledge Graph Organization
259//!
260//! Group related entities for improved navigation and querying.
261//!
262//! ## Recommendation Systems
263//!
264//! Find similar users or items based on learned embeddings.
265//!
266//! ## Anomaly Detection
267//!
268//! Identify outliers using DBSCAN's noise detection (cluster_id == usize::MAX).
269//!
270//! # Performance Considerations
271//!
272//! - **K-Means**: O(n*k*d*i) where n=entities, k=clusters, d=dimensions, i=iterations
273//! - **Hierarchical**: O(n^2 * log n) - slow for large datasets
274//! - **DBSCAN**: O(n * log n) with spatial indexing
275//! - **Spectral**: O(n^3) due to eigenvalue computation - slow for large datasets
276//!
277//! For large knowledge graphs (>10,000 entities), K-Means or DBSCAN are recommended.
278//!
279//! # Choosing the Right Algorithm
280//!
281//! | Algorithm    | When to Use                                | Pros                           | Cons                    |
282//! |--------------|-------------------------------------------|--------------------------------|-------------------------|
283//! | K-Means      | Known cluster count, spherical clusters   | Fast, scalable                 | Requires K, spherical   |
284//! | Hierarchical | Nested structure, small datasets          | No K needed, hierarchical      | Slow, memory intensive  |
285//! | DBSCAN       | Arbitrary shapes, noise handling          | Finds outliers, no K needed    | Sensitive to parameters |
286//! | Spectral     | Non-convex clusters, graph structure      | Handles complex shapes         | Slow, requires K        |
287//!
288//! # See Also
289//!
290//! - [`EntityClustering`]: Main clustering interface
291//! - [`ClusteringConfig`]: Configuration options
292//! - [`ClusteringResult`]: Clustering results and metrics
293//! - [`ClusteringAlgorithm`]: Available algorithms
294
295use anyhow::{anyhow, Result};
296use scirs2_core::ndarray_ext::Array1;
297use scirs2_core::random::Random;
298use serde::{Deserialize, Serialize};
299use std::collections::{HashMap, HashSet};
300use tracing::{debug, info};
301
302/// Clustering algorithm type
303#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
304pub enum ClusteringAlgorithm {
305    /// K-Means clustering
306    KMeans,
307    /// Hierarchical clustering
308    Hierarchical,
309    /// DBSCAN (Density-Based Spatial Clustering)
310    DBSCAN,
311    /// Spectral clustering
312    Spectral,
313}
314
315/// Clustering configuration
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct ClusteringConfig {
318    /// Clustering algorithm to use
319    pub algorithm: ClusteringAlgorithm,
320    /// Number of clusters (for K-Means, Spectral)
321    pub num_clusters: usize,
322    /// Maximum iterations (for iterative algorithms)
323    pub max_iterations: usize,
324    /// Convergence tolerance
325    pub tolerance: f32,
326    /// Random seed for reproducibility
327    pub random_seed: Option<u64>,
328    /// DBSCAN epsilon (neighborhood radius)
329    pub epsilon: f32,
330    /// DBSCAN minimum points
331    pub min_points: usize,
332}
333
334impl Default for ClusteringConfig {
335    fn default() -> Self {
336        Self {
337            algorithm: ClusteringAlgorithm::KMeans,
338            num_clusters: 10,
339            max_iterations: 100,
340            tolerance: 1e-4,
341            random_seed: None,
342            epsilon: 0.5,
343            min_points: 5,
344        }
345    }
346}
347
348/// Clustering result
349#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct ClusteringResult {
351    /// Cluster assignments for each entity (entity_id -> cluster_id)
352    pub assignments: HashMap<String, usize>,
353    /// Cluster centroids (for K-Means, Spectral)
354    pub centroids: Vec<Array1<f32>>,
355    /// Cluster sizes
356    pub cluster_sizes: Vec<usize>,
357    /// Inertia/objective function value
358    pub inertia: f32,
359    /// Number of iterations performed
360    pub num_iterations: usize,
361    /// Silhouette score (quality metric, -1 to 1, higher is better)
362    pub silhouette_score: f32,
363}
364
365/// Entity clustering for knowledge graph embeddings
366pub struct EntityClustering {
367    config: ClusteringConfig,
368    rng: Random,
369}
370
371impl EntityClustering {
372    /// Create new entity clustering
373    pub fn new(config: ClusteringConfig) -> Self {
374        let rng = Random::default();
375
376        Self { config, rng }
377    }
378
379    /// Cluster entities based on their embeddings
380    pub fn cluster(
381        &mut self,
382        entity_embeddings: &HashMap<String, Array1<f32>>,
383    ) -> Result<ClusteringResult> {
384        if entity_embeddings.is_empty() {
385            return Err(anyhow!("No entity embeddings provided"));
386        }
387
388        info!(
389            "Clustering {} entities using {:?}",
390            entity_embeddings.len(),
391            self.config.algorithm
392        );
393
394        match self.config.algorithm {
395            ClusteringAlgorithm::KMeans => self.kmeans_clustering(entity_embeddings),
396            ClusteringAlgorithm::Hierarchical => self.hierarchical_clustering(entity_embeddings),
397            ClusteringAlgorithm::DBSCAN => self.dbscan_clustering(entity_embeddings),
398            ClusteringAlgorithm::Spectral => self.spectral_clustering(entity_embeddings),
399        }
400    }
401
402    /// K-Means clustering implementation
403    fn kmeans_clustering(
404        &mut self,
405        entity_embeddings: &HashMap<String, Array1<f32>>,
406    ) -> Result<ClusteringResult> {
407        let k = self.config.num_clusters;
408        let entity_list: Vec<String> = entity_embeddings.keys().cloned().collect();
409        let n = entity_list.len();
410
411        if k > n {
412            return Err(anyhow!("Number of clusters exceeds number of entities"));
413        }
414
415        // Initialize centroids randomly
416        let dim = entity_embeddings.values().next().unwrap().len();
417        let mut centroids: Vec<Array1<f32>> = Vec::new();
418
419        // K-Means++ initialization for better convergence
420        let first_idx = self.rng.random_range(0..n);
421        centroids.push(entity_embeddings[&entity_list[first_idx]].clone());
422
423        for _ in 1..k {
424            // Compute distances to nearest centroid
425            let distances: Vec<f32> = entity_list
426                .iter()
427                .map(|entity| {
428                    let emb = &entity_embeddings[entity];
429                    centroids
430                        .iter()
431                        .map(|c| self.euclidean_distance(emb, c))
432                        .fold(f32::INFINITY, f32::min)
433                        .powi(2)
434                })
435                .collect();
436
437            // Sample proportional to distance squared
438            let sum: f32 = distances.iter().sum();
439            let mut prob = self.rng.random_range(0.0..sum);
440            let mut next_idx = 0;
441
442            for (i, &dist) in distances.iter().enumerate() {
443                prob -= dist;
444                if prob <= 0.0 {
445                    next_idx = i;
446                    break;
447                }
448            }
449
450            centroids.push(entity_embeddings[&entity_list[next_idx]].clone());
451        }
452
453        // Iterative refinement
454        let mut assignments: HashMap<String, usize> = HashMap::new();
455        let mut prev_inertia = f32::INFINITY;
456
457        for iteration in 0..self.config.max_iterations {
458            // Assignment step
459            assignments.clear();
460            for entity in &entity_list {
461                let emb = &entity_embeddings[entity];
462                let cluster = self.nearest_centroid(emb, &centroids);
463                assignments.insert(entity.clone(), cluster);
464            }
465
466            // Update step
467            let mut new_centroids: Vec<Array1<f32>> = vec![Array1::zeros(dim); k];
468            let mut counts = vec![0; k];
469
470            for entity in &entity_list {
471                if let Some(&cluster) = assignments.get(entity) {
472                    new_centroids[cluster] = &new_centroids[cluster] + &entity_embeddings[entity];
473                    counts[cluster] += 1;
474                }
475            }
476
477            for (i, count) in counts.iter().enumerate() {
478                if *count > 0 {
479                    new_centroids[i] = &new_centroids[i] / (*count as f32);
480                }
481            }
482
483            centroids = new_centroids;
484
485            // Compute inertia
486            let inertia =
487                self.compute_inertia(&entity_list, entity_embeddings, &assignments, &centroids);
488
489            debug!("Iteration {}: inertia = {:.6}", iteration + 1, inertia);
490
491            // Check convergence
492            if (prev_inertia - inertia).abs() < self.config.tolerance {
493                info!("K-Means converged at iteration {}", iteration + 1);
494                break;
495            }
496
497            prev_inertia = inertia;
498        }
499
500        let final_inertia =
501            self.compute_inertia(&entity_list, entity_embeddings, &assignments, &centroids);
502        let cluster_sizes = self.compute_cluster_sizes(&assignments, k);
503        let silhouette =
504            self.compute_silhouette_score(&entity_list, entity_embeddings, &assignments);
505
506        Ok(ClusteringResult {
507            assignments,
508            centroids,
509            cluster_sizes,
510            inertia: final_inertia,
511            num_iterations: self.config.max_iterations,
512            silhouette_score: silhouette,
513        })
514    }
515
516    /// Hierarchical clustering (agglomerative)
517    fn hierarchical_clustering(
518        &mut self,
519        entity_embeddings: &HashMap<String, Array1<f32>>,
520    ) -> Result<ClusteringResult> {
521        let entity_list: Vec<String> = entity_embeddings.keys().cloned().collect();
522        let n = entity_list.len();
523
524        // Start with each entity in its own cluster
525        let mut clusters: Vec<HashSet<usize>> = (0..n)
526            .map(|i| {
527                let mut set = HashSet::new();
528                set.insert(i);
529                set
530            })
531            .collect();
532
533        // Merge clusters until we reach desired number
534        while clusters.len() > self.config.num_clusters {
535            // Find closest pair of clusters
536            let (i, j) = self.find_closest_clusters(&clusters, &entity_list, entity_embeddings);
537
538            // Merge clusters
539            let cluster_j = clusters.remove(j);
540            clusters[i].extend(cluster_j);
541        }
542
543        // Convert to assignments
544        let mut assignments = HashMap::new();
545        for (cluster_id, cluster) in clusters.iter().enumerate() {
546            for &entity_idx in cluster {
547                assignments.insert(entity_list[entity_idx].clone(), cluster_id);
548            }
549        }
550
551        // Compute centroids
552        let dim = entity_embeddings.values().next().unwrap().len();
553        let mut centroids = vec![Array1::zeros(dim); self.config.num_clusters];
554        let mut counts = vec![0; self.config.num_clusters];
555
556        for (entity, &cluster) in &assignments {
557            centroids[cluster] = &centroids[cluster] + &entity_embeddings[entity];
558            counts[cluster] += 1;
559        }
560
561        for (i, count) in counts.iter().enumerate() {
562            if *count > 0 {
563                centroids[i] = &centroids[i] / (*count as f32);
564            }
565        }
566
567        let inertia =
568            self.compute_inertia(&entity_list, entity_embeddings, &assignments, &centroids);
569        let cluster_sizes = self.compute_cluster_sizes(&assignments, self.config.num_clusters);
570        let silhouette =
571            self.compute_silhouette_score(&entity_list, entity_embeddings, &assignments);
572
573        Ok(ClusteringResult {
574            assignments,
575            centroids,
576            cluster_sizes,
577            inertia,
578            num_iterations: n - self.config.num_clusters,
579            silhouette_score: silhouette,
580        })
581    }
582
583    /// DBSCAN clustering implementation
584    fn dbscan_clustering(
585        &mut self,
586        entity_embeddings: &HashMap<String, Array1<f32>>,
587    ) -> Result<ClusteringResult> {
588        let entity_list: Vec<String> = entity_embeddings.keys().cloned().collect();
589        let n = entity_list.len();
590
591        let mut assignments: HashMap<String, usize> = HashMap::new();
592        let mut visited = HashSet::new();
593        let mut cluster_id = 0;
594
595        for i in 0..n {
596            let entity = &entity_list[i];
597            if visited.contains(&i) {
598                continue;
599            }
600
601            visited.insert(i);
602
603            // Find neighbors
604            let neighbors = self.find_neighbors(i, &entity_list, entity_embeddings);
605
606            if neighbors.len() < self.config.min_points {
607                // Mark as noise (-1 represented as max usize)
608                assignments.insert(entity.clone(), usize::MAX);
609            } else {
610                // Start new cluster
611                self.expand_cluster(
612                    i,
613                    &neighbors,
614                    cluster_id,
615                    &entity_list,
616                    entity_embeddings,
617                    &mut assignments,
618                    &mut visited,
619                );
620                cluster_id += 1;
621            }
622        }
623
624        // Compute centroids for non-noise clusters
625        let dim = entity_embeddings.values().next().unwrap().len();
626        let mut centroids = vec![Array1::zeros(dim); cluster_id];
627        let mut counts = vec![0; cluster_id];
628
629        for (entity, &cluster) in &assignments {
630            if cluster != usize::MAX {
631                centroids[cluster] = &centroids[cluster] + &entity_embeddings[entity];
632                counts[cluster] += 1;
633            }
634        }
635
636        for (i, count) in counts.iter().enumerate() {
637            if *count > 0 {
638                centroids[i] = &centroids[i] / (*count as f32);
639            }
640        }
641
642        let inertia =
643            self.compute_inertia(&entity_list, entity_embeddings, &assignments, &centroids);
644        let cluster_sizes = self.compute_cluster_sizes(&assignments, cluster_id);
645        let silhouette =
646            self.compute_silhouette_score(&entity_list, entity_embeddings, &assignments);
647
648        Ok(ClusteringResult {
649            assignments,
650            centroids,
651            cluster_sizes,
652            inertia,
653            num_iterations: 1,
654            silhouette_score: silhouette,
655        })
656    }
657
658    /// Spectral clustering (simplified implementation)
659    fn spectral_clustering(
660        &mut self,
661        entity_embeddings: &HashMap<String, Array1<f32>>,
662    ) -> Result<ClusteringResult> {
663        // For simplicity, use K-Means on normalized embeddings
664        // Full spectral clustering requires eigendecomposition of graph Laplacian
665
666        let mut normalized_embeddings = HashMap::new();
667        for (entity, emb) in entity_embeddings {
668            let norm = emb.dot(emb).sqrt();
669            if norm > 0.0 {
670                normalized_embeddings.insert(entity.clone(), emb / norm);
671            } else {
672                normalized_embeddings.insert(entity.clone(), emb.clone());
673            }
674        }
675
676        self.kmeans_clustering(&normalized_embeddings)
677    }
678
679    /// Find nearest centroid for an embedding
680    fn nearest_centroid(&self, embedding: &Array1<f32>, centroids: &[Array1<f32>]) -> usize {
681        centroids
682            .iter()
683            .enumerate()
684            .map(|(i, c)| (i, self.euclidean_distance(embedding, c)))
685            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
686            .map(|(i, _)| i)
687            .unwrap_or(0)
688    }
689
690    /// Compute Euclidean distance
691    fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
692        let diff = a - b;
693        diff.dot(&diff).sqrt()
694    }
695
696    /// Compute inertia (sum of squared distances to centroids)
697    fn compute_inertia(
698        &self,
699        entity_list: &[String],
700        embeddings: &HashMap<String, Array1<f32>>,
701        assignments: &HashMap<String, usize>,
702        centroids: &[Array1<f32>],
703    ) -> f32 {
704        entity_list
705            .iter()
706            .filter_map(|entity| {
707                assignments.get(entity).and_then(|&cluster| {
708                    if cluster < centroids.len() {
709                        Some(
710                            self.euclidean_distance(&embeddings[entity], &centroids[cluster])
711                                .powi(2),
712                        )
713                    } else {
714                        None
715                    }
716                })
717            })
718            .sum()
719    }
720
721    /// Compute cluster sizes
722    fn compute_cluster_sizes(
723        &self,
724        assignments: &HashMap<String, usize>,
725        num_clusters: usize,
726    ) -> Vec<usize> {
727        let mut sizes = vec![0; num_clusters];
728        for &cluster in assignments.values() {
729            if cluster < num_clusters {
730                sizes[cluster] += 1;
731            }
732        }
733        sizes
734    }
735
736    /// Compute silhouette score
737    fn compute_silhouette_score(
738        &self,
739        entity_list: &[String],
740        embeddings: &HashMap<String, Array1<f32>>,
741        assignments: &HashMap<String, usize>,
742    ) -> f32 {
743        if entity_list.len() < 2 {
744            return 0.0;
745        }
746
747        let scores: Vec<f32> = entity_list
748            .iter()
749            .filter_map(|entity| {
750                assignments.get(entity).map(|&cluster| {
751                    let emb = &embeddings[entity];
752
753                    // Compute average distance to same cluster (a)
754                    let same_cluster: Vec<f32> = entity_list
755                        .iter()
756                        .filter_map(|other| {
757                            if other != entity && assignments.get(other) == Some(&cluster) {
758                                Some(self.euclidean_distance(emb, &embeddings[other]))
759                            } else {
760                                None
761                            }
762                        })
763                        .collect();
764
765                    let a = if !same_cluster.is_empty() {
766                        same_cluster.iter().sum::<f32>() / same_cluster.len() as f32
767                    } else {
768                        0.0
769                    };
770
771                    // Compute minimum average distance to other clusters (b)
772                    let unique_clusters: HashSet<usize> = assignments.values().copied().collect();
773                    let b = unique_clusters
774                        .iter()
775                        .filter(|&&c| c != cluster)
776                        .map(|&other_cluster| {
777                            let distances: Vec<f32> = entity_list
778                                .iter()
779                                .filter_map(|other| {
780                                    if assignments.get(other) == Some(&other_cluster) {
781                                        Some(self.euclidean_distance(emb, &embeddings[other]))
782                                    } else {
783                                        None
784                                    }
785                                })
786                                .collect();
787
788                            if !distances.is_empty() {
789                                distances.iter().sum::<f32>() / distances.len() as f32
790                            } else {
791                                f32::INFINITY
792                            }
793                        })
794                        .fold(f32::INFINITY, f32::min);
795
796                    (b - a) / a.max(b).max(1e-10)
797                })
798            })
799            .collect();
800
801        if scores.is_empty() {
802            0.0
803        } else {
804            scores.iter().sum::<f32>() / scores.len() as f32
805        }
806    }
807
808    /// Find closest pair of clusters for hierarchical clustering
809    fn find_closest_clusters(
810        &self,
811        clusters: &[HashSet<usize>],
812        entity_list: &[String],
813        embeddings: &HashMap<String, Array1<f32>>,
814    ) -> (usize, usize) {
815        let mut min_dist = f32::INFINITY;
816        let mut closest_pair = (0, 1);
817
818        for i in 0..clusters.len() {
819            for j in (i + 1)..clusters.len() {
820                // Average linkage
821                let mut total_dist = 0.0;
822                let mut count = 0;
823
824                for &idx_i in &clusters[i] {
825                    for &idx_j in &clusters[j] {
826                        let dist = self.euclidean_distance(
827                            &embeddings[&entity_list[idx_i]],
828                            &embeddings[&entity_list[idx_j]],
829                        );
830                        total_dist += dist;
831                        count += 1;
832                    }
833                }
834
835                let avg_dist = if count > 0 {
836                    total_dist / count as f32
837                } else {
838                    f32::INFINITY
839                };
840
841                if avg_dist < min_dist {
842                    min_dist = avg_dist;
843                    closest_pair = (i, j);
844                }
845            }
846        }
847
848        closest_pair
849    }
850
851    /// Find neighbors within epsilon distance for DBSCAN
852    fn find_neighbors(
853        &self,
854        idx: usize,
855        entity_list: &[String],
856        embeddings: &HashMap<String, Array1<f32>>,
857    ) -> Vec<usize> {
858        let entity = &entity_list[idx];
859        let emb = &embeddings[entity];
860
861        entity_list
862            .iter()
863            .enumerate()
864            .filter_map(|(i, other)| {
865                if i != idx
866                    && self.euclidean_distance(emb, &embeddings[other]) <= self.config.epsilon
867                {
868                    Some(i)
869                } else {
870                    None
871                }
872            })
873            .collect()
874    }
875
876    /// Expand cluster for DBSCAN
877    #[allow(clippy::too_many_arguments)]
878    fn expand_cluster(
879        &self,
880        idx: usize,
881        neighbors: &[usize],
882        cluster_id: usize,
883        entity_list: &[String],
884        embeddings: &HashMap<String, Array1<f32>>,
885        assignments: &mut HashMap<String, usize>,
886        visited: &mut HashSet<usize>,
887    ) {
888        assignments.insert(entity_list[idx].clone(), cluster_id);
889
890        let mut queue: Vec<usize> = neighbors.to_vec();
891        let mut processed = 0;
892
893        while processed < queue.len() {
894            let neighbor_idx = queue[processed];
895            processed += 1;
896
897            if !visited.contains(&neighbor_idx) {
898                visited.insert(neighbor_idx);
899
900                let neighbor_neighbors = self.find_neighbors(neighbor_idx, entity_list, embeddings);
901
902                if neighbor_neighbors.len() >= self.config.min_points {
903                    queue.extend(neighbor_neighbors);
904                }
905            }
906
907            if !assignments.contains_key(&entity_list[neighbor_idx]) {
908                assignments.insert(entity_list[neighbor_idx].clone(), cluster_id);
909            }
910        }
911    }
912}
913
914#[cfg(test)]
915mod tests {
916    use super::*;
917    use scirs2_core::ndarray_ext::array;
918
919    #[test]
920    fn test_kmeans_clustering() {
921        let mut embeddings = HashMap::new();
922        embeddings.insert("e1".to_string(), array![1.0, 1.0]);
923        embeddings.insert("e2".to_string(), array![1.1, 0.9]);
924        embeddings.insert("e3".to_string(), array![5.0, 5.0]);
925        embeddings.insert("e4".to_string(), array![5.1, 4.9]);
926
927        let config = ClusteringConfig {
928            algorithm: ClusteringAlgorithm::KMeans,
929            num_clusters: 2,
930            ..Default::default()
931        };
932
933        let mut clustering = EntityClustering::new(config);
934        let result = clustering.cluster(&embeddings).unwrap();
935
936        assert_eq!(result.assignments.len(), 4);
937        assert_eq!(result.centroids.len(), 2);
938        assert_eq!(result.cluster_sizes.len(), 2);
939
940        // Check that similar entities are in the same cluster
941        assert_eq!(result.assignments["e1"], result.assignments["e2"]);
942        assert_eq!(result.assignments["e3"], result.assignments["e4"]);
943    }
944
945    #[test]
946    fn test_silhouette_score() {
947        let mut embeddings = HashMap::new();
948        embeddings.insert("e1".to_string(), array![0.0, 0.0]);
949        embeddings.insert("e2".to_string(), array![1.0, 1.0]);
950        embeddings.insert("e3".to_string(), array![5.0, 5.0]);
951
952        let config = ClusteringConfig {
953            num_clusters: 2,
954            ..Default::default()
955        };
956
957        let mut clustering = EntityClustering::new(config);
958        let result = clustering.cluster(&embeddings).unwrap();
959
960        // Silhouette score should be between -1 and 1
961        assert!(result.silhouette_score >= -1.0 && result.silhouette_score <= 1.0);
962    }
963}