Skip to main content

nodedb_cluster/distributed_graph/
pagerank.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3//! Per-shard PageRank execution state for distributed BSP.
4
5use std::collections::HashMap;
6
7/// Per-shard PageRank state maintained across supersteps.
8#[derive(Debug)]
9pub struct ShardPageRankState {
10    pub vertex_count: usize,
11    pub rank: Vec<f64>,
12    pub next_rank: Vec<f64>,
13    pub out_degrees: Vec<usize>,
14    pub is_dangling: Vec<bool>,
15    pub boundary_edges: HashMap<u32, Vec<(String, u16)>>,
16    pub incoming_contributions: HashMap<String, f64>,
17}
18
19impl ShardPageRankState {
20    /// Initialize from local CSR partition.
21    pub fn init<F>(
22        vertex_count: usize,
23        out_degrees: Vec<usize>,
24        _ghost_lookup: F,
25        csr_out_edges: &dyn Fn(u32) -> Vec<(String, bool, u16)>,
26    ) -> Self
27    where
28        F: Fn(&str) -> Option<u16>,
29    {
30        let init_rank = if vertex_count > 0 {
31            1.0 / vertex_count as f64
32        } else {
33            0.0
34        };
35
36        let rank = vec![init_rank; vertex_count];
37        let next_rank = vec![0.0; vertex_count];
38        let is_dangling: Vec<bool> = out_degrees.iter().map(|&d| d == 0).collect();
39
40        let mut boundary_edges: HashMap<u32, Vec<(String, u16)>> = HashMap::new();
41        for node in 0..vertex_count {
42            for (dst_name, is_ghost, target_shard) in csr_out_edges(node as u32) {
43                if is_ghost {
44                    boundary_edges
45                        .entry(node as u32)
46                        .or_default()
47                        .push((dst_name, target_shard));
48                }
49            }
50        }
51
52        Self {
53            vertex_count,
54            rank,
55            next_rank,
56            out_degrees,
57            is_dangling,
58            boundary_edges,
59            incoming_contributions: HashMap::new(),
60        }
61    }
62
63    /// Execute one superstep. Returns (local_delta, outbound_contributions).
64    pub fn superstep(
65        &mut self,
66        damping: f64,
67        global_n: usize,
68        local_edge_iter: &dyn Fn(u32) -> Vec<u32>,
69    ) -> (f64, HashMap<u16, Vec<(String, f64)>>) {
70        let n = global_n as f64;
71        let teleport = (1.0 - damping) / n;
72
73        let dangling_sum: f64 = self
74            .rank
75            .iter()
76            .enumerate()
77            .filter(|(i, _)| self.is_dangling[*i])
78            .map(|(_, r)| r)
79            .sum();
80
81        let base = teleport + damping * dangling_sum / n;
82
83        for r in self.next_rank.iter_mut() {
84            *r = base;
85        }
86
87        let mut outbound: HashMap<u16, Vec<(String, f64)>> = HashMap::new();
88        for u in 0..self.vertex_count {
89            let deg = self.out_degrees[u];
90            if deg == 0 {
91                continue;
92            }
93            let contrib = damping * self.rank[u] / deg as f64;
94
95            // Scatter to local edges.
96            for dst in local_edge_iter(u as u32) {
97                self.next_rank[dst as usize] += contrib;
98            }
99
100            // Scatter to boundary edges (cross-shard).
101            if let Some(boundary) = self.boundary_edges.get(&(u as u32)) {
102                for (dst_name, target_shard) in boundary {
103                    outbound
104                        .entry(*target_shard)
105                        .or_default()
106                        .push((dst_name.clone(), contrib));
107                }
108            }
109        }
110
111        let delta: f64 = self
112            .rank
113            .iter()
114            .zip(self.next_rank.iter())
115            .map(|(old, new)| (old - new).abs())
116            .sum();
117
118        std::mem::swap(&mut self.rank, &mut self.next_rank);
119        self.incoming_contributions.clear();
120
121        (delta, outbound)
122    }
123
124    pub fn apply_incoming_contributions(&mut self, node_id_to_local: &dyn Fn(&str) -> Option<u32>) {
125        for (vertex_name, contrib) in &self.incoming_contributions {
126            if let Some(local_id) = node_id_to_local(vertex_name) {
127                self.next_rank[local_id as usize] += contrib;
128            }
129        }
130    }
131
132    pub fn add_remote_contribution(&mut self, vertex_name: String, value: f64) {
133        *self
134            .incoming_contributions
135            .entry(vertex_name)
136            .or_insert(0.0) += value;
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn shard_state_init() {
146        let state = ShardPageRankState::init(3, vec![2, 1, 0], |_| None, &|_node| Vec::new());
147        assert_eq!(state.vertex_count, 3);
148        assert!(!state.is_dangling[0]);
149        assert!(state.is_dangling[2]);
150    }
151
152    #[test]
153    fn shard_state_with_ghost_edges() {
154        let state = ShardPageRankState::init(
155            2,
156            vec![2, 1],
157            |node| if node == "remote" { Some(5) } else { None },
158            &|node| {
159                if node == 0 {
160                    vec![("remote".into(), true, 5)]
161                } else {
162                    Vec::new()
163                }
164            },
165        );
166        assert_eq!(state.boundary_edges.len(), 1);
167        assert_eq!(state.boundary_edges[&0][0].1, 5);
168    }
169
170    #[test]
171    fn shard_superstep_local_only() {
172        let mut state = ShardPageRankState::init(3, vec![1, 1, 1], |_| None, &|_| Vec::new());
173        let (delta, outbound) = state.superstep(0.85, 3, &|node| match node {
174            0 => vec![1],
175            1 => vec![2],
176            2 => vec![0],
177            _ => Vec::new(),
178        });
179        assert!(outbound.is_empty());
180        assert!(delta >= 0.0);
181        let sum: f64 = state.rank.iter().sum();
182        assert!((sum - 1.0).abs() < 1e-6);
183    }
184
185    #[test]
186    fn remote_contribution_accumulation() {
187        let mut state = ShardPageRankState::init(2, vec![1, 0], |_| None, &|_| Vec::new());
188        state.add_remote_contribution("n0".into(), 0.1);
189        state.add_remote_contribution("n0".into(), 0.2);
190        state.add_remote_contribution("n1".into(), 0.3);
191        assert!((state.incoming_contributions["n0"] - 0.3).abs() < 1e-10);
192        assert!((state.incoming_contributions["n1"] - 0.3).abs() < 1e-10);
193    }
194}