grafeo_adapters/plugins/algorithms/
community.rs1use std::sync::OnceLock;
7
8use grafeo_common::types::{NodeId, Value};
9use grafeo_common::utils::error::Result;
10use grafeo_common::utils::hash::{FxHashMap, FxHashSet};
11use grafeo_core::graph::Direction;
12use grafeo_core::graph::lpg::LpgStore;
13
14use super::super::{AlgorithmResult, ParameterDef, ParameterType, Parameters};
15use super::traits::{ComponentResultBuilder, GraphAlgorithm};
16
17pub fn label_propagation(store: &LpgStore, max_iterations: usize) -> FxHashMap<NodeId, u64> {
40 let nodes = store.node_ids();
41 let n = nodes.len();
42
43 if n == 0 {
44 return FxHashMap::default();
45 }
46
47 let mut labels: FxHashMap<NodeId, u64> = FxHashMap::default();
49 for (idx, &node) in nodes.iter().enumerate() {
50 labels.insert(node, idx as u64);
51 }
52
53 let max_iter = if max_iterations == 0 {
54 n * 10
55 } else {
56 max_iterations
57 };
58
59 for _ in 0..max_iter {
60 let mut changed = false;
61
62 for &node in &nodes {
64 let mut label_counts: FxHashMap<u64, usize> = FxHashMap::default();
66
67 for (neighbor, _) in store.edges_from(node, Direction::Outgoing) {
70 if let Some(&label) = labels.get(&neighbor) {
71 *label_counts.entry(label).or_insert(0) += 1;
72 }
73 }
74
75 for (incoming_neighbor, _) in store.edges_from(node, Direction::Incoming) {
78 if let Some(&label) = labels.get(&incoming_neighbor) {
79 *label_counts.entry(label).or_insert(0) += 1;
80 }
81 }
82
83 if label_counts.is_empty() {
84 continue;
85 }
86
87 let max_count = *label_counts.values().max().unwrap_or(&0);
89 let max_labels: Vec<u64> = label_counts
90 .into_iter()
91 .filter(|&(_, count)| count == max_count)
92 .map(|(label, _)| label)
93 .collect();
94
95 let new_label = *max_labels.iter().min().unwrap();
97 let current_label = *labels.get(&node).unwrap();
98
99 if new_label != current_label {
100 labels.insert(node, new_label);
101 changed = true;
102 }
103 }
104
105 if !changed {
106 break;
107 }
108 }
109
110 let unique_labels: FxHashSet<u64> = labels.values().copied().collect();
112 let mut label_map: FxHashMap<u64, u64> = FxHashMap::default();
113 for (idx, label) in unique_labels.into_iter().enumerate() {
114 label_map.insert(label, idx as u64);
115 }
116
117 labels
118 .into_iter()
119 .map(|(node, label)| (node, *label_map.get(&label).unwrap()))
120 .collect()
121}
122
123#[derive(Debug, Clone)]
129pub struct LouvainResult {
130 pub communities: FxHashMap<NodeId, u64>,
132 pub modularity: f64,
134 pub num_communities: usize,
136}
137
138pub fn louvain(store: &LpgStore, resolution: f64) -> LouvainResult {
158 let nodes = store.node_ids();
159 let n = nodes.len();
160
161 if n == 0 {
162 return LouvainResult {
163 communities: FxHashMap::default(),
164 modularity: 0.0,
165 num_communities: 0,
166 };
167 }
168
169 let mut node_to_idx: FxHashMap<NodeId, usize> = FxHashMap::default();
171 for (idx, &node) in nodes.iter().enumerate() {
172 node_to_idx.insert(node, idx);
173 }
174
175 let mut weights: Vec<FxHashMap<usize, f64>> = vec![FxHashMap::default(); n];
178 let mut total_weight = 0.0;
179
180 for &node in &nodes {
181 let i = *node_to_idx.get(&node).unwrap();
182 for (neighbor, _edge_id) in store.edges_from(node, Direction::Outgoing) {
183 if let Some(&j) = node_to_idx.get(&neighbor) {
184 let w = 1.0; *weights[i].entry(j).or_insert(0.0) += w;
187 *weights[j].entry(i).or_insert(0.0) += w;
188 total_weight += w;
189 }
190 }
191 }
192
193 if total_weight == 0.0 {
195 let communities: FxHashMap<NodeId, u64> = nodes
196 .iter()
197 .enumerate()
198 .map(|(idx, &node)| (node, idx as u64))
199 .collect();
200 return LouvainResult {
201 communities,
202 modularity: 0.0,
203 num_communities: n,
204 };
205 }
206
207 let degrees: Vec<f64> = (0..n).map(|i| weights[i].values().sum()).collect();
209
210 let mut community: Vec<usize> = (0..n).collect();
212
213 let mut community_internal: FxHashMap<usize, f64> = FxHashMap::default();
215 let mut community_total: FxHashMap<usize, f64> = FxHashMap::default();
216
217 for i in 0..n {
218 community_total.insert(i, degrees[i]);
219 community_internal.insert(i, weights[i].get(&i).copied().unwrap_or(0.0));
220 }
221
222 let mut improved = true;
224 while improved {
225 improved = false;
226
227 for i in 0..n {
228 let current_comm = community[i];
229
230 let mut comm_links: FxHashMap<usize, f64> = FxHashMap::default();
232 for (&j, &w) in &weights[i] {
233 let c = community[j];
234 *comm_links.entry(c).or_insert(0.0) += w;
235 }
236
237 let mut best_delta = 0.0;
239 let mut best_comm = current_comm;
240
241 let ki = degrees[i];
243 let ki_in = comm_links.get(¤t_comm).copied().unwrap_or(0.0);
244
245 for (&target_comm, &k_i_to_comm) in &comm_links {
246 if target_comm == current_comm {
247 continue;
248 }
249
250 let sigma_tot = *community_total.get(&target_comm).unwrap_or(&0.0);
251
252 let delta = resolution
254 * (k_i_to_comm
255 - ki_in
256 - ki * (sigma_tot - community_total.get(¤t_comm).unwrap_or(&0.0)
257 + ki)
258 / (2.0 * total_weight));
259
260 if delta > best_delta {
261 best_delta = delta;
262 best_comm = target_comm;
263 }
264 }
265
266 if best_comm != current_comm {
267 *community_total.entry(current_comm).or_insert(0.0) -= ki;
270 *community_internal.entry(current_comm).or_insert(0.0) -=
271 2.0 * ki_in + weights[i].get(&i).copied().unwrap_or(0.0);
272
273 community[i] = best_comm;
274
275 *community_total.entry(best_comm).or_insert(0.0) += ki;
276 let k_i_best = comm_links.get(&best_comm).copied().unwrap_or(0.0);
277 *community_internal.entry(best_comm).or_insert(0.0) +=
278 2.0 * k_i_best + weights[i].get(&i).copied().unwrap_or(0.0);
279
280 improved = true;
281 }
282 }
283 }
284
285 let unique_comms: FxHashSet<usize> = community.iter().copied().collect();
287 let mut comm_map: FxHashMap<usize, u64> = FxHashMap::default();
288 for (idx, c) in unique_comms.iter().enumerate() {
289 comm_map.insert(*c, idx as u64);
290 }
291
292 let communities: FxHashMap<NodeId, u64> = nodes
293 .iter()
294 .enumerate()
295 .map(|(i, &node)| (node, *comm_map.get(&community[i]).unwrap()))
296 .collect();
297
298 let modularity = compute_modularity(&weights, &community, total_weight, resolution);
300
301 LouvainResult {
302 communities,
303 modularity,
304 num_communities: unique_comms.len(),
305 }
306}
307
308fn compute_modularity(
310 weights: &[FxHashMap<usize, f64>],
311 community: &[usize],
312 total_weight: f64,
313 resolution: f64,
314) -> f64 {
315 let n = community.len();
316 let m2 = 2.0 * total_weight;
317
318 if m2 == 0.0 {
319 return 0.0;
320 }
321
322 let degrees: Vec<f64> = (0..n).map(|i| weights[i].values().sum()).collect();
323
324 let mut modularity = 0.0;
325
326 for i in 0..n {
327 for (&j, &a_ij) in &weights[i] {
328 if community[i] == community[j] {
329 modularity += a_ij - resolution * degrees[i] * degrees[j] / m2;
330 }
331 }
332 }
333
334 modularity / m2
335}
336
337pub fn community_count(communities: &FxHashMap<NodeId, u64>) -> usize {
339 let unique: FxHashSet<u64> = communities.values().copied().collect();
340 unique.len()
341}
342
343static LABEL_PROP_PARAMS: OnceLock<Vec<ParameterDef>> = OnceLock::new();
349
350fn label_prop_params() -> &'static [ParameterDef] {
351 LABEL_PROP_PARAMS.get_or_init(|| {
352 vec![ParameterDef {
353 name: "max_iterations".to_string(),
354 description: "Maximum iterations (0 for unlimited, default: 100)".to_string(),
355 param_type: ParameterType::Integer,
356 required: false,
357 default: Some("100".to_string()),
358 }]
359 })
360}
361
362pub struct LabelPropagationAlgorithm;
364
365impl GraphAlgorithm for LabelPropagationAlgorithm {
366 fn name(&self) -> &str {
367 "label_propagation"
368 }
369
370 fn description(&self) -> &str {
371 "Label Propagation community detection"
372 }
373
374 fn parameters(&self) -> &[ParameterDef] {
375 label_prop_params()
376 }
377
378 fn execute(&self, store: &LpgStore, params: &Parameters) -> Result<AlgorithmResult> {
379 let max_iter = params.get_int("max_iterations").unwrap_or(100) as usize;
380
381 let communities = label_propagation(store, max_iter);
382
383 let mut builder = ComponentResultBuilder::with_capacity(communities.len());
384 for (node, community_id) in communities {
385 builder.push(node, community_id);
386 }
387
388 Ok(builder.build())
389 }
390}
391
392static LOUVAIN_PARAMS: OnceLock<Vec<ParameterDef>> = OnceLock::new();
394
395fn louvain_params() -> &'static [ParameterDef] {
396 LOUVAIN_PARAMS.get_or_init(|| {
397 vec![ParameterDef {
398 name: "resolution".to_string(),
399 description: "Resolution parameter (default: 1.0)".to_string(),
400 param_type: ParameterType::Float,
401 required: false,
402 default: Some("1.0".to_string()),
403 }]
404 })
405}
406
407pub struct LouvainAlgorithm;
409
410impl GraphAlgorithm for LouvainAlgorithm {
411 fn name(&self) -> &str {
412 "louvain"
413 }
414
415 fn description(&self) -> &str {
416 "Louvain community detection (modularity optimization)"
417 }
418
419 fn parameters(&self) -> &[ParameterDef] {
420 louvain_params()
421 }
422
423 fn execute(&self, store: &LpgStore, params: &Parameters) -> Result<AlgorithmResult> {
424 let resolution = params.get_float("resolution").unwrap_or(1.0);
425
426 let result = louvain(store, resolution);
427
428 let mut output = AlgorithmResult::new(vec![
429 "node_id".to_string(),
430 "community_id".to_string(),
431 "modularity".to_string(),
432 ]);
433
434 for (node, community_id) in result.communities {
435 output.add_row(vec![
436 Value::Int64(node.0 as i64),
437 Value::Int64(community_id as i64),
438 Value::Float64(result.modularity),
439 ]);
440 }
441
442 Ok(output)
443 }
444}
445
446#[cfg(test)]
451mod tests {
452 use super::*;
453
454 fn create_two_cliques_graph() -> LpgStore {
455 let store = LpgStore::new();
460
461 let nodes: Vec<NodeId> = (0..8).map(|_| store.create_node(&["Node"])).collect();
462
463 for i in 0..4 {
465 for j in (i + 1)..4 {
466 store.create_edge(nodes[i], nodes[j], "EDGE");
467 store.create_edge(nodes[j], nodes[i], "EDGE");
468 }
469 }
470
471 for i in 4..8 {
473 for j in (i + 1)..8 {
474 store.create_edge(nodes[i], nodes[j], "EDGE");
475 store.create_edge(nodes[j], nodes[i], "EDGE");
476 }
477 }
478
479 store.create_edge(nodes[3], nodes[4], "EDGE");
481 store.create_edge(nodes[4], nodes[3], "EDGE");
482
483 store
484 }
485
486 fn create_simple_graph() -> LpgStore {
487 let store = LpgStore::new();
488
489 let n0 = store.create_node(&["Node"]);
491 let n1 = store.create_node(&["Node"]);
492 let n2 = store.create_node(&["Node"]);
493
494 store.create_edge(n0, n1, "EDGE");
495 store.create_edge(n1, n2, "EDGE");
496
497 store
498 }
499
500 #[test]
501 fn test_label_propagation_basic() {
502 let store = create_simple_graph();
503 let communities = label_propagation(&store, 100);
504
505 assert_eq!(communities.len(), 3);
506
507 for (_, &comm) in &communities {
509 assert!(comm < 3);
510 }
511 }
512
513 #[test]
514 fn test_label_propagation_cliques() {
515 let store = create_two_cliques_graph();
516 let communities = label_propagation(&store, 100);
517
518 assert_eq!(communities.len(), 8);
519
520 let num_comms = community_count(&communities);
522 assert!(num_comms >= 1 && num_comms <= 8); }
524
525 #[test]
526 fn test_label_propagation_empty() {
527 let store = LpgStore::new();
528 let communities = label_propagation(&store, 100);
529 assert!(communities.is_empty());
530 }
531
532 #[test]
533 fn test_label_propagation_single_node() {
534 let store = LpgStore::new();
535 store.create_node(&["Node"]);
536
537 let communities = label_propagation(&store, 100);
538 assert_eq!(communities.len(), 1);
539 }
540
541 #[test]
542 fn test_louvain_basic() {
543 let store = create_simple_graph();
544 let result = louvain(&store, 1.0);
545
546 assert_eq!(result.communities.len(), 3);
547 assert!(result.num_communities >= 1);
548 }
549
550 #[test]
551 fn test_louvain_cliques() {
552 let store = create_two_cliques_graph();
553 let result = louvain(&store, 1.0);
554
555 assert_eq!(result.communities.len(), 8);
556
557 assert!(result.num_communities >= 1 && result.num_communities <= 8);
560 }
561
562 #[test]
563 fn test_louvain_empty() {
564 let store = LpgStore::new();
565 let result = louvain(&store, 1.0);
566
567 assert!(result.communities.is_empty());
568 assert_eq!(result.modularity, 0.0);
569 assert_eq!(result.num_communities, 0);
570 }
571
572 #[test]
573 fn test_louvain_isolated_nodes() {
574 let store = LpgStore::new();
575 store.create_node(&["Node"]);
576 store.create_node(&["Node"]);
577 store.create_node(&["Node"]);
578
579 let result = louvain(&store, 1.0);
580
581 assert_eq!(result.communities.len(), 3);
583 assert_eq!(result.num_communities, 3);
584 }
585
586 #[test]
587 fn test_louvain_resolution_parameter() {
588 let store = create_two_cliques_graph();
589
590 let result_low = louvain(&store, 0.5);
592
593 let result_high = louvain(&store, 2.0);
595
596 assert!(!result_low.communities.is_empty());
598 assert!(!result_high.communities.is_empty());
599 }
600
601 #[test]
602 fn test_community_count() {
603 let mut communities: FxHashMap<NodeId, u64> = FxHashMap::default();
604 communities.insert(NodeId::new(0), 0);
605 communities.insert(NodeId::new(1), 0);
606 communities.insert(NodeId::new(2), 1);
607 communities.insert(NodeId::new(3), 1);
608 communities.insert(NodeId::new(4), 2);
609
610 assert_eq!(community_count(&communities), 3);
611 }
612}