Skip to main content

reddb_server/storage/engine/algorithms/
community.rs

1//! Community Detection Algorithms
2//!
3//! Algorithms for detecting communities/clusters in graphs:
4//! - Label Propagation: Fast, simple community detection
5//! - Louvain: Modularity optimization for quality communities
6
7use std::collections::{HashMap, HashSet};
8
9use super::super::graph_store::GraphStore;
10
11// ============================================================================
12// Label Propagation
13// ============================================================================
14
15/// Label Propagation Algorithm for community detection
16///
17/// Nodes adopt the most common label among their neighbors.
18/// Fast and scales well, but results can be non-deterministic.
19pub struct LabelPropagation {
20    /// Maximum iterations
21    pub max_iterations: usize,
22}
23
24impl Default for LabelPropagation {
25    fn default() -> Self {
26        Self {
27            max_iterations: 100,
28        }
29    }
30}
31
32/// A community of nodes
33#[derive(Debug, Clone)]
34pub struct Community {
35    /// Community label (typically the ID of a founding member)
36    pub label: String,
37    /// Nodes in this community
38    pub nodes: Vec<String>,
39    /// Size of the community
40    pub size: usize,
41}
42
43/// Result of community detection
44#[derive(Debug, Clone)]
45pub struct CommunitiesResult {
46    /// Detected communities, sorted by size descending
47    pub communities: Vec<Community>,
48    /// Number of iterations until convergence
49    pub iterations: usize,
50    /// Whether the algorithm converged
51    pub converged: bool,
52}
53
54impl CommunitiesResult {
55    /// Get the largest community
56    pub fn largest(&self) -> Option<&Community> {
57        self.communities.first()
58    }
59
60    /// Find which community a node belongs to
61    pub fn community_of(&self, node_id: &str) -> Option<&Community> {
62        self.communities
63            .iter()
64            .find(|c| c.nodes.contains(&node_id.to_string()))
65    }
66}
67
68impl LabelPropagation {
69    /// Create with default parameters
70    pub fn new() -> Self {
71        Self::default()
72    }
73
74    /// Set maximum iterations
75    pub fn max_iterations(mut self, max: usize) -> Self {
76        self.max_iterations = max;
77        self
78    }
79
80    /// Run label propagation on the graph
81    pub fn run(&self, graph: &GraphStore) -> CommunitiesResult {
82        let nodes: Vec<String> = graph.iter_nodes().map(|n| n.id.clone()).collect();
83
84        if nodes.is_empty() {
85            return CommunitiesResult {
86                communities: Vec::new(),
87                iterations: 0,
88                converged: true,
89            };
90        }
91
92        // Initialize: each node gets its own label
93        let mut labels: HashMap<String, String> =
94            nodes.iter().map(|id| (id.clone(), id.clone())).collect();
95
96        let mut converged = false;
97        let mut iterations = 0;
98
99        for iter in 0..self.max_iterations {
100            iterations = iter + 1;
101            let mut changed = false;
102
103            // Process nodes in order (could shuffle for randomization)
104            for node_id in &nodes {
105                // Count neighbor labels
106                let mut label_counts: HashMap<String, usize> = HashMap::new();
107
108                // Outgoing edges
109                for (_, neighbor, _) in graph.outgoing_edges(node_id) {
110                    if let Some(label) = labels.get(&neighbor) {
111                        *label_counts.entry(label.clone()).or_insert(0) += 1;
112                    }
113                }
114
115                // Incoming edges (treat as undirected)
116                for (_, neighbor, _) in graph.incoming_edges(node_id) {
117                    if let Some(label) = labels.get(&neighbor) {
118                        *label_counts.entry(label.clone()).or_insert(0) += 1;
119                    }
120                }
121
122                // Find most common label
123                if let Some((best_label, _)) =
124                    label_counts.into_iter().max_by_key(|(_, count)| *count)
125                {
126                    let current = labels.get(node_id).cloned().unwrap_or_default();
127                    if best_label != current {
128                        labels.insert(node_id.clone(), best_label);
129                        changed = true;
130                    }
131                }
132            }
133
134            if !changed {
135                converged = true;
136                break;
137            }
138        }
139
140        // Group nodes by label
141        let mut groups: HashMap<String, Vec<String>> = HashMap::new();
142        for (node_id, label) in &labels {
143            groups
144                .entry(label.clone())
145                .or_default()
146                .push(node_id.clone());
147        }
148
149        // Build communities
150        let mut communities: Vec<Community> = groups
151            .into_iter()
152            .map(|(label, nodes)| {
153                let size = nodes.len();
154                Community { label, nodes, size }
155            })
156            .collect();
157
158        // Sort by size descending
159        communities.sort_by_key(|b| std::cmp::Reverse(b.size));
160
161        CommunitiesResult {
162            communities,
163            iterations,
164            converged,
165        }
166    }
167}
168
169// ============================================================================
170// Louvain Community Detection
171// ============================================================================
172
173/// Louvain algorithm for community detection
174///
175/// A greedy algorithm that optimizes modularity - a measure of how well
176/// the network is partitioned into communities where nodes are densely
177/// connected within communities but sparsely between them.
178pub struct Louvain {
179    /// Resolution parameter (higher = smaller communities)
180    pub resolution: f64,
181    /// Maximum iterations per phase
182    pub max_iterations: usize,
183    /// Minimum modularity improvement to continue
184    pub min_improvement: f64,
185}
186
187impl Default for Louvain {
188    fn default() -> Self {
189        Self {
190            resolution: 1.0,
191            max_iterations: 10,
192            min_improvement: 1e-6,
193        }
194    }
195}
196
197/// Result of Louvain community detection
198#[derive(Debug, Clone)]
199pub struct LouvainResult {
200    /// Node ID → community ID
201    pub communities: HashMap<String, usize>,
202    /// Number of communities found
203    pub count: usize,
204    /// Final modularity score (-0.5 to 1.0, higher = better)
205    pub modularity: f64,
206    /// Number of passes/phases completed
207    pub passes: usize,
208}
209
210impl LouvainResult {
211    /// Get all nodes in a specific community
212    pub fn get_community(&self, community_id: usize) -> Vec<String> {
213        self.communities
214            .iter()
215            .filter(|(_, &c)| c == community_id)
216            .map(|(n, _)| n.clone())
217            .collect()
218    }
219
220    /// Get community sizes
221    pub fn community_sizes(&self) -> HashMap<usize, usize> {
222        let mut sizes: HashMap<usize, usize> = HashMap::new();
223        for &c in self.communities.values() {
224            *sizes.entry(c).or_insert(0) += 1;
225        }
226        sizes
227    }
228}
229
230impl Louvain {
231    /// Create new Louvain with default parameters
232    pub fn new() -> Self {
233        Self::default()
234    }
235
236    /// Set resolution parameter (default: 1.0)
237    pub fn resolution(mut self, resolution: f64) -> Self {
238        self.resolution = resolution;
239        self
240    }
241
242    /// Set maximum iterations per phase (default: 10)
243    pub fn max_iterations(mut self, max: usize) -> Self {
244        self.max_iterations = max;
245        self
246    }
247
248    /// Run Louvain community detection (treats graph as undirected)
249    pub fn run(&self, graph: &GraphStore) -> LouvainResult {
250        let nodes: Vec<String> = graph.iter_nodes().map(|n| n.id.clone()).collect();
251
252        if nodes.is_empty() {
253            return LouvainResult {
254                communities: HashMap::new(),
255                count: 0,
256                modularity: 0.0,
257                passes: 0,
258            };
259        }
260
261        // Build undirected weighted adjacency
262        let mut weights: HashMap<(String, String), f64> = HashMap::new();
263        let mut node_strength: HashMap<String, f64> = HashMap::new();
264        let mut total_weight = 0.0;
265
266        for node in &nodes {
267            for (_, target, _) in graph.outgoing_edges(node) {
268                if node != &target {
269                    let key = if node < &target {
270                        (node.clone(), target.clone())
271                    } else {
272                        (target.clone(), node.clone())
273                    };
274
275                    let w = weights.entry(key).or_insert(0.0);
276                    *w += 1.0; // Can use edge weight if available
277                }
278            }
279        }
280
281        // Calculate node strengths and total weight
282        for ((a, b), w) in &weights {
283            *node_strength.entry(a.clone()).or_insert(0.0) += w;
284            *node_strength.entry(b.clone()).or_insert(0.0) += w;
285            total_weight += w;
286        }
287
288        if total_weight == 0.0 {
289            // No edges - each node is its own community
290            let communities: HashMap<String, usize> = nodes
291                .iter()
292                .enumerate()
293                .map(|(i, n)| (n.clone(), i))
294                .collect();
295            return LouvainResult {
296                count: nodes.len(),
297                communities,
298                modularity: 0.0,
299                passes: 0,
300            };
301        }
302
303        // Initialize: each node in its own community
304        let mut communities: HashMap<String, usize> = nodes
305            .iter()
306            .enumerate()
307            .map(|(i, n)| (n.clone(), i))
308            .collect();
309
310        // Community total weights
311        let mut comm_total: HashMap<usize, f64> = nodes
312            .iter()
313            .enumerate()
314            .map(|(i, n)| (i, *node_strength.get(n).unwrap_or(&0.0)))
315            .collect();
316
317        // Community internal weights
318        let mut comm_internal: HashMap<usize, f64> = HashMap::new();
319
320        let mut passes = 0;
321        let mut improved = true;
322
323        while improved && passes < self.max_iterations {
324            improved = false;
325            passes += 1;
326
327            for node in &nodes {
328                let current_comm = *communities.get(node).unwrap();
329                let node_w = *node_strength.get(node).unwrap_or(&0.0);
330
331                // Calculate weights to each neighboring community
332                let mut neighbor_comm_weights: HashMap<usize, f64> = HashMap::new();
333
334                for ((a, b), w) in &weights {
335                    if a == node {
336                        let neighbor_comm = *communities.get(b).unwrap();
337                        *neighbor_comm_weights.entry(neighbor_comm).or_insert(0.0) += w;
338                    } else if b == node {
339                        let neighbor_comm = *communities.get(a).unwrap();
340                        *neighbor_comm_weights.entry(neighbor_comm).or_insert(0.0) += w;
341                    }
342                }
343
344                // Try moving to each neighboring community
345                let mut best_comm = current_comm;
346                let mut best_delta = 0.0;
347
348                // First, calculate delta for removing from current community
349                let current_internal = neighbor_comm_weights
350                    .get(&current_comm)
351                    .copied()
352                    .unwrap_or(0.0);
353                let current_total = *comm_total.get(&current_comm).unwrap_or(&0.0);
354
355                for (&target_comm, &weight_to_target) in &neighbor_comm_weights {
356                    if target_comm == current_comm {
357                        continue;
358                    }
359
360                    let target_total = *comm_total.get(&target_comm).unwrap_or(&0.0);
361
362                    let delta = (weight_to_target - current_internal) / total_weight
363                        - self.resolution * node_w * (target_total - current_total + node_w)
364                            / (2.0 * total_weight * total_weight);
365
366                    if delta > best_delta + self.min_improvement {
367                        best_delta = delta;
368                        best_comm = target_comm;
369                    }
370                }
371
372                // Move node if beneficial
373                if best_comm != current_comm {
374                    improved = true;
375
376                    // Update community totals
377                    *comm_total.entry(current_comm).or_insert(0.0) -= node_w;
378                    *comm_total.entry(best_comm).or_insert(0.0) += node_w;
379
380                    // Update community internals
381                    let current_internal = neighbor_comm_weights
382                        .get(&current_comm)
383                        .copied()
384                        .unwrap_or(0.0);
385                    *comm_internal.entry(current_comm).or_insert(0.0) -= current_internal;
386
387                    let new_internal = neighbor_comm_weights
388                        .get(&best_comm)
389                        .copied()
390                        .unwrap_or(0.0);
391                    *comm_internal.entry(best_comm).or_insert(0.0) += new_internal;
392
393                    communities.insert(node.clone(), best_comm);
394                }
395            }
396        }
397
398        // Renumber communities to be contiguous
399        let unique_communities: Vec<usize> = {
400            let c: HashSet<usize> = communities.values().copied().collect();
401            let mut v: Vec<usize> = c.into_iter().collect();
402            v.sort();
403            v
404        };
405
406        let comm_map: HashMap<usize, usize> = unique_communities
407            .iter()
408            .enumerate()
409            .map(|(new, &old)| (old, new))
410            .collect();
411
412        let remapped: HashMap<String, usize> = communities
413            .into_iter()
414            .map(|(n, c)| (n, *comm_map.get(&c).unwrap_or(&0)))
415            .collect();
416
417        // Calculate final modularity
418        let modularity =
419            self.calculate_modularity(&remapped, &weights, &node_strength, total_weight);
420
421        LouvainResult {
422            count: unique_communities.len(),
423            communities: remapped,
424            modularity,
425            passes,
426        }
427    }
428
429    /// Calculate modularity of a partition
430    fn calculate_modularity(
431        &self,
432        communities: &HashMap<String, usize>,
433        weights: &HashMap<(String, String), f64>,
434        node_strength: &HashMap<String, f64>,
435        total_weight: f64,
436    ) -> f64 {
437        if total_weight == 0.0 {
438            return 0.0;
439        }
440
441        let mut q = 0.0;
442
443        // Sum over all edges within same community
444        for ((a, b), w) in weights {
445            let ca = communities.get(a).unwrap();
446            let cb = communities.get(b).unwrap();
447
448            if ca == cb {
449                let ka = node_strength.get(a).unwrap_or(&0.0);
450                let kb = node_strength.get(b).unwrap_or(&0.0);
451                q += w - self.resolution * ka * kb / (2.0 * total_weight);
452            }
453        }
454
455        q / total_weight
456    }
457}