use crate::graph::GraphRef;
use crate::{Error, Result};
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct KatzConfig {
pub alpha: f64,
pub beta: f64,
pub max_iterations: usize,
pub tolerance: f64,
}
impl Default for KatzConfig {
fn default() -> Self {
Self {
alpha: 0.1,
beta: 1.0,
max_iterations: 100,
tolerance: 1e-6,
}
}
}
impl KatzConfig {
pub fn validate(&self) -> Result<()> {
if !self.alpha.is_finite() || self.alpha <= 0.0 {
return Err(Error::InvalidParameter(
"alpha must be finite and > 0".to_string(),
));
}
if self.alpha >= 1.0 {
return Err(Error::InvalidParameter(
"alpha must be < 1 (required for convergence; tighter bound \
alpha < 1/spectral_radius(A) may be needed for large graphs)"
.to_string(),
));
}
if !self.beta.is_finite() {
return Err(Error::InvalidParameter("beta must be finite".to_string()));
}
if self.max_iterations == 0 {
return Err(Error::InvalidParameter(
"max_iterations must be > 0".to_string(),
));
}
if !self.tolerance.is_finite() || self.tolerance <= 0.0 {
return Err(Error::InvalidParameter(
"tolerance must be finite and > 0".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct KatzRun {
pub scores: Vec<f64>,
pub iterations: usize,
pub diff_l1: f64,
pub converged: bool,
}
pub fn katz_centrality<G: GraphRef>(graph: &G, config: KatzConfig) -> Vec<f64> {
katz_centrality_run(graph, config).scores
}
pub fn katz_centrality_run<G: GraphRef>(graph: &G, config: KatzConfig) -> KatzRun {
let n = graph.node_count();
if n == 0 {
return KatzRun {
scores: Vec::new(),
iterations: 0,
diff_l1: 0.0,
converged: true,
};
}
let mut scores = vec![config.beta; n];
let mut new_scores = vec![0.0_f64; n];
let neighbors: Vec<&[usize]> = (0..n).map(|u| graph.neighbors_ref(u)).collect();
let mut iters = 0usize;
let mut last_diff = f64::INFINITY;
let mut converged = false;
for _ in 0..config.max_iterations {
iters += 1;
new_scores.fill(config.beta);
for u in 0..n {
let contrib = config.alpha * scores[u];
for &v in neighbors[u] {
if v < n {
new_scores[v] += contrib;
}
}
}
#[cfg(feature = "simd")]
let diff: f64 = innr::dense_f64::l1_distance_f64(&scores, &new_scores);
#[cfg(not(feature = "simd"))]
let diff: f64 = scores
.iter()
.zip(new_scores.iter())
.map(|(old, new)| (old - new).abs())
.sum();
last_diff = diff;
std::mem::swap(&mut scores, &mut new_scores);
if diff < config.tolerance {
converged = true;
break;
}
}
KatzRun {
scores,
iterations: iters,
diff_l1: last_diff,
converged,
}
}
pub fn katz_centrality_checked<G: GraphRef>(graph: &G, config: KatzConfig) -> Result<Vec<f64>> {
config.validate()?;
Ok(katz_centrality(graph, config))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::GraphRef;
use proptest::prelude::*;
struct VecGraph {
adj: Vec<Vec<usize>>,
}
impl GraphRef for VecGraph {
fn node_count(&self) -> usize {
self.adj.len()
}
fn neighbors_ref(&self, node: usize) -> &[usize] {
&self.adj[node]
}
}
#[test]
fn path_4_nodes_center_outscores_endpoints() {
let g = VecGraph {
adj: vec![
vec![1], vec![0, 2], vec![1, 3], vec![2], ],
};
let scores = katz_centrality(&g, KatzConfig::default());
assert_eq!(scores.len(), 4);
assert!(
scores[1] > scores[0],
"scores[1]={} should exceed scores[0]={}",
scores[1],
scores[0]
);
assert!(
scores[2] > scores[3],
"scores[2]={} should exceed scores[3]={}",
scores[2],
scores[3]
);
}
#[test]
fn triangle_clique_all_equal() {
let g = VecGraph {
adj: vec![vec![1, 2], vec![0, 2], vec![0, 1]],
};
let run = katz_centrality_run(&g, KatzConfig::default());
assert!(run.converged, "triangle should converge");
let s0 = run.scores[0];
for (i, &s) in run.scores.iter().enumerate() {
assert!(
(s - s0).abs() < 1e-6,
"scores[{i}]={s} diverges from scores[0]={s0}"
);
}
}
#[test]
fn alpha_geq_one_rejected_by_checked() {
let g = VecGraph {
adj: vec![vec![1], vec![0]],
};
let cfg = KatzConfig {
alpha: 1.0,
..KatzConfig::default()
};
assert!(
katz_centrality_checked(&g, cfg).is_err(),
"alpha=1.0 should return Err"
);
let cfg_large = KatzConfig {
alpha: 2.5,
..KatzConfig::default()
};
assert!(
katz_centrality_checked(&g, cfg_large).is_err(),
"alpha=2.5 should return Err"
);
}
proptest! {
#[test]
fn proptest_scores_finite_nonneg(
n in 1usize..10,
edges in proptest::collection::vec((0usize..10, 0usize..10), 0..30),
) {
let mut adj = vec![vec![]; n];
for (u, v) in edges {
if u < n && v < n && u != v {
adj[u].push(v);
adj[v].push(u); }
}
for row in &mut adj {
row.sort_unstable();
row.dedup();
}
let g = VecGraph { adj };
let scores = katz_centrality(&g, KatzConfig::default());
prop_assert_eq!(scores.len(), n);
for &s in &scores {
prop_assert!(s.is_finite(), "score is not finite: {s}");
prop_assert!(s >= 0.0, "score is negative: {s}");
}
}
}
}