use core::sync::atomic::Ordering;
use core::time::Duration;
use ockam_core::{Message, RelayMessage, Result, Routed};
use crate::debugger;
use crate::error::*;
use crate::tokio::time::timeout;
use crate::{Context, DEFAULT_TIMEOUT};
pub(super) enum MessageWait {
Timeout(Duration),
Blocking,
}
pub struct MessageReceiveOptions {
message_wait: MessageWait,
}
impl Default for MessageReceiveOptions {
fn default() -> Self {
Self::new()
}
}
impl MessageReceiveOptions {
pub fn new() -> Self {
Self {
message_wait: MessageWait::Timeout(DEFAULT_TIMEOUT),
}
}
pub(super) fn with_message_wait(mut self, message_wait: MessageWait) -> Self {
self.message_wait = message_wait;
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.message_wait = MessageWait::Timeout(timeout);
self
}
pub fn with_timeout_secs(mut self, timeout: u64) -> Self {
self.message_wait = MessageWait::Timeout(Duration::from_secs(timeout));
self
}
pub fn without_timeout(mut self) -> Self {
self.message_wait = MessageWait::Blocking;
self
}
}
impl Context {
pub(crate) async fn receiver_next(&mut self) -> Result<Option<RelayMessage>> {
loop {
let relay_msg = if let Some(msg) = self.receiver.recv().await.map(|msg| {
trace!("{}: received new message!", self.address());
self.mailbox_count.fetch_sub(1, Ordering::Acquire);
msg
}) {
msg
} else {
return Ok(None);
};
debugger::log_incoming_message(self, &relay_msg);
if !self.mailboxes.is_incoming_authorized(&relay_msg).await? {
warn!(
"Message received from {} for {} did not pass incoming access control",
relay_msg.source(),
relay_msg.destination()
);
debug!(
"Message return_route: {:?} onward_route: {:?}",
relay_msg.return_route(),
relay_msg.onward_route()
);
continue;
}
return Ok(Some(relay_msg));
}
}
async fn next_from_mailbox<M: Message>(&mut self) -> Result<Routed<M>> {
let msg = self
.receiver_next()
.await?
.ok_or_else(|| NodeError::Data.not_found())?;
let destination_addr = msg.destination().clone();
let src_addr = msg.source().clone();
let local_msg = msg.into_local_message();
Ok(Routed::new(destination_addr, src_addr, local_msg))
}
pub async fn receive<M: Message>(&mut self) -> Result<Routed<M>> {
self.receive_extended(MessageReceiveOptions::new()).await
}
pub async fn receive_extended<M: Message>(
&mut self,
options: MessageReceiveOptions,
) -> Result<Routed<M>> {
match options.message_wait {
MessageWait::Timeout(timeout_duration) => {
timeout(timeout_duration, async { self.next_from_mailbox().await })
.await
.map_err(|_| NodeError::Data.with_timeout(timeout_duration))?
}
MessageWait::Blocking => self.next_from_mailbox().await,
}
}
}