Skip to main content

graphmind_graph_algorithms/
community.rs

1//! Community detection algorithms
2//!
3//! Implements REQ-ALGO-004 (Weakly Connected Components)
4
5use super::common::{GraphView, NodeId};
6use std::collections::HashMap;
7
8/// Result of WCC algorithm
9pub struct WccResult {
10    /// Map of Component ID -> List of NodeIds
11    pub components: HashMap<usize, Vec<NodeId>>,
12    /// Map of NodeId -> Component ID
13    pub node_component: HashMap<NodeId, usize>,
14}
15
16/// Union-Find data structure
17struct UnionFind {
18    parent: Vec<usize>,
19    rank: Vec<usize>,
20}
21
22impl UnionFind {
23    fn new(size: usize) -> Self {
24        UnionFind {
25            parent: (0..size).collect(),
26            rank: vec![0; size],
27        }
28    }
29
30    fn find(&mut self, i: usize) -> usize {
31        if self.parent[i] != i {
32            self.parent[i] = self.find(self.parent[i]); // Path compression
33        }
34        self.parent[i]
35    }
36
37    fn union(&mut self, i: usize, j: usize) {
38        let root_i = self.find(i);
39        let root_j = self.find(j);
40
41        if root_i != root_j {
42            if self.rank[root_i] < self.rank[root_j] {
43                self.parent[root_i] = root_j;
44            } else if self.rank[root_i] > self.rank[root_j] {
45                self.parent[root_j] = root_i;
46            } else {
47                self.parent[root_j] = root_i;
48                self.rank[root_i] += 1;
49            }
50        }
51    }
52}
53
54/// Weakly Connected Components (WCC)
55///
56/// Finds all disjoint subgraphs in the graph.
57/// Ignores edge direction.
58pub fn weakly_connected_components(view: &GraphView) -> WccResult {
59    let n = view.node_count;
60    let mut uf = UnionFind::new(n);
61
62    // Iterate all edges and Union connected nodes
63    for u_idx in 0..n {
64        for &v_idx in view.successors(u_idx) {
65            uf.union(u_idx, v_idx);
66        }
67    }
68
69    // Build results
70    let mut components = HashMap::new();
71    let mut node_component = HashMap::new();
72
73    for i in 0..n {
74        let root = uf.find(i);
75        let node_id = view.index_to_node[i];
76
77        components
78            .entry(root)
79            .or_insert_with(Vec::new)
80            .push(node_id);
81        node_component.insert(node_id, root);
82    }
83
84    WccResult {
85        components,
86        node_component,
87    }
88}
89
90/// Result of SCC algorithm
91pub struct SccResult {
92    /// Map of Component ID -> List of NodeIds
93    pub components: HashMap<usize, Vec<NodeId>>,
94    /// Map of NodeId -> Component ID
95    pub node_component: HashMap<NodeId, usize>,
96}
97
98/// Strongly Connected Components (SCC) using Tarjan's algorithm
99pub fn strongly_connected_components(view: &GraphView) -> SccResult {
100    let n = view.node_count;
101    let mut ids = vec![-1; n];
102    let mut low = vec![0; n];
103    let mut on_stack = vec![false; n];
104    let mut stack = Vec::new();
105    let mut id_counter = 0;
106    let mut scc_count = 0;
107
108    let mut node_component = HashMap::new();
109    let mut components = HashMap::new();
110
111    #[allow(clippy::too_many_arguments)]
112    fn dfs(
113        u: usize,
114        id_counter: &mut i32,
115        scc_count: &mut usize,
116        ids: &mut Vec<i32>,
117        low: &mut Vec<usize>,
118        on_stack: &mut Vec<bool>,
119        stack: &mut Vec<usize>,
120        view: &GraphView,
121        node_component: &mut HashMap<NodeId, usize>,
122        components: &mut HashMap<usize, Vec<NodeId>>,
123    ) {
124        stack.push(u);
125        on_stack[u] = true;
126        ids[u] = *id_counter;
127        low[u] = *id_counter as usize;
128        *id_counter += 1;
129
130        for &v in view.successors(u) {
131            if ids[v] == -1 {
132                dfs(
133                    v,
134                    id_counter,
135                    scc_count,
136                    ids,
137                    low,
138                    on_stack,
139                    stack,
140                    view,
141                    node_component,
142                    components,
143                );
144                low[u] = low[u].min(low[v]);
145            } else if on_stack[v] {
146                low[u] = low[u].min(ids[v] as usize);
147            }
148        }
149
150        if ids[u] == low[u] as i32 {
151            while let Some(node_idx) = stack.pop() {
152                on_stack[node_idx] = false;
153                low[node_idx] = ids[u] as usize;
154
155                let node_id = view.index_to_node[node_idx];
156                node_component.insert(node_id, *scc_count);
157                components.entry(*scc_count).or_default().push(node_id);
158
159                if node_idx == u {
160                    break;
161                }
162            }
163            *scc_count += 1;
164        }
165    }
166
167    for i in 0..n {
168        if ids[i] == -1 {
169            dfs(
170                i,
171                &mut id_counter,
172                &mut scc_count,
173                &mut ids,
174                &mut low,
175                &mut on_stack,
176                &mut stack,
177                view,
178                &mut node_component,
179                &mut components,
180            );
181        }
182    }
183
184    SccResult {
185        components,
186        node_component,
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use std::collections::HashMap;
194
195    #[test]
196    fn test_wcc() {
197        // ... (existing test)
198    }
199
200    #[test]
201    fn test_scc() {
202        // Graph with cycle: 1->2->3->1, and 4 (isolated)
203        let node_count = 4;
204        let index_to_node = vec![1, 2, 3, 4];
205        let mut node_to_index = HashMap::new();
206        for (i, &id) in index_to_node.iter().enumerate() {
207            node_to_index.insert(id, i);
208        }
209
210        let mut outgoing = vec![vec![]; 4];
211        outgoing[0].push(1); // 1->2
212        outgoing[1].push(2); // 2->3
213        outgoing[2].push(0); // 3->1
214
215        let view = GraphView::from_adjacency_list(
216            node_count,
217            index_to_node,
218            node_to_index,
219            outgoing,
220            vec![vec![]; 4],
221            None,
222        );
223
224        let result = strongly_connected_components(&view);
225        assert_eq!(result.components.len(), 2);
226
227        let c1 = result.node_component[&1];
228        let c2 = result.node_component[&2];
229        let c3 = result.node_component[&3];
230        let c4 = result.node_component[&4];
231
232        assert_eq!(c1, c2);
233        assert_eq!(c2, c3);
234        assert_ne!(c1, c4);
235    }
236}