Skip to main content

nodedb_cluster/distributed_graph/
coordinator.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3//! BSP coordinator for distributed graph algorithms.
4//!
5//! Runs on the Control Plane. Tracks which shards have completed each
6//! superstep and aggregates convergence metrics.
7
8use std::collections::HashMap;
9
10use super::types::{AlgoComplete, SuperstepAck, SuperstepBarrier};
11
12#[derive(Debug)]
13pub struct BspCoordinator {
14    pub algorithm: String,
15    pub iteration: u32,
16    pub max_iterations: u32,
17    pub tolerance: f64,
18    pub shard_ids: Vec<u32>,
19    pub acks: HashMap<u32, SuperstepAck>,
20    pub completed: bool,
21    /// Bitemporal system-time ordinal for this run. Stamped onto every
22    /// `SuperstepBarrier` so all shards materialize the same historical
23    /// topology. `None` means current state.
24    pub system_as_of: Option<i64>,
25}
26
27impl BspCoordinator {
28    pub fn new(
29        algorithm: String,
30        max_iterations: u32,
31        tolerance: f64,
32        shard_ids: Vec<u32>,
33    ) -> Self {
34        Self::new_as_of(algorithm, max_iterations, tolerance, shard_ids, None)
35    }
36
37    pub fn new_as_of(
38        algorithm: String,
39        max_iterations: u32,
40        tolerance: f64,
41        shard_ids: Vec<u32>,
42        system_as_of: Option<i64>,
43    ) -> Self {
44        Self {
45            algorithm,
46            iteration: 0,
47            max_iterations,
48            tolerance,
49            shard_ids,
50            acks: HashMap::new(),
51            completed: false,
52            system_as_of,
53        }
54    }
55
56    pub fn record_ack(&mut self, ack: SuperstepAck) {
57        self.acks.insert(ack.shard_id, ack);
58    }
59
60    pub fn all_acked(&self) -> bool {
61        self.shard_ids.iter().all(|id| self.acks.contains_key(id))
62    }
63
64    /// Sum of per-shard convergence deltas. Only meaningful when `all_acked()`.
65    pub fn global_delta(&self) -> f64 {
66        debug_assert!(
67            self.all_acked(),
68            "global_delta called before all shards ACKed"
69        );
70        self.acks.values().map(|ack| ack.local_delta).sum()
71    }
72
73    /// Total vertex count across all shards. Only meaningful when `all_acked()`.
74    pub fn total_vertices(&self) -> usize {
75        debug_assert!(
76            self.all_acked(),
77            "total_vertices called before all shards ACKed"
78        );
79        self.acks.values().map(|ack| ack.vertex_count).sum()
80    }
81
82    /// Advance to next superstep. Returns `true` if should continue.
83    pub fn advance(&mut self) -> bool {
84        let delta = self.global_delta();
85        self.iteration += 1;
86        self.acks.clear();
87
88        if delta < self.tolerance || self.iteration >= self.max_iterations {
89            self.completed = true;
90            return false;
91        }
92        true
93    }
94
95    pub fn barrier_message(&self) -> SuperstepBarrier {
96        SuperstepBarrier {
97            algorithm: self.algorithm.clone(),
98            iteration: self.iteration + 1,
99            max_iterations: self.max_iterations,
100            params: String::new(),
101            system_as_of: self.system_as_of,
102        }
103    }
104
105    pub fn completion_message(&self) -> AlgoComplete {
106        AlgoComplete {
107            iterations: self.iteration,
108            converged: self.global_delta() < self.tolerance,
109            final_delta: self.global_delta(),
110        }
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn coordinator_convergence() {
120        let mut coord = BspCoordinator::new("pagerank".into(), 20, 1e-6, vec![0, 1, 2]);
121
122        for id in 0..3u32 {
123            coord.record_ack(SuperstepAck {
124                shard_id: id,
125                iteration: 1,
126                local_delta: 0.3,
127                vertex_count: 100,
128                contributions_sent: 10,
129            });
130        }
131        assert!(coord.all_acked());
132        assert!((coord.global_delta() - 0.9).abs() < 1e-10);
133        assert!(coord.advance());
134
135        for id in 0..3u32 {
136            coord.record_ack(SuperstepAck {
137                shard_id: id,
138                iteration: 2,
139                local_delta: 1e-8,
140                vertex_count: 100,
141                contributions_sent: 10,
142            });
143        }
144        assert!(!coord.advance());
145        assert!(coord.completed);
146    }
147
148    #[test]
149    fn coordinator_max_iterations() {
150        let mut coord = BspCoordinator::new("pagerank".into(), 2, 1e-10, vec![0]);
151
152        coord.record_ack(SuperstepAck {
153            shard_id: 0,
154            iteration: 1,
155            local_delta: 1.0,
156            vertex_count: 10,
157            contributions_sent: 0,
158        });
159        assert!(coord.advance());
160
161        coord.record_ack(SuperstepAck {
162            shard_id: 0,
163            iteration: 2,
164            local_delta: 0.5,
165            vertex_count: 10,
166            contributions_sent: 0,
167        });
168        assert!(!coord.advance());
169        assert!(coord.completed);
170    }
171}