use std::sync::Arc;
use ockam_core::{Address, AddressSet, Message, Result, Route, TransportMessage, Worker};
use tokio::{
runtime::Runtime,
sync::mpsc::{channel, Sender},
};
use crate::{
block_future,
error::Error,
node::NullWorker,
parser,
relay::{self, RelayMessage},
Cancel, Mailbox, NodeMessage,
};
pub struct Context {
address: AddressSet,
sender: Sender<NodeMessage>,
rt: Arc<Runtime>,
pub(crate) mailbox: Mailbox,
}
impl Context {
pub fn runtime(&self) -> Arc<Runtime> {
self.rt.clone()
}
}
impl Drop for Context {
fn drop(&mut self) {
let addr = self.address.first();
trace!("Running Context::drop()");
if let Err(e) = block_future(self.rt.as_ref(), async { self.stop_worker(addr).await }) {
trace!("Error occured during Context::drop(): {}", e);
};
}
}
impl Context {
pub(crate) fn new(
rt: Arc<Runtime>,
sender: Sender<NodeMessage>,
address: AddressSet,
mailbox: Mailbox,
) -> Self {
Self {
rt,
sender,
address,
mailbox,
}
}
pub fn address(&self) -> Address {
self.address.first().clone()
}
pub fn aliases(&self) -> AddressSet {
self.address
.clone()
.into_iter()
.skip(1)
.collect::<Vec<_>>()
.into()
}
pub async fn new_context<S: Into<Address>>(&self, addr: S) -> Result<Context> {
let addr = addr.into();
let ctx = NullWorker::new(Arc::clone(&self.rt), &addr, self.sender.clone());
let sender = relay::build_root::<NullWorker, _>(Arc::clone(&self.rt), &ctx.mailbox);
let (msg, mut rx) = NodeMessage::start_worker(addr.into(), sender);
self.sender
.send(msg)
.await
.map_err(|_| Error::FailedStartWorker)?;
Ok(rx
.recv()
.await
.ok_or(Error::InternalIOFailure)?
.map(|_| ctx)?)
}
pub async fn start_worker<NM, NW, S>(&self, address: S, worker: NW) -> Result<()>
where
S: Into<AddressSet>,
NM: Message + Send + 'static,
NW: Worker<Context = Context, Message = NM>,
{
let address = address.into();
let (check_addrs, mut check_rx) = NodeMessage::check_address(address.clone());
self.sender
.send(check_addrs)
.await
.map_err(|_| Error::InternalIOFailure)?;
check_rx.recv().await.ok_or(Error::InternalIOFailure)??;
let (mb_tx, mb_rx) = channel(32);
let mb = Mailbox::new(mb_rx, mb_tx);
let ctx = Context::new(self.rt.clone(), self.sender.clone(), address.clone(), mb);
let sender = relay::build::<NW, NM>(self.rt.as_ref(), worker, ctx);
let (msg, mut rx) = NodeMessage::start_worker(address, sender);
self.sender
.send(msg)
.await
.map_err(|_| Error::FailedStartWorker)?;
Ok(rx
.recv()
.await
.ok_or(Error::InternalIOFailure)?
.map(|_| ())?)
}
pub async fn stop(&mut self) -> Result<()> {
let tx = self.sender.clone();
info!("Shutting down all workers");
match tx.send(NodeMessage::StopNode).await {
Ok(()) => Ok(()),
Err(_e) => Err(Error::FailedStopNode.into()),
}
}
pub async fn stop_worker<A: Into<Address>>(&self, addr: A) -> Result<()> {
let addr = addr.into();
debug!("Shutting down worker {}", addr);
let (req, mut rx) = NodeMessage::stop_worker(addr);
self.sender.send(req).await.map_err(|e| Error::from(e))?;
Ok(rx
.recv()
.await
.ok_or(Error::InternalIOFailure)?
.map(|_| ())?)
}
pub async fn send<R, M>(&self, route: R, msg: M) -> Result<()>
where
R: Into<Route>,
M: Message + Send + 'static,
{
self.send_from_address(route, msg, self.address()).await
}
pub async fn send_from_address<R, M>(
&self,
route: R,
msg: M,
sending_address: Address,
) -> Result<()>
where
R: Into<Route>,
M: Message + Send + 'static,
{
if !self.address.as_ref().contains(&sending_address) {
return Err(Error::SenderAddressDoesntExist.into());
}
let route = route.into();
let (reply_tx, mut reply_rx) = channel(1);
let next = route.next().unwrap();
let req = NodeMessage::SenderReq(next.clone(), reply_tx);
self.sender.send(req).await.map_err(|e| Error::from(e))?;
let (addr, sender, needs_wrapping) = reply_rx
.recv()
.await
.ok_or(Error::InternalIOFailure)??
.take_sender()?;
let payload = msg.encode().unwrap();
let mut data = TransportMessage::v1(route.clone(), payload);
data.return_route.modify().append(sending_address);
let msg = if needs_wrapping {
RelayMessage::pre_router(addr, data, route)
} else {
RelayMessage::direct(addr, data, route)
};
sender.send(msg).await.map_err(|e| Error::from(e))?;
Ok(())
}
pub async fn forward(&self, data: TransportMessage) -> Result<()> {
let (reply_tx, mut reply_rx) = channel(1);
let next = data.onward_route.next().unwrap();
let req = NodeMessage::SenderReq(next.clone(), reply_tx);
self.sender.send(req).await.map_err(|e| Error::from(e))?;
let (addr, sender, needs_wrapping) = reply_rx
.recv()
.await
.ok_or(Error::InternalIOFailure)??
.take_sender()?;
let onward = data.onward_route.clone();
let msg = if needs_wrapping {
RelayMessage::pre_router(addr, data, onward)
} else {
RelayMessage::direct(addr, data, onward)
};
sender.send(msg).await.map_err(|e| Error::from(e))?;
Ok(())
}
pub async fn receive<'ctx, M: Message>(&'ctx mut self) -> Result<Cancel<'ctx, M>> {
let (msg, data, addr) = self.next_from_mailbox().await?;
Ok(Cancel::new(msg, data, addr, self))
}
pub async fn receive_match<'ctx, M, F>(&'ctx mut self, check: F) -> Result<Cancel<'ctx, M>>
where
M: Message,
F: Fn(&M) -> bool,
{
while let Ok((m, data, addr)) = self.next_from_mailbox().await {
if check(&m) {
return Ok(Cancel::new(m, data, addr, self));
} else {
let onward = data.onward_route.clone();
self.mailbox
.requeue(RelayMessage::direct(addr, data, onward))
.await;
}
}
Err(Error::FailedLoadData.into())
}
pub async fn list_workers(&self) -> Result<Vec<Address>> {
let (msg, mut reply_rx) = NodeMessage::list_workers();
self.sender.send(msg).await.map_err(|e| Error::from(e))?;
Ok(reply_rx
.recv()
.await
.ok_or(Error::InternalIOFailure)??
.take_workers()?)
}
pub async fn register<A: Into<Address>>(&self, type_: u8, addr: A) -> Result<()> {
let addr = addr.into();
let (tx, mut rx) = channel(1);
self.sender
.send(NodeMessage::Router(type_, addr, tx))
.await
.map_err(|_| Error::InternalIOFailure)?;
Ok(rx.recv().await.ok_or(Error::InternalIOFailure)??.is_ok()?)
}
async fn next_from_mailbox<M: Message>(&mut self) -> Result<(M, TransportMessage, Address)> {
loop {
let msg = self
.mailbox
.next()
.await
.ok_or_else(|| Error::FailedLoadData)?;
let (addr, data) = msg.transport();
match parser::message(&data.payload).ok() {
Some(msg) => return Ok((msg, data, addr)),
None => {
let onward = data.onward_route.clone();
self.mailbox
.requeue(RelayMessage::direct(addr, data, onward))
.await;
}
}
}
}
}