use crate::{mutex::Mutex, RelaxStrategy, Spin};
pub struct Barrier<R = Spin> {
lock: Mutex<BarrierState, R>,
num_threads: usize,
}
struct BarrierState {
count: usize,
generation_id: usize,
}
pub struct BarrierWaitResult(bool);
impl<R: RelaxStrategy> Barrier<R> {
pub fn wait(&self) -> BarrierWaitResult {
let mut lock = self.lock.lock();
lock.count += 1;
if lock.count < self.num_threads {
let local_gen = lock.generation_id;
while local_gen == lock.generation_id &&
lock.count < self.num_threads {
drop(lock);
R::relax();
lock = self.lock.lock();
}
BarrierWaitResult(false)
} else {
lock.count = 0;
lock.generation_id = lock.generation_id.wrapping_add(1);
BarrierWaitResult(true)
}
}
}
impl<R> Barrier<R> {
pub const fn new(n: usize) -> Self {
Self {
lock: Mutex::new(BarrierState {
count: 0,
generation_id: 0,
}),
num_threads: n,
}
}
}
impl BarrierWaitResult {
pub fn is_leader(&self) -> bool { self.0 }
}
#[cfg(test)]
mod tests {
use std::prelude::v1::*;
use std::sync::mpsc::{channel, TryRecvError};
use std::sync::Arc;
use std::thread;
type Barrier = super::Barrier;
fn use_barrier(n: usize, barrier: Arc<Barrier>) {
let (tx, rx) = channel();
for _ in 0..n - 1 {
let c = barrier.clone();
let tx = tx.clone();
thread::spawn(move|| {
tx.send(c.wait().is_leader()).unwrap();
});
}
assert!(match rx.try_recv() {
Err(TryRecvError::Empty) => true,
_ => false,
});
let mut leader_found = barrier.wait().is_leader();
for _ in 0..n - 1 {
if rx.recv().unwrap() {
assert!(!leader_found);
leader_found = true;
}
}
assert!(leader_found);
}
#[test]
fn test_barrier() {
const N: usize = 10;
let barrier = Arc::new(Barrier::new(N));
use_barrier(N, barrier.clone());
use_barrier(N, barrier.clone());
}
}