nodedb_cluster/distributed_graph/
coordinator.rs1use std::collections::HashMap;
7
8use super::types::{AlgoComplete, SuperstepAck, SuperstepBarrier};
9
10#[derive(Debug)]
11pub struct BspCoordinator {
12 pub algorithm: String,
13 pub iteration: u32,
14 pub max_iterations: u32,
15 pub tolerance: f64,
16 pub shard_ids: Vec<u16>,
17 pub acks: HashMap<u16, SuperstepAck>,
18 pub completed: bool,
19}
20
21impl BspCoordinator {
22 pub fn new(
23 algorithm: String,
24 max_iterations: u32,
25 tolerance: f64,
26 shard_ids: Vec<u16>,
27 ) -> Self {
28 Self {
29 algorithm,
30 iteration: 0,
31 max_iterations,
32 tolerance,
33 shard_ids,
34 acks: HashMap::new(),
35 completed: false,
36 }
37 }
38
39 pub fn record_ack(&mut self, ack: SuperstepAck) {
40 self.acks.insert(ack.shard_id, ack);
41 }
42
43 pub fn all_acked(&self) -> bool {
44 self.shard_ids.iter().all(|id| self.acks.contains_key(id))
45 }
46
47 pub fn global_delta(&self) -> f64 {
49 debug_assert!(
50 self.all_acked(),
51 "global_delta called before all shards ACKed"
52 );
53 self.acks.values().map(|ack| ack.local_delta).sum()
54 }
55
56 pub fn total_vertices(&self) -> usize {
58 debug_assert!(
59 self.all_acked(),
60 "total_vertices called before all shards ACKed"
61 );
62 self.acks.values().map(|ack| ack.vertex_count).sum()
63 }
64
65 pub fn advance(&mut self) -> bool {
67 let delta = self.global_delta();
68 self.iteration += 1;
69 self.acks.clear();
70
71 if delta < self.tolerance || self.iteration >= self.max_iterations {
72 self.completed = true;
73 return false;
74 }
75 true
76 }
77
78 pub fn barrier_message(&self) -> SuperstepBarrier {
79 SuperstepBarrier {
80 algorithm: self.algorithm.clone(),
81 iteration: self.iteration + 1,
82 max_iterations: self.max_iterations,
83 params: String::new(),
84 }
85 }
86
87 pub fn completion_message(&self) -> AlgoComplete {
88 AlgoComplete {
89 iterations: self.iteration,
90 converged: self.global_delta() < self.tolerance,
91 final_delta: self.global_delta(),
92 }
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn coordinator_convergence() {
102 let mut coord = BspCoordinator::new("pagerank".into(), 20, 1e-6, vec![0, 1, 2]);
103
104 for id in 0..3u16 {
105 coord.record_ack(SuperstepAck {
106 shard_id: id,
107 iteration: 1,
108 local_delta: 0.3,
109 vertex_count: 100,
110 contributions_sent: 10,
111 });
112 }
113 assert!(coord.all_acked());
114 assert!((coord.global_delta() - 0.9).abs() < 1e-10);
115 assert!(coord.advance());
116
117 for id in 0..3u16 {
118 coord.record_ack(SuperstepAck {
119 shard_id: id,
120 iteration: 2,
121 local_delta: 1e-8,
122 vertex_count: 100,
123 contributions_sent: 10,
124 });
125 }
126 assert!(!coord.advance());
127 assert!(coord.completed);
128 }
129
130 #[test]
131 fn coordinator_max_iterations() {
132 let mut coord = BspCoordinator::new("pagerank".into(), 2, 1e-10, vec![0]);
133
134 coord.record_ack(SuperstepAck {
135 shard_id: 0,
136 iteration: 1,
137 local_delta: 1.0,
138 vertex_count: 10,
139 contributions_sent: 0,
140 });
141 assert!(coord.advance());
142
143 coord.record_ack(SuperstepAck {
144 shard_id: 0,
145 iteration: 2,
146 local_delta: 0.5,
147 vertex_count: 10,
148 contributions_sent: 0,
149 });
150 assert!(!coord.advance());
151 assert!(coord.completed);
152 }
153}