use std::{collections::HashMap, fmt::Display, marker::PhantomData, sync::Arc, time::Duration};
use serde::{Serialize, de::DeserializeOwned};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
sync::{RwLock, mpsc, oneshot},
};
use crate::{
actors::remote::UntypedHandle,
messaging::{Message, MsgResult},
};
use super::{
address::{self, ActorAddress, PeerId},
dencoder::{self, Dencoder},
netlayer::{AsyncMsgStream, NetLayer},
};
#[derive(Debug)]
pub struct Router;
impl Router {
pub async fn with_netlayer<N>(
mut netlayer: N,
opts: Option<RouterOpts>,
) -> Result<RouterHandle, Error>
where
N: NetLayer + Send + 'static,
<N as NetLayer>::Error: Send + std::fmt::Display,
{
let opts = opts.unwrap_or_default();
netlayer.init().await.map_err(|e| {
tracing::error!("router init: {e}");
Error::Init(e.to_string())
})?;
let host_address = netlayer.address().await.map_err(|e| {
tracing::error!("router init: failed to obtain address - {e}");
Error::Init(e.to_string())
})?;
let host_address_inner = host_address.clone();
let peers: HashMap<PeerId, UntypedHandle> = HashMap::new();
let (sender, mut receiver) =
mpsc::channel::<(RouterMessage, oneshot::Sender<Result<RouterReply, Error>>)>(1024);
let (conf_sender, conf_receiver) = oneshot::channel::<Result<(), Error>>();
tokio::spawn(async move {
let opts = Arc::new(opts);
let peers = Arc::new(RwLock::new(peers));
let _ = conf_sender.send(Ok(()));
loop {
tokio::select! {
Some((command, sender)) = receiver.recv() => {
match command {
RouterMessage::Stop => {
let _ = sender.send(Ok(RouterReply::Accepted));
return;
},
RouterMessage::Attach { handle, peer_id } => {
let addr = match peer_id {
Some(id) => ActorAddress::new_with_peer_id::<N>(&host_address_inner, id),
None => match ActorAddress::new::<N>(&host_address_inner) {
Ok(addr) => addr,
Err(err) => {
tracing::error!("router: attach - {err}");
continue;
}
},
};
peers.write().await.insert(addr.peer_id().to_owned(), handle);
let _ = sender.send(Ok(RouterReply::Address(addr)));
},
RouterMessage::Revoke(addr) => {
peers.write().await.remove(addr.peer_id());
let _ = sender.send(Ok(RouterReply::Address(addr)));
},
}
},
Ok(mut stream) = netlayer.accept() => {
let opts = opts.clone();
let peers = peers.clone();
tokio::spawn(async move {
let _ = tokio::time::timeout(
Duration::from_millis(opts.msg_read_timeout()),
async move {
let id = match try_read_id(&mut stream).await {
Ok(id) => id,
Err(_) => {
return;
},
};
let handle = match peers.read().await.get(&id) {
Some(handle) => handle.clone(),
None => {
tracing::warn!("router: recv - unknown peer {id}");
return;
},
};
let _ = try_handle_message(stream, handle, opts.as_ref()).await;
}).await;
});
}
}
}
});
conf_receiver
.await
.map_err(|e| Error::Init(e.to_string()))??;
Ok(RouterHandle {
sender,
host_address,
})
}
}
async fn try_read_id<S>(stream: &mut S) -> Result<PeerId, Error>
where
S: AsyncReadExt + Unpin,
{
let size = stream.read_u16().await.map_err(|e| {
tracing::error!("router: could not read id size - {e}");
Error::Recv(e.to_string())
})?;
let mut id_buffer: Vec<u8> = vec![0; size as usize];
stream.read_exact(&mut id_buffer).await.map_err(|e| {
tracing::error!("router: recv - {e}");
Error::Recv(e.to_string())
})?;
Ok(PeerId::new_from_bytes(&id_buffer))
}
async fn try_handle_message<S>(
mut stream: S,
handle: UntypedHandle,
opts: &RouterOpts,
) -> Result<(), Error>
where
S: AsyncMsgStream,
{
let msg_size = stream.read_u32().await.map_err(|e| {
tracing::error!("router: recv - could not read msg size - {e}");
Error::Recv(e.to_string())
})?;
if msg_size > opts.max_msg_size() {
tracing::warn!("router: recv - incoming message body exceeds size limit; dropping");
Err(Error::Recv("message too big".into()))?
}
let mut msg_buffer = vec![0; msg_size as usize];
stream.read_exact(&mut msg_buffer).await.map_err(|e| {
tracing::error!("router: recv - could not read msg - {e}");
Error::Recv(e.to_string())
})?;
let res = handle.send(msg_buffer).await.map_err(|err| {
tracing::error!("router: msg error - {err}");
Error::Send(err.to_string())
})?;
stream.write_u32(res.len() as u32).await.map_err(|err| {
tracing::error!("router: could not send response size - {err}");
Error::Send(err.to_string())
})?;
stream.write_all(&res).await.map_err(|err| {
tracing::error!("router: could not send response - {err}");
Error::Send(err.to_string())
})?;
stream.flush().await.map_err(|err| {
tracing::error!("router: could not flush response - {err}");
Error::Send(err.to_string())
})?;
Ok(())
}
#[derive(Debug)]
pub struct RouterOpts {
pub msg_read_timeout: u64,
pub max_msg_size: u32,
}
impl RouterOpts {
pub fn new(msg_read_timeout: u64, max_msg_size: u32) -> Self {
Self {
msg_read_timeout,
max_msg_size,
}
}
pub fn msg_read_timeout(&self) -> u64 {
self.msg_read_timeout
}
pub fn max_msg_size(&self) -> u32 {
self.max_msg_size
}
}
impl Default for RouterOpts {
fn default() -> Self {
Self {
msg_read_timeout: 5000,
max_msg_size: 4194304,
}
}
}
#[derive(Debug, Clone)]
pub struct RouterHandle {
host_address: String,
sender: mpsc::Sender<(RouterMessage, oneshot::Sender<Result<RouterReply, Error>>)>,
}
impl RouterHandle {
pub async fn attach(&self, handle: UntypedHandle) -> Result<ActorAddress, Error> {
self.attach_handle(handle, None).await
}
pub async fn attach_with_id(
&self,
handle: UntypedHandle,
peer_id: PeerId,
) -> Result<ActorAddress, Error> {
self.attach_handle(handle, Some(peer_id)).await
}
async fn attach_handle(
&self,
handle: UntypedHandle,
peer_id: Option<PeerId>,
) -> Result<ActorAddress, Error> {
let (sender, receiver) = oneshot::channel();
self.sender
.send((RouterMessage::Attach { handle, peer_id }, sender))
.await
.map_err(|e| {
tracing::error!("router: {e}");
Error::Send(e.to_string())
})?;
let reply = receiver.await.map_err(|e| {
tracing::error!("router: {e}");
Error::Recv(e.to_string())
})??;
match reply {
RouterReply::Accepted => panic!("expected Address variant"),
RouterReply::Address(a) => Ok(a),
}
}
pub async fn revoke(&self, address: &ActorAddress) -> Result<ActorAddress, Error> {
let (sender, receiver) = oneshot::channel();
self.sender
.send((RouterMessage::Revoke(address.clone()), sender))
.await
.map_err(|e| {
tracing::error!("router: {e}");
Error::Send(e.to_string())
})?;
let reply = receiver.await.map_err(|e| {
tracing::error!("router: {e}");
Error::Recv(e.to_string())
})??;
match reply {
RouterReply::Accepted => panic!("expected Address variant"),
RouterReply::Address(a) => Ok(a),
}
}
pub async fn stop(&self) -> Result<(), Error> {
let (sender, receiver) = oneshot::channel();
self.sender
.send((RouterMessage::Stop, sender))
.await
.map_err(|e| {
tracing::error!("router: {e}");
Error::Send(e.to_string())
})?;
let reply = receiver.await.map_err(|e| {
tracing::error!("router: {e}");
Error::Recv(e.to_string())
})??;
match reply {
RouterReply::Accepted => Ok(()),
RouterReply::Address(_) => panic!("expected Accepted variant"),
}
}
pub fn host_address(&self) -> &str {
&self.host_address
}
}
#[derive(Debug, Clone)]
pub struct RemoteHandle<I, O, E, D: Dencoder, N: NetLayer> {
address: ActorAddress,
netlayer: N,
_ipd: PhantomData<I>,
_opd: PhantomData<O>,
_epd: PhantomData<E>,
_dpd: PhantomData<D>,
}
impl<I, O, E, D, N> RemoteHandle<I, O, E, D, N>
where
I: Serialize + DeserializeOwned,
O: Serialize + DeserializeOwned,
E: Serialize + DeserializeOwned,
D: Dencoder,
N: NetLayer,
{
pub fn new(address: &ActorAddress, netlayer: N) -> Self {
Self {
address: address.to_owned(),
netlayer,
_ipd: PhantomData,
_opd: PhantomData,
_epd: PhantomData,
_dpd: PhantomData,
}
}
pub async fn send(&self, msg: Message<I>) -> Result<MsgResult<O, E>, Error>
where
<N as NetLayer>::Error: std::fmt::Display,
{
let mut stream = self
.netlayer
.connect(self.address.host())
.await
.map_err(|err| {
tracing::error!("remote handle: failed to connect - {err}");
Error::Connect(err.to_string())
})?;
let id = self.addr().peer_id();
let id_len = self.addr().peer_id().len() as u16;
stream.write_u16(id_len).await.map_err(|err| {
tracing::error!("remote handle: failed to send peer ID size - {err}");
Error::Send(err.to_string())
})?;
stream.write_all(id.bytes()).await.map_err(|err| {
tracing::error!("remote handle: failed to send peer ID - {err}");
Error::Send(err.to_string())
})?;
let bytes = D::encode(msg).map_err(Error::Serialize)?;
stream.write_u32(bytes.len() as u32).await.map_err(|err| {
tracing::error!("remote handle: failed to send message size - {err}");
Error::Send(err.to_string())
})?;
stream.write_all(&bytes).await.map_err(|err| {
tracing::error!("remote handle: failed to send message - {err}");
Error::Send(err.to_string())
})?;
stream.flush().await.map_err(|err| {
tracing::error!("remote handle: failed to flush message - {err}");
Error::Send(err.to_string())
})?;
let size = stream.read_u32().await.map_err(|err| {
tracing::error!("remote handle: failed to receive message size - {err}");
Error::Recv(err.to_string())
})?;
let mut res_buffer = vec![0; size as usize];
stream.read_exact(&mut res_buffer).await.map_err(|err| {
tracing::error!("remote handle: failed to receive message - {err}");
Error::Recv(err.to_string())
})?;
D::decode(res_buffer).map_err(Error::Serialize)
}
pub fn addr(&self) -> &ActorAddress {
&self.address
}
}
#[derive(Debug)]
enum RouterMessage {
Stop,
Attach {
handle: UntypedHandle,
peer_id: Option<PeerId>,
},
Revoke(ActorAddress),
}
enum RouterReply {
Accepted,
Address(ActorAddress),
}
#[allow(missing_docs)]
#[derive(Debug)]
pub enum Error {
Init(String),
Connect(String),
Serialize(dencoder::Error),
Send(String),
Recv(String),
Address(address::Error),
}
impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::Init(ctx) => write!(f, "failed to init router: {ctx}"),
Error::Connect(ctx) => write!(f, "failed to connect to endpoint: {ctx}"),
Error::Serialize(ctx) => write!(f, "failed to encode/decode message: {ctx}"),
Error::Send(ctx) => write!(f, "failed to send message: {ctx}"),
Error::Recv(ctx) => write!(f, "failed to receive message: {ctx}"),
Error::Address(ctx) => write!(f, "failed to create address: {ctx}"),
}
}
}
impl std::error::Error for Error {}
#[cfg(test)]
mod tests {
use crate::{
actors::{
remote::{
self,
address::PeerId,
dencoder::bitcode::BitcodeDencoder,
netlayer::tcp_layer::TcpNetLayer,
router::{RemoteHandle, Router, RouterOpts},
},
tests::{Mult, SomeError},
},
messaging::{Message, Reply},
};
#[tokio::test]
async fn spawn_and_message() {
let (_, handle) = remote::spawn_untyped::<_, _, _, BitcodeDencoder>(Mult { a: 3 })
.await
.unwrap();
let router = Router::with_netlayer(TcpNetLayer::new(), Some(RouterOpts::default()))
.await
.unwrap();
let addr = router.attach(handle).await.unwrap();
let remote = RemoteHandle::<u32, u32, SomeError, BitcodeDencoder, TcpNetLayer>::new(
&addr,
TcpNetLayer::new(),
);
let res = remote.send(Message::Task(5)).await.unwrap();
assert!(matches!(res, Ok(Reply::Task(15))));
}
#[tokio::test]
async fn spawn_with_id() {
let (_, handle) = remote::spawn_untyped::<_, _, _, BitcodeDencoder>(Mult { a: 3 })
.await
.unwrap();
let router = Router::with_netlayer(TcpNetLayer::new(), Some(RouterOpts::default()))
.await
.unwrap();
let peer_id = PeerId::new().unwrap();
let addr = router
.attach_with_id(handle, peer_id.clone())
.await
.unwrap();
assert_eq!(peer_id, *addr.peer_id());
let remote = RemoteHandle::<u32, u32, SomeError, BitcodeDencoder, TcpNetLayer>::new(
&addr,
TcpNetLayer::new(),
);
let res = remote.send(Message::Task(5)).await.unwrap();
assert!(matches!(res, Ok(Reply::Task(15))));
}
#[tokio::test]
async fn ping() {
let (_, handle) = remote::spawn_untyped::<_, _, _, BitcodeDencoder>(Mult { a: 3 })
.await
.unwrap();
let router = Router::with_netlayer(TcpNetLayer::new(), Some(RouterOpts::default()))
.await
.unwrap();
let addr = router.attach(handle).await.unwrap();
let remote = RemoteHandle::<u32, u32, SomeError, BitcodeDencoder, TcpNetLayer>::new(
&addr,
TcpNetLayer::new(),
);
let res = remote.send(Message::Ping).await.unwrap();
assert!(matches!(res, Ok(Reply::Accepted)));
}
#[tokio::test]
async fn stop() {
let (_, mut handle) = remote::spawn_untyped::<_, _, _, BitcodeDencoder>(Mult { a: 3 })
.await
.unwrap();
handle.allow_stop(true);
let router = Router::with_netlayer(TcpNetLayer::new(), Some(RouterOpts::default()))
.await
.unwrap();
let addr = router.attach(handle).await.unwrap();
let remote = RemoteHandle::<u32, u32, SomeError, BitcodeDencoder, TcpNetLayer>::new(
&addr,
TcpNetLayer::new(),
);
let res = remote.send(Message::Stop).await.unwrap();
assert!(matches!(res, Ok(Reply::Accepted)));
remote.send(Message::Ping).await.unwrap_err();
}
#[tokio::test]
async fn revoke() {
let (_, handle) = remote::spawn_untyped::<_, _, _, BitcodeDencoder>(Mult { a: 3 })
.await
.unwrap();
let router = Router::with_netlayer(TcpNetLayer::new(), Some(RouterOpts::default()))
.await
.unwrap();
let addr = router.attach(handle).await.unwrap();
let remote = RemoteHandle::<u32, u32, SomeError, BitcodeDencoder, TcpNetLayer>::new(
&addr,
TcpNetLayer::new(),
);
let res = remote.send(Message::Ping).await.unwrap();
assert!(matches!(res, Ok(Reply::Accepted)));
router.revoke(&addr).await.unwrap();
remote.send(Message::Ping).await.unwrap_err();
}
}