use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy)]
pub struct RestartPolicy {
pub max_restarts: u32,
pub window: Duration,
pub delay: Duration,
}
impl Default for RestartPolicy {
fn default() -> Self {
Self {
max_restarts: 10,
window: Duration::from_secs(600),
delay: Duration::from_secs(10),
}
}
}
#[derive(Debug, Default)]
pub struct RestartTracker {
timestamps: Vec<Instant>,
}
impl RestartTracker {
pub fn new() -> Self {
Self::default()
}
pub fn try_record(&mut self, policy: &RestartPolicy) -> Result<(), (u32, u64)> {
let now = Instant::now();
self.timestamps
.retain(|t| now.duration_since(*t) < policy.window);
if self.timestamps.len() as u32 >= policy.max_restarts {
return Err((policy.max_restarts, policy.window.as_secs()));
}
self.timestamps.push(now);
Ok(())
}
pub fn last(&self) -> Option<Instant> {
self.timestamps.last().copied()
}
pub fn reset(&mut self) {
self.timestamps.clear();
}
}
#[derive(Debug)]
pub enum SuperviseError<E> {
TooManyRestarts {
max_restarts: u32,
window_secs: u64,
last_error: E,
},
}
impl<E: std::fmt::Display> std::fmt::Display for SuperviseError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TooManyRestarts {
max_restarts,
window_secs,
last_error,
} => write!(
f,
"supervisor: too many restarts ({max_restarts} in {window_secs}s); \
last error: {last_error}"
),
}
}
}
impl<E: std::fmt::Display + std::fmt::Debug> std::error::Error for SuperviseError<E> {}
pub async fn supervise<F, Fut, E>(
policy: RestartPolicy,
mut task_factory: F,
) -> Result<(), SuperviseError<E>>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<(), E>>,
E: std::fmt::Debug,
{
let mut tracker = RestartTracker::new();
let mut attempt = 0u32;
let mut last_error: Option<E> = None;
loop {
if let Err((max, win)) = tracker.try_record(&policy) {
match last_error {
Some(last_error) => {
tracing::error!(max, window_secs = win, "supervise: too many restarts");
return Err(SuperviseError::TooManyRestarts {
max_restarts: max,
window_secs: win,
last_error,
});
}
None => tracing::warn!(
max,
"supervise: max_restarts=0 is degenerate; running the task once"
),
}
}
attempt += 1;
if attempt > 1 {
tracing::info!(
attempt,
delay_ms = policy.delay.as_millis() as u64,
"supervise: scheduling restart"
);
tokio::time::sleep(policy.delay).await;
}
tracing::info!(attempt, "supervise: starting attempt");
match task_factory().await {
Ok(()) => {
tracing::info!(attempt, "supervise: task exited normally");
return Ok(());
}
Err(e) => {
tracing::warn!(attempt, error = ?e, "supervise: task failed");
last_error = Some(e);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
#[test]
fn rate_limit_bails_after_max() {
let policy = RestartPolicy {
max_restarts: 3,
window: Duration::from_secs(60),
delay: Duration::ZERO,
};
let mut t = RestartTracker::new();
assert!(t.try_record(&policy).is_ok());
assert!(t.try_record(&policy).is_ok());
assert!(t.try_record(&policy).is_ok());
assert!(t.try_record(&policy).is_err());
}
#[test]
fn reset_clears_window() {
let policy = RestartPolicy {
max_restarts: 2,
window: Duration::from_secs(60),
delay: Duration::ZERO,
};
let mut t = RestartTracker::new();
t.try_record(&policy).unwrap();
t.try_record(&policy).unwrap();
assert!(t.try_record(&policy).is_err());
t.reset();
assert!(t.try_record(&policy).is_ok());
}
#[tokio::test]
async fn supervise_immediate_success() {
let policy = RestartPolicy {
max_restarts: 3,
window: Duration::from_secs(60),
delay: Duration::from_millis(1),
};
let result: Result<(), SuperviseError<&str>> =
supervise(policy, || async { Ok::<(), &str>(()) }).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn supervise_eventual_success() {
let count = Arc::new(AtomicU32::new(0));
let policy = RestartPolicy {
max_restarts: 5,
window: Duration::from_secs(60),
delay: Duration::from_millis(1),
};
let count_clone = count.clone();
let result: Result<(), SuperviseError<&str>> = supervise(policy, || {
let c = count_clone.clone();
async move {
let n = c.fetch_add(1, Ordering::Relaxed);
if n < 2 {
Err::<(), &str>("not yet")
} else {
Ok::<(), &str>(())
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(count.load(Ordering::Relaxed), 3);
}
#[tokio::test]
async fn supervise_too_many_restarts() {
let policy = RestartPolicy {
max_restarts: 2,
window: Duration::from_secs(60),
delay: Duration::from_millis(1),
};
let result: Result<(), SuperviseError<&str>> =
supervise(policy, || async { Err::<(), &str>("always fails") }).await;
match result {
Err(SuperviseError::TooManyRestarts {
max_restarts,
window_secs,
last_error,
}) => {
assert_eq!(max_restarts, 2);
assert_eq!(window_secs, 60);
assert_eq!(last_error, "always fails");
}
other => panic!("expected TooManyRestarts, got {other:?}"),
}
}
#[tokio::test]
async fn supervise_launch_count_equals_max_restarts() {
for max in [1u32, 2, 3, 10] {
let launches = Arc::new(AtomicU32::new(0));
let policy = RestartPolicy {
max_restarts: max,
window: Duration::from_secs(60),
delay: Duration::ZERO,
};
let launches_clone = launches.clone();
let result: Result<(), SuperviseError<&str>> = supervise(policy, || {
let c = launches_clone.clone();
async move {
c.fetch_add(1, Ordering::Relaxed);
Err::<(), &str>("always fails")
}
})
.await;
assert!(
matches!(result, Err(SuperviseError::TooManyRestarts { .. })),
"max_restarts={max} must end in TooManyRestarts"
);
assert_eq!(
launches.load(Ordering::Relaxed),
max,
"exactly max_restarts={max} launches must occur, not {} (the pre-fix off-by-one)",
max + 1
);
}
}
#[tokio::test]
async fn supervise_zero_max_runs_task_once() {
let launches = Arc::new(AtomicU32::new(0));
let policy = RestartPolicy {
max_restarts: 0,
window: Duration::from_secs(60),
delay: Duration::ZERO,
};
let launches_clone = launches.clone();
let result: Result<(), SuperviseError<&str>> = supervise(policy, || {
let c = launches_clone.clone();
async move {
c.fetch_add(1, Ordering::Relaxed);
Err::<(), &str>("boom")
}
})
.await;
match result {
Err(SuperviseError::TooManyRestarts { last_error, .. }) => {
assert_eq!(last_error, "boom");
}
other => panic!("expected TooManyRestarts, got {other:?}"),
}
assert_eq!(launches.load(Ordering::Relaxed), 1);
}
}