use crate::{parser, Context, Mailbox};
use ockam_core::{
Address, Message, Result, Route, Routed, RouterMessage, TransportMessage, Worker,
};
use std::{marker::PhantomData, sync::Arc};
use tokio::runtime::Runtime;
use tokio::sync::mpsc::{channel, Receiver, Sender};
#[derive(Clone, Debug)]
pub struct RelayMessage {
addr: Address,
data: RelayPayload,
onward: Route,
}
impl RelayMessage {
pub fn direct(addr: Address, data: TransportMessage, onward: Route) -> Self {
Self {
addr,
data: RelayPayload::Direct(data),
onward,
}
}
#[inline]
pub fn pre_router(addr: Address, data: TransportMessage, onward: Route) -> Self {
let route = data.return_.clone();
let r_msg = RouterMessage::Route(data);
Self {
addr,
data: RelayPayload::PreRouter(r_msg.encode().unwrap(), route),
onward,
}
}
#[inline]
pub fn transport(self) -> (Address, TransportMessage) {
(
self.addr,
match self.data {
RelayPayload::Direct(msg) => msg,
_ => panic!("Called transport() on invalid RelayMessage type!"),
},
)
}
}
#[derive(Clone, Debug)]
pub enum RelayPayload {
Direct(TransportMessage),
PreRouter(Vec<u8>, Route),
}
pub struct Relay<W, M>
where
W: Worker<Context = Context>,
M: Message,
{
worker: W,
ctx: Context,
_phantom: PhantomData<M>,
}
impl<W, M> Relay<W, M>
where
W: Worker<Context = Context, Message = M>,
M: Message + Send + 'static,
{
pub fn new(worker: W, ctx: Context) -> Self {
Self {
worker,
ctx,
_phantom: PhantomData,
}
}
#[inline]
fn handle_direct(&mut self, msg: TransportMessage) -> Result<(M, Route)> {
let TransportMessage {
payload, return_, ..
} = msg;
parser::message::<M>(payload)
.map_err(|e| {
error!(
"Failed to decode message payload for worker {}",
self.ctx.address()
);
e.into()
})
.map(|m| (m, return_.clone()))
}
#[inline]
fn handle_pre_router(&mut self, msg: Vec<u8>) -> Result<M> {
M::decode(&msg).map_err(|e| {
error!(
"Failed to decode wrapped router message for worker {}. \
Is your router accepting the correct message type? (ockam_core::RouterMessage)",
self.ctx.address()
);
e.into()
})
}
async fn run(mut self) {
self.worker.initialize(&mut self.ctx).await.unwrap();
while let Some(RelayMessage { addr, data, onward }) = self.ctx.mailbox.next().await {
self.ctx.message_address(addr);
let (msg, return_) = match (|data| -> Result<(M, Route)> {
Ok(match data {
RelayPayload::Direct(trans_msg) => self.handle_direct(trans_msg)?,
RelayPayload::PreRouter(enc_msg, route) => {
self.handle_pre_router(enc_msg).map(|m| (m, route))?
}
})
})(data)
{
Ok((msg, route)) => (msg, route),
Err(_) => continue, };
let routed = Routed::new(msg, return_, onward);
match self.worker.handle_message(&mut self.ctx, routed).await {
Ok(()) => {}
Err(e) => {
error!(
"Worker {} error while handling message: {}",
self.ctx.address(),
e
);
continue;
}
}
self.ctx.message_address(None);
}
self.worker.shutdown(&mut self.ctx).unwrap();
}
async fn run_mailbox(mut rx: Receiver<RelayMessage>, mb_tx: Sender<RelayMessage>) {
while let Some(enc) = rx.recv().await {
match mb_tx.send(enc.clone()).await {
Ok(x) => x,
Err(_) => panic!("Failed to route message to address '{}'", enc.addr),
};
}
}
}
pub(crate) fn build<W, M>(rt: &Runtime, worker: W, ctx: Context) -> Sender<RelayMessage>
where
W: Worker<Context = Context, Message = M>,
M: Message + Send + 'static,
{
let (tx, rx) = channel(32);
let mb_tx = ctx.mailbox.sender();
let relay = Relay::<W, M>::new(worker, ctx);
rt.spawn(Relay::<W, M>::run_mailbox(rx, mb_tx));
rt.spawn(relay.run());
tx
}
pub(crate) fn build_root<W, M>(rt: Arc<Runtime>, mailbox: &Mailbox) -> Sender<RelayMessage>
where
W: Worker<Context = Context, Message = M>,
M: Message + Send + 'static,
{
let (tx, rx) = channel(32);
let mb_tx = mailbox.sender();
rt.spawn(Relay::<W, M>::run_mailbox(rx, mb_tx));
tx
}