use crate::restart::{RestartIntensity, RestartPolicy, RestartStrategy, RestartTracker};
use crate::supervisor_common::{WorkerTermination, run_worker};
use crate::types::{ChildExitReason, ChildId, ChildInfo, ChildType, WorkerContext};
use crate::worker::Worker;
use std::fmt;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinHandle;
pub(crate) struct StatefulWorkerSpec<W: Worker> {
pub id: ChildId,
pub worker_factory: Arc<dyn Fn(Arc<WorkerContext>) -> W + Send + Sync>,
pub restart_policy: RestartPolicy,
pub context: Arc<WorkerContext>,
}
impl<W: Worker> Clone for StatefulWorkerSpec<W> {
fn clone(&self) -> Self {
Self {
id: self.id.clone(),
worker_factory: Arc::clone(&self.worker_factory),
restart_policy: self.restart_policy,
context: Arc::clone(&self.context),
}
}
}
impl<W: Worker> StatefulWorkerSpec<W> {
pub(crate) fn new(
id: impl Into<String>,
factory: impl Fn(Arc<WorkerContext>) -> W + Send + Sync + 'static,
restart_policy: RestartPolicy,
context: Arc<WorkerContext>,
) -> Self {
Self {
id: id.into(),
worker_factory: Arc::new(factory),
restart_policy,
context,
}
}
pub(crate) fn create_worker(&self) -> W {
(self.worker_factory)(Arc::clone(&self.context))
}
}
impl<W: Worker> fmt::Debug for StatefulWorkerSpec<W> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StatefulWorkerSpec")
.field("id", &self.id)
.field("restart_policy", &self.restart_policy)
.finish_non_exhaustive()
}
}
pub(crate) struct StatefulWorkerProcess<W: Worker> {
pub spec: StatefulWorkerSpec<W>,
pub handle: Option<JoinHandle<()>>,
}
impl<W: Worker> StatefulWorkerProcess<W> {
pub(crate) fn spawn<Cmd>(
spec: StatefulWorkerSpec<W>,
supervisor_name: String,
control_tx: mpsc::UnboundedSender<Cmd>,
) -> Self
where
Cmd: From<WorkerTermination> + Send + 'static,
{
let worker = spec.create_worker();
let worker_id = spec.id.clone();
let handle = tokio::spawn(async move {
run_worker(supervisor_name, worker_id, worker, control_tx, None).await;
});
Self {
spec,
handle: Some(handle),
}
}
pub(crate) fn spawn_with_link<Cmd>(
spec: StatefulWorkerSpec<W>,
supervisor_name: String,
control_tx: mpsc::UnboundedSender<Cmd>,
init_tx: tokio::sync::oneshot::Sender<Result<(), String>>,
) -> Self
where
Cmd: From<WorkerTermination> + Send + 'static,
{
let worker = spec.create_worker();
let worker_id = spec.id.clone();
let handle = tokio::spawn(async move {
run_worker(
supervisor_name,
worker_id,
worker,
control_tx,
Some(init_tx),
)
.await;
});
Self {
spec,
handle: Some(handle),
}
}
pub(crate) async fn stop(&mut self) {
if let Some(handle) = self.handle.take() {
handle.abort();
drop(handle.await);
}
}
}
impl<W: Worker> Drop for StatefulWorkerProcess<W> {
fn drop(&mut self) {
if let Some(handle) = self.handle.take() {
handle.abort();
}
}
}
pub(crate) enum StatefulChildSpec<W: Worker> {
Worker(StatefulWorkerSpec<W>),
Supervisor(Arc<StatefulSupervisorSpec<W>>),
}
impl<W: Worker> Clone for StatefulChildSpec<W> {
fn clone(&self) -> Self {
match self {
StatefulChildSpec::Worker(w) => StatefulChildSpec::Worker(w.clone()),
StatefulChildSpec::Supervisor(s) => StatefulChildSpec::Supervisor(Arc::clone(s)),
}
}
}
pub struct StatefulSupervisorSpec<W: Worker> {
pub(crate) name: String,
pub(crate) children: Vec<StatefulChildSpec<W>>,
pub(crate) restart_strategy: RestartStrategy,
pub(crate) restart_intensity: RestartIntensity,
pub(crate) context: Arc<WorkerContext>,
}
impl<W: Worker> Clone for StatefulSupervisorSpec<W> {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
children: self.children.clone(),
restart_strategy: self.restart_strategy,
restart_intensity: self.restart_intensity,
context: Arc::clone(&self.context),
}
}
}
impl<W: Worker> StatefulSupervisorSpec<W> {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
children: Vec::new(),
restart_strategy: RestartStrategy::default(),
restart_intensity: RestartIntensity::default(),
context: Arc::new(WorkerContext::new()),
}
}
#[must_use]
pub fn with_restart_strategy(mut self, strategy: RestartStrategy) -> Self {
self.restart_strategy = strategy;
self
}
#[must_use]
pub fn with_restart_intensity(mut self, intensity: RestartIntensity) -> Self {
self.restart_intensity = intensity;
self
}
#[must_use]
pub fn with_worker(
mut self,
id: impl Into<String>,
factory: impl Fn(Arc<WorkerContext>) -> W + Send + Sync + 'static,
restart_policy: RestartPolicy,
) -> Self {
self.children
.push(StatefulChildSpec::Worker(StatefulWorkerSpec::new(
id,
factory,
restart_policy,
Arc::clone(&self.context),
)));
self
}
#[must_use]
pub fn with_supervisor(mut self, supervisor: StatefulSupervisorSpec<W>) -> Self {
self.children
.push(StatefulChildSpec::Supervisor(Arc::new(supervisor)));
self
}
#[must_use]
pub fn context(&self) -> &Arc<WorkerContext> {
&self.context
}
}
pub(crate) enum StatefulChild<W: Worker> {
Worker(StatefulWorkerProcess<W>),
Supervisor {
handle: StatefulSupervisorHandle<W>,
spec: Arc<StatefulSupervisorSpec<W>>,
},
}
impl<W: Worker> StatefulChild<W> {
#[inline]
pub fn id(&self) -> &str {
match self {
StatefulChild::Worker(w) => &w.spec.id,
StatefulChild::Supervisor { spec, .. } => &spec.name,
}
}
#[inline]
pub fn child_type(&self) -> ChildType {
match self {
StatefulChild::Worker(_) => ChildType::Worker,
StatefulChild::Supervisor { .. } => ChildType::Supervisor,
}
}
#[inline]
#[allow(clippy::unnecessary_wraps)]
pub fn restart_policy(&self) -> Option<RestartPolicy> {
match self {
StatefulChild::Worker(w) => Some(w.spec.restart_policy),
StatefulChild::Supervisor { .. } => Some(RestartPolicy::Permanent),
}
}
pub async fn shutdown(&mut self) {
match self {
StatefulChild::Worker(w) => w.stop().await,
StatefulChild::Supervisor { handle, .. } => {
let _shutdown_result = handle.shutdown().await;
}
}
}
}
pub(crate) enum StatefulRestartInfo<W: Worker> {
Worker(StatefulWorkerSpec<W>),
Supervisor(Arc<StatefulSupervisorSpec<W>>),
}
#[derive(Debug)]
pub enum StatefulSupervisorError {
NoChildren(String),
AllChildrenFailed(String),
ShuttingDown(String),
ChildAlreadyExists(String),
ChildNotFound(String),
InitializationFailed {
child_id: String,
reason: String,
},
InitializationTimeout {
child_id: String,
timeout: std::time::Duration,
},
}
impl fmt::Display for StatefulSupervisorError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
StatefulSupervisorError::NoChildren(name) => {
write!(f, "stateful supervisor '{name}' has no children")
}
StatefulSupervisorError::AllChildrenFailed(name) => {
write!(
f,
"all children failed for stateful supervisor '{name}' - restart intensity limit exceeded"
)
}
StatefulSupervisorError::ShuttingDown(name) => {
write!(
f,
"stateful supervisor '{name}' is shutting down - operation not permitted"
)
}
StatefulSupervisorError::ChildAlreadyExists(id) => {
write!(
f,
"child with id '{id}' already exists - use a unique identifier"
)
}
StatefulSupervisorError::ChildNotFound(id) => {
write!(
f,
"child with id '{id}' not found - it may have already terminated"
)
}
StatefulSupervisorError::InitializationFailed { child_id, reason } => {
write!(f, "child '{child_id}' initialization failed: {reason}")
}
StatefulSupervisorError::InitializationTimeout { child_id, timeout } => {
write!(
f,
"child '{child_id}' initialization timed out after {timeout:?}"
)
}
}
}
}
impl std::error::Error for StatefulSupervisorError {}
pub(crate) enum StatefulSupervisorCommand<W: Worker> {
StartChild {
spec: StatefulWorkerSpec<W>,
respond_to: oneshot::Sender<Result<ChildId, StatefulSupervisorError>>,
},
StartChildLinked {
spec: StatefulWorkerSpec<W>,
timeout: std::time::Duration,
respond_to: oneshot::Sender<Result<ChildId, StatefulSupervisorError>>,
},
TerminateChild {
id: ChildId,
respond_to: oneshot::Sender<Result<(), StatefulSupervisorError>>,
},
WhichChildren {
respond_to: oneshot::Sender<Result<Vec<ChildInfo>, StatefulSupervisorError>>,
},
GetRestartStrategy {
respond_to: oneshot::Sender<RestartStrategy>,
},
GetUptime {
respond_to: oneshot::Sender<u64>,
},
ChildTerminated {
id: ChildId,
reason: ChildExitReason,
},
Shutdown,
}
impl<W: Worker> From<WorkerTermination> for StatefulSupervisorCommand<W> {
fn from(term: WorkerTermination) -> Self {
StatefulSupervisorCommand::ChildTerminated {
id: term.id,
reason: term.reason,
}
}
}
pub(crate) struct StatefulSupervisorRuntime<W: Worker> {
name: String,
children: Vec<StatefulChild<W>>,
control_rx: mpsc::UnboundedReceiver<StatefulSupervisorCommand<W>>,
control_tx: mpsc::UnboundedSender<StatefulSupervisorCommand<W>>,
restart_strategy: RestartStrategy,
restart_tracker: RestartTracker,
created_at: std::time::Instant,
}
impl<W: Worker> StatefulSupervisorRuntime<W> {
pub(crate) fn new(
spec: StatefulSupervisorSpec<W>,
control_rx: mpsc::UnboundedReceiver<StatefulSupervisorCommand<W>>,
control_tx: mpsc::UnboundedSender<StatefulSupervisorCommand<W>>,
) -> Self {
let mut children = Vec::with_capacity(spec.children.len());
for child_spec in spec.children {
match child_spec {
StatefulChildSpec::Worker(worker_spec) => {
let worker = StatefulWorkerProcess::spawn(
worker_spec,
spec.name.clone(),
control_tx.clone(),
);
children.push(StatefulChild::Worker(worker));
}
StatefulChildSpec::Supervisor(supervisor_spec) => {
let supervisor = StatefulSupervisorHandle::start((*supervisor_spec).clone());
children.push(StatefulChild::Supervisor {
handle: supervisor,
spec: Arc::clone(&supervisor_spec),
});
}
}
}
Self {
name: spec.name,
children,
control_rx,
control_tx,
restart_strategy: spec.restart_strategy,
restart_tracker: RestartTracker::new(spec.restart_intensity),
created_at: std::time::Instant::now(),
}
}
pub(crate) async fn run(mut self) {
while let Some(command) = self.control_rx.recv().await {
match command {
StatefulSupervisorCommand::StartChild { spec, respond_to } => {
let result = self.handle_start_child(spec);
let _send = respond_to.send(result);
}
StatefulSupervisorCommand::StartChildLinked {
spec,
timeout,
respond_to,
} => {
let result = self.handle_start_child_linked(spec, timeout).await;
let _send = respond_to.send(result);
}
StatefulSupervisorCommand::TerminateChild { id, respond_to } => {
let result = self.handle_terminate_child(&id).await;
let _send = respond_to.send(result);
}
StatefulSupervisorCommand::WhichChildren { respond_to } => {
let result = self.handle_which_children();
let _send = respond_to.send(result);
}
StatefulSupervisorCommand::GetRestartStrategy { respond_to } => {
let _send = respond_to.send(self.restart_strategy);
}
StatefulSupervisorCommand::GetUptime { respond_to } => {
let uptime = self.created_at.elapsed().as_secs();
let _send = respond_to.send(uptime);
}
StatefulSupervisorCommand::ChildTerminated { id, reason } => {
self.handle_child_terminated(id, reason).await;
}
StatefulSupervisorCommand::Shutdown => {
self.shutdown_children().await;
return;
}
}
}
self.shutdown_children().await;
}
fn handle_start_child(
&mut self,
spec: StatefulWorkerSpec<W>,
) -> Result<ChildId, StatefulSupervisorError> {
if self.children.iter().any(|c| c.id() == spec.id) {
return Err(StatefulSupervisorError::ChildAlreadyExists(spec.id.clone()));
}
let id = spec.id.clone();
let worker = StatefulWorkerProcess::spawn(spec, self.name.clone(), self.control_tx.clone());
self.children.push(StatefulChild::Worker(worker));
tracing::debug!(
supervisor = %self.name,
child = %id,
"dynamically started child"
);
Ok(id)
}
async fn handle_start_child_linked(
&mut self,
spec: StatefulWorkerSpec<W>,
timeout: std::time::Duration,
) -> Result<ChildId, StatefulSupervisorError> {
if self.children.iter().any(|c| c.id() == spec.id) {
return Err(StatefulSupervisorError::ChildAlreadyExists(spec.id.clone()));
}
let id = spec.id.clone();
let (init_tx, init_rx) = oneshot::channel();
let worker = StatefulWorkerProcess::spawn_with_link(
spec,
self.name.clone(),
self.control_tx.clone(),
init_tx,
);
let init_result = tokio::time::timeout(timeout, init_rx).await;
match init_result {
Ok(Ok(Ok(()))) => {
self.children.push(StatefulChild::Worker(worker));
tracing::debug!(
supervisor = %self.name,
child = %id,
"linked child started successfully"
);
Ok(id)
}
Ok(Ok(Err(reason))) => {
tracing::error!(
supervisor = %self.name,
child = %id,
reason = %reason,
"linked child initialization failed"
);
Err(StatefulSupervisorError::InitializationFailed {
child_id: id,
reason,
})
}
Ok(Err(_)) => {
tracing::error!(
supervisor = %self.name,
child = %id,
"linked child panicked during initialization"
);
Err(StatefulSupervisorError::InitializationFailed {
child_id: id,
reason: "worker panicked during initialization".to_owned(),
})
}
Err(_) => {
tracing::error!(
supervisor = %self.name,
child = %id,
timeout_secs = ?timeout.as_secs(),
"linked child initialization timed out"
);
Err(StatefulSupervisorError::InitializationTimeout {
child_id: id,
timeout,
})
}
}
}
async fn handle_terminate_child(&mut self, id: &str) -> Result<(), StatefulSupervisorError> {
let position = self
.children
.iter()
.position(|c| c.id() == id)
.ok_or_else(|| StatefulSupervisorError::ChildNotFound(id.to_owned()))?;
let mut child = self.children.remove(position);
child.shutdown().await;
tracing::debug!(
supervisor = %self.name,
child = %id,
"terminated child"
);
Ok(())
}
#[allow(clippy::unnecessary_wraps)]
fn handle_which_children(&self) -> Result<Vec<ChildInfo>, StatefulSupervisorError> {
let info = self
.children
.iter()
.map(|child| ChildInfo {
id: child.id().to_owned(),
child_type: child.child_type(),
restart_policy: child.restart_policy(),
})
.collect();
Ok(info)
}
#[allow(clippy::indexing_slicing)]
async fn handle_child_terminated(&mut self, id: ChildId, reason: ChildExitReason) {
tracing::debug!(
supervisor = %self.name,
child = %id,
reason = ?reason,
"child terminated"
);
let Some(position) = self.children.iter().position(|c| c.id() == id) else {
tracing::warn!(
supervisor = %self.name,
child = %id,
"terminated child not found in list"
);
return;
};
let should_restart = match &self.children[position] {
StatefulChild::Worker(w) => match w.spec.restart_policy {
RestartPolicy::Permanent => true,
RestartPolicy::Temporary => false,
RestartPolicy::Transient => reason == ChildExitReason::Abnormal,
},
StatefulChild::Supervisor { .. } => true, };
if !should_restart {
tracing::debug!(
supervisor = %self.name,
child = %id,
policy = ?self.children[position].restart_policy(),
reason = ?reason,
"not restarting child"
);
self.children.remove(position);
return;
}
if self.restart_tracker.record_restart() {
tracing::error!(
supervisor = %self.name,
"restart intensity exceeded, shutting down"
);
self.shutdown_children().await;
return;
}
match self.restart_strategy {
RestartStrategy::OneForOne => {
self.restart_child(position).await;
}
RestartStrategy::OneForAll => {
self.restart_all_children().await;
}
RestartStrategy::RestForOne => {
self.restart_from(position).await;
}
}
}
#[allow(clippy::indexing_slicing)]
async fn restart_child(&mut self, position: usize) {
let restart_info = match &self.children[position] {
StatefulChild::Worker(worker) => StatefulRestartInfo::Worker(worker.spec.clone()),
StatefulChild::Supervisor { spec, .. } => {
StatefulRestartInfo::Supervisor(Arc::clone(spec))
}
};
self.children[position].shutdown().await;
match restart_info {
StatefulRestartInfo::Worker(spec) => {
tracing::debug!(
supervisor = %self.name,
worker = %spec.id,
"restarting worker"
);
let new_worker = StatefulWorkerProcess::spawn(
spec.clone(),
self.name.clone(),
self.control_tx.clone(),
);
self.children[position] = StatefulChild::Worker(new_worker);
tracing::debug!(
supervisor = %self.name,
worker = %spec.id,
"worker restarted"
);
}
StatefulRestartInfo::Supervisor(spec) => {
let name = spec.name.clone();
tracing::debug!(
supervisor = %self.name,
child_supervisor = %name,
"restarting supervisor"
);
let new_handle = StatefulSupervisorHandle::start((*spec).clone());
self.children[position] = StatefulChild::Supervisor {
handle: new_handle,
spec,
};
tracing::debug!(
supervisor = %self.name,
child_supervisor = %name,
"supervisor restarted"
);
}
}
}
async fn restart_all_children(&mut self) {
tracing::debug!(
supervisor = %self.name,
"restarting all children (one_for_all)"
);
for child in &mut self.children {
child.shutdown().await;
}
for child in &mut self.children {
if let StatefulChild::Worker(worker) = child {
let spec = worker.spec.clone();
let new_worker = StatefulWorkerProcess::spawn(
spec.clone(),
self.name.clone(),
self.control_tx.clone(),
);
*child = StatefulChild::Worker(new_worker);
tracing::debug!(
supervisor = %self.name,
child = %spec.id,
"child restarted"
);
}
}
}
#[allow(clippy::indexing_slicing)]
async fn restart_from(&mut self, position: usize) {
tracing::debug!(
supervisor = %self.name,
position = %position,
"restarting from position (rest_for_one)"
);
for i in position..self.children.len() {
self.children[i].shutdown().await;
if let StatefulChild::Worker(worker) = &self.children[i] {
let spec = worker.spec.clone();
let new_worker = StatefulWorkerProcess::spawn(
spec.clone(),
self.name.clone(),
self.control_tx.clone(),
);
self.children[i] = StatefulChild::Worker(new_worker);
tracing::debug!(
supervisor = %self.name,
child = %spec.id,
"child restarted"
);
}
}
}
async fn shutdown_children(&mut self) {
for mut child in self.children.drain(..) {
let id = child.id().to_owned();
child.shutdown().await;
tracing::debug!(
supervisor = %self.name,
child = %id,
"shut down child"
);
}
}
}
#[derive(Clone)]
pub struct StatefulSupervisorHandle<W: Worker> {
pub(crate) name: Arc<String>,
pub(crate) control_tx: mpsc::UnboundedSender<StatefulSupervisorCommand<W>>,
}
impl<W: Worker> StatefulSupervisorHandle<W> {
#[must_use]
pub fn start(spec: StatefulSupervisorSpec<W>) -> Self {
let (control_tx, control_rx) = mpsc::unbounded_channel();
let name_arc = Arc::new(spec.name.clone());
let runtime = StatefulSupervisorRuntime::new(spec, control_rx, control_tx.clone());
let runtime_name = Arc::clone(&name_arc);
tokio::spawn(async move {
runtime.run().await;
tracing::debug!(name = %*runtime_name, "supervisor stopped");
});
Self {
name: name_arc,
control_tx,
}
}
pub async fn start_child(
&self,
id: impl Into<String>,
factory: impl Fn(Arc<WorkerContext>) -> W + Send + Sync + 'static,
restart_policy: RestartPolicy,
context: Arc<WorkerContext>,
) -> Result<ChildId, StatefulSupervisorError> {
let (result_tx, result_rx) = oneshot::channel();
let spec = StatefulWorkerSpec::new(id, factory, restart_policy, context);
self.control_tx
.send(StatefulSupervisorCommand::StartChild {
spec,
respond_to: result_tx,
})
.map_err(|_| StatefulSupervisorError::ShuttingDown(self.name().to_owned()))?;
result_rx
.await
.map_err(|_| StatefulSupervisorError::ShuttingDown(self.name().to_owned()))?
}
pub async fn start_child_linked(
&self,
id: impl Into<String>,
factory: impl Fn(Arc<WorkerContext>) -> W + Send + Sync + 'static,
restart_policy: RestartPolicy,
context: Arc<WorkerContext>,
timeout: std::time::Duration,
) -> Result<ChildId, StatefulSupervisorError> {
let (result_tx, result_rx) = oneshot::channel();
let spec = StatefulWorkerSpec::new(id, factory, restart_policy, context);
self.control_tx
.send(StatefulSupervisorCommand::StartChildLinked {
spec,
timeout,
respond_to: result_tx,
})
.map_err(|_| StatefulSupervisorError::ShuttingDown(self.name().to_owned()))?;
result_rx
.await
.map_err(|_| StatefulSupervisorError::ShuttingDown(self.name().to_owned()))?
}
pub async fn terminate_child(&self, id: &str) -> Result<(), StatefulSupervisorError> {
let (result_tx, result_rx) = oneshot::channel();
self.control_tx
.send(StatefulSupervisorCommand::TerminateChild {
id: id.to_owned(),
respond_to: result_tx,
})
.map_err(|_| StatefulSupervisorError::ShuttingDown(self.name().to_owned()))?;
result_rx
.await
.map_err(|_| StatefulSupervisorError::ShuttingDown(self.name().to_owned()))?
}
pub async fn which_children(&self) -> Result<Vec<ChildInfo>, StatefulSupervisorError> {
let (result_tx, result_rx) = oneshot::channel();
self.control_tx
.send(StatefulSupervisorCommand::WhichChildren {
respond_to: result_tx,
})
.map_err(|_| StatefulSupervisorError::ShuttingDown(self.name().to_owned()))?;
result_rx
.await
.map_err(|_| StatefulSupervisorError::ShuttingDown(self.name().to_owned()))?
}
#[allow(clippy::unused_async)]
pub async fn shutdown(&self) -> Result<(), StatefulSupervisorError> {
self.control_tx
.send(StatefulSupervisorCommand::Shutdown)
.map_err(|_| StatefulSupervisorError::ShuttingDown(self.name().to_owned()))?;
Ok(())
}
#[must_use]
pub fn name(&self) -> &str {
self.name.as_str()
}
pub async fn restart_strategy(&self) -> Result<RestartStrategy, StatefulSupervisorError> {
let (result_tx, result_rx) = oneshot::channel();
self.control_tx
.send(StatefulSupervisorCommand::GetRestartStrategy {
respond_to: result_tx,
})
.map_err(|_| StatefulSupervisorError::ShuttingDown(self.name().to_owned()))?;
result_rx
.await
.map_err(|_| StatefulSupervisorError::ShuttingDown(self.name().to_owned()))
}
pub async fn uptime(&self) -> Result<u64, StatefulSupervisorError> {
let (result_tx, result_rx) = oneshot::channel();
self.control_tx
.send(StatefulSupervisorCommand::GetUptime {
respond_to: result_tx,
})
.map_err(|_| StatefulSupervisorError::ShuttingDown(self.name().to_owned()))?;
result_rx
.await
.map_err(|_| StatefulSupervisorError::ShuttingDown(self.name().to_owned()))
}
}