1use std::collections::{HashMap, HashSet};
2use super::graph_core::{Graph, GraphKind, NodeId};
3
4#[derive(Debug, Clone)]
5pub struct Community {
6 pub members: HashSet<NodeId>,
7}
8
9impl Community {
10 pub fn new() -> Self {
11 Self { members: HashSet::new() }
12 }
13
14 pub fn from_members(members: impl IntoIterator<Item = NodeId>) -> Self {
15 Self { members: members.into_iter().collect() }
16 }
17
18 pub fn contains(&self, id: NodeId) -> bool {
19 self.members.contains(&id)
20 }
21
22 pub fn len(&self) -> usize {
23 self.members.len()
24 }
25
26 pub fn is_empty(&self) -> bool {
27 self.members.is_empty()
28 }
29}
30
31#[derive(Debug, Clone)]
32pub struct CommunityResult {
33 pub communities: Vec<Community>,
34 pub modularity: f32,
35 pub iterations: usize,
36}
37
38pub fn modularity<N, E>(graph: &Graph<N, E>, communities: &[Community]) -> f32 {
41 let m = graph.edge_count() as f32;
42 if m == 0.0 { return 0.0; }
43
44 let m2 = if graph.kind == GraphKind::Undirected { 2.0 * m } else { m };
45
46 let mut community_of: HashMap<NodeId, usize> = HashMap::new();
48 for (ci, comm) in communities.iter().enumerate() {
49 for &nid in &comm.members {
50 community_of.insert(nid, ci);
51 }
52 }
53
54 let mut q = 0.0f32;
55 let node_ids = graph.node_ids();
56
57 let degrees: HashMap<NodeId, f32> = node_ids.iter()
59 .map(|&nid| (nid, graph.degree(nid) as f32))
60 .collect();
61
62 for edge in graph.edges() {
63 let ci = community_of.get(&edge.from).copied().unwrap_or(usize::MAX);
64 let cj = community_of.get(&edge.to).copied().unwrap_or(usize::MAX);
65 if ci == cj {
66 q += 1.0 - degrees[&edge.from] * degrees[&edge.to] / m2;
67 if graph.kind == GraphKind::Undirected {
68 q += 1.0 - degrees[&edge.to] * degrees[&edge.from] / m2;
69 }
70 }
71 }
72
73 q / m2
74}
75
76pub fn louvain<N: Clone, E: Clone>(graph: &Graph<N, E>) -> CommunityResult {
79 let node_ids = graph.node_ids();
80 let n = node_ids.len();
81 if n == 0 {
82 return CommunityResult { communities: Vec::new(), modularity: 0.0, iterations: 0 };
83 }
84
85 let m = graph.edge_count() as f32;
86 if m == 0.0 {
87 let communities: Vec<Community> = node_ids.iter()
88 .map(|&nid| Community::from_members(std::iter::once(nid)))
89 .collect();
90 return CommunityResult { communities, modularity: 0.0, iterations: 0 };
91 }
92
93 let m2 = if graph.kind == GraphKind::Undirected { 2.0 * m } else { m };
94
95 let mut comm_of: HashMap<NodeId, usize> = HashMap::new();
97 for (i, &nid) in node_ids.iter().enumerate() {
98 comm_of.insert(nid, i);
99 }
100 let mut num_communities = n;
101
102 let degrees: HashMap<NodeId, f32> = node_ids.iter()
104 .map(|&nid| (nid, graph.degree(nid) as f32))
105 .collect();
106
107 let mut adj_weights: HashMap<NodeId, Vec<(NodeId, f32)>> = HashMap::new();
109 for &nid in &node_ids {
110 let mut ws = Vec::new();
111 for (nbr, eid) in graph.neighbor_edges(nid) {
112 ws.push((nbr, graph.edge_weight(eid)));
113 }
114 adj_weights.insert(nid, ws);
115 }
116
117 let mut sigma_tot: HashMap<usize, f32> = HashMap::new();
119 for &nid in &node_ids {
120 let c = comm_of[&nid];
121 *sigma_tot.entry(c).or_insert(0.0) += degrees[&nid];
122 }
123
124 let mut iterations = 0;
125 let max_iterations = 100;
126
127 loop {
128 iterations += 1;
129 let mut improved = false;
130
131 for &nid in &node_ids {
132 let current_comm = comm_of[&nid];
133 let ki = degrees[&nid];
134
135 let mut comm_weights: HashMap<usize, f32> = HashMap::new();
137 for &(nbr, w) in adj_weights.get(&nid).unwrap_or(&Vec::new()) {
138 let nc = comm_of[&nbr];
139 *comm_weights.entry(nc).or_insert(0.0) += w;
140 }
141
142 *sigma_tot.get_mut(¤t_comm).unwrap() -= ki;
144
145 let ki_in_current = comm_weights.get(¤t_comm).copied().unwrap_or(0.0);
147 let mut best_comm = current_comm;
148 let mut best_gain = 0.0f32;
149
150 for (&c, &ki_in) in &comm_weights {
151 let st = sigma_tot.get(&c).copied().unwrap_or(0.0);
152 let gain = ki_in / m2 - st * ki / (m2 * m2);
153 let loss = ki_in_current / m2 - sigma_tot.get(¤t_comm).copied().unwrap_or(0.0) * ki / (m2 * m2);
154 let delta_q = gain - loss;
155 if delta_q > best_gain {
156 best_gain = delta_q;
157 best_comm = c;
158 }
159 }
160
161 comm_of.insert(nid, best_comm);
163 *sigma_tot.get_mut(&best_comm).unwrap_or(&mut 0.0) += ki;
164 if !sigma_tot.contains_key(&best_comm) {
165 sigma_tot.insert(best_comm, ki);
166 }
167
168 if best_comm != current_comm {
169 improved = true;
170 }
171 }
172
173 if !improved || iterations >= max_iterations {
174 break;
175 }
176 }
177
178 let mut comm_map: HashMap<usize, Vec<NodeId>> = HashMap::new();
180 for (&nid, &c) in &comm_of {
181 comm_map.entry(c).or_default().push(nid);
182 }
183
184 let communities: Vec<Community> = comm_map.into_values()
185 .map(|members| Community::from_members(members))
186 .collect();
187
188 let mod_val = modularity(graph, &communities);
189
190 CommunityResult {
191 communities,
192 modularity: mod_val,
193 iterations,
194 }
195}
196
197pub fn label_propagation<N, E>(graph: &Graph<N, E>) -> CommunityResult {
200 let node_ids = graph.node_ids();
201 let n = node_ids.len();
202 if n == 0 {
203 return CommunityResult { communities: Vec::new(), modularity: 0.0, iterations: 0 };
204 }
205
206 let mut labels: HashMap<NodeId, u32> = HashMap::new();
208 for (i, &nid) in node_ids.iter().enumerate() {
209 labels.insert(nid, i as u32);
210 }
211
212 let max_iterations = 100;
213 let mut iterations = 0;
214
215 loop {
217 iterations += 1;
218 let mut changed = false;
219
220 for &nid in &node_ids {
221 let neighbors = graph.neighbors(nid);
222 if neighbors.is_empty() { continue; }
223
224 let mut freq: HashMap<u32, usize> = HashMap::new();
226 for nbr in &neighbors {
227 let lbl = labels[nbr];
228 *freq.entry(lbl).or_insert(0) += 1;
229 }
230
231 let max_count = freq.values().copied().max().unwrap_or(0);
233 let best_label = freq.iter()
234 .filter(|(_, &c)| c == max_count)
235 .map(|(&l, _)| l)
236 .min()
237 .unwrap_or(labels[&nid]);
238
239 if labels[&nid] != best_label {
240 labels.insert(nid, best_label);
241 changed = true;
242 }
243 }
244
245 if !changed || iterations >= max_iterations {
246 break;
247 }
248 }
249
250 let mut comm_map: HashMap<u32, Vec<NodeId>> = HashMap::new();
252 for (&nid, &lbl) in &labels {
253 comm_map.entry(lbl).or_default().push(nid);
254 }
255
256 let communities: Vec<Community> = comm_map.into_values()
257 .map(|members| Community::from_members(members))
258 .collect();
259
260 let mod_val = modularity(graph, &communities);
261
262 CommunityResult {
263 communities,
264 modularity: mod_val,
265 iterations,
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::graph::graph_core::GraphKind;
273
274 fn make_two_cliques() -> Graph<(), ()> {
275 let mut g = Graph::new(GraphKind::Undirected);
276 let a = g.add_node(());
278 let b = g.add_node(());
279 let c = g.add_node(());
280 g.add_edge(a, b, ());
281 g.add_edge(b, c, ());
282 g.add_edge(a, c, ());
283 let d = g.add_node(());
285 let e = g.add_node(());
286 let f = g.add_node(());
287 g.add_edge(d, e, ());
288 g.add_edge(e, f, ());
289 g.add_edge(d, f, ());
290 g.add_edge(c, d, ());
292 g
293 }
294
295 #[test]
296 fn test_modularity_single_community() {
297 let mut g = Graph::new(GraphKind::Undirected);
298 let a = g.add_node(());
299 let b = g.add_node(());
300 let c = g.add_node(());
301 g.add_edge(a, b, ());
302 g.add_edge(b, c, ());
303 g.add_edge(a, c, ());
304 let comms = vec![Community::from_members(vec![a, b, c])];
305 let q = modularity(&g, &comms);
306 assert!((q - 0.0).abs() < 0.01);
308 }
309
310 #[test]
311 fn test_louvain_two_cliques() {
312 let g = make_two_cliques();
313 let result = louvain(&g);
314 assert!(result.communities.len() >= 2);
316 assert!(result.modularity >= 0.0);
317 }
318
319 #[test]
320 fn test_label_propagation_two_cliques() {
321 let g = make_two_cliques();
322 let result = label_propagation(&g);
323 assert!(result.communities.len() >= 1);
324 }
325
326 #[test]
327 fn test_louvain_empty() {
328 let g: Graph<(), ()> = Graph::new(GraphKind::Undirected);
329 let result = louvain(&g);
330 assert!(result.communities.is_empty());
331 }
332
333 #[test]
334 fn test_label_propagation_disconnected() {
335 let mut g = Graph::new(GraphKind::Undirected);
336 let a = g.add_node(());
337 let b = g.add_node(());
338 let c = g.add_node(());
339 let result = label_propagation(&g);
341 assert_eq!(result.communities.len(), 3);
342 }
343
344 #[test]
345 fn test_community_struct() {
346 let c = Community::from_members(vec![NodeId(0), NodeId(1), NodeId(2)]);
347 assert_eq!(c.len(), 3);
348 assert!(c.contains(NodeId(1)));
349 assert!(!c.contains(NodeId(5)));
350 assert!(!c.is_empty());
351 }
352
353 #[test]
354 fn test_louvain_single_node() {
355 let mut g: Graph<(), ()> = Graph::new(GraphKind::Undirected);
356 g.add_node(());
357 let result = louvain(&g);
358 assert_eq!(result.communities.len(), 1);
359 }
360
361 #[test]
362 fn test_modularity_two_perfect_communities() {
363 let mut g = Graph::new(GraphKind::Undirected);
364 let a = g.add_node(());
365 let b = g.add_node(());
366 let c = g.add_node(());
367 let d = g.add_node(());
368 g.add_edge(a, b, ());
369 g.add_edge(c, d, ());
370 let comms = vec![
371 Community::from_members(vec![a, b]),
372 Community::from_members(vec![c, d]),
373 ];
374 let q = modularity(&g, &comms);
375 assert!(q > 0.0, "Modularity should be positive for good partition, got {}", q);
376 }
377}