use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering::*};
use std::sync::{Condvar, Mutex};
use std::time::{Duration, Instant};
const SPIN_BUDGET: u32 = 512;
const WATCHDOG_STRIDE: u32 = 1024;
pub const DEFAULT_DEADLINE: Duration = Duration::from_secs(5);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BarrierResult {
Released,
Poisoned,
TimedOut { elapsed_ms: u32 },
}
pub struct SpinBarrier {
generation: AtomicU32,
count: AtomicU32,
parties: u32,
poisoned: AtomicBool,
timed_out: AtomicBool,
timeout_elapsed_ms: AtomicU64,
deadline: Duration,
park_mu: Mutex<()>,
park_cv: Condvar,
}
impl SpinBarrier {
pub fn new(parties: u32) -> Self {
Self::with_deadline(parties, DEFAULT_DEADLINE)
}
pub fn with_deadline(parties: u32, deadline: Duration) -> Self {
assert!(parties >= 2);
Self {
generation: AtomicU32::new(0),
count: AtomicU32::new(0),
parties,
poisoned: AtomicBool::new(false),
timed_out: AtomicBool::new(false),
timeout_elapsed_ms: AtomicU64::new(0),
deadline,
park_mu: Mutex::new(()),
park_cv: Condvar::new(),
}
}
pub fn wait(&self) -> BarrierResult {
let start = Instant::now();
if self.poisoned.load(Acquire) {
return self.poisoned_or_timed_out_result(start);
}
let cur_gen = self.generation.load(Acquire);
let n = self.count.fetch_add(1, AcqRel) + 1;
if n == self.parties {
{
let _g = self.park_mu.lock().unwrap();
self.count.store(0, Relaxed);
self.generation.store(cur_gen.wrapping_add(1), Release);
}
self.park_cv.notify_all();
return BarrierResult::Released;
}
let mut i: u32 = 0;
for _ in 0..SPIN_BUDGET {
if self.poisoned.load(Acquire) {
return self.poisoned_or_timed_out_result(start);
}
if self.generation.load(Acquire) != cur_gen {
return BarrierResult::Released;
}
i = i.wrapping_add(1);
if i & (WATCHDOG_STRIDE - 1) == 0 && start.elapsed() >= self.deadline {
return self.trip_watchdog(start);
}
std::hint::spin_loop();
}
let mut g = self.park_mu.lock().unwrap();
loop {
if self.generation.load(Acquire) != cur_gen {
return BarrierResult::Released;
}
if self.poisoned.load(Acquire) {
return self.poisoned_or_timed_out_result(start);
}
let remaining = self.deadline.saturating_sub(start.elapsed());
if remaining.is_zero() {
drop(g);
return self.trip_watchdog(start);
}
let (gg, _timeout) = self.park_cv.wait_timeout(g, remaining).unwrap();
g = gg;
}
}
#[inline]
fn poisoned_or_timed_out_result(&self, start: Instant) -> BarrierResult {
if self.timed_out.load(Acquire) {
let recorded = self.timeout_elapsed_ms.load(Acquire);
let elapsed_ms = if recorded != 0 {
recorded.min(u32::MAX as u64) as u32
} else {
start.elapsed().as_millis().min(u32::MAX as u128) as u32
};
BarrierResult::TimedOut { elapsed_ms }
} else {
BarrierResult::Poisoned
}
}
#[inline]
fn trip_watchdog(&self, start: Instant) -> BarrierResult {
let elapsed = start.elapsed();
let elapsed_ms = elapsed.as_millis().min(u32::MAX as u128) as u32;
let _ = self
.timeout_elapsed_ms
.compare_exchange(0, elapsed_ms as u64, AcqRel, Acquire);
self.timed_out.store(true, Release);
self.poison();
BarrierResult::TimedOut { elapsed_ms }
}
pub fn timed_out(&self) -> bool {
self.timed_out.load(Acquire)
}
pub fn timeout_elapsed_ms(&self) -> u32 {
self.timeout_elapsed_ms.load(Acquire).min(u32::MAX as u64) as u32
}
pub fn poison(&self) {
{
let _g = self.park_mu.lock().unwrap();
self.poisoned.store(true, Release);
}
self.park_cv.notify_all();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{
AtomicU32, AtomicUsize,
Ordering::{self, SeqCst},
};
use std::thread;
use std::time::Duration;
#[test]
fn all_threads_released() {
let barrier = Arc::new(SpinBarrier::new(4));
let flags: [Arc<AtomicU32>; 4] = [
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(0)),
Arc::new(AtomicU32::new(0)),
];
let handles: Vec<_> = (0..4)
.map(|i| {
let b = Arc::clone(&barrier);
let f = Arc::clone(&flags[i]);
thread::spawn(move || match b.wait() {
BarrierResult::Released => f.store(1, SeqCst),
BarrierResult::Poisoned => panic!("unexpected poison"),
BarrierResult::TimedOut { .. } => panic!("unexpected timeout"),
})
})
.collect();
for h in handles {
h.join().expect("thread panicked");
}
for f in &flags {
assert_eq!(f.load(SeqCst), 1, "thread did not set released flag");
}
}
#[test]
fn multiple_rounds() {
const PARTIES: u32 = 4;
const ROUNDS: u32 = 10;
let barrier = Arc::new(SpinBarrier::new(PARTIES));
let counter = Arc::new(AtomicUsize::new(0));
let handles: Vec<_> = (0..PARTIES)
.map(|_| {
let b = Arc::clone(&barrier);
let c = Arc::clone(&counter);
thread::spawn(move || {
for _ in 0..ROUNDS {
match b.wait() {
BarrierResult::Released => {
c.fetch_add(1, SeqCst);
}
BarrierResult::Poisoned => panic!("unexpected poison"),
BarrierResult::TimedOut { .. } => panic!("unexpected timeout"),
}
}
})
})
.collect();
for h in handles {
h.join().expect("thread panicked");
}
assert_eq!(
counter.load(SeqCst),
(PARTIES * ROUNDS) as usize,
"counter should equal parties * rounds"
);
}
#[test]
fn multiple_rounds_6way() {
const PARTIES: u32 = 6;
const ROUNDS: u32 = 10;
let barrier = Arc::new(SpinBarrier::new(PARTIES));
let counter = Arc::new(AtomicUsize::new(0));
let handles: Vec<_> = (0..PARTIES)
.map(|_| {
let b = Arc::clone(&barrier);
let c = Arc::clone(&counter);
thread::spawn(move || {
for _ in 0..ROUNDS {
match b.wait() {
BarrierResult::Released => {
c.fetch_add(1, SeqCst);
}
BarrierResult::Poisoned => panic!("unexpected poison"),
BarrierResult::TimedOut { .. } => panic!("unexpected timeout"),
}
}
})
})
.collect();
for h in handles {
h.join().expect("thread panicked");
}
assert_eq!(
counter.load(SeqCst),
(PARTIES * ROUNDS) as usize,
"counter should equal parties * rounds"
);
}
#[test]
fn parties_6_asymmetric_arrival_does_not_park() {
const PARTIES: u32 = 6;
const ROUNDS: u32 = 10;
let barrier = Arc::new(SpinBarrier::new(PARTIES));
let counter = Arc::new(AtomicU32::new(0));
let handles: Vec<_> = (0..PARTIES)
.map(|tid| {
let b = Arc::clone(&barrier);
let c = Arc::clone(&counter);
thread::spawn(move || {
for _ in 0..ROUNDS {
if tid == PARTIES - 1 {
let t0 = std::time::Instant::now();
while t0.elapsed() < Duration::from_micros(2) {
std::hint::spin_loop();
}
}
match b.wait() {
BarrierResult::Released => {
c.fetch_add(1, SeqCst);
}
BarrierResult::Poisoned => panic!("unexpected poison"),
BarrierResult::TimedOut { .. } => panic!("unexpected timeout"),
}
}
})
})
.collect();
for h in handles {
h.join().expect("thread panicked");
}
assert_eq!(counter.load(SeqCst), PARTIES * ROUNDS);
}
#[test]
fn poison_breaks_waiters() {
let barrier = Arc::new(SpinBarrier::new(4));
let entered = Arc::new(AtomicU32::new(0));
let handles: Vec<_> = (0..3)
.map(|_| {
let b = Arc::clone(&barrier);
let e = Arc::clone(&entered);
thread::spawn(move || {
e.fetch_add(1, Ordering::Release);
b.wait()
})
})
.collect();
while entered.load(Ordering::Acquire) < 3 {
thread::sleep(Duration::from_millis(1));
}
thread::sleep(Duration::from_millis(10));
barrier.poison();
for h in handles {
let result = h.join().expect("thread panicked");
assert_eq!(
result,
BarrierResult::Poisoned,
"waiter should have returned Poisoned"
);
}
}
#[test]
fn barrier_watchdog_fires_when_worker_stalls() {
const DEADLINE: Duration = Duration::from_millis(150);
let barrier = Arc::new(SpinBarrier::with_deadline(2, DEADLINE));
let b = Arc::clone(&barrier);
let start = std::time::Instant::now();
let handle = thread::spawn(move || b.wait());
let result = handle.join().expect("worker panicked");
let elapsed = start.elapsed();
assert!(
elapsed >= DEADLINE,
"watchdog fired too early: {elapsed:?} < {DEADLINE:?}"
);
assert!(
elapsed < DEADLINE + Duration::from_millis(500),
"watchdog took too long: {elapsed:?} > {:?}",
DEADLINE + Duration::from_millis(500)
);
match result {
BarrierResult::TimedOut { elapsed_ms } => {
assert!(
elapsed_ms >= DEADLINE.as_millis() as u32,
"reported elapsed_ms {elapsed_ms} < deadline {}",
DEADLINE.as_millis()
);
}
other => panic!("expected TimedOut, got {other:?}"),
}
let reentry = barrier.wait();
match reentry {
BarrierResult::TimedOut { .. } => {}
other => panic!("re-entry after timeout should yield TimedOut, got {other:?}"),
}
assert!(barrier.timed_out());
assert!(barrier.timeout_elapsed_ms() >= DEADLINE.as_millis() as u32);
}
#[test]
fn wait_after_poison_returns_poisoned_immediately() {
let barrier = SpinBarrier::new(4);
barrier.poison();
let r = barrier.wait();
assert_eq!(r, BarrierResult::Poisoned);
assert!(!barrier.timed_out());
assert_eq!(barrier.timeout_elapsed_ms(), 0);
assert_eq!(barrier.wait(), BarrierResult::Poisoned);
assert!(!barrier.timed_out());
}
#[test]
fn waiter_parks_then_observes_release() {
let barrier = Arc::new(SpinBarrier::new(2));
let b = Arc::clone(&barrier);
let early = thread::spawn(move || b.wait());
thread::sleep(Duration::from_millis(50));
let late = barrier.wait();
assert_eq!(late, BarrierResult::Released);
assert_eq!(
early.join().expect("early panicked"),
BarrierResult::Released
);
assert!(!barrier.timed_out());
}
#[test]
fn waiter_parks_then_observes_poison() {
let barrier = Arc::new(SpinBarrier::with_deadline(2, Duration::from_secs(30)));
let b = Arc::clone(&barrier);
let h = thread::spawn(move || b.wait());
thread::sleep(Duration::from_millis(50));
barrier.poison();
let r = h.join().expect("waiter panicked");
assert_eq!(r, BarrierResult::Poisoned);
assert!(!barrier.timed_out());
}
#[test]
fn watchdog_fires_from_park_path() {
const DEADLINE: Duration = Duration::from_millis(80);
let barrier = Arc::new(SpinBarrier::with_deadline(2, DEADLINE));
let b = Arc::clone(&barrier);
let r = thread::spawn(move || b.wait())
.join()
.expect("waiter panicked");
match r {
BarrierResult::TimedOut { elapsed_ms } => {
assert!(elapsed_ms as u128 >= DEADLINE.as_millis());
}
other => panic!("expected TimedOut, got {other:?}"),
}
assert!(barrier.timed_out());
assert!(barrier.timeout_elapsed_ms() >= DEADLINE.as_millis() as u32);
}
#[test]
fn watchdog_trips_during_spin_2party_50ms() {
const DEADLINE: Duration = Duration::from_millis(50);
let barrier = Arc::new(SpinBarrier::with_deadline(2, DEADLINE));
let b = Arc::clone(&barrier);
let r = thread::spawn(move || b.wait())
.join()
.expect("waiter panicked");
match r {
BarrierResult::TimedOut { elapsed_ms } => {
assert!(
elapsed_ms >= 30,
"elapsed_ms {elapsed_ms} should be at least 30 (deadline 50)"
);
}
other => panic!("expected TimedOut, got {other:?}"),
}
assert!(barrier.timed_out());
}
#[test]
fn watchdog_trips_during_sleep_2party_200ms() {
const DEADLINE: Duration = Duration::from_millis(200);
let barrier = Arc::new(SpinBarrier::with_deadline(2, DEADLINE));
let b = Arc::clone(&barrier);
let r = thread::spawn(move || b.wait())
.join()
.expect("waiter panicked");
match r {
BarrierResult::TimedOut { elapsed_ms } => {
assert!(
elapsed_ms >= 150,
"elapsed_ms {elapsed_ms} should be at least 150 (deadline 200)"
);
}
other => panic!("expected TimedOut, got {other:?}"),
}
assert!(barrier.timed_out());
assert!(barrier.timeout_elapsed_ms() >= 150);
}
#[test]
fn poisoned_on_entry_via_separate_thread() {
let barrier = Arc::new(SpinBarrier::new(2));
let b_poisoner = Arc::clone(&barrier);
thread::spawn(move || {
b_poisoner.poison();
})
.join()
.expect("poisoner panicked");
let r = barrier.wait();
assert_eq!(r, BarrierResult::Poisoned);
assert!(!barrier.timed_out());
}
#[test]
fn poisoned_during_spin_loop() {
let barrier = Arc::new(SpinBarrier::new(2));
let entered = Arc::new(AtomicU32::new(0));
let b = Arc::clone(&barrier);
let e = Arc::clone(&entered);
let h = thread::spawn(move || {
e.store(1, SeqCst);
b.wait()
});
while entered.load(SeqCst) == 0 {
thread::sleep(Duration::from_micros(10));
}
thread::sleep(Duration::from_millis(1));
barrier.poison();
let r = h.join().expect("waiter panicked");
assert_eq!(r, BarrierResult::Poisoned);
assert!(!barrier.timed_out());
}
#[test]
fn poisoned_during_condvar_sleep() {
let barrier = Arc::new(SpinBarrier::with_deadline(2, Duration::from_secs(30)));
let b = Arc::clone(&barrier);
let h = thread::spawn(move || b.wait());
thread::sleep(Duration::from_millis(50));
barrier.poison();
let r = h.join().expect("waiter panicked");
assert_eq!(r, BarrierResult::Poisoned);
assert!(!barrier.timed_out());
}
#[test]
fn poisoned_or_timed_out_result_distinguishes_variants() {
const DEADLINE: Duration = Duration::from_millis(50);
let timed_out_barrier = SpinBarrier::with_deadline(2, DEADLINE);
let r1 = timed_out_barrier.wait();
let first_elapsed = match r1 {
BarrierResult::TimedOut { elapsed_ms } => elapsed_ms,
other => panic!("expected TimedOut on first wait, got {other:?}"),
};
assert!(timed_out_barrier.timed_out());
let r2 = timed_out_barrier.wait();
match r2 {
BarrierResult::TimedOut { elapsed_ms } => {
assert_eq!(
elapsed_ms, first_elapsed,
"re-entry should surface the recorded watchdog elapsed_ms"
);
}
other => panic!("expected TimedOut on re-entry, got {other:?}"),
}
let poison_only_barrier = SpinBarrier::new(2);
poison_only_barrier.poison();
assert!(!poison_only_barrier.timed_out());
let r3 = poison_only_barrier.wait();
assert_eq!(r3, BarrierResult::Poisoned);
assert_eq!(poison_only_barrier.timeout_elapsed_ms(), 0);
}
#[test]
fn solo_round_then_normal_round() {
let barrier = Arc::new(SpinBarrier::new(2));
let b1 = Arc::clone(&barrier);
let h = thread::spawn(move || b1.wait());
let r_main = barrier.wait();
let r_other = h.join().expect("other panicked");
assert_eq!(r_main, BarrierResult::Released);
assert_eq!(r_other, BarrierResult::Released);
let b2 = Arc::clone(&barrier);
let h2 = thread::spawn(move || b2.wait());
let r_main2 = barrier.wait();
let r_other2 = h2.join().expect("other panicked");
assert_eq!(r_main2, BarrierResult::Released);
assert_eq!(r_other2, BarrierResult::Released);
assert!(!barrier.timed_out());
}
}