use crate::core::error::{GraphinaError, Result};
use crate::core::types::{BaseGraph, GraphConstructor, NodeId, NodeMap};
use nalgebra::{DMatrix, DVector};
pub fn katz_centrality<A, W, Ty>(
graph: &BaseGraph<A, W, Ty>,
alpha: f64,
beta: Option<&dyn Fn(NodeId) -> f64>,
max_iter: usize,
tolerance: f64,
) -> Result<NodeMap<f64>>
where
W: Copy + PartialOrd + Into<f64>,
Ty: GraphConstructor<A, W>,
{
let n = graph.node_count();
if n == 0 {
return Ok(NodeMap::new());
}
let node_list: Vec<NodeId> = graph.nodes().map(|(node, _)| node).collect();
let mut node_to_idx = std::collections::HashMap::new();
for (idx, &node) in node_list.iter().enumerate() {
node_to_idx.insert(node, idx);
}
let mut adj = DMatrix::<f64>::zeros(n, n);
for (u, v, w) in graph.edges() {
let ui = node_to_idx[&u];
let vi = node_to_idx[&v];
let weight: f64 = (*w).into();
adj[(ui, vi)] += weight;
}
let mut x = DVector::<f64>::from_element(n, 0.0);
let beta_vec = if let Some(b) = beta {
DVector::from_fn(n, |idx, _| b(node_list[idx]))
} else {
DVector::from_element(n, 1.0)
};
let mut converged = false;
for _ in 0..max_iter {
let x_new = alpha * &adj * &x + &beta_vec;
if (&x_new - &x).norm() < tolerance {
x = x_new;
converged = true;
break;
}
x = x_new;
}
if !converged {
return Err(GraphinaError::convergence_failed(
max_iter,
"Katz centrality failed to converge within maximum iterations",
));
}
let mut centrality = NodeMap::new();
for (idx, &val) in x.iter().enumerate() {
centrality.insert(node_list[idx], val);
}
Ok(centrality)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::types::{Digraph, Graph};
#[test]
fn test_katz_simple_directed() {
let mut graph: Digraph<i32, f64> = Digraph::new();
let n1 = graph.add_node(1);
let n2 = graph.add_node(2);
let n3 = graph.add_node(3);
graph.add_edge(n1, n2, 1.0);
graph.add_edge(n2, n3, 1.0);
graph.add_edge(n3, n1, 1.0);
let katz = katz_centrality(&graph, 0.1, None, 100, 1e-6).unwrap();
let k1 = katz[&n1];
let k2 = katz[&n2];
let k3 = katz[&n3];
assert!((k1 - k2).abs() < 1e-3);
assert!((k2 - k3).abs() < 1e-3);
}
#[test]
fn test_katz_empty_graph() {
let graph: Graph<i32, f64> = Graph::new();
let result = katz_centrality(&graph, 0.1, None, 100, 1e-6).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_katz_with_beta() {
let mut graph: Graph<i32, f64> = Graph::new();
let n1 = graph.add_node(1);
let n2 = graph.add_node(2);
graph.add_edge(n1, n2, 1.0);
let beta_fn = |node: NodeId| if node == n1 { 2.0 } else { 1.0 };
let katz = katz_centrality(&graph, 0.1, Some(&beta_fn), 100, 1e-6).unwrap();
assert!(katz[&n1] > katz[&n2]);
}
}