Skip to main content

manifoldb_graph/analytics/
community.rs

1//! Community Detection using Label Propagation Algorithm.
2//!
3//! This module implements community detection using the Label Propagation
4//! Algorithm (LPA). LPA is a fast, near-linear time algorithm for detecting
5//! communities in large networks.
6//!
7//! # Algorithm
8//!
9//! Label Propagation works by:
10//! 1. Initialize each node with a unique community label
11//! 2. For each iteration:
12//!    - Update each node's label to the most frequent label among neighbors
13//!    - If there are ties, randomly select one of the most frequent labels
14//! 3. Repeat until no labels change (convergence) or max iterations reached
15//!
16//! # Properties
17//!
18//! - Near-linear time complexity O(m) where m is the number of edges
19//! - No prior knowledge of the number of communities required
20//! - Non-deterministic (different runs may produce different results)
21//! - Works well for networks with clear community structure
22//!
23//! # Example
24//!
25//! ```ignore
26//! use manifoldb_graph::analytics::{CommunityDetection, CommunityDetectionConfig};
27//!
28//! let config = CommunityDetectionConfig::default()
29//!     .with_max_iterations(100);
30//!
31//! let result = CommunityDetection::label_propagation(&tx, &config)?;
32//!
33//! // Get community assignments
34//! for (node, community) in result.assignments.iter() {
35//!     println!("Node {:?} belongs to community {}", node, community);
36//! }
37//!
38//! // Get community sizes
39//! for (community_id, size) in result.community_sizes() {
40//!     println!("Community {} has {} members", community_id, size);
41//! }
42//! ```
43
44use std::collections::HashMap;
45
46use manifoldb_core::EntityId;
47use manifoldb_storage::Transaction;
48
49use crate::index::AdjacencyIndex;
50use crate::store::{EdgeStore, GraphError, GraphResult, NodeStore};
51use crate::traversal::Direction;
52
53use super::pagerank::DEFAULT_MAX_GRAPH_NODES;
54
55/// Configuration for Community Detection.
56#[derive(Debug, Clone)]
57pub struct CommunityDetectionConfig {
58    /// Maximum number of iterations.
59    /// Default: 100
60    pub max_iterations: usize,
61
62    /// Direction of edges to follow.
63    /// Default: Both (treat as undirected)
64    pub direction: Direction,
65
66    /// Seed for random number generation (for reproducibility).
67    /// Default: None (use system entropy)
68    pub seed: Option<u64>,
69
70    /// Minimum improvement in modularity to continue iterating.
71    /// Set to 0.0 to only check for label stability.
72    /// Default: 0.0
73    pub min_improvement: f64,
74
75    /// Maximum number of nodes allowed before returning an error.
76    /// Set to `None` to disable the check.
77    /// Default: 10,000,000 (10M nodes)
78    pub max_graph_nodes: Option<usize>,
79}
80
81impl Default for CommunityDetectionConfig {
82    fn default() -> Self {
83        Self {
84            max_iterations: 100,
85            direction: Direction::Both,
86            seed: None,
87            min_improvement: 0.0,
88            max_graph_nodes: Some(DEFAULT_MAX_GRAPH_NODES),
89        }
90    }
91}
92
93impl CommunityDetectionConfig {
94    /// Create a new configuration with default values.
95    pub fn new() -> Self {
96        Self::default()
97    }
98
99    /// Set the maximum number of iterations.
100    pub const fn with_max_iterations(mut self, max_iterations: usize) -> Self {
101        self.max_iterations = max_iterations;
102        self
103    }
104
105    /// Set the direction to follow edges.
106    pub const fn with_direction(mut self, direction: Direction) -> Self {
107        self.direction = direction;
108        self
109    }
110
111    /// Set the seed for reproducible results.
112    pub const fn with_seed(mut self, seed: u64) -> Self {
113        self.seed = Some(seed);
114        self
115    }
116
117    /// Set the minimum improvement threshold.
118    pub const fn with_min_improvement(mut self, min_improvement: f64) -> Self {
119        self.min_improvement = min_improvement;
120        self
121    }
122
123    /// Set the maximum number of nodes allowed.
124    ///
125    /// If the graph has more nodes than this limit, the algorithm will
126    /// return a [`GraphError::GraphTooLarge`] error instead of attempting
127    /// to allocate potentially gigabytes of memory.
128    ///
129    /// Set to `None` to disable the check (use with caution).
130    ///
131    /// [`GraphError::GraphTooLarge`]: crate::store::GraphError::GraphTooLarge
132    pub const fn with_max_graph_nodes(mut self, limit: Option<usize>) -> Self {
133        self.max_graph_nodes = limit;
134        self
135    }
136}
137
138/// Result of community detection.
139#[derive(Debug, Clone)]
140pub struct CommunityResult {
141    /// Community assignments: node -> community ID
142    pub assignments: HashMap<EntityId, u64>,
143
144    /// Number of iterations performed.
145    pub iterations: usize,
146
147    /// Whether the algorithm converged.
148    pub converged: bool,
149
150    /// Number of distinct communities found.
151    pub num_communities: usize,
152}
153
154impl CommunityResult {
155    /// Get the community ID for a specific node.
156    pub fn community(&self, node: EntityId) -> Option<u64> {
157        self.assignments.get(&node).copied()
158    }
159
160    /// Get all nodes in a specific community.
161    pub fn members(&self, community_id: u64) -> Vec<EntityId> {
162        self.assignments.iter().filter(|(_, &c)| c == community_id).map(|(&node, _)| node).collect()
163    }
164
165    /// Get community sizes.
166    pub fn community_sizes(&self) -> HashMap<u64, usize> {
167        let mut sizes: HashMap<u64, usize> = HashMap::new();
168        for &community in self.assignments.values() {
169            *sizes.entry(community).or_insert(0) += 1;
170        }
171        sizes
172    }
173
174    /// Get communities sorted by size (descending).
175    pub fn communities_by_size(&self) -> Vec<(u64, usize)> {
176        let mut sizes: Vec<_> = self.community_sizes().into_iter().collect();
177        sizes.sort_by(|a, b| b.1.cmp(&a.1));
178        sizes
179    }
180
181    /// Get the largest community.
182    pub fn largest_community(&self) -> Option<(u64, usize)> {
183        self.communities_by_size().into_iter().next()
184    }
185
186    /// Get the smallest community.
187    pub fn smallest_community(&self) -> Option<(u64, usize)> {
188        self.communities_by_size().into_iter().last()
189    }
190
191    /// Check if two nodes are in the same community.
192    pub fn same_community(&self, node1: EntityId, node2: EntityId) -> bool {
193        match (self.community(node1), self.community(node2)) {
194            (Some(c1), Some(c2)) => c1 == c2,
195            _ => false,
196        }
197    }
198}
199
200/// Community Detection algorithm implementations.
201pub struct CommunityDetection;
202
203impl CommunityDetection {
204    /// Detect communities using Label Propagation Algorithm.
205    ///
206    /// This is a fast, near-linear time algorithm that doesn't require
207    /// knowing the number of communities in advance.
208    ///
209    /// # Arguments
210    ///
211    /// * `tx` - The transaction to use for graph access
212    /// * `config` - Configuration parameters for the algorithm
213    ///
214    /// # Returns
215    ///
216    /// A `CommunityResult` containing community assignments for all nodes.
217    pub fn label_propagation<T: Transaction>(
218        tx: &T,
219        config: &CommunityDetectionConfig,
220    ) -> GraphResult<CommunityResult> {
221        // Check graph size before allocating large data structures
222        if let Some(limit) = config.max_graph_nodes {
223            let node_count = NodeStore::count(tx)?;
224            if node_count > limit {
225                return Err(GraphError::GraphTooLarge { node_count, limit });
226            }
227        }
228
229        // Collect all nodes
230        let mut nodes: Vec<EntityId> = Vec::new();
231        NodeStore::for_each(tx, |entity| {
232            nodes.push(entity.id);
233            true
234        })?;
235
236        let n = nodes.len();
237        if n == 0 {
238            return Ok(CommunityResult {
239                assignments: HashMap::new(),
240                iterations: 0,
241                converged: true,
242                num_communities: 0,
243            });
244        }
245
246        // Build node index
247        let node_index: HashMap<EntityId, usize> =
248            nodes.iter().enumerate().map(|(i, &id)| (id, i)).collect();
249
250        // Build adjacency lists
251        let mut neighbors: Vec<Vec<usize>> = vec![Vec::new(); n];
252        for (i, &node) in nodes.iter().enumerate() {
253            let neighbor_ids = Self::get_neighbors(tx, node, config.direction)?;
254            for neighbor_id in neighbor_ids {
255                if let Some(&j) = node_index.get(&neighbor_id) {
256                    neighbors[i].push(j);
257                }
258            }
259        }
260
261        // Initialize labels (each node starts with its own label)
262        let mut labels: Vec<u64> = (0..n as u64).collect();
263
264        // Simple pseudo-random number generator
265        let mut rng_state = config.seed.unwrap_or(12345);
266        let simple_random = |state: &mut u64| -> u64 {
267            // Linear congruential generator (constants from PCG)
268            *state = state
269                .wrapping_mul(6_364_136_223_846_793_005)
270                .wrapping_add(1_442_695_040_888_963_407);
271            *state
272        };
273
274        let mut iterations = 0;
275        let mut converged = false;
276
277        // Create a shuffled order for processing nodes
278        let mut order: Vec<usize> = (0..n).collect();
279
280        while iterations < config.max_iterations {
281            iterations += 1;
282
283            // Shuffle node processing order for better convergence
284            for i in (1..n).rev() {
285                let j = (simple_random(&mut rng_state) as usize) % (i + 1);
286                order.swap(i, j);
287            }
288
289            let mut changed = false;
290
291            for &i in &order {
292                if neighbors[i].is_empty() {
293                    continue;
294                }
295
296                // Count neighbor labels
297                let mut label_counts: HashMap<u64, usize> = HashMap::new();
298                for &j in &neighbors[i] {
299                    *label_counts.entry(labels[j]).or_insert(0) += 1;
300                }
301
302                // Find maximum count
303                let max_count = *label_counts.values().max().unwrap_or(&0);
304
305                // Get all labels with maximum count
306                let max_labels: Vec<u64> = label_counts
307                    .iter()
308                    .filter(|(_, &count)| count == max_count)
309                    .map(|(&label, _)| label)
310                    .collect();
311
312                // Select one of the max labels (preferring current label if tied)
313                let new_label = if max_labels.contains(&labels[i]) {
314                    labels[i]
315                } else if max_labels.len() == 1 {
316                    max_labels[0]
317                } else {
318                    // Random selection among tied labels
319                    let idx = (simple_random(&mut rng_state) as usize) % max_labels.len();
320                    max_labels[idx]
321                };
322
323                if new_label != labels[i] {
324                    labels[i] = new_label;
325                    changed = true;
326                }
327            }
328
329            if !changed {
330                converged = true;
331                break;
332            }
333        }
334
335        // Renumber communities to be contiguous starting from 0
336        let mut community_map: HashMap<u64, u64> = HashMap::new();
337        let mut next_id = 0u64;
338        for label in &mut labels {
339            let new_id = *community_map.entry(*label).or_insert_with(|| {
340                let id = next_id;
341                next_id += 1;
342                id
343            });
344            *label = new_id;
345        }
346
347        // Build result
348        let assignments: HashMap<EntityId, u64> = nodes.into_iter().zip(labels).collect();
349
350        let num_communities = community_map.len();
351
352        Ok(CommunityResult { assignments, iterations, converged, num_communities })
353    }
354
355    /// Detect communities for a subset of nodes.
356    ///
357    /// Only considers edges within the specified subgraph.
358    ///
359    /// # Arguments
360    ///
361    /// * `tx` - The transaction to use for graph access
362    /// * `nodes` - The nodes to include in the computation
363    /// * `config` - Configuration parameters for the algorithm
364    pub fn label_propagation_for_nodes<T: Transaction>(
365        tx: &T,
366        nodes: &[EntityId],
367        config: &CommunityDetectionConfig,
368    ) -> GraphResult<CommunityResult> {
369        let n = nodes.len();
370        if n == 0 {
371            return Ok(CommunityResult {
372                assignments: HashMap::new(),
373                iterations: 0,
374                converged: true,
375                num_communities: 0,
376            });
377        }
378
379        // Build node index and set
380        let node_set: std::collections::HashSet<EntityId> = nodes.iter().copied().collect();
381        let node_index: HashMap<EntityId, usize> =
382            nodes.iter().enumerate().map(|(i, &id)| (id, i)).collect();
383
384        // Build adjacency lists (only including edges within the subgraph)
385        let mut neighbors: Vec<Vec<usize>> = vec![Vec::new(); n];
386        for (i, &node) in nodes.iter().enumerate() {
387            let neighbor_ids = Self::get_neighbors(tx, node, config.direction)?;
388            for neighbor_id in neighbor_ids {
389                if node_set.contains(&neighbor_id) {
390                    if let Some(&j) = node_index.get(&neighbor_id) {
391                        neighbors[i].push(j);
392                    }
393                }
394            }
395        }
396
397        // Initialize labels
398        let mut labels: Vec<u64> = (0..n as u64).collect();
399
400        // Simple pseudo-random number generator
401        let mut rng_state = config.seed.unwrap_or(12345);
402        let simple_random = |state: &mut u64| -> u64 {
403            *state = state
404                .wrapping_mul(6_364_136_223_846_793_005)
405                .wrapping_add(1_442_695_040_888_963_407);
406            *state
407        };
408
409        let mut iterations = 0;
410        let mut converged = false;
411        let mut order: Vec<usize> = (0..n).collect();
412
413        while iterations < config.max_iterations {
414            iterations += 1;
415
416            // Shuffle order
417            for i in (1..n).rev() {
418                let j = (simple_random(&mut rng_state) as usize) % (i + 1);
419                order.swap(i, j);
420            }
421
422            let mut changed = false;
423
424            for &i in &order {
425                if neighbors[i].is_empty() {
426                    continue;
427                }
428
429                let mut label_counts: HashMap<u64, usize> = HashMap::new();
430                for &j in &neighbors[i] {
431                    *label_counts.entry(labels[j]).or_insert(0) += 1;
432                }
433
434                let max_count = *label_counts.values().max().unwrap_or(&0);
435                let max_labels: Vec<u64> = label_counts
436                    .iter()
437                    .filter(|(_, &count)| count == max_count)
438                    .map(|(&label, _)| label)
439                    .collect();
440
441                let new_label = if max_labels.contains(&labels[i]) {
442                    labels[i]
443                } else if max_labels.len() == 1 {
444                    max_labels[0]
445                } else {
446                    let idx = (simple_random(&mut rng_state) as usize) % max_labels.len();
447                    max_labels[idx]
448                };
449
450                if new_label != labels[i] {
451                    labels[i] = new_label;
452                    changed = true;
453                }
454            }
455
456            if !changed {
457                converged = true;
458                break;
459            }
460        }
461
462        // Renumber communities
463        let mut community_map: HashMap<u64, u64> = HashMap::new();
464        let mut next_id = 0u64;
465        for label in &mut labels {
466            let new_id = *community_map.entry(*label).or_insert_with(|| {
467                let id = next_id;
468                next_id += 1;
469                id
470            });
471            *label = new_id;
472        }
473
474        let assignments: HashMap<EntityId, u64> = nodes.iter().copied().zip(labels).collect();
475
476        let num_communities = community_map.len();
477
478        Ok(CommunityResult { assignments, iterations, converged, num_communities })
479    }
480
481    /// Get neighbors of a node based on direction.
482    fn get_neighbors<T: Transaction>(
483        tx: &T,
484        node: EntityId,
485        direction: Direction,
486    ) -> GraphResult<Vec<EntityId>> {
487        let mut neighbors = Vec::new();
488
489        if direction.includes_outgoing() {
490            let edges = AdjacencyIndex::get_outgoing_edge_ids(tx, node)?;
491            for edge_id in edges {
492                if let Some(edge) = EdgeStore::get(tx, edge_id)? {
493                    neighbors.push(edge.target);
494                }
495            }
496        }
497
498        if direction.includes_incoming() {
499            let edges = AdjacencyIndex::get_incoming_edge_ids(tx, node)?;
500            for edge_id in edges {
501                if let Some(edge) = EdgeStore::get(tx, edge_id)? {
502                    neighbors.push(edge.source);
503                }
504            }
505        }
506
507        Ok(neighbors)
508    }
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514
515    #[test]
516    fn config_defaults() {
517        let config = CommunityDetectionConfig::default();
518        assert_eq!(config.max_iterations, 100);
519        assert_eq!(config.direction, Direction::Both);
520        assert!(config.seed.is_none());
521        assert!((config.min_improvement - 0.0).abs() < f64::EPSILON);
522    }
523
524    #[test]
525    fn config_builder() {
526        let config = CommunityDetectionConfig::new()
527            .with_max_iterations(50)
528            .with_direction(Direction::Outgoing)
529            .with_seed(42)
530            .with_min_improvement(0.001);
531
532        assert_eq!(config.max_iterations, 50);
533        assert_eq!(config.direction, Direction::Outgoing);
534        assert_eq!(config.seed, Some(42));
535        assert!((config.min_improvement - 0.001).abs() < f64::EPSILON);
536    }
537
538    #[test]
539    fn result_empty() {
540        let result = CommunityResult {
541            assignments: HashMap::new(),
542            iterations: 0,
543            converged: true,
544            num_communities: 0,
545        };
546
547        assert!(result.community(EntityId::new(1)).is_none());
548        assert!(result.members(0).is_empty());
549        assert!(result.community_sizes().is_empty());
550        assert!(result.largest_community().is_none());
551        assert!(result.smallest_community().is_none());
552    }
553
554    #[test]
555    fn result_community_operations() {
556        let mut assignments = HashMap::new();
557        assignments.insert(EntityId::new(1), 0);
558        assignments.insert(EntityId::new(2), 0);
559        assignments.insert(EntityId::new(3), 1);
560        assignments.insert(EntityId::new(4), 1);
561        assignments.insert(EntityId::new(5), 1);
562
563        let result =
564            CommunityResult { assignments, iterations: 10, converged: true, num_communities: 2 };
565
566        // Test community lookup
567        assert_eq!(result.community(EntityId::new(1)), Some(0));
568        assert_eq!(result.community(EntityId::new(3)), Some(1));
569        assert_eq!(result.community(EntityId::new(99)), None);
570
571        // Test members
572        let mut members_0 = result.members(0);
573        members_0.sort_by_key(|e| e.as_u64());
574        assert_eq!(members_0.len(), 2);
575
576        let mut members_1 = result.members(1);
577        members_1.sort_by_key(|e| e.as_u64());
578        assert_eq!(members_1.len(), 3);
579
580        // Test sizes
581        let sizes = result.community_sizes();
582        assert_eq!(sizes.get(&0), Some(&2));
583        assert_eq!(sizes.get(&1), Some(&3));
584
585        // Test largest/smallest
586        assert_eq!(result.largest_community(), Some((1, 3)));
587        assert_eq!(result.smallest_community(), Some((0, 2)));
588
589        // Test same_community
590        assert!(result.same_community(EntityId::new(1), EntityId::new(2)));
591        assert!(result.same_community(EntityId::new(3), EntityId::new(4)));
592        assert!(!result.same_community(EntityId::new(1), EntityId::new(3)));
593    }
594}