use crate::bg_agent::CancelOutcome;
use std::collections::HashMap;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use tokio::process::Child;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ProcessWaitOutcome {
Exited {
code: Option<i32>,
},
TimedOut(BgProcessSnapshot),
NotFound,
Forbidden,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BgProcessStatus {
Running,
Exited {
code: Option<i32>,
},
Killed,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BgProcessSnapshot {
pub pid: u32,
pub command: String,
pub age: Duration,
pub status: BgProcessStatus,
pub spawner: Option<u32>,
}
struct BgEntry {
command: String,
child: Child,
started_at: Instant,
status: BgProcessStatus,
spawner: Option<u32>,
}
pub struct BgRegistry {
inner: Mutex<HashMap<u32, BgEntry>>,
}
impl BgRegistry {
pub fn new() -> Self {
Self {
inner: Mutex::new(HashMap::new()),
}
}
pub fn insert(&self, pid: u32, command: String, child: Child, spawner: Option<u32>) -> u32 {
self.inner.lock().unwrap().insert(
pid,
BgEntry {
command,
child,
started_at: Instant::now(),
status: BgProcessStatus::Running,
spawner,
},
);
pid
}
pub fn list(&self) -> Vec<(u32, String)> {
self.inner
.lock()
.unwrap()
.iter()
.map(|(pid, e)| (*pid, e.command.clone()))
.collect()
}
pub fn snapshot(&self) -> Vec<BgProcessSnapshot> {
let guard = self.inner.lock().unwrap();
let now = Instant::now();
let mut out: Vec<_> = guard
.iter()
.map(|(pid, e)| BgProcessSnapshot {
pid: *pid,
command: e.command.clone(),
age: now.saturating_duration_since(e.started_at),
status: e.status,
spawner: e.spawner,
})
.collect();
out.sort_by_key(|s| s.pid);
out
}
pub fn snapshot_for_caller(&self, caller_spawner: Option<u32>) -> Vec<BgProcessSnapshot> {
self.snapshot()
.into_iter()
.filter(|s| s.spawner == caller_spawner)
.collect()
}
pub fn len(&self) -> usize {
self.inner.lock().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.inner.lock().unwrap().is_empty()
}
pub fn reap(&self) {
let mut guard = self.inner.lock().unwrap();
for entry in guard.values_mut() {
if entry.status != BgProcessStatus::Running {
continue;
}
match entry.child.try_wait() {
Ok(Some(exit)) => {
entry.status = BgProcessStatus::Exited { code: exit.code() };
}
Ok(None) => { }
Err(e) => {
tracing::warn!(
"BgRegistry reap try_wait failed for PID {}: {e}",
entry.child.id().unwrap_or(0)
);
entry.status = BgProcessStatus::Exited { code: None };
}
}
}
}
pub fn kill(&self, pid: u32) -> bool {
let mut guard = self.inner.lock().unwrap();
let Some(entry) = guard.get_mut(&pid) else {
return false;
};
if entry.status == BgProcessStatus::Running {
if let Err(e) = entry.child.start_kill() {
tracing::warn!("BgRegistry::kill: failed to SIGTERM PID {pid}: {e}");
}
entry.status = BgProcessStatus::Killed;
}
true
}
pub fn kill_as_caller(&self, pid: u32, caller_spawner: Option<u32>) -> CancelOutcome {
let mut guard = self.inner.lock().unwrap();
let Some(entry) = guard.get_mut(&pid) else {
return CancelOutcome::NotFound;
};
if entry.spawner != caller_spawner {
return CancelOutcome::Forbidden;
}
if entry.status == BgProcessStatus::Running {
if let Err(e) = entry.child.start_kill() {
tracing::warn!("BgRegistry::kill_as_caller: SIGTERM PID {pid}: {e}");
}
entry.status = BgProcessStatus::Killed;
}
CancelOutcome::Cancelled
}
pub async fn wait_for_exit_as_caller(
&self,
pid: u32,
caller_spawner: Option<u32>,
timeout: Duration,
) -> ProcessWaitOutcome {
const POLL_INTERVAL: Duration = Duration::from_millis(100);
{
let guard = self.inner.lock().unwrap();
match guard.get(&pid) {
None => return ProcessWaitOutcome::NotFound,
Some(e) if e.spawner != caller_spawner => return ProcessWaitOutcome::Forbidden,
Some(_) => {}
}
}
let deadline = Instant::now() + timeout;
loop {
self.reap();
{
let guard = self.inner.lock().unwrap();
let Some(entry) = guard.get(&pid) else {
return ProcessWaitOutcome::NotFound;
};
match entry.status {
BgProcessStatus::Running => {}
BgProcessStatus::Exited { code } => {
return ProcessWaitOutcome::Exited { code };
}
BgProcessStatus::Killed => {
return ProcessWaitOutcome::Exited { code: None };
}
}
}
if Instant::now() >= deadline {
let guard = self.inner.lock().unwrap();
let Some(entry) = guard.get(&pid) else {
return ProcessWaitOutcome::NotFound;
};
let now = Instant::now();
return ProcessWaitOutcome::TimedOut(BgProcessSnapshot {
pid,
command: entry.command.clone(),
age: now.saturating_duration_since(entry.started_at),
status: entry.status,
spawner: entry.spawner,
});
}
let remaining = deadline.saturating_duration_since(Instant::now());
tokio::time::sleep(POLL_INTERVAL.min(remaining)).await;
}
}
pub fn kill_for_spawner(&self, spawner: u32) -> usize {
let mut guard = self.inner.lock().unwrap();
let mut count = 0;
for entry in guard.values_mut() {
if entry.spawner != Some(spawner) {
continue;
}
if entry.status == BgProcessStatus::Running {
if let Err(e) = entry.child.start_kill() {
tracing::warn!(
"BgRegistry::kill_for_spawner: SIGTERM PID {}: {e}",
entry.child.id().unwrap_or(0)
);
}
entry.status = BgProcessStatus::Killed;
count += 1;
}
}
count
}
}
impl Default for BgRegistry {
fn default() -> Self {
Self::new()
}
}
impl Drop for BgRegistry {
fn drop(&mut self) {
let mut guard = self.inner.lock().unwrap();
for (pid, entry) in guard.iter_mut() {
if entry.status != BgProcessStatus::Running {
continue;
}
if let Err(e) = entry.child.start_kill() {
tracing::warn!("BgRegistry drop: failed to kill PID {pid}: {e}");
} else {
tracing::debug!("BgRegistry drop: sent SIGTERM to PID {pid}");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn spawn_sleep_child() -> (u32, Child) {
let child = tokio::process::Command::new("sleep")
.arg("60")
.spawn()
.expect("spawn sleep");
let pid = child.id().expect("pid");
(pid, child)
}
fn spawn_true_child() -> (u32, Child) {
let child = tokio::process::Command::new("true").spawn().expect("spawn");
let pid = child.id().unwrap_or(99999);
(pid, child)
}
#[test]
fn registry_starts_empty() {
let reg = BgRegistry::new();
assert_eq!(reg.len(), 0);
assert!(reg.list().is_empty());
assert!(reg.snapshot().is_empty());
}
#[tokio::test]
async fn insert_records_spawner_and_appears_in_snapshot() {
let reg = BgRegistry::new();
let (pid, child) = spawn_sleep_child();
reg.insert(pid, "sleep 60".into(), child, Some(7));
let snap = reg.snapshot();
assert_eq!(snap.len(), 1);
assert_eq!(snap[0].pid, pid);
assert_eq!(snap[0].command, "sleep 60");
assert_eq!(snap[0].status, BgProcessStatus::Running);
assert_eq!(snap[0].spawner, Some(7));
}
#[tokio::test]
async fn snapshot_for_caller_filters_by_spawner() {
let reg = BgRegistry::new();
let (p1, c1) = spawn_sleep_child();
let (p2, c2) = spawn_sleep_child();
let (p3, c3) = spawn_sleep_child();
reg.insert(p1, "a".into(), c1, None);
reg.insert(p2, "b".into(), c2, Some(7));
reg.insert(p3, "c".into(), c3, Some(9));
let top = reg.snapshot_for_caller(None);
assert_eq!(top.len(), 1);
assert_eq!(top[0].pid, p1);
let sub_7 = reg.snapshot_for_caller(Some(7));
assert_eq!(sub_7.len(), 1);
assert_eq!(sub_7[0].pid, p2);
assert!(reg.snapshot_for_caller(Some(42)).is_empty());
}
#[tokio::test]
async fn reap_transitions_finished_children_to_exited() {
let reg = BgRegistry::new();
let (pid, child) = spawn_true_child();
reg.insert(pid, "true".into(), child, None);
let mut observed = None;
for _ in 0..50 {
tokio::time::sleep(Duration::from_millis(20)).await;
reg.reap();
let snap = reg.snapshot();
if let BgProcessStatus::Exited { code } = snap[0].status {
observed = Some(code);
break;
}
}
assert_eq!(
observed,
Some(Some(0)),
"reap should observe `true` exiting with code 0 within 1s"
);
}
#[tokio::test]
async fn kill_transitions_to_killed_and_returns_true() {
let reg = BgRegistry::new();
let (pid, child) = spawn_sleep_child();
reg.insert(pid, "sleep 60".into(), child, None);
assert!(reg.kill(pid));
assert_eq!(reg.snapshot()[0].status, BgProcessStatus::Killed);
assert!(!reg.kill(987654));
}
#[tokio::test]
async fn kill_as_caller_enforces_spawner_scope() {
let reg = BgRegistry::new();
let (pid, child) = spawn_sleep_child();
reg.insert(pid, "sleep 60".into(), child, Some(5));
assert_eq!(reg.kill_as_caller(pid, None), CancelOutcome::Forbidden);
assert_eq!(reg.kill_as_caller(pid, Some(99)), CancelOutcome::Forbidden);
assert_eq!(reg.snapshot()[0].status, BgProcessStatus::Running);
assert_eq!(reg.kill_as_caller(pid, Some(5)), CancelOutcome::Cancelled);
assert_eq!(reg.snapshot()[0].status, BgProcessStatus::Killed);
assert_eq!(reg.kill_as_caller(987654, None), CancelOutcome::NotFound);
}
#[tokio::test]
async fn wait_for_exit_returns_exited_when_child_finishes() {
let reg = BgRegistry::new();
let (pid, child) = spawn_true_child();
reg.insert(pid, "true".into(), child, None);
let outcome = reg
.wait_for_exit_as_caller(pid, None, Duration::from_secs(2))
.await;
assert_eq!(outcome, ProcessWaitOutcome::Exited { code: Some(0) });
}
#[tokio::test]
async fn wait_for_exit_returns_exited_when_already_killed() {
let reg = BgRegistry::new();
let (pid, child) = spawn_sleep_child();
reg.insert(pid, "sleep 60".into(), child, Some(7));
reg.kill(pid);
let outcome = reg
.wait_for_exit_as_caller(pid, Some(7), Duration::from_secs(1))
.await;
assert_eq!(outcome, ProcessWaitOutcome::Exited { code: None });
}
#[tokio::test]
async fn wait_for_exit_returns_timed_out_with_snapshot() {
let reg = BgRegistry::new();
let (pid, child) = spawn_sleep_child();
reg.insert(pid, "sleep 60".into(), child, None);
let outcome = reg
.wait_for_exit_as_caller(pid, None, Duration::from_millis(150))
.await;
match outcome {
ProcessWaitOutcome::TimedOut(snap) => {
assert_eq!(snap.pid, pid);
assert_eq!(snap.status, BgProcessStatus::Running);
assert_eq!(snap.spawner, None);
}
other => panic!("expected TimedOut, got {other:?}"),
}
assert_eq!(
reg.snapshot().len(),
1,
"entry must be preserved on timeout"
);
}
#[tokio::test]
async fn wait_for_exit_enforces_spawner_scope() {
let reg = BgRegistry::new();
let (pid, child) = spawn_sleep_child();
reg.insert(pid, "sleep 60".into(), child, Some(5));
assert_eq!(
reg.wait_for_exit_as_caller(pid, None, Duration::from_millis(20))
.await,
ProcessWaitOutcome::Forbidden
);
assert_eq!(
reg.wait_for_exit_as_caller(pid, Some(99), Duration::from_millis(20))
.await,
ProcessWaitOutcome::Forbidden
);
}
#[tokio::test]
async fn wait_for_exit_returns_not_found_for_unknown_pid() {
let reg = BgRegistry::new();
assert_eq!(
reg.wait_for_exit_as_caller(987654, None, Duration::from_millis(10))
.await,
ProcessWaitOutcome::NotFound
);
}
#[tokio::test]
async fn kill_for_spawner_kills_only_matching_running_children() {
let reg = BgRegistry::new();
let (p_top, c_top) = spawn_sleep_child();
let (p_a, c_a) = spawn_sleep_child();
let (p_b, c_b) = spawn_sleep_child();
reg.insert(p_top, "top".into(), c_top, None);
reg.insert(p_a, "a".into(), c_a, Some(7));
reg.insert(p_b, "b".into(), c_b, Some(9));
let count = reg.kill_for_spawner(7);
assert_eq!(count, 1);
let by_pid: HashMap<u32, BgProcessStatus> = reg
.snapshot()
.into_iter()
.map(|s| (s.pid, s.status))
.collect();
assert_eq!(by_pid[&p_top], BgProcessStatus::Running);
assert_eq!(by_pid[&p_a], BgProcessStatus::Killed);
assert_eq!(by_pid[&p_b], BgProcessStatus::Running);
assert_eq!(reg.kill_for_spawner(7), 0);
assert_eq!(reg.kill_for_spawner(99), 0);
}
}