use std::marker::PhantomData;
use crate::cell::UnsafeCell;
use crate::oneshot;
#[derive(Clone, Debug)]
pub struct BarrierWaitResult(bool);
impl BarrierWaitResult {
pub fn is_leader(&self) -> bool {
self.0
}
}
#[derive(Debug)]
struct Inner {
wakers: Vec<oneshot::Sender<()>>,
n: usize,
_marker: PhantomData<*const ()>,
}
impl Inner {
fn wait_impl(&mut self) -> Option<oneshot::Receiver<()>> {
let should_wake = self.n - self.wakers.len() == 1;
if should_wake {
for sender in self.wakers.drain(..) {
let _ = sender.send(());
}
None
} else {
let (tx, rx) = oneshot::channel();
self.wakers.push(tx);
Some(rx)
}
}
}
#[derive(Debug)]
pub struct Barrier {
inner: UnsafeCell<Inner>,
}
impl Barrier {
pub fn new(n: usize) -> Self {
Self {
inner: UnsafeCell::new(Inner {
wakers: Vec::with_capacity(n - 1),
n,
_marker: PhantomData,
}),
}
}
pub async fn wait(&self) -> BarrierWaitResult {
let maybe_recv = unsafe { self.inner.with_mut(|inner| inner.wait_impl()) };
match maybe_recv {
Some(m) => {
m.await.expect("channel failed");
BarrierWaitResult(false)
}
None => BarrierWaitResult(true),
}
}
}
#[cfg(test)]
mod tests {
use std::rc::Rc;
use futures::channel::mpsc::unbounded;
use futures::stream::StreamExt;
use tokio::task::{spawn_local, LocalSet};
use tokio::test;
use super::*;
#[test]
async fn test_barrier() {
let local_set = LocalSet::new();
local_set
.run_until(async {
const N: usize = 10;
let barrier = Rc::new(Barrier::new(N));
let (tx, mut rx) = unbounded();
for _ in 0..N - 1 {
let c = barrier.clone();
let tx = tx.clone();
spawn_local(async move {
tx.unbounded_send(c.wait().await.is_leader()).unwrap();
});
}
assert!(rx.try_next().is_err());
let mut leader_found = barrier.wait().await.is_leader();
for _ in 0..N - 1 {
if rx.next().await.unwrap() {
assert!(!leader_found);
leader_found = true;
}
}
assert!(leader_found);
})
.await;
}
}