Skip to main content

oxirs_graphrag/graph/
community.rs

1//! Community detection for hierarchical summarization
2
3use crate::{CommunitySummary, GraphRAGError, GraphRAGResult, Triple};
4use petgraph::graph::{NodeIndex, UnGraph};
5use scirs2_core::random::{seeded_rng, Random};
6use std::collections::{HashMap, HashSet};
7
8/// Community detection algorithm
9#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
10pub enum CommunityAlgorithm {
11    /// Louvain algorithm (baseline: ~0.65 modularity)
12    Louvain,
13    /// Leiden algorithm (target: >0.75 modularity, improved Louvain)
14    #[default]
15    Leiden,
16    /// Label propagation
17    LabelPropagation,
18    /// Connected components
19    ConnectedComponents,
20    /// Hierarchical (multi-level)
21    Hierarchical,
22}
23
24/// Community detector configuration
25#[derive(Debug, Clone)]
26pub struct CommunityConfig {
27    /// Algorithm to use
28    pub algorithm: CommunityAlgorithm,
29    /// Resolution parameter for Louvain/Leiden
30    pub resolution: f64,
31    /// Minimum community size
32    pub min_community_size: usize,
33    /// Maximum number of communities
34    pub max_communities: usize,
35    /// Number of iterations for iterative algorithms
36    pub max_iterations: usize,
37    /// Random seed for reproducibility
38    pub random_seed: u64,
39}
40
41impl Default for CommunityConfig {
42    fn default() -> Self {
43        Self {
44            algorithm: CommunityAlgorithm::Leiden,
45            resolution: 1.0,
46            min_community_size: 3,
47            max_communities: 50,
48            max_iterations: 10,
49            random_seed: 42,
50        }
51    }
52}
53
54/// Community detector
55pub struct CommunityDetector {
56    config: CommunityConfig,
57}
58
59impl Default for CommunityDetector {
60    fn default() -> Self {
61        Self::new(CommunityConfig::default())
62    }
63}
64
65impl CommunityDetector {
66    pub fn new(config: CommunityConfig) -> Self {
67        Self { config }
68    }
69
70    /// Detect communities in the given subgraph
71    pub fn detect(&self, triples: &[Triple]) -> GraphRAGResult<Vec<CommunitySummary>> {
72        if triples.is_empty() {
73            return Ok(vec![]);
74        }
75
76        // Build graph
77        let (graph, node_map) = self.build_graph(triples);
78
79        // Detect communities based on algorithm
80        let communities = match self.config.algorithm {
81            CommunityAlgorithm::Louvain => self.louvain(&graph, &node_map),
82            CommunityAlgorithm::Leiden => self.leiden(&graph, &node_map)?,
83            CommunityAlgorithm::LabelPropagation => self.label_propagation(&graph, &node_map),
84            CommunityAlgorithm::ConnectedComponents => self.connected_components(&graph, &node_map),
85            CommunityAlgorithm::Hierarchical => {
86                return self.detect_hierarchical(triples);
87            }
88        };
89
90        // Filter and create summaries
91        let summaries = self.create_summaries(communities, triples);
92
93        Ok(summaries)
94    }
95
96    /// Build undirected graph from triples
97    fn build_graph(&self, triples: &[Triple]) -> (UnGraph<String, ()>, HashMap<String, NodeIndex>) {
98        let mut graph: UnGraph<String, ()> = UnGraph::new_undirected();
99        let mut node_map: HashMap<String, NodeIndex> = HashMap::new();
100
101        for triple in triples {
102            let subj_idx = *node_map
103                .entry(triple.subject.clone())
104                .or_insert_with(|| graph.add_node(triple.subject.clone()));
105            let obj_idx = *node_map
106                .entry(triple.object.clone())
107                .or_insert_with(|| graph.add_node(triple.object.clone()));
108
109            if subj_idx != obj_idx && graph.find_edge(subj_idx, obj_idx).is_none() {
110                graph.add_edge(subj_idx, obj_idx, ());
111            }
112        }
113
114        (graph, node_map)
115    }
116
117    /// Simplified Louvain algorithm
118    fn louvain(
119        &self,
120        graph: &UnGraph<String, ()>,
121        node_map: &HashMap<String, NodeIndex>,
122    ) -> Vec<HashSet<String>> {
123        let node_count = graph.node_count();
124        if node_count == 0 {
125            return vec![];
126        }
127
128        // Initialize: each node in its own community
129        let mut community: HashMap<NodeIndex, usize> = HashMap::new();
130        for (community_id, &idx) in node_map.values().enumerate() {
131            community.insert(idx, community_id);
132        }
133
134        // Total edges (for modularity calculation)
135        let m = graph.edge_count() as f64;
136        if m == 0.0 {
137            // No edges, each node is its own community
138            return node_map
139                .keys()
140                .map(|k| {
141                    let mut set = HashSet::new();
142                    set.insert(k.clone());
143                    set
144                })
145                .collect();
146        }
147
148        // Degree of each node
149        let degree: HashMap<NodeIndex, f64> = node_map
150            .values()
151            .map(|&idx| (idx, graph.neighbors(idx).count() as f64))
152            .collect();
153
154        // Iterate
155        for _ in 0..self.config.max_iterations {
156            let mut changed = false;
157
158            for (&node, &current_comm) in community.clone().iter() {
159                let node_degree = degree.get(&node).copied().unwrap_or(0.0);
160
161                // Calculate modularity gain for each neighbor's community
162                let mut best_comm = current_comm;
163                let mut best_gain = 0.0;
164
165                let neighbor_comms: HashSet<usize> = graph
166                    .neighbors(node)
167                    .filter_map(|n| community.get(&n).copied())
168                    .collect();
169
170                for &neighbor_comm in &neighbor_comms {
171                    if neighbor_comm == current_comm {
172                        continue;
173                    }
174
175                    // Simplified modularity gain calculation
176                    let edges_to_comm: f64 = graph
177                        .neighbors(node)
178                        .filter(|n| community.get(n) == Some(&neighbor_comm))
179                        .count() as f64;
180
181                    let comm_degree: f64 = community
182                        .iter()
183                        .filter(|(_, &c)| c == neighbor_comm)
184                        .map(|(n, _)| degree.get(n).copied().unwrap_or(0.0))
185                        .sum();
186
187                    let gain = edges_to_comm / m
188                        - self.config.resolution * node_degree * comm_degree / (2.0 * m * m);
189
190                    if gain > best_gain {
191                        best_gain = gain;
192                        best_comm = neighbor_comm;
193                    }
194                }
195
196                if best_comm != current_comm && best_gain > 0.0 {
197                    community.insert(node, best_comm);
198                    changed = true;
199                }
200            }
201
202            if !changed {
203                break;
204            }
205        }
206
207        // Group nodes by community
208        self.group_by_community(graph, &community)
209    }
210
211    /// Label propagation algorithm
212    fn label_propagation(
213        &self,
214        graph: &UnGraph<String, ()>,
215        node_map: &HashMap<String, NodeIndex>,
216    ) -> Vec<HashSet<String>> {
217        if graph.node_count() == 0 {
218            return vec![];
219        }
220
221        // Initialize labels
222        let mut labels: HashMap<NodeIndex, usize> = HashMap::new();
223        for (i, &idx) in node_map.values().enumerate() {
224            labels.insert(idx, i);
225        }
226
227        // Iterate
228        for _ in 0..self.config.max_iterations {
229            let mut changed = false;
230
231            for &node in node_map.values() {
232                // Count neighbor labels
233                let mut label_counts: HashMap<usize, usize> = HashMap::new();
234                for neighbor in graph.neighbors(node) {
235                    if let Some(&label) = labels.get(&neighbor) {
236                        *label_counts.entry(label).or_insert(0) += 1;
237                    }
238                }
239
240                // Assign most common label
241                if let Some((&best_label, _)) = label_counts.iter().max_by_key(|(_, &count)| count)
242                {
243                    if labels.get(&node) != Some(&best_label) {
244                        labels.insert(node, best_label);
245                        changed = true;
246                    }
247                }
248            }
249
250            if !changed {
251                break;
252            }
253        }
254
255        self.group_by_community(graph, &labels)
256    }
257
258    /// Connected components
259    fn connected_components(
260        &self,
261        graph: &UnGraph<String, ()>,
262        _node_map: &HashMap<String, NodeIndex>,
263    ) -> Vec<HashSet<String>> {
264        let sccs = petgraph::algo::kosaraju_scc(graph);
265
266        sccs.into_iter()
267            .map(|component| {
268                component
269                    .into_iter()
270                    .filter_map(|idx| graph.node_weight(idx).cloned())
271                    .collect()
272            })
273            .collect()
274    }
275
276    /// Leiden algorithm (improved Louvain with refinement phase)
277    fn leiden(
278        &self,
279        graph: &UnGraph<String, ()>,
280        node_map: &HashMap<String, NodeIndex>,
281    ) -> GraphRAGResult<Vec<HashSet<String>>> {
282        let node_count = graph.node_count();
283        if node_count == 0 {
284            return Ok(vec![]);
285        }
286
287        // Initialize: each node in its own community
288        let mut community: HashMap<NodeIndex, usize> = HashMap::new();
289        for (community_id, &idx) in node_map.values().enumerate() {
290            community.insert(idx, community_id);
291        }
292
293        let m = graph.edge_count() as f64;
294        if m == 0.0 {
295            return Ok(node_map
296                .keys()
297                .map(|k| {
298                    let mut set = HashSet::new();
299                    set.insert(k.clone());
300                    set
301                })
302                .collect());
303        }
304
305        let degree: HashMap<NodeIndex, f64> = node_map
306            .values()
307            .map(|&idx| (idx, graph.neighbors(idx).count() as f64))
308            .collect();
309
310        let mut rng = seeded_rng(self.config.random_seed);
311        let mut best_modularity = self.calculate_modularity(graph, &community, m, &degree)?;
312
313        // Main Leiden loop
314        for iteration in 0..self.config.max_iterations {
315            let mut changed = false;
316
317            // Phase 1: Local moving (like Louvain)
318            let mut node_order: Vec<NodeIndex> = node_map.values().copied().collect();
319            // Shuffle for randomness
320            for i in (1..node_order.len()).rev() {
321                let j = (rng.random_range(0.0..1.0) * (i + 1) as f64) as usize;
322                node_order.swap(i, j);
323            }
324
325            for &node in &node_order {
326                let current_comm = match community.get(&node) {
327                    Some(&c) => c,
328                    None => continue,
329                };
330                let node_degree = degree.get(&node).copied().unwrap_or(0.0);
331
332                let mut best_comm = current_comm;
333                let mut best_gain = 0.0;
334
335                // Get neighbor communities
336                let neighbor_comms: HashSet<usize> = graph
337                    .neighbors(node)
338                    .filter_map(|n| community.get(&n).copied())
339                    .collect();
340
341                for &neighbor_comm in &neighbor_comms {
342                    if neighbor_comm == current_comm {
343                        continue;
344                    }
345
346                    let edges_to_comm: f64 = graph
347                        .neighbors(node)
348                        .filter(|n| community.get(n) == Some(&neighbor_comm))
349                        .count() as f64;
350
351                    let comm_degree: f64 = community
352                        .iter()
353                        .filter(|(_, &c)| c == neighbor_comm)
354                        .map(|(n, _)| degree.get(n).copied().unwrap_or(0.0))
355                        .sum();
356
357                    let gain = edges_to_comm / m
358                        - self.config.resolution * node_degree * comm_degree / (2.0 * m * m);
359
360                    if gain > best_gain {
361                        best_gain = gain;
362                        best_comm = neighbor_comm;
363                    }
364                }
365
366                if best_comm != current_comm && best_gain > 0.0 {
367                    community.insert(node, best_comm);
368                    changed = true;
369                }
370            }
371
372            // Phase 2: Refinement (what makes Leiden better than Louvain)
373            // Split communities and re-merge if it improves modularity
374            let unique_comms: HashSet<usize> = community.values().copied().collect();
375            for &comm_id in &unique_comms {
376                let comm_nodes: Vec<NodeIndex> = community
377                    .iter()
378                    .filter(|(_, &c)| c == comm_id)
379                    .map(|(&n, _)| n)
380                    .collect();
381
382                if comm_nodes.len() <= 1 {
383                    continue;
384                }
385
386                // Try to split and refine
387                self.refine_community(graph, &mut community, &comm_nodes, comm_id, m, &degree)?;
388            }
389
390            // Check modularity improvement
391            let current_modularity = self.calculate_modularity(graph, &community, m, &degree)?;
392            if current_modularity > best_modularity {
393                best_modularity = current_modularity;
394            } else if !changed {
395                break;
396            }
397
398            // Early stop if modularity is very high
399            if best_modularity > 0.95 || iteration > 0 && !changed {
400                break;
401            }
402        }
403
404        // Verify target: modularity > 0.75
405        if best_modularity < 0.75 {
406            tracing::warn!("Leiden modularity {:.3} below target 0.75", best_modularity);
407        } else {
408            tracing::info!("Leiden achieved modularity: {:.3}", best_modularity);
409        }
410
411        Ok(self.group_by_community(graph, &community))
412    }
413
414    /// Refine a community by attempting local splits
415    fn refine_community(
416        &self,
417        graph: &UnGraph<String, ()>,
418        community: &mut HashMap<NodeIndex, usize>,
419        comm_nodes: &[NodeIndex],
420        comm_id: usize,
421        m: f64,
422        degree: &HashMap<NodeIndex, f64>,
423    ) -> GraphRAGResult<()> {
424        if comm_nodes.len() < 2 {
425            return Ok(());
426        }
427
428        // Try to find a better split using local connectivity
429        let mut subcomm: HashMap<NodeIndex, usize> = HashMap::new();
430        for (i, &node) in comm_nodes.iter().enumerate() {
431            subcomm.insert(node, i);
432        }
433
434        // One pass of local moving within the community
435        let mut changed = false;
436        for &node in comm_nodes {
437            let current_sub = match subcomm.get(&node) {
438                Some(&c) => c,
439                None => continue,
440            };
441
442            // Count edges to each subcommunity
443            let mut sub_edges: HashMap<usize, f64> = HashMap::new();
444            for neighbor in graph.neighbors(node) {
445                if let Some(&sub) = subcomm.get(&neighbor) {
446                    *sub_edges.entry(sub).or_insert(0.0) += 1.0;
447                }
448            }
449
450            // Find best subcommunity
451            if let Some((&best_sub, _)) = sub_edges
452                .iter()
453                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
454            {
455                if best_sub != current_sub {
456                    subcomm.insert(node, best_sub);
457                    changed = true;
458                }
459            }
460        }
461
462        // If we found a better partition, create new communities
463        if changed {
464            let unique_subs: HashSet<usize> = subcomm.values().copied().collect();
465            if unique_subs.len() > 1 {
466                let max_comm = community.values().max().copied().unwrap_or(0);
467                for (i, sub_id) in unique_subs.iter().enumerate() {
468                    for &node in comm_nodes {
469                        if subcomm.get(&node) == Some(sub_id) {
470                            let new_comm = if i == 0 { comm_id } else { max_comm + i };
471                            community.insert(node, new_comm);
472                        }
473                    }
474                }
475            }
476        }
477
478        Ok(())
479    }
480
481    /// Calculate modularity of a community assignment
482    fn calculate_modularity(
483        &self,
484        graph: &UnGraph<String, ()>,
485        community: &HashMap<NodeIndex, usize>,
486        m: f64,
487        degree: &HashMap<NodeIndex, f64>,
488    ) -> GraphRAGResult<f64> {
489        if m == 0.0 {
490            return Ok(0.0);
491        }
492
493        let mut modularity = 0.0;
494
495        for edge in graph.edge_indices() {
496            if let Some((a, b)) = graph.edge_endpoints(edge) {
497                let comm_a = community.get(&a);
498                let comm_b = community.get(&b);
499
500                if comm_a == comm_b && comm_a.is_some() {
501                    let deg_a = degree.get(&a).copied().unwrap_or(0.0);
502                    let deg_b = degree.get(&b).copied().unwrap_or(0.0);
503
504                    modularity += 1.0 - (deg_a * deg_b) / (2.0 * m * m);
505                }
506            }
507        }
508
509        Ok(modularity / m)
510    }
511
512    /// Hierarchical community detection (multi-level)
513    fn detect_hierarchical(&self, triples: &[Triple]) -> GraphRAGResult<Vec<CommunitySummary>> {
514        let mut all_summaries = Vec::new();
515        let mut current_triples = triples.to_vec();
516        let mut level = 0;
517
518        while level < 5 && !current_triples.is_empty() {
519            let (graph, node_map) = self.build_graph(&current_triples);
520
521            if graph.node_count() < 10 {
522                break;
523            }
524
525            // Detect communities at this level using Leiden
526            let communities = self.leiden(&graph, &node_map)?;
527
528            // Create summaries for this level
529            let mut level_summaries = self.create_summaries(communities.clone(), &current_triples);
530
531            // Tag with level
532            for summary in &mut level_summaries {
533                summary.level = level;
534            }
535
536            all_summaries.extend(level_summaries);
537
538            // Coarsen graph: each community becomes a supernode
539            current_triples = self.coarsen_graph(&graph, &node_map, &communities)?;
540            level += 1;
541        }
542
543        Ok(all_summaries)
544    }
545
546    /// Coarsen graph by collapsing communities into supernodes
547    fn coarsen_graph(
548        &self,
549        graph: &UnGraph<String, ()>,
550        node_map: &HashMap<String, NodeIndex>,
551        communities: &[HashSet<String>],
552    ) -> GraphRAGResult<Vec<Triple>> {
553        let mut node_to_community: HashMap<String, usize> = HashMap::new();
554        for (comm_id, community) in communities.iter().enumerate() {
555            for node in community {
556                node_to_community.insert(node.clone(), comm_id);
557            }
558        }
559
560        let mut coarsened_triples = Vec::new();
561        let mut seen_edges: HashSet<(usize, usize)> = HashSet::new();
562
563        for edge in graph.edge_indices() {
564            if let Some((a, b)) = graph.edge_endpoints(edge) {
565                let label_a = graph.node_weight(a);
566                let label_b = graph.node_weight(b);
567
568                if let (Some(la), Some(lb)) = (label_a, label_b) {
569                    if let (Some(&comm_a), Some(&comm_b)) =
570                        (node_to_community.get(la), node_to_community.get(lb))
571                    {
572                        if comm_a != comm_b {
573                            let edge_key = if comm_a < comm_b {
574                                (comm_a, comm_b)
575                            } else {
576                                (comm_b, comm_a)
577                            };
578
579                            if !seen_edges.contains(&edge_key) {
580                                seen_edges.insert(edge_key);
581                                coarsened_triples.push(Triple::new(
582                                    format!("community_{}", comm_a),
583                                    "inter_community_link",
584                                    format!("community_{}", comm_b),
585                                ));
586                            }
587                        }
588                    }
589                }
590            }
591        }
592
593        Ok(coarsened_triples)
594    }
595
596    /// Group nodes by community assignment
597    fn group_by_community(
598        &self,
599        graph: &UnGraph<String, ()>,
600        assignment: &HashMap<NodeIndex, usize>,
601    ) -> Vec<HashSet<String>> {
602        let mut communities: HashMap<usize, HashSet<String>> = HashMap::new();
603
604        for (&node, &comm) in assignment {
605            if let Some(label) = graph.node_weight(node) {
606                communities.entry(comm).or_default().insert(label.clone());
607            }
608        }
609
610        communities.into_values().collect()
611    }
612
613    /// Create community summaries
614    fn create_summaries(
615        &self,
616        communities: Vec<HashSet<String>>,
617        triples: &[Triple],
618    ) -> Vec<CommunitySummary> {
619        // Build graph for proper modularity calculation
620        let (graph, node_map) = self.build_graph(triples);
621        let m = graph.edge_count() as f64;
622
623        // Create community assignments
624        let mut community_map: HashMap<NodeIndex, usize> = HashMap::new();
625        for (idx, entities) in communities.iter().enumerate() {
626            for entity in entities {
627                if let Some(&node_idx) = node_map.get(entity) {
628                    community_map.insert(node_idx, idx);
629                }
630            }
631        }
632
633        // Calculate degrees
634        let degree: HashMap<NodeIndex, f64> = node_map
635            .values()
636            .map(|&idx| (idx, graph.neighbors(idx).count() as f64))
637            .collect();
638
639        // Calculate overall partition modularity (Newman-Girvan formula)
640        let overall_modularity = if m > 0.0 {
641            let mut q = 0.0;
642            for edge in graph.edge_indices() {
643                if let Some((a, b)) = graph.edge_endpoints(edge) {
644                    let comm_a = community_map.get(&a);
645                    let comm_b = community_map.get(&b);
646
647                    if comm_a.is_some() && comm_a == comm_b {
648                        let deg_a = degree.get(&a).copied().unwrap_or(0.0);
649                        let deg_b = degree.get(&b).copied().unwrap_or(0.0);
650                        q += 1.0 - (deg_a * deg_b) / (2.0 * m);
651                    }
652                }
653            }
654            q / (2.0 * m)
655        } else {
656            0.0
657        };
658
659        communities
660            .into_iter()
661            .enumerate()
662            .filter(|(_, entities)| entities.len() >= self.config.min_community_size)
663            .take(self.config.max_communities)
664            .map(|(idx, entities)| {
665                // Find representative triples
666                let representative_triples: Vec<Triple> = triples
667                    .iter()
668                    .filter(|t| entities.contains(&t.subject) || entities.contains(&t.object))
669                    .take(5)
670                    .cloned()
671                    .collect();
672
673                // Generate summary
674                let entity_list: Vec<String> = entities.iter().cloned().collect();
675                let summary = self.generate_summary(&entity_list, &representative_triples);
676
677                // All communities share the overall partition modularity
678                CommunitySummary {
679                    id: format!("community_{}", idx),
680                    summary,
681                    entities: entity_list,
682                    representative_triples,
683                    level: 0,
684                    modularity: overall_modularity,
685                }
686            })
687            .collect()
688    }
689
690    /// Generate a text summary for a community
691    fn generate_summary(&self, entities: &[String], triples: &[Triple]) -> String {
692        // Extract short names from URIs
693        let short_names: Vec<String> = entities
694            .iter()
695            .take(3)
696            .map(|uri| {
697                uri.rsplit('/')
698                    .next()
699                    .or_else(|| uri.rsplit('#').next())
700                    .unwrap_or(uri)
701                    .to_string()
702            })
703            .collect();
704
705        // Extract predicates
706        let predicates: HashSet<String> = triples
707            .iter()
708            .map(|t| {
709                t.predicate
710                    .rsplit('/')
711                    .next()
712                    .or_else(|| t.predicate.rsplit('#').next())
713                    .unwrap_or(&t.predicate)
714                    .to_string()
715            })
716            .collect();
717
718        let pred_str: Vec<String> = predicates.into_iter().take(3).collect();
719
720        format!(
721            "Community of {} entities including {} connected by {}",
722            entities.len(),
723            short_names.join(", "),
724            pred_str.join(", ")
725        )
726    }
727}
728
729#[cfg(test)]
730mod tests {
731    use super::*;
732
733    #[test]
734    fn test_community_detection() {
735        // Use min_community_size: 1 to ensure small communities are detected
736        let detector = CommunityDetector::new(CommunityConfig {
737            min_community_size: 1,
738            ..Default::default()
739        });
740
741        let triples = vec![
742            Triple::new("http://a", "http://rel", "http://b"),
743            Triple::new("http://b", "http://rel", "http://c"),
744            Triple::new("http://a", "http://rel", "http://c"),
745            Triple::new("http://x", "http://rel", "http://y"),
746            Triple::new("http://y", "http://rel", "http://z"),
747            Triple::new("http://x", "http://rel", "http://z"),
748        ];
749
750        let communities = detector.detect(&triples).unwrap();
751
752        // Should detect at least 1 community (a-b-c and x-y-z may be merged by Leiden)
753        assert!(!communities.is_empty());
754    }
755
756    #[test]
757    fn test_empty_graph() {
758        let detector = CommunityDetector::default();
759        let communities = detector.detect(&[]).unwrap();
760        assert!(communities.is_empty());
761    }
762}