use crate::node::NodeIndex;
use crate::plugins::algorithm::{
AlgorithmData, AlgorithmResult, FastHashMap, GraphAlgorithm, PluginContext, PluginInfo,
};
use crate::vgi::{Capability, GraphType, VgiResult, VirtualGraph};
use std::any::Any;
use std::collections::VecDeque;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[inline]
fn node_to_idx(node: NodeIndex) -> usize {
node.index()
}
pub struct BetweennessCentralityPlugin {
normalized: bool,
}
impl BetweennessCentralityPlugin {
pub fn new(normalized: bool) -> Self {
Self { normalized }
}
pub fn normalized() -> Self {
Self { normalized: true }
}
pub fn unnormalized() -> Self {
Self { normalized: false }
}
pub fn compute<G>(&self, graph: &G) -> VgiResult<FastHashMap<usize, f64>>
where
G: VirtualGraph + ?Sized,
{
let n = graph.node_count();
if n == 0 {
return Ok(FastHashMap::default());
}
let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
let mut centrality: Vec<f64> = vec![0.0; n];
for source_idx in &node_indices {
let source = source_idx.index();
let mut dist: Vec<i64> = vec![-1; n];
let mut sigma: Vec<f64> = vec![0.0; n];
let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); n];
let mut queue: VecDeque<usize> = VecDeque::new();
let mut stack: Vec<usize> = Vec::with_capacity(n);
dist[source] = 0;
sigma[source] = 1.0;
queue.push_back(source);
while let Some(v) = queue.pop_front() {
stack.push(v);
let v_idx = NodeIndex::new_public(v);
for w_idx in graph.neighbors(v_idx) {
let w = node_to_idx(w_idx);
if dist[w] < 0 {
dist[w] = dist[v] + 1;
queue.push_back(w);
}
if dist[w] == dist[v] + 1 {
sigma[w] += sigma[v];
predecessors[w].push(v);
}
}
}
let mut delta: Vec<f64> = vec![0.0; n];
while let Some(w) = stack.pop() {
for &v in &predecessors[w] {
let sigma_v = sigma[v];
let sigma_w = sigma[w];
if sigma_w > 0.0 {
let coeff = (sigma_v / sigma_w) * (1.0 + delta[w]);
delta[v] += coeff;
}
}
if w != source {
centrality[w] += delta[w];
}
}
}
let scale = if self.normalized && n > 2 {
Some(if graph.graph_type() == GraphType::Directed {
((n - 1) * (n - 2)) as f64
} else {
((n - 1) * (n - 2) / 2) as f64
})
} else {
None
};
let mut result = FastHashMap::default();
result.reserve(n);
if let Some(scale) = scale {
if scale > 0.0 {
for (i, idx) in node_indices.iter().enumerate() {
result.insert(idx.index(), centrality[i] / scale);
}
} else {
for (i, idx) in node_indices.iter().enumerate() {
result.insert(idx.index(), centrality[i]);
}
}
} else {
for (i, idx) in node_indices.iter().enumerate() {
result.insert(idx.index(), centrality[i]);
}
}
Ok(result)
}
#[cfg(feature = "parallel")]
pub fn compute_parallel<G>(&self, graph: &G) -> VgiResult<FastHashMap<usize, f64>>
where
G: VirtualGraph + Sync,
{
let n = graph.node_count();
if n == 0 {
return Ok(FastHashMap::default());
}
let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
let centrality_contributions: Vec<Vec<f64>> = node_indices
.par_iter()
.map(|source_idx| {
let source = source_idx.index();
let mut dist: Vec<i64> = vec![-1; n];
let mut sigma: Vec<f64> = vec![0.0; n];
let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); n];
let mut queue: VecDeque<usize> = VecDeque::new();
let mut stack: Vec<usize> = Vec::with_capacity(n);
dist[source] = 0;
sigma[source] = 1.0;
queue.push_back(source);
while let Some(v) = queue.pop_front() {
stack.push(v);
let v_idx = NodeIndex::new_public(v);
for w_idx in graph.neighbors(v_idx) {
let w = node_to_idx(w_idx);
if dist[w] < 0 {
dist[w] = dist[v] + 1;
queue.push_back(w);
}
if dist[w] == dist[v] + 1 {
sigma[w] += sigma[v];
predecessors[w].push(v);
}
}
}
let mut delta: Vec<f64> = vec![0.0; n];
while let Some(w) = stack.pop() {
for &v in &predecessors[w] {
let sigma_v = sigma[v];
let sigma_w = sigma[w];
if sigma_w > 0.0 {
let coeff = (sigma_v / sigma_w) * (1.0 + delta[w]);
delta[v] += coeff;
}
}
}
delta
})
.collect();
let mut centrality: Vec<f64> = vec![0.0; n];
for contrib in centrality_contributions {
for (i, &delta) in contrib.iter().enumerate() {
centrality[i] += delta;
}
}
let scale = if self.normalized && n > 2 {
Some(if graph.graph_type() == GraphType::Directed {
((n - 1) * (n - 2)) as f64
} else {
((n - 1) * (n - 2) / 2) as f64
})
} else {
None
};
let mut result = FastHashMap::default();
result.reserve(n);
if let Some(scale) = scale {
if scale > 0.0 {
for (i, idx) in node_indices.iter().enumerate() {
result.insert(idx.index(), centrality[i] / scale);
}
} else {
for (i, idx) in node_indices.iter().enumerate() {
result.insert(idx.index(), centrality[i]);
}
}
} else {
for (i, idx) in node_indices.iter().enumerate() {
result.insert(idx.index(), centrality[i]);
}
}
Ok(result)
}
pub fn top_k<G>(&self, graph: &G, k: usize) -> VgiResult<Vec<(usize, f64)>>
where
G: VirtualGraph + ?Sized,
{
let centrality = self.compute(graph)?;
let mut nodes: Vec<(usize, f64)> = centrality.into_iter().collect();
if k >= nodes.len() {
nodes.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
return Ok(nodes);
}
nodes.select_nth_unstable_by(k, |a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
nodes.truncate(k);
Ok(nodes)
}
}
impl Default for BetweennessCentralityPlugin {
fn default() -> Self {
Self::normalized()
}
}
impl GraphAlgorithm for BetweennessCentralityPlugin {
fn info(&self) -> PluginInfo {
PluginInfo::new(
"betweenness-centrality",
"1.0.0",
"介数中心性算法(Brandes 算法)",
)
.with_author("God-Graph Team")
.with_required_capabilities(&[Capability::IncrementalUpdate])
.with_supported_graph_types(&[GraphType::Directed, GraphType::Undirected])
.with_tags(&["centrality", "betweenness", "importance", "bridge"])
}
fn execute<G>(&self, ctx: &mut PluginContext<G>) -> VgiResult<AlgorithmResult>
where
G: VirtualGraph + ?Sized,
{
let normalized = ctx.get_config_or("normalized", "true") == "true";
let top_k = ctx.get_config_as("top_k", 0usize);
let plugin = if normalized {
BetweennessCentralityPlugin::normalized()
} else {
BetweennessCentralityPlugin::unnormalized()
};
ctx.report_progress(0.1);
let result = if top_k > 0 {
let top_nodes = plugin.top_k(ctx.graph, top_k)?;
let nodes: Vec<usize> = top_nodes.iter().map(|(id, _)| *id).collect();
let scores: FastHashMap<usize, f64> = top_nodes.into_iter().collect();
AlgorithmResult::new("betweenness_top_k", AlgorithmData::NodeList(nodes))
.with_metadata("top_k", top_k.to_string())
.with_metadata("scores", format!("{:?}", scores))
} else {
let centrality = plugin.compute(ctx.graph)?;
AlgorithmResult::new(
"betweenness_centrality",
AlgorithmData::NodeValues(centrality.clone()),
)
}
.with_metadata("normalized", normalized.to_string())
.with_metadata("algorithm", "betweenness-centrality")
.with_metadata("node_count", ctx.graph.node_count().to_string());
ctx.report_progress(1.0);
Ok(result)
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::Graph;
use crate::graph::traits::GraphOps;
fn create_line_graph() -> Graph<String, ()> {
let mut graph = Graph::<String, ()>::undirected();
let a = graph.add_node("A".to_string()).unwrap();
let b = graph.add_node("B".to_string()).unwrap();
let c = graph.add_node("C".to_string()).unwrap();
let d = graph.add_node("D".to_string()).unwrap();
graph.add_edge(a, b, ()).unwrap();
graph.add_edge(b, a, ()).unwrap();
graph.add_edge(b, c, ()).unwrap();
graph.add_edge(c, b, ()).unwrap();
graph.add_edge(c, d, ()).unwrap();
graph.add_edge(d, c, ()).unwrap();
graph
}
fn create_star_graph() -> Graph<String, ()> {
let mut graph = Graph::<String, ()>::undirected();
let center = graph.add_node("Center".to_string()).unwrap();
let leaves: Vec<NodeIndex> = (0..5)
.map(|i| graph.add_node(format!("Leaf{}", i)).unwrap())
.collect();
for leaf in leaves {
graph.add_edge(center, leaf, ()).unwrap();
graph.add_edge(leaf, center, ()).unwrap();
}
graph
}
#[test]
fn test_betweenness_line_graph() {
let graph = create_line_graph();
let plugin = BetweennessCentralityPlugin::unnormalized();
let centrality = plugin.compute(&graph).unwrap();
let b_centrality = centrality.get(&1).copied().unwrap_or(0.0);
let c_centrality = centrality.get(&2).copied().unwrap_or(0.0);
let a_centrality = centrality.get(&0).copied().unwrap_or(0.0);
let d_centrality = centrality.get(&3).copied().unwrap_or(0.0);
assert_eq!(a_centrality, 0.0);
assert_eq!(d_centrality, 0.0);
assert!(b_centrality > 0.0);
assert!(c_centrality > 0.0);
assert!((b_centrality - c_centrality).abs() < 1e-10);
}
#[test]
fn test_betweenness_star_graph() {
let graph = create_star_graph();
let plugin = BetweennessCentralityPlugin::unnormalized();
let centrality = plugin.compute(&graph).unwrap();
let center_centrality = centrality.get(&0).copied().unwrap_or(0.0);
for i in 1..=5 {
let leaf_centrality = centrality.get(&i).copied().unwrap_or(0.0);
assert_eq!(leaf_centrality, 0.0);
}
assert!(center_centrality > 0.0);
}
#[test]
fn test_betweenness_empty_graph() {
let graph = Graph::<String, ()>::undirected();
let plugin = BetweennessCentralityPlugin::default();
let centrality = plugin.compute(&graph).unwrap();
assert!(centrality.is_empty());
}
#[test]
fn test_betweenness_normalized() {
let graph = create_line_graph();
let plugin = BetweennessCentralityPlugin::normalized();
let centrality = plugin.compute(&graph).unwrap();
for &value in centrality.values() {
assert!((0.0..=1.0).contains(&value));
}
}
#[test]
fn test_betweenness_top_k() {
let graph = create_line_graph();
let plugin = BetweennessCentralityPlugin::unnormalized();
let top_nodes = plugin.top_k(&graph, 2).unwrap();
assert_eq!(top_nodes.len(), 2);
let top_ids: Vec<usize> = top_nodes.iter().map(|(id, _)| *id).collect();
assert!(top_ids.contains(&1));
assert!(top_ids.contains(&2));
}
#[test]
fn test_betweenness_plugin_info() {
let plugin = BetweennessCentralityPlugin::default();
let info = plugin.info();
assert_eq!(info.name, "betweenness-centrality");
assert_eq!(info.version, "1.0.0");
assert!(info.tags.contains(&"centrality".to_string()));
}
}