use crate::launch::Launch;
use crate::{Actor, BoundedOutbox, UnboundedOutbox};
use futures::Stream;
use tokio::task::JoinHandle;
use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
pub trait ActorExt: Actor {
fn with<L>(self, launch: L) -> L::Result<Self>
where
L: Launch<Message = Self::Message>,
{
launch.launch(self)
}
fn start_with<I>(self, inbox: I) -> JoinHandle<()>
where
I: Stream<Item = Self::Message> + Send + 'static,
{
tokio::spawn(self.run(inbox))
}
fn start_with_mailbox_capacity(self, mailbox_capacity: usize) -> BoundedOutbox<Self::Message> {
let (sender, receiver) = tokio::sync::mpsc::channel(mailbox_capacity);
let receiver_stream = ReceiverStream::new(receiver);
BoundedOutbox::new(sender, self.start_with(receiver_stream))
}
fn start(self) -> UnboundedOutbox<Self::Message> {
let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
let receiver_stream = UnboundedReceiverStream::new(receiver);
UnboundedOutbox::new(sender, self.start_with(receiver_stream))
}
}
impl<A: Actor> ActorExt for A {}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Actor, AsyncClose};
use futures::StreamExt;
use std::pin::pin;
use tokio::sync::oneshot;
struct CountActor {
result_tx: oneshot::Sender<u64>,
}
impl Actor for CountActor {
type Message = String;
async fn run(self, inbox: impl Stream<Item = Self::Message> + Send) {
let mut inbox = pin!(inbox);
let mut count = 0;
while (inbox.next().await).is_some() {
count += 1;
}
self.result_tx.send(count).unwrap_or(());
}
}
fn create_count_actor() -> (CountActor, oneshot::Receiver<u64>) {
let (tx, rx) = oneshot::channel();
(CountActor { result_tx: tx }, rx)
}
#[tokio::test]
async fn test_unbounded_start() {
let (actor, rx) = create_count_actor();
let outbox = actor.start();
outbox.send("msg1".to_string()).expect("Failed to send 1");
outbox.send("msg2".to_string()).expect("Failed to send 2");
outbox.close().await;
let count = rx.await.expect("Actor did not return result");
assert_eq!(count, 2);
}
#[tokio::test]
async fn test_bounded_start() {
let (actor, rx) = create_count_actor();
let outbox = actor.start_with_mailbox_capacity(5);
for i in 0..3 {
outbox
.send(format!("msg{}", i))
.await
.expect("Failed to send");
}
outbox.close().await;
let count = rx.await.expect("Actor did not return result");
assert_eq!(count, 3);
}
#[tokio::test]
async fn test_bounded_start_with_backpressure() {
struct BlockingActor {
start_tx: oneshot::Sender<()>,
_stop_rx: oneshot::Receiver<()>,
}
impl Actor for BlockingActor {
type Message = String;
async fn run(self, _: impl Stream<Item = Self::Message> + Send) {
self.start_tx.send(()).unwrap();
self._stop_rx.await.unwrap_or(());
}
}
let (start_tx, start_rx) = oneshot::channel();
let (stop_tx, stop_rx) = oneshot::channel();
let actor = BlockingActor {
start_tx,
_stop_rx: stop_rx,
};
let outbox = actor.start_with_mailbox_capacity(1);
start_rx.await.unwrap();
outbox.send("m1".to_string()).await.unwrap();
let send_future = outbox.send("m2".to_string());
let timeout_result =
tokio::time::timeout(tokio::time::Duration::from_millis(50), send_future).await;
assert!(
timeout_result.is_err(),
"Send should have timed out due to full channel"
);
drop(stop_tx); outbox.close().await;
}
}