use std::collections::BTreeSet;
use std::pin::Pin;
use std::time::Duration;
use std::time::Instant;
use serde::Deserialize;
use serde::Serialize;
use crate::shutdown_brutal_kill;
use crate::shutdown_infinity;
use crate::shutdown_timeout;
use crate::AutoShutdown;
use crate::CallError;
use crate::ChildSpec;
use crate::ChildType;
use crate::Dest;
use crate::ExitReason;
use crate::From;
use crate::GenServer;
use crate::GenServerOptions;
use crate::Local;
use crate::Message;
use crate::Pid;
use crate::Process;
use crate::ProcessFlags;
use crate::Restart;
use crate::Shutdown;
use crate::SystemMessage;
#[derive(Clone)]
struct SupervisedChild {
spec: ChildSpec,
pid: Option<Pid>,
restarting: bool,
}
#[doc(hidden)]
#[derive(Serialize, Deserialize)]
pub enum SupervisorMessage {
TryAgainRestartPid(Pid),
TryAgainRestartId(String),
CountChildren,
CountChildrenSuccess(SupervisorCounts),
StartChild(Local<ChildSpec>),
StartChildSuccess(Option<Pid>),
StartChildError(SupervisorError),
TerminateChild(String),
TerminateChildSuccess,
TerminateChildError(SupervisorError),
RestartChild(String),
RestartChildSuccess(Option<Pid>),
RestartChildError(SupervisorError),
DeleteChild(String),
DeleteChildSuccess,
DeleteChildError(SupervisorError),
WhichChildren,
WhichChildrenSuccess(Vec<SupervisorChildInfo>),
}
#[derive(Debug, Serialize, Deserialize)]
pub enum SupervisorError {
CallError(CallError),
AlreadyStarted,
AlreadyPresent,
StartError(ExitReason),
NotFound,
Running,
Restarting,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SupervisorChildInfo {
id: String,
child: Option<Pid>,
child_type: ChildType,
restarting: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SupervisorCounts {
pub specs: usize,
pub active: usize,
pub supervisors: usize,
pub workers: usize,
}
#[derive(Debug, Clone, Copy)]
pub enum SupervisionStrategy {
OneForOne,
OneForAll,
RestForOne,
}
#[derive(Clone)]
pub struct Supervisor {
children: Vec<SupervisedChild>,
identifiers: BTreeSet<String>,
restarts: Vec<Instant>,
strategy: SupervisionStrategy,
auto_shutdown: AutoShutdown,
max_restarts: usize,
max_duration: Duration,
}
impl Supervisor {
pub const fn new() -> Self {
Self {
children: Vec::new(),
identifiers: BTreeSet::new(),
restarts: Vec::new(),
strategy: SupervisionStrategy::OneForOne,
auto_shutdown: AutoShutdown::Never,
max_restarts: 3,
max_duration: Duration::from_secs(5),
}
}
pub fn with_children<T: IntoIterator<Item = ChildSpec>>(children: T) -> Self {
let mut result = Self::new();
for child in children {
result = result.add_child(child);
}
result
}
pub fn add_child(mut self, child: ChildSpec) -> Self {
if self.identifiers.contains(&child.id) {
panic!("Child id was not unique!");
}
self.identifiers.insert(child.id.clone());
self.children.push(SupervisedChild {
spec: child,
pid: None,
restarting: false,
});
self
}
pub fn child_spec(self, options: GenServerOptions) -> ChildSpec {
ChildSpec::new("Supervisor")
.start(move || self.clone().start_link(options.clone()))
.child_type(ChildType::Supervisor)
}
pub const fn strategy(mut self, strategy: SupervisionStrategy) -> Self {
self.strategy = strategy;
self
}
pub const fn auto_shutdown(mut self, auto_shutdown: AutoShutdown) -> Self {
self.auto_shutdown = auto_shutdown;
self
}
pub const fn max_restarts(mut self, max_restarts: usize) -> Self {
self.max_restarts = max_restarts;
self
}
pub const fn max_duration(mut self, max_duration: Duration) -> Self {
self.max_duration = max_duration;
self
}
pub async fn start_link(self, options: GenServerOptions) -> Result<Pid, ExitReason> {
GenServer::start_link(self, options).await
}
pub async fn count_children<T: Into<Dest>>(
supervisor: T,
) -> Result<SupervisorCounts, SupervisorError> {
use SupervisorMessage::*;
match Supervisor::call(supervisor, CountChildren, None).await? {
CountChildrenSuccess(counts) => Ok(counts),
_ => unreachable!(),
}
}
pub async fn start_child<T: Into<Dest>>(
supervisor: T,
child: ChildSpec,
) -> Result<Option<Pid>, SupervisorError> {
use SupervisorMessage::*;
match Supervisor::call(supervisor, StartChild(Local::new(child)), None).await? {
StartChildSuccess(pid) => Ok(pid),
StartChildError(error) => Err(error),
_ => unreachable!(),
}
}
pub async fn terminate_child<T: Into<Dest>, I: Into<String>>(
supervisor: T,
child_id: I,
) -> Result<(), SupervisorError> {
use SupervisorMessage::*;
match Supervisor::call(supervisor, TerminateChild(child_id.into()), None).await? {
TerminateChildSuccess => Ok(()),
TerminateChildError(error) => Err(error),
_ => unreachable!(),
}
}
pub async fn restart_child<T: Into<Dest>, I: Into<String>>(
supervisor: T,
child_id: I,
) -> Result<Option<Pid>, SupervisorError> {
use SupervisorMessage::*;
match Supervisor::call(supervisor, RestartChild(child_id.into()), None).await? {
RestartChildSuccess(pid) => Ok(pid),
RestartChildError(error) => Err(error),
_ => unreachable!(),
}
}
pub async fn delete_child<T: Into<Dest>, I: Into<String>>(
supervisor: T,
child_id: I,
) -> Result<(), SupervisorError> {
use SupervisorMessage::*;
match Supervisor::call(supervisor, DeleteChild(child_id.into()), None).await? {
DeleteChildSuccess => Ok(()),
DeleteChildError(error) => Err(error),
_ => unreachable!(),
}
}
pub async fn which_children<T: Into<Dest>>(
supervisor: T,
) -> Result<Vec<SupervisorChildInfo>, SupervisorError> {
use SupervisorMessage::*;
match Supervisor::call(supervisor, WhichChildren, None).await? {
WhichChildrenSuccess(info) => Ok(info),
_ => unreachable!(),
}
}
async fn start_children(&mut self) -> Result<(), ExitReason> {
let mut remove: Vec<usize> = Vec::new();
for index in 0..self.children.len() {
match self.start_child_by_index(index).await {
Ok(pid) => {
let child = &mut self.children[index];
child.pid = pid;
child.restarting = false;
if child.is_temporary() && pid.is_none() {
remove.push(index);
}
}
Err(reason) => {
#[cfg(feature = "tracing")]
tracing::error!(reason = ?reason, child_id = ?self.children[index].spec.id, "Start error");
return Err(ExitReason::from("failed_to_start_child"));
}
}
}
for index in remove.into_iter().rev() {
self.remove_child(index);
}
Ok(())
}
async fn delete_child_by_id(&mut self, child_id: String) -> Result<(), SupervisorError> {
let index = self
.children
.iter()
.position(|child| child.spec.id == child_id);
let Some(index) = index else {
return Err(SupervisorError::NotFound);
};
let child = &self.children[index];
if child.restarting {
return Err(SupervisorError::Restarting);
} else if child.pid.is_some() {
return Err(SupervisorError::Running);
}
let child = self.children.remove(index);
self.identifiers.remove(&child.spec.id);
Ok(())
}
async fn terminate_child_by_id(&mut self, child_id: String) -> Result<(), SupervisorError> {
let index = self
.children
.iter()
.position(|child| child.spec.id == child_id);
if let Some(index) = index {
self.terminate_child_by_index(index).await;
Ok(())
} else {
Err(SupervisorError::NotFound)
}
}
async fn restart_child_by_id(
&mut self,
child_id: String,
) -> Result<Option<Pid>, SupervisorError> {
let index = self
.children
.iter()
.position(|child| child.spec.id == child_id);
let Some(index) = index else {
return Err(SupervisorError::NotFound);
};
let child = &mut self.children[index];
if child.restarting {
return Err(SupervisorError::Restarting);
} else if child.pid.is_some() {
return Err(SupervisorError::Running);
}
match self.start_child_by_index(index).await {
Ok(pid) => {
let child = &mut self.children[index];
child.pid = pid;
child.restarting = false;
Ok(pid)
}
Err(reason) => Err(SupervisorError::StartError(reason)),
}
}
async fn terminate_children(&mut self) {
let mut remove: Vec<usize> = Vec::new();
for (index, child) in self.children.iter_mut().enumerate().rev() {
if child.is_temporary() {
remove.push(index);
}
let Some(pid) = child.pid.take() else {
continue;
};
if let Err(reason) = shutdown(pid, child.shutdown()).await {
#[cfg(feature = "tracing")]
tracing::error!(reason = ?reason, child_pid = ?pid, "Shutdown error");
#[cfg(not(feature = "tracing"))]
let _ = reason;
}
}
for index in remove {
self.remove_child(index);
}
}
async fn terminate_child_by_index(&mut self, index: usize) {
let child = &mut self.children[index];
let Some(pid) = child.pid.take() else {
return;
};
child.restarting = false;
let _ = shutdown(pid, child.shutdown()).await;
}
async fn init_children(&mut self) -> Result<(), ExitReason> {
if let Err(reason) = self.start_children().await {
self.terminate_children().await;
return Err(reason);
}
Ok(())
}
async fn restart_exited_child(
&mut self,
pid: Pid,
reason: ExitReason,
) -> Result<(), ExitReason> {
let Some(index) = self.find_child(pid) else {
return Ok(());
};
let child = &mut self.children[index];
if child.is_permanent() {
#[cfg(feature = "tracing")]
tracing::error!(reason = ?reason, child_id = ?child.spec.id, child_pid = ?child.pid, "Child terminated");
if self.add_restart() {
return Err(ExitReason::from("shutdown"));
}
self.restart(index).await;
return Ok(());
}
if reason.is_normal() || reason == "shutdown" {
let child = self.remove_child(index);
if self.check_auto_shutdown(child) {
return Err(ExitReason::from("shutdown"));
} else {
return Ok(());
}
}
if child.is_transient() {
#[cfg(feature = "tracing")]
tracing::error!(reason = ?reason, child_id = ?child.spec.id, child_pid = ?child.pid, "Child terminated");
if self.add_restart() {
return Err(ExitReason::from("shutdown"));
}
self.restart(index).await;
return Ok(());
}
if child.is_temporary() {
#[cfg(feature = "tracing")]
tracing::error!(reason = ?reason, child_id = ?child.spec.id, child_pid = ?child.pid, "Child terminated");
let child = self.remove_child(index);
if self.check_auto_shutdown(child) {
return Err(ExitReason::from("shutdown"));
}
}
Ok(())
}
async fn restart(&mut self, index: usize) {
use SupervisorMessage::*;
match self.strategy {
SupervisionStrategy::OneForOne => {
match self.start_child_by_index(index).await {
Ok(pid) => {
let child = &mut self.children[index];
child.pid = pid;
child.restarting = false;
}
Err(reason) => {
let id = self.children[index].id();
#[cfg(feature = "tracing")]
tracing::error!(reason = ?reason, child_id = ?id, child_pid = ?self.children[index].pid, "Start error");
self.children[index].restarting = true;
Supervisor::cast(Process::current(), TryAgainRestartId(id));
}
};
}
SupervisionStrategy::RestForOne => {
if let Some((index, reason)) = self.restart_multiple_children(index, false).await {
let id = self.children[index].id();
#[cfg(feature = "tracing")]
tracing::error!(reason = ?reason, child_id = ?id, child_pid = ?self.children[index].pid, "Start error");
self.children[index].restarting = true;
Supervisor::cast(Process::current(), TryAgainRestartId(id));
}
}
SupervisionStrategy::OneForAll => {
if let Some((index, reason)) = self.restart_multiple_children(index, true).await {
let id = self.children[index].id();
#[cfg(feature = "tracing")]
tracing::error!(reason = ?reason, child_id = ?id, child_pid = ?self.children[index].pid, "Start error");
self.children[index].restarting = true;
Supervisor::cast(Process::current(), TryAgainRestartId(id));
}
}
}
}
async fn restart_multiple_children(
&mut self,
index: usize,
all: bool,
) -> Option<(usize, ExitReason)> {
let mut indices = Vec::new();
let range = if all {
0..self.children.len()
} else {
index..self.children.len()
};
for tindex in range {
indices.push(tindex);
if index == tindex {
continue;
}
self.terminate_child_by_index(tindex).await;
}
for sindex in indices {
match self.start_child_by_index(sindex).await {
Ok(pid) => {
let child = &mut self.children[sindex];
child.pid = pid;
child.restarting = false;
}
Err(reason) => {
return Some((sindex, reason));
}
}
}
None
}
async fn try_again_restart(&mut self, index: usize) -> Result<(), ExitReason> {
if self.add_restart() {
return Err(ExitReason::from("shutdown"));
}
if !self.children[index].restarting {
return Ok(());
}
self.restart(index).await;
Ok(())
}
async fn start_child_by_index(&mut self, index: usize) -> Result<Option<Pid>, ExitReason> {
let child = &mut self.children[index];
let start_child = Pin::from(child.spec.start.as_ref().unwrap()()).await;
match start_child {
Ok(pid) => {
#[cfg(feature = "tracing")]
tracing::info!(child_id = ?child.spec.id, child_pid = ?pid, "Started child");
Ok(Some(pid))
}
Err(reason) => {
if reason.is_ignore() {
#[cfg(feature = "tracing")]
tracing::info!(child_id = ?child.spec.id, child_pid = ?None::<Pid>, "Started child");
Ok(None)
} else {
Err(reason)
}
}
}
}
async fn start_new_child(&mut self, spec: ChildSpec) -> Result<Option<Pid>, SupervisorError> {
if self.identifiers.contains(&spec.id) {
let child = self
.children
.iter()
.find(|child| child.spec.id == spec.id)
.unwrap();
if child.pid.is_some() {
return Err(SupervisorError::AlreadyStarted);
} else {
return Err(SupervisorError::AlreadyPresent);
}
}
self.identifiers.insert(spec.id.clone());
self.children.push(SupervisedChild {
spec,
pid: None,
restarting: false,
});
match self.start_child_by_index(self.children.len() - 1).await {
Ok(pid) => {
let index = self.children.len() - 1;
let child = &mut self.children[index];
child.pid = pid;
child.restarting = false;
if child.is_temporary() && pid.is_none() {
self.children.remove(index);
}
Ok(pid)
}
Err(reason) => Err(SupervisorError::StartError(reason)),
}
}
fn check_auto_shutdown(&mut self, child: SupervisedChild) -> bool {
if matches!(self.auto_shutdown, AutoShutdown::Never) {
return false;
}
if !child.spec.significant {
return false;
}
if matches!(self.auto_shutdown, AutoShutdown::AnySignificant) {
return true;
}
self.children.iter().any(|child| {
if child.pid.is_none() {
return false;
}
child.spec.significant
})
}
fn add_restart(&mut self) -> bool {
let now = Instant::now();
let threshold = now - self.max_duration;
self.restarts.retain(|restart| *restart >= threshold);
self.restarts.push(now);
if self.restarts.len() > self.max_restarts {
#[cfg(feature = "tracing")]
tracing::error!(restarts = ?self.restarts.len(), threshold = ?self.max_duration, "Reached max restart intensity");
return true;
}
false
}
fn which_children_info(&mut self) -> Vec<SupervisorChildInfo> {
let mut result = Vec::with_capacity(self.children.len());
for child in &self.children {
result.push(SupervisorChildInfo {
id: child.spec.id.clone(),
child: child.pid,
child_type: child.spec.child_type,
restarting: child.restarting,
});
}
result
}
fn count_all_children(&mut self) -> SupervisorCounts {
let mut counts = SupervisorCounts {
specs: 0,
active: 0,
supervisors: 0,
workers: 0,
};
for child in &self.children {
counts.specs += 1;
if child.pid.is_some() {
counts.active += 1;
}
if matches!(child.spec.child_type, ChildType::Supervisor) {
counts.supervisors += 1;
} else {
counts.workers += 1;
}
}
counts
}
fn remove_child(&mut self, index: usize) -> SupervisedChild {
let child = self.children.remove(index);
self.identifiers.remove(&child.spec.id);
child
}
fn find_child(&mut self, pid: Pid) -> Option<usize> {
self.children
.iter()
.position(|child| child.pid.is_some_and(|cpid| cpid == pid))
}
fn find_child_id(&mut self, id: &str) -> Option<usize> {
self.children.iter().position(|child| child.spec.id == id)
}
}
impl SupervisedChild {
pub const fn is_permanent(&self) -> bool {
matches!(self.spec.restart, Restart::Permanent)
}
pub const fn is_transient(&self) -> bool {
matches!(self.spec.restart, Restart::Transient)
}
pub const fn is_temporary(&self) -> bool {
matches!(self.spec.restart, Restart::Temporary)
}
pub fn id(&self) -> String {
self.spec.id.clone()
}
pub const fn shutdown(&self) -> Shutdown {
match self.spec.shutdown {
None => match self.spec.child_type {
ChildType::Worker => Shutdown::Duration(Duration::from_secs(5)),
ChildType::Supervisor => Shutdown::Infinity,
},
Some(shutdown) => shutdown,
}
}
}
impl Default for Supervisor {
fn default() -> Self {
Self::new()
}
}
impl GenServer for Supervisor {
type Message = SupervisorMessage;
async fn init(&mut self) -> Result<(), ExitReason> {
Process::set_flags(ProcessFlags::TRAP_EXIT);
self.init_children().await
}
async fn terminate(&mut self, _reason: ExitReason) {
self.terminate_children().await;
}
async fn handle_cast(&mut self, message: Self::Message) -> Result<(), ExitReason> {
use SupervisorMessage::*;
match message {
TryAgainRestartPid(pid) => {
if let Some(index) = self.find_child(pid) {
return self.try_again_restart(index).await;
}
}
TryAgainRestartId(id) => {
if let Some(index) = self.find_child_id(&id) {
return self.try_again_restart(index).await;
}
}
_ => unreachable!(),
}
Ok(())
}
async fn handle_call(
&mut self,
message: Self::Message,
_from: From,
) -> Result<Option<Self::Message>, ExitReason> {
use SupervisorMessage::*;
match message {
CountChildren => {
let counts = self.count_all_children();
Ok(Some(CountChildrenSuccess(counts)))
}
StartChild(spec) => match self.start_new_child(spec.into_inner()).await {
Ok(pid) => Ok(Some(StartChildSuccess(pid))),
Err(error) => Ok(Some(StartChildError(error))),
},
TerminateChild(child_id) => match self.terminate_child_by_id(child_id).await {
Ok(()) => Ok(Some(TerminateChildSuccess)),
Err(error) => Ok(Some(TerminateChildError(error))),
},
RestartChild(child_id) => match self.restart_child_by_id(child_id).await {
Ok(pid) => Ok(Some(RestartChildSuccess(pid))),
Err(error) => Ok(Some(RestartChildError(error))),
},
DeleteChild(child_id) => match self.delete_child_by_id(child_id).await {
Ok(()) => Ok(Some(DeleteChildSuccess)),
Err(error) => Ok(Some(DeleteChildError(error))),
},
WhichChildren => {
let children = self.which_children_info();
Ok(Some(WhichChildrenSuccess(children)))
}
_ => unreachable!(),
}
}
async fn handle_info(&mut self, info: Message<Self::Message>) -> Result<(), ExitReason> {
match info {
Message::System(SystemMessage::Exit(pid, reason)) => {
self.restart_exited_child(pid, reason).await
}
_ => Ok(()),
}
}
}
impl std::convert::From<CallError> for SupervisorError {
fn from(value: CallError) -> Self {
Self::CallError(value)
}
}
async fn shutdown(pid: Pid, shutdown: Shutdown) -> Result<(), ExitReason> {
let monitor = Process::monitor(pid);
match shutdown {
Shutdown::BrutalKill => shutdown_brutal_kill(pid, monitor).await,
Shutdown::Duration(timeout) => shutdown_timeout(pid, monitor, timeout).await,
Shutdown::Infinity => shutdown_infinity(pid, monitor).await,
}
}