Skip to main content

proof_engine/graph/
community.rs

1use std::collections::{HashMap, HashSet};
2use super::graph_core::{Graph, GraphKind, NodeId};
3
4#[derive(Debug, Clone)]
5pub struct Community {
6    pub members: HashSet<NodeId>,
7}
8
9impl Community {
10    pub fn new() -> Self {
11        Self { members: HashSet::new() }
12    }
13
14    pub fn from_members(members: impl IntoIterator<Item = NodeId>) -> Self {
15        Self { members: members.into_iter().collect() }
16    }
17
18    pub fn contains(&self, id: NodeId) -> bool {
19        self.members.contains(&id)
20    }
21
22    pub fn len(&self) -> usize {
23        self.members.len()
24    }
25
26    pub fn is_empty(&self) -> bool {
27        self.members.is_empty()
28    }
29}
30
31#[derive(Debug, Clone)]
32pub struct CommunityResult {
33    pub communities: Vec<Community>,
34    pub modularity: f32,
35    pub iterations: usize,
36}
37
38/// Compute modularity Q for a given partitioning of the graph.
39/// Q = (1/2m) * sum_ij [ A_ij - k_i*k_j/(2m) ] * delta(c_i, c_j)
40pub fn modularity<N, E>(graph: &Graph<N, E>, communities: &[Community]) -> f32 {
41    let m = graph.edge_count() as f32;
42    if m == 0.0 { return 0.0; }
43
44    let m2 = if graph.kind == GraphKind::Undirected { 2.0 * m } else { m };
45
46    // Build community assignment map
47    let mut community_of: HashMap<NodeId, usize> = HashMap::new();
48    for (ci, comm) in communities.iter().enumerate() {
49        for &nid in &comm.members {
50            community_of.insert(nid, ci);
51        }
52    }
53
54    let mut q = 0.0f32;
55    let node_ids = graph.node_ids();
56
57    // Precompute degrees
58    let degrees: HashMap<NodeId, f32> = node_ids.iter()
59        .map(|&nid| (nid, graph.degree(nid) as f32))
60        .collect();
61
62    for edge in graph.edges() {
63        let ci = community_of.get(&edge.from).copied().unwrap_or(usize::MAX);
64        let cj = community_of.get(&edge.to).copied().unwrap_or(usize::MAX);
65        if ci == cj {
66            q += 1.0 - degrees[&edge.from] * degrees[&edge.to] / m2;
67            if graph.kind == GraphKind::Undirected {
68                q += 1.0 - degrees[&edge.to] * degrees[&edge.from] / m2;
69            }
70        }
71    }
72
73    q / m2
74}
75
76/// Louvain method for community detection.
77/// Iteratively moves nodes to communities that maximize modularity gain.
78pub fn louvain<N: Clone, E: Clone>(graph: &Graph<N, E>) -> CommunityResult {
79    let node_ids = graph.node_ids();
80    let n = node_ids.len();
81    if n == 0 {
82        return CommunityResult { communities: Vec::new(), modularity: 0.0, iterations: 0 };
83    }
84
85    let m = graph.edge_count() as f32;
86    if m == 0.0 {
87        let communities: Vec<Community> = node_ids.iter()
88            .map(|&nid| Community::from_members(std::iter::once(nid)))
89            .collect();
90        return CommunityResult { communities, modularity: 0.0, iterations: 0 };
91    }
92
93    let m2 = if graph.kind == GraphKind::Undirected { 2.0 * m } else { m };
94
95    // Each node starts in its own community
96    let mut comm_of: HashMap<NodeId, usize> = HashMap::new();
97    for (i, &nid) in node_ids.iter().enumerate() {
98        comm_of.insert(nid, i);
99    }
100    let mut num_communities = n;
101
102    // Precompute degrees and adjacency weights
103    let degrees: HashMap<NodeId, f32> = node_ids.iter()
104        .map(|&nid| (nid, graph.degree(nid) as f32))
105        .collect();
106
107    // Weighted adjacency
108    let mut adj_weights: HashMap<NodeId, Vec<(NodeId, f32)>> = HashMap::new();
109    for &nid in &node_ids {
110        let mut ws = Vec::new();
111        for (nbr, eid) in graph.neighbor_edges(nid) {
112            ws.push((nbr, graph.edge_weight(eid)));
113        }
114        adj_weights.insert(nid, ws);
115    }
116
117    // Sum of weights in each community
118    let mut sigma_tot: HashMap<usize, f32> = HashMap::new();
119    for &nid in &node_ids {
120        let c = comm_of[&nid];
121        *sigma_tot.entry(c).or_insert(0.0) += degrees[&nid];
122    }
123
124    let mut iterations = 0;
125    let max_iterations = 100;
126
127    loop {
128        iterations += 1;
129        let mut improved = false;
130
131        for &nid in &node_ids {
132            let current_comm = comm_of[&nid];
133            let ki = degrees[&nid];
134
135            // Compute weights to each neighboring community
136            let mut comm_weights: HashMap<usize, f32> = HashMap::new();
137            for &(nbr, w) in adj_weights.get(&nid).unwrap_or(&Vec::new()) {
138                let nc = comm_of[&nbr];
139                *comm_weights.entry(nc).or_insert(0.0) += w;
140            }
141
142            // Remove node from current community
143            *sigma_tot.get_mut(&current_comm).unwrap() -= ki;
144
145            // Find best community
146            let ki_in_current = comm_weights.get(&current_comm).copied().unwrap_or(0.0);
147            let mut best_comm = current_comm;
148            let mut best_gain = 0.0f32;
149
150            for (&c, &ki_in) in &comm_weights {
151                let st = sigma_tot.get(&c).copied().unwrap_or(0.0);
152                let gain = ki_in / m2 - st * ki / (m2 * m2);
153                let loss = ki_in_current / m2 - sigma_tot.get(&current_comm).copied().unwrap_or(0.0) * ki / (m2 * m2);
154                let delta_q = gain - loss;
155                if delta_q > best_gain {
156                    best_gain = delta_q;
157                    best_comm = c;
158                }
159            }
160
161            // Move node to best community
162            comm_of.insert(nid, best_comm);
163            *sigma_tot.get_mut(&best_comm).unwrap_or(&mut 0.0) += ki;
164            if !sigma_tot.contains_key(&best_comm) {
165                sigma_tot.insert(best_comm, ki);
166            }
167
168            if best_comm != current_comm {
169                improved = true;
170            }
171        }
172
173        if !improved || iterations >= max_iterations {
174            break;
175        }
176    }
177
178    // Build communities from assignments
179    let mut comm_map: HashMap<usize, Vec<NodeId>> = HashMap::new();
180    for (&nid, &c) in &comm_of {
181        comm_map.entry(c).or_default().push(nid);
182    }
183
184    let communities: Vec<Community> = comm_map.into_values()
185        .map(|members| Community::from_members(members))
186        .collect();
187
188    let mod_val = modularity(graph, &communities);
189
190    CommunityResult {
191        communities,
192        modularity: mod_val,
193        iterations,
194    }
195}
196
197/// Label propagation community detection.
198/// Each node adopts the label most common among its neighbors.
199pub fn label_propagation<N, E>(graph: &Graph<N, E>) -> CommunityResult {
200    let node_ids = graph.node_ids();
201    let n = node_ids.len();
202    if n == 0 {
203        return CommunityResult { communities: Vec::new(), modularity: 0.0, iterations: 0 };
204    }
205
206    // Initialize each node with its own label
207    let mut labels: HashMap<NodeId, u32> = HashMap::new();
208    for (i, &nid) in node_ids.iter().enumerate() {
209        labels.insert(nid, i as u32);
210    }
211
212    let max_iterations = 100;
213    let mut iterations = 0;
214
215    // Simple deterministic ordering (could shuffle for randomness)
216    loop {
217        iterations += 1;
218        let mut changed = false;
219
220        for &nid in &node_ids {
221            let neighbors = graph.neighbors(nid);
222            if neighbors.is_empty() { continue; }
223
224            // Count label frequencies among neighbors
225            let mut freq: HashMap<u32, usize> = HashMap::new();
226            for nbr in &neighbors {
227                let lbl = labels[nbr];
228                *freq.entry(lbl).or_insert(0) += 1;
229            }
230
231            // Pick most frequent label (ties broken by smallest label)
232            let max_count = freq.values().copied().max().unwrap_or(0);
233            let best_label = freq.iter()
234                .filter(|(_, &c)| c == max_count)
235                .map(|(&l, _)| l)
236                .min()
237                .unwrap_or(labels[&nid]);
238
239            if labels[&nid] != best_label {
240                labels.insert(nid, best_label);
241                changed = true;
242            }
243        }
244
245        if !changed || iterations >= max_iterations {
246            break;
247        }
248    }
249
250    // Build communities from labels
251    let mut comm_map: HashMap<u32, Vec<NodeId>> = HashMap::new();
252    for (&nid, &lbl) in &labels {
253        comm_map.entry(lbl).or_default().push(nid);
254    }
255
256    let communities: Vec<Community> = comm_map.into_values()
257        .map(|members| Community::from_members(members))
258        .collect();
259
260    let mod_val = modularity(graph, &communities);
261
262    CommunityResult {
263        communities,
264        modularity: mod_val,
265        iterations,
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::graph::graph_core::GraphKind;
273
274    fn make_two_cliques() -> Graph<(), ()> {
275        let mut g = Graph::new(GraphKind::Undirected);
276        // Clique 1: 0,1,2
277        let a = g.add_node(());
278        let b = g.add_node(());
279        let c = g.add_node(());
280        g.add_edge(a, b, ());
281        g.add_edge(b, c, ());
282        g.add_edge(a, c, ());
283        // Clique 2: 3,4,5
284        let d = g.add_node(());
285        let e = g.add_node(());
286        let f = g.add_node(());
287        g.add_edge(d, e, ());
288        g.add_edge(e, f, ());
289        g.add_edge(d, f, ());
290        // Bridge
291        g.add_edge(c, d, ());
292        g
293    }
294
295    #[test]
296    fn test_modularity_single_community() {
297        let mut g = Graph::new(GraphKind::Undirected);
298        let a = g.add_node(());
299        let b = g.add_node(());
300        let c = g.add_node(());
301        g.add_edge(a, b, ());
302        g.add_edge(b, c, ());
303        g.add_edge(a, c, ());
304        let comms = vec![Community::from_members(vec![a, b, c])];
305        let q = modularity(&g, &comms);
306        // All in one community: modularity should be 0
307        assert!((q - 0.0).abs() < 0.01);
308    }
309
310    #[test]
311    fn test_louvain_two_cliques() {
312        let g = make_two_cliques();
313        let result = louvain(&g);
314        // Should find 2 communities
315        assert!(result.communities.len() >= 2);
316        assert!(result.modularity >= 0.0);
317    }
318
319    #[test]
320    fn test_label_propagation_two_cliques() {
321        let g = make_two_cliques();
322        let result = label_propagation(&g);
323        assert!(result.communities.len() >= 1);
324    }
325
326    #[test]
327    fn test_louvain_empty() {
328        let g: Graph<(), ()> = Graph::new(GraphKind::Undirected);
329        let result = louvain(&g);
330        assert!(result.communities.is_empty());
331    }
332
333    #[test]
334    fn test_label_propagation_disconnected() {
335        let mut g = Graph::new(GraphKind::Undirected);
336        let a = g.add_node(());
337        let b = g.add_node(());
338        let c = g.add_node(());
339        // No edges => each node is its own community
340        let result = label_propagation(&g);
341        assert_eq!(result.communities.len(), 3);
342    }
343
344    #[test]
345    fn test_community_struct() {
346        let c = Community::from_members(vec![NodeId(0), NodeId(1), NodeId(2)]);
347        assert_eq!(c.len(), 3);
348        assert!(c.contains(NodeId(1)));
349        assert!(!c.contains(NodeId(5)));
350        assert!(!c.is_empty());
351    }
352
353    #[test]
354    fn test_louvain_single_node() {
355        let mut g: Graph<(), ()> = Graph::new(GraphKind::Undirected);
356        g.add_node(());
357        let result = louvain(&g);
358        assert_eq!(result.communities.len(), 1);
359    }
360
361    #[test]
362    fn test_modularity_two_perfect_communities() {
363        let mut g = Graph::new(GraphKind::Undirected);
364        let a = g.add_node(());
365        let b = g.add_node(());
366        let c = g.add_node(());
367        let d = g.add_node(());
368        g.add_edge(a, b, ());
369        g.add_edge(c, d, ());
370        let comms = vec![
371            Community::from_members(vec![a, b]),
372            Community::from_members(vec![c, d]),
373        ];
374        let q = modularity(&g, &comms);
375        assert!(q > 0.0, "Modularity should be positive for good partition, got {}", q);
376    }
377}