use crate::callbacks::CallbackType;
use crate::children::Children;
use crate::context::{BastionId, ContextState};
use crate::envelope::{RefAddr, SignedMessage};
use crate::supervisor::{SupervisionStrategy, Supervisor};
use futures::channel::oneshot::{self, Receiver};
use std::any::{type_name, Any};
use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tracing::{debug, trace};
pub trait Message: Any + Send + Sync + Debug {}
impl<T> Message for T where T: Any + Send + Sync + Debug {}
#[derive(Debug)]
pub struct AnswerSender(oneshot::Sender<SignedMessage>, RefAddr);
#[derive(Debug)]
pub struct Answer(Receiver<SignedMessage>);
#[derive(Debug)]
pub struct Msg(MsgInner);
#[derive(Debug)]
enum MsgInner {
Broadcast(Arc<dyn Any + Send + Sync + 'static>),
Tell(Box<dyn Any + Send + Sync + 'static>),
Ask {
msg: Box<dyn Any + Send + Sync + 'static>,
sender: Option<AnswerSender>,
},
}
#[derive(Debug)]
pub(crate) enum BastionMessage {
Start,
Stop,
Kill,
Deploy(Box<Deployment>),
Prune {
id: BastionId,
},
SuperviseWith(SupervisionStrategy),
ApplyCallback(CallbackType),
InstantiatedChild {
parent_id: BastionId,
child_id: BastionId,
state: Arc<Pin<Box<ContextState>>>,
},
Message(Msg),
RestartRequired {
id: BastionId,
parent_id: BastionId,
},
FinishedChild {
id: BastionId,
parent_id: BastionId,
},
RestartSubtree,
RestoreChild {
id: BastionId,
state: Arc<Pin<Box<ContextState>>>,
},
DropChild {
id: BastionId,
},
SetState {
state: Arc<Pin<Box<ContextState>>>,
},
Stopped {
id: BastionId,
},
Faulted {
id: BastionId,
},
Heartbeat,
}
#[derive(Debug)]
pub(crate) enum Deployment {
Supervisor(Supervisor),
Children(Children),
}
impl AnswerSender {
pub fn reply<M: Message>(self, msg: M) -> Result<(), M> {
debug!("{:?}: Sending answer: {:?}", self, msg);
let msg = Msg::tell(msg);
trace!("{:?}: Sending message: {:?}", self, msg);
let AnswerSender(sender, sign) = self;
sender
.send(SignedMessage::new(msg, sign))
.map_err(|smsg| smsg.msg.try_unwrap().unwrap())
}
}
impl Msg {
pub(crate) fn broadcast<M: Message>(msg: M) -> Self {
let inner = MsgInner::Broadcast(Arc::new(msg));
Msg(inner)
}
pub(crate) fn tell<M: Message>(msg: M) -> Self {
let inner = MsgInner::Tell(Box::new(msg));
Msg(inner)
}
pub(crate) fn ask<M: Message>(msg: M, sign: RefAddr) -> (Self, Answer) {
let msg = Box::new(msg);
let (sender, recver) = oneshot::channel();
let sender = AnswerSender(sender, sign);
let answer = Answer(recver);
let sender = Some(sender);
let inner = MsgInner::Ask { msg, sender };
(Msg(inner), answer)
}
#[doc(hidden)]
pub fn is_broadcast(&self) -> bool {
matches!(self.0, MsgInner::Broadcast(_))
}
#[doc(hidden)]
pub fn is_tell(&self) -> bool {
matches!(self.0, MsgInner::Tell(_))
}
#[doc(hidden)]
pub fn is_ask(&self) -> bool {
matches!(self.0, MsgInner::Ask { .. })
}
#[doc(hidden)]
pub fn take_sender(&mut self) -> Option<AnswerSender> {
debug!("{:?}: Taking sender.", self);
if let MsgInner::Ask { sender, .. } = &mut self.0 {
sender.take()
} else {
None
}
}
#[doc(hidden)]
pub fn is<M: Message>(&self) -> bool {
match &self.0 {
MsgInner::Tell(msg) => msg.is::<M>(),
MsgInner::Ask { msg, .. } => msg.is::<M>(),
MsgInner::Broadcast(msg) => msg.is::<M>(),
}
}
#[doc(hidden)]
pub fn downcast<M: Message>(self) -> Result<M, Self> {
trace!("{:?}: Downcasting to {}.", self, type_name::<M>());
match self.0 {
MsgInner::Tell(msg) => {
if msg.is::<M>() {
let msg: Box<dyn Any + 'static> = msg;
Ok(*msg.downcast().unwrap())
} else {
let inner = MsgInner::Tell(msg);
Err(Msg(inner))
}
}
MsgInner::Ask { msg, sender } => {
if msg.is::<M>() {
let msg: Box<dyn Any + 'static> = msg;
Ok(*msg.downcast().unwrap())
} else {
let inner = MsgInner::Ask { msg, sender };
Err(Msg(inner))
}
}
_ => Err(self),
}
}
#[doc(hidden)]
pub fn downcast_ref<M: Message>(&self) -> Option<Arc<M>> {
trace!("{:?}: Downcasting to ref of {}.", self, type_name::<M>());
if let MsgInner::Broadcast(msg) = &self.0 {
if msg.is::<M>() {
return Some(msg.clone().downcast::<M>().unwrap());
}
}
None
}
pub(crate) fn try_clone(&self) -> Option<Self> {
trace!("{:?}: Trying to clone.", self);
if let MsgInner::Broadcast(msg) = &self.0 {
let inner = MsgInner::Broadcast(msg.clone());
Some(Msg(inner))
} else {
None
}
}
pub(crate) fn try_unwrap<M: Message>(self) -> Result<M, Self> {
debug!("{:?}: Trying to unwrap.", self);
if let MsgInner::Broadcast(msg) = self.0 {
match msg.downcast() {
Ok(msg) => match Arc::try_unwrap(msg) {
Ok(msg) => Ok(msg),
Err(msg) => {
let inner = MsgInner::Broadcast(msg);
Err(Msg(inner))
}
},
Err(msg) => {
let inner = MsgInner::Broadcast(msg);
Err(Msg(inner))
}
}
} else {
self.downcast()
}
}
}
impl AsRef<dyn Any> for Msg {
fn as_ref(&self) -> &dyn Any {
match &self.0 {
MsgInner::Broadcast(msg) => msg.as_ref(),
MsgInner::Tell(msg) => msg.as_ref(),
MsgInner::Ask { msg, .. } => msg.as_ref(),
}
}
}
impl BastionMessage {
pub(crate) fn start() -> Self {
BastionMessage::Start
}
pub(crate) fn stop() -> Self {
BastionMessage::Stop
}
pub(crate) fn kill() -> Self {
BastionMessage::Kill
}
pub(crate) fn deploy_supervisor(supervisor: Supervisor) -> Self {
let deployment = Deployment::Supervisor(supervisor);
BastionMessage::Deploy(deployment.into())
}
pub(crate) fn deploy_children(children: Children) -> Self {
let deployment = Deployment::Children(children);
BastionMessage::Deploy(deployment.into())
}
pub(crate) fn prune(id: BastionId) -> Self {
BastionMessage::Prune { id }
}
pub(crate) fn supervise_with(strategy: SupervisionStrategy) -> Self {
BastionMessage::SuperviseWith(strategy)
}
pub(crate) fn apply_callback(callback_type: CallbackType) -> Self {
BastionMessage::ApplyCallback(callback_type)
}
pub(crate) fn instantiated_child(
parent_id: BastionId,
child_id: BastionId,
state: Arc<Pin<Box<ContextState>>>,
) -> Self {
BastionMessage::InstantiatedChild {
parent_id,
child_id,
state,
}
}
pub(crate) fn broadcast<M: Message>(msg: M) -> Self {
let msg = Msg::broadcast(msg);
BastionMessage::Message(msg)
}
pub(crate) fn tell<M: Message>(msg: M) -> Self {
let msg = Msg::tell(msg);
BastionMessage::Message(msg)
}
pub(crate) fn ask<M: Message>(msg: M, sign: RefAddr) -> (Self, Answer) {
let (msg, answer) = Msg::ask(msg, sign);
(BastionMessage::Message(msg), answer)
}
pub(crate) fn restart_required(id: BastionId, parent_id: BastionId) -> Self {
BastionMessage::RestartRequired { id, parent_id }
}
pub(crate) fn finished_child(id: BastionId, parent_id: BastionId) -> Self {
BastionMessage::FinishedChild { id, parent_id }
}
pub(crate) fn restart_subtree() -> Self {
BastionMessage::RestartSubtree
}
pub(crate) fn restore_child(id: BastionId, state: Arc<Pin<Box<ContextState>>>) -> Self {
BastionMessage::RestoreChild { id, state }
}
pub(crate) fn drop_child(id: BastionId) -> Self {
BastionMessage::DropChild { id }
}
pub(crate) fn set_state(state: Arc<Pin<Box<ContextState>>>) -> Self {
BastionMessage::SetState { state }
}
pub(crate) fn stopped(id: BastionId) -> Self {
BastionMessage::Stopped { id }
}
pub(crate) fn faulted(id: BastionId) -> Self {
BastionMessage::Faulted { id }
}
pub(crate) fn heartbeat() -> Self {
BastionMessage::Heartbeat
}
pub(crate) fn try_clone(&self) -> Option<Self> {
trace!("{:?}: Trying to clone.", self);
let clone = match self {
BastionMessage::Start => BastionMessage::start(),
BastionMessage::Stop => BastionMessage::stop(),
BastionMessage::Kill => BastionMessage::kill(),
BastionMessage::Deploy(_) => unimplemented!(),
BastionMessage::Prune { id } => BastionMessage::prune(id.clone()),
BastionMessage::SuperviseWith(strategy) => {
BastionMessage::supervise_with(strategy.clone())
}
BastionMessage::ApplyCallback(callback_type) => {
BastionMessage::apply_callback(callback_type.clone())
}
BastionMessage::InstantiatedChild {
parent_id,
child_id,
state,
} => BastionMessage::instantiated_child(
parent_id.clone(),
child_id.clone(),
state.clone(),
),
BastionMessage::Message(msg) => BastionMessage::Message(msg.try_clone()?),
BastionMessage::RestartRequired { id, parent_id } => {
BastionMessage::restart_required(id.clone(), parent_id.clone())
}
BastionMessage::FinishedChild { id, parent_id } => {
BastionMessage::finished_child(id.clone(), parent_id.clone())
}
BastionMessage::RestartSubtree => BastionMessage::restart_subtree(),
BastionMessage::RestoreChild { id, state } => {
BastionMessage::restore_child(id.clone(), state.clone())
}
BastionMessage::DropChild { id } => BastionMessage::drop_child(id.clone()),
BastionMessage::SetState { state } => BastionMessage::set_state(state.clone()),
BastionMessage::Stopped { id } => BastionMessage::stopped(id.clone()),
BastionMessage::Faulted { id } => BastionMessage::faulted(id.clone()),
BastionMessage::Heartbeat => BastionMessage::heartbeat(),
};
Some(clone)
}
pub(crate) fn into_msg<M: Message>(self) -> Option<M> {
if let BastionMessage::Message(msg) = self {
msg.try_unwrap().ok()
} else {
None
}
}
}
impl Future for Answer {
type Output = Result<SignedMessage, ()>;
fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
debug!("{:?}: Polling.", self);
Pin::new(&mut self.get_mut().0).poll(ctx).map_err(|_| ())
}
}
#[macro_export]
macro_rules! msg {
($msg:expr, $($tokens:tt)+) => {
msg!(@internal $msg, (), (), (), $($tokens)+)
};
(@internal
$msg:expr,
($($bvar:ident, $bty:ty, $bhandle:expr,)*),
($($tvar:ident, $tty:ty, $thandle:expr,)*),
($($avar:ident, $aty:ty, $ahandle:expr,)*),
ref $var:ident: $ty:ty => $handle:expr;
$($rest:tt)+
) => {
msg!(@internal $msg,
($($bvar, $bty, $bhandle,)* $var, $ty, $handle,),
($($tvar, $tty, $thandle,)*),
($($avar, $aty, $ahandle,)*),
$($rest)+
)
};
(@internal
$msg:expr,
($($bvar:ident, $bty:ty, $bhandle:expr,)*),
($($tvar:ident, $tty:ty, $thandle:expr,)*),
($($avar:ident, $aty:ty, $ahandle:expr,)*),
$var:ident: $ty:ty => $handle:expr;
$($rest:tt)+
) => {
msg!(@internal $msg,
($($bvar, $bty, $bhandle,)*),
($($tvar, $tty, $thandle,)* $var, $ty, $handle,),
($($avar, $aty, $ahandle,)*),
$($rest)+
)
};
(@internal
$msg:expr,
($($bvar:ident, $bty:ty, $bhandle:expr,)*),
($($tvar:ident, $tty:ty, $thandle:expr,)*),
($($avar:ident, $aty:ty, $ahandle:expr,)*),
$var:ident: $ty:ty =!> $handle:expr;
$($rest:tt)+
) => {
msg!(@internal $msg,
($($bvar, $bty, $bhandle,)*),
($($tvar, $tty, $thandle,)*),
($($avar, $aty, $ahandle,)* $var, $ty, $handle,),
$($rest)+
)
};
(@internal
$msg:expr,
($($bvar:ident, $bty:ty, $bhandle:expr,)*),
($($tvar:ident, $tty:ty, $thandle:expr,)*),
($($avar:ident, $aty:ty, $ahandle:expr,)*),
_: _ => $handle:expr;
) => {
msg!(@internal $msg,
($($bvar, $bty, $bhandle,)*),
($($tvar, $tty, $thandle,)*),
($($avar, $aty, $ahandle,)*),
msg: _ => $handle;
)
};
(@internal
$msg:expr,
($($bvar:ident, $bty:ty, $bhandle:expr,)*),
($($tvar:ident, $tty:ty, $thandle:expr,)*),
($($avar:ident, $aty:ty, $ahandle:expr,)*),
$var:ident: _ => $handle:expr;
) => { {
let mut signed = $msg;
let (mut $var, sign) = signed.extract();
macro_rules! signature {
() => {
sign
};
}
let sender = $var.take_sender();
if $var.is_broadcast() {
if false {
unreachable!();
}
$(
else if $var.is::<$bty>() {
let $bvar = &*$var.downcast_ref::<$bty>().unwrap();
{ $bhandle }
}
)*
else {
{ $handle }
}
} else if sender.is_some() {
let sender = sender.unwrap();
macro_rules! answer {
($ctx:expr, $answer:expr) => {
{
let sign = $ctx.signature();
sender.reply($answer)
}
};
}
if false {
unreachable!();
}
$(
else if $var.is::<$aty>() {
let $avar = $var.downcast::<$aty>().unwrap();
{ $ahandle }
}
)*
else {
{ $handle }
}
} else {
if false {
unreachable!();
}
$(
else if $var.is::<$tty>() {
let $tvar = $var.downcast::<$tty>().unwrap();
{ $thandle }
}
)*
else {
{ $handle }
}
}
} };
}
#[macro_export]
macro_rules! answer {
($msg:expr, $answer:expr) => {{
let (mut msg, sign) = $msg.extract();
let sender = msg.take_sender().expect("failed to take render");
sender.reply($answer)
}};
}
#[derive(Debug)]
enum MessageHandlerState<O> {
Matched(O),
Unmatched(SignedMessage),
}
impl<O> MessageHandlerState<O> {
fn take_message(self) -> Result<SignedMessage, O> {
match self {
MessageHandlerState::Unmatched(msg) => Ok(msg),
MessageHandlerState::Matched(output) => Err(output),
}
}
fn output_or_else(self, f: impl FnOnce(SignedMessage) -> O) -> O {
match self {
MessageHandlerState::Matched(output) => output,
MessageHandlerState::Unmatched(msg) => f(msg),
}
}
}
#[derive(Debug)]
pub struct MessageHandler<O> {
state: MessageHandlerState<O>,
}
impl<O> MessageHandler<O> {
pub fn new(msg: SignedMessage) -> MessageHandler<O> {
let state = MessageHandlerState::Unmatched(msg);
MessageHandler { state }
}
pub fn on_question<T, F>(self, f: F) -> MessageHandler<O>
where
T: 'static,
F: FnOnce(T, AnswerSender) -> O,
{
match self.try_into_question::<T>() {
Ok((arg, sender)) => {
let val = f(arg, sender);
MessageHandler::matched(val)
}
Err(this) => this,
}
}
pub fn on_fallback<F>(self, f: F) -> O
where
F: FnOnce(&dyn Any, RefAddr) -> O,
{
self.state
.output_or_else(|SignedMessage { msg, sign }| f(msg.as_ref(), sign))
}
pub fn on_broadcast<T, F>(self, f: F) -> MessageHandler<O>
where
T: 'static + Send + Sync,
F: FnOnce(&T, RefAddr) -> O,
{
match self.try_into_broadcast::<T>() {
Ok((arg, addr)) => {
let val = f(arg.as_ref(), addr);
MessageHandler::matched(val)
}
Err(this) => this,
}
}
pub fn on_tell<T, F>(self, f: F) -> MessageHandler<O>
where
T: Debug + 'static,
F: FnOnce(T, RefAddr) -> O,
{
match self.try_into_tell::<T>() {
Ok((msg, addr)) => {
let val = f(msg, addr);
MessageHandler::matched(val)
}
Err(this) => this,
}
}
fn matched(output: O) -> MessageHandler<O> {
let state = MessageHandlerState::Matched(output);
MessageHandler { state }
}
fn try_into_question<T: 'static>(self) -> Result<(T, AnswerSender), MessageHandler<O>> {
debug!("try_into_question with type {}", std::any::type_name::<T>());
match self.state.take_message() {
Ok(SignedMessage {
msg:
Msg(MsgInner::Ask {
msg,
sender: Some(sender),
}),
..
}) if msg.is::<T>() => {
let msg: Box<dyn Any> = msg;
Ok((*msg.downcast::<T>().unwrap(), sender))
}
Ok(anything) => Err(MessageHandler::new(anything)),
Err(output) => Err(MessageHandler::matched(output)),
}
}
fn try_into_broadcast<T: Send + Sync + 'static>(
self,
) -> Result<(Arc<T>, RefAddr), MessageHandler<O>> {
debug!(
"try_into_broadcast with type {}",
std::any::type_name::<T>()
);
match self.state.take_message() {
Ok(SignedMessage {
msg: Msg(MsgInner::Broadcast(msg)),
sign,
}) if msg.is::<T>() => {
let msg: Arc<dyn Any + Send + Sync + 'static> = msg;
Ok((msg.downcast::<T>().unwrap(), sign))
}
Ok(anything) => Err(MessageHandler::new(anything)),
Err(output) => Err(MessageHandler::matched(output)),
}
}
fn try_into_tell<T: Debug + 'static>(self) -> Result<(T, RefAddr), MessageHandler<O>> {
debug!("try_into_tell with type {}", std::any::type_name::<T>());
match self.state.take_message() {
Ok(SignedMessage {
msg: Msg(MsgInner::Tell(msg)),
sign,
}) if msg.is::<T>() => {
let msg: Box<dyn Any> = msg;
Ok((*msg.downcast::<T>().unwrap(), sign))
}
Ok(anything) => Err(MessageHandler::new(anything)),
Err(output) => Err(MessageHandler::matched(output)),
}
}
}