nodedb_cluster/distributed_graph/
pagerank.rs1use std::collections::HashMap;
6
7#[derive(Debug)]
9pub struct ShardPageRankState {
10 pub vertex_count: usize,
11 pub rank: Vec<f64>,
12 pub next_rank: Vec<f64>,
13 pub out_degrees: Vec<usize>,
14 pub is_dangling: Vec<bool>,
15 pub boundary_edges: HashMap<u32, Vec<(String, u16)>>,
16 pub incoming_contributions: HashMap<String, f64>,
17}
18
19impl ShardPageRankState {
20 pub fn init<F>(
22 vertex_count: usize,
23 out_degrees: Vec<usize>,
24 _ghost_lookup: F,
25 csr_out_edges: &dyn Fn(u32) -> Vec<(String, bool, u16)>,
26 ) -> Self
27 where
28 F: Fn(&str) -> Option<u16>,
29 {
30 let init_rank = if vertex_count > 0 {
31 1.0 / vertex_count as f64
32 } else {
33 0.0
34 };
35
36 let rank = vec![init_rank; vertex_count];
37 let next_rank = vec![0.0; vertex_count];
38 let is_dangling: Vec<bool> = out_degrees.iter().map(|&d| d == 0).collect();
39
40 let mut boundary_edges: HashMap<u32, Vec<(String, u16)>> = HashMap::new();
41 for node in 0..vertex_count {
42 for (dst_name, is_ghost, target_shard) in csr_out_edges(node as u32) {
43 if is_ghost {
44 boundary_edges
45 .entry(node as u32)
46 .or_default()
47 .push((dst_name, target_shard));
48 }
49 }
50 }
51
52 Self {
53 vertex_count,
54 rank,
55 next_rank,
56 out_degrees,
57 is_dangling,
58 boundary_edges,
59 incoming_contributions: HashMap::new(),
60 }
61 }
62
63 pub fn superstep(
65 &mut self,
66 damping: f64,
67 global_n: usize,
68 local_edge_iter: &dyn Fn(u32) -> Vec<u32>,
69 ) -> (f64, HashMap<u16, Vec<(String, f64)>>) {
70 let n = global_n as f64;
71 let teleport = (1.0 - damping) / n;
72
73 let dangling_sum: f64 = self
74 .rank
75 .iter()
76 .enumerate()
77 .filter(|(i, _)| self.is_dangling[*i])
78 .map(|(_, r)| r)
79 .sum();
80
81 let base = teleport + damping * dangling_sum / n;
82
83 for r in self.next_rank.iter_mut() {
84 *r = base;
85 }
86
87 let mut outbound: HashMap<u16, Vec<(String, f64)>> = HashMap::new();
88 for u in 0..self.vertex_count {
89 let deg = self.out_degrees[u];
90 if deg == 0 {
91 continue;
92 }
93 let contrib = damping * self.rank[u] / deg as f64;
94
95 for dst in local_edge_iter(u as u32) {
97 self.next_rank[dst as usize] += contrib;
98 }
99
100 if let Some(boundary) = self.boundary_edges.get(&(u as u32)) {
102 for (dst_name, target_shard) in boundary {
103 outbound
104 .entry(*target_shard)
105 .or_default()
106 .push((dst_name.clone(), contrib));
107 }
108 }
109 }
110
111 let delta: f64 = self
112 .rank
113 .iter()
114 .zip(self.next_rank.iter())
115 .map(|(old, new)| (old - new).abs())
116 .sum();
117
118 std::mem::swap(&mut self.rank, &mut self.next_rank);
119 self.incoming_contributions.clear();
120
121 (delta, outbound)
122 }
123
124 pub fn apply_incoming_contributions(&mut self, node_id_to_local: &dyn Fn(&str) -> Option<u32>) {
125 for (vertex_name, contrib) in &self.incoming_contributions {
126 if let Some(local_id) = node_id_to_local(vertex_name) {
127 self.next_rank[local_id as usize] += contrib;
128 }
129 }
130 }
131
132 pub fn add_remote_contribution(&mut self, vertex_name: String, value: f64) {
133 *self
134 .incoming_contributions
135 .entry(vertex_name)
136 .or_insert(0.0) += value;
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn shard_state_init() {
146 let state = ShardPageRankState::init(3, vec![2, 1, 0], |_| None, &|_node| Vec::new());
147 assert_eq!(state.vertex_count, 3);
148 assert!(!state.is_dangling[0]);
149 assert!(state.is_dangling[2]);
150 }
151
152 #[test]
153 fn shard_state_with_ghost_edges() {
154 let state = ShardPageRankState::init(
155 2,
156 vec![2, 1],
157 |node| if node == "remote" { Some(5) } else { None },
158 &|node| {
159 if node == 0 {
160 vec![("remote".into(), true, 5)]
161 } else {
162 Vec::new()
163 }
164 },
165 );
166 assert_eq!(state.boundary_edges.len(), 1);
167 assert_eq!(state.boundary_edges[&0][0].1, 5);
168 }
169
170 #[test]
171 fn shard_superstep_local_only() {
172 let mut state = ShardPageRankState::init(3, vec![1, 1, 1], |_| None, &|_| Vec::new());
173 let (delta, outbound) = state.superstep(0.85, 3, &|node| match node {
174 0 => vec![1],
175 1 => vec![2],
176 2 => vec![0],
177 _ => Vec::new(),
178 });
179 assert!(outbound.is_empty());
180 assert!(delta >= 0.0);
181 let sum: f64 = state.rank.iter().sum();
182 assert!((sum - 1.0).abs() < 1e-6);
183 }
184
185 #[test]
186 fn remote_contribution_accumulation() {
187 let mut state = ShardPageRankState::init(2, vec![1, 0], |_| None, &|_| Vec::new());
188 state.add_remote_contribution("n0".into(), 0.1);
189 state.add_remote_contribution("n0".into(), 0.2);
190 state.add_remote_contribution("n1".into(), 0.3);
191 assert!((state.incoming_contributions["n0"] - 0.3).abs() < 1e-10);
192 assert!((state.incoming_contributions["n1"] - 0.3).abs() < 1e-10);
193 }
194}