use std::sync::atomic::{AtomicUsize, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
pub struct MkBarrier {
inner: Arc<BarrierInner>,
}
struct BarrierInner {
thread_count: usize,
waiting_count: AtomicUsize,
generation: AtomicU64,
mutex: std::sync::Mutex<()>,
condvar: std::sync::Condvar,
}
impl MkBarrier {
pub fn new(thread_count: usize) -> Self {
Self {
inner: Arc::new(BarrierInner {
thread_count,
waiting_count: AtomicUsize::new(0),
generation: AtomicU64::new(0),
mutex: std::sync::Mutex::new(()),
condvar: std::sync::Condvar::new(),
}),
}
}
pub fn wait(&self) -> u64 {
let generation = self.inner.generation.load(Ordering::Acquire);
let waiting = self.inner.waiting_count.fetch_add(1, Ordering::AcqRel) + 1;
if waiting == self.inner.thread_count {
self.inner.generation.fetch_add(1, Ordering::AcqRel);
self.inner.waiting_count.store(0, Ordering::Release);
let _guard = self.inner.mutex.lock().unwrap();
self.inner.condvar.notify_all();
generation
} else {
let guard = self.inner.mutex.lock().unwrap();
let mut guard = guard;
if self.inner.generation.load(Ordering::Acquire) != generation {
return generation;
}
while self.inner.generation.load(Ordering::Acquire) == generation {
guard = self.inner.condvar.wait(guard).unwrap();
}
generation
}
}
pub fn wait_timeout(&self, mut timeout: Duration) -> Option<u64> {
let generation = self.inner.generation.load(Ordering::Acquire);
let waiting = self.inner.waiting_count.fetch_add(1, Ordering::AcqRel) + 1;
if waiting == self.inner.thread_count {
self.inner.generation.fetch_add(1, Ordering::AcqRel);
self.inner.waiting_count.store(0, Ordering::Release);
let _guard = self.inner.mutex.lock().unwrap();
self.inner.condvar.notify_all();
Some(generation)
} else {
let deadline = Instant::now() + timeout;
let guard = self.inner.mutex.lock().unwrap();
let mut guard = guard;
if self.inner.generation.load(Ordering::Acquire) != generation {
return Some(generation);
}
while self.inner.generation.load(Ordering::Acquire) == generation {
let result = self.inner.condvar.wait_timeout(guard, timeout).unwrap();
guard = result.0;
if result.1.timed_out() {
self.inner.waiting_count.fetch_sub(1, Ordering::AcqRel);
return None;
}
if self.inner.generation.load(Ordering::Acquire) != generation {
return Some(generation);
}
let now = Instant::now();
if now >= deadline {
self.inner.waiting_count.fetch_sub(1, Ordering::AcqRel);
return None;
}
timeout = deadline - now;
}
Some(generation)
}
}
pub fn try_wait(&self) -> Option<u64> {
let _generation = self.inner.generation.load(Ordering::Acquire);
let waiting = self.inner.waiting_count.load(Ordering::Acquire);
if waiting + 1 >= self.inner.thread_count {
Some(self.wait())
} else {
None
}
}
pub fn thread_count(&self) -> usize {
self.inner.thread_count
}
pub fn waiting_count(&self) -> usize {
self.inner.waiting_count.load(Ordering::Acquire)
}
pub fn generation(&self) -> u64 {
self.inner.generation.load(Ordering::Acquire)
}
}
impl Default for MkBarrier {
fn default() -> Self {
Self::new(2) }
}
impl Clone for MkBarrier {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_barrier_basic() {
let barrier = Arc::new(MkBarrier::new(3));
let mut handles = vec![];
for i in 0..3 {
let b = barrier.clone();
let handle = thread::spawn(move || {
thread::sleep(Duration::from_millis(10 * i));
let gen = b.wait();
gen
});
handles.push(handle);
}
let generations: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
assert_eq!(generations.len(), 3);
assert!(generations.windows(2).all(|w| w[0] == w[1]));
}
#[test]
fn test_barrier_timeout() {
let barrier = Arc::new(MkBarrier::new(3));
let mut handles = vec![];
for i in 0..2 {
let b = barrier.clone();
let handle = thread::spawn(move || {
let result = b.wait_timeout(Duration::from_millis(50));
assert!(result.is_none()); });
handles.push(handle);
}
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_barrier_try_wait() {
let barrier = MkBarrier::new(2);
assert!(barrier.try_wait().is_none());
let barrier2 = barrier.clone();
thread::spawn(move || {
barrier2.wait();
});
thread::sleep(Duration::from_millis(10));
assert!(barrier.try_wait().is_some());
}
}