Skip to main content

grafeo_adapters/plugins/algorithms/
community.rs

1//! Community detection algorithms: Louvain, Label Propagation.
2//!
3//! These algorithms identify clusters or communities of nodes that are
4//! more densely connected to each other than to the rest of the graph.
5
6use std::sync::OnceLock;
7
8use grafeo_common::types::{NodeId, Value};
9use grafeo_common::utils::error::Result;
10use grafeo_common::utils::hash::{FxHashMap, FxHashSet};
11use grafeo_core::graph::Direction;
12use grafeo_core::graph::lpg::LpgStore;
13
14use super::super::{AlgorithmResult, ParameterDef, ParameterType, Parameters};
15use super::traits::{ComponentResultBuilder, GraphAlgorithm};
16
17// ============================================================================
18// Label Propagation
19// ============================================================================
20
21/// Detects communities using the Label Propagation Algorithm.
22///
23/// Each node is initially assigned a unique label. Then, iteratively,
24/// each node adopts the most frequent label among its neighbors until
25/// the labels stabilize.
26///
27/// # Arguments
28///
29/// * `store` - The graph store
30/// * `max_iterations` - Maximum number of iterations (0 for unlimited)
31///
32/// # Returns
33///
34/// A map from node ID to community (label) ID.
35///
36/// # Complexity
37///
38/// O(iterations × E)
39pub fn label_propagation(store: &LpgStore, max_iterations: usize) -> FxHashMap<NodeId, u64> {
40    let nodes = store.node_ids();
41    let n = nodes.len();
42
43    if n == 0 {
44        return FxHashMap::default();
45    }
46
47    // Initialize labels: each node gets its own unique label
48    let mut labels: FxHashMap<NodeId, u64> = FxHashMap::default();
49    for (idx, &node) in nodes.iter().enumerate() {
50        labels.insert(node, idx as u64);
51    }
52
53    let max_iter = if max_iterations == 0 {
54        n * 10
55    } else {
56        max_iterations
57    };
58
59    for _ in 0..max_iter {
60        let mut changed = false;
61
62        // Update labels in random order (here we use insertion order)
63        for &node in &nodes {
64            // Get neighbor labels and their frequencies
65            let mut label_counts: FxHashMap<u64, usize> = FxHashMap::default();
66
67            // Consider both outgoing and incoming edges (undirected community detection)
68            // Outgoing edges: node -> neighbor
69            for (neighbor, _) in store.edges_from(node, Direction::Outgoing) {
70                if let Some(&label) = labels.get(&neighbor) {
71                    *label_counts.entry(label).or_insert(0) += 1;
72                }
73            }
74
75            // Incoming edges: neighbor -> node
76            // Uses backward adjacency index for O(degree) instead of O(V*E)
77            for (incoming_neighbor, _) in store.edges_from(node, Direction::Incoming) {
78                if let Some(&label) = labels.get(&incoming_neighbor) {
79                    *label_counts.entry(label).or_insert(0) += 1;
80                }
81            }
82
83            if label_counts.is_empty() {
84                continue;
85            }
86
87            // Find the most frequent label
88            let max_count = *label_counts.values().max().unwrap_or(&0);
89            let max_labels: Vec<u64> = label_counts
90                .into_iter()
91                .filter(|&(_, count)| count == max_count)
92                .map(|(label, _)| label)
93                .collect();
94
95            // Choose the smallest label in case of tie (deterministic)
96            let new_label = *max_labels.iter().min().unwrap();
97            let current_label = *labels.get(&node).unwrap();
98
99            if new_label != current_label {
100                labels.insert(node, new_label);
101                changed = true;
102            }
103        }
104
105        if !changed {
106            break;
107        }
108    }
109
110    // Normalize labels to be contiguous starting from 0
111    let unique_labels: FxHashSet<u64> = labels.values().copied().collect();
112    let mut label_map: FxHashMap<u64, u64> = FxHashMap::default();
113    for (idx, label) in unique_labels.into_iter().enumerate() {
114        label_map.insert(label, idx as u64);
115    }
116
117    labels
118        .into_iter()
119        .map(|(node, label)| (node, *label_map.get(&label).unwrap()))
120        .collect()
121}
122
123// ============================================================================
124// Louvain Algorithm
125// ============================================================================
126
127/// Result of Louvain algorithm.
128#[derive(Debug, Clone)]
129pub struct LouvainResult {
130    /// Community assignment for each node.
131    pub communities: FxHashMap<NodeId, u64>,
132    /// Final modularity score.
133    pub modularity: f64,
134    /// Number of communities detected.
135    pub num_communities: usize,
136}
137
138/// Detects communities using the Louvain algorithm.
139///
140/// The Louvain algorithm optimizes modularity through a greedy approach,
141/// consisting of two phases that are repeated iteratively:
142/// 1. Local optimization: Move nodes to neighboring communities if it increases modularity
143/// 2. Aggregation: Build a new graph where communities become super-nodes
144///
145/// # Arguments
146///
147/// * `store` - The graph store
148/// * `resolution` - Resolution parameter (higher = smaller communities, default 1.0)
149///
150/// # Returns
151///
152/// Community assignments and modularity score.
153///
154/// # Complexity
155///
156/// O(V log V) on average for sparse graphs
157pub fn louvain(store: &LpgStore, resolution: f64) -> LouvainResult {
158    let nodes = store.node_ids();
159    let n = nodes.len();
160
161    if n == 0 {
162        return LouvainResult {
163            communities: FxHashMap::default(),
164            modularity: 0.0,
165            num_communities: 0,
166        };
167    }
168
169    // Build node index mapping
170    let mut node_to_idx: FxHashMap<NodeId, usize> = FxHashMap::default();
171    for (idx, &node) in nodes.iter().enumerate() {
172        node_to_idx.insert(node, idx);
173    }
174
175    // Build adjacency with weights (for undirected graph)
176    // weights[i][j] = weight of edge between nodes i and j
177    let mut weights: Vec<FxHashMap<usize, f64>> = vec![FxHashMap::default(); n];
178    let mut total_weight = 0.0;
179
180    for &node in &nodes {
181        let i = *node_to_idx.get(&node).unwrap();
182        for (neighbor, _edge_id) in store.edges_from(node, Direction::Outgoing) {
183            if let Some(&j) = node_to_idx.get(&neighbor) {
184                // For undirected: add weight to both directions
185                let w = 1.0; // Could extract from edge property
186                *weights[i].entry(j).or_insert(0.0) += w;
187                *weights[j].entry(i).or_insert(0.0) += w;
188                total_weight += w;
189            }
190        }
191    }
192
193    // Handle isolated nodes
194    if total_weight == 0.0 {
195        let communities: FxHashMap<NodeId, u64> = nodes
196            .iter()
197            .enumerate()
198            .map(|(idx, &node)| (node, idx as u64))
199            .collect();
200        return LouvainResult {
201            communities,
202            modularity: 0.0,
203            num_communities: n,
204        };
205    }
206
207    // Compute node degrees (sum of incident edge weights)
208    let degrees: Vec<f64> = (0..n).map(|i| weights[i].values().sum()).collect();
209
210    // Initialize: each node in its own community
211    let mut community: Vec<usize> = (0..n).collect();
212
213    // Community internal weights and total weights
214    let mut community_internal: FxHashMap<usize, f64> = FxHashMap::default();
215    let mut community_total: FxHashMap<usize, f64> = FxHashMap::default();
216
217    for i in 0..n {
218        community_total.insert(i, degrees[i]);
219        community_internal.insert(i, weights[i].get(&i).copied().unwrap_or(0.0));
220    }
221
222    // Phase 1: Local optimization
223    let mut improved = true;
224    while improved {
225        improved = false;
226
227        for i in 0..n {
228            let current_comm = community[i];
229
230            // Compute links to each neighboring community
231            let mut comm_links: FxHashMap<usize, f64> = FxHashMap::default();
232            for (&j, &w) in &weights[i] {
233                let c = community[j];
234                *comm_links.entry(c).or_insert(0.0) += w;
235            }
236
237            // Try moving to each neighboring community
238            let mut best_delta = 0.0;
239            let mut best_comm = current_comm;
240
241            // Remove node from current community for delta calculation
242            let ki = degrees[i];
243            let ki_in = comm_links.get(&current_comm).copied().unwrap_or(0.0);
244
245            for (&target_comm, &k_i_to_comm) in &comm_links {
246                if target_comm == current_comm {
247                    continue;
248                }
249
250                let sigma_tot = *community_total.get(&target_comm).unwrap_or(&0.0);
251
252                // Modularity delta for moving to target_comm
253                let delta = resolution
254                    * (k_i_to_comm
255                        - ki_in
256                        - ki * (sigma_tot - community_total.get(&current_comm).unwrap_or(&0.0)
257                            + ki)
258                            / (2.0 * total_weight));
259
260                if delta > best_delta {
261                    best_delta = delta;
262                    best_comm = target_comm;
263                }
264            }
265
266            if best_comm != current_comm {
267                // Move node to best community
268                // Update community statistics
269                *community_total.entry(current_comm).or_insert(0.0) -= ki;
270                *community_internal.entry(current_comm).or_insert(0.0) -=
271                    2.0 * ki_in + weights[i].get(&i).copied().unwrap_or(0.0);
272
273                community[i] = best_comm;
274
275                *community_total.entry(best_comm).or_insert(0.0) += ki;
276                let k_i_best = comm_links.get(&best_comm).copied().unwrap_or(0.0);
277                *community_internal.entry(best_comm).or_insert(0.0) +=
278                    2.0 * k_i_best + weights[i].get(&i).copied().unwrap_or(0.0);
279
280                improved = true;
281            }
282        }
283    }
284
285    // Normalize community IDs
286    let unique_comms: FxHashSet<usize> = community.iter().copied().collect();
287    let mut comm_map: FxHashMap<usize, u64> = FxHashMap::default();
288    for (idx, c) in unique_comms.iter().enumerate() {
289        comm_map.insert(*c, idx as u64);
290    }
291
292    let communities: FxHashMap<NodeId, u64> = nodes
293        .iter()
294        .enumerate()
295        .map(|(i, &node)| (node, *comm_map.get(&community[i]).unwrap()))
296        .collect();
297
298    // Compute final modularity
299    let modularity = compute_modularity(&weights, &community, total_weight, resolution);
300
301    LouvainResult {
302        communities,
303        modularity,
304        num_communities: unique_comms.len(),
305    }
306}
307
308/// Computes the modularity of a community assignment.
309fn compute_modularity(
310    weights: &[FxHashMap<usize, f64>],
311    community: &[usize],
312    total_weight: f64,
313    resolution: f64,
314) -> f64 {
315    let n = community.len();
316    let m2 = 2.0 * total_weight;
317
318    if m2 == 0.0 {
319        return 0.0;
320    }
321
322    let degrees: Vec<f64> = (0..n).map(|i| weights[i].values().sum()).collect();
323
324    let mut modularity = 0.0;
325
326    for i in 0..n {
327        for (&j, &a_ij) in &weights[i] {
328            if community[i] == community[j] {
329                modularity += a_ij - resolution * degrees[i] * degrees[j] / m2;
330            }
331        }
332    }
333
334    modularity / m2
335}
336
337/// Returns the number of communities detected.
338pub fn community_count(communities: &FxHashMap<NodeId, u64>) -> usize {
339    let unique: FxHashSet<u64> = communities.values().copied().collect();
340    unique.len()
341}
342
343// ============================================================================
344// Algorithm Wrappers for Plugin Registry
345// ============================================================================
346
347/// Static parameter definitions for Label Propagation algorithm.
348static LABEL_PROP_PARAMS: OnceLock<Vec<ParameterDef>> = OnceLock::new();
349
350fn label_prop_params() -> &'static [ParameterDef] {
351    LABEL_PROP_PARAMS.get_or_init(|| {
352        vec![ParameterDef {
353            name: "max_iterations".to_string(),
354            description: "Maximum iterations (0 for unlimited, default: 100)".to_string(),
355            param_type: ParameterType::Integer,
356            required: false,
357            default: Some("100".to_string()),
358        }]
359    })
360}
361
362/// Label Propagation algorithm wrapper.
363pub struct LabelPropagationAlgorithm;
364
365impl GraphAlgorithm for LabelPropagationAlgorithm {
366    fn name(&self) -> &str {
367        "label_propagation"
368    }
369
370    fn description(&self) -> &str {
371        "Label Propagation community detection"
372    }
373
374    fn parameters(&self) -> &[ParameterDef] {
375        label_prop_params()
376    }
377
378    fn execute(&self, store: &LpgStore, params: &Parameters) -> Result<AlgorithmResult> {
379        let max_iter = params.get_int("max_iterations").unwrap_or(100) as usize;
380
381        let communities = label_propagation(store, max_iter);
382
383        let mut builder = ComponentResultBuilder::with_capacity(communities.len());
384        for (node, community_id) in communities {
385            builder.push(node, community_id);
386        }
387
388        Ok(builder.build())
389    }
390}
391
392/// Static parameter definitions for Louvain algorithm.
393static LOUVAIN_PARAMS: OnceLock<Vec<ParameterDef>> = OnceLock::new();
394
395fn louvain_params() -> &'static [ParameterDef] {
396    LOUVAIN_PARAMS.get_or_init(|| {
397        vec![ParameterDef {
398            name: "resolution".to_string(),
399            description: "Resolution parameter (default: 1.0)".to_string(),
400            param_type: ParameterType::Float,
401            required: false,
402            default: Some("1.0".to_string()),
403        }]
404    })
405}
406
407/// Louvain algorithm wrapper.
408pub struct LouvainAlgorithm;
409
410impl GraphAlgorithm for LouvainAlgorithm {
411    fn name(&self) -> &str {
412        "louvain"
413    }
414
415    fn description(&self) -> &str {
416        "Louvain community detection (modularity optimization)"
417    }
418
419    fn parameters(&self) -> &[ParameterDef] {
420        louvain_params()
421    }
422
423    fn execute(&self, store: &LpgStore, params: &Parameters) -> Result<AlgorithmResult> {
424        let resolution = params.get_float("resolution").unwrap_or(1.0);
425
426        let result = louvain(store, resolution);
427
428        let mut output = AlgorithmResult::new(vec![
429            "node_id".to_string(),
430            "community_id".to_string(),
431            "modularity".to_string(),
432        ]);
433
434        for (node, community_id) in result.communities {
435            output.add_row(vec![
436                Value::Int64(node.0 as i64),
437                Value::Int64(community_id as i64),
438                Value::Float64(result.modularity),
439            ]);
440        }
441
442        Ok(output)
443    }
444}
445
446// ============================================================================
447// Tests
448// ============================================================================
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    fn create_two_cliques_graph() -> LpgStore {
455        // Two cliques connected by one edge
456        // Clique 1: 0-1-2-3 (fully connected)
457        // Clique 2: 4-5-6-7 (fully connected)
458        // Bridge: 3-4
459        let store = LpgStore::new();
460
461        let nodes: Vec<NodeId> = (0..8).map(|_| store.create_node(&["Node"])).collect();
462
463        // Clique 1
464        for i in 0..4 {
465            for j in (i + 1)..4 {
466                store.create_edge(nodes[i], nodes[j], "EDGE");
467                store.create_edge(nodes[j], nodes[i], "EDGE");
468            }
469        }
470
471        // Clique 2
472        for i in 4..8 {
473            for j in (i + 1)..8 {
474                store.create_edge(nodes[i], nodes[j], "EDGE");
475                store.create_edge(nodes[j], nodes[i], "EDGE");
476            }
477        }
478
479        // Bridge
480        store.create_edge(nodes[3], nodes[4], "EDGE");
481        store.create_edge(nodes[4], nodes[3], "EDGE");
482
483        store
484    }
485
486    fn create_simple_graph() -> LpgStore {
487        let store = LpgStore::new();
488
489        // Simple chain: 0 -> 1 -> 2
490        let n0 = store.create_node(&["Node"]);
491        let n1 = store.create_node(&["Node"]);
492        let n2 = store.create_node(&["Node"]);
493
494        store.create_edge(n0, n1, "EDGE");
495        store.create_edge(n1, n2, "EDGE");
496
497        store
498    }
499
500    #[test]
501    fn test_label_propagation_basic() {
502        let store = create_simple_graph();
503        let communities = label_propagation(&store, 100);
504
505        assert_eq!(communities.len(), 3);
506
507        // All nodes should have some community assignment
508        for (_, &comm) in &communities {
509            assert!(comm < 3);
510        }
511    }
512
513    #[test]
514    fn test_label_propagation_cliques() {
515        let store = create_two_cliques_graph();
516        let communities = label_propagation(&store, 100);
517
518        assert_eq!(communities.len(), 8);
519
520        // Should detect 2 communities (ideally)
521        let num_comms = community_count(&communities);
522        assert!(num_comms >= 1 && num_comms <= 8); // May vary due to algorithm randomness
523    }
524
525    #[test]
526    fn test_label_propagation_empty() {
527        let store = LpgStore::new();
528        let communities = label_propagation(&store, 100);
529        assert!(communities.is_empty());
530    }
531
532    #[test]
533    fn test_label_propagation_single_node() {
534        let store = LpgStore::new();
535        store.create_node(&["Node"]);
536
537        let communities = label_propagation(&store, 100);
538        assert_eq!(communities.len(), 1);
539    }
540
541    #[test]
542    fn test_louvain_basic() {
543        let store = create_simple_graph();
544        let result = louvain(&store, 1.0);
545
546        assert_eq!(result.communities.len(), 3);
547        assert!(result.num_communities >= 1);
548    }
549
550    #[test]
551    fn test_louvain_cliques() {
552        let store = create_two_cliques_graph();
553        let result = louvain(&store, 1.0);
554
555        assert_eq!(result.communities.len(), 8);
556
557        // Should detect approximately 2 communities
558        // Louvain should find good modularity
559        assert!(result.num_communities >= 1 && result.num_communities <= 8);
560    }
561
562    #[test]
563    fn test_louvain_empty() {
564        let store = LpgStore::new();
565        let result = louvain(&store, 1.0);
566
567        assert!(result.communities.is_empty());
568        assert_eq!(result.modularity, 0.0);
569        assert_eq!(result.num_communities, 0);
570    }
571
572    #[test]
573    fn test_louvain_isolated_nodes() {
574        let store = LpgStore::new();
575        store.create_node(&["Node"]);
576        store.create_node(&["Node"]);
577        store.create_node(&["Node"]);
578
579        let result = louvain(&store, 1.0);
580
581        // Each isolated node should be its own community
582        assert_eq!(result.communities.len(), 3);
583        assert_eq!(result.num_communities, 3);
584    }
585
586    #[test]
587    fn test_louvain_resolution_parameter() {
588        let store = create_two_cliques_graph();
589
590        // Low resolution: fewer, larger communities
591        let result_low = louvain(&store, 0.5);
592
593        // High resolution: more, smaller communities
594        let result_high = louvain(&store, 2.0);
595
596        // Both should be valid
597        assert!(!result_low.communities.is_empty());
598        assert!(!result_high.communities.is_empty());
599    }
600
601    #[test]
602    fn test_community_count() {
603        let mut communities: FxHashMap<NodeId, u64> = FxHashMap::default();
604        communities.insert(NodeId::new(0), 0);
605        communities.insert(NodeId::new(1), 0);
606        communities.insert(NodeId::new(2), 1);
607        communities.insert(NodeId::new(3), 1);
608        communities.insert(NodeId::new(4), 2);
609
610        assert_eq!(community_count(&communities), 3);
611    }
612}