use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::time::{Duration, Instant};
use crate::error::{CoreError, CoreResult, ErrorContext, ErrorLocation};
fn lock_err(context: &'static str, e: impl std::fmt::Display) -> CoreError {
CoreError::MutexError(
ErrorContext::new(format!("{context}: mutex poisoned: {e}"))
.with_location(ErrorLocation::new(file!(), line!())),
)
}
fn wait_err(context: &'static str, e: impl std::fmt::Display) -> CoreError {
CoreError::MutexError(
ErrorContext::new(format!("{context}: condvar wait poisoned: {e}"))
.with_location(ErrorLocation::new(file!(), line!())),
)
}
struct CyclicBarrierInner {
waiting: usize,
parties: usize,
generation: u64,
broken: bool,
}
pub struct CyclicBarrier {
inner: Mutex<CyclicBarrierInner>,
condvar: Condvar,
}
impl CyclicBarrier {
pub fn new(parties: usize) -> Self {
Self {
inner: Mutex::new(CyclicBarrierInner {
waiting: parties,
parties,
generation: 0,
broken: false,
}),
condvar: Condvar::new(),
}
}
pub fn wait(&self) -> CoreResult<bool> {
let mut g = self
.inner
.lock()
.map_err(|e| lock_err("CyclicBarrier::wait", e))?;
if g.broken {
return Err(CoreError::MutexError(ErrorContext::new(
"CyclicBarrier: barrier is broken",
)));
}
let gen = g.generation;
g.waiting -= 1;
if g.waiting == 0 {
g.waiting = g.parties;
g.generation = gen.wrapping_add(1);
self.condvar.notify_all();
return Ok(true);
}
loop {
g = self
.condvar
.wait(g)
.map_err(|e| wait_err("CyclicBarrier::wait", e))?;
if g.broken {
return Err(CoreError::MutexError(ErrorContext::new(
"CyclicBarrier: barrier broken while waiting",
)));
}
if g.generation != gen {
return Ok(false);
}
}
}
pub fn wait_timeout(&self, timeout: Duration) -> CoreResult<bool> {
let deadline = Instant::now() + timeout;
let mut g = self
.inner
.lock()
.map_err(|e| lock_err("CyclicBarrier::wait_timeout", e))?;
if g.broken {
return Err(CoreError::MutexError(ErrorContext::new(
"CyclicBarrier: barrier is broken",
)));
}
let gen = g.generation;
g.waiting -= 1;
if g.waiting == 0 {
g.waiting = g.parties;
g.generation = gen.wrapping_add(1);
self.condvar.notify_all();
return Ok(true);
}
loop {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
g.broken = true;
self.condvar.notify_all();
return Err(CoreError::TimeoutError(ErrorContext::new(
"CyclicBarrier: timed out waiting for all parties",
)));
}
let (next_g, _timeout_result) = self
.condvar
.wait_timeout(g, remaining)
.map_err(|e| wait_err("CyclicBarrier::wait_timeout", e))?;
g = next_g;
if g.broken {
return Err(CoreError::MutexError(ErrorContext::new(
"CyclicBarrier: barrier broken while waiting",
)));
}
if g.generation != gen {
return Ok(false);
}
}
}
pub fn break_barrier(&self) {
if let Ok(mut g) = self.inner.lock() {
g.broken = true;
self.condvar.notify_all();
}
}
pub fn reset(&self) {
if let Ok(mut g) = self.inner.lock() {
g.waiting = g.parties;
g.broken = false;
g.generation = g.generation.wrapping_add(1);
self.condvar.notify_all();
}
}
pub fn is_broken(&self) -> bool {
self.inner.lock().map(|g| g.broken).unwrap_or(true)
}
pub fn waiting(&self) -> usize {
self.inner.lock().map(|g| g.waiting).unwrap_or(0)
}
pub fn parties(&self) -> usize {
self.inner.lock().map(|g| g.parties).unwrap_or(0)
}
}
struct PhaseBarrierInner {
phase: u64,
waiting: usize,
parties: usize,
}
pub struct PhaseBarrier {
inner: Mutex<PhaseBarrierInner>,
condvar: Condvar,
}
impl PhaseBarrier {
pub fn new(parties: usize) -> Self {
Self {
inner: Mutex::new(PhaseBarrierInner {
phase: 0,
waiting: parties,
parties,
}),
condvar: Condvar::new(),
}
}
pub fn arrive_and_wait(&self) -> CoreResult<u64> {
let mut g = self
.inner
.lock()
.map_err(|e| lock_err("PhaseBarrier::arrive_and_wait", e))?;
let current_phase = g.phase;
g.waiting -= 1;
if g.waiting == 0 {
g.phase = current_phase.wrapping_add(1);
g.waiting = g.parties;
self.condvar.notify_all();
return Ok(current_phase);
}
loop {
g = self
.condvar
.wait(g)
.map_err(|e| wait_err("PhaseBarrier::arrive_and_wait", e))?;
if g.phase != current_phase {
return Ok(current_phase);
}
}
}
pub fn arrive(&self) -> CoreResult<u64> {
let mut g = self
.inner
.lock()
.map_err(|e| lock_err("PhaseBarrier::arrive", e))?;
g.waiting -= 1;
if g.waiting == 0 {
let completed = g.phase;
g.phase = completed.wrapping_add(1);
g.waiting = g.parties;
self.condvar.notify_all();
Ok(completed)
} else {
Ok(g.phase)
}
}
pub fn phase(&self) -> u64 {
self.inner.lock().map(|g| g.phase).unwrap_or(0)
}
pub fn waiting(&self) -> usize {
self.inner.lock().map(|g| g.waiting).unwrap_or(0)
}
}
pub struct CountDownLatch {
inner: Mutex<usize>,
condvar: Condvar,
}
impl CountDownLatch {
pub fn new(n: usize) -> Self {
Self {
inner: Mutex::new(n),
condvar: Condvar::new(),
}
}
pub fn count_down(&self) {
if let Ok(mut g) = self.inner.lock() {
if *g > 0 {
*g -= 1;
if *g == 0 {
self.condvar.notify_all();
}
}
}
}
pub fn wait(&self) -> CoreResult<()> {
let mut g = self
.inner
.lock()
.map_err(|e| lock_err("CountDownLatch::wait", e))?;
loop {
if *g == 0 {
return Ok(());
}
g = self
.condvar
.wait(g)
.map_err(|e| wait_err("CountDownLatch::wait", e))?;
}
}
pub fn wait_timeout(&self, timeout: Duration) -> CoreResult<bool> {
let deadline = Instant::now() + timeout;
let mut g = self
.inner
.lock()
.map_err(|e| lock_err("CountDownLatch::wait_timeout", e))?;
loop {
if *g == 0 {
return Ok(true);
}
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return Ok(false);
}
let (next_g, _) = self
.condvar
.wait_timeout(g, remaining)
.map_err(|e| wait_err("CountDownLatch::wait_timeout", e))?;
g = next_g;
}
}
pub fn count(&self) -> usize {
self.inner.lock().map(|g| *g).unwrap_or(0)
}
pub fn is_open(&self) -> bool {
self.count() == 0
}
}
pub struct SpinBarrier {
arrived: AtomicUsize,
epoch: AtomicUsize,
parties: usize,
}
impl SpinBarrier {
pub fn new(parties: usize) -> Self {
let parties = parties.max(1);
Self {
arrived: AtomicUsize::new(0),
epoch: AtomicUsize::new(0),
parties,
}
}
pub fn wait(&self) {
let current_epoch = self.epoch.load(Ordering::Acquire);
let prev = self.arrived.fetch_add(1, Ordering::AcqRel);
let new_count = prev + 1;
if new_count == self.parties {
self.arrived.store(0, Ordering::Release);
self.epoch.fetch_add(1, Ordering::Release);
} else {
let mut spins = 0usize;
loop {
let e = self.epoch.load(Ordering::Acquire);
if e != current_epoch {
break;
}
if spins < 32 {
std::hint::spin_loop();
} else {
std::thread::yield_now();
}
spins = spins.saturating_add(1);
}
}
}
pub fn parties(&self) -> usize {
self.parties
}
pub fn epoch(&self) -> usize {
self.epoch.load(Ordering::Relaxed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU64, Ordering as AO};
use std::thread;
#[test]
fn cyclic_barrier_all_proceed() {
const N: usize = 5;
let barrier = Arc::new(CyclicBarrier::new(N));
let counter = Arc::new(AtomicU64::new(0));
let mut handles = Vec::new();
for _ in 0..N {
let b = Arc::clone(&barrier);
let c = Arc::clone(&counter);
handles.push(thread::spawn(move || {
c.fetch_add(1, AO::Relaxed);
b.wait().expect("barrier wait");
}));
}
for h in handles {
h.join().expect("thread");
}
assert_eq!(counter.load(AO::Relaxed), N as u64);
}
#[test]
fn cyclic_barrier_trip_thread_count() {
const N: usize = 4;
let barrier = Arc::new(CyclicBarrier::new(N));
let trips = Arc::new(AtomicU64::new(0));
let mut handles = Vec::new();
for _ in 0..N {
let b = Arc::clone(&barrier);
let t = Arc::clone(&trips);
handles.push(thread::spawn(move || {
let trip = b.wait().expect("wait");
if trip {
t.fetch_add(1, AO::Relaxed);
}
}));
}
for h in handles {
h.join().expect("thread");
}
assert_eq!(trips.load(AO::Relaxed), 1, "exactly one trip thread");
}
#[test]
fn cyclic_barrier_two_cycles() {
const N: usize = 3;
let barrier = Arc::new(CyclicBarrier::new(N));
let phase_counter = Arc::new(AtomicU64::new(0));
let mut handles = Vec::new();
for _ in 0..N {
let b = Arc::clone(&barrier);
let p = Arc::clone(&phase_counter);
handles.push(thread::spawn(move || {
b.wait().expect("phase 1 wait");
p.fetch_add(1, AO::Relaxed);
b.wait().expect("phase 2 wait");
p.fetch_add(1, AO::Relaxed);
}));
}
for h in handles {
h.join().expect("thread");
}
assert_eq!(phase_counter.load(AO::Relaxed), (N * 2) as u64);
}
#[test]
fn phase_barrier_advances_phase() {
const N: usize = 4;
let pb = Arc::new(PhaseBarrier::new(N));
let mut handles = Vec::new();
for _ in 0..N {
let p = Arc::clone(&pb);
handles.push(thread::spawn(move || {
p.arrive_and_wait().expect("arrive phase 1");
p.arrive_and_wait().expect("arrive phase 2");
}));
}
for h in handles {
h.join().expect("thread");
}
assert_eq!(pb.phase(), 2);
}
#[test]
fn countdown_latch_basic() {
const N: usize = 5;
let latch = Arc::new(CountDownLatch::new(N));
let counter = Arc::new(AtomicU64::new(0));
let mut handles = Vec::new();
for _ in 0..N {
let l = Arc::clone(&latch);
let c = Arc::clone(&counter);
handles.push(thread::spawn(move || {
c.fetch_add(1, AO::Relaxed);
l.count_down();
}));
}
latch.wait().expect("latch wait");
assert!(latch.is_open());
assert_eq!(counter.load(AO::Relaxed), N as u64);
for h in handles {
h.join().expect("thread");
}
}
#[test]
fn countdown_latch_already_open() {
let latch = CountDownLatch::new(0);
assert!(latch.is_open());
latch.wait().expect("already open wait");
}
#[test]
fn countdown_latch_timeout_opens() {
let latch = Arc::new(CountDownLatch::new(1));
let l2 = Arc::clone(&latch);
thread::spawn(move || {
thread::sleep(Duration::from_millis(20));
l2.count_down();
});
let opened = latch
.wait_timeout(Duration::from_secs(5))
.expect("wait_timeout");
assert!(opened);
}
#[test]
fn countdown_latch_timeout_expires() {
let latch = CountDownLatch::new(1); let opened = latch
.wait_timeout(Duration::from_millis(10))
.expect("wait_timeout");
assert!(!opened);
}
#[test]
fn spin_barrier_basic() {
const N: usize = 4;
let b = Arc::new(SpinBarrier::new(N));
let counter = Arc::new(AtomicU64::new(0));
let mut handles = Vec::new();
for _ in 0..N {
let bar = Arc::clone(&b);
let c = Arc::clone(&counter);
handles.push(thread::spawn(move || {
bar.wait();
c.fetch_add(1, AO::Relaxed);
}));
}
for h in handles {
h.join().expect("thread");
}
assert_eq!(counter.load(AO::Relaxed), N as u64);
assert_eq!(b.epoch(), 1);
}
#[test]
fn spin_barrier_multiple_epochs() {
const N: usize = 3;
let b = Arc::new(SpinBarrier::new(N));
let mut handles = Vec::new();
for _ in 0..N {
let bar = Arc::clone(&b);
handles.push(thread::spawn(move || {
bar.wait(); bar.wait(); bar.wait(); }));
}
for h in handles {
h.join().expect("thread");
}
assert_eq!(b.epoch(), 3);
}
}