use std::any::type_name;
use std::future::Future;
use std::pin::Pin;
use std::task::{self, Poll};
use tokio::sync::{mpsc, oneshot};
pub trait Message: Send + 'static {
type Reply: Send + 'static;
}
pub struct AgentStarted;
impl Message for AgentStarted {
type Reply = ();
}
pub struct Context<A, M: Message> {
tx: oneshot::Sender<M::Reply>,
addr: Addr<A>,
}
impl<A, M: Message> Context<A, M> {
fn new(addr: &Addr<A>) -> (Self, ReplyFuture<M>) {
let (tx, rx) = oneshot::channel();
let addr = addr.clone();
let cx = Self { tx, addr };
let reply_fut = ReplyFuture { rx };
(cx, reply_fut)
}
pub fn reply(self, reply: M::Reply) {
let _res = self.tx.send(reply);
}
#[allow(dead_code)]
pub fn reply_with<F>(self, f: F)
where
F: FnOnce() -> M::Reply,
{
let _res = self.tx.send(f());
}
pub fn reply_later<F>(self, f: F)
where
F: Future<Output = M::Reply> + Send + 'static,
{
let Context { tx, .. } = self;
tokio::spawn(async move {
let _res = tx.send(f.await);
});
}
pub fn addr(&self) -> &Addr<A> {
&self.addr
}
}
pub struct ReplyFuture<M: Message> {
rx: oneshot::Receiver<M::Reply>,
}
impl<M: Message> Future for ReplyFuture<M> {
type Output = M::Reply;
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Self::Output> {
Pin::new(&mut self.rx)
.poll(cx)
.map(|val| val.expect("agent did not send a reply"))
}
}
pub type DispatchFn<A> = Box<dyn FnOnce(&mut A) + Send + 'static>;
pub trait Agent: Send + Sized + 'static {
fn spawn_loop(mut self, mut rx: mpsc::Receiver<DispatchFn<Self>>) {
tokio::spawn(async move {
while let Some(dispatch) = rx.recv().await {
tokio::task::block_in_place(|| {
dispatch(&mut self);
});
}
});
}
fn started(&mut self, cx: Context<Self, AgentStarted>) {
cx.reply(());
}
}
pub async fn spawn_agent<A: Agent>(agent: A) -> Addr<A> {
log::trace!("Starting agent {:?}", type_name::<A>());
let (addr, rx) = Addr::new();
let (cx, reply_fut) = Context::new(&addr);
let tx = addr.tx.clone();
tokio::spawn(async move {
let send_fut = tx.send(Box::new(move |agent: &mut A| {
agent.started(cx);
}));
if send_fut.await.is_err() {
panic!("agent stopped before startup completed");
}
});
agent.spawn_loop(rx);
reply_fut.await;
log::trace!("Started agent {:?}", type_name::<A>());
addr
}
pub trait Handler<M: Message>: Sized {
fn handle(&mut self, message: M, cx: Context<Self, M>);
}
pub struct Addr<A> {
tx: mpsc::Sender<DispatchFn<A>>,
}
impl<A> Addr<A> {
fn new() -> (Addr<A>, mpsc::Receiver<DispatchFn<A>>) {
let (tx, rx) = mpsc::channel(8);
let addr = Addr { tx };
(addr, rx)
}
pub fn send<M>(&self, message: M) -> ReplyFuture<M>
where
M: Message,
A: Handler<M> + Send + 'static,
{
log::trace!(
"Sending message {:?} to agent {:?}",
type_name::<M>(),
type_name::<A>()
);
let (cx, reply_fut) = Context::new(self);
let tx = self.tx.clone();
tokio::spawn(async move {
let send_fut = tx.send(Box::new(move |agent: &mut A| {
agent.handle(message, cx);
}));
if send_fut.await.is_err() {
panic!("tried to send message to stopped agent");
}
});
reply_fut
}
}
impl<A> Clone for Addr<A> {
fn clone(&self) -> Self {
let tx = self.tx.clone();
Addr { tx }
}
}
pub trait Sender<M: Message>: Send + Sync {
fn send(&self, message: M) -> ReplyFuture<M>;
}
impl<M, A> Sender<M> for Addr<A>
where
M: Message,
A: Handler<M> + Send + 'static,
{
fn send(&self, message: M) -> ReplyFuture<M> {
Addr::<A>::send(self, message)
}
}