use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct MemoryGraph {
edges: HashMap<u64, Vec<u64>>,
incoming: HashMap<u64, Vec<u64>>,
node_count: usize,
}
impl MemoryGraph {
pub fn new() -> Self {
Self::default()
}
pub fn add_edge(&mut self, from: u64, to: u64) {
if from == to {
return; }
self.edges.entry(from).or_default();
self.edges.entry(to).or_default();
self.incoming.entry(from).or_default();
self.incoming.entry(to).or_default();
let neighbors = self
.edges
.get_mut(&from)
.expect("entry(or_default) guarantees existence");
if !neighbors.contains(&to) {
neighbors.push(to);
self.incoming
.get_mut(&to)
.expect("entry(or_default) guarantees existence")
.push(from);
}
self.node_count = self.edges.len();
}
pub fn link(&mut self, a: u64, b: u64) {
self.add_edge(a, b);
self.add_edge(b, a);
}
pub fn node_count(&self) -> usize {
self.node_count
}
pub fn neighbors(&self, node: u64) -> &[u64] {
self.edges.get(&node).map(|v| v.as_slice()).unwrap_or(&[])
}
pub fn pagerank(
&self,
damping: f64,
iterations: usize,
initial_scores: Option<&HashMap<u64, f64>>,
) -> HashMap<u64, f64> {
if self.node_count == 0 {
return HashMap::new();
}
let n = self.node_count as f64;
let base = 1.0 / n;
let mut scores: HashMap<u64, f64> = self
.edges
.keys()
.map(|&k| {
let init = initial_scores
.and_then(|m| m.get(&k))
.copied()
.unwrap_or(base);
(k, init)
})
.collect();
let out_degree: HashMap<u64, usize> =
self.edges.iter().map(|(&k, v)| (k, v.len())).collect();
for _ in 0..iterations {
let mut new_scores = HashMap::with_capacity(self.node_count);
let sink_sum: f64 = scores
.iter()
.filter(|(&k, _)| out_degree.get(&k).copied().unwrap_or(0) == 0)
.map(|(_, &s)| s)
.sum();
for &node in self.edges.keys() {
let incoming_sum: f64 = self
.incoming
.get(&node)
.map(|neighbors| {
neighbors
.iter()
.map(|&src| {
let src_out = out_degree.get(&src).copied().unwrap_or(1) as f64;
scores.get(&src).copied().unwrap_or(0.0) / src_out
})
.sum()
})
.unwrap_or(0.0);
let rank = (1.0 - damping) / n + damping * (incoming_sum + sink_sum / n);
new_scores.insert(node, rank);
}
scores = new_scores;
}
scores
}
pub fn from_co_access(sessions: &[Vec<u64>]) -> Self {
let mut graph = Self::new();
for session in sessions {
for i in 0..session.len() {
for j in (i + 1)..session.len() {
graph.link(session[i], session[j]);
}
}
}
graph
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_graph() {
let graph = MemoryGraph::new();
let scores = graph.pagerank(0.85, 20, None);
assert!(scores.is_empty());
}
#[test]
fn test_single_node() {
let mut graph = MemoryGraph::new();
graph.add_edge(1, 1); let scores = graph.pagerank(0.85, 20, None);
assert!(scores.is_empty() || scores.values().all(|&v| v > 0.0));
}
#[test]
fn test_two_nodes() {
let mut graph = MemoryGraph::new();
graph.link(1, 2);
let scores = graph.pagerank(0.85, 50, None);
assert_eq!(scores.len(), 2);
let s1 = scores.get(&1).unwrap();
let s2 = scores.get(&2).unwrap();
assert!(
(s1 - s2).abs() < 0.01,
"Symmetric graph should have equal scores"
);
}
#[test]
fn test_hub_authority() {
let mut graph = MemoryGraph::new();
graph.add_edge(1, 2);
graph.add_edge(1, 3);
graph.add_edge(1, 4);
graph.add_edge(2, 1);
graph.add_edge(3, 1);
graph.add_edge(4, 1);
let scores = graph.pagerank(0.85, 50, None);
let s1 = scores.get(&1).unwrap();
for &node in &[2u64, 3, 4] {
let sn = scores.get(&node).unwrap();
assert!(*s1 >= *sn, "Hub node should have >= score than leaf");
}
}
#[test]
fn test_from_co_access() {
let sessions = vec![
vec![1, 2, 3], vec![2, 4], ];
let graph = MemoryGraph::from_co_access(&sessions);
assert_eq!(graph.node_count(), 4);
let scores = graph.pagerank(0.85, 50, None);
let s2 = scores.get(&2).unwrap();
for &node in &[1u64, 3, 4] {
let sn = scores.get(&node).unwrap();
assert!(*s2 >= *sn, "Node 2 should have highest score");
}
}
#[test]
fn test_initial_scores_influence() {
let mut graph = MemoryGraph::new();
graph.add_edge(1, 2);
let initial = HashMap::from([(1u64, 10.0), (2u64, 0.1)]);
let scores = graph.pagerank(0.85, 5, Some(&initial));
let s1 = scores.get(&1).unwrap();
let s2 = scores.get(&2).unwrap();
assert!(*s1 > 0.0, "Node 1 should have positive score");
assert!(*s2 > 0.0, "Node 2 should have positive score");
}
}