use std::marker::PhantomData;
use async_mutex::Mutex;
use async_trait::async_trait;
use flume::{Receiver, Sender};
use snafu::Snafu;
use tracing::{error, info_span, Instrument};
use crate::{
executor::Executor,
traits::{Event, EventConsumer, EventProducer},
types::CompletionToken,
};
#[derive(Debug, Snafu)]
pub enum WrappedReceiverError {
AlreadyRegistered,
Callback,
}
pub struct AsyncWrappedReceiver<X: Executor, T> {
inner: Mutex<Option<Receiver<T>>>,
_executor: PhantomData<X>,
}
impl<X: Executor, T> AsyncWrappedReceiver<X, T> {
pub fn new(bound: Option<usize>) -> (Sender<T>, Self) {
let (tx, rx) = match bound {
Some(bound) => flume::bounded(bound),
None => flume::unbounded(),
};
(tx, Self::from(rx))
}
}
impl<X: Executor, T> From<Receiver<T>> for AsyncWrappedReceiver<X, T> {
fn from(inner: Receiver<T>) -> Self {
Self {
inner: Mutex::new(Some(inner)),
_executor: PhantomData,
}
}
}
#[async_trait]
impl<X: Executor, T: Event> EventProducer<T> for AsyncWrappedReceiver<X, T> {
type Error = WrappedReceiverError;
async fn register_consumer<C>(&self, consumer: C) -> Result<(), Self::Error>
where
C: EventConsumer<T> + Send + Sync + 'static,
{
if let Some(channel) = self.inner.lock().await.take() {
X::spawn_async(
async move {
while let Ok(msg) = channel.recv_async().await {
if let Err(e) = consumer.accept(msg).await {
error!(?e, "Error pushing message into consumer");
}
}
}
.instrument(info_span!("WrappedReceiver internal task")),
);
Ok(())
} else {
AlreadyRegisteredSnafu.fail()
}
}
async fn register_callback<F>(
&self,
_callback: F,
_token: CompletionToken,
) -> Result<(), Self::Error>
where
F: FnOnce(T) + Send + Sync + 'static,
{
CallbackSnafu.fail()
}
fn register_consumer_sync<C>(&self, consumer: C) -> Result<(), Self::Error>
where
C: EventConsumer<T> + Send + Sync + 'static,
{
if let Some(channel) = futures::executor::block_on(self.inner.lock()).take() {
X::spawn_sync(move || {
{
while let Ok(msg) = channel.recv() {
if let Err(e) = consumer.accept_sync(msg) {
error!(?e, "Error pushing message into consumer");
}
}
}
.instrument(info_span!("WrappedReceiver internal task"))
});
Ok(())
} else {
AlreadyRegisteredSnafu.fail()
}
}
fn register_callback_sync<F>(
&self,
_callback: F,
_token: CompletionToken,
) -> Result<(), Self::Error>
where
F: FnOnce(T) + Send + Sync + 'static,
{
CallbackSnafu.fail()
}
}
#[cfg(test)]
mod tests {
use futures::channel::oneshot;
use super::*;
use crate::{
executor::{Executor, Threads},
traits::Actor,
util::{AsyncActor, WrappedEvent},
};
#[derive(Clone, PartialEq, Eq, Debug, PartialOrd, Ord, Hash, Default)]
pub struct Output {
val: usize,
}
async fn smoke<X: Executor>() {
let (tx, rx) = AsyncWrappedReceiver::<X, WrappedEvent<usize>>::new(Some(1));
let actor: AsyncActor<WrappedEvent<usize>, WrappedEvent<Output>, X> = AsyncActor::spawn(
|mut value: usize, mut add: WrappedEvent<usize>| {
let token = add.token();
let add = add.into_inner();
value += add;
if let Some(token) = token {
let mut event: WrappedEvent<Output> = Output { val: value }.into();
event.set_completion_token(token);
(value, Some(event))
} else {
(value, None)
}
},
0,
Some(1),
);
rx.register_consumer(actor.inbox().clone()).await.unwrap();
for _ in 0..100 {
tx.send_async(1_usize.into()).await.unwrap();
}
let (tx, rx) = oneshot::channel();
let mut event: WrappedEvent<usize> = 0_usize.into();
let token = event.tokenize().unwrap();
actor
.outbox()
.register_callback(
move |event| {
tx.send(event).unwrap();
},
token,
)
.await
.unwrap();
actor.catchup().wait().await;
actor.inbox().accept(event).await.unwrap();
let res = rx.await.unwrap().into_inner();
println!("res.val: {}", res.val);
assert!(res.val == 98 || res.val == 99 || res.val == 100);
}
#[cfg(feature = "async-std")]
#[async_std::test]
async fn smoke_async_std() {
smoke::<crate::executor::AsyncStd>().await;
}
#[async_std::test]
async fn smoke_threads() {
smoke::<Threads>().await;
}
}