use crate::broadcast::Broadcast;
use crate::callbacks::{CallbackType, Callbacks};
use crate::child_ref::ChildRef;
use crate::context::{BastionContext, BastionId, ContextState};
use crate::envelope::Envelope;
use crate::message::BastionMessage;
#[cfg(feature = "scaling")]
use crate::resizer::ActorGroupStats;
use crate::system::SYSTEM;
use anyhow::Result as AnyResult;
use async_mutex::Mutex;
use bastion_executor::pool;
use futures::pending;
use futures::poll;
use futures::prelude::*;
use lightproc::prelude::*;
use lightproc::proc_state::EmptyProcState;
use std::fmt::{self, Debug, Formatter};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tracing::{debug, error, trace, warn};
pub(crate) struct Init(pub(crate) Box<dyn Fn(BastionContext) -> Exec + Send>);
pub(crate) struct Exec(pub(crate) Pin<Box<dyn Future<Output = Result<(), ()>> + Send>>);
#[derive(Debug)]
pub(crate) struct Child {
bcast: Broadcast,
callbacks: Callbacks,
exec: Exec,
state: Arc<Mutex<Pin<Box<ContextState>>>>,
pre_start_msgs: Vec<Envelope>,
child_ref: ChildRef,
started: bool,
}
impl Init {
pub(crate) fn new<C, F>(init: C) -> Self
where
C: Fn(BastionContext) -> F + Send + 'static,
F: Future<Output = Result<(), ()>> + Send + 'static,
{
let init = Box::new(move |ctx: BastionContext| {
let fut = init(ctx);
let exec = Box::pin(fut);
Exec(exec)
});
Init(init)
}
}
impl Child {
pub(crate) fn new(
exec: Exec,
callbacks: Callbacks,
bcast: Broadcast,
state: Arc<Mutex<Pin<Box<ContextState>>>>,
child_ref: ChildRef,
) -> Self {
debug!("Child({}): Initializing.", bcast.id());
let pre_start_msgs = Vec::new();
let started = false;
Child {
bcast,
callbacks,
exec,
state,
pre_start_msgs,
child_ref,
started,
}
}
fn stack(&self) -> ProcStack {
trace!("Child({}): Creating ProcStack.", self.id());
let id = self.bcast.id().clone();
let parent = self.bcast.parent().clone().into_children().unwrap();
let path = self.bcast.path().clone();
let sender = self.bcast.sender().clone();
let parent_inner = self.bcast.parent().clone().into_children();
let child_ref_inner = self.child_ref.clone();
ProcStack::default().with_after_panic(move |_state: &mut EmptyProcState| {
warn!("Child({}): Panicked.", id);
if let Some(parent) = &parent_inner {
let used_dispatchers = parent.dispatchers();
let global_dispatcher = SYSTEM.dispatcher();
global_dispatcher.remove(used_dispatchers, &child_ref_inner);
}
let id = id.clone();
let msg = BastionMessage::restart_required(id, parent.id().clone());
let env = Envelope::new(msg, path.clone(), sender.clone());
parent.send(env).ok();
})
}
pub(crate) fn id(&self) -> &BastionId {
self.bcast.id()
}
fn stopped(&mut self) {
debug!("Child({}): Stopped.", self.id());
self.remove_from_dispatchers();
self.bcast.stopped();
}
fn faulted(&mut self) {
debug!("Child({}): Faulted.", self.id());
self.remove_from_dispatchers();
let parent = self.bcast.parent().clone().into_children().unwrap();
let path = self.bcast.path().clone();
let sender = self.bcast.sender().clone();
let msg = BastionMessage::restart_required(self.id().clone(), parent.id().clone());
let env = Envelope::new(msg, path, sender);
parent.send(env).ok();
}
async fn handle(&mut self, env: Envelope) -> Result<(), ()> {
match env {
Envelope {
msg: BastionMessage::Start,
..
} => unreachable!(),
Envelope {
msg: BastionMessage::Stop,
..
} => {
self.stopped();
#[cfg(feature = "scaling")]
self.cleanup_actors_stats().await;
self.callbacks.after_stop();
return Err(());
}
Envelope {
msg: BastionMessage::Kill,
..
} => {
self.stopped();
#[cfg(feature = "scaling")]
self.cleanup_actors_stats().await;
self.callbacks.before_restart();
return Err(());
}
Envelope {
msg: BastionMessage::Deploy(_),
..
} => unimplemented!(),
Envelope {
msg: BastionMessage::Prune { .. },
..
} => unimplemented!(),
Envelope {
msg: BastionMessage::ApplyCallback(callback_type),
..
} => self.apply_callback(callback_type),
Envelope {
msg: BastionMessage::SuperviseWith(_),
..
} => unimplemented!(),
Envelope {
msg: BastionMessage::InstantiatedChild { .. },
..
} => unreachable!(),
Envelope {
msg: BastionMessage::Message(msg),
sign,
} => {
debug!("Child({}): Received a message: {:?}", self.id(), msg);
let state = self.state.clone();
let mut guard = state.lock().await;
guard.push_message(msg, sign);
}
Envelope {
msg: BastionMessage::RestartRequired { .. },
..
} => unreachable!(),
Envelope {
msg: BastionMessage::RestartSubtree,
..
} => unreachable!(),
Envelope {
msg: BastionMessage::RestoreChild { .. },
..
} => unreachable!(),
Envelope {
msg: BastionMessage::FinishedChild { .. },
..
} => unreachable!(),
Envelope {
msg: BastionMessage::DropChild { .. },
..
} => unreachable!(),
Envelope {
msg: BastionMessage::SetState { state },
..
} => {
debug!("Child({}): Setting new state: {:?}", self.id(), state);
self.state = state;
}
Envelope {
msg: BastionMessage::Stopped { .. },
..
} => unimplemented!(),
Envelope {
msg: BastionMessage::Faulted { .. },
..
} => unimplemented!(),
Envelope {
msg: BastionMessage::Heartbeat,
..
} => unreachable!(),
}
Ok(())
}
async fn initialize(&mut self) -> Result<(), ()> {
trace!(
"Child({}): Received a new message (started=false): {:?}",
self.id(),
BastionMessage::Start
);
debug!("Child({}): Starting.", self.id());
self.callbacks.before_start();
self.started = true;
let msgs = self.pre_start_msgs.drain(..).collect::<Vec<_>>();
self.pre_start_msgs.shrink_to_fit();
debug!(
"Child({}): Replaying messages received before starting.",
self.id()
);
for msg in msgs {
trace!("Child({}): Replaying message: {:?}", self.id(), msg);
if self.handle(msg).await.is_err() {
return Err(());
}
}
Ok(())
}
fn apply_callback(&mut self, callback_type: CallbackType) {
match callback_type {
CallbackType::BeforeStart => self.callbacks.before_start(),
CallbackType::BeforeRestart => self.callbacks.before_restart(),
CallbackType::AfterRestart => self.callbacks.after_restart(),
CallbackType::AfterStop => self.callbacks.after_stop(),
}
}
#[cfg(feature = "scaling")]
async fn update_stats(&mut self) {
let guard = self.state.lock().await;
let context_state = guard.as_ref();
let storage = guard.stats();
let mut stats = ActorGroupStats::load(storage.clone());
stats.update_average_mailbox_size(context_state.mailbox_size());
stats.store(storage);
let actor_stats_table = guard.actor_stats();
actor_stats_table
.insert(self.bcast.id().clone(), context_state.mailbox_size())
.ok();
}
async fn run(mut self) {
debug!("Child({}): Launched.", self.id());
if let Err(e) = self.register_in_dispatchers() {
error!("couldn't add actor to the registry: {}", e);
return;
};
loop {
#[cfg(feature = "scaling")]
self.update_stats().await;
match poll!(&mut self.bcast.next()) {
Poll::Ready(Some(Envelope {
msg: BastionMessage::Start,
..
})) => {
if self.initialize().await.is_err() {
error!("couldn't initialize Child with id: {}", self.id());
return;
}
continue;
}
Poll::Ready(Some(msg)) if !self.started => {
trace!(
"Child({}): Received a new message (started=false): {:?}",
self.id(),
msg
);
self.pre_start_msgs.push(msg);
continue;
}
Poll::Ready(Some(msg)) => {
trace!(
"Child({}): Received a new message (started=true): {:?}",
self.id(),
msg
);
if self.handle(msg).await.is_err() {
error!("Child({}): Couldn't handle message", self.id());
return;
}
continue;
}
Poll::Ready(None) => unreachable!(),
Poll::Pending => (),
}
#[cfg(feature = "scaling")]
self.update_stats().await;
if !self.started {
pending!();
continue;
}
match poll!(&mut self.exec) {
Poll::Ready(Ok(())) => {
debug!(
"Child({}): The future finished executing successfully.",
self.id()
);
return self.stopped();
}
Poll::Ready(Err(())) => {
warn!("Child({}): The future returned an error.", self.id());
return self.faulted();
}
Poll::Pending => (),
}
pending!();
}
}
pub(crate) fn launch(self) -> RecoverableHandle<()> {
let stack = self.stack();
pool::spawn(self.run(), stack)
}
fn register_in_dispatchers(&self) -> AnyResult<()> {
if let Some(parent) = self.bcast.parent().clone().into_children() {
let child_ref = self.child_ref.clone();
let used_dispatchers = parent.dispatchers();
let global_dispatcher = SYSTEM.dispatcher();
let module_name = module_path!().to_string();
global_dispatcher.register(used_dispatchers, &child_ref, module_name)?;
}
Ok(())
}
fn remove_from_dispatchers(&self) {
if let Some(parent) = self.bcast.parent().clone().into_children() {
let child_ref = self.child_ref.clone();
let used_dispatchers = parent.dispatchers();
let global_dispatcher = SYSTEM.dispatcher();
global_dispatcher.remove(used_dispatchers, &child_ref);
}
}
#[cfg(feature = "scaling")]
async fn cleanup_actors_stats(&mut self) {
let guard = self.state.lock().await;
let actor_stats_table = guard.actor_stats();
actor_stats_table.remove(&self.bcast.id().clone()).ok();
}
}
impl Future for Exec {
type Output = Result<(), ()>;
fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
Pin::new(&mut self.get_mut().0).poll(ctx)
}
}
impl Default for Init {
fn default() -> Self {
Init::new(|_| async { Ok(()) })
}
}
impl Debug for Init {
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
fmt.debug_struct("Init").finish()
}
}
impl Debug for Exec {
fn fmt(&self, fmt: &mut Formatter) -> fmt::Result {
fmt.debug_struct("Exec").finish()
}
}