use crate::sync::{Condvar, Mutex};
use std::fmt;
pub struct Barrier {
lock: Mutex<BarrierState>,
cvar: Condvar,
num_threads: usize,
}
struct BarrierState {
count: usize,
generation_id: usize,
}
pub struct BarrierWaitResult(bool);
impl fmt::Debug for Barrier {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Barrier").finish_non_exhaustive()
}
}
impl Barrier {
#[must_use]
#[inline]
pub fn new(n: usize) -> Barrier {
Barrier {
lock: Mutex::new(BarrierState {
count: 0,
generation_id: 0,
}),
cvar: Condvar::new(),
num_threads: n,
}
}
pub fn wait(&self) -> BarrierWaitResult {
let mut lock = self.lock.lock().unwrap();
let local_gen = lock.generation_id;
lock.count += 1;
if lock.count < self.num_threads {
let _guard = self
.cvar
.wait_while(lock, |state| local_gen == state.generation_id)
.unwrap();
BarrierWaitResult(false)
} else {
lock.count = 0;
lock.generation_id = lock.generation_id.wrapping_add(1);
self.cvar.notify_all();
BarrierWaitResult(true)
}
}
}
impl BarrierWaitResult {
#[must_use]
pub fn is_leader(&self) -> bool {
self.0
}
}
#[test]
fn test_barrier() {
use crate::sync::mpsc::channel;
use std::sync::mpsc::TryRecvError;
use std::sync::Arc;
const N: usize = 10;
let barrier = Arc::new(Barrier::new(N));
let (tx, rx) = channel();
for _ in 0..N - 1 {
let c = barrier.clone();
let tx = tx.clone();
go!(move || {
tx.send(c.wait().is_leader()).unwrap();
});
}
assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
let mut leader_found = barrier.wait().is_leader();
for _ in 0..N - 1 {
if rx.recv().unwrap() {
assert!(!leader_found);
leader_found = true;
}
}
assert!(leader_found);
}