use std::collections::HashMap;
use std::process::{Child, Command};
use std::time::{Duration, Instant};
#[derive(Debug)]
pub enum RecoveryOutcome {
Spawned {
child_pid: u32,
},
Debounced,
SpawnFailed(std::io::Error),
Reaped {
child_pid: u32,
status: std::process::ExitStatus,
},
Killed {
child_pid: u32,
},
ReapFailed(std::io::Error),
}
struct Outstanding {
child: Child,
spawned_at: Instant,
killed: bool,
}
pub struct Recovery {
template: String,
debounce: Duration,
last_fired: HashMap<u32, Instant>,
timeout: Option<Duration>,
outstanding: HashMap<u32, Outstanding>,
}
impl Recovery {
pub fn new(template: String, debounce: Duration) -> Self {
Self::with_timeout(template, debounce, None)
}
pub fn with_timeout(template: String, debounce: Duration, timeout: Option<Duration>) -> Self {
Recovery {
template,
debounce,
last_fired: HashMap::new(),
timeout,
outstanding: HashMap::new(),
}
}
pub fn on_stall(&mut self, pid: u32) -> RecoveryOutcome {
let now = Instant::now();
let prune_threshold = self.debounce.saturating_mul(10);
self.last_fired
.retain(|_, &mut fired_at| now.duration_since(fired_at) < prune_threshold);
if let Some(prev) = self.last_fired.get(&pid) {
if now.duration_since(*prev) < self.debounce {
return RecoveryOutcome::Debounced;
}
}
let rendered = self.template.replace("{pid}", &pid.to_string());
self.last_fired.insert(pid, now);
match Command::new("/bin/sh").arg("-c").arg(&rendered).spawn() {
Ok(child) => {
let child_pid = child.id();
self.outstanding.insert(
pid,
Outstanding {
child,
spawned_at: Instant::now(),
killed: false,
},
);
RecoveryOutcome::Spawned { child_pid }
}
Err(e) => RecoveryOutcome::SpawnFailed(e),
}
}
pub fn try_reap(&mut self) -> Vec<RecoveryOutcome> {
let mut outcomes = Vec::new();
let pids: Vec<u32> = self.outstanding.keys().copied().collect();
for pid in pids {
let entry = match self.outstanding.get_mut(&pid) {
Some(e) => e,
None => continue,
};
match entry.child.try_wait() {
Ok(Some(status)) => {
let child_pid = entry.child.id();
self.outstanding.remove(&pid);
outcomes.push(RecoveryOutcome::Reaped { child_pid, status });
}
Ok(None) => {
if let Some(to) = self.timeout {
if entry.spawned_at.elapsed() >= to {
if entry.killed {
continue;
}
let child_pid = entry.child.id();
match entry.child.kill() {
Ok(()) => {
entry.killed = true;
outcomes.push(RecoveryOutcome::Killed { child_pid });
}
Err(e) if e.kind() == std::io::ErrorKind::InvalidInput => {
match entry.child.try_wait() {
Ok(Some(status)) => {
let child_pid = entry.child.id();
self.outstanding.remove(&pid);
outcomes.push(RecoveryOutcome::Reaped {
child_pid,
status,
});
}
_ => {
}
}
}
Err(e) => {
self.outstanding.remove(&pid);
outcomes.push(RecoveryOutcome::ReapFailed(e));
}
}
}
}
}
Err(e) => {
self.outstanding.remove(&pid);
outcomes.push(RecoveryOutcome::ReapFailed(e));
}
}
}
outcomes
}
}
impl Drop for Recovery {
fn drop(&mut self) {
let _ = self.try_reap();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Instant;
#[test]
fn debounces_repeat_calls_for_same_pid() {
let mut rec = Recovery::new("true".to_string(), Duration::from_secs(10));
let first = rec.on_stall(1);
let second = rec.on_stall(1);
assert!(matches!(first, RecoveryOutcome::Spawned { .. }));
assert!(matches!(second, RecoveryOutcome::Debounced));
}
#[test]
fn debounce_is_per_pid() {
let mut rec = Recovery::new("true".to_string(), Duration::from_secs(10));
let a = rec.on_stall(1);
let b = rec.on_stall(2);
assert!(matches!(a, RecoveryOutcome::Spawned { .. }));
assert!(matches!(b, RecoveryOutcome::Spawned { .. }));
}
#[test]
fn template_substitutes_every_pid_token() {
let mut rec = Recovery::new(
"test \"{pid}-{pid}\" = \"7-7\"".to_string(),
Duration::from_secs(0),
);
match rec.on_stall(7) {
RecoveryOutcome::Spawned { child_pid: _ } => {
std::thread::sleep(Duration::from_millis(50));
let outcomes = rec.try_reap();
let reaped = outcomes.into_iter().find_map(|o| match o {
RecoveryOutcome::Reaped { status, .. } => Some(status),
_ => None,
});
assert!(
matches!(reaped, Some(s) if s.success()),
"expected Reaped(success) for pid 7; got {:?}",
reaped
);
}
other => panic!("expected Spawned, got {other:?}"),
}
}
#[test]
fn spawn_returns_immediately_for_slow_template() {
let mut rec = Recovery::new("sleep 1".to_string(), Duration::ZERO);
let start = Instant::now();
match rec.on_stall(42) {
RecoveryOutcome::Spawned { .. } => {}
other => panic!("expected Spawned, got {other:?}"),
}
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_millis(50),
"spawn blocked for {elapsed:?}; expected non-blocking"
);
}
#[test]
fn try_reap_surfaces_reaped_for_fast_child() {
let mut rec = Recovery::new("true".to_string(), Duration::ZERO);
match rec.on_stall(99) {
RecoveryOutcome::Spawned { .. } => {}
other => panic!("expected Spawned, got {other:?}"),
}
let deadline = Instant::now() + Duration::from_millis(500);
loop {
if Instant::now() >= deadline {
panic!("timed out waiting for Reaped");
}
let outcomes = rec.try_reap();
if let Some(o) = outcomes.into_iter().find_map(|o| match o {
RecoveryOutcome::Reaped { status, .. } => Some(status),
_ => None,
}) {
assert!(o.success(), "expected success from 'true'");
return;
}
std::thread::sleep(Duration::from_millis(20));
}
}
#[test]
fn try_reap_kills_after_timeout() {
let mut rec = Recovery::with_timeout(
"sleep 5".to_string(),
Duration::ZERO,
Some(Duration::from_millis(100)),
);
match rec.on_stall(7) {
RecoveryOutcome::Spawned { .. } => {}
other => panic!("expected Spawned, got {other:?}"),
}
let deadline = Instant::now() + Duration::from_millis(1_000);
loop {
if Instant::now() >= deadline {
panic!("timed out waiting for Killed");
}
let outcomes = rec.try_reap();
if outcomes
.iter()
.any(|o| matches!(o, RecoveryOutcome::Killed { .. }))
{
return;
}
std::thread::sleep(Duration::from_millis(30));
}
}
#[test]
fn drop_does_not_leak_zombies() {
{
let mut rec = Recovery::new("true".to_string(), Duration::ZERO);
match rec.on_stall(999) {
RecoveryOutcome::Spawned { .. } => {}
other => panic!("expected Spawned, got {other:?}"),
}
}
}
#[test]
fn with_timeout_constructor_accepts_optional_duration() {
let _none = Recovery::with_timeout("true".to_string(), Duration::ZERO, None);
let _some = Recovery::with_timeout(
"true".to_string(),
Duration::ZERO,
Some(Duration::from_millis(50)),
);
}
#[test]
fn last_fired_hashmap_is_pruned_after_debounce_times_ten() {
let debounce = Duration::from_millis(10);
let mut rec = Recovery::new("true".to_string(), debounce);
assert!(matches!(rec.on_stall(1), RecoveryOutcome::Spawned { .. }));
assert!(matches!(rec.on_stall(1), RecoveryOutcome::Debounced));
let prune_threshold = debounce.saturating_mul(10);
std::thread::sleep(prune_threshold + Duration::from_millis(40));
assert!(matches!(rec.on_stall(1), RecoveryOutcome::Spawned { .. }));
}
}