Skip to main content

nodedb_cluster/distributed_graph/
coordinator.rs

1//! BSP coordinator for distributed graph algorithms.
2//!
3//! Runs on the Control Plane. Tracks which shards have completed each
4//! superstep and aggregates convergence metrics.
5
6use std::collections::HashMap;
7
8use super::types::{AlgoComplete, SuperstepAck, SuperstepBarrier};
9
10#[derive(Debug)]
11pub struct BspCoordinator {
12    pub algorithm: String,
13    pub iteration: u32,
14    pub max_iterations: u32,
15    pub tolerance: f64,
16    pub shard_ids: Vec<u16>,
17    pub acks: HashMap<u16, SuperstepAck>,
18    pub completed: bool,
19}
20
21impl BspCoordinator {
22    pub fn new(
23        algorithm: String,
24        max_iterations: u32,
25        tolerance: f64,
26        shard_ids: Vec<u16>,
27    ) -> Self {
28        Self {
29            algorithm,
30            iteration: 0,
31            max_iterations,
32            tolerance,
33            shard_ids,
34            acks: HashMap::new(),
35            completed: false,
36        }
37    }
38
39    pub fn record_ack(&mut self, ack: SuperstepAck) {
40        self.acks.insert(ack.shard_id, ack);
41    }
42
43    pub fn all_acked(&self) -> bool {
44        self.shard_ids.iter().all(|id| self.acks.contains_key(id))
45    }
46
47    /// Sum of per-shard convergence deltas. Only meaningful when `all_acked()`.
48    pub fn global_delta(&self) -> f64 {
49        debug_assert!(
50            self.all_acked(),
51            "global_delta called before all shards ACKed"
52        );
53        self.acks.values().map(|ack| ack.local_delta).sum()
54    }
55
56    /// Total vertex count across all shards. Only meaningful when `all_acked()`.
57    pub fn total_vertices(&self) -> usize {
58        debug_assert!(
59            self.all_acked(),
60            "total_vertices called before all shards ACKed"
61        );
62        self.acks.values().map(|ack| ack.vertex_count).sum()
63    }
64
65    /// Advance to next superstep. Returns `true` if should continue.
66    pub fn advance(&mut self) -> bool {
67        let delta = self.global_delta();
68        self.iteration += 1;
69        self.acks.clear();
70
71        if delta < self.tolerance || self.iteration >= self.max_iterations {
72            self.completed = true;
73            return false;
74        }
75        true
76    }
77
78    pub fn barrier_message(&self) -> SuperstepBarrier {
79        SuperstepBarrier {
80            algorithm: self.algorithm.clone(),
81            iteration: self.iteration + 1,
82            max_iterations: self.max_iterations,
83            params: String::new(),
84        }
85    }
86
87    pub fn completion_message(&self) -> AlgoComplete {
88        AlgoComplete {
89            iterations: self.iteration,
90            converged: self.global_delta() < self.tolerance,
91            final_delta: self.global_delta(),
92        }
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn coordinator_convergence() {
102        let mut coord = BspCoordinator::new("pagerank".into(), 20, 1e-6, vec![0, 1, 2]);
103
104        for id in 0..3u16 {
105            coord.record_ack(SuperstepAck {
106                shard_id: id,
107                iteration: 1,
108                local_delta: 0.3,
109                vertex_count: 100,
110                contributions_sent: 10,
111            });
112        }
113        assert!(coord.all_acked());
114        assert!((coord.global_delta() - 0.9).abs() < 1e-10);
115        assert!(coord.advance());
116
117        for id in 0..3u16 {
118            coord.record_ack(SuperstepAck {
119                shard_id: id,
120                iteration: 2,
121                local_delta: 1e-8,
122                vertex_count: 100,
123                contributions_sent: 10,
124            });
125        }
126        assert!(!coord.advance());
127        assert!(coord.completed);
128    }
129
130    #[test]
131    fn coordinator_max_iterations() {
132        let mut coord = BspCoordinator::new("pagerank".into(), 2, 1e-10, vec![0]);
133
134        coord.record_ack(SuperstepAck {
135            shard_id: 0,
136            iteration: 1,
137            local_delta: 1.0,
138            vertex_count: 10,
139            contributions_sent: 0,
140        });
141        assert!(coord.advance());
142
143        coord.record_ack(SuperstepAck {
144            shard_id: 0,
145            iteration: 2,
146            local_delta: 0.5,
147            vertex_count: 10,
148            contributions_sent: 0,
149        });
150        assert!(!coord.advance());
151        assert!(coord.completed);
152    }
153}