use std::future::Future;
use std::mem::ManuallyDrop;
use std::pin::Pin;
use std::task::{Context, Poll};
use diatomic_waker::WakeSink;
use super::sender::{RecycledFuture, Sender};
use crate::channel::SendError;
use crate::util::task_set::TaskSet;
pub(super) struct BroadcasterInner<T: Clone, R> {
senders: Vec<Box<dyn Sender<T, R>>>,
shared: Shared<R>,
}
impl<T: Clone, R> BroadcasterInner<T, R> {
pub(super) fn add(&mut self, sender: Box<dyn Sender<T, R>>) {
assert!(self.senders.len() < (u32::MAX as usize - 2));
self.senders.push(sender);
self.shared.outputs.push(None);
if let Some(storage) = self.shared.storage.as_mut() {
let _ = storage.try_reserve(self.senders.len());
};
}
pub(super) fn len(&self) -> usize {
self.senders.len()
}
#[allow(clippy::type_complexity)]
fn futures(
&mut self,
arg: T,
) -> (
&'_ mut Shared<R>,
Vec<RecycledFuture<'_, Result<R, SendError>>>,
) {
let mut futures = recycle_vec(self.shared.storage.take().unwrap_or_default());
let mut iter = self.senders.iter_mut();
while let Some(sender) = iter.next() {
if iter.len() == 0 {
if let Some(fut) = sender.send_owned(arg) {
futures.push(fut);
}
break;
}
if let Some(fut) = sender.send(&arg) {
futures.push(fut);
}
}
(&mut self.shared, futures)
}
}
impl<T: Clone, R> Default for BroadcasterInner<T, R> {
fn default() -> Self {
let wake_sink = WakeSink::new();
let wake_src = wake_sink.source();
Self {
senders: Vec::new(),
shared: Shared {
wake_sink,
task_set: TaskSet::new(wake_src),
outputs: Vec::new(),
storage: None,
},
}
}
}
impl<T: Clone, R> Clone for BroadcasterInner<T, R> {
fn clone(&self) -> Self {
Self {
senders: self.senders.clone(),
shared: self.shared.clone(),
}
}
}
#[derive(Clone)]
pub(super) struct EventBroadcaster<T: Clone> {
inner: BroadcasterInner<T, ()>,
}
impl<T: Clone> EventBroadcaster<T> {
pub(super) fn add(&mut self, sender: Box<dyn Sender<T, ()>>) {
self.inner.add(sender)
}
pub(super) fn len(&self) -> usize {
self.inner.len()
}
pub(super) async fn broadcast(&mut self, arg: T) -> Result<(), SendError> {
match self.inner.senders.as_mut_slice() {
[] => Ok(()),
[sender] => match sender.send_owned(arg) {
None => Ok(()),
Some(fut) => fut.await,
},
_ => {
let (shared, mut futures) = self.inner.futures(arg);
match futures.as_mut_slice() {
[] => Ok(()),
[fut] => fut.await,
_ => BroadcastFuture::new(shared, futures).await,
}
}
}
}
}
impl<T: Clone> Default for EventBroadcaster<T> {
fn default() -> Self {
Self {
inner: BroadcasterInner::default(),
}
}
}
pub(super) struct QueryBroadcaster<T: Clone, R> {
inner: BroadcasterInner<T, R>,
}
impl<T: Clone, R> QueryBroadcaster<T, R> {
pub(super) fn add(&mut self, sender: Box<dyn Sender<T, R>>) {
self.inner.add(sender)
}
pub(super) fn len(&self) -> usize {
self.inner.len()
}
pub(super) async fn broadcast(
&mut self,
arg: T,
) -> Result<impl Iterator<Item = R> + '_, SendError> {
let output_count = match self.inner.senders.as_mut_slice() {
[] => 0,
[sender] => {
if let Some(fut) = sender.send_owned(arg) {
let output = fut.await?;
self.inner.shared.outputs[0] = Some(output);
1
} else {
0
}
}
_ => {
let (shared, mut futures) = self.inner.futures(arg);
let output_count = futures.len();
match futures.as_mut_slice() {
[] => {}
[fut] => {
let output = fut.await?;
shared.outputs[0] = Some(output);
}
_ => {
BroadcastFuture::new(shared, futures).await?;
}
}
output_count
}
};
let outputs = self
.inner
.shared
.outputs
.iter_mut()
.take(output_count)
.map(|t| t.take().unwrap());
Ok(outputs)
}
}
impl<T: Clone, R> Default for QueryBroadcaster<T, R> {
fn default() -> Self {
Self {
inner: BroadcasterInner::default(),
}
}
}
impl<T: Clone, R> Clone for QueryBroadcaster<T, R> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
struct Shared<R> {
wake_sink: WakeSink,
task_set: TaskSet,
outputs: Vec<Option<R>>,
storage: Option<Vec<Pin<RecycledFuture<'static, R>>>>,
}
impl<R> Clone for Shared<R> {
fn clone(&self) -> Self {
let wake_sink = WakeSink::new();
let wake_src = wake_sink.source();
let mut outputs = Vec::new();
outputs.resize_with(self.outputs.len(), Default::default);
Self {
wake_sink,
task_set: TaskSet::new(wake_src),
outputs,
storage: None,
}
}
}
pub(super) struct BroadcastFuture<'a, R> {
shared: &'a mut Shared<R>,
futures: ManuallyDrop<Vec<RecycledFuture<'a, Result<R, SendError>>>>,
pending_futures_count: usize,
state: FutureState,
}
impl<'a, R> BroadcastFuture<'a, R> {
fn new(
shared: &'a mut Shared<R>,
futures: Vec<RecycledFuture<'a, Result<R, SendError>>>,
) -> Self {
let pending_futures_count = futures.len();
shared.task_set.resize(pending_futures_count);
for output in shared.outputs.iter_mut().take(pending_futures_count) {
output.take();
}
BroadcastFuture {
shared,
futures: ManuallyDrop::new(futures),
state: FutureState::Uninit,
pending_futures_count,
}
}
}
impl<R> Drop for BroadcastFuture<'_, R> {
fn drop(&mut self) {
let futures = unsafe { ManuallyDrop::take(&mut self.futures) };
self.shared.storage = Some(recycle_vec(futures));
}
}
impl<R> Future for BroadcastFuture<'_, R> {
type Output = Result<(), SendError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
assert_ne!(
this.state,
FutureState::Completed,
"broadcast future polled after completion"
);
if this.state == FutureState::Uninit {
this.shared.task_set.discard_scheduled();
for task_idx in 0..this.futures.len() {
let output = &mut this.shared.outputs[task_idx];
let future = std::pin::Pin::new(&mut this.futures[task_idx]);
let task_waker_ref = this.shared.task_set.waker_of(task_idx);
let task_cx_ref = &mut Context::from_waker(&task_waker_ref);
match future.poll(task_cx_ref) {
Poll::Ready(Ok(o)) => {
*output = Some(o);
this.pending_futures_count -= 1;
}
Poll::Ready(Err(SendError)) => {
this.state = FutureState::Completed;
return Poll::Ready(Err(SendError));
}
Poll::Pending => {}
}
}
if this.pending_futures_count == 0 {
this.state = FutureState::Completed;
return Poll::Ready(Ok(()));
}
this.state = FutureState::Pending;
}
loop {
if !this.shared.task_set.has_scheduled() {
this.shared.wake_sink.register(cx.waker());
}
let scheduled_tasks = match this.shared.task_set.take_scheduled(1) {
Some(st) => st,
None => return Poll::Pending,
};
for task_idx in scheduled_tasks {
let output = &mut this.shared.outputs[task_idx];
if output.is_some() {
continue;
}
let future = std::pin::Pin::new(&mut this.futures[task_idx]);
let task_waker_ref = this.shared.task_set.waker_of(task_idx);
let task_cx_ref = &mut Context::from_waker(&task_waker_ref);
match future.poll(task_cx_ref) {
Poll::Ready(Ok(o)) => {
*output = Some(o);
this.pending_futures_count -= 1;
}
Poll::Ready(Err(SendError)) => {
this.state = FutureState::Completed;
return Poll::Ready(Err(SendError));
}
Poll::Pending => {}
}
}
if this.pending_futures_count == 0 {
this.state = FutureState::Completed;
return Poll::Ready(Ok(()));
}
}
}
}
#[derive(Debug, PartialEq)]
enum FutureState {
Uninit,
Pending,
Completed,
}
fn recycle_vec<T, U>(mut v: Vec<T>) -> Vec<U> {
debug_assert_eq!(
std::alloc::Layout::new::<T>(),
std::alloc::Layout::new::<U>()
);
let cap = v.capacity();
v.clear();
let v_out: Vec<U> = v.into_iter().map(|_| unreachable!()).collect();
debug_assert_eq!(v_out.capacity(), cap);
v_out
}
#[cfg(all(test, not(nexosim_loom)))]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use futures_executor::block_on;
use serde::{Deserialize, Serialize};
use crate::channel::Receiver;
use super::super::sender::{
FilterMapInputSender, FilterMapReplierSender, InputSender, ReplierSender,
};
use super::*;
use crate::model::{Context, Model};
struct SumModel {
inner: Arc<AtomicUsize>,
}
impl SumModel {
fn new(counter: Arc<AtomicUsize>) -> Self {
Self { inner: counter }
}
async fn increment(&mut self, by: usize) {
self.inner.fetch_add(by, Ordering::Relaxed);
}
}
impl Model for SumModel {
type Env = ();
}
impl Serialize for SumModel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_u64(self.inner.load(Ordering::Relaxed) as u64)
}
}
impl<'de> Deserialize<'de> for SumModel {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let counter = usize::deserialize(deserializer)?;
Ok(SumModel::new(Arc::new(AtomicUsize::new(counter))))
}
}
#[derive(Serialize, Deserialize)]
struct DoubleModel {}
impl DoubleModel {
fn new() -> Self {
Self {}
}
async fn double(&mut self, value: usize) -> usize {
2 * value
}
}
impl Model for DoubleModel {
type Env = ();
}
#[test]
fn broadcast_event_smoke() {
const N_RECV: usize = 4;
const MESSAGE: usize = 42;
let mut mailboxes = Vec::new();
let mut broadcaster = EventBroadcaster::default();
for _ in 0..N_RECV {
let mailbox = Receiver::new(10);
let address = mailbox.sender();
let sender = Box::new(InputSender::new(SumModel::increment, address));
broadcaster.add(sender);
mailboxes.push(mailbox);
}
let th_broadcast = thread::spawn(move || {
block_on(broadcaster.broadcast(MESSAGE)).unwrap();
});
let sum = Arc::new(AtomicUsize::new(0));
let th_recv: Vec<_> = mailboxes
.into_iter()
.map(|mut mailbox| {
thread::spawn({
let mut sum_model = SumModel::new(sum.clone());
move || {
let dummy_cx = Context::new_dummy();
block_on(mailbox.recv(&mut sum_model, &dummy_cx, &mut ())).unwrap();
}
})
})
.collect();
th_broadcast.join().unwrap();
for th in th_recv {
th.join().unwrap();
}
assert_eq!(sum.load(Ordering::Relaxed), N_RECV * MESSAGE);
}
#[test]
fn broadcast_event_filter_map() {
const N_RECV: usize = 4;
const BROADCAST_ALL: usize = 42;
let mut mailboxes = Vec::new();
let mut broadcaster = EventBroadcaster::default();
for id in 0..N_RECV {
let mailbox = Receiver::new(10);
let address = mailbox.sender();
let id_filter_sender = Box::new(FilterMapInputSender::new(
move |x: &usize| (*x == id || *x == BROADCAST_ALL).then_some(*x),
SumModel::increment,
address,
));
broadcaster.add(id_filter_sender);
mailboxes.push(mailbox);
}
let th_broadcast = thread::spawn(move || {
block_on(async {
for id in 0..N_RECV {
broadcaster.broadcast(id).await.unwrap();
}
broadcaster.broadcast(BROADCAST_ALL).await.unwrap();
for id in 0..N_RECV {
broadcaster.broadcast(id).await.unwrap();
}
})
});
let sum = Arc::new(AtomicUsize::new(0));
let th_recv: Vec<_> = mailboxes
.into_iter()
.map(|mut mailbox| {
thread::spawn({
let mut sum_model = SumModel::new(sum.clone());
move || {
let dummy_cx = Context::new_dummy();
block_on(async {
mailbox
.recv(&mut sum_model, &dummy_cx, &mut ())
.await
.unwrap();
mailbox
.recv(&mut sum_model, &dummy_cx, &mut ())
.await
.unwrap();
mailbox
.recv(&mut sum_model, &dummy_cx, &mut ())
.await
.unwrap();
});
}
})
})
.collect();
th_broadcast.join().unwrap();
for th in th_recv {
th.join().unwrap();
}
assert_eq!(
sum.load(Ordering::Relaxed),
N_RECV * ((N_RECV - 1) + BROADCAST_ALL)
);
}
#[test]
fn broadcast_query_smoke() {
const N_RECV: usize = 4;
const MESSAGE: usize = 42;
let mut mailboxes = Vec::new();
let mut broadcaster = QueryBroadcaster::default();
for _ in 0..N_RECV {
let mailbox = Receiver::new(10);
let address = mailbox.sender();
let sender = Box::new(ReplierSender::new(DoubleModel::double, address));
broadcaster.add(sender);
mailboxes.push(mailbox);
}
let th_broadcast = thread::spawn(move || {
let iter = block_on(broadcaster.broadcast(MESSAGE)).unwrap();
iter.sum::<usize>()
});
let th_recv: Vec<_> = mailboxes
.into_iter()
.map(|mut mailbox| {
thread::spawn({
let mut double_model = DoubleModel::new();
move || {
let dummy_cx = Context::new_dummy();
block_on(mailbox.recv(&mut double_model, &dummy_cx, &mut ())).unwrap();
thread::sleep(std::time::Duration::from_millis(100));
}
})
})
.collect();
let sum = th_broadcast.join().unwrap();
for th in th_recv {
th.join().unwrap();
}
assert_eq!(sum, N_RECV * MESSAGE * 2);
}
#[test]
fn broadcast_query_filter_map() {
const N_RECV: usize = 4;
const BROADCAST_ALL: usize = 42;
let mut mailboxes = Vec::new();
let mut broadcaster = QueryBroadcaster::default();
for id in 0..N_RECV {
let mailbox = Receiver::new(10);
let address = mailbox.sender();
let sender = Box::new(FilterMapReplierSender::new(
move |x: &usize| (*x == id || *x == BROADCAST_ALL).then_some(*x),
|x| 3 * x,
DoubleModel::double,
address,
));
broadcaster.add(sender);
mailboxes.push(mailbox);
}
let th_broadcast = thread::spawn(move || {
block_on(async {
let mut sum = 0;
for id in 0..N_RECV {
sum += broadcaster.broadcast(id).await.unwrap().sum::<usize>();
}
sum += broadcaster
.broadcast(BROADCAST_ALL)
.await
.unwrap()
.sum::<usize>();
for id in 0..N_RECV {
sum += broadcaster.broadcast(id).await.unwrap().sum::<usize>();
}
sum
})
});
let th_recv: Vec<_> = mailboxes
.into_iter()
.map(|mut mailbox| {
thread::spawn({
let mut double_model = DoubleModel::new();
move || {
let dummy_cx = Context::new_dummy();
block_on(async {
mailbox
.recv(&mut double_model, &dummy_cx, &mut ())
.await
.unwrap();
mailbox
.recv(&mut double_model, &dummy_cx, &mut ())
.await
.unwrap();
mailbox
.recv(&mut double_model, &dummy_cx, &mut ())
.await
.unwrap();
});
thread::sleep(std::time::Duration::from_millis(100));
}
})
})
.collect();
let sum = th_broadcast.join().unwrap();
for th in th_recv {
th.join().unwrap();
}
assert_eq!(
sum,
N_RECV * ((N_RECV - 1) + BROADCAST_ALL) * 2 * 3,
);
}
}
#[cfg(all(test, nexosim_loom))]
mod tests {
use futures_channel::mpsc;
use futures_util::StreamExt;
use loom::model::Builder;
use loom::sync::atomic::{AtomicBool, Ordering};
use loom::thread;
use recycle_box::RecycleBox;
use waker_fn::waker_fn;
use super::super::sender::RecycledFuture;
use super::*;
struct TestEvent<R> {
receiver: mpsc::UnboundedReceiver<Option<R>>,
fut_storage: Option<RecycleBox<()>>,
}
impl<R: Send> Sender<(), R> for TestEvent<R> {
fn send(&mut self, _arg: &()) -> Option<RecycledFuture<'_, Result<R, SendError>>> {
let fut_storage = &mut self.fut_storage;
let receiver = &mut self.receiver;
Some(RecycledFuture::new(fut_storage, async {
let mut stream = Box::pin(receiver.filter_map(|item| async { item }));
Ok(stream.next().await.unwrap())
}))
}
}
impl<R> Clone for TestEvent<R> {
fn clone(&self) -> Self {
unreachable!()
}
}
#[derive(Clone)]
struct TestEventWaker<R> {
sender: mpsc::UnboundedSender<Option<R>>,
}
impl<R> TestEventWaker<R> {
fn wake_spurious(&self) {
let _ = self.sender.unbounded_send(None);
}
fn wake_final(&self, value: R) {
let _ = self.sender.unbounded_send(Some(value));
}
}
fn test_event<R>() -> (TestEvent<R>, TestEventWaker<R>) {
let (sender, receiver) = mpsc::unbounded();
(
TestEvent {
receiver,
fut_storage: None,
},
TestEventWaker { sender },
)
}
#[ignore]
#[test]
fn loom_broadcast_basic() {
const DEFAULT_PREEMPTION_BOUND: usize = 3;
let mut builder = Builder::new();
if builder.preemption_bound.is_none() {
builder.preemption_bound = Some(DEFAULT_PREEMPTION_BOUND);
}
builder.check(move || {
let (test_event1, waker1) = test_event::<usize>();
let (test_event2, waker2) = test_event::<usize>();
let (test_event3, waker3) = test_event::<usize>();
let mut broadcaster = QueryBroadcaster::default();
broadcaster.add(Box::new(test_event1));
broadcaster.add(Box::new(test_event2));
broadcaster.add(Box::new(test_event3));
let mut fut = Box::pin(broadcaster.broadcast(()));
let is_scheduled = loom::sync::Arc::new(AtomicBool::new(false));
let is_scheduled_waker = is_scheduled.clone();
let waker = waker_fn(move || {
is_scheduled_waker.swap(true, Ordering::Release);
});
let mut cx = Context::from_waker(&waker);
let th1 = thread::spawn(move || waker1.wake_final(3));
let th2 = thread::spawn(move || waker2.wake_final(7));
let th3 = thread::spawn(move || waker3.wake_final(42));
loop {
match fut.as_mut().poll(&mut cx) {
Poll::Ready(Ok(mut res)) => {
assert_eq!(res.next(), Some(3));
assert_eq!(res.next(), Some(7));
assert_eq!(res.next(), Some(42));
assert_eq!(res.next(), None);
return;
}
Poll::Ready(Err(_)) => panic!("sender error"),
Poll::Pending => {}
}
if !is_scheduled.swap(false, Ordering::Acquire) {
break;
}
}
th1.join().unwrap();
th2.join().unwrap();
th3.join().unwrap();
assert!(is_scheduled.load(Ordering::Acquire));
match fut.as_mut().poll(&mut cx) {
Poll::Ready(Ok(mut res)) => {
assert_eq!(res.next(), Some(3));
assert_eq!(res.next(), Some(7));
assert_eq!(res.next(), Some(42));
assert_eq!(res.next(), None);
}
Poll::Ready(Err(_)) => panic!("sender error"),
Poll::Pending => panic!("the future has not completed"),
};
});
}
#[ignore]
#[test]
fn loom_broadcast_spurious() {
const DEFAULT_PREEMPTION_BOUND: usize = 3;
let mut builder = Builder::new();
if builder.preemption_bound.is_none() {
builder.preemption_bound = Some(DEFAULT_PREEMPTION_BOUND);
}
builder.check(move || {
let (test_event1, waker1) = test_event::<usize>();
let (test_event2, waker2) = test_event::<usize>();
let mut broadcaster = QueryBroadcaster::default();
broadcaster.add(Box::new(test_event1));
broadcaster.add(Box::new(test_event2));
let mut fut = Box::pin(broadcaster.broadcast(()));
let is_scheduled = loom::sync::Arc::new(AtomicBool::new(false));
let is_scheduled_waker = is_scheduled.clone();
let waker = waker_fn(move || {
is_scheduled_waker.swap(true, Ordering::Release);
});
let mut cx = Context::from_waker(&waker);
let spurious_waker = waker1.clone();
let th1 = thread::spawn(move || waker1.wake_final(3));
let th2 = thread::spawn(move || waker2.wake_final(7));
let th_spurious = thread::spawn(move || spurious_waker.wake_spurious());
loop {
match fut.as_mut().poll(&mut cx) {
Poll::Ready(Ok(mut res)) => {
assert_eq!(res.next(), Some(3));
assert_eq!(res.next(), Some(7));
assert_eq!(res.next(), None);
return;
}
Poll::Ready(Err(_)) => panic!("sender error"),
Poll::Pending => {}
}
if !is_scheduled.swap(false, Ordering::Acquire) {
break;
}
}
th1.join().unwrap();
th2.join().unwrap();
th_spurious.join().unwrap();
assert!(is_scheduled.load(Ordering::Acquire));
match fut.as_mut().poll(&mut cx) {
Poll::Ready(Ok(mut res)) => {
assert_eq!(res.next(), Some(3));
assert_eq!(res.next(), Some(7));
assert_eq!(res.next(), None);
}
Poll::Ready(Err(_)) => panic!("sender error"),
Poll::Pending => panic!("the future has not completed"),
};
});
}
}