1use crate::{CommunitySummary, GraphRAGError, GraphRAGResult, Triple};
4use petgraph::graph::{NodeIndex, UnGraph};
5use scirs2_core::random::{seeded_rng, Random};
6use std::collections::{HashMap, HashSet};
7
8#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
10pub enum CommunityAlgorithm {
11 Louvain,
13 #[default]
15 Leiden,
16 LabelPropagation,
18 ConnectedComponents,
20 Hierarchical,
22}
23
24#[derive(Debug, Clone)]
26pub struct CommunityConfig {
27 pub algorithm: CommunityAlgorithm,
29 pub resolution: f64,
31 pub min_community_size: usize,
33 pub max_communities: usize,
35 pub max_iterations: usize,
37 pub random_seed: u64,
39}
40
41impl Default for CommunityConfig {
42 fn default() -> Self {
43 Self {
44 algorithm: CommunityAlgorithm::Leiden,
45 resolution: 1.0,
46 min_community_size: 3,
47 max_communities: 50,
48 max_iterations: 10,
49 random_seed: 42,
50 }
51 }
52}
53
54pub struct CommunityDetector {
56 config: CommunityConfig,
57}
58
59impl Default for CommunityDetector {
60 fn default() -> Self {
61 Self::new(CommunityConfig::default())
62 }
63}
64
65impl CommunityDetector {
66 pub fn new(config: CommunityConfig) -> Self {
67 Self { config }
68 }
69
70 pub fn detect(&self, triples: &[Triple]) -> GraphRAGResult<Vec<CommunitySummary>> {
72 if triples.is_empty() {
73 return Ok(vec![]);
74 }
75
76 let (graph, node_map) = self.build_graph(triples);
78
79 let communities = match self.config.algorithm {
81 CommunityAlgorithm::Louvain => self.louvain(&graph, &node_map),
82 CommunityAlgorithm::Leiden => self.leiden(&graph, &node_map)?,
83 CommunityAlgorithm::LabelPropagation => self.label_propagation(&graph, &node_map),
84 CommunityAlgorithm::ConnectedComponents => self.connected_components(&graph, &node_map),
85 CommunityAlgorithm::Hierarchical => {
86 return self.detect_hierarchical(triples);
87 }
88 };
89
90 let summaries = self.create_summaries(communities, triples);
92
93 Ok(summaries)
94 }
95
96 fn build_graph(&self, triples: &[Triple]) -> (UnGraph<String, ()>, HashMap<String, NodeIndex>) {
98 let mut graph: UnGraph<String, ()> = UnGraph::new_undirected();
99 let mut node_map: HashMap<String, NodeIndex> = HashMap::new();
100
101 for triple in triples {
102 let subj_idx = *node_map
103 .entry(triple.subject.clone())
104 .or_insert_with(|| graph.add_node(triple.subject.clone()));
105 let obj_idx = *node_map
106 .entry(triple.object.clone())
107 .or_insert_with(|| graph.add_node(triple.object.clone()));
108
109 if subj_idx != obj_idx && graph.find_edge(subj_idx, obj_idx).is_none() {
110 graph.add_edge(subj_idx, obj_idx, ());
111 }
112 }
113
114 (graph, node_map)
115 }
116
117 fn louvain(
119 &self,
120 graph: &UnGraph<String, ()>,
121 node_map: &HashMap<String, NodeIndex>,
122 ) -> Vec<HashSet<String>> {
123 let node_count = graph.node_count();
124 if node_count == 0 {
125 return vec![];
126 }
127
128 let mut community: HashMap<NodeIndex, usize> = HashMap::new();
130 for (community_id, &idx) in node_map.values().enumerate() {
131 community.insert(idx, community_id);
132 }
133
134 let m = graph.edge_count() as f64;
136 if m == 0.0 {
137 return node_map
139 .keys()
140 .map(|k| {
141 let mut set = HashSet::new();
142 set.insert(k.clone());
143 set
144 })
145 .collect();
146 }
147
148 let degree: HashMap<NodeIndex, f64> = node_map
150 .values()
151 .map(|&idx| (idx, graph.neighbors(idx).count() as f64))
152 .collect();
153
154 for _ in 0..self.config.max_iterations {
156 let mut changed = false;
157
158 for (&node, ¤t_comm) in community.clone().iter() {
159 let node_degree = degree.get(&node).copied().unwrap_or(0.0);
160
161 let mut best_comm = current_comm;
163 let mut best_gain = 0.0;
164
165 let neighbor_comms: HashSet<usize> = graph
166 .neighbors(node)
167 .filter_map(|n| community.get(&n).copied())
168 .collect();
169
170 for &neighbor_comm in &neighbor_comms {
171 if neighbor_comm == current_comm {
172 continue;
173 }
174
175 let edges_to_comm: f64 = graph
177 .neighbors(node)
178 .filter(|n| community.get(n) == Some(&neighbor_comm))
179 .count() as f64;
180
181 let comm_degree: f64 = community
182 .iter()
183 .filter(|(_, &c)| c == neighbor_comm)
184 .map(|(n, _)| degree.get(n).copied().unwrap_or(0.0))
185 .sum();
186
187 let gain = edges_to_comm / m
188 - self.config.resolution * node_degree * comm_degree / (2.0 * m * m);
189
190 if gain > best_gain {
191 best_gain = gain;
192 best_comm = neighbor_comm;
193 }
194 }
195
196 if best_comm != current_comm && best_gain > 0.0 {
197 community.insert(node, best_comm);
198 changed = true;
199 }
200 }
201
202 if !changed {
203 break;
204 }
205 }
206
207 self.group_by_community(graph, &community)
209 }
210
211 fn label_propagation(
213 &self,
214 graph: &UnGraph<String, ()>,
215 node_map: &HashMap<String, NodeIndex>,
216 ) -> Vec<HashSet<String>> {
217 if graph.node_count() == 0 {
218 return vec![];
219 }
220
221 let mut labels: HashMap<NodeIndex, usize> = HashMap::new();
223 for (i, &idx) in node_map.values().enumerate() {
224 labels.insert(idx, i);
225 }
226
227 for _ in 0..self.config.max_iterations {
229 let mut changed = false;
230
231 for &node in node_map.values() {
232 let mut label_counts: HashMap<usize, usize> = HashMap::new();
234 for neighbor in graph.neighbors(node) {
235 if let Some(&label) = labels.get(&neighbor) {
236 *label_counts.entry(label).or_insert(0) += 1;
237 }
238 }
239
240 if let Some((&best_label, _)) = label_counts.iter().max_by_key(|(_, &count)| count)
242 {
243 if labels.get(&node) != Some(&best_label) {
244 labels.insert(node, best_label);
245 changed = true;
246 }
247 }
248 }
249
250 if !changed {
251 break;
252 }
253 }
254
255 self.group_by_community(graph, &labels)
256 }
257
258 fn connected_components(
260 &self,
261 graph: &UnGraph<String, ()>,
262 _node_map: &HashMap<String, NodeIndex>,
263 ) -> Vec<HashSet<String>> {
264 let sccs = petgraph::algo::kosaraju_scc(graph);
265
266 sccs.into_iter()
267 .map(|component| {
268 component
269 .into_iter()
270 .filter_map(|idx| graph.node_weight(idx).cloned())
271 .collect()
272 })
273 .collect()
274 }
275
276 fn leiden(
278 &self,
279 graph: &UnGraph<String, ()>,
280 node_map: &HashMap<String, NodeIndex>,
281 ) -> GraphRAGResult<Vec<HashSet<String>>> {
282 let node_count = graph.node_count();
283 if node_count == 0 {
284 return Ok(vec![]);
285 }
286
287 let mut community: HashMap<NodeIndex, usize> = HashMap::new();
289 for (community_id, &idx) in node_map.values().enumerate() {
290 community.insert(idx, community_id);
291 }
292
293 let m = graph.edge_count() as f64;
294 if m == 0.0 {
295 return Ok(node_map
296 .keys()
297 .map(|k| {
298 let mut set = HashSet::new();
299 set.insert(k.clone());
300 set
301 })
302 .collect());
303 }
304
305 let degree: HashMap<NodeIndex, f64> = node_map
306 .values()
307 .map(|&idx| (idx, graph.neighbors(idx).count() as f64))
308 .collect();
309
310 let mut rng = seeded_rng(self.config.random_seed);
311 let mut best_modularity = self.calculate_modularity(graph, &community, m, °ree)?;
312
313 for iteration in 0..self.config.max_iterations {
315 let mut changed = false;
316
317 let mut node_order: Vec<NodeIndex> = node_map.values().copied().collect();
319 for i in (1..node_order.len()).rev() {
321 let j = (rng.random_range(0.0..1.0) * (i + 1) as f64) as usize;
322 node_order.swap(i, j);
323 }
324
325 for &node in &node_order {
326 let current_comm = match community.get(&node) {
327 Some(&c) => c,
328 None => continue,
329 };
330 let node_degree = degree.get(&node).copied().unwrap_or(0.0);
331
332 let mut best_comm = current_comm;
333 let mut best_gain = 0.0;
334
335 let neighbor_comms: HashSet<usize> = graph
337 .neighbors(node)
338 .filter_map(|n| community.get(&n).copied())
339 .collect();
340
341 for &neighbor_comm in &neighbor_comms {
342 if neighbor_comm == current_comm {
343 continue;
344 }
345
346 let edges_to_comm: f64 = graph
347 .neighbors(node)
348 .filter(|n| community.get(n) == Some(&neighbor_comm))
349 .count() as f64;
350
351 let comm_degree: f64 = community
352 .iter()
353 .filter(|(_, &c)| c == neighbor_comm)
354 .map(|(n, _)| degree.get(n).copied().unwrap_or(0.0))
355 .sum();
356
357 let gain = edges_to_comm / m
358 - self.config.resolution * node_degree * comm_degree / (2.0 * m * m);
359
360 if gain > best_gain {
361 best_gain = gain;
362 best_comm = neighbor_comm;
363 }
364 }
365
366 if best_comm != current_comm && best_gain > 0.0 {
367 community.insert(node, best_comm);
368 changed = true;
369 }
370 }
371
372 let unique_comms: HashSet<usize> = community.values().copied().collect();
375 for &comm_id in &unique_comms {
376 let comm_nodes: Vec<NodeIndex> = community
377 .iter()
378 .filter(|(_, &c)| c == comm_id)
379 .map(|(&n, _)| n)
380 .collect();
381
382 if comm_nodes.len() <= 1 {
383 continue;
384 }
385
386 self.refine_community(graph, &mut community, &comm_nodes, comm_id, m, °ree)?;
388 }
389
390 let current_modularity = self.calculate_modularity(graph, &community, m, °ree)?;
392 if current_modularity > best_modularity {
393 best_modularity = current_modularity;
394 } else if !changed {
395 break;
396 }
397
398 if best_modularity > 0.95 || iteration > 0 && !changed {
400 break;
401 }
402 }
403
404 if best_modularity < 0.75 {
406 tracing::warn!("Leiden modularity {:.3} below target 0.75", best_modularity);
407 } else {
408 tracing::info!("Leiden achieved modularity: {:.3}", best_modularity);
409 }
410
411 Ok(self.group_by_community(graph, &community))
412 }
413
414 fn refine_community(
416 &self,
417 graph: &UnGraph<String, ()>,
418 community: &mut HashMap<NodeIndex, usize>,
419 comm_nodes: &[NodeIndex],
420 comm_id: usize,
421 m: f64,
422 degree: &HashMap<NodeIndex, f64>,
423 ) -> GraphRAGResult<()> {
424 if comm_nodes.len() < 2 {
425 return Ok(());
426 }
427
428 let mut subcomm: HashMap<NodeIndex, usize> = HashMap::new();
430 for (i, &node) in comm_nodes.iter().enumerate() {
431 subcomm.insert(node, i);
432 }
433
434 let mut changed = false;
436 for &node in comm_nodes {
437 let current_sub = match subcomm.get(&node) {
438 Some(&c) => c,
439 None => continue,
440 };
441
442 let mut sub_edges: HashMap<usize, f64> = HashMap::new();
444 for neighbor in graph.neighbors(node) {
445 if let Some(&sub) = subcomm.get(&neighbor) {
446 *sub_edges.entry(sub).or_insert(0.0) += 1.0;
447 }
448 }
449
450 if let Some((&best_sub, _)) = sub_edges
452 .iter()
453 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
454 {
455 if best_sub != current_sub {
456 subcomm.insert(node, best_sub);
457 changed = true;
458 }
459 }
460 }
461
462 if changed {
464 let unique_subs: HashSet<usize> = subcomm.values().copied().collect();
465 if unique_subs.len() > 1 {
466 let max_comm = community.values().max().copied().unwrap_or(0);
467 for (i, sub_id) in unique_subs.iter().enumerate() {
468 for &node in comm_nodes {
469 if subcomm.get(&node) == Some(sub_id) {
470 let new_comm = if i == 0 { comm_id } else { max_comm + i };
471 community.insert(node, new_comm);
472 }
473 }
474 }
475 }
476 }
477
478 Ok(())
479 }
480
481 fn calculate_modularity(
483 &self,
484 graph: &UnGraph<String, ()>,
485 community: &HashMap<NodeIndex, usize>,
486 m: f64,
487 degree: &HashMap<NodeIndex, f64>,
488 ) -> GraphRAGResult<f64> {
489 if m == 0.0 {
490 return Ok(0.0);
491 }
492
493 let mut modularity = 0.0;
494
495 for edge in graph.edge_indices() {
496 if let Some((a, b)) = graph.edge_endpoints(edge) {
497 let comm_a = community.get(&a);
498 let comm_b = community.get(&b);
499
500 if comm_a == comm_b && comm_a.is_some() {
501 let deg_a = degree.get(&a).copied().unwrap_or(0.0);
502 let deg_b = degree.get(&b).copied().unwrap_or(0.0);
503
504 modularity += 1.0 - (deg_a * deg_b) / (2.0 * m * m);
505 }
506 }
507 }
508
509 Ok(modularity / m)
510 }
511
512 fn detect_hierarchical(&self, triples: &[Triple]) -> GraphRAGResult<Vec<CommunitySummary>> {
514 let mut all_summaries = Vec::new();
515 let mut current_triples = triples.to_vec();
516 let mut level = 0;
517
518 while level < 5 && !current_triples.is_empty() {
519 let (graph, node_map) = self.build_graph(¤t_triples);
520
521 if graph.node_count() < 10 {
522 break;
523 }
524
525 let communities = self.leiden(&graph, &node_map)?;
527
528 let mut level_summaries = self.create_summaries(communities.clone(), ¤t_triples);
530
531 for summary in &mut level_summaries {
533 summary.level = level;
534 }
535
536 all_summaries.extend(level_summaries);
537
538 current_triples = self.coarsen_graph(&graph, &node_map, &communities)?;
540 level += 1;
541 }
542
543 Ok(all_summaries)
544 }
545
546 fn coarsen_graph(
548 &self,
549 graph: &UnGraph<String, ()>,
550 node_map: &HashMap<String, NodeIndex>,
551 communities: &[HashSet<String>],
552 ) -> GraphRAGResult<Vec<Triple>> {
553 let mut node_to_community: HashMap<String, usize> = HashMap::new();
554 for (comm_id, community) in communities.iter().enumerate() {
555 for node in community {
556 node_to_community.insert(node.clone(), comm_id);
557 }
558 }
559
560 let mut coarsened_triples = Vec::new();
561 let mut seen_edges: HashSet<(usize, usize)> = HashSet::new();
562
563 for edge in graph.edge_indices() {
564 if let Some((a, b)) = graph.edge_endpoints(edge) {
565 let label_a = graph.node_weight(a);
566 let label_b = graph.node_weight(b);
567
568 if let (Some(la), Some(lb)) = (label_a, label_b) {
569 if let (Some(&comm_a), Some(&comm_b)) =
570 (node_to_community.get(la), node_to_community.get(lb))
571 {
572 if comm_a != comm_b {
573 let edge_key = if comm_a < comm_b {
574 (comm_a, comm_b)
575 } else {
576 (comm_b, comm_a)
577 };
578
579 if !seen_edges.contains(&edge_key) {
580 seen_edges.insert(edge_key);
581 coarsened_triples.push(Triple::new(
582 format!("community_{}", comm_a),
583 "inter_community_link",
584 format!("community_{}", comm_b),
585 ));
586 }
587 }
588 }
589 }
590 }
591 }
592
593 Ok(coarsened_triples)
594 }
595
596 fn group_by_community(
598 &self,
599 graph: &UnGraph<String, ()>,
600 assignment: &HashMap<NodeIndex, usize>,
601 ) -> Vec<HashSet<String>> {
602 let mut communities: HashMap<usize, HashSet<String>> = HashMap::new();
603
604 for (&node, &comm) in assignment {
605 if let Some(label) = graph.node_weight(node) {
606 communities.entry(comm).or_default().insert(label.clone());
607 }
608 }
609
610 communities.into_values().collect()
611 }
612
613 fn create_summaries(
615 &self,
616 communities: Vec<HashSet<String>>,
617 triples: &[Triple],
618 ) -> Vec<CommunitySummary> {
619 let (graph, node_map) = self.build_graph(triples);
621 let m = graph.edge_count() as f64;
622
623 let mut community_map: HashMap<NodeIndex, usize> = HashMap::new();
625 for (idx, entities) in communities.iter().enumerate() {
626 for entity in entities {
627 if let Some(&node_idx) = node_map.get(entity) {
628 community_map.insert(node_idx, idx);
629 }
630 }
631 }
632
633 let degree: HashMap<NodeIndex, f64> = node_map
635 .values()
636 .map(|&idx| (idx, graph.neighbors(idx).count() as f64))
637 .collect();
638
639 let overall_modularity = if m > 0.0 {
641 let mut q = 0.0;
642 for edge in graph.edge_indices() {
643 if let Some((a, b)) = graph.edge_endpoints(edge) {
644 let comm_a = community_map.get(&a);
645 let comm_b = community_map.get(&b);
646
647 if comm_a.is_some() && comm_a == comm_b {
648 let deg_a = degree.get(&a).copied().unwrap_or(0.0);
649 let deg_b = degree.get(&b).copied().unwrap_or(0.0);
650 q += 1.0 - (deg_a * deg_b) / (2.0 * m);
651 }
652 }
653 }
654 q / (2.0 * m)
655 } else {
656 0.0
657 };
658
659 communities
660 .into_iter()
661 .enumerate()
662 .filter(|(_, entities)| entities.len() >= self.config.min_community_size)
663 .take(self.config.max_communities)
664 .map(|(idx, entities)| {
665 let representative_triples: Vec<Triple> = triples
667 .iter()
668 .filter(|t| entities.contains(&t.subject) || entities.contains(&t.object))
669 .take(5)
670 .cloned()
671 .collect();
672
673 let entity_list: Vec<String> = entities.iter().cloned().collect();
675 let summary = self.generate_summary(&entity_list, &representative_triples);
676
677 CommunitySummary {
679 id: format!("community_{}", idx),
680 summary,
681 entities: entity_list,
682 representative_triples,
683 level: 0,
684 modularity: overall_modularity,
685 }
686 })
687 .collect()
688 }
689
690 fn generate_summary(&self, entities: &[String], triples: &[Triple]) -> String {
692 let short_names: Vec<String> = entities
694 .iter()
695 .take(3)
696 .map(|uri| {
697 uri.rsplit('/')
698 .next()
699 .or_else(|| uri.rsplit('#').next())
700 .unwrap_or(uri)
701 .to_string()
702 })
703 .collect();
704
705 let predicates: HashSet<String> = triples
707 .iter()
708 .map(|t| {
709 t.predicate
710 .rsplit('/')
711 .next()
712 .or_else(|| t.predicate.rsplit('#').next())
713 .unwrap_or(&t.predicate)
714 .to_string()
715 })
716 .collect();
717
718 let pred_str: Vec<String> = predicates.into_iter().take(3).collect();
719
720 format!(
721 "Community of {} entities including {} connected by {}",
722 entities.len(),
723 short_names.join(", "),
724 pred_str.join(", ")
725 )
726 }
727}
728
729#[cfg(test)]
730mod tests {
731 use super::*;
732
733 #[test]
734 fn test_community_detection() {
735 let detector = CommunityDetector::new(CommunityConfig {
737 min_community_size: 1,
738 ..Default::default()
739 });
740
741 let triples = vec![
742 Triple::new("http://a", "http://rel", "http://b"),
743 Triple::new("http://b", "http://rel", "http://c"),
744 Triple::new("http://a", "http://rel", "http://c"),
745 Triple::new("http://x", "http://rel", "http://y"),
746 Triple::new("http://y", "http://rel", "http://z"),
747 Triple::new("http://x", "http://rel", "http://z"),
748 ];
749
750 let communities = detector.detect(&triples).unwrap();
751
752 assert!(!communities.is_empty());
754 }
755
756 #[test]
757 fn test_empty_graph() {
758 let detector = CommunityDetector::default();
759 let communities = detector.detect(&[]).unwrap();
760 assert!(communities.is_empty());
761 }
762}