Skip to main content

god_graph/algorithms/
community.rs

1//! 社区发现算法模块
2//!
3//! 包含 Label Propagation、Louvain、连通分量等算法
4
5use crate::graph::traits::{GraphBase, GraphQuery};
6use crate::graph::Graph;
7use crate::node::NodeIndex;
8use std::collections::{HashMap, VecDeque};
9
10/// 连通分量算法(基于 BFS)
11///
12/// 返回所有连通分量,每个分量是节点索引的向量
13pub fn connected_components<T>(graph: &Graph<T, impl Clone>) -> Vec<Vec<NodeIndex>> {
14    // 收集所有有效节点
15    let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
16
17    // 创建索引到 NodeIndex 的映射
18    let index_to_node: std::collections::HashMap<usize, NodeIndex> =
19        node_indices.iter().map(|&ni| (ni.index(), ni)).collect();
20
21    let n = graph.node_count();
22    let mut visited = vec![false; n];
23    let mut components = Vec::new();
24
25    for &node in &node_indices {
26        if !visited[node.index()] {
27            let mut component = Vec::new();
28            bfs_component(graph, node, &mut visited, &mut component, &index_to_node);
29            components.push(component);
30        }
31    }
32
33    components
34}
35
36fn bfs_component<T>(
37    graph: &Graph<T, impl Clone>,
38    start: NodeIndex,
39    visited: &mut [bool],
40    component: &mut Vec<NodeIndex>,
41    index_to_node: &std::collections::HashMap<usize, NodeIndex>,
42) {
43    let mut queue = VecDeque::new();
44    queue.push_back(start);
45    visited[start.index()] = true;
46
47    while let Some(node) = queue.pop_front() {
48        component.push(node);
49
50        for neighbor in graph.neighbors(node) {
51            if !visited[neighbor.index()] {
52                visited[neighbor.index()] = true;
53                // 使用映射获取正确的 NodeIndex
54                if let Some(&neighbor_ni) = index_to_node.get(&neighbor.index()) {
55                    queue.push_back(neighbor_ni);
56                }
57            }
58        }
59    }
60}
61
62/// 强连通分量算法(基于 Kosaraju 算法)
63///
64/// 返回所有强连通分量(仅适用于有向图)
65pub fn strongly_connected_components<T>(graph: &Graph<T, impl Clone>) -> Vec<Vec<NodeIndex>> {
66    let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
67    let index_to_node: std::collections::HashMap<usize, NodeIndex> =
68        node_indices.iter().map(|&ni| (ni.index(), ni)).collect();
69
70    let n = graph.node_count();
71    let mut visited = vec![false; n];
72    let mut finish_order = Vec::with_capacity(n);
73
74    // 第一次 DFS,记录完成顺序
75    for &node in &node_indices {
76        if !visited[node.index()] {
77            dfs_finish_order(graph, node, &mut visited, &mut finish_order);
78        }
79    }
80
81    // 构建反向图(隐式)
82    // 在反向图上按完成顺序的逆序进行 DFS
83    let mut visited = vec![false; n];
84    let mut components = Vec::new();
85
86    for &node in finish_order.iter().rev() {
87        if !visited[node.index()] {
88            let mut component = Vec::new();
89            dfs_reverse(graph, node, &mut visited, &mut component, &index_to_node);
90            components.push(component);
91        }
92    }
93
94    components
95}
96
97fn dfs_finish_order<T>(
98    graph: &Graph<T, impl Clone>,
99    node: NodeIndex,
100    visited: &mut [bool],
101    finish_order: &mut Vec<NodeIndex>,
102) {
103    visited[node.index()] = true;
104
105    for neighbor in graph.neighbors(node) {
106        if !visited[neighbor.index()] {
107            dfs_finish_order(graph, neighbor, visited, finish_order);
108        }
109    }
110
111    finish_order.push(node);
112}
113
114fn dfs_reverse<T>(
115    graph: &Graph<T, impl Clone>,
116    node: NodeIndex,
117    visited: &mut [bool],
118    component: &mut Vec<NodeIndex>,
119    index_to_node: &std::collections::HashMap<usize, NodeIndex>,
120) {
121    component.push(node);
122    visited[node.index()] = true;
123
124    // 在反向图中,我们需要找到所有指向当前节点的节点
125    for potential_source in graph.nodes() {
126        if graph.has_edge(potential_source.index(), node) {
127            let source_idx = potential_source.index();
128            if !visited[source_idx.index()] {
129                if let Some(&source_ni) = index_to_node.get(&source_idx.index()) {
130                    dfs_reverse(graph, source_ni, visited, component, index_to_node);
131                }
132            }
133        }
134    }
135}
136
137/// Label Propagation 社区发现算法
138///
139/// 基于标签传播的社区发现,每个节点最终被分配到一个社区标签
140pub fn label_propagation<T>(
141    graph: &Graph<T, impl Clone>,
142    max_iterations: usize,
143) -> HashMap<NodeIndex, usize> {
144    let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
145    let n = node_indices.len();
146
147    if n == 0 {
148        return HashMap::new();
149    }
150
151    // 初始化:每个节点一个唯一标签
152    let mut labels: HashMap<NodeIndex, usize> = node_indices
153        .iter()
154        .enumerate()
155        .map(|(i, &ni)| (ni, i))
156        .collect();
157
158    for _ in 0..max_iterations {
159        let mut changed = false;
160
161        for &node in &node_indices {
162            // 收集邻居的标签
163            let mut label_counts: HashMap<usize, usize> = HashMap::new();
164
165            // 无向图:考虑所有邻居(出边和入边)
166            for neighbor in graph.neighbors(node) {
167                if let Some(&label) = labels.get(&neighbor) {
168                    *label_counts.entry(label).or_insert(0) += 1;
169                }
170            }
171
172            // 找到出现频率最高的标签
173            if let Some((&max_label, _)) = label_counts.iter().max_by_key(|&(_, count)| count) {
174                let current_label = labels.get(&node).copied().unwrap_or(usize::MAX);
175                if max_label != current_label {
176                    labels.insert(node, max_label);
177                    changed = true;
178                }
179            }
180        }
181
182        if !changed {
183            break;
184        }
185    }
186
187    labels
188}
189
190/// Louvain 社区发现算法(简化版)
191///
192/// 基于模块度优化的社区发现算法
193pub fn louvain<T>(graph: &Graph<T, impl Clone>, resolution: f64) -> HashMap<NodeIndex, usize>
194where
195    T: Clone,
196{
197    let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
198    let n = node_indices.len();
199
200    if n == 0 {
201        return HashMap::new();
202    }
203
204    // 初始化:每个节点一个社区
205    let mut communities: HashMap<NodeIndex, usize> = node_indices
206        .iter()
207        .enumerate()
208        .map(|(i, &ni)| (ni, i))
209        .collect();
210
211    // 计算总边数
212    let total_edges = graph.edge_count() as f64;
213    if total_edges == 0.0 {
214        return communities;
215    }
216
217    // 计算每个节点的度数
218    let mut degrees: HashMap<NodeIndex, usize> = HashMap::new();
219    for &node in &node_indices {
220        degrees.insert(node, graph.out_degree(node).unwrap_or(0));
221    }
222
223    let mut improved = true;
224    while improved {
225        improved = false;
226
227        for &node in &node_indices {
228            let current_comm = *communities.get(&node).unwrap();
229            let node_degree = *degrees.get(&node).unwrap();
230
231            // 计算移动到每个邻居社区的模块度增益
232            let mut comm_delta_q: HashMap<usize, f64> = HashMap::new();
233
234            for neighbor in graph.neighbors(node) {
235                let neighbor_comm = *communities.get(&neighbor).unwrap();
236                if neighbor_comm != current_comm {
237                    // 简化:假设无权重边
238                    let delta_q = 1.0 / (2.0 * total_edges)
239                        - resolution
240                            * (node_degree as f64
241                                * degrees.get(&neighbor).copied().unwrap_or(0) as f64)
242                            / (4.0 * total_edges * total_edges);
243
244                    *comm_delta_q.entry(neighbor_comm).or_insert(0.0) += delta_q;
245                }
246            }
247
248            // 选择最大增益的社区(使用 partial_cmp 处理 f64)
249            if let Some((&best_comm, &max_delta)) = comm_delta_q
250                .iter()
251                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
252            {
253                if max_delta > 0.0 {
254                    communities.insert(node, best_comm);
255                    improved = true;
256                }
257            }
258        }
259    }
260
261    // 重新编号社区,使标签从 0 开始连续
262    let mut comm_remap: HashMap<usize, usize> = HashMap::new();
263    let mut next_comm = 0usize;
264    for comm in communities.values_mut() {
265        if !comm_remap.contains_key(comm) {
266            comm_remap.insert(*comm, next_comm);
267            next_comm += 1;
268        }
269        *comm = comm_remap[comm];
270    }
271
272    communities
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use crate::graph::builders::GraphBuilder;
279
280    #[test]
281    fn test_connected_components() {
282        let graph = GraphBuilder::undirected()
283            .with_nodes(vec![1, 2, 3, 4, 5, 6])
284            .with_edges(vec![(0, 1, 1.0), (1, 2, 1.0), (3, 4, 1.0)])
285            .build()
286            .unwrap();
287
288        let components = connected_components(&graph);
289        assert_eq!(components.len(), 3); // {0,1,2}, {3,4}, {5}
290    }
291
292    #[test]
293    fn test_connected_components_empty_graph() {
294        let graph: Graph<i32, f64> = GraphBuilder::undirected()
295            .with_nodes(vec![1, 2, 3])
296            .build()
297            .unwrap();
298
299        let components = connected_components(&graph);
300        assert_eq!(components.len(), 3); // 每个节点独立一个分量
301        assert!(components.iter().all(|c| c.len() == 1));
302    }
303
304    #[test]
305    fn test_connected_components_single_node() {
306        let graph: Graph<i32, f64> = GraphBuilder::undirected()
307            .with_nodes(vec![1])
308            .build()
309            .unwrap();
310
311        let components = connected_components(&graph);
312        assert_eq!(components.len(), 1);
313        assert_eq!(components[0].len(), 1);
314    }
315
316    #[test]
317    fn test_connected_components_fully_connected() {
318        let graph: Graph<i32, f64> = GraphBuilder::undirected()
319            .with_nodes(vec![1, 2, 3, 4])
320            .with_edges(vec![
321                (0, 1, 1.0),
322                (0, 2, 1.0),
323                (0, 3, 1.0),
324                (1, 2, 1.0),
325                (1, 3, 1.0),
326                (2, 3, 1.0),
327            ])
328            .build()
329            .unwrap();
330
331        let components = connected_components(&graph);
332        assert_eq!(components.len(), 1);
333        assert_eq!(components[0].len(), 4);
334    }
335
336    #[test]
337    fn test_strongly_connected_components() {
338        let graph = GraphBuilder::directed()
339            .with_nodes(vec![1, 2, 3, 4])
340            .with_edges(vec![(0, 1, 1.0), (1, 2, 1.0), (2, 0, 1.0), (2, 3, 1.0)])
341            .build()
342            .unwrap();
343
344        let components = strongly_connected_components(&graph);
345        // {0,1,2} 形成一个强连通分量,{3} 单独一个
346        assert!(!components.is_empty());
347    }
348
349    #[test]
350    fn test_strongly_connected_single_node() {
351        let graph: Graph<i32, f64> = GraphBuilder::directed()
352            .with_nodes(vec![1])
353            .build()
354            .unwrap();
355
356        let components = strongly_connected_components(&graph);
357        assert_eq!(components.len(), 1);
358        assert_eq!(components[0].len(), 1);
359    }
360
361    #[test]
362    fn test_strongly_connected_dag() {
363        // DAG 没有强连通分量(除了单个节点)
364        let graph: Graph<i32, f64> = GraphBuilder::directed()
365            .with_nodes(vec![1, 2, 3, 4])
366            .with_edges(vec![(0, 1, 1.0), (1, 2, 1.0), (2, 3, 1.0)])
367            .build()
368            .unwrap();
369
370        let components = strongly_connected_components(&graph);
371        assert_eq!(components.len(), 4); // 每个节点独立
372    }
373
374    #[test]
375    fn test_label_propagation() {
376        let graph = GraphBuilder::undirected()
377            .with_nodes(vec![1, 2, 3, 4])
378            .with_edges(vec![(0, 1, 1.0), (1, 2, 1.0), (2, 3, 1.0)])
379            .build()
380            .unwrap();
381
382        let labels = label_propagation(&graph, 10);
383        assert_eq!(labels.len(), 4);
384    }
385
386    #[test]
387    fn test_label_propagation_empty_graph() {
388        let graph: Graph<i32, f64> = GraphBuilder::undirected()
389            .with_nodes(vec![1, 2, 3])
390            .build()
391            .unwrap();
392
393        let labels = label_propagation(&graph, 10);
394        // 空图中每个节点保持自己的标签
395        assert_eq!(labels.len(), 3);
396    }
397
398    #[test]
399    fn test_label_propagation_single_node() {
400        let graph: Graph<i32, f64> = GraphBuilder::undirected()
401            .with_nodes(vec![1])
402            .build()
403            .unwrap();
404
405        let labels = label_propagation(&graph, 10);
406        assert_eq!(labels.len(), 1);
407    }
408
409    #[test]
410    fn test_louvain_basic() {
411        let graph: Graph<i32, f64> = GraphBuilder::undirected()
412            .with_nodes(vec![1, 2, 3, 4])
413            .with_edges(vec![(0, 1, 1.0), (1, 2, 1.0), (2, 3, 1.0)])
414            .build()
415            .unwrap();
416
417        let communities = louvain(&graph, 1.0);
418        assert!(!communities.is_empty());
419        assert_eq!(communities.len(), 4);
420    }
421
422    #[test]
423    fn test_louvain_empty_graph() {
424        let graph: Graph<i32, f64> = GraphBuilder::undirected()
425            .with_nodes(vec![1, 2, 3])
426            .build()
427            .unwrap();
428
429        let communities = louvain(&graph, 1.0);
430        // 空图中每个节点独立一个社区
431        assert_eq!(communities.len(), 3);
432    }
433}