use std::time::Duration;
use crate::command::Command;
use crate::error::Result;
use crate::result::ProcessResult;
use crate::runner::{JobRunner, ProcessRunner};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum RestartPolicy {
Always,
OnCrash,
Never,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum StopReason {
Predicate,
PolicySatisfied,
RestartsExhausted,
}
#[derive(Debug)]
#[non_exhaustive]
pub struct SupervisionOutcome {
pub final_result: ProcessResult<String>,
pub restarts: u32,
pub stopped: StopReason,
pub storm_pauses: u32,
}
pub struct Supervisor<R: ProcessRunner = JobRunner> {
command: Command,
runner: R,
policy: RestartPolicy,
max_restarts: Option<u32>,
backoff_base: Duration,
backoff_factor: f64,
max_backoff: Duration,
jitter: bool,
failure_decay: Duration,
failure_threshold: f64,
storm_pause: Option<Duration>,
#[allow(clippy::type_complexity)]
stop_when: Option<Box<dyn Fn(&ProcessResult<String>) -> bool + Send + Sync>>,
}
impl<R: ProcessRunner> std::fmt::Debug for Supervisor<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Supervisor")
.field("policy", &self.policy)
.field("max_restarts", &self.max_restarts)
.field("backoff_base", &self.backoff_base)
.field("backoff_factor", &self.backoff_factor)
.field("max_backoff", &self.max_backoff)
.field("jitter", &self.jitter)
.field("failure_decay", &self.failure_decay)
.field("failure_threshold", &self.failure_threshold)
.field("storm_pause", &self.storm_pause)
.field("has_stop_when", &self.stop_when.is_some())
.finish_non_exhaustive()
}
}
impl Supervisor<JobRunner> {
pub fn new(command: Command) -> Self {
Supervisor {
command,
runner: JobRunner::new(),
policy: RestartPolicy::OnCrash,
max_restarts: None,
backoff_base: Duration::from_millis(200),
backoff_factor: 2.0,
max_backoff: Duration::from_secs(30),
jitter: true,
failure_decay: Duration::from_secs(30),
failure_threshold: 5.0,
storm_pause: None,
stop_when: None,
}
}
}
impl<R: ProcessRunner> Supervisor<R> {
#[must_use]
pub fn with_runner<R2: ProcessRunner>(self, runner: R2) -> Supervisor<R2> {
Supervisor {
command: self.command,
runner,
policy: self.policy,
max_restarts: self.max_restarts,
backoff_base: self.backoff_base,
backoff_factor: self.backoff_factor,
max_backoff: self.max_backoff,
jitter: self.jitter,
failure_decay: self.failure_decay,
failure_threshold: self.failure_threshold,
storm_pause: self.storm_pause,
stop_when: self.stop_when,
}
}
#[must_use]
pub fn restart(mut self, policy: RestartPolicy) -> Self {
self.policy = policy;
self
}
#[must_use]
pub fn max_restarts(mut self, n: u32) -> Self {
self.max_restarts = Some(n);
self
}
#[must_use]
pub fn backoff(mut self, base: Duration, factor: f64) -> Self {
self.backoff_base = base;
self.backoff_factor = factor;
self
}
#[must_use]
pub fn max_backoff(mut self, cap: Duration) -> Self {
self.max_backoff = cap;
self
}
#[must_use]
pub fn jitter(mut self, enabled: bool) -> Self {
self.jitter = enabled;
self
}
#[must_use]
pub fn storm_pause(mut self, pause: Duration) -> Self {
self.storm_pause = Some(pause);
self
}
#[must_use]
pub fn failure_decay(mut self, decay: Duration) -> Self {
self.failure_decay = decay;
self
}
#[must_use]
pub fn failure_threshold(mut self, threshold: f64) -> Self {
self.failure_threshold = threshold;
self
}
#[must_use]
pub fn stop_when(
mut self,
predicate: impl Fn(&ProcessResult<String>) -> bool + Send + Sync + 'static,
) -> Self {
self.stop_when = Some(Box::new(predicate));
self
}
pub async fn run(self) -> Result<SupervisionOutcome> {
let factor = if self.backoff_factor.is_finite() {
self.backoff_factor.max(1.0)
} else {
1.0
};
let mut restarts: u32 = 0;
let mut storm = StormState::new();
loop {
match self.runner.output(&self.command).await {
Ok(result) => {
if let Some(predicate) = &self.stop_when
&& predicate(&result)
{
return Ok(self.outcome(result, restarts, &storm, StopReason::Predicate));
}
let crashed = result.code() != Some(0);
let wants_restart = match self.policy {
RestartPolicy::Always => true,
RestartPolicy::OnCrash => crashed,
RestartPolicy::Never => false,
};
if !wants_restart {
return Ok(self.outcome(
result,
restarts,
&storm,
StopReason::PolicySatisfied,
));
}
if self.max_restarts.is_some_and(|max| restarts >= max) {
return Ok(self.outcome(
result,
restarts,
&storm,
StopReason::RestartsExhausted,
));
}
if crashed {
self.storm_gate(&mut storm).await;
}
self.sleep_backoff(restarts, factor).await;
restarts += 1;
}
Err(err) => {
#[cfg(feature = "cancellation")]
if matches!(err, crate::Error::Cancelled { .. }) {
return Err(err);
}
let wants_restart = !matches!(self.policy, RestartPolicy::Never);
if !wants_restart || self.max_restarts.is_some_and(|max| restarts >= max) {
return Err(err);
}
self.storm_gate(&mut storm).await;
self.sleep_backoff(restarts, factor).await;
restarts += 1;
}
}
}
}
fn outcome(
&self,
final_result: ProcessResult<String>,
restarts: u32,
storm: &StormState,
stopped: StopReason,
) -> SupervisionOutcome {
SupervisionOutcome {
final_result,
restarts,
stopped,
storm_pauses: storm.pauses,
}
}
async fn storm_gate(&self, storm: &mut StormState) {
let Some(pause) = self.storm_pause else {
return;
};
let now = tokio::time::Instant::now();
let elapsed = storm
.last_failure_at
.map(|at| now.saturating_duration_since(at))
.unwrap_or(Duration::ZERO);
storm.last_failure_at = Some(now);
storm.score = decayed_failure_score(storm.score, elapsed, self.failure_decay);
let tripped = storm.score > self.failure_threshold;
if !tripped {
return;
}
let pause = apply_jitter(pause, self.jitter);
#[cfg(feature = "tracing")]
tracing::warn!(
target: "processkit",
pause_ms = pause.as_millis() as u64,
"supervisor failure storm — pausing restarts"
);
if !pause.is_zero() {
tokio::time::sleep(pause).await;
}
storm.score = 0.0;
storm.last_failure_at = None;
storm.pauses += 1;
}
async fn sleep_backoff(&self, restarts: u32, factor: f64) {
let delay = backoff_delay(self.backoff_base, factor, restarts, self.max_backoff);
let delay = apply_jitter(delay, self.jitter);
#[cfg(feature = "tracing")]
tracing::debug!(
target: "processkit",
restart = restarts + 1,
delay_ms = delay.as_millis() as u64,
"supervisor restarting child"
);
if !delay.is_zero() {
tokio::time::sleep(delay).await;
}
}
}
struct StormState {
score: f64,
last_failure_at: Option<tokio::time::Instant>,
pauses: u32,
}
impl StormState {
fn new() -> Self {
StormState {
score: 0.0,
last_failure_at: None,
pauses: 0,
}
}
}
fn decayed_failure_score(prev: f64, elapsed: Duration, half_life: Duration) -> f64 {
if half_life.is_zero() {
return 1.0;
}
let halflives = elapsed.as_secs_f64() / half_life.as_secs_f64();
let decayed = prev * 0.5_f64.powf(halflives);
if decayed.is_finite() {
decayed + 1.0
} else {
1.0
}
}
fn backoff_delay(base: Duration, factor: f64, n: u32, cap: Duration) -> Duration {
if base.is_zero() {
return Duration::ZERO;
}
let scaled = base.as_secs_f64() * factor.powi(n.min(i32::MAX as u32) as i32);
if !scaled.is_finite() || scaled >= cap.as_secs_f64() {
return cap;
}
Duration::from_secs_f64(scaled).min(cap)
}
fn apply_jitter(delay: Duration, enabled: bool) -> Duration {
if !enabled || delay.is_zero() {
return delay;
}
delay.mul_f64(jitter_factor())
}
fn jitter_factor() -> f64 {
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hasher};
let mut hasher = RandomState::new().build_hasher();
hasher.write_u64(0x9E37_79B9_7F4A_7C15);
let bits = hasher.finish();
let unit = (bits >> 11) as f64 / (1u64 << 53) as f64;
0.5 + unit
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::VecDeque;
use std::sync::Mutex;
use std::sync::atomic::{AtomicU32, Ordering};
struct SeqRunner {
replies: Mutex<VecDeque<Result<ProcessResult<String>>>>,
}
impl SeqRunner {
fn new(replies: Vec<Result<ProcessResult<String>>>) -> Self {
SeqRunner {
replies: Mutex::new(replies.into()),
}
}
}
#[async_trait::async_trait]
impl ProcessRunner for SeqRunner {
async fn output(&self, _command: &Command) -> Result<ProcessResult<String>> {
self.replies
.lock()
.expect("replies lock")
.pop_front()
.expect("SeqRunner ran out of scripted replies")
}
}
fn ok() -> Result<ProcessResult<String>> {
Ok(ProcessResult::new(
"fake".into(),
"out".into(),
String::new(),
Some(0),
false,
None,
))
}
fn fail(code: i32) -> Result<ProcessResult<String>> {
Ok(ProcessResult::new(
"fake".into(),
String::new(),
"boom".into(),
Some(code),
false,
None,
))
}
fn timeout() -> Result<ProcessResult<String>> {
Ok(ProcessResult::new(
"fake".into(),
String::new(),
String::new(),
None,
true,
Some(Duration::from_secs(1)),
))
}
fn spawn_err() -> Result<ProcessResult<String>> {
Err(crate::Error::Spawn {
program: "fake".into(),
source: std::io::Error::new(std::io::ErrorKind::NotFound, "no such binary"),
})
}
fn supervise(runner: SeqRunner) -> Supervisor<SeqRunner> {
Supervisor::new(Command::new("fake"))
.with_runner(runner)
.backoff(Duration::ZERO, 1.0)
.jitter(false)
}
#[tokio::test]
async fn on_crash_restarts_until_success() {
let outcome = supervise(SeqRunner::new(vec![fail(1), fail(1), ok()]))
.run()
.await
.expect("supervision");
assert_eq!(outcome.restarts, 2);
assert_eq!(outcome.stopped, StopReason::PolicySatisfied);
assert!(outcome.final_result.is_success());
}
#[tokio::test]
async fn zero_max_restarts_means_a_single_run() {
let outcome = supervise(SeqRunner::new(vec![fail(1), ok()]))
.max_restarts(0)
.run()
.await
.expect("supervision completes with the single run's result");
assert_eq!(outcome.restarts, 0);
assert_eq!(outcome.stopped, StopReason::RestartsExhausted);
assert_eq!(outcome.final_result.code(), Some(1));
}
#[tokio::test]
async fn on_crash_accepts_a_clean_first_run() {
let outcome = supervise(SeqRunner::new(vec![ok()]))
.run()
.await
.expect("supervision");
assert_eq!(outcome.restarts, 0);
assert_eq!(outcome.stopped, StopReason::PolicySatisfied);
}
#[tokio::test]
async fn predicate_beats_policy() {
let outcome = supervise(SeqRunner::new(vec![ok()]))
.restart(RestartPolicy::Always)
.stop_when(|res| res.code() == Some(0))
.run()
.await
.expect("supervision");
assert_eq!(outcome.restarts, 0);
assert_eq!(outcome.stopped, StopReason::Predicate);
}
#[tokio::test]
async fn always_restarts_clean_runs_until_predicate() {
let seen = AtomicU32::new(0);
let outcome = supervise(SeqRunner::new(vec![ok(), ok(), ok()]))
.restart(RestartPolicy::Always)
.stop_when(move |_| seen.fetch_add(1, Ordering::SeqCst) == 2)
.run()
.await
.expect("supervision");
assert_eq!(outcome.restarts, 2, "third run matched the predicate");
assert_eq!(outcome.stopped, StopReason::Predicate);
}
#[tokio::test]
async fn never_reports_a_failing_run_without_restarting() {
let outcome = supervise(SeqRunner::new(vec![fail(3)]))
.restart(RestartPolicy::Never)
.run()
.await
.expect("supervision");
assert_eq!(outcome.restarts, 0);
assert_eq!(outcome.stopped, StopReason::PolicySatisfied);
assert_eq!(outcome.final_result.code(), Some(3));
}
#[tokio::test]
async fn exhausting_the_budget_reports_the_last_failure() {
let runner = SeqRunner::new(vec![fail(7), fail(7), fail(7)]);
let outcome = supervise(runner)
.max_restarts(2)
.run()
.await
.expect("supervision");
assert_eq!(outcome.restarts, 2, "two restarts = three runs");
assert_eq!(outcome.stopped, StopReason::RestartsExhausted);
assert_eq!(outcome.final_result.code(), Some(7));
}
#[tokio::test]
async fn a_timeout_counts_as_a_crash() {
let outcome = supervise(SeqRunner::new(vec![timeout(), ok()]))
.run()
.await
.expect("supervision");
assert_eq!(outcome.restarts, 1);
assert!(outcome.final_result.is_success());
}
#[tokio::test]
async fn terminal_spawn_error_surfaces_as_err() {
let err = supervise(SeqRunner::new(vec![spawn_err(), spawn_err()]))
.max_restarts(1)
.run()
.await
.expect_err("the budget-exhausting attempt errored");
assert!(matches!(err, crate::Error::Spawn { .. }), "got {err:?}");
}
#[tokio::test]
async fn spawn_error_is_retried_like_a_crash() {
let outcome = supervise(SeqRunner::new(vec![spawn_err(), ok()]))
.run()
.await
.expect("supervision");
assert_eq!(outcome.restarts, 1);
assert_eq!(outcome.stopped, StopReason::PolicySatisfied);
}
#[cfg(feature = "cancellation")]
#[tokio::test]
async fn cancelled_incarnation_is_terminal_under_always() {
let err = supervise(SeqRunner::new(vec![
Err(crate::Error::Cancelled {
program: "fake".into(),
}),
ok(),
]))
.restart(RestartPolicy::Always)
.max_restarts(5)
.run()
.await
.expect_err("a cancelled incarnation is terminal");
assert!(matches!(err, crate::Error::Cancelled { .. }), "got {err:?}");
}
#[tokio::test]
async fn never_returns_a_spawn_error_directly() {
let err = supervise(SeqRunner::new(vec![spawn_err()]))
.restart(RestartPolicy::Never)
.run()
.await
.expect_err("Never does not retry a spawn failure");
assert!(matches!(err, crate::Error::Spawn { .. }), "got {err:?}");
}
#[tokio::test(start_paused = true)]
async fn backoff_doubles_per_restart_without_jitter() {
let start = tokio::time::Instant::now();
let outcome = Supervisor::new(Command::new("fake"))
.with_runner(SeqRunner::new(vec![fail(1), fail(1), ok()]))
.backoff(Duration::from_millis(200), 2.0)
.jitter(false)
.run()
.await
.expect("supervision");
assert_eq!(outcome.restarts, 2);
assert_eq!(start.elapsed(), Duration::from_millis(600));
}
#[tokio::test(start_paused = true)]
async fn max_backoff_caps_the_delay() {
let start = tokio::time::Instant::now();
let outcome = Supervisor::new(Command::new("fake"))
.with_runner(SeqRunner::new(vec![fail(1), fail(1), ok()]))
.backoff(Duration::from_millis(200), 2.0)
.max_backoff(Duration::from_millis(300))
.jitter(false)
.run()
.await
.expect("supervision");
assert_eq!(outcome.restarts, 2);
assert_eq!(start.elapsed(), Duration::from_millis(500));
}
#[tokio::test(start_paused = true)]
async fn jitter_stays_within_its_band() {
let start = tokio::time::Instant::now();
let outcome = Supervisor::new(Command::new("fake"))
.with_runner(SeqRunner::new(vec![fail(1), ok()]))
.backoff(Duration::from_millis(1000), 1.0)
.run()
.await
.expect("supervision");
assert_eq!(outcome.restarts, 1);
let waited = start.elapsed();
assert!(
waited >= Duration::from_millis(500) && waited <= Duration::from_millis(1500),
"jittered delay out of [0.5, 1.5] band: {waited:?}"
);
}
#[tokio::test(start_paused = true)]
async fn nonsense_backoff_factor_decays_to_constant_delay() {
let start = tokio::time::Instant::now();
let outcome = Supervisor::new(Command::new("fake"))
.with_runner(SeqRunner::new(vec![fail(1), fail(1), ok()]))
.backoff(Duration::from_millis(100), 0.0)
.jitter(false)
.run()
.await
.expect("supervision");
assert_eq!(outcome.restarts, 2);
assert_eq!(start.elapsed(), Duration::from_millis(200));
}
#[test]
fn jitter_factor_is_in_band() {
for _ in 0..256 {
let f = jitter_factor();
assert!((0.5..1.5).contains(&f), "factor out of band: {f}");
}
}
#[test]
fn decayed_failure_score_math() {
let hl = Duration::from_secs(30);
assert_eq!(decayed_failure_score(0.0, Duration::ZERO, hl), 1.0);
assert_eq!(decayed_failure_score(1.0, Duration::ZERO, hl), 2.0);
assert_eq!(decayed_failure_score(2.0, hl, hl), 2.0);
assert_eq!(decayed_failure_score(4.0, hl, hl), 3.0);
let aged = decayed_failure_score(8.0, Duration::from_secs(3000), hl);
assert!((aged - 1.0).abs() < 1e-9, "got {aged}");
assert_eq!(
decayed_failure_score(100.0, Duration::ZERO, Duration::ZERO),
1.0
);
assert_eq!(decayed_failure_score(f64::NAN, Duration::ZERO, hl), 1.0);
}
#[tokio::test(start_paused = true)]
async fn storm_guard_is_off_by_default() {
let start = tokio::time::Instant::now();
let outcome = supervise(SeqRunner::new(vec![
fail(1),
fail(1),
fail(1),
fail(1),
ok(),
]))
.run()
.await
.expect("supervision");
assert_eq!(outcome.storm_pauses, 0);
assert_eq!(start.elapsed(), Duration::ZERO, "no hidden pauses");
}
#[tokio::test(start_paused = true)]
async fn storm_trips_past_the_threshold() {
let start = tokio::time::Instant::now();
let outcome = supervise(SeqRunner::new(vec![fail(1), fail(1), fail(1), ok()]))
.storm_pause(Duration::from_secs(1))
.failure_threshold(2.5)
.failure_decay(Duration::from_secs(1000))
.run()
.await
.expect("supervision");
assert_eq!(outcome.restarts, 3);
assert_eq!(outcome.storm_pauses, 1);
assert_eq!(start.elapsed(), Duration::from_secs(1));
}
#[tokio::test(start_paused = true)]
async fn spaced_failures_decay_below_the_threshold() {
let outcome = Supervisor::new(Command::new("fake"))
.with_runner(SeqRunner::new(vec![fail(1), fail(1), fail(1), ok()]))
.backoff(Duration::from_secs(10), 1.0)
.jitter(false)
.storm_pause(Duration::from_secs(1))
.failure_threshold(2.5)
.failure_decay(Duration::from_secs(1))
.run()
.await
.expect("supervision");
assert_eq!(outcome.restarts, 3);
assert_eq!(outcome.storm_pauses, 0);
}
#[tokio::test(start_paused = true)]
async fn storm_pause_resets_the_score() {
let outcome = supervise(SeqRunner::new(vec![
fail(1),
fail(1),
fail(1),
fail(1),
ok(),
]))
.storm_pause(Duration::from_secs(1))
.failure_threshold(1.5)
.failure_decay(Duration::from_secs(1000))
.run()
.await
.expect("supervision");
assert_eq!(outcome.restarts, 4);
assert_eq!(outcome.storm_pauses, 2);
}
#[tokio::test(start_paused = true)]
async fn exhausted_budget_wins_over_the_storm_gate() {
let start = tokio::time::Instant::now();
let outcome = supervise(SeqRunner::new(vec![fail(1), fail(1)]))
.max_restarts(1)
.storm_pause(Duration::from_secs(60))
.failure_threshold(1.5)
.failure_decay(Duration::from_secs(1000))
.run()
.await
.expect("supervision");
assert_eq!(outcome.stopped, StopReason::RestartsExhausted);
assert_eq!(outcome.storm_pauses, 0);
assert_eq!(start.elapsed(), Duration::ZERO);
}
#[tokio::test(start_paused = true)]
async fn storm_pause_is_jittered_within_the_band() {
let start = tokio::time::Instant::now();
let outcome = Supervisor::new(Command::new("fake"))
.with_runner(SeqRunner::new(vec![fail(1), ok()]))
.backoff(Duration::ZERO, 1.0)
.storm_pause(Duration::from_millis(1000))
.failure_threshold(0.5)
.run()
.await
.expect("supervision");
assert_eq!(outcome.storm_pauses, 1);
let waited = start.elapsed();
assert!(
waited >= Duration::from_millis(500) && waited <= Duration::from_millis(1500),
"jittered storm pause out of [0.5, 1.5] band: {waited:?}"
);
}
#[tokio::test(start_paused = true)]
async fn clean_restarts_under_always_do_not_feed_the_storm_score() {
let seen = AtomicU32::new(0);
let outcome = supervise(SeqRunner::new(vec![ok(), ok(), ok()]))
.restart(RestartPolicy::Always)
.storm_pause(Duration::from_secs(60))
.failure_threshold(1.5)
.failure_decay(Duration::from_secs(1000))
.stop_when(move |_| seen.fetch_add(1, Ordering::SeqCst) == 2)
.run()
.await
.expect("supervision");
assert_eq!(outcome.restarts, 2);
assert_eq!(outcome.storm_pauses, 0);
}
#[cfg(feature = "cancellation")]
#[tokio::test(start_paused = true)]
async fn cancellation_is_terminal_before_any_storm_pause() {
let start = tokio::time::Instant::now();
let err = supervise(SeqRunner::new(vec![Err(crate::Error::Cancelled {
program: "fake".into(),
})]))
.storm_pause(Duration::from_secs(60))
.failure_threshold(0.0)
.run()
.await
.expect_err("cancelled is terminal");
assert!(matches!(err, crate::Error::Cancelled { .. }), "got {err:?}");
assert_eq!(start.elapsed(), Duration::ZERO, "no storm pause was taken");
}
#[test]
fn backoff_delay_math() {
let base = Duration::from_millis(100);
let cap = Duration::from_secs(30);
assert_eq!(backoff_delay(base, 2.0, 0, cap), base);
assert_eq!(backoff_delay(base, 2.0, 1, cap), Duration::from_millis(200));
assert_eq!(backoff_delay(base, 2.0, 3, cap), Duration::from_millis(800));
assert_eq!(backoff_delay(base, 2.0, 1_000, cap), cap);
assert_eq!(backoff_delay(Duration::ZERO, 2.0, 5, cap), Duration::ZERO);
}
}