use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use crate::internal::Mutex;
use crate::internal::WaitSet;
#[cfg(test)]
mod tests;
#[derive(Debug)]
pub struct Barrier {
n: u32,
state: Mutex<BarrierState>,
}
struct BarrierState {
arrived: u32,
generation: usize,
waiters: WaitSet,
}
impl fmt::Debug for BarrierState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BarrierState")
.field("arrived", &self.arrived)
.field("generation", &self.generation)
.finish_non_exhaustive()
}
}
pub struct BarrierWaitResult(bool);
impl fmt::Debug for BarrierWaitResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BarrierWaitResult")
.field("is_leader", &self.is_leader())
.finish()
}
}
impl BarrierWaitResult {
#[must_use]
pub fn is_leader(&self) -> bool {
self.0
}
}
impl Barrier {
pub fn new(n: u32) -> Self {
let n = if n > 0 { n } else { 1 };
Self {
n,
state: Mutex::new(BarrierState {
arrived: 0,
generation: 0,
waiters: WaitSet::with_capacity(n as usize),
}),
}
}
pub async fn wait(&self) -> BarrierWaitResult {
let generation = {
let mut state = self.state.lock();
let generation = state.generation;
state.arrived += 1;
if state.arrived == self.n {
state.arrived = 0;
state.generation += 1;
state.waiters.wake_all();
return BarrierWaitResult(true);
}
generation
};
let fut = BarrierWait {
idx: None,
generation,
barrier: self,
};
fut.await;
BarrierWaitResult(false)
}
}
#[must_use = "futures do nothing unless you `.await` or poll them"]
struct BarrierWait<'a> {
idx: Option<usize>,
generation: usize,
barrier: &'a Barrier,
}
impl fmt::Debug for BarrierWait<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BarrierWait")
.field("generation", &self.generation)
.finish_non_exhaustive()
}
}
impl Future for BarrierWait<'_> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Self {
idx,
generation,
barrier,
} = self.get_mut();
let mut state = barrier.state.lock();
if *generation < state.generation {
Poll::Ready(())
} else {
state.waiters.register_waker(idx, cx);
Poll::Pending
}
}
}