phago_distributed/coordinator/
tick_barrier.rs1use crate::types::*;
8use phago_core::types::Tick;
9use std::collections::HashSet;
10use tokio::sync::{Mutex, Notify};
11
12pub struct TickBarrier {
23 shard_count: Mutex<usize>,
25 completed: Mutex<HashSet<(ShardId, TickPhase, Tick)>>,
27 notify: Notify,
29 phase_timeout_secs: u64,
31}
32
33impl TickBarrier {
34 pub fn new(shard_count: usize) -> Self {
36 Self {
37 shard_count: Mutex::new(shard_count),
38 completed: Mutex::new(HashSet::new()),
39 notify: Notify::new(),
40 phase_timeout_secs: 30,
41 }
42 }
43
44 pub fn with_timeout(shard_count: usize, timeout_secs: u64) -> Self {
46 Self {
47 shard_count: Mutex::new(shard_count),
48 completed: Mutex::new(HashSet::new()),
49 notify: Notify::new(),
50 phase_timeout_secs: timeout_secs,
51 }
52 }
53
54 pub async fn set_shard_count(&self, count: usize) {
58 let mut sc = self.shard_count.lock().await;
59 *sc = count;
60 }
61
62 pub async fn shard_count(&self) -> usize {
64 *self.shard_count.lock().await
65 }
66
67 pub async fn complete(
75 &self,
76 shard_id: ShardId,
77 phase: TickPhase,
78 tick: Tick,
79 ) -> DistributedResult<()> {
80 let mut completed = self.completed.lock().await;
81 completed.insert((shard_id, phase, tick));
82 drop(completed);
83
84 self.notify.notify_waiters();
86
87 Ok(())
88 }
89
90 pub async fn is_complete(&self, shard_id: ShardId, phase: TickPhase, tick: Tick) -> bool {
92 let completed = self.completed.lock().await;
93 completed.contains(&(shard_id, phase, tick))
94 }
95
96 pub async fn completed_count(&self, phase: TickPhase, tick: Tick) -> usize {
98 let completed = self.completed.lock().await;
99 completed
100 .iter()
101 .filter(|(_, p, t)| *p == phase && *t == tick)
102 .count()
103 }
104
105 pub async fn wait_all(&self, phase: TickPhase, tick: Tick) -> DistributedResult<()> {
120 let timeout = tokio::time::Duration::from_secs(self.phase_timeout_secs);
121
122 loop {
123 {
125 let completed = self.completed.lock().await;
126 let shard_count = *self.shard_count.lock().await;
127
128 let count = completed
129 .iter()
130 .filter(|(_, p, t)| *p == phase && *t == tick)
131 .count();
132
133 if count >= shard_count && shard_count > 0 {
134 return Ok(());
135 }
136 }
137
138 tokio::select! {
140 _ = self.notify.notified() => {
141 continue;
143 }
144 _ = tokio::time::sleep(timeout) => {
145 return Err(DistributedError::PhaseTimeout(phase));
146 }
147 }
148 }
149 }
150
151 pub async fn wait_all_with_timeout(
153 &self,
154 phase: TickPhase,
155 tick: Tick,
156 timeout: tokio::time::Duration,
157 ) -> DistributedResult<()> {
158 loop {
159 {
160 let completed = self.completed.lock().await;
161 let shard_count = *self.shard_count.lock().await;
162
163 let count = completed
164 .iter()
165 .filter(|(_, p, t)| *p == phase && *t == tick)
166 .count();
167
168 if count >= shard_count && shard_count > 0 {
169 return Ok(());
170 }
171 }
172
173 tokio::select! {
174 _ = self.notify.notified() => continue,
175 _ = tokio::time::sleep(timeout) => {
176 return Err(DistributedError::PhaseTimeout(phase));
177 }
178 }
179 }
180 }
181
182 pub async fn reset_for_tick(&self, _tick: Tick) {
187 let mut completed = self.completed.lock().await;
188 completed.clear();
189 }
190
191 pub async fn get_completed_shards(&self, phase: TickPhase, tick: Tick) -> Vec<ShardId> {
193 let completed = self.completed.lock().await;
194 completed
195 .iter()
196 .filter(|(_, p, t)| *p == phase && *t == tick)
197 .map(|(s, _, _)| *s)
198 .collect()
199 }
200
201 pub async fn get_pending_shards(
203 &self,
204 phase: TickPhase,
205 tick: Tick,
206 all_shards: &[ShardId],
207 ) -> Vec<ShardId> {
208 let completed = self.completed.lock().await;
209 let completed_set: HashSet<_> = completed
210 .iter()
211 .filter(|(_, p, t)| *p == phase && *t == tick)
212 .map(|(s, _, _)| *s)
213 .collect();
214
215 all_shards
216 .iter()
217 .filter(|s| !completed_set.contains(s))
218 .copied()
219 .collect()
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[tokio::test]
228 async fn test_barrier_creation() {
229 let barrier = TickBarrier::new(3);
230 assert_eq!(barrier.shard_count().await, 3);
231 }
232
233 #[tokio::test]
234 async fn test_phase_completion() {
235 let barrier = TickBarrier::new(2);
236
237 barrier
239 .complete(ShardId::new(0), TickPhase::Sense, 1)
240 .await
241 .unwrap();
242 assert!(
243 barrier
244 .is_complete(ShardId::new(0), TickPhase::Sense, 1)
245 .await
246 );
247 assert!(
248 !barrier
249 .is_complete(ShardId::new(1), TickPhase::Sense, 1)
250 .await
251 );
252
253 barrier
255 .complete(ShardId::new(1), TickPhase::Sense, 1)
256 .await
257 .unwrap();
258 assert!(
259 barrier
260 .is_complete(ShardId::new(1), TickPhase::Sense, 1)
261 .await
262 );
263 }
264
265 #[tokio::test]
266 async fn test_wait_all_completes() {
267 let barrier = TickBarrier::with_timeout(2, 5);
268
269 let barrier_clone = std::sync::Arc::new(barrier);
271 let barrier_ref = barrier_clone.clone();
272
273 tokio::spawn(async move {
274 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
275 barrier_ref
276 .complete(ShardId::new(0), TickPhase::Sense, 1)
277 .await
278 .unwrap();
279 barrier_ref
280 .complete(ShardId::new(1), TickPhase::Sense, 1)
281 .await
282 .unwrap();
283 });
284
285 barrier_clone.wait_all(TickPhase::Sense, 1).await.unwrap();
287 }
288
289 #[tokio::test]
290 async fn test_reset_for_tick() {
291 let barrier = TickBarrier::new(1);
292
293 barrier
294 .complete(ShardId::new(0), TickPhase::Sense, 1)
295 .await
296 .unwrap();
297 assert!(
298 barrier
299 .is_complete(ShardId::new(0), TickPhase::Sense, 1)
300 .await
301 );
302
303 barrier.reset_for_tick(2).await;
304 assert!(
305 !barrier
306 .is_complete(ShardId::new(0), TickPhase::Sense, 1)
307 .await
308 );
309 }
310
311 #[tokio::test]
312 async fn test_completed_count() {
313 let barrier = TickBarrier::new(3);
314
315 assert_eq!(barrier.completed_count(TickPhase::Sense, 1).await, 0);
316
317 barrier
318 .complete(ShardId::new(0), TickPhase::Sense, 1)
319 .await
320 .unwrap();
321 assert_eq!(barrier.completed_count(TickPhase::Sense, 1).await, 1);
322
323 barrier
324 .complete(ShardId::new(1), TickPhase::Sense, 1)
325 .await
326 .unwrap();
327 assert_eq!(barrier.completed_count(TickPhase::Sense, 1).await, 2);
328 }
329
330 #[tokio::test]
331 async fn test_update_shard_count() {
332 let barrier = TickBarrier::new(2);
333 assert_eq!(barrier.shard_count().await, 2);
334
335 barrier.set_shard_count(5).await;
336 assert_eq!(barrier.shard_count().await, 5);
337 }
338}