use crate::channel_types::OneshotReceiver;
use crate::relay::CtrlSignal;
use crate::tokio::runtime::Handle;
use crate::Context;
use cfg_if::cfg_if;
use ockam_core::{Message, RelayMessage, Result, Routed, Worker};
#[cfg(feature = "std")]
use opentelemetry::trace::FutureExt;
pub struct WorkerRelay<W> {
worker: W,
ctx: Context,
}
impl<W: Worker> WorkerRelay<W> {
pub fn new(worker: W, ctx: Context) -> Self {
Self { worker, ctx }
}
}
impl<W, M> WorkerRelay<W>
where
W: Worker<Context = Context, Message = M>,
M: Message + Send + 'static,
{
fn wrap_direct_message(relay_msg: RelayMessage) -> Routed<M> {
Routed::new(
relay_msg.destination().clone(),
relay_msg.source().clone(),
relay_msg.into_local_message(),
)
}
async fn recv_message(&mut self) -> Result<bool> {
let relay_msg = match self.ctx.receiver_next().await? {
Some(msg) => msg,
None => {
trace!("No more messages for worker {}", self.ctx.primary_address());
return Ok(false);
}
};
cfg_if! {
if #[cfg(feature = "std")] {
let tracing_context = relay_msg.local_message().tracing_context();
self.ctx.set_tracing_context(tracing_context.clone());
self.worker
.handle_message(&mut self.ctx, Self::wrap_direct_message(relay_msg))
.with_context(tracing_context.update().extract())
.await?;
} else {
let routed = Self::wrap_direct_message(relay_msg);
self.worker
.handle_message(&mut self.ctx, routed)
.await?;
}
}
Ok(true)
}
#[cfg_attr(not(feature = "std"), allow(unused_mut))]
#[cfg_attr(not(feature = "std"), allow(unused_variables))]
async fn run(mut self, mut ctrl_rx: OneshotReceiver<CtrlSignal>) {
match self.worker.initialize(&mut self.ctx).await {
Ok(()) => {}
Err(e) => {
error!(
"Failure during '{}' worker initialisation: {}",
self.ctx.primary_address(),
e
);
shutdown_and_stop_ack(&mut self.worker, &mut self.ctx, false).await;
return;
}
}
#[cfg(feature = "std")]
loop {
crate::tokio::select! {
result = self.recv_message() => {
match result {
Ok(true) => {},
Ok(false) => {
break;
},
Err(e) => {
#[cfg(feature = "debugger")]
error!("Error encountered during '{}' message handling: {:?}", self.ctx.primary_address(), e);
#[cfg(not(feature = "debugger"))]
error!("Error encountered during '{}' message handling: {}", self.ctx.primary_address(), e);
}
}
},
_ = &mut ctrl_rx => {
debug!(primary_address=%self.ctx.primary_address(), "Relay received shutdown signal, terminating!");
break;
}
};
}
#[cfg(not(feature = "std"))]
loop {
match self.recv_message().await {
Ok(true) => {}
Ok(false) => {
break;
}
Err(e) => error!(
"Error encountered during '{}' message handling: {}",
self.ctx.primary_address(),
e
),
}
}
shutdown_and_stop_ack(&mut self.worker, &mut self.ctx, true).await;
}
pub(crate) fn init(rt: &Handle, worker: W, ctx: Context, ctrl_rx: OneshotReceiver<CtrlSignal>) {
let relay = WorkerRelay::new(worker, ctx);
rt.spawn(relay.run(ctrl_rx));
}
}
async fn shutdown_and_stop_ack<W>(worker: &mut W, ctx: &mut Context, stopped_from_router: bool)
where
W: Worker<Context = Context>,
{
match worker.shutdown(ctx).await {
Ok(()) => {}
Err(e) => {
error!(
"Failure during '{}' worker shutdown: {}",
ctx.primary_address(),
e
);
}
}
let router = match ctx.router() {
Ok(router) => router,
Err(_) => {
error!(
"Failure during '{}' worker shutdown. Can't get router",
ctx.primary_address()
);
return;
}
};
if !stopped_from_router {
if let Err(e) = router.stop_address(ctx.primary_address(), !stopped_from_router) {
error!(
"Failure during '{}' worker shutdown: {}",
ctx.primary_address(),
e
);
}
}
trace!("Sending shutdown ACK");
router.stop_ack(ctx.primary_address()).unwrap_or_else(|e| {
error!(
"Failed to send stop ACK for worker '{}': {}",
ctx.primary_address(),
e
)
});
}