use std::sync::{Arc, Mutex};
use tokio::sync::oneshot;
#[derive(Debug)]
pub struct Barrier {
inner: Arc<Mutex<BarrierInner>>,
}
#[derive(Debug)]
struct BarrierInner {
n: usize,
count: usize,
generation: usize,
waiters: Vec<oneshot::Sender<bool>>, }
#[derive(Clone, Debug)]
pub struct BarrierWaitResult {
is_leader: bool,
}
impl BarrierWaitResult {
#[must_use]
pub const fn is_leader(&self) -> bool {
self.is_leader
}
}
impl Barrier {
#[must_use]
pub fn new(n: usize) -> Self {
assert!(n > 0, "barrier size must be positive");
Self {
inner: Arc::new(Mutex::new(BarrierInner {
n,
count: 0,
generation: 0,
waiters: Vec::with_capacity(n.saturating_sub(1)),
})),
}
}
pub async fn wait(&self) -> BarrierWaitResult {
let receiver = {
let mut inner = self.inner.lock().unwrap();
if inner.count == 0 {
inner.waiters.clear();
}
inner.count += 1;
if inner.count == inner.n {
inner.count = 0;
inner.generation = inner.generation.wrapping_add(1);
for tx in inner.waiters.drain(..) {
let _ = tx.send(false);
}
return BarrierWaitResult { is_leader: true };
}
let (tx, rx) = oneshot::channel();
inner.waiters.push(tx);
rx
};
let is_leader = receiver.await.unwrap_or(false);
BarrierWaitResult { is_leader }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::Builder;
#[test]
fn test_barrier_basic() {
let runtime = crate::simulator::runtime::build_runtime(&Builder::new()).unwrap();
runtime.block_on(async {
let barrier = Arc::new(Barrier::new(2));
let b1 = barrier.clone();
let b2 = barrier.clone();
let (r1, r2) =
futures::future::join(
async move { b1.wait().await },
async move { b2.wait().await },
)
.await;
assert_ne!(r1.is_leader(), r2.is_leader());
assert!(r1.is_leader() || r2.is_leader());
});
runtime.wait().unwrap();
}
#[test]
fn test_barrier_multiple_tasks() {
let runtime = crate::simulator::runtime::build_runtime(&Builder::new()).unwrap();
runtime.block_on(async {
let barrier = Arc::new(Barrier::new(5));
let mut handles = vec![];
for i in 0..5 {
let b = barrier.clone();
handles.push(crate::task::spawn(async move {
let result = b.wait().await;
(i, result.is_leader())
}));
}
let results: Vec<_> = futures::future::join_all(handles)
.await
.into_iter()
.map(Result::unwrap)
.collect();
let leader_count = results.iter().filter(|(_, is_leader)| *is_leader).count();
assert_eq!(leader_count, 1);
});
runtime.wait().unwrap();
}
#[test]
fn test_barrier_single_task() {
let runtime = crate::simulator::runtime::build_runtime(&Builder::new()).unwrap();
runtime.block_on(async {
let barrier = Barrier::new(1);
let result = barrier.wait().await;
assert!(result.is_leader());
});
runtime.wait().unwrap();
}
#[test]
#[should_panic(expected = "barrier size must be positive")]
fn test_barrier_zero_size() {
let _ = Barrier::new(0);
}
#[test]
fn test_barrier_wait_result_clone() {
let result = BarrierWaitResult { is_leader: true };
let cloned = result.clone();
assert_eq!(result.is_leader(), cloned.is_leader());
let result2 = BarrierWaitResult { is_leader: false };
let cloned2 = result2.clone();
assert_eq!(result2.is_leader(), cloned2.is_leader());
}
#[test]
fn test_barrier_generation_advances_on_each_cycle() {
let runtime = crate::simulator::runtime::build_runtime(&Builder::new()).unwrap();
runtime.block_on(async {
let barrier = Arc::new(Barrier::new(2));
for cycle in 0..5 {
let b1 = barrier.clone();
let b2 = barrier.clone();
let (r1, r2) = futures::future::join(async move { b1.wait().await }, async move {
b2.wait().await
})
.await;
assert_ne!(
r1.is_leader(),
r2.is_leader(),
"Cycle {cycle}: expected exactly one leader"
);
}
});
runtime.wait().unwrap();
}
#[test]
fn test_barrier_inner_state_initialization() {
let barrier = Barrier::new(5);
let inner = barrier.inner.lock().unwrap();
assert_eq!(inner.n, 5);
assert_eq!(inner.count, 0);
assert_eq!(inner.generation, 0);
assert!(inner.waiters.is_empty());
drop(inner);
}
}