use serde::{Deserialize, Serialize};
use crate::{
core::entities::VID,
db::{
api::state::{GenericNodeState, TypedNodeState},
graph::node::NodeView,
},
prelude::{GraphViewOps, NodeViewOps},
};
use std::collections::{HashMap, VecDeque};
#[derive(Clone, PartialEq, Serialize, Deserialize, Debug)]
pub struct BetweennessCentrality {
pub betweenness_centrality: f64,
}
pub fn betweenness_centrality<'graph, G: GraphViewOps<'graph>>(
g: &G,
k: Option<usize>,
normalized: bool,
) -> TypedNodeState<'graph, BetweennessCentrality, G> {
let mut betweenness: Vec<f64> = vec![0.0; g.unfiltered_num_nodes()];
let nodes = g.nodes();
let n = g.count_nodes();
let k_sample = k.unwrap_or(n);
for node in nodes.iter().take(k_sample) {
let mut stack = Vec::new();
let mut predecessors: HashMap<usize, Vec<usize>> = HashMap::new();
let mut sigma: HashMap<usize, f64> = HashMap::new();
let mut dist: HashMap<usize, i64> = HashMap::new();
let mut queue = VecDeque::new();
for node in nodes.iter() {
dist.insert(node.node.0, -1);
sigma.insert(node.node.0, 0.0);
}
dist.insert(node.node.0, 0);
sigma.insert(node.node.0, 1.0);
queue.push_back(node.node.0);
while let Some(current_node_id) = queue.pop_front() {
stack.push(current_node_id);
for neighbor in
NodeView::new_internal(g.clone(), VID::from(current_node_id)).out_neighbours()
{
if dist[&neighbor.node.0] < 0 {
queue.push_back(neighbor.node.0);
dist.insert(neighbor.node.0, dist[¤t_node_id] + 1);
}
if dist[&neighbor.node.0] == dist[¤t_node_id] + 1 {
sigma.insert(
neighbor.node.0,
sigma[&neighbor.node.0] + sigma[¤t_node_id],
);
predecessors
.entry(neighbor.node.0)
.or_default()
.push(current_node_id);
}
}
}
let mut delta: HashMap<usize, f64> = HashMap::new();
for node in nodes.iter() {
delta.insert(node.node.0, 0.0);
}
while let Some(w) = stack.pop() {
for v in predecessors.get(&w).unwrap_or(&Vec::new()) {
let coeff = (sigma[v] / sigma[&w]) * (1.0 + delta[&w]);
let new_delta_v = delta[v] + coeff;
delta.insert(*v, new_delta_v);
}
if w != node.node.0 {
betweenness[w] += delta[&w];
}
}
}
if normalized {
let factor = 1.0 / ((n as f64 - 1.0) * (n as f64 - 2.0));
for node in nodes.iter() {
betweenness[node.node.index()] *= factor;
}
}
TypedNodeState::new(GenericNodeState::new_from_eval_mapped(
g.clone(),
betweenness,
|value| BetweennessCentrality {
betweenness_centrality: value,
},
None,
))
}