1use crate::{CommunitySummary, GraphRAGResult, Triple};
4use petgraph::graph::{NodeIndex, UnGraph};
5use std::collections::{HashMap, HashSet};
6
7#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
9pub enum CommunityAlgorithm {
10 #[default]
12 Louvain,
13 LabelPropagation,
15 ConnectedComponents,
17}
18
19#[derive(Debug, Clone)]
21pub struct CommunityConfig {
22 pub algorithm: CommunityAlgorithm,
24 pub resolution: f64,
26 pub min_community_size: usize,
28 pub max_communities: usize,
30 pub max_iterations: usize,
32}
33
34impl Default for CommunityConfig {
35 fn default() -> Self {
36 Self {
37 algorithm: CommunityAlgorithm::Louvain,
38 resolution: 1.0,
39 min_community_size: 2,
40 max_communities: 50,
41 max_iterations: 100,
42 }
43 }
44}
45
46pub struct CommunityDetector {
48 config: CommunityConfig,
49}
50
51impl Default for CommunityDetector {
52 fn default() -> Self {
53 Self::new(CommunityConfig::default())
54 }
55}
56
57impl CommunityDetector {
58 pub fn new(config: CommunityConfig) -> Self {
59 Self { config }
60 }
61
62 pub fn detect(&self, triples: &[Triple]) -> GraphRAGResult<Vec<CommunitySummary>> {
64 if triples.is_empty() {
65 return Ok(vec![]);
66 }
67
68 let (graph, node_map) = self.build_graph(triples);
70
71 let communities = match self.config.algorithm {
73 CommunityAlgorithm::Louvain => self.louvain(&graph, &node_map),
74 CommunityAlgorithm::LabelPropagation => self.label_propagation(&graph, &node_map),
75 CommunityAlgorithm::ConnectedComponents => self.connected_components(&graph, &node_map),
76 };
77
78 let summaries = self.create_summaries(communities, triples);
80
81 Ok(summaries)
82 }
83
84 fn build_graph(&self, triples: &[Triple]) -> (UnGraph<String, ()>, HashMap<String, NodeIndex>) {
86 let mut graph: UnGraph<String, ()> = UnGraph::new_undirected();
87 let mut node_map: HashMap<String, NodeIndex> = HashMap::new();
88
89 for triple in triples {
90 let subj_idx = *node_map
91 .entry(triple.subject.clone())
92 .or_insert_with(|| graph.add_node(triple.subject.clone()));
93 let obj_idx = *node_map
94 .entry(triple.object.clone())
95 .or_insert_with(|| graph.add_node(triple.object.clone()));
96
97 if subj_idx != obj_idx && graph.find_edge(subj_idx, obj_idx).is_none() {
98 graph.add_edge(subj_idx, obj_idx, ());
99 }
100 }
101
102 (graph, node_map)
103 }
104
105 fn louvain(
107 &self,
108 graph: &UnGraph<String, ()>,
109 node_map: &HashMap<String, NodeIndex>,
110 ) -> Vec<HashSet<String>> {
111 let node_count = graph.node_count();
112 if node_count == 0 {
113 return vec![];
114 }
115
116 let mut community: HashMap<NodeIndex, usize> = HashMap::new();
118 for (community_id, &idx) in node_map.values().enumerate() {
119 community.insert(idx, community_id);
120 }
121
122 let m = graph.edge_count() as f64;
124 if m == 0.0 {
125 return node_map
127 .keys()
128 .map(|k| {
129 let mut set = HashSet::new();
130 set.insert(k.clone());
131 set
132 })
133 .collect();
134 }
135
136 let degree: HashMap<NodeIndex, f64> = node_map
138 .values()
139 .map(|&idx| (idx, graph.neighbors(idx).count() as f64))
140 .collect();
141
142 for _ in 0..self.config.max_iterations {
144 let mut changed = false;
145
146 for (&node, ¤t_comm) in community.clone().iter() {
147 let node_degree = degree.get(&node).copied().unwrap_or(0.0);
148
149 let mut best_comm = current_comm;
151 let mut best_gain = 0.0;
152
153 let neighbor_comms: HashSet<usize> = graph
154 .neighbors(node)
155 .filter_map(|n| community.get(&n).copied())
156 .collect();
157
158 for &neighbor_comm in &neighbor_comms {
159 if neighbor_comm == current_comm {
160 continue;
161 }
162
163 let edges_to_comm: f64 = graph
165 .neighbors(node)
166 .filter(|n| community.get(n) == Some(&neighbor_comm))
167 .count() as f64;
168
169 let comm_degree: f64 = community
170 .iter()
171 .filter(|(_, &c)| c == neighbor_comm)
172 .map(|(n, _)| degree.get(n).copied().unwrap_or(0.0))
173 .sum();
174
175 let gain = edges_to_comm / m
176 - self.config.resolution * node_degree * comm_degree / (2.0 * m * m);
177
178 if gain > best_gain {
179 best_gain = gain;
180 best_comm = neighbor_comm;
181 }
182 }
183
184 if best_comm != current_comm && best_gain > 0.0 {
185 community.insert(node, best_comm);
186 changed = true;
187 }
188 }
189
190 if !changed {
191 break;
192 }
193 }
194
195 self.group_by_community(graph, &community)
197 }
198
199 fn label_propagation(
201 &self,
202 graph: &UnGraph<String, ()>,
203 node_map: &HashMap<String, NodeIndex>,
204 ) -> Vec<HashSet<String>> {
205 if graph.node_count() == 0 {
206 return vec![];
207 }
208
209 let mut labels: HashMap<NodeIndex, usize> = HashMap::new();
211 for (i, &idx) in node_map.values().enumerate() {
212 labels.insert(idx, i);
213 }
214
215 for _ in 0..self.config.max_iterations {
217 let mut changed = false;
218
219 for &node in node_map.values() {
220 let mut label_counts: HashMap<usize, usize> = HashMap::new();
222 for neighbor in graph.neighbors(node) {
223 if let Some(&label) = labels.get(&neighbor) {
224 *label_counts.entry(label).or_insert(0) += 1;
225 }
226 }
227
228 if let Some((&best_label, _)) = label_counts.iter().max_by_key(|(_, &count)| count)
230 {
231 if labels.get(&node) != Some(&best_label) {
232 labels.insert(node, best_label);
233 changed = true;
234 }
235 }
236 }
237
238 if !changed {
239 break;
240 }
241 }
242
243 self.group_by_community(graph, &labels)
244 }
245
246 fn connected_components(
248 &self,
249 graph: &UnGraph<String, ()>,
250 _node_map: &HashMap<String, NodeIndex>,
251 ) -> Vec<HashSet<String>> {
252 let sccs = petgraph::algo::kosaraju_scc(graph);
253
254 sccs.into_iter()
255 .map(|component| {
256 component
257 .into_iter()
258 .filter_map(|idx| graph.node_weight(idx).cloned())
259 .collect()
260 })
261 .collect()
262 }
263
264 fn group_by_community(
266 &self,
267 graph: &UnGraph<String, ()>,
268 assignment: &HashMap<NodeIndex, usize>,
269 ) -> Vec<HashSet<String>> {
270 let mut communities: HashMap<usize, HashSet<String>> = HashMap::new();
271
272 for (&node, &comm) in assignment {
273 if let Some(label) = graph.node_weight(node) {
274 communities.entry(comm).or_default().insert(label.clone());
275 }
276 }
277
278 communities.into_values().collect()
279 }
280
281 fn create_summaries(
283 &self,
284 communities: Vec<HashSet<String>>,
285 triples: &[Triple],
286 ) -> Vec<CommunitySummary> {
287 communities
288 .into_iter()
289 .enumerate()
290 .filter(|(_, entities)| entities.len() >= self.config.min_community_size)
291 .take(self.config.max_communities)
292 .map(|(idx, entities)| {
293 let representative_triples: Vec<Triple> = triples
295 .iter()
296 .filter(|t| entities.contains(&t.subject) || entities.contains(&t.object))
297 .take(5)
298 .cloned()
299 .collect();
300
301 let internal_edges = triples
303 .iter()
304 .filter(|t| entities.contains(&t.subject) && entities.contains(&t.object))
305 .count() as f64;
306 let total_edges = triples.len().max(1) as f64;
307 let modularity = internal_edges / total_edges;
308
309 let entity_list: Vec<String> = entities.iter().cloned().collect();
311 let summary = self.generate_summary(&entity_list, &representative_triples);
312
313 CommunitySummary {
314 id: format!("community_{}", idx),
315 summary,
316 entities: entity_list,
317 representative_triples,
318 level: 0,
319 modularity,
320 }
321 })
322 .collect()
323 }
324
325 fn generate_summary(&self, entities: &[String], triples: &[Triple]) -> String {
327 let short_names: Vec<String> = entities
329 .iter()
330 .take(3)
331 .map(|uri| {
332 uri.rsplit('/')
333 .next()
334 .or_else(|| uri.rsplit('#').next())
335 .unwrap_or(uri)
336 .to_string()
337 })
338 .collect();
339
340 let predicates: HashSet<String> = triples
342 .iter()
343 .map(|t| {
344 t.predicate
345 .rsplit('/')
346 .next()
347 .or_else(|| t.predicate.rsplit('#').next())
348 .unwrap_or(&t.predicate)
349 .to_string()
350 })
351 .collect();
352
353 let pred_str: Vec<String> = predicates.into_iter().take(3).collect();
354
355 format!(
356 "Community of {} entities including {} connected by {}",
357 entities.len(),
358 short_names.join(", "),
359 pred_str.join(", ")
360 )
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_community_detection() {
370 let detector = CommunityDetector::default();
371
372 let triples = vec![
373 Triple::new("http://a", "http://rel", "http://b"),
374 Triple::new("http://b", "http://rel", "http://c"),
375 Triple::new("http://a", "http://rel", "http://c"),
376 Triple::new("http://x", "http://rel", "http://y"),
377 Triple::new("http://y", "http://rel", "http://z"),
378 Triple::new("http://x", "http://rel", "http://z"),
379 ];
380
381 let communities = detector.detect(&triples).unwrap();
382
383 assert!(!communities.is_empty());
385 }
386
387 #[test]
388 fn test_empty_graph() {
389 let detector = CommunityDetector::default();
390 let communities = detector.detect(&[]).unwrap();
391 assert!(communities.is_empty());
392 }
393}