oxirs_graphrag/graph/
community.rs

1//! Community detection for hierarchical summarization
2
3use crate::{CommunitySummary, GraphRAGResult, Triple};
4use petgraph::graph::{NodeIndex, UnGraph};
5use std::collections::{HashMap, HashSet};
6
7/// Community detection algorithm
8#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
9pub enum CommunityAlgorithm {
10    /// Louvain algorithm
11    #[default]
12    Louvain,
13    /// Label propagation
14    LabelPropagation,
15    /// Connected components
16    ConnectedComponents,
17}
18
19/// Community detector configuration
20#[derive(Debug, Clone)]
21pub struct CommunityConfig {
22    /// Algorithm to use
23    pub algorithm: CommunityAlgorithm,
24    /// Resolution parameter for Louvain
25    pub resolution: f64,
26    /// Minimum community size
27    pub min_community_size: usize,
28    /// Maximum number of communities
29    pub max_communities: usize,
30    /// Number of iterations for iterative algorithms
31    pub max_iterations: usize,
32}
33
34impl Default for CommunityConfig {
35    fn default() -> Self {
36        Self {
37            algorithm: CommunityAlgorithm::Louvain,
38            resolution: 1.0,
39            min_community_size: 2,
40            max_communities: 50,
41            max_iterations: 100,
42        }
43    }
44}
45
46/// Community detector
47pub struct CommunityDetector {
48    config: CommunityConfig,
49}
50
51impl Default for CommunityDetector {
52    fn default() -> Self {
53        Self::new(CommunityConfig::default())
54    }
55}
56
57impl CommunityDetector {
58    pub fn new(config: CommunityConfig) -> Self {
59        Self { config }
60    }
61
62    /// Detect communities in the given subgraph
63    pub fn detect(&self, triples: &[Triple]) -> GraphRAGResult<Vec<CommunitySummary>> {
64        if triples.is_empty() {
65            return Ok(vec![]);
66        }
67
68        // Build graph
69        let (graph, node_map) = self.build_graph(triples);
70
71        // Detect communities based on algorithm
72        let communities = match self.config.algorithm {
73            CommunityAlgorithm::Louvain => self.louvain(&graph, &node_map),
74            CommunityAlgorithm::LabelPropagation => self.label_propagation(&graph, &node_map),
75            CommunityAlgorithm::ConnectedComponents => self.connected_components(&graph, &node_map),
76        };
77
78        // Filter and create summaries
79        let summaries = self.create_summaries(communities, triples);
80
81        Ok(summaries)
82    }
83
84    /// Build undirected graph from triples
85    fn build_graph(&self, triples: &[Triple]) -> (UnGraph<String, ()>, HashMap<String, NodeIndex>) {
86        let mut graph: UnGraph<String, ()> = UnGraph::new_undirected();
87        let mut node_map: HashMap<String, NodeIndex> = HashMap::new();
88
89        for triple in triples {
90            let subj_idx = *node_map
91                .entry(triple.subject.clone())
92                .or_insert_with(|| graph.add_node(triple.subject.clone()));
93            let obj_idx = *node_map
94                .entry(triple.object.clone())
95                .or_insert_with(|| graph.add_node(triple.object.clone()));
96
97            if subj_idx != obj_idx && graph.find_edge(subj_idx, obj_idx).is_none() {
98                graph.add_edge(subj_idx, obj_idx, ());
99            }
100        }
101
102        (graph, node_map)
103    }
104
105    /// Simplified Louvain algorithm
106    fn louvain(
107        &self,
108        graph: &UnGraph<String, ()>,
109        node_map: &HashMap<String, NodeIndex>,
110    ) -> Vec<HashSet<String>> {
111        let node_count = graph.node_count();
112        if node_count == 0 {
113            return vec![];
114        }
115
116        // Initialize: each node in its own community
117        let mut community: HashMap<NodeIndex, usize> = HashMap::new();
118        for (community_id, &idx) in node_map.values().enumerate() {
119            community.insert(idx, community_id);
120        }
121
122        // Total edges (for modularity calculation)
123        let m = graph.edge_count() as f64;
124        if m == 0.0 {
125            // No edges, each node is its own community
126            return node_map
127                .keys()
128                .map(|k| {
129                    let mut set = HashSet::new();
130                    set.insert(k.clone());
131                    set
132                })
133                .collect();
134        }
135
136        // Degree of each node
137        let degree: HashMap<NodeIndex, f64> = node_map
138            .values()
139            .map(|&idx| (idx, graph.neighbors(idx).count() as f64))
140            .collect();
141
142        // Iterate
143        for _ in 0..self.config.max_iterations {
144            let mut changed = false;
145
146            for (&node, &current_comm) in community.clone().iter() {
147                let node_degree = degree.get(&node).copied().unwrap_or(0.0);
148
149                // Calculate modularity gain for each neighbor's community
150                let mut best_comm = current_comm;
151                let mut best_gain = 0.0;
152
153                let neighbor_comms: HashSet<usize> = graph
154                    .neighbors(node)
155                    .filter_map(|n| community.get(&n).copied())
156                    .collect();
157
158                for &neighbor_comm in &neighbor_comms {
159                    if neighbor_comm == current_comm {
160                        continue;
161                    }
162
163                    // Simplified modularity gain calculation
164                    let edges_to_comm: f64 = graph
165                        .neighbors(node)
166                        .filter(|n| community.get(n) == Some(&neighbor_comm))
167                        .count() as f64;
168
169                    let comm_degree: f64 = community
170                        .iter()
171                        .filter(|(_, &c)| c == neighbor_comm)
172                        .map(|(n, _)| degree.get(n).copied().unwrap_or(0.0))
173                        .sum();
174
175                    let gain = edges_to_comm / m
176                        - self.config.resolution * node_degree * comm_degree / (2.0 * m * m);
177
178                    if gain > best_gain {
179                        best_gain = gain;
180                        best_comm = neighbor_comm;
181                    }
182                }
183
184                if best_comm != current_comm && best_gain > 0.0 {
185                    community.insert(node, best_comm);
186                    changed = true;
187                }
188            }
189
190            if !changed {
191                break;
192            }
193        }
194
195        // Group nodes by community
196        self.group_by_community(graph, &community)
197    }
198
199    /// Label propagation algorithm
200    fn label_propagation(
201        &self,
202        graph: &UnGraph<String, ()>,
203        node_map: &HashMap<String, NodeIndex>,
204    ) -> Vec<HashSet<String>> {
205        if graph.node_count() == 0 {
206            return vec![];
207        }
208
209        // Initialize labels
210        let mut labels: HashMap<NodeIndex, usize> = HashMap::new();
211        for (i, &idx) in node_map.values().enumerate() {
212            labels.insert(idx, i);
213        }
214
215        // Iterate
216        for _ in 0..self.config.max_iterations {
217            let mut changed = false;
218
219            for &node in node_map.values() {
220                // Count neighbor labels
221                let mut label_counts: HashMap<usize, usize> = HashMap::new();
222                for neighbor in graph.neighbors(node) {
223                    if let Some(&label) = labels.get(&neighbor) {
224                        *label_counts.entry(label).or_insert(0) += 1;
225                    }
226                }
227
228                // Assign most common label
229                if let Some((&best_label, _)) = label_counts.iter().max_by_key(|(_, &count)| count)
230                {
231                    if labels.get(&node) != Some(&best_label) {
232                        labels.insert(node, best_label);
233                        changed = true;
234                    }
235                }
236            }
237
238            if !changed {
239                break;
240            }
241        }
242
243        self.group_by_community(graph, &labels)
244    }
245
246    /// Connected components
247    fn connected_components(
248        &self,
249        graph: &UnGraph<String, ()>,
250        _node_map: &HashMap<String, NodeIndex>,
251    ) -> Vec<HashSet<String>> {
252        let sccs = petgraph::algo::kosaraju_scc(graph);
253
254        sccs.into_iter()
255            .map(|component| {
256                component
257                    .into_iter()
258                    .filter_map(|idx| graph.node_weight(idx).cloned())
259                    .collect()
260            })
261            .collect()
262    }
263
264    /// Group nodes by community assignment
265    fn group_by_community(
266        &self,
267        graph: &UnGraph<String, ()>,
268        assignment: &HashMap<NodeIndex, usize>,
269    ) -> Vec<HashSet<String>> {
270        let mut communities: HashMap<usize, HashSet<String>> = HashMap::new();
271
272        for (&node, &comm) in assignment {
273            if let Some(label) = graph.node_weight(node) {
274                communities.entry(comm).or_default().insert(label.clone());
275            }
276        }
277
278        communities.into_values().collect()
279    }
280
281    /// Create community summaries
282    fn create_summaries(
283        &self,
284        communities: Vec<HashSet<String>>,
285        triples: &[Triple],
286    ) -> Vec<CommunitySummary> {
287        communities
288            .into_iter()
289            .enumerate()
290            .filter(|(_, entities)| entities.len() >= self.config.min_community_size)
291            .take(self.config.max_communities)
292            .map(|(idx, entities)| {
293                // Find representative triples
294                let representative_triples: Vec<Triple> = triples
295                    .iter()
296                    .filter(|t| entities.contains(&t.subject) || entities.contains(&t.object))
297                    .take(5)
298                    .cloned()
299                    .collect();
300
301                // Calculate modularity (simplified)
302                let internal_edges = triples
303                    .iter()
304                    .filter(|t| entities.contains(&t.subject) && entities.contains(&t.object))
305                    .count() as f64;
306                let total_edges = triples.len().max(1) as f64;
307                let modularity = internal_edges / total_edges;
308
309                // Generate summary
310                let entity_list: Vec<String> = entities.iter().cloned().collect();
311                let summary = self.generate_summary(&entity_list, &representative_triples);
312
313                CommunitySummary {
314                    id: format!("community_{}", idx),
315                    summary,
316                    entities: entity_list,
317                    representative_triples,
318                    level: 0,
319                    modularity,
320                }
321            })
322            .collect()
323    }
324
325    /// Generate a text summary for a community
326    fn generate_summary(&self, entities: &[String], triples: &[Triple]) -> String {
327        // Extract short names from URIs
328        let short_names: Vec<String> = entities
329            .iter()
330            .take(3)
331            .map(|uri| {
332                uri.rsplit('/')
333                    .next()
334                    .or_else(|| uri.rsplit('#').next())
335                    .unwrap_or(uri)
336                    .to_string()
337            })
338            .collect();
339
340        // Extract predicates
341        let predicates: HashSet<String> = triples
342            .iter()
343            .map(|t| {
344                t.predicate
345                    .rsplit('/')
346                    .next()
347                    .or_else(|| t.predicate.rsplit('#').next())
348                    .unwrap_or(&t.predicate)
349                    .to_string()
350            })
351            .collect();
352
353        let pred_str: Vec<String> = predicates.into_iter().take(3).collect();
354
355        format!(
356            "Community of {} entities including {} connected by {}",
357            entities.len(),
358            short_names.join(", "),
359            pred_str.join(", ")
360        )
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_community_detection() {
370        let detector = CommunityDetector::default();
371
372        let triples = vec![
373            Triple::new("http://a", "http://rel", "http://b"),
374            Triple::new("http://b", "http://rel", "http://c"),
375            Triple::new("http://a", "http://rel", "http://c"),
376            Triple::new("http://x", "http://rel", "http://y"),
377            Triple::new("http://y", "http://rel", "http://z"),
378            Triple::new("http://x", "http://rel", "http://z"),
379        ];
380
381        let communities = detector.detect(&triples).unwrap();
382
383        // Should detect 2 communities (a-b-c and x-y-z)
384        assert!(!communities.is_empty());
385    }
386
387    #[test]
388    fn test_empty_graph() {
389        let detector = CommunityDetector::default();
390        let communities = detector.detect(&[]).unwrap();
391        assert!(communities.is_empty());
392    }
393}