Skip to main content

god_graph/algorithms/
centrality.rs

1//! 中心性算法模块
2//!
3//! 包含度中心性、介数中心性、接近中心性、PageRank 等算法
4
5use crate::graph::traits::{GraphBase, GraphQuery};
6use crate::graph::Graph;
7use crate::node::NodeIndex;
8use std::collections::{HashMap, VecDeque};
9
10/// 度中心性
11///
12/// 计算每个节点的度中心性(归一化的度数)
13pub fn degree_centrality<T>(graph: &Graph<T, impl Clone>) -> HashMap<NodeIndex, f64> {
14    let n = graph.node_count();
15    if n <= 1 {
16        return HashMap::new();
17    }
18
19    let mut centrality = HashMap::new();
20    let norm = 1.0 / (n - 1) as f64;
21
22    for node in graph.nodes() {
23        let degree = graph.out_degree(node.index()).unwrap_or(0) as f64;
24        centrality.insert(node.index(), degree * norm);
25    }
26
27    centrality
28}
29
30/// PageRank 算法
31///
32/// 计算每个节点的 PageRank 分数
33///
34/// # 参数
35/// * `graph` - 图
36/// * `damping` - 阻尼系数(通常 0.85)
37/// * `iterations` - 迭代次数
38///
39/// # 返回
40/// HashMap,键为节点索引,值为 PageRank 分数
41pub fn pagerank<T>(
42    graph: &Graph<T, impl Clone>,
43    damping: f64,
44    iterations: usize,
45) -> HashMap<NodeIndex, f64> {
46    let n = graph.node_count();
47    if n == 0 {
48        return HashMap::new();
49    }
50
51    // 收集所有有效节点
52    let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
53
54    // 初始化:均匀分布
55    let mut scores: HashMap<NodeIndex, f64> = node_indices
56        .iter()
57        .map(|&ni| (ni, 1.0 / n as f64))
58        .collect();
59
60    for _ in 0..iterations {
61        let mut new_scores: HashMap<NodeIndex, f64> =
62            node_indices.iter().map(|&ni| (ni, 0.0)).collect();
63
64        // 计算每个节点的 PageRank
65        for &node in &node_indices {
66            // 基础分数:随机跳转贡献
67            let mut rank = (1.0 - damping) / n as f64;
68
69            // 收集指向当前节点的邻居贡献
70            for neighbor in graph.nodes() {
71                // 检查 neighbor 是否指向 node
72                if graph.has_edge(neighbor.index(), node) {
73                    let out_degree = graph.out_degree(neighbor.index()).unwrap_or(1);
74                    if out_degree > 0 {
75                        let contribution = scores.get(&neighbor.index()).copied().unwrap_or(0.0);
76                        rank += damping * contribution / out_degree as f64;
77                    }
78                }
79            }
80
81            new_scores.insert(node, rank);
82        }
83
84        scores = new_scores;
85    }
86
87    scores
88}
89
90/// 介数中心性(基于 Brandes 算法)
91///
92/// 计算每个节点的介数中心性
93/// 介数中心性衡量节点在所有最短路径中出现的频率
94pub fn betweenness_centrality<T>(graph: &Graph<T, impl Clone>) -> HashMap<NodeIndex, f64> {
95    let n = graph.node_count();
96    if n == 0 {
97        return HashMap::new();
98    }
99
100    let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
101    let mut centrality: HashMap<NodeIndex, f64> =
102        node_indices.iter().map(|&ni| (ni, 0.0)).collect();
103
104    for s in &node_indices {
105        // 单源最短路径
106        let mut dist: HashMap<NodeIndex, i32> = node_indices.iter().map(|&ni| (ni, -1)).collect();
107        let mut sigma: HashMap<NodeIndex, usize> = node_indices.iter().map(|&ni| (ni, 0)).collect();
108        let mut predecessors: HashMap<NodeIndex, Vec<NodeIndex>> =
109            node_indices.iter().map(|&ni| (ni, Vec::new())).collect();
110        let mut stack = Vec::new();
111
112        dist.insert(*s, 0);
113        sigma.insert(*s, 1);
114        let mut queue = VecDeque::new();
115        queue.push_back(*s);
116
117        while let Some(v) = queue.pop_front() {
118            stack.push(v);
119
120            for w in graph.neighbors(v) {
121                // w 第一次访问
122                if dist.get(&w).copied().unwrap_or(-1) < 0 {
123                    dist.insert(w, dist.get(&v).copied().unwrap_or(0) + 1);
124                    queue.push_back(w);
125                }
126
127                // 最短路径经过 v 到 w
128                if dist.get(&w).copied().unwrap_or(0) == dist.get(&v).copied().unwrap_or(0) + 1 {
129                    *sigma.get_mut(&w).unwrap() += sigma.get(&v).copied().unwrap_or(0);
130                    predecessors.get_mut(&w).unwrap().push(v);
131                }
132            }
133        }
134
135        // 反向累加依赖值
136        let mut delta: HashMap<NodeIndex, f64> = node_indices.iter().map(|&ni| (ni, 0.0)).collect();
137
138        while let Some(w) = stack.pop() {
139            for &v in predecessors.get(&w).unwrap_or(&Vec::new()) {
140                let delta_w = delta.get(&w).copied().unwrap_or(0.0);
141                let sigma_w = sigma.get(&w).copied().unwrap_or(1);
142                let sigma_v = sigma.get(&v).copied().unwrap_or(1);
143
144                if sigma_w > 0 {
145                    let contrib = (sigma_v as f64 / sigma_w as f64) * (1.0 + delta_w);
146                    *delta.get_mut(&v).unwrap() += contrib;
147                }
148            }
149
150            if w != *s {
151                let centrality_w = centrality.get_mut(&w).unwrap();
152                *centrality_w += delta.get(&w).copied().unwrap_or(0.0);
153            }
154        }
155    }
156
157    // 归一化(有向图)
158    if n > 2 {
159        let norm = 1.0 / ((n - 1) * (n - 2)) as f64;
160        for val in centrality.values_mut() {
161            *val *= norm;
162        }
163    }
164
165    centrality
166}
167
168/// 接近中心性(基于 BFS 计算平均最短距离)
169///
170/// 计算每个节点的接近中心性
171/// 接近中心性衡量节点到所有其他节点的平均最短距离的倒数
172pub fn closeness_centrality<T>(graph: &Graph<T, impl Clone>) -> HashMap<NodeIndex, f64> {
173    let n = graph.node_count();
174    if n == 0 {
175        return HashMap::new();
176    }
177
178    let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
179    let mut centrality: HashMap<NodeIndex, f64> = HashMap::new();
180
181    for &source in &node_indices {
182        // BFS 计算从 source 到所有节点的最短距离
183        let mut dist: HashMap<NodeIndex, usize> =
184            node_indices.iter().map(|&ni| (ni, usize::MAX)).collect();
185        let mut queue = VecDeque::new();
186
187        dist.insert(source, 0);
188        queue.push_back(source);
189
190        while let Some(v) = queue.pop_front() {
191            let d = dist.get(&v).copied().unwrap_or(usize::MAX);
192
193            for w in graph.neighbors(v) {
194                if dist.get(&w).copied().unwrap_or(usize::MAX) == usize::MAX {
195                    dist.insert(w, d + 1);
196                    queue.push_back(w);
197                }
198            }
199        }
200
201        // 计算可达节点的总距离
202        let mut total_dist = 0usize;
203        let mut reachable = 0usize;
204
205        for &node in &node_indices {
206            if node != source {
207                let d = dist.get(&node).copied().unwrap_or(usize::MAX);
208                if d != usize::MAX {
209                    total_dist += d;
210                    reachable += 1;
211                }
212            }
213        }
214
215        // 接近中心性 = 可达节点数 / 总距离
216        let closeness = if total_dist > 0 {
217            reachable as f64 / total_dist as f64
218        } else if reachable == 0 {
219            0.0
220        } else {
221            1.0 // 所有节点距离都为 0(只有孤立点)
222        };
223
224        centrality.insert(source, closeness);
225    }
226
227    centrality
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use crate::graph::builders::GraphBuilder;
234
235    #[test]
236    fn test_pagerank_basic() {
237        let graph = GraphBuilder::directed()
238            .with_nodes(vec!["A", "B", "C"])
239            .with_edges(vec![(0, 1, 1.0), (1, 2, 1.0), (2, 0, 1.0)])
240            .build()
241            .unwrap();
242
243        let ranks = pagerank(&graph, 0.85, 20);
244        assert_eq!(ranks.len(), 3);
245
246        // 在环形结构中,所有节点的 PageRank 应该相近
247        let values: Vec<_> = ranks.values().collect();
248        for i in 1..values.len() {
249            assert!((values[i] - values[0]).abs() < 0.01);
250        }
251    }
252
253    #[test]
254    fn test_degree_centrality() {
255        let graph = GraphBuilder::directed()
256            .with_nodes(vec!["A", "B", "C", "D"])
257            .with_edges(vec![(0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0)])
258            .build()
259            .unwrap();
260
261        let centrality = degree_centrality(&graph);
262
263        // A 的出度为 3,中心性最高
264        let a_idx = graph.nodes().next().unwrap().index();
265        assert!(centrality.contains_key(&a_idx));
266    }
267
268    #[test]
269    fn test_betweenness_centrality() {
270        let graph = GraphBuilder::directed()
271            .with_nodes(vec!["A", "B", "C"])
272            .with_edges(vec![(0, 1, 1.0), (1, 2, 1.0)])
273            .build()
274            .unwrap();
275
276        let centrality = betweenness_centrality(&graph);
277        assert_eq!(centrality.len(), 3);
278
279        // B 在 A->C 的路径上,介数中心性应该最高
280        let b_idx = graph.nodes().nth(1).unwrap().index();
281        let b_centrality = centrality.get(&b_idx).copied().unwrap_or(0.0);
282        assert!(b_centrality > 0.0);
283    }
284
285    #[test]
286    fn test_closeness_centrality() {
287        let graph = GraphBuilder::directed()
288            .with_nodes(vec!["A", "B", "C"])
289            .with_edges(vec![(0, 1, 1.0), (1, 2, 1.0)])
290            .build()
291            .unwrap();
292
293        let centrality = closeness_centrality(&graph);
294        assert_eq!(centrality.len(), 3);
295    }
296
297    #[test]
298    fn test_betweenness_centrality_star() {
299        // 星型图:中心节点介数应该最高
300        let graph = GraphBuilder::directed()
301            .with_nodes(vec!["center", "A", "B", "C"])
302            .with_edges(vec![
303                (1, 0, 1.0),
304                (2, 0, 1.0),
305                (3, 0, 1.0), // A, B, C -> center
306                (0, 1, 1.0),
307                (0, 2, 1.0),
308                (0, 3, 1.0), // center -> A, B, C
309            ])
310            .build()
311            .unwrap();
312
313        let centrality = betweenness_centrality(&graph);
314
315        // 中心节点应该是介数最高的
316        let center_idx = graph.nodes().next().unwrap().index();
317        let center_centrality = centrality.get(&center_idx).copied().unwrap_or(0.0);
318
319        // 验证中心节点的介数大于 0
320        assert!(center_centrality > 0.0);
321    }
322
323    #[test]
324    fn test_closeness_centrality_connected() {
325        // 完全连通图的接近中心性应该都接近 1
326        let graph = GraphBuilder::directed()
327            .with_nodes(vec!["A", "B", "C", "D"])
328            .with_edges(vec![
329                (0, 1, 1.0),
330                (0, 2, 1.0),
331                (0, 3, 1.0),
332                (1, 0, 1.0),
333                (1, 2, 1.0),
334                (1, 3, 1.0),
335                (2, 0, 1.0),
336                (2, 1, 1.0),
337                (2, 3, 1.0),
338                (3, 0, 1.0),
339                (3, 1, 1.0),
340                (3, 2, 1.0),
341            ])
342            .build()
343            .unwrap();
344
345        let centrality = closeness_centrality(&graph);
346
347        // 所有节点的接近中心性应该都大于 0
348        for (&node, &cent) in &centrality {
349            assert!(cent > 0.0, "Node {:?} has zero closeness", node);
350        }
351    }
352}