nodedb_cluster/distributed_graph/
coordinator.rs1use std::collections::HashMap;
9
10use super::types::{AlgoComplete, SuperstepAck, SuperstepBarrier};
11
12#[derive(Debug)]
13pub struct BspCoordinator {
14 pub algorithm: String,
15 pub iteration: u32,
16 pub max_iterations: u32,
17 pub tolerance: f64,
18 pub shard_ids: Vec<u32>,
19 pub acks: HashMap<u32, SuperstepAck>,
20 pub completed: bool,
21 pub system_as_of: Option<i64>,
25}
26
27impl BspCoordinator {
28 pub fn new(
29 algorithm: String,
30 max_iterations: u32,
31 tolerance: f64,
32 shard_ids: Vec<u32>,
33 ) -> Self {
34 Self::new_as_of(algorithm, max_iterations, tolerance, shard_ids, None)
35 }
36
37 pub fn new_as_of(
38 algorithm: String,
39 max_iterations: u32,
40 tolerance: f64,
41 shard_ids: Vec<u32>,
42 system_as_of: Option<i64>,
43 ) -> Self {
44 Self {
45 algorithm,
46 iteration: 0,
47 max_iterations,
48 tolerance,
49 shard_ids,
50 acks: HashMap::new(),
51 completed: false,
52 system_as_of,
53 }
54 }
55
56 pub fn record_ack(&mut self, ack: SuperstepAck) {
57 self.acks.insert(ack.shard_id, ack);
58 }
59
60 pub fn all_acked(&self) -> bool {
61 self.shard_ids.iter().all(|id| self.acks.contains_key(id))
62 }
63
64 pub fn global_delta(&self) -> f64 {
66 debug_assert!(
67 self.all_acked(),
68 "global_delta called before all shards ACKed"
69 );
70 self.acks.values().map(|ack| ack.local_delta).sum()
71 }
72
73 pub fn total_vertices(&self) -> usize {
75 debug_assert!(
76 self.all_acked(),
77 "total_vertices called before all shards ACKed"
78 );
79 self.acks.values().map(|ack| ack.vertex_count).sum()
80 }
81
82 pub fn advance(&mut self) -> bool {
84 let delta = self.global_delta();
85 self.iteration += 1;
86 self.acks.clear();
87
88 if delta < self.tolerance || self.iteration >= self.max_iterations {
89 self.completed = true;
90 return false;
91 }
92 true
93 }
94
95 pub fn barrier_message(&self) -> SuperstepBarrier {
96 SuperstepBarrier {
97 algorithm: self.algorithm.clone(),
98 iteration: self.iteration + 1,
99 max_iterations: self.max_iterations,
100 params: String::new(),
101 system_as_of: self.system_as_of,
102 }
103 }
104
105 pub fn completion_message(&self) -> AlgoComplete {
106 AlgoComplete {
107 iterations: self.iteration,
108 converged: self.global_delta() < self.tolerance,
109 final_delta: self.global_delta(),
110 }
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn coordinator_convergence() {
120 let mut coord = BspCoordinator::new("pagerank".into(), 20, 1e-6, vec![0, 1, 2]);
121
122 for id in 0..3u32 {
123 coord.record_ack(SuperstepAck {
124 shard_id: id,
125 iteration: 1,
126 local_delta: 0.3,
127 vertex_count: 100,
128 contributions_sent: 10,
129 });
130 }
131 assert!(coord.all_acked());
132 assert!((coord.global_delta() - 0.9).abs() < 1e-10);
133 assert!(coord.advance());
134
135 for id in 0..3u32 {
136 coord.record_ack(SuperstepAck {
137 shard_id: id,
138 iteration: 2,
139 local_delta: 1e-8,
140 vertex_count: 100,
141 contributions_sent: 10,
142 });
143 }
144 assert!(!coord.advance());
145 assert!(coord.completed);
146 }
147
148 #[test]
149 fn coordinator_max_iterations() {
150 let mut coord = BspCoordinator::new("pagerank".into(), 2, 1e-10, vec![0]);
151
152 coord.record_ack(SuperstepAck {
153 shard_id: 0,
154 iteration: 1,
155 local_delta: 1.0,
156 vertex_count: 10,
157 contributions_sent: 0,
158 });
159 assert!(coord.advance());
160
161 coord.record_ack(SuperstepAck {
162 shard_id: 0,
163 iteration: 2,
164 local_delta: 0.5,
165 vertex_count: 10,
166 contributions_sent: 0,
167 });
168 assert!(!coord.advance());
169 assert!(coord.completed);
170 }
171}