use std::future::Future;
use std::mem::ManuallyDrop;
use std::pin::Pin;
use std::task::{Context, Poll};
use diatomic_waker::WakeSink;
use recycle_box::{coerce_box, RecycleBox};
use super::sender::{SendError, Sender};
use super::LineId;
use task_set::TaskSet;
mod task_set;
pub(super) struct Broadcaster<T: Clone + 'static, R: 'static> {
senders: Vec<(LineId, Box<dyn Sender<T, R>>)>,
shared: Shared<R>,
}
impl<T: Clone + 'static> Broadcaster<T, ()> {
pub(super) async fn broadcast_event(&mut self, arg: T) -> Result<(), BroadcastError> {
match self.senders.as_mut_slice() {
[] => Ok(()),
[sender] => sender.1.send(arg).await.map_err(|_| BroadcastError {}),
_ => self.broadcast(arg).await,
}
}
}
impl<T: Clone + 'static, R> Broadcaster<T, R> {
pub(super) fn add(&mut self, sender: Box<dyn Sender<T, R>>, id: LineId) {
self.senders.push((id, sender));
self.shared.futures_env.push(FutureEnv {
storage: None,
output: None,
});
self.shared.task_set.resize(self.senders.len());
}
pub(super) fn remove(&mut self, id: LineId) -> bool {
if let Some(pos) = self.senders.iter().position(|s| s.0 == id) {
self.senders.swap_remove(pos);
self.shared.futures_env.swap_remove(pos);
self.shared.task_set.resize(self.senders.len());
return true;
}
false
}
pub(super) fn clear(&mut self) {
self.senders.clear();
self.shared.futures_env.clear();
self.shared.task_set.resize(0);
}
pub(super) fn len(&self) -> usize {
self.senders.len()
}
pub(super) async fn broadcast_query(
&mut self,
arg: T,
) -> Result<impl Iterator<Item = R> + '_, BroadcastError> {
match self.senders.as_mut_slice() {
[] => {}
[sender] => {
let output = sender.1.send(arg).await.map_err(|_| BroadcastError {})?;
self.shared.futures_env[0].output = Some(output);
}
_ => self.broadcast(arg).await?,
};
let outputs = self
.shared
.futures_env
.iter_mut()
.map(|t| t.output.take().unwrap());
Ok(outputs)
}
fn broadcast(&mut self, arg: T) -> BroadcastFuture<'_, R> {
let futures_count = self.senders.len();
let mut futures = recycle_vec(self.shared.storage.take().unwrap_or_default());
for (i, (sender, futures_env)) in self
.senders
.iter_mut()
.zip(self.shared.futures_env.iter_mut())
.enumerate()
{
let future_cache = futures_env
.storage
.take()
.unwrap_or_else(|| RecycleBox::new(()));
if i + 1 == futures_count {
let future: RecycleBox<dyn Future<Output = Result<R, SendError>> + Send + '_> =
coerce_box!(RecycleBox::recycle(future_cache, sender.1.send(arg)));
futures.push(RecycleBox::into_pin(future));
break;
}
let future: RecycleBox<dyn Future<Output = Result<R, SendError>> + Send + '_> = coerce_box!(
RecycleBox::recycle(future_cache, sender.1.send(arg.clone()))
);
futures.push(RecycleBox::into_pin(future));
}
BroadcastFuture::new(&mut self.shared, futures)
}
}
impl<T: Clone + 'static, R> Default for Broadcaster<T, R> {
fn default() -> Self {
let wake_sink = WakeSink::new();
let wake_src = wake_sink.source();
Self {
senders: Vec::new(),
shared: Shared {
wake_sink,
task_set: TaskSet::new(wake_src),
futures_env: Vec::new(),
storage: None,
},
}
}
}
struct FutureEnv<R> {
storage: Option<RecycleBox<()>>,
output: Option<R>,
}
type RecycleBoxFuture<'a, R> = RecycleBox<dyn Future<Output = Result<R, SendError>> + Send + 'a>;
struct Shared<R> {
wake_sink: WakeSink,
task_set: TaskSet,
futures_env: Vec<FutureEnv<R>>,
storage: Option<Vec<Pin<RecycleBoxFuture<'static, R>>>>,
}
pub(super) struct BroadcastFuture<'a, R> {
shared: &'a mut Shared<R>,
futures: ManuallyDrop<Vec<Pin<RecycleBoxFuture<'a, R>>>>,
pending_futures_count: usize,
state: FutureState,
}
impl<'a, R> BroadcastFuture<'a, R> {
fn new(shared: &'a mut Shared<R>, futures: Vec<Pin<RecycleBoxFuture<'a, R>>>) -> Self {
let futures_count = futures.len();
assert!(shared.futures_env.len() == futures_count);
for futures_env in shared.futures_env.iter_mut() {
futures_env.output.take();
}
BroadcastFuture {
shared,
futures: ManuallyDrop::new(futures),
state: FutureState::Uninit,
pending_futures_count: futures_count,
}
}
}
impl<'a, R> Drop for BroadcastFuture<'a, R> {
fn drop(&mut self) {
let mut futures = unsafe { ManuallyDrop::take(&mut self.futures) };
for (future, futures_env) in futures.drain(..).zip(self.shared.futures_env.iter_mut()) {
futures_env.storage = Some(RecycleBox::vacate_pinned(future));
}
self.shared.storage = Some(recycle_vec(futures));
}
}
impl<'a, R> Future for BroadcastFuture<'a, R> {
type Output = Result<(), BroadcastError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
assert_ne!(this.state, FutureState::Completed);
if this.state == FutureState::Uninit {
this.shared.task_set.discard_scheduled();
for task_idx in 0..this.futures.len() {
let future_env = &mut this.shared.futures_env[task_idx];
let future = &mut this.futures[task_idx];
let task_waker_ref = this.shared.task_set.waker_of(task_idx);
let task_cx_ref = &mut Context::from_waker(&task_waker_ref);
match future.as_mut().poll(task_cx_ref) {
Poll::Ready(Ok(output)) => {
future_env.output = Some(output);
this.pending_futures_count -= 1;
}
Poll::Ready(Err(_)) => {
this.state = FutureState::Completed;
return Poll::Ready(Err(BroadcastError {}));
}
Poll::Pending => {}
}
}
if this.pending_futures_count == 0 {
this.state = FutureState::Completed;
return Poll::Ready(Ok(()));
}
this.state = FutureState::Pending;
}
loop {
if !this.shared.task_set.has_scheduled() {
this.shared.wake_sink.register(cx.waker());
}
let scheduled_tasks = match this
.shared
.task_set
.steal_scheduled(this.pending_futures_count)
{
Some(st) => st,
None => return Poll::Pending,
};
for task_idx in scheduled_tasks {
let future_env = &mut this.shared.futures_env[task_idx];
if future_env.output.is_some() {
continue;
}
let future = &mut this.futures[task_idx];
let task_waker_ref = this.shared.task_set.waker_of(task_idx);
let task_cx_ref = &mut Context::from_waker(&task_waker_ref);
match future.as_mut().poll(task_cx_ref) {
Poll::Ready(Ok(output)) => {
future_env.output = Some(output);
this.pending_futures_count -= 1;
}
Poll::Ready(Err(_)) => {
this.state = FutureState::Completed;
return Poll::Ready(Err(BroadcastError {}));
}
Poll::Pending => {}
}
}
if this.pending_futures_count == 0 {
this.state = FutureState::Completed;
return Poll::Ready(Ok(()));
}
}
}
}
#[derive(Debug)]
pub(super) struct BroadcastError {}
#[derive(Debug, PartialEq)]
enum FutureState {
Uninit,
Pending,
Completed,
}
fn recycle_vec<T, U>(mut v: Vec<T>) -> Vec<U> {
debug_assert_eq!(
std::alloc::Layout::new::<T>(),
std::alloc::Layout::new::<U>()
);
let cap = v.capacity();
v.clear();
let v_out: Vec<U> = v.into_iter().map(|_| unreachable!()).collect();
debug_assert_eq!(v_out.capacity(), cap);
v_out
}
#[cfg(all(test, not(asynchronix_loom)))]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use futures_executor::block_on;
use crate::channel::Receiver;
use crate::time::Scheduler;
use crate::time::{MonotonicTime, TearableAtomicTime};
use crate::util::priority_queue::PriorityQueue;
use crate::util::sync_cell::SyncCell;
use super::super::*;
use super::*;
struct Counter {
inner: Arc<AtomicUsize>,
}
impl Counter {
fn new(counter: Arc<AtomicUsize>) -> Self {
Self { inner: counter }
}
async fn inc(&mut self, by: usize) {
self.inner.fetch_add(by, Ordering::Relaxed);
}
async fn fetch_inc(&mut self, by: usize) -> usize {
let res = self.inner.fetch_add(by, Ordering::Relaxed);
res
}
}
impl Model for Counter {}
#[test]
fn broadcast_event_smoke() {
const N_RECV: usize = 4;
let mut mailboxes = Vec::new();
let mut broadcaster = Broadcaster::default();
for id in 0..N_RECV {
let mailbox = Receiver::new(10);
let address = mailbox.sender();
let sender = Box::new(EventSender::new(Counter::inc, address));
broadcaster.add(sender, LineId(id as u64));
mailboxes.push(mailbox);
}
let th_broadcast = thread::spawn(move || {
block_on(broadcaster.broadcast_event(1)).unwrap();
});
let counter = Arc::new(AtomicUsize::new(0));
let th_recv: Vec<_> = mailboxes
.into_iter()
.map(|mut mailbox| {
thread::spawn({
let mut counter = Counter::new(counter.clone());
move || {
let dummy_address = Receiver::new(1).sender();
let dummy_priority_queue = Arc::new(Mutex::new(PriorityQueue::new()));
let dummy_time =
SyncCell::new(TearableAtomicTime::new(MonotonicTime::EPOCH)).reader();
let dummy_scheduler =
Scheduler::new(dummy_address, dummy_priority_queue, dummy_time);
block_on(mailbox.recv(&mut counter, &dummy_scheduler)).unwrap();
}
})
})
.collect();
th_broadcast.join().unwrap();
for th in th_recv {
th.join().unwrap();
}
assert_eq!(counter.load(Ordering::Relaxed), N_RECV);
}
#[test]
fn broadcast_query_smoke() {
const N_RECV: usize = 4;
let mut mailboxes = Vec::new();
let mut broadcaster = Broadcaster::default();
for id in 0..N_RECV {
let mailbox = Receiver::new(10);
let address = mailbox.sender();
let sender = Box::new(QuerySender::new(Counter::fetch_inc, address));
broadcaster.add(sender, LineId(id as u64));
mailboxes.push(mailbox);
}
let th_broadcast = thread::spawn(move || {
let iter = block_on(broadcaster.broadcast_query(1)).unwrap();
let sum = iter.fold(0, |acc, val| acc + val);
assert_eq!(sum, N_RECV * (N_RECV - 1) / 2); });
let counter = Arc::new(AtomicUsize::new(0));
let th_recv: Vec<_> = mailboxes
.into_iter()
.map(|mut mailbox| {
thread::spawn({
let mut counter = Counter::new(counter.clone());
move || {
let dummy_address = Receiver::new(1).sender();
let dummy_priority_queue = Arc::new(Mutex::new(PriorityQueue::new()));
let dummy_time =
SyncCell::new(TearableAtomicTime::new(MonotonicTime::EPOCH)).reader();
let dummy_scheduler =
Scheduler::new(dummy_address, dummy_priority_queue, dummy_time);
block_on(mailbox.recv(&mut counter, &dummy_scheduler)).unwrap();
thread::sleep(std::time::Duration::from_millis(100));
}
})
})
.collect();
th_broadcast.join().unwrap();
for th in th_recv {
th.join().unwrap();
}
assert_eq!(counter.load(Ordering::Relaxed), N_RECV);
}
}
#[cfg(all(test, asynchronix_loom))]
mod tests {
use futures_channel::mpsc;
use futures_util::StreamExt;
use loom::model::Builder;
use loom::sync::atomic::{AtomicBool, Ordering};
use loom::thread;
use waker_fn::waker_fn;
use super::super::sender::RecycledFuture;
use super::*;
struct TestEvent<R> {
receiver: mpsc::UnboundedReceiver<Option<R>>,
fut_storage: Option<RecycleBox<()>>,
}
impl<R: Send> Sender<(), R> for TestEvent<R> {
fn send(&mut self, _arg: ()) -> RecycledFuture<'_, Result<R, SendError>> {
let fut_storage = &mut self.fut_storage;
let receiver = &mut self.receiver;
RecycledFuture::new(fut_storage, async {
let mut stream = Box::pin(receiver.filter_map(|item| async { item }));
Ok(stream.next().await.unwrap())
})
}
}
#[derive(Clone)]
struct TestEventWaker<R> {
sender: mpsc::UnboundedSender<Option<R>>,
}
impl<R> TestEventWaker<R> {
fn wake_spurious(&self) {
let _ = self.sender.unbounded_send(None);
}
fn wake_final(&self, value: R) {
let _ = self.sender.unbounded_send(Some(value));
}
}
fn test_event<R>() -> (TestEvent<R>, TestEventWaker<R>) {
let (sender, receiver) = mpsc::unbounded();
(
TestEvent {
receiver,
fut_storage: None,
},
TestEventWaker { sender },
)
}
#[test]
fn loom_broadcast_basic() {
const DEFAULT_PREEMPTION_BOUND: usize = 3;
let mut builder = Builder::new();
if builder.preemption_bound.is_none() {
builder.preemption_bound = Some(DEFAULT_PREEMPTION_BOUND);
}
builder.check(move || {
let (test_event1, waker1) = test_event::<usize>();
let (test_event2, waker2) = test_event::<usize>();
let (test_event3, waker3) = test_event::<usize>();
let mut broadcaster = Broadcaster::default();
broadcaster.add(Box::new(test_event1), LineId(1));
broadcaster.add(Box::new(test_event2), LineId(2));
broadcaster.add(Box::new(test_event3), LineId(3));
let mut fut = Box::pin(broadcaster.broadcast_query(()));
let is_scheduled = loom::sync::Arc::new(AtomicBool::new(false));
let is_scheduled_waker = is_scheduled.clone();
let waker = waker_fn(move || {
is_scheduled_waker.swap(true, Ordering::Release);
});
let mut cx = Context::from_waker(&waker);
let th1 = thread::spawn(move || waker1.wake_final(3));
let th2 = thread::spawn(move || waker2.wake_final(7));
let th3 = thread::spawn(move || waker3.wake_final(42));
let mut schedule_count = 0;
loop {
match fut.as_mut().poll(&mut cx) {
Poll::Ready(Ok(mut res)) => {
assert_eq!(res.next(), Some(3));
assert_eq!(res.next(), Some(7));
assert_eq!(res.next(), Some(42));
assert_eq!(res.next(), None);
return;
}
Poll::Ready(Err(_)) => panic!("sender error"),
Poll::Pending => {}
}
if !is_scheduled.swap(false, Ordering::Acquire) {
break;
}
schedule_count += 1;
assert!(schedule_count <= 1);
}
th1.join().unwrap();
th2.join().unwrap();
th3.join().unwrap();
assert!(is_scheduled.load(Ordering::Acquire));
match fut.as_mut().poll(&mut cx) {
Poll::Ready(Ok(mut res)) => {
assert_eq!(res.next(), Some(3));
assert_eq!(res.next(), Some(7));
assert_eq!(res.next(), Some(42));
assert_eq!(res.next(), None);
}
Poll::Ready(Err(_)) => panic!("sender error"),
Poll::Pending => panic!("the future has not completed"),
};
});
}
#[test]
fn loom_broadcast_spurious() {
const DEFAULT_PREEMPTION_BOUND: usize = 3;
let mut builder = Builder::new();
if builder.preemption_bound.is_none() {
builder.preemption_bound = Some(DEFAULT_PREEMPTION_BOUND);
}
builder.check(move || {
let (test_event1, waker1) = test_event::<usize>();
let (test_event2, waker2) = test_event::<usize>();
let mut broadcaster = Broadcaster::default();
broadcaster.add(Box::new(test_event1), LineId(1));
broadcaster.add(Box::new(test_event2), LineId(2));
let mut fut = Box::pin(broadcaster.broadcast_query(()));
let is_scheduled = loom::sync::Arc::new(AtomicBool::new(false));
let is_scheduled_waker = is_scheduled.clone();
let waker = waker_fn(move || {
is_scheduled_waker.swap(true, Ordering::Release);
});
let mut cx = Context::from_waker(&waker);
let spurious_waker = waker1.clone();
let th1 = thread::spawn(move || waker1.wake_final(3));
let th2 = thread::spawn(move || waker2.wake_final(7));
let th_spurious = thread::spawn(move || spurious_waker.wake_spurious());
let mut schedule_count = 0;
loop {
match fut.as_mut().poll(&mut cx) {
Poll::Ready(Ok(mut res)) => {
assert_eq!(res.next(), Some(3));
assert_eq!(res.next(), Some(7));
assert_eq!(res.next(), None);
return;
}
Poll::Ready(Err(_)) => panic!("sender error"),
Poll::Pending => {}
}
if !is_scheduled.swap(false, Ordering::Acquire) {
break;
}
schedule_count += 1;
assert!(schedule_count <= 2);
}
th1.join().unwrap();
th2.join().unwrap();
th_spurious.join().unwrap();
assert!(is_scheduled.load(Ordering::Acquire));
match fut.as_mut().poll(&mut cx) {
Poll::Ready(Ok(mut res)) => {
assert_eq!(res.next(), Some(3));
assert_eq!(res.next(), Some(7));
assert_eq!(res.next(), None);
}
Poll::Ready(Err(_)) => panic!("sender error"),
Poll::Pending => panic!("the future has not completed"),
};
});
}
}