use crate::KnowledgeGraph;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy)]
pub struct PprConfig {
pub damping: f64,
pub max_iterations: usize,
pub tolerance: f64,
}
impl Default for PprConfig {
fn default() -> Self {
Self {
damping: 0.85,
max_iterations: 100,
tolerance: 1e-6,
}
}
}
#[must_use]
pub fn personalized_pagerank(
kg: &KnowledgeGraph,
seed: &str,
config: PprConfig,
) -> HashMap<String, f64> {
let graph = kg.as_petgraph();
let n = graph.node_count();
if n == 0 {
return HashMap::new();
}
let seed_id = crate::EntityId::from(seed);
let seed_idx = match kg.get_node_index(&seed_id) {
Some(idx) => idx.index(),
None => return HashMap::new(),
};
let mut personalization = vec![0.0; n];
personalization[seed_idx] = 1.0;
let gp_config = graphops::pagerank::PageRankConfig {
damping: config.damping,
max_iterations: config.max_iterations,
tolerance: config.tolerance,
};
let scores = graphops::ppr::personalized_pagerank(graph, gp_config, &personalization);
let mut result = HashMap::with_capacity(n);
for (idx, score) in scores.into_iter().enumerate() {
let entity = &graph[petgraph::graph::NodeIndex::new(idx)];
result.insert(entity.id.as_str().to_owned(), score);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Triple;
#[test]
fn ppr_seed_scores_highest() {
let mut kg = KnowledgeGraph::new();
kg.add_triple(Triple::new("A", "rel", "B"));
kg.add_triple(Triple::new("B", "rel", "C"));
kg.add_triple(Triple::new("C", "rel", "A"));
let scores = personalized_pagerank(&kg, "A", PprConfig::default());
let a = *scores.get("A").unwrap();
let b = *scores.get("B").unwrap();
let c = *scores.get("C").unwrap();
assert!(a > b, "Seed A ({a}) should score higher than B ({b})");
assert!(a > c, "Seed A ({a}) should score higher than C ({c})");
}
#[test]
fn ppr_missing_seed_returns_empty() {
let mut kg = KnowledgeGraph::new();
kg.add_triple(Triple::new("A", "rel", "B"));
let scores = personalized_pagerank(&kg, "Z", PprConfig::default());
assert!(scores.is_empty());
}
#[test]
fn ppr_empty_graph() {
let kg = KnowledgeGraph::new();
let scores = personalized_pagerank(&kg, "A", PprConfig::default());
assert!(scores.is_empty());
}
#[test]
fn ppr_scores_sum_to_one() {
let mut kg = KnowledgeGraph::new();
kg.add_triple(Triple::new("A", "rel", "B"));
kg.add_triple(Triple::new("B", "rel", "C"));
kg.add_triple(Triple::new("C", "rel", "A"));
kg.add_triple(Triple::new("A", "rel", "D"));
let scores = personalized_pagerank(&kg, "A", PprConfig::default());
let total: f64 = scores.values().sum();
assert!(
(total - 1.0).abs() < 1e-6,
"Scores should sum to 1.0, got {total}",
);
}
}