use crate::misc::{PhantomBarrier, PhantomBarrierWaitResult};
use std::fmt;
use std::sync::atomic::{AtomicUsize, Ordering};
pub struct Barrier {
num_threads: usize, count: AtomicUsize,
generation_id: AtomicUsize, _phantom: PhantomBarrier,
}
impl Barrier {
pub const fn new(n: usize) -> Self {
Self {
num_threads: n,
count: AtomicUsize::new(0),
generation_id: AtomicUsize::new(0),
_phantom: PhantomBarrier {},
}
}
pub fn wait(&self) -> BarrierWaitResult {
let (guard, generation_id) = self.lock();
let count = self.count.load(Ordering::Relaxed) + 1;
self.count.store(count, Ordering::Relaxed);
if count < self.num_threads {
drop(guard);
loop {
let mut current_id = self.generation_id.load(Ordering::Relaxed);
if (current_id & BarrierLockGuard::MSB) != 0 {
current_id = current_id - BarrierLockGuard::MSB;
}
if generation_id != current_id {
return BarrierWaitResult(false, PhantomBarrierWaitResult {});
} else {
std::thread::yield_now();
}
}
} else {
self.count.store(0, Ordering::Relaxed);
let generation_id = (generation_id + 1) | BarrierLockGuard::MSB;
self.generation_id.store(generation_id, Ordering::Relaxed);
drop(guard);
BarrierWaitResult(true, PhantomBarrierWaitResult {})
}
}
fn lock(&self) -> (BarrierLockGuard, usize) {
let mut expected = 0;
loop {
let desired = expected + BarrierLockGuard::MSB;
let current = self
.generation_id
.compare_and_swap(expected, desired, Ordering::Acquire);
if current == expected {
break;
} else {
if (current & BarrierLockGuard::MSB) != 0 {
expected = current - BarrierLockGuard::MSB;
std::thread::yield_now();
} else {
expected = current;
}
}
}
(
BarrierLockGuard {
generation_id: &self.generation_id,
},
expected,
)
}
}
impl fmt::Debug for Barrier {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad("Barrier { .. }")
}
}
pub struct BarrierWaitResult(bool, PhantomBarrierWaitResult);
impl BarrierWaitResult {
pub fn is_leader(&self) -> bool {
self.0
}
}
impl fmt::Debug for BarrierWaitResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BarrierWaitResult")
.field("is_leader", &self.is_leader())
.finish()
}
}
struct BarrierLockGuard<'a> {
generation_id: &'a AtomicUsize,
}
impl BarrierLockGuard<'_> {
pub const MSB: usize = usize::MAX / 2 + 1;
}
impl Drop for BarrierLockGuard<'_> {
fn drop(&mut self) {
let current = self.generation_id.load(Ordering::Relaxed);
debug_assert_eq!(Self::MSB, current & Self::MSB);
let desired = current - Self::MSB;
self.generation_id.store(desired, Ordering::Release);
}
}