use std::sync::atomic::{AtomicUsize, AtomicBool, Ordering};
use std::sync::{Condvar, Mutex};
use std::ptr;
pub struct Barrier{
n: usize,
cvar: Condvar,
finished: Mutex<bool>,
checkpoints_created: usize,
checkpoints_remaining: AtomicUsize,
checkpoint_panicked: AtomicBool,
}
impl Barrier{
pub fn new(n: usize) -> Barrier{
Barrier{
n: n,
cvar: Condvar::new(),
finished: Mutex::new(false),
checkpoints_created: 0,
checkpoints_remaining: AtomicUsize::new(n),
checkpoint_panicked: AtomicBool::new(false),
}
}
pub fn set_n(&mut self, n: usize){
self.n = n;
}
pub fn activate<'a>(&'a mut self) -> ActiveBarrier<'a>{
self.reset();
ActiveBarrier{barrier: self}
}
pub fn n(&self) -> usize{
self.n
}
fn reset(&mut self){
*self.finished.lock().unwrap() = false;
self.checkpoints_created = 0;
self.checkpoints_remaining.store(self.n, Ordering::Release);
self.checkpoint_panicked.store(false, Ordering::Release);
}
fn check_in_x(&self, x: usize){
let result = self.checkpoints_remaining.fetch_sub(x, Ordering::AcqRel);
debug_assert!(result >= x); debug_assert!(result <= self.n); if result == x {
let mut finished = self.finished.lock().unwrap();
*finished = true;
self.cvar.notify_all();
}
}
}
pub struct ActiveBarrier<'a>{
barrier: &'a mut Barrier,
}
impl<'a> ActiveBarrier<'a>{
pub fn checkpoint(&mut self) -> Checkpoint{
if self.barrier.checkpoints_created >= self.barrier.n{
panic!("More than n checkpoints generated.");
} else {
self.barrier.checkpoints_created +=1 ;
Checkpoint{barrier: self.barrier as *const Barrier}
}
}
pub fn finished(&self) -> bool {
*self.barrier.finished.lock().unwrap()
}
pub fn wait(&self) -> WaitResult{
let missing = self.barrier.n - self.barrier.checkpoints_created;
if self.barrier.checkpoints_remaining.load(Ordering::Acquire) != 0 && missing != 0{
self.barrier.check_in_x(missing);
}
let mut finished = self.barrier.finished.lock().unwrap();
while !*finished {
finished = self.barrier.cvar.wait(finished).unwrap();
}
debug_assert_eq!(0, self.barrier.checkpoints_remaining.load(Ordering::Acquire));
if self.barrier.checkpoint_panicked.load(Ordering::Acquire) {
Err(WaitError::CheckpointPanic)
} else if missing != 0 {
Err(WaitError::InsufficientCheckpoints)
} else {
Ok(())
}
}
pub fn n(&self) -> usize{
self.barrier.n
}
}
impl<'a> Drop for ActiveBarrier<'a>{
fn drop(&mut self){
self.wait().ok(); }
}
#[derive(Debug, PartialEq)]
pub enum WaitError {
CheckpointPanic,
InsufficientCheckpoints,
}
pub type WaitResult = Result<(), WaitError>;
pub struct Checkpoint{
barrier: *const Barrier,
}
unsafe impl Send for Checkpoint{}
impl Checkpoint{
pub fn check_in(&mut self){
if !self.barrier.is_null() {
let barrier = unsafe{&*self.barrier};
if std::thread::panicking() {
barrier.checkpoint_panicked.store(true, Ordering::Release);
}
barrier.check_in_x(1);
self.barrier = ptr::null();
}
}
}
impl Drop for Checkpoint{
fn drop(&mut self){
self.check_in();
}
}
#[cfg(test)]
mod tests{
extern crate rand;
use super::*;
use tests::rand::Rng;
const THREADS: usize = 5;
fn threaded_run(barrier: &mut ActiveBarrier, n_threads: usize) -> WaitResult{
for i in 0..n_threads{
let mut checkpoint = barrier.checkpoint();
std::thread::spawn(move||{
std::thread::sleep(std::time::Duration::new(0,rand::thread_rng().gen_range(1,10)*10_000_000));
println!("thread_id: {}", i); checkpoint.check_in(); });
}
std::thread::sleep(std::time::Duration::new(0,rand::thread_rng().gen_range(1,10)*10_000_000));
let result = barrier.wait(); println!("main thread"); result
}
fn panic_run(barrier: &mut ActiveBarrier){
for i in 0..THREADS{
let mut checkpoint = barrier.checkpoint();
std::thread::spawn(move||{
std::thread::sleep(std::time::Duration::new(0,rand::thread_rng().gen_range(1,10)*10_000_000));
if i%2 == 1 {panic!("Deliberate panic")};
println!("thread_id: {}", i);
checkpoint.check_in();
});
}
std::thread::sleep(std::time::Duration::new(0,rand::thread_rng().gen_range(1,10)*10_000_000));
let result = barrier.wait();
assert_eq!(result, Err(WaitError::CheckpointPanic)); println!("main thread");
}
#[test]
fn same_thread() {
fn run(mut barrier: ActiveBarrier){
for i in 0..THREADS{
let mut checkpoint = barrier.checkpoint();
println!("thread_id: {}", i);
checkpoint.check_in();
}
barrier.wait().unwrap();
println!("main thread");
}
let mut barrier = Barrier::new(THREADS);
run(barrier.activate());
}
#[test]
fn single_use() {
let mut barrier = Barrier::new(THREADS);
threaded_run(&mut barrier.activate(), THREADS).unwrap();
}
#[test]
fn reuse() {
let mut barrier = Barrier::new(THREADS);
threaded_run(&mut barrier.activate(), THREADS).unwrap();
threaded_run(&mut barrier.activate(), THREADS).unwrap();
threaded_run(&mut barrier.activate(), THREADS).unwrap();
threaded_run(&mut barrier.activate(), THREADS).unwrap();
threaded_run(&mut barrier.activate(), THREADS).unwrap();
}
#[test]
fn test_checkpoint_panic_detection() {
let mut barrier = Barrier::new(THREADS);
panic_run(&mut barrier.activate());
}
#[test]
fn not_enough_checkpoints() {
let mut barrier = Barrier::new(THREADS);
assert_eq!(threaded_run(&mut barrier.activate(), THREADS-1), Err(WaitError::InsufficientCheckpoints));
}
#[test]
#[should_panic]
fn too_many_checkpoints() {
let mut barrier = Barrier::new(THREADS);
threaded_run(&mut barrier.activate(), THREADS+1).unwrap();
}
#[test]
fn test_finished_true() {
let mut barrier = Barrier::new(THREADS);
let mut active_barrier = barrier.activate();
threaded_run(&mut active_barrier, THREADS).unwrap();
assert_eq!(true, active_barrier.finished());
}
#[test]
fn test_finished_false() {
let mut barrier = Barrier::new(THREADS);
let mut active_barrier = barrier.activate();
assert_eq!(false, active_barrier.finished());
threaded_run(&mut active_barrier, THREADS).unwrap();
}
}