use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy)]
pub struct RestartPolicy {
pub max_restarts: u32,
pub window: Duration,
pub delay: Duration,
}
impl Default for RestartPolicy {
fn default() -> Self {
Self {
max_restarts: 10,
window: Duration::from_secs(600),
delay: Duration::from_secs(10),
}
}
}
#[derive(Debug, Default)]
pub struct RestartTracker {
timestamps: Vec<Instant>,
}
impl RestartTracker {
pub fn new() -> Self {
Self::default()
}
pub fn try_record(&mut self, policy: &RestartPolicy) -> Result<(), (u32, u64)> {
let now = Instant::now();
self.timestamps
.retain(|t| now.duration_since(*t) < policy.window);
if self.timestamps.len() as u32 >= policy.max_restarts {
return Err((policy.max_restarts, policy.window.as_secs()));
}
self.timestamps.push(now);
Ok(())
}
pub fn last(&self) -> Option<Instant> {
self.timestamps.last().copied()
}
pub fn reset(&mut self) {
self.timestamps.clear();
}
}
#[derive(Debug)]
pub enum SuperviseError<E> {
TooManyRestarts,
Inner(E),
}
impl<E: std::fmt::Display> std::fmt::Display for SuperviseError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TooManyRestarts => write!(f, "supervisor: too many restarts"),
Self::Inner(e) => write!(f, "supervisor inner error: {e}"),
}
}
}
impl<E: std::fmt::Display + std::fmt::Debug> std::error::Error for SuperviseError<E> {}
pub async fn supervise<F, Fut, E>(
policy: RestartPolicy,
mut task_factory: F,
) -> Result<(), SuperviseError<E>>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<(), E>>,
E: std::fmt::Debug,
{
let mut tracker = RestartTracker::new();
let mut attempt = 0u32;
loop {
attempt += 1;
tracing::info!(attempt, "supervise: starting attempt");
let result = task_factory().await;
match result {
Ok(()) => {
tracing::info!(attempt, "supervise: task exited normally");
return Ok(());
}
Err(e) => {
tracing::warn!(attempt, error = ?e, "supervise: task failed");
}
}
if let Err((max, win)) = tracker.try_record(&policy) {
tracing::error!(max, window_secs = win, "supervise: too many restarts");
return Err(SuperviseError::TooManyRestarts);
}
tracing::info!(
attempt,
delay_ms = policy.delay.as_millis() as u64,
"supervise: scheduling restart"
);
tokio::time::sleep(policy.delay).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
#[test]
fn rate_limit_bails_after_max() {
let policy = RestartPolicy {
max_restarts: 3,
window: Duration::from_secs(60),
delay: Duration::ZERO,
};
let mut t = RestartTracker::new();
assert!(t.try_record(&policy).is_ok());
assert!(t.try_record(&policy).is_ok());
assert!(t.try_record(&policy).is_ok());
assert!(t.try_record(&policy).is_err());
}
#[test]
fn reset_clears_window() {
let policy = RestartPolicy {
max_restarts: 2,
window: Duration::from_secs(60),
delay: Duration::ZERO,
};
let mut t = RestartTracker::new();
t.try_record(&policy).unwrap();
t.try_record(&policy).unwrap();
assert!(t.try_record(&policy).is_err());
t.reset();
assert!(t.try_record(&policy).is_ok());
}
#[tokio::test]
async fn supervise_immediate_success() {
let policy = RestartPolicy {
max_restarts: 3,
window: Duration::from_secs(60),
delay: Duration::from_millis(1),
};
let result: Result<(), SuperviseError<&str>> =
supervise(policy, || async { Ok::<(), &str>(()) }).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn supervise_eventual_success() {
let count = Arc::new(AtomicU32::new(0));
let policy = RestartPolicy {
max_restarts: 5,
window: Duration::from_secs(60),
delay: Duration::from_millis(1),
};
let count_clone = count.clone();
let result: Result<(), SuperviseError<&str>> = supervise(policy, || {
let c = count_clone.clone();
async move {
let n = c.fetch_add(1, Ordering::Relaxed);
if n < 2 {
Err::<(), &str>("not yet")
} else {
Ok::<(), &str>(())
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(count.load(Ordering::Relaxed), 3);
}
#[tokio::test]
async fn supervise_too_many_restarts() {
let policy = RestartPolicy {
max_restarts: 2,
window: Duration::from_secs(60),
delay: Duration::from_millis(1),
};
let result: Result<(), SuperviseError<&str>> =
supervise(policy, || async { Err::<(), &str>("always fails") }).await;
assert!(matches!(result, Err(SuperviseError::TooManyRestarts)));
}
}