use event_listener::{Event, EventListener};
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::futures::Lock;
use crate::Mutex;
#[derive(Debug)]
pub struct Barrier {
n: usize,
state: Mutex<State>,
event: Event,
}
#[derive(Debug)]
struct State {
count: usize,
generation_id: u64,
}
impl Barrier {
pub const fn new(n: usize) -> Barrier {
Barrier {
n,
state: Mutex::new(State {
count: 0,
generation_id: 0,
}),
event: Event::new(),
}
}
pub fn wait(&self) -> BarrierWait<'_> {
BarrierWait {
barrier: self,
lock: Some(self.state.lock()),
state: WaitState::Initial,
}
}
}
pub struct BarrierWait<'a> {
barrier: &'a Barrier,
lock: Option<Lock<'a, State>>,
state: WaitState,
}
impl fmt::Debug for BarrierWait<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("BarrierWait { .. }")
}
}
enum WaitState {
Initial,
Waiting { evl: EventListener, local_gen: u64 },
Reacquiring(u64),
}
impl Future for BarrierWait<'_> {
type Output = BarrierWaitResult;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
match this.state {
WaitState::Initial => {
let mut state = ready!(Pin::new(this.lock.as_mut().unwrap()).poll(cx));
this.lock = None;
let local_gen = state.generation_id;
state.count += 1;
if state.count < this.barrier.n {
this.state = WaitState::Waiting {
evl: this.barrier.event.listen(),
local_gen,
};
} else {
state.count = 0;
state.generation_id = state.generation_id.wrapping_add(1);
this.barrier.event.notify(std::usize::MAX);
return Poll::Ready(BarrierWaitResult { is_leader: true });
}
}
WaitState::Waiting {
ref mut evl,
local_gen,
} => {
ready!(Pin::new(evl).poll(cx));
this.lock = Some(this.barrier.state.lock());
this.state = WaitState::Reacquiring(local_gen);
}
WaitState::Reacquiring(local_gen) => {
let state = ready!(Pin::new(this.lock.as_mut().unwrap()).poll(cx));
this.lock = None;
if local_gen == state.generation_id && state.count < this.barrier.n {
this.state = WaitState::Waiting {
evl: this.barrier.event.listen(),
local_gen,
};
} else {
return Poll::Ready(BarrierWaitResult { is_leader: false });
}
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct BarrierWaitResult {
is_leader: bool,
}
impl BarrierWaitResult {
pub fn is_leader(&self) -> bool {
self.is_leader
}
}