use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::{Context, ExitReason, Pid, ProcessHandle, Received, Runtime};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Strategy {
OneForOne,
OneForAll,
RestForOne,
}
type ChildFn = Arc<dyn Fn(&Runtime) -> ProcessHandle + Send + Sync>;
pub struct Supervisor {
rt: Runtime,
strategy: Strategy,
children: Vec<ChildFn>,
max_restarts: u32,
within: Option<Duration>,
}
impl Runtime {
pub fn supervisor(&self, strategy: Strategy) -> Supervisor {
Supervisor {
rt: self.clone(),
strategy,
children: Vec::new(),
max_restarts: 3,
within: Some(Duration::from_secs(5)),
}
}
}
impl Supervisor {
pub fn child<F>(mut self, start: F) -> Self
where
F: Fn(&Runtime) -> ProcessHandle + Send + Sync + 'static,
{
self.children.push(Arc::new(start));
self
}
pub fn max_restarts(mut self, n: u32) -> Self {
self.max_restarts = n;
self
}
pub fn within(mut self, window: Duration) -> Self {
self.within = Some(window);
self
}
pub fn over_lifetime(mut self) -> Self {
self.within = None;
self
}
pub fn start(self) -> ProcessHandle {
let Supervisor {
rt,
strategy,
children,
max_restarts,
within,
} = self;
let sup_rt = rt.clone();
rt.spawn(move |mut ctx| async move {
let me = ctx.pid();
let mut pids: Vec<Pid> = (0..children.len())
.map(|i| start_child(&sup_rt, me, &children, i))
.collect();
let mut lifetime = 0u32;
let mut window: Vec<Instant> = Vec::new();
loop {
let dead = match ctx.recv().await {
Received::Down { pid, .. } => pid,
Received::Exit { .. } => {
for &p in &pids {
sup_rt.kill(p);
}
return;
}
_ => continue,
};
let Some(index) = pids.iter().position(|&p| p == dead) else {
continue;
};
if over_budget(&mut window, &mut lifetime, within, max_restarts) {
for &p in &pids {
sup_rt.kill(p);
}
sup_rt.exit(me, ExitReason::Crashed);
return;
}
match strategy {
Strategy::OneForOne => {
pids[index] = start_child(&sup_rt, me, &children, index);
}
Strategy::OneForAll => {
let survivors: Vec<Pid> = pids
.iter()
.enumerate()
.filter(|&(j, _)| j != index)
.map(|(_, &p)| p)
.collect();
terminate(&sup_rt, &mut ctx, survivors).await;
pids = (0..children.len())
.map(|i| start_child(&sup_rt, me, &children, i))
.collect();
}
Strategy::RestForOne => {
let later: Vec<Pid> = pids[index + 1..].to_vec();
terminate(&sup_rt, &mut ctx, later).await;
for j in index..pids.len() {
pids[j] = start_child(&sup_rt, me, &children, j);
}
}
}
}
})
}
}
fn start_child(rt: &Runtime, me: Pid, children: &[ChildFn], i: usize) -> Pid {
let pid = (children[i])(rt).pid();
rt.monitor(me, pid);
pid
}
async fn terminate(rt: &Runtime, ctx: &mut Context, mut targets: Vec<Pid>) {
for &p in &targets {
rt.kill(p);
}
while !targets.is_empty() {
if let Received::Down { pid, .. } = ctx.recv().await {
targets.retain(|&p| p != pid);
}
}
}
fn over_budget(
window: &mut Vec<Instant>,
lifetime: &mut u32,
within: Option<Duration>,
max_restarts: u32,
) -> bool {
match within {
Some(span) => {
let now = Instant::now();
window.push(now);
window.retain(|t| now.duration_since(*t) <= span);
max_restarts != 0 && window.len() as u32 > max_restarts
}
None => {
*lifetime += 1;
max_restarts != 0 && *lifetime > max_restarts
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
fn named(name: &'static str) -> impl Fn(&Runtime) -> ProcessHandle + Send + Sync + 'static {
move |r: &Runtime| {
let rt = r.clone();
let name = name.to_string();
r.spawn(move |ctx| async move {
rt.register(name, ctx.pid());
std::future::pending::<()>().await
})
}
}
async fn wait_for(rt: &Runtime, name: &str) -> Pid {
loop {
if let Some(pid) = rt.whereis(name) {
return pid;
}
tokio::task::yield_now().await;
}
}
async fn wait_for_change(rt: &Runtime, name: &str, old: Pid) -> Pid {
loop {
if let Some(pid) = rt.whereis(name) {
if pid != old {
return pid;
}
}
tokio::task::yield_now().await;
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn one_for_one_restarts_only_the_failed_child() {
let rt = Runtime::new();
let sup = rt
.supervisor(Strategy::OneForOne)
.child(named("c0"))
.child(named("c1"))
.start();
let c0 = wait_for(&rt, "c0").await;
let c1 = wait_for(&rt, "c1").await;
rt.kill(c0);
let c0b = wait_for_change(&rt, "c0", c0).await;
assert_ne!(c0b, c0, "the failed child is restarted with a fresh pid");
assert_eq!(rt.whereis("c1"), Some(c1), "a sibling is left untouched");
sup.kill();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn one_for_all_restarts_every_child() {
let rt = Runtime::new();
let sup = rt
.supervisor(Strategy::OneForAll)
.child(named("a0"))
.child(named("a1"))
.start();
let a0 = wait_for(&rt, "a0").await;
let a1 = wait_for(&rt, "a1").await;
rt.kill(a0);
let a0b = wait_for_change(&rt, "a0", a0).await;
let a1b = wait_for_change(&rt, "a1", a1).await;
assert_ne!(a0b, a0);
assert_ne!(
a1b, a1,
"every child is restarted, not just the one that died"
);
sup.kill();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn rest_for_one_restarts_the_failed_and_later_children() {
let rt = Runtime::new();
let sup = rt
.supervisor(Strategy::RestForOne)
.child(named("r0"))
.child(named("r1"))
.child(named("r2"))
.start();
let r0 = wait_for(&rt, "r0").await;
let r1 = wait_for(&rt, "r1").await;
let r2 = wait_for(&rt, "r2").await;
rt.kill(r1);
let r1b = wait_for_change(&rt, "r1", r1).await;
let r2b = wait_for_change(&rt, "r2", r2).await;
assert_eq!(
rt.whereis("r0"),
Some(r0),
"children before the failure stay"
);
assert_ne!(r1b, r1, "the failed child restarts");
assert_ne!(r2b, r2, "and every child after it restarts");
sup.kill();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn restart_intensity_stops_the_supervisor() {
let rt = Runtime::new();
let crashes = Arc::new(AtomicU32::new(0));
let counter = crashes.clone();
let crasher = move |r: &Runtime| {
let counter = counter.clone();
r.spawn(move |_ctx| async move {
counter.fetch_add(1, Ordering::Relaxed);
panic!("boom");
})
};
let _sup = rt
.supervisor(Strategy::OneForOne)
.max_restarts(3)
.within(Duration::from_secs(60))
.child(crasher)
.start();
loop {
if rt.process_count() == 0 {
break;
}
tokio::task::yield_now().await;
}
assert_eq!(
crashes.load(Ordering::Relaxed),
4,
"initial start + 3 restarts, then the supervisor gives up"
);
}
}