use crate::types::*;
use phago_core::types::Tick;
use std::collections::HashSet;
use tokio::sync::{Mutex, Notify};
pub struct TickBarrier {
shard_count: Mutex<usize>,
completed: Mutex<HashSet<(ShardId, TickPhase, Tick)>>,
notify: Notify,
phase_timeout_secs: u64,
}
impl TickBarrier {
pub fn new(shard_count: usize) -> Self {
Self {
shard_count: Mutex::new(shard_count),
completed: Mutex::new(HashSet::new()),
notify: Notify::new(),
phase_timeout_secs: 30,
}
}
pub fn with_timeout(shard_count: usize, timeout_secs: u64) -> Self {
Self {
shard_count: Mutex::new(shard_count),
completed: Mutex::new(HashSet::new()),
notify: Notify::new(),
phase_timeout_secs: timeout_secs,
}
}
pub async fn set_shard_count(&self, count: usize) {
let mut sc = self.shard_count.lock().await;
*sc = count;
}
pub async fn shard_count(&self) -> usize {
*self.shard_count.lock().await
}
pub async fn complete(
&self,
shard_id: ShardId,
phase: TickPhase,
tick: Tick,
) -> DistributedResult<()> {
let mut completed = self.completed.lock().await;
completed.insert((shard_id, phase, tick));
drop(completed);
self.notify.notify_waiters();
Ok(())
}
pub async fn is_complete(&self, shard_id: ShardId, phase: TickPhase, tick: Tick) -> bool {
let completed = self.completed.lock().await;
completed.contains(&(shard_id, phase, tick))
}
pub async fn completed_count(&self, phase: TickPhase, tick: Tick) -> usize {
let completed = self.completed.lock().await;
completed
.iter()
.filter(|(_, p, t)| *p == phase && *t == tick)
.count()
}
pub async fn wait_all(&self, phase: TickPhase, tick: Tick) -> DistributedResult<()> {
let timeout = tokio::time::Duration::from_secs(self.phase_timeout_secs);
loop {
{
let completed = self.completed.lock().await;
let shard_count = *self.shard_count.lock().await;
let count = completed
.iter()
.filter(|(_, p, t)| *p == phase && *t == tick)
.count();
if count >= shard_count && shard_count > 0 {
return Ok(());
}
}
tokio::select! {
_ = self.notify.notified() => {
continue;
}
_ = tokio::time::sleep(timeout) => {
return Err(DistributedError::PhaseTimeout(phase));
}
}
}
}
pub async fn wait_all_with_timeout(
&self,
phase: TickPhase,
tick: Tick,
timeout: tokio::time::Duration,
) -> DistributedResult<()> {
loop {
{
let completed = self.completed.lock().await;
let shard_count = *self.shard_count.lock().await;
let count = completed
.iter()
.filter(|(_, p, t)| *p == phase && *t == tick)
.count();
if count >= shard_count && shard_count > 0 {
return Ok(());
}
}
tokio::select! {
_ = self.notify.notified() => continue,
_ = tokio::time::sleep(timeout) => {
return Err(DistributedError::PhaseTimeout(phase));
}
}
}
}
pub async fn reset_for_tick(&self, _tick: Tick) {
let mut completed = self.completed.lock().await;
completed.clear();
}
pub async fn get_completed_shards(&self, phase: TickPhase, tick: Tick) -> Vec<ShardId> {
let completed = self.completed.lock().await;
completed
.iter()
.filter(|(_, p, t)| *p == phase && *t == tick)
.map(|(s, _, _)| *s)
.collect()
}
pub async fn get_pending_shards(
&self,
phase: TickPhase,
tick: Tick,
all_shards: &[ShardId],
) -> Vec<ShardId> {
let completed = self.completed.lock().await;
let completed_set: HashSet<_> = completed
.iter()
.filter(|(_, p, t)| *p == phase && *t == tick)
.map(|(s, _, _)| *s)
.collect();
all_shards
.iter()
.filter(|s| !completed_set.contains(s))
.copied()
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_barrier_creation() {
let barrier = TickBarrier::new(3);
assert_eq!(barrier.shard_count().await, 3);
}
#[tokio::test]
async fn test_phase_completion() {
let barrier = TickBarrier::new(2);
barrier
.complete(ShardId::new(0), TickPhase::Sense, 1)
.await
.unwrap();
assert!(
barrier
.is_complete(ShardId::new(0), TickPhase::Sense, 1)
.await
);
assert!(
!barrier
.is_complete(ShardId::new(1), TickPhase::Sense, 1)
.await
);
barrier
.complete(ShardId::new(1), TickPhase::Sense, 1)
.await
.unwrap();
assert!(
barrier
.is_complete(ShardId::new(1), TickPhase::Sense, 1)
.await
);
}
#[tokio::test]
async fn test_wait_all_completes() {
let barrier = TickBarrier::with_timeout(2, 5);
let barrier_clone = std::sync::Arc::new(barrier);
let barrier_ref = barrier_clone.clone();
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
barrier_ref
.complete(ShardId::new(0), TickPhase::Sense, 1)
.await
.unwrap();
barrier_ref
.complete(ShardId::new(1), TickPhase::Sense, 1)
.await
.unwrap();
});
barrier_clone.wait_all(TickPhase::Sense, 1).await.unwrap();
}
#[tokio::test]
async fn test_reset_for_tick() {
let barrier = TickBarrier::new(1);
barrier
.complete(ShardId::new(0), TickPhase::Sense, 1)
.await
.unwrap();
assert!(
barrier
.is_complete(ShardId::new(0), TickPhase::Sense, 1)
.await
);
barrier.reset_for_tick(2).await;
assert!(
!barrier
.is_complete(ShardId::new(0), TickPhase::Sense, 1)
.await
);
}
#[tokio::test]
async fn test_completed_count() {
let barrier = TickBarrier::new(3);
assert_eq!(barrier.completed_count(TickPhase::Sense, 1).await, 0);
barrier
.complete(ShardId::new(0), TickPhase::Sense, 1)
.await
.unwrap();
assert_eq!(barrier.completed_count(TickPhase::Sense, 1).await, 1);
barrier
.complete(ShardId::new(1), TickPhase::Sense, 1)
.await
.unwrap();
assert_eq!(barrier.completed_count(TickPhase::Sense, 1).await, 2);
}
#[tokio::test]
async fn test_update_shard_count() {
let barrier = TickBarrier::new(2);
assert_eq!(barrier.shard_count().await, 2);
barrier.set_shard_count(5).await;
assert_eq!(barrier.shard_count().await, 5);
}
}