use std::collections::VecDeque;
use super::params::AlgoParams;
use super::result::AlgoResultBatch;
use crate::engine::graph::algo::GraphAlgorithm;
use crate::engine::graph::csr::CsrIndex;
pub fn run(csr: &CsrIndex, params: &AlgoParams) -> AlgoResultBatch {
let n = csr.node_count();
if n == 0 {
return AlgoResultBatch::new(GraphAlgorithm::Betweenness);
}
let mut cb = vec![0.0f64; n];
let sources: Vec<usize> = match params.sample_size {
Some(sample) if sample < n => {
let mut state: u64 = (n as u64).wrapping_mul(0x517cc1b727220a95).wrapping_add(1);
let mut selected = Vec::with_capacity(sample);
let mut used = vec![false; n];
while selected.len() < sample {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1);
let idx = (state >> 33) as usize % n;
if !used[idx] {
used[idx] = true;
selected.push(idx);
}
}
selected
}
_ => (0..n).collect(),
};
let scale = if params.sample_size.is_some() && sources.len() < n {
n as f64 / sources.len() as f64
} else {
1.0
};
for &s in &sources {
brandes_from_source(csr, s, n, &mut cb);
}
if scale != 1.0 {
for c in cb.iter_mut() {
*c *= scale;
}
}
let mut scored: Vec<(usize, f64)> = cb.into_iter().enumerate().collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut batch = AlgoResultBatch::new(GraphAlgorithm::Betweenness);
for (node, centrality) in scored {
batch.push_node_f64(csr.node_name(node as u32).to_string(), centrality);
}
batch
}
fn brandes_from_source(csr: &CsrIndex, s: usize, n: usize, cb: &mut [f64]) {
let mut stack: Vec<usize> = Vec::with_capacity(n);
let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); n];
let mut sigma = vec![0.0f64; n]; let mut dist = vec![-1i64; n]; let mut delta = vec![0.0f64; n];
sigma[s] = 1.0;
dist[s] = 0;
let mut queue = VecDeque::new();
queue.push_back(s);
while let Some(v) = queue.pop_front() {
stack.push(v);
let neighbors = undirected_neighbors(csr, v as u32);
for w in neighbors {
let w = w as usize;
if dist[w] < 0 {
dist[w] = dist[v] + 1;
queue.push_back(w);
}
if dist[w] == dist[v] + 1 {
sigma[w] += sigma[v];
predecessors[w].push(v);
}
}
}
while let Some(w) = stack.pop() {
for &v in &predecessors[w] {
let contrib = (sigma[v] / sigma[w]) * (1.0 + delta[w]);
delta[v] += contrib;
}
if w != s {
cb[w] += delta[w];
}
}
}
use super::util::undirected_neighbors;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn betweenness_path_graph() {
let mut csr = CsrIndex::new();
csr.add_edge("a", "L", "b");
csr.add_edge("b", "L", "c");
csr.add_edge("c", "L", "d");
csr.compact();
let batch = run(&csr, &AlgoParams::default());
let json = batch.to_json().unwrap();
let rows: Vec<serde_json::Value> = serde_json::from_slice(&json).unwrap();
let map: std::collections::HashMap<&str, f64> = rows
.iter()
.map(|r| {
(
r["node_id"].as_str().unwrap(),
r["centrality"].as_f64().unwrap(),
)
})
.collect();
assert!(map["b"] > map["a"]);
assert!(map["c"] > map["d"]);
assert!(map["a"].abs() < 1e-9);
assert!(map["d"].abs() < 1e-9);
}
#[test]
fn betweenness_triangle() {
let mut csr = CsrIndex::new();
for (s, d) in &[
("a", "b"),
("b", "a"),
("b", "c"),
("c", "b"),
("a", "c"),
("c", "a"),
] {
csr.add_edge(s, "L", d);
}
csr.compact();
let batch = run(&csr, &AlgoParams::default());
let json = batch.to_json().unwrap();
let rows: Vec<serde_json::Value> = serde_json::from_slice(&json).unwrap();
for row in &rows {
let c = row["centrality"].as_f64().unwrap();
assert!(
c.abs() < 1e-9,
"node {} has BC {c}, expected 0",
row["node_id"]
);
}
}
#[test]
fn betweenness_star() {
let mut csr = CsrIndex::new();
csr.add_edge("a", "L", "b");
csr.add_edge("a", "L", "c");
csr.add_edge("a", "L", "d");
csr.compact();
let batch = run(&csr, &AlgoParams::default());
let json = batch.to_json().unwrap();
let rows: Vec<serde_json::Value> = serde_json::from_slice(&json).unwrap();
let map: std::collections::HashMap<&str, f64> = rows
.iter()
.map(|r| {
(
r["node_id"].as_str().unwrap(),
r["centrality"].as_f64().unwrap(),
)
})
.collect();
assert!(map["a"] > map["b"]);
assert!(map["b"].abs() < 1e-9);
}
#[test]
fn betweenness_with_sampling() {
let mut csr = CsrIndex::new();
for i in 0..20 {
csr.add_edge(&format!("n{i}"), "L", &format!("n{}", i + 1));
}
csr.compact();
let params = AlgoParams {
sample_size: Some(5),
..Default::default()
};
let batch = run(&csr, ¶ms);
assert_eq!(batch.len(), 21);
}
#[test]
fn betweenness_empty() {
let csr = CsrIndex::new();
assert!(run(&csr, &AlgoParams::default()).is_empty());
}
}