Skip to main content

phago_distributed/coordinator/
tick_barrier.rs

1//! Tick barrier for phase synchronization.
2//!
3//! This module implements a barrier mechanism that ensures all shards
4//! complete each phase of a tick before any shard proceeds to the next phase.
5//! This is essential for maintaining consistency in the distributed colony.
6
7use crate::types::*;
8use phago_core::types::Tick;
9use std::collections::HashSet;
10use tokio::sync::{Mutex, Notify};
11
12/// Barrier ensuring all shards complete a phase before any proceeds.
13///
14/// The tick barrier coordinates the phases within each simulation tick:
15/// 1. Sense - agents sense the substrate (read-only phase)
16/// 2. Act - process agent actions (write phase)
17/// 3. Decay - decay signals, traces, and edges (maintenance phase)
18/// 4. Advance - advance tick counter (finalization phase)
19///
20/// Each shard must signal completion of each phase, and all shards must
21/// complete before any can proceed to the next phase.
22pub struct TickBarrier {
23    /// Number of shards expected to participate.
24    shard_count: Mutex<usize>,
25    /// Set of (shard, phase, tick) tuples that have completed.
26    completed: Mutex<HashSet<(ShardId, TickPhase, Tick)>>,
27    /// Notification channel for waiters.
28    notify: Notify,
29    /// Default timeout for phase completion in seconds.
30    phase_timeout_secs: u64,
31}
32
33impl TickBarrier {
34    /// Create a new tick barrier for the specified number of shards.
35    pub fn new(shard_count: usize) -> Self {
36        Self {
37            shard_count: Mutex::new(shard_count),
38            completed: Mutex::new(HashSet::new()),
39            notify: Notify::new(),
40            phase_timeout_secs: 30,
41        }
42    }
43
44    /// Create a tick barrier with custom timeout.
45    pub fn with_timeout(shard_count: usize, timeout_secs: u64) -> Self {
46        Self {
47            shard_count: Mutex::new(shard_count),
48            completed: Mutex::new(HashSet::new()),
49            notify: Notify::new(),
50            phase_timeout_secs: timeout_secs,
51        }
52    }
53
54    /// Update the expected shard count.
55    ///
56    /// This should be called when shards are added or removed from the cluster.
57    pub async fn set_shard_count(&self, count: usize) {
58        let mut sc = self.shard_count.lock().await;
59        *sc = count;
60    }
61
62    /// Get the current shard count.
63    pub async fn shard_count(&self) -> usize {
64        *self.shard_count.lock().await
65    }
66
67    /// Mark a shard as having completed a phase.
68    ///
69    /// # Arguments
70    ///
71    /// * `shard_id` - The shard that completed
72    /// * `phase` - The phase that was completed
73    /// * `tick` - The tick number
74    pub async fn complete(
75        &self,
76        shard_id: ShardId,
77        phase: TickPhase,
78        tick: Tick,
79    ) -> DistributedResult<()> {
80        let mut completed = self.completed.lock().await;
81        completed.insert((shard_id, phase, tick));
82        drop(completed);
83
84        // Notify all waiters that progress was made
85        self.notify.notify_waiters();
86
87        Ok(())
88    }
89
90    /// Check if a specific shard has completed a phase.
91    pub async fn is_complete(&self, shard_id: ShardId, phase: TickPhase, tick: Tick) -> bool {
92        let completed = self.completed.lock().await;
93        completed.contains(&(shard_id, phase, tick))
94    }
95
96    /// Get the number of shards that have completed a phase.
97    pub async fn completed_count(&self, phase: TickPhase, tick: Tick) -> usize {
98        let completed = self.completed.lock().await;
99        completed
100            .iter()
101            .filter(|(_, p, t)| *p == phase && *t == tick)
102            .count()
103    }
104
105    /// Wait for all shards to complete a phase.
106    ///
107    /// This will block until all registered shards have signaled completion
108    /// of the specified phase, or until the timeout is reached.
109    ///
110    /// # Arguments
111    ///
112    /// * `phase` - The phase to wait for
113    /// * `tick` - The tick number
114    ///
115    /// # Errors
116    ///
117    /// Returns `DistributedError::PhaseTimeout` if the timeout is reached
118    /// before all shards complete.
119    pub async fn wait_all(&self, phase: TickPhase, tick: Tick) -> DistributedResult<()> {
120        let timeout = tokio::time::Duration::from_secs(self.phase_timeout_secs);
121
122        loop {
123            // Check if all shards have completed
124            {
125                let completed = self.completed.lock().await;
126                let shard_count = *self.shard_count.lock().await;
127
128                let count = completed
129                    .iter()
130                    .filter(|(_, p, t)| *p == phase && *t == tick)
131                    .count();
132
133                if count >= shard_count && shard_count > 0 {
134                    return Ok(());
135                }
136            }
137
138            // Wait for notification or timeout
139            tokio::select! {
140                _ = self.notify.notified() => {
141                    // A shard completed, loop to check if all are done
142                    continue;
143                }
144                _ = tokio::time::sleep(timeout) => {
145                    return Err(DistributedError::PhaseTimeout(phase));
146                }
147            }
148        }
149    }
150
151    /// Wait for all shards with custom timeout.
152    pub async fn wait_all_with_timeout(
153        &self,
154        phase: TickPhase,
155        tick: Tick,
156        timeout: tokio::time::Duration,
157    ) -> DistributedResult<()> {
158        loop {
159            {
160                let completed = self.completed.lock().await;
161                let shard_count = *self.shard_count.lock().await;
162
163                let count = completed
164                    .iter()
165                    .filter(|(_, p, t)| *p == phase && *t == tick)
166                    .count();
167
168                if count >= shard_count && shard_count > 0 {
169                    return Ok(());
170                }
171            }
172
173            tokio::select! {
174                _ = self.notify.notified() => continue,
175                _ = tokio::time::sleep(timeout) => {
176                    return Err(DistributedError::PhaseTimeout(phase));
177                }
178            }
179        }
180    }
181
182    /// Reset the barrier for a new tick.
183    ///
184    /// This clears all completion records. Should be called before
185    /// starting a new tick.
186    pub async fn reset_for_tick(&self, _tick: Tick) {
187        let mut completed = self.completed.lock().await;
188        completed.clear();
189    }
190
191    /// Get all shards that have completed a specific phase.
192    pub async fn get_completed_shards(&self, phase: TickPhase, tick: Tick) -> Vec<ShardId> {
193        let completed = self.completed.lock().await;
194        completed
195            .iter()
196            .filter(|(_, p, t)| *p == phase && *t == tick)
197            .map(|(s, _, _)| *s)
198            .collect()
199    }
200
201    /// Get all shards that have NOT completed a specific phase.
202    pub async fn get_pending_shards(
203        &self,
204        phase: TickPhase,
205        tick: Tick,
206        all_shards: &[ShardId],
207    ) -> Vec<ShardId> {
208        let completed = self.completed.lock().await;
209        let completed_set: HashSet<_> = completed
210            .iter()
211            .filter(|(_, p, t)| *p == phase && *t == tick)
212            .map(|(s, _, _)| *s)
213            .collect();
214
215        all_shards
216            .iter()
217            .filter(|s| !completed_set.contains(s))
218            .copied()
219            .collect()
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[tokio::test]
228    async fn test_barrier_creation() {
229        let barrier = TickBarrier::new(3);
230        assert_eq!(barrier.shard_count().await, 3);
231    }
232
233    #[tokio::test]
234    async fn test_phase_completion() {
235        let barrier = TickBarrier::new(2);
236
237        // Complete phase for shard 0
238        barrier
239            .complete(ShardId::new(0), TickPhase::Sense, 1)
240            .await
241            .unwrap();
242        assert!(
243            barrier
244                .is_complete(ShardId::new(0), TickPhase::Sense, 1)
245                .await
246        );
247        assert!(
248            !barrier
249                .is_complete(ShardId::new(1), TickPhase::Sense, 1)
250                .await
251        );
252
253        // Complete phase for shard 1
254        barrier
255            .complete(ShardId::new(1), TickPhase::Sense, 1)
256            .await
257            .unwrap();
258        assert!(
259            barrier
260                .is_complete(ShardId::new(1), TickPhase::Sense, 1)
261                .await
262        );
263    }
264
265    #[tokio::test]
266    async fn test_wait_all_completes() {
267        let barrier = TickBarrier::with_timeout(2, 5);
268
269        // Spawn task to complete both shards
270        let barrier_clone = std::sync::Arc::new(barrier);
271        let barrier_ref = barrier_clone.clone();
272
273        tokio::spawn(async move {
274            tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
275            barrier_ref
276                .complete(ShardId::new(0), TickPhase::Sense, 1)
277                .await
278                .unwrap();
279            barrier_ref
280                .complete(ShardId::new(1), TickPhase::Sense, 1)
281                .await
282                .unwrap();
283        });
284
285        // Wait should succeed
286        barrier_clone.wait_all(TickPhase::Sense, 1).await.unwrap();
287    }
288
289    #[tokio::test]
290    async fn test_reset_for_tick() {
291        let barrier = TickBarrier::new(1);
292
293        barrier
294            .complete(ShardId::new(0), TickPhase::Sense, 1)
295            .await
296            .unwrap();
297        assert!(
298            barrier
299                .is_complete(ShardId::new(0), TickPhase::Sense, 1)
300                .await
301        );
302
303        barrier.reset_for_tick(2).await;
304        assert!(
305            !barrier
306                .is_complete(ShardId::new(0), TickPhase::Sense, 1)
307                .await
308        );
309    }
310
311    #[tokio::test]
312    async fn test_completed_count() {
313        let barrier = TickBarrier::new(3);
314
315        assert_eq!(barrier.completed_count(TickPhase::Sense, 1).await, 0);
316
317        barrier
318            .complete(ShardId::new(0), TickPhase::Sense, 1)
319            .await
320            .unwrap();
321        assert_eq!(barrier.completed_count(TickPhase::Sense, 1).await, 1);
322
323        barrier
324            .complete(ShardId::new(1), TickPhase::Sense, 1)
325            .await
326            .unwrap();
327        assert_eq!(barrier.completed_count(TickPhase::Sense, 1).await, 2);
328    }
329
330    #[tokio::test]
331    async fn test_update_shard_count() {
332        let barrier = TickBarrier::new(2);
333        assert_eq!(barrier.shard_count().await, 2);
334
335        barrier.set_shard_count(5).await;
336        assert_eq!(barrier.shard_count().await, 5);
337    }
338}