use std::collections::HashMap;
use crate::storage::Hexastore;
#[derive(Debug, Clone)]
pub struct PageRankOptions {
pub damping: f64,
pub max_iterations: usize,
pub tolerance: f64,
}
impl Default for PageRankOptions {
fn default() -> Self {
Self {
damping: 0.85,
max_iterations: 100,
tolerance: 1e-6,
}
}
}
#[derive(Debug, Clone)]
pub struct PageRankResult {
pub scores: HashMap<u64, f64>,
pub iterations: usize,
pub converged: bool,
pub max_change: f64,
}
pub fn pagerank(
store: &dyn Hexastore,
predicate_id: Option<u64>,
options: PageRankOptions,
) -> PageRankResult {
let (nodes, outgoing, incoming) = build_graph_structure(store, predicate_id);
let n = nodes.len();
if n == 0 {
return PageRankResult {
scores: HashMap::new(),
iterations: 0,
converged: true,
max_change: 0.0,
};
}
let initial_score = 1.0 / n as f64;
let mut scores: HashMap<u64, f64> = nodes.iter().map(|&id| (id, initial_score)).collect();
let mut new_scores: HashMap<u64, f64> = HashMap::with_capacity(n);
let damping = options.damping;
let random_jump = (1.0 - damping) / n as f64;
let mut iterations = 0;
let mut converged = false;
let mut max_change = f64::MAX;
for _ in 0..options.max_iterations {
iterations += 1;
max_change = 0.0;
let dangling_sum: f64 = nodes
.iter()
.filter(|&&id| outgoing.get(&id).is_none_or(|v| v.is_empty()))
.map(|&id| scores.get(&id).unwrap_or(&0.0))
.sum();
let dangling_contribution = damping * dangling_sum / n as f64;
for &node_id in &nodes {
let mut link_score = 0.0;
if let Some(in_neighbors) = incoming.get(&node_id) {
for &neighbor_id in in_neighbors {
let neighbor_score = *scores.get(&neighbor_id).unwrap_or(&0.0);
let neighbor_out_degree =
outgoing.get(&neighbor_id).map_or(1, |v| v.len().max(1));
link_score += neighbor_score / neighbor_out_degree as f64;
}
}
let new_score = random_jump + dangling_contribution + damping * link_score;
let old_score = *scores.get(&node_id).unwrap_or(&0.0);
let change = (new_score - old_score).abs();
max_change = max_change.max(change);
new_scores.insert(node_id, new_score);
}
std::mem::swap(&mut scores, &mut new_scores);
new_scores.clear();
if max_change < options.tolerance {
converged = true;
break;
}
}
let total: f64 = scores.values().sum();
if total > 0.0 {
for score in scores.values_mut() {
*score /= total;
}
}
PageRankResult {
scores,
iterations,
converged,
max_change,
}
}
type GraphStructure = (Vec<u64>, HashMap<u64, Vec<u64>>, HashMap<u64, Vec<u64>>);
fn build_graph_structure(store: &dyn Hexastore, predicate_id: Option<u64>) -> GraphStructure {
let mut nodes_set = std::collections::HashSet::new();
let mut outgoing: HashMap<u64, Vec<u64>> = HashMap::new();
let mut incoming: HashMap<u64, Vec<u64>> = HashMap::new();
let iter = store.query(None, predicate_id, None);
for triple in iter {
let subject = triple.subject_id;
let object = triple.object_id;
nodes_set.insert(subject);
nodes_set.insert(object);
outgoing.entry(subject).or_default().push(object);
incoming.entry(object).or_default().push(subject);
}
let nodes: Vec<u64> = nodes_set.into_iter().collect();
(nodes, outgoing, incoming)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Triple;
use crate::storage::memory::MemoryHexastore;
fn create_simple_graph() -> MemoryHexastore {
let mut store = MemoryHexastore::new();
store.insert(&Triple::new(1, 10, 2)).unwrap(); store.insert(&Triple::new(2, 10, 3)).unwrap(); store.insert(&Triple::new(2, 10, 4)).unwrap();
store
}
fn create_cyclic_graph() -> MemoryHexastore {
let mut store = MemoryHexastore::new();
store.insert(&Triple::new(1, 10, 2)).unwrap();
store.insert(&Triple::new(2, 10, 3)).unwrap();
store.insert(&Triple::new(3, 10, 1)).unwrap();
store
}
#[test]
fn test_pagerank_simple() {
let store = create_simple_graph();
let result = pagerank(&store, None, PageRankOptions::default());
assert!(result.converged);
assert!(result.iterations > 0);
for &score in result.scores.values() {
assert!(score > 0.0);
}
let total: f64 = result.scores.values().sum();
assert!((total - 1.0).abs() < 0.001);
let score_a = *result.scores.get(&1).unwrap_or(&0.0);
let score_b = *result.scores.get(&2).unwrap_or(&0.0);
let score_c = *result.scores.get(&3).unwrap_or(&0.0);
assert!(score_b > score_a);
assert!(score_c > 0.0);
}
#[test]
fn test_pagerank_cyclic() {
let store = create_cyclic_graph();
let result = pagerank(&store, None, PageRankOptions::default());
assert!(result.converged);
let scores: Vec<f64> = result.scores.values().cloned().collect();
let first = scores[0];
for score in &scores {
assert!((score - first).abs() < 0.01);
}
}
#[test]
fn test_pagerank_empty_graph() {
let store = MemoryHexastore::new();
let result = pagerank(&store, None, PageRankOptions::default());
assert!(result.converged);
assert_eq!(result.iterations, 0);
assert!(result.scores.is_empty());
}
#[test]
fn test_pagerank_custom_options() {
let store = create_simple_graph();
let options = PageRankOptions {
damping: 0.5,
max_iterations: 10,
tolerance: 1e-3,
};
let result = pagerank(&store, None, options);
assert!(result.iterations <= 10);
let total: f64 = result.scores.values().sum();
assert!((total - 1.0).abs() < 0.001);
}
}