use crate::shared::{fence_acquire, invalid_mut, StrictProvenance, Waiter};
use std::{
fmt,
ptr::{self, NonNull},
sync::atomic::{AtomicPtr, Ordering},
};
const QUEUED: usize = 1;
const QUEUE_LOCKED: usize = 2;
const COMPLETED: usize = 0;
const COUNT_SHIFT: u32 = QUEUED.trailing_zeros();
#[derive(Default)]
pub struct Barrier {
state: AtomicPtr<Waiter>,
}
unsafe impl Send for Barrier {}
unsafe impl Sync for Barrier {}
impl fmt::Debug for Barrier {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Barrier").finish_non_exhaustive()
}
}
impl Barrier {
#[must_use]
pub const fn new(n: usize) -> Self {
let state = invalid_mut(n << COUNT_SHIFT);
Self {
state: AtomicPtr::new(state),
}
}
#[inline]
pub fn wait(&self) -> BarrierWaitResult {
let mut is_leader = false;
let state = self.state.load(Ordering::Acquire);
if state.address() != COMPLETED {
is_leader = self.wait_slow(state);
}
BarrierWaitResult(is_leader)
}
#[cold]
fn wait_slow(&self, mut state: *mut Waiter) -> bool {
Waiter::with(|waiter| {
waiter.waiting_on.set(None);
waiter.prev.set(None);
loop {
if state.address() == COMPLETED {
fence_acquire(&self.state);
return false;
}
if state.address() == (1 << COUNT_SHIFT) {
match self.state.compare_exchange_weak(
state,
state.with_address(COMPLETED),
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => return true,
Err(e) => state = e,
}
continue;
}
let waiter_ptr = NonNull::from(&*waiter).as_ptr();
let mut new_state = waiter_ptr.map_address(|addr| addr | QUEUED);
if state.address() & QUEUED == 0 {
let counter = (state.address() >> COUNT_SHIFT)
.checked_sub(1)
.expect("Barrier counter with zero value when waiting");
waiter.counter.store(counter, Ordering::Relaxed);
waiter.next.set(None);
waiter.tail.set(Some(NonNull::from(&*waiter)));
} else {
let head = NonNull::new(state.map_address(|addr| addr & Waiter::MASK));
new_state = new_state.map_address(|addr| addr | QUEUE_LOCKED);
waiter.next.set(head);
waiter.tail.set(None);
}
if let Err(e) = self.state.compare_exchange_weak(
state,
new_state,
Ordering::Release,
Ordering::Relaxed,
) {
state = e;
continue;
}
if (state.address() & QUEUED != 0) && (state.address() & QUEUE_LOCKED == 0) {
if unsafe { self.link_queue_or_complete(new_state) } {
return true;
}
}
assert!(waiter.parker.park(None));
state = self.state.load(Ordering::Acquire);
assert_eq!(state.address(), COMPLETED);
return false;
}
})
}
#[cold]
unsafe fn link_queue_or_complete(&self, mut state: *mut Waiter) -> bool {
loop {
assert_ne!(state.address() & QUEUED, 0);
assert_ne!(state.address() & QUEUE_LOCKED, 0);
fence_acquire(&self.state);
let mut discovered = 0;
let (_, tail) = Waiter::get_and_link_queue(state, |_| discovered += 1);
let mut counter = tail.as_ref().counter.load(Ordering::Relaxed);
counter = counter.saturating_sub(discovered);
tail.as_ref().counter.store(counter, Ordering::Relaxed);
if counter == 0 {
return self.complete();
}
match self.state.compare_exchange_weak(
state,
state.map_address(|addr| addr & !QUEUE_LOCKED),
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => return false,
Err(e) => state = e,
}
}
}
#[cold]
unsafe fn complete(&self) -> bool {
let completed = ptr::null_mut::<Waiter>().with_address(COMPLETED);
let state = self.state.swap(completed, Ordering::AcqRel);
assert_ne!(state.address() & QUEUED, 0);
assert_ne!(state.address() & QUEUE_LOCKED, 0);
let mut waiters = NonNull::new(state.map_address(|addr| addr & Waiter::MASK));
while let Some(waiter) = waiters {
waiters = waiter.as_ref().next.get();
waiter.as_ref().parker.unpark();
}
true
}
}
pub struct BarrierWaitResult(bool);
impl fmt::Debug for BarrierWaitResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BarrierWaitResult")
.field("is_leader", &self.is_leader())
.finish()
}
}
impl BarrierWaitResult {
#[must_use]
pub fn is_leader(&self) -> bool {
self.0
}
}