use crate::KnowledgeGraph;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy)]
pub struct KatzConfig {
pub alpha: f64,
pub beta: f64,
pub max_iterations: usize,
pub tolerance: f64,
pub undirected: bool,
pub normalized: bool,
}
impl Default for KatzConfig {
fn default() -> Self {
Self {
alpha: 0.1,
beta: 1.0,
max_iterations: 100,
tolerance: 1e-6,
undirected: false,
normalized: true,
}
}
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn katz_centrality(kg: &KnowledgeGraph, config: KatzConfig) -> HashMap<String, f64> {
let graph = kg.as_petgraph();
let n = graph.node_count();
if n == 0 {
return HashMap::new();
}
assert!(
config.alpha > 0.0 && config.alpha < 1.0,
"Katz alpha must be in (0, 1) for convergence, got {}",
config.alpha
);
let mut scores = vec![config.beta; n];
let mut new_scores = vec![0.0; n];
for _iter in 0..config.max_iterations {
for idx in graph.node_indices() {
let i = idx.index();
let predecessors: Vec<_> = if config.undirected {
graph.neighbors_undirected(idx).collect()
} else {
graph
.neighbors_directed(idx, petgraph::Direction::Incoming)
.collect()
};
let pred_sum: f64 = predecessors.iter().map(|p| scores[p.index()]).sum();
new_scores[i] = config.alpha * pred_sum + config.beta;
}
let diff: f64 = scores
.iter()
.zip(new_scores.iter())
.map(|(old, new)| (old - new).abs())
.sum();
std::mem::swap(&mut scores, &mut new_scores);
if diff < config.tolerance {
break;
}
}
if config.normalized {
let max = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if max > 0.0 {
for s in &mut scores {
*s /= max;
}
}
}
graph
.node_indices()
.map(|idx| (graph[idx].id.as_str().to_owned(), scores[idx.index()]))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Triple;
#[test]
fn test_katz_chain() {
let mut kg = KnowledgeGraph::new();
kg.add_triple(Triple::new("A", "rel", "B"));
kg.add_triple(Triple::new("B", "rel", "C"));
let config = KatzConfig {
normalized: false,
..Default::default()
};
let scores = katz_centrality(&kg, config);
let a = *scores.get("A").unwrap();
let b = *scores.get("B").unwrap();
let c = *scores.get("C").unwrap();
assert!(c > b, "C={c} should be > B={b}");
assert!(b > a, "B={b} should be > A={a}");
}
#[test]
fn test_katz_isolated() {
let mut kg = KnowledgeGraph::new();
kg.add_triple(Triple::new("A", "rel", "B"));
kg.add_triple(Triple::new("C", "rel", "D"));
let config = KatzConfig::default();
let scores = katz_centrality(&kg, config);
for (name, score) in &scores {
assert!(*score > 0.0, "{name} should have positive score: {score}");
}
}
#[test]
fn test_katz_vs_eigenvector() {
let mut kg = KnowledgeGraph::new();
kg.add_triple(Triple::new("A", "rel", "B"));
kg.add_triple(Triple::new("B", "rel", "A"));
kg.add_triple(Triple::new("B", "rel", "C"));
kg.add_triple(Triple::new("C", "rel", "B"));
let config = KatzConfig {
alpha: 0.1,
beta: 0.0,
normalized: true,
..Default::default()
};
let scores = katz_centrality(&kg, config);
let b = *scores.get("B").unwrap();
let a = *scores.get("A").unwrap();
let c = *scores.get("C").unwrap();
assert!(b >= a, "B={b} should be >= A={a}");
assert!(b >= c, "B={b} should be >= C={c}");
}
#[test]
fn test_katz_normalized() {
let mut kg = KnowledgeGraph::new();
kg.add_triple(Triple::new("A", "rel", "B"));
kg.add_triple(Triple::new("B", "rel", "C"));
let config = KatzConfig {
normalized: true,
..Default::default()
};
let scores = katz_centrality(&kg, config);
let max = scores.values().cloned().fold(f64::NEG_INFINITY, f64::max);
assert!((max - 1.0).abs() < 1e-6, "Max should be 1.0: {max}");
}
}