oxirs_embed/
community_detection.rs

1//! Community Detection for Knowledge Graphs
2//!
3//! This module provides community detection algorithms for identifying densely
4//! connected groups of entities in knowledge graphs. Communities represent
5//! semantic groups that can improve understanding and navigation of large graphs.
6
7use anyhow::{anyhow, Result};
8use scirs2_core::ndarray_ext::Array1;
9use scirs2_core::random::Random;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet, VecDeque};
12use tracing::{debug, info};
13
14use crate::Triple;
15
16/// Community detection algorithm
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18pub enum CommunityAlgorithm {
19    /// Louvain modularity-based algorithm
20    Louvain,
21    /// Label propagation algorithm
22    LabelPropagation,
23    /// Girvan-Newman edge betweenness
24    GirvanNewman,
25    /// Embedding-based communities (using embeddings similarity)
26    EmbeddingBased,
27}
28
29/// Community detection configuration
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct CommunityConfig {
32    /// Algorithm to use
33    pub algorithm: CommunityAlgorithm,
34    /// Maximum iterations (for iterative algorithms)
35    pub max_iterations: usize,
36    /// Resolution parameter for Louvain
37    pub resolution: f32,
38    /// Minimum community size
39    pub min_community_size: usize,
40    /// Similarity threshold for embedding-based detection
41    pub similarity_threshold: f32,
42    /// Random seed
43    pub random_seed: Option<u64>,
44}
45
46impl Default for CommunityConfig {
47    fn default() -> Self {
48        Self {
49            algorithm: CommunityAlgorithm::Louvain,
50            max_iterations: 100,
51            resolution: 1.0,
52            min_community_size: 2,
53            similarity_threshold: 0.7,
54            random_seed: None,
55        }
56    }
57}
58
59/// Community detection result
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct CommunityResult {
62    /// Community assignments (entity_id -> community_id)
63    pub assignments: HashMap<String, usize>,
64    /// Number of communities found
65    pub num_communities: usize,
66    /// Community sizes
67    pub community_sizes: Vec<usize>,
68    /// Modularity score (quality metric)
69    pub modularity: f32,
70    /// Coverage (fraction of edges within communities)
71    pub coverage: f32,
72    /// Community members (community_id -> set of entity_ids)
73    pub communities: HashMap<usize, HashSet<String>>,
74}
75
76/// Graph structure for community detection
77struct Graph {
78    /// Adjacency list
79    edges: HashMap<String, HashSet<String>>,
80    /// Edge weights (for weighted graphs)
81    weights: HashMap<(String, String), f32>,
82    /// Total number of edges
83    num_edges: usize,
84}
85
86impl Graph {
87    fn new() -> Self {
88        Self {
89            edges: HashMap::new(),
90            weights: HashMap::new(),
91            num_edges: 0,
92        }
93    }
94
95    fn add_edge(&mut self, from: &str, to: &str, weight: f32) {
96        self.edges
97            .entry(from.to_string())
98            .or_default()
99            .insert(to.to_string());
100
101        self.edges
102            .entry(to.to_string())
103            .or_default()
104            .insert(from.to_string());
105
106        self.weights
107            .insert((from.to_string(), to.to_string()), weight);
108        self.weights
109            .insert((to.to_string(), from.to_string()), weight);
110
111        self.num_edges += 1;
112    }
113
114    fn get_neighbors(&self, node: &str) -> Option<&HashSet<String>> {
115        self.edges.get(node)
116    }
117
118    fn get_weight(&self, from: &str, to: &str) -> f32 {
119        self.weights
120            .get(&(from.to_string(), to.to_string()))
121            .copied()
122            .unwrap_or(1.0)
123    }
124
125    fn degree(&self, node: &str) -> usize {
126        self.edges.get(node).map(|s| s.len()).unwrap_or(0)
127    }
128
129    fn nodes(&self) -> Vec<String> {
130        self.edges.keys().cloned().collect()
131    }
132}
133
134/// Community detector
135pub struct CommunityDetector {
136    config: CommunityConfig,
137    rng: Random,
138}
139
140impl CommunityDetector {
141    /// Create new community detector
142    pub fn new(config: CommunityConfig) -> Self {
143        let rng = Random::default();
144
145        Self { config, rng }
146    }
147
148    /// Detect communities from knowledge graph triples
149    pub fn detect_from_triples(&mut self, triples: &[Triple]) -> Result<CommunityResult> {
150        // Build graph from triples
151        let mut graph = Graph::new();
152
153        for triple in triples {
154            // Add edge from subject to object (undirected)
155            graph.add_edge(&triple.subject.to_string(), &triple.object.to_string(), 1.0);
156        }
157
158        info!(
159            "Detecting communities in graph with {} nodes and {} edges using {:?}",
160            graph.nodes().len(),
161            graph.num_edges,
162            self.config.algorithm
163        );
164
165        self.detect_from_graph(&graph)
166    }
167
168    /// Detect communities from entity embeddings
169    pub fn detect_from_embeddings(
170        &mut self,
171        embeddings: &HashMap<String, Array1<f32>>,
172    ) -> Result<CommunityResult> {
173        info!("Detecting communities from {} embeddings", embeddings.len());
174
175        match self.config.algorithm {
176            CommunityAlgorithm::EmbeddingBased => self.embedding_based_detection(embeddings),
177            _ => {
178                // Build similarity graph
179                let graph = self.build_similarity_graph(embeddings);
180                self.detect_from_graph(&graph)
181            }
182        }
183    }
184
185    /// Detect communities from graph structure
186    fn detect_from_graph(&mut self, graph: &Graph) -> Result<CommunityResult> {
187        match self.config.algorithm {
188            CommunityAlgorithm::Louvain => self.louvain_detection(graph),
189            CommunityAlgorithm::LabelPropagation => self.label_propagation(graph),
190            CommunityAlgorithm::GirvanNewman => self.girvan_newman(graph),
191            CommunityAlgorithm::EmbeddingBased => {
192                Err(anyhow!("Embedding-based detection requires embeddings"))
193            }
194        }
195    }
196
197    /// Louvain modularity optimization
198    fn louvain_detection(&mut self, graph: &Graph) -> Result<CommunityResult> {
199        let nodes = graph.nodes();
200        let m = graph.num_edges as f32;
201
202        // Initialize: each node in its own community
203        let mut community: HashMap<String, usize> = nodes
204            .iter()
205            .enumerate()
206            .map(|(i, node)| (node.clone(), i))
207            .collect();
208
209        let mut improved = true;
210        let mut iteration = 0;
211
212        while improved && iteration < self.config.max_iterations {
213            improved = false;
214            iteration += 1;
215
216            // For each node, try moving to neighbor's community
217            for node in &nodes {
218                let current_comm = community[node];
219                let best_comm = self.find_best_community(node, current_comm, &community, graph, m);
220
221                if best_comm != current_comm {
222                    community.insert(node.clone(), best_comm);
223                    improved = true;
224                }
225            }
226
227            debug!("Louvain iteration {}: improved = {}", iteration, improved);
228        }
229
230        self.create_result(&community, graph)
231    }
232
233    /// Find best community for a node (max modularity gain)
234    fn find_best_community(
235        &self,
236        node: &str,
237        current_comm: usize,
238        community: &HashMap<String, usize>,
239        graph: &Graph,
240        m: f32,
241    ) -> usize {
242        let neighbors = match graph.get_neighbors(node) {
243            Some(n) => n,
244            None => return current_comm,
245        };
246
247        // Get neighboring communities
248        let mut neighbor_comms: HashSet<usize> = HashSet::new();
249        for neighbor in neighbors {
250            if let Some(&comm) = community.get(neighbor) {
251                neighbor_comms.insert(comm);
252            }
253        }
254
255        // Compute modularity gain for each community
256        let current_modularity =
257            self.compute_modularity_contribution(node, current_comm, community, graph, m);
258
259        let mut best_comm = current_comm;
260        let mut best_modularity = current_modularity;
261
262        for &comm in &neighbor_comms {
263            if comm == current_comm {
264                continue;
265            }
266
267            let modularity = self.compute_modularity_contribution(node, comm, community, graph, m);
268
269            if modularity > best_modularity {
270                best_modularity = modularity;
271                best_comm = comm;
272            }
273        }
274
275        best_comm
276    }
277
278    /// Compute modularity contribution for a node in a community
279    fn compute_modularity_contribution(
280        &self,
281        node: &str,
282        comm: usize,
283        community: &HashMap<String, usize>,
284        graph: &Graph,
285        m: f32,
286    ) -> f32 {
287        let neighbors = match graph.get_neighbors(node) {
288            Some(n) => n,
289            None => return 0.0,
290        };
291
292        let k_i = graph.degree(node) as f32;
293
294        // Sum of weights to nodes in community
295        let mut e_ic = 0.0;
296        let mut k_c = 0.0;
297
298        for neighbor in neighbors {
299            if let Some(&neighbor_comm) = community.get(neighbor) {
300                if neighbor_comm == comm {
301                    e_ic += graph.get_weight(node, neighbor);
302                    k_c += graph.degree(neighbor) as f32;
303                }
304            }
305        }
306
307        (e_ic - (self.config.resolution * k_i * k_c) / (2.0 * m)) / m
308    }
309
310    /// Label propagation algorithm
311    fn label_propagation(&mut self, graph: &Graph) -> Result<CommunityResult> {
312        let nodes = graph.nodes();
313
314        // Initialize: each node with unique label
315        let mut labels: HashMap<String, usize> = nodes
316            .iter()
317            .enumerate()
318            .map(|(i, node)| (node.clone(), i))
319            .collect();
320
321        for iteration in 0..self.config.max_iterations {
322            let mut changed = false;
323
324            // Randomize node order
325            let mut node_order = nodes.clone();
326            for i in (1..node_order.len()).rev() {
327                let j = self.rng.random_range(0..i + 1);
328                node_order.swap(i, j);
329            }
330
331            // Update labels
332            for node in &node_order {
333                let old_label = labels[node];
334                let new_label = self.majority_label(node, &labels, graph);
335
336                if new_label != old_label {
337                    labels.insert(node.clone(), new_label);
338                    changed = true;
339                }
340            }
341
342            debug!(
343                "Label propagation iteration {}: changed = {}",
344                iteration + 1,
345                changed
346            );
347
348            if !changed {
349                info!("Label propagation converged at iteration {}", iteration + 1);
350                break;
351            }
352        }
353
354        self.create_result(&labels, graph)
355    }
356
357    /// Get majority label from neighbors
358    fn majority_label(&self, node: &str, labels: &HashMap<String, usize>, graph: &Graph) -> usize {
359        let neighbors = match graph.get_neighbors(node) {
360            Some(n) => n,
361            None => return labels[node],
362        };
363
364        let mut label_counts: HashMap<usize, usize> = HashMap::new();
365
366        for neighbor in neighbors {
367            if let Some(&label) = labels.get(neighbor) {
368                *label_counts.entry(label).or_insert(0) += 1;
369            }
370        }
371
372        // Return most frequent label
373        label_counts
374            .into_iter()
375            .max_by_key(|(_, count)| *count)
376            .map(|(label, _)| label)
377            .unwrap_or_else(|| labels[node])
378    }
379
380    /// Girvan-Newman edge betweenness clustering
381    fn girvan_newman(&mut self, graph: &Graph) -> Result<CommunityResult> {
382        // Simplified implementation: iteratively remove highest betweenness edges
383        // Full implementation would require connected components analysis
384
385        let nodes = graph.nodes();
386        let mut assignments: HashMap<String, usize> = HashMap::new();
387
388        // Use BFS to identify connected components
389        let mut visited = HashSet::new();
390        let mut community_id = 0;
391
392        for node in &nodes {
393            if visited.contains(node) {
394                continue;
395            }
396
397            // BFS from this node
398            let component = self.bfs_component(node, graph, &visited);
399
400            for comp_node in &component {
401                assignments.insert(comp_node.clone(), community_id);
402                visited.insert(comp_node.clone());
403            }
404
405            community_id += 1;
406        }
407
408        self.create_result(&assignments, graph)
409    }
410
411    /// BFS to find connected component
412    fn bfs_component(
413        &self,
414        start: &str,
415        graph: &Graph,
416        visited: &HashSet<String>,
417    ) -> HashSet<String> {
418        let mut component = HashSet::new();
419        let mut queue = VecDeque::new();
420        queue.push_back(start.to_string());
421        component.insert(start.to_string());
422
423        while let Some(node) = queue.pop_front() {
424            if let Some(neighbors) = graph.get_neighbors(&node) {
425                for neighbor in neighbors {
426                    if !visited.contains(neighbor) && !component.contains(neighbor) {
427                        component.insert(neighbor.clone());
428                        queue.push_back(neighbor.clone());
429                    }
430                }
431            }
432        }
433
434        component
435    }
436
437    /// Embedding-based community detection
438    fn embedding_based_detection(
439        &mut self,
440        embeddings: &HashMap<String, Array1<f32>>,
441    ) -> Result<CommunityResult> {
442        let entity_list: Vec<String> = embeddings.keys().cloned().collect();
443        let mut assignments: HashMap<String, usize> = HashMap::new();
444        let mut community_id = 0;
445
446        let mut unassigned: HashSet<String> = entity_list.iter().cloned().collect();
447
448        while !unassigned.is_empty() {
449            // Pick random seed
450            let seed = unassigned.iter().next().unwrap().clone();
451            let mut community = HashSet::new();
452            community.insert(seed.clone());
453            unassigned.remove(&seed);
454
455            // Grow community by similarity
456            let mut changed = true;
457            while changed {
458                changed = false;
459
460                for entity in &entity_list {
461                    if community.contains(entity) || !unassigned.contains(entity) {
462                        continue;
463                    }
464
465                    // Check similarity to community members
466                    let avg_similarity =
467                        self.average_similarity_to_community(entity, &community, embeddings);
468
469                    if avg_similarity >= self.config.similarity_threshold {
470                        community.insert(entity.clone());
471                        unassigned.remove(entity);
472                        changed = true;
473                    }
474                }
475            }
476
477            // Assign community
478            if community.len() >= self.config.min_community_size {
479                for member in community {
480                    assignments.insert(member, community_id);
481                }
482                community_id += 1;
483            } else {
484                // Assign to noise/outlier community
485                for member in community {
486                    assignments.insert(member, usize::MAX);
487                }
488            }
489        }
490
491        // Build dummy graph for result creation
492        let mut graph = Graph::new();
493        for entity in &entity_list {
494            graph.edges.insert(entity.clone(), HashSet::new());
495        }
496
497        self.create_result(&assignments, &graph)
498    }
499
500    /// Compute average similarity to community members
501    fn average_similarity_to_community(
502        &self,
503        entity: &str,
504        community: &HashSet<String>,
505        embeddings: &HashMap<String, Array1<f32>>,
506    ) -> f32 {
507        if community.is_empty() {
508            return 0.0;
509        }
510
511        let entity_emb = &embeddings[entity];
512
513        let total_sim: f32 = community
514            .iter()
515            .map(|member| {
516                let member_emb = &embeddings[member];
517                self.cosine_similarity(entity_emb, member_emb)
518            })
519            .sum();
520
521        total_sim / community.len() as f32
522    }
523
524    /// Cosine similarity
525    fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
526        let dot = a.dot(b);
527        let norm_a = a.dot(a).sqrt();
528        let norm_b = b.dot(b).sqrt();
529
530        if norm_a == 0.0 || norm_b == 0.0 {
531            0.0
532        } else {
533            dot / (norm_a * norm_b)
534        }
535    }
536
537    /// Build similarity graph from embeddings
538    fn build_similarity_graph(&self, embeddings: &HashMap<String, Array1<f32>>) -> Graph {
539        let mut graph = Graph::new();
540        let entity_list: Vec<String> = embeddings.keys().cloned().collect();
541
542        for i in 0..entity_list.len() {
543            for j in (i + 1)..entity_list.len() {
544                let sim = self
545                    .cosine_similarity(&embeddings[&entity_list[i]], &embeddings[&entity_list[j]]);
546
547                if sim >= self.config.similarity_threshold {
548                    graph.add_edge(&entity_list[i], &entity_list[j], sim);
549                }
550            }
551        }
552
553        graph
554    }
555
556    /// Create community result from assignments
557    fn create_result(
558        &self,
559        assignments: &HashMap<String, usize>,
560        graph: &Graph,
561    ) -> Result<CommunityResult> {
562        // Compute community sizes
563        let mut community_sizes: HashMap<usize, usize> = HashMap::new();
564        let mut communities: HashMap<usize, HashSet<String>> = HashMap::new();
565
566        for (entity, &comm) in assignments {
567            if comm != usize::MAX {
568                *community_sizes.entry(comm).or_insert(0) += 1;
569                communities.entry(comm).or_default().insert(entity.clone());
570            }
571        }
572
573        let num_communities = community_sizes.len();
574        let sizes: Vec<usize> = (0..num_communities)
575            .map(|i| community_sizes.get(&i).copied().unwrap_or(0))
576            .collect();
577
578        // Compute modularity
579        let modularity = self.compute_modularity(assignments, graph);
580
581        // Compute coverage
582        let coverage = self.compute_coverage(assignments, graph);
583
584        Ok(CommunityResult {
585            assignments: assignments.clone(),
586            num_communities,
587            community_sizes: sizes,
588            modularity,
589            coverage,
590            communities,
591        })
592    }
593
594    /// Compute overall modularity
595    fn compute_modularity(&self, assignments: &HashMap<String, usize>, graph: &Graph) -> f32 {
596        let m = graph.num_edges as f32;
597        if m == 0.0 {
598            return 0.0;
599        }
600
601        let nodes = graph.nodes();
602        let mut modularity = 0.0;
603
604        for node_i in &nodes {
605            for node_j in &nodes {
606                if let (Some(&comm_i), Some(&comm_j)) =
607                    (assignments.get(node_i), assignments.get(node_j))
608                {
609                    if comm_i == comm_j && comm_i != usize::MAX {
610                        let a_ij = if graph
611                            .get_neighbors(node_i)
612                            .map(|n| n.contains(node_j))
613                            .unwrap_or(false)
614                        {
615                            1.0
616                        } else {
617                            0.0
618                        };
619
620                        let k_i = graph.degree(node_i) as f32;
621                        let k_j = graph.degree(node_j) as f32;
622
623                        modularity += a_ij - (k_i * k_j) / (2.0 * m);
624                    }
625                }
626            }
627        }
628
629        modularity / (2.0 * m)
630    }
631
632    /// Compute coverage (fraction of edges within communities)
633    fn compute_coverage(&self, assignments: &HashMap<String, usize>, graph: &Graph) -> f32 {
634        if graph.num_edges == 0 {
635            return 0.0;
636        }
637
638        let mut internal_edges = 0;
639
640        for (node, neighbors) in &graph.edges {
641            if let Some(&comm) = assignments.get(node) {
642                if comm == usize::MAX {
643                    continue;
644                }
645
646                for neighbor in neighbors {
647                    if let Some(&neighbor_comm) = assignments.get(neighbor) {
648                        if comm == neighbor_comm {
649                            internal_edges += 1;
650                        }
651                    }
652                }
653            }
654        }
655
656        // Each edge is counted twice
657        (internal_edges / 2) as f32 / graph.num_edges as f32
658    }
659}
660
661#[cfg(test)]
662mod tests {
663    use super::*;
664    use crate::NamedNode;
665    use scirs2_core::ndarray_ext::array;
666
667    #[test]
668    fn test_community_detection_from_triples() {
669        let triples = vec![
670            Triple::new(
671                NamedNode::new("a").unwrap(),
672                NamedNode::new("r").unwrap(),
673                NamedNode::new("b").unwrap(),
674            ),
675            Triple::new(
676                NamedNode::new("b").unwrap(),
677                NamedNode::new("r").unwrap(),
678                NamedNode::new("c").unwrap(),
679            ),
680            Triple::new(
681                NamedNode::new("d").unwrap(),
682                NamedNode::new("r").unwrap(),
683                NamedNode::new("e").unwrap(),
684            ),
685        ];
686
687        let config = CommunityConfig::default();
688        let mut detector = CommunityDetector::new(config);
689        let result = detector.detect_from_triples(&triples).unwrap();
690
691        assert!(result.num_communities > 0);
692        assert_eq!(result.assignments.len(), 5); // a, b, c, d, e
693    }
694
695    #[test]
696    fn test_embedding_based_detection() {
697        let mut embeddings = HashMap::new();
698        embeddings.insert("e1".to_string(), array![1.0, 0.0]);
699        embeddings.insert("e2".to_string(), array![0.9, 0.1]);
700        embeddings.insert("e3".to_string(), array![0.0, 1.0]);
701        embeddings.insert("e4".to_string(), array![0.1, 0.9]);
702
703        let config = CommunityConfig {
704            algorithm: CommunityAlgorithm::EmbeddingBased,
705            similarity_threshold: 0.8,
706            ..Default::default()
707        };
708
709        let mut detector = CommunityDetector::new(config);
710        let result = detector.detect_from_embeddings(&embeddings).unwrap();
711
712        assert!(result.num_communities >= 1);
713        // Similar embeddings should be in same community
714        assert_eq!(result.assignments.get("e1"), result.assignments.get("e2"));
715    }
716}