use std::collections::{HashMap, HashSet, VecDeque};
use std::process::{Command, Stdio};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use extism::{CurrentPlugin, Error, UserData, Val};
use serde::{Deserialize, Serialize};
use tracing::warn;
use astrid_workspace::SandboxCommand;
use crate::engine::wasm::host::util;
use crate::engine::wasm::host_state::HostState;
#[derive(Debug, Deserialize)]
struct ProcessRequest<'a> {
cmd: &'a str,
#[serde(default)]
args: Vec<&'a str>,
}
#[derive(Debug, Serialize)]
struct ProcessResult {
stdout: String,
stderr: String,
exit_code: i32,
}
const SIGKILL_GRACE_PERIOD: Duration = Duration::from_secs(2);
#[derive(Debug, Default)]
pub struct ProcessTracker {
active_pids: std::sync::Arc<Mutex<HashMap<u32, Option<String>>>>,
}
impl ProcessTracker {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register(&self, pid: u32, call_id: Option<String>) {
if pid == 0 {
return; }
self.active_pids
.lock()
.expect("process tracker lock poisoned")
.insert(pid, call_id);
}
pub fn unregister(&self, pid: u32) {
self.active_pids
.lock()
.expect("process tracker lock poisoned")
.remove(&pid);
}
pub fn cancel_by_call_ids(&self, call_ids: &[String], handle: &tokio::runtime::Handle) {
if call_ids.is_empty() {
return;
}
let call_id_set: HashSet<&String> = call_ids.iter().collect();
let pids: Vec<u32> = self
.active_pids
.lock()
.expect("process tracker lock poisoned")
.iter()
.filter_map(|(&pid, stored_call_id)| {
match stored_call_id {
None => Some(pid),
Some(id) => call_id_set.contains(id).then_some(pid),
}
})
.collect();
self.signal_pids(&pids, handle);
}
pub fn cancel_all(&self, handle: &tokio::runtime::Handle) {
let pids: Vec<u32> = self
.active_pids
.lock()
.expect("process tracker lock poisoned")
.keys()
.copied()
.collect();
self.signal_pids(&pids, handle);
}
fn signal_pids(&self, pids: &[u32], handle: &tokio::runtime::Handle) {
if pids.is_empty() {
return;
}
for &pid in pids {
let Some(raw) = i32::try_from(pid).ok() else {
warn!(pid, "PID overflows i32, skipping signal");
continue;
};
let _ = nix::sys::signal::kill(
nix::unistd::Pid::from_raw(raw),
nix::sys::signal::Signal::SIGINT,
);
}
let tracker = self.active_pids.clone();
let target_pids: Vec<u32> = pids.to_vec();
handle.spawn(async move {
tokio::time::sleep(SIGKILL_GRACE_PERIOD).await;
let still_active = tracker.lock().expect("process tracker lock poisoned");
for pid in target_pids {
if !still_active.contains_key(&pid) {
continue;
}
let Some(raw) = i32::try_from(pid).ok() else {
continue;
};
let _ = nix::sys::signal::kill(
nix::unistd::Pid::from_raw(raw),
nix::sys::signal::Signal::SIGKILL,
);
}
});
}
}
#[expect(clippy::needless_pass_by_value)]
pub(crate) fn astrid_spawn_host_impl(
plugin: &mut CurrentPlugin,
inputs: &[Val],
outputs: &mut [Val],
user_data: UserData<HostState>,
) -> Result<(), Error> {
let req_bytes: Vec<u8> = util::get_safe_bytes(plugin, &inputs[0], util::MAX_GUEST_PAYLOAD_LEN)?;
let req: ProcessRequest = serde_json::from_slice(&req_bytes)
.map_err(|e| Error::msg(format!("failed to parse process request: {e}")))?;
let ud = user_data.get()?;
let state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
let workspace_root = state.workspace_root.clone();
let security = state.security.clone();
let capsule_id = state.capsule_id.as_str().to_owned();
let handle = state.runtime_handle.clone();
let semaphore = state.host_semaphore.clone();
let cancel_token = state.cancel_token.clone();
let process_tracker = state.process_tracker.clone();
let call_id = state.caller_context.as_ref().and_then(|msg| {
if let astrid_events::ipc::IpcPayload::ToolExecuteRequest { call_id, .. } = &msg.payload {
Some(call_id.clone())
} else {
None
}
});
drop(state);
if let Some(sec) = security {
let cmd = req.cmd.to_string();
util::bounded_block_on(&handle, &semaphore, async {
sec.check_host_process(&capsule_id, &cmd).await
})
.map_err(|e| Error::msg(format!("Security Check Failed: {e}")))?;
} else {
return Err(Error::msg(
"Security Check Failed: No security gate found for host_process capability.",
));
}
let mut inner_cmd = Command::new(req.cmd);
inner_cmd.args(&req.args);
inner_cmd.env_remove("ASTRID_SOCKET_PATH");
inner_cmd.env_remove("ASTRID_SESSION_TOKEN");
inner_cmd.env_remove("ASTRID_HOME");
let sandboxed_cmd = SandboxCommand::wrap(inner_cmd, &workspace_root)
.map_err(|e| Error::msg(format!("failed to wrap command in sandbox: {e}")))?;
let mut sandboxed_cmd = sandboxed_cmd;
sandboxed_cmd.stdout(Stdio::piped());
sandboxed_cmd.stderr(Stdio::piped());
let child = sandboxed_cmd
.spawn()
.map_err(|e| Error::msg(format!("failed to spawn command: {e}")))?;
let pid = child.id();
process_tracker.register(pid, call_id);
let output_result =
util::bounded_block_on_cancellable(&handle, &semaphore, &cancel_token, async move {
tokio::task::spawn_blocking(move || child.wait_with_output())
.await
.map_err(std::io::Error::other)
.and_then(|r| r)
});
let result = match output_result {
Some(Ok(output)) => {
process_tracker.unregister(pid);
ProcessResult {
stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
exit_code: output.status.code().unwrap_or(-1),
}
},
Some(Err(e)) => {
process_tracker.unregister(pid);
return Err(Error::msg(format!("failed to execute command: {e}")));
},
None => {
warn!(capsule_id, pid, "process cancelled");
if let Ok(raw) = i32::try_from(pid) {
let _ = nix::sys::signal::kill(
nix::unistd::Pid::from_raw(raw),
nix::sys::signal::Signal::SIGKILL,
);
}
process_tracker.unregister(pid);
ProcessResult {
stdout: String::new(),
stderr: "process cancelled".to_owned(),
exit_code: -1,
}
},
};
let result_bytes = serde_json::to_vec(&result)?;
let mem = plugin.memory_new(&result_bytes)?;
outputs[0] = plugin.memory_to_val(mem);
Ok(())
}
pub(crate) const MAX_BACKGROUND_PROCESSES: usize = 8;
const MAX_BUFFER_BYTES: usize = 1024 * 1024;
pub struct ManagedProcess {
child: Option<std::process::Child>,
stdout_buf: Arc<Mutex<VecDeque<u8>>>,
stderr_buf: Arc<Mutex<VecDeque<u8>>>,
command: String,
}
fn kill_and_reap(child: &mut std::process::Child) -> Option<i32> {
#[cfg(unix)]
{
let raw_pid = child.id();
let pid = nix::unistd::Pid::from_raw(i32::try_from(raw_pid).unwrap_or(i32::MAX));
let _ = nix::sys::signal::killpg(pid, nix::sys::signal::Signal::SIGKILL);
}
let _ = child.kill(); child.wait().ok().and_then(|s| s.code())
}
impl Drop for ManagedProcess {
fn drop(&mut self) {
if let Some(mut child) = self.child.take() {
kill_and_reap(&mut child);
}
}
}
fn drain_buffer(buf: &Mutex<VecDeque<u8>>) -> String {
let mut locked = buf.lock().unwrap_or_else(|e| e.into_inner());
let bytes: Vec<u8> = locked.drain(..).collect();
String::from_utf8_lossy(&bytes).into_owned()
}
fn spawn_reader_thread(
id: u64,
label: &str,
mut pipe: impl std::io::Read + Send + 'static,
buffer: Arc<Mutex<VecDeque<u8>>>,
) {
let name = format!("bg-{id}-{label}");
std::thread::Builder::new()
.name(name)
.spawn(move || {
let mut chunk = [0u8; 4096];
loop {
match pipe.read(&mut chunk) {
Ok(0) => break, Ok(n) => {
let mut locked = buffer.lock().unwrap_or_else(|e| e.into_inner());
locked.extend(&chunk[..n]);
let excess = locked.len().saturating_sub(MAX_BUFFER_BYTES);
if excess > 0 {
locked.drain(..excess);
}
},
Err(_) => break,
}
}
})
.ok(); }
fn prepare_sandboxed_command(
cmd: &str,
args: &[&str],
workspace_root: &std::path::Path,
) -> Result<Command, Error> {
let mut inner_cmd = Command::new(cmd);
inner_cmd.args(args);
inner_cmd.env_remove("ASTRID_SOCKET_PATH");
inner_cmd.env_remove("ASTRID_SESSION_TOKEN");
inner_cmd.env_remove("ASTRID_HOME");
SandboxCommand::wrap(inner_cmd, workspace_root)
.map_err(|e| Error::msg(format!("failed to wrap command in sandbox: {e}")))
}
#[derive(Debug, Serialize)]
struct SpawnBackgroundResult {
id: u64,
}
#[derive(Debug, Deserialize)]
struct BackgroundProcessRequest {
id: u64,
}
#[derive(Debug, Serialize)]
struct ReadLogsResult {
stdout: String,
stderr: String,
running: bool,
exit_code: Option<i32>,
}
#[derive(Debug, Serialize)]
struct KillProcessResult {
killed: bool,
exit_code: Option<i32>,
stdout: String,
stderr: String,
}
#[expect(clippy::needless_pass_by_value)]
pub(crate) fn astrid_spawn_background_host_impl(
plugin: &mut CurrentPlugin,
inputs: &[Val],
outputs: &mut [Val],
user_data: UserData<HostState>,
) -> Result<(), Error> {
let req_bytes: Vec<u8> = util::get_safe_bytes(plugin, &inputs[0], util::MAX_GUEST_PAYLOAD_LEN)?;
let req: ProcessRequest = serde_json::from_slice(&req_bytes)
.map_err(|e| Error::msg(format!("failed to parse process request: {e}")))?;
let ud = user_data.get()?;
let state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
if state.background_processes.len() >= MAX_BACKGROUND_PROCESSES {
return Err(Error::msg(format!(
"background process limit reached (max {MAX_BACKGROUND_PROCESSES})"
)));
}
let workspace_root = state.workspace_root.clone();
let security = state.security.clone();
let capsule_id = state.capsule_id.as_str().to_owned();
let handle = state.runtime_handle.clone();
let semaphore = state.host_semaphore.clone();
drop(state);
if let Some(sec) = security {
let cmd = req.cmd.to_string();
util::bounded_block_on(&handle, &semaphore, async {
sec.check_host_process(&capsule_id, &cmd).await
})
.map_err(|e| Error::msg(format!("Security Check Failed: {e}")))?;
} else {
return Err(Error::msg(
"Security Check Failed: No security gate found for host_process capability.",
));
}
let mut sandboxed_cmd = prepare_sandboxed_command(req.cmd, &req.args, &workspace_root)?;
#[cfg(unix)]
{
use std::os::unix::process::CommandExt as _;
sandboxed_cmd.process_group(0);
}
sandboxed_cmd.stdout(Stdio::piped());
sandboxed_cmd.stderr(Stdio::piped());
let command_str = format!("{} {}", req.cmd, req.args.join(" "));
let child = sandboxed_cmd
.spawn()
.map_err(|e| Error::msg(format!("failed to spawn background process: {e}")))?;
let stdout_buf: Arc<Mutex<VecDeque<u8>>> = Arc::new(Mutex::new(VecDeque::new()));
let stderr_buf: Arc<Mutex<VecDeque<u8>>> = Arc::new(Mutex::new(VecDeque::new()));
let mut managed = ManagedProcess {
child: Some(child),
stdout_buf: Arc::clone(&stdout_buf),
stderr_buf: Arc::clone(&stderr_buf),
command: command_str,
};
let ud2 = user_data.get()?;
let mut state = ud2
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
if state.background_processes.len() >= MAX_BACKGROUND_PROCESSES {
return Err(Error::msg(format!(
"background process limit reached (max {MAX_BACKGROUND_PROCESSES})"
)));
}
let process_id = state.next_process_id;
state.next_process_id += 1;
if let Some(child) = managed.child.as_mut() {
if let Some(stdout) = child.stdout.take() {
spawn_reader_thread(process_id, "stdout", stdout, Arc::clone(&stdout_buf));
}
if let Some(stderr) = child.stderr.take() {
spawn_reader_thread(process_id, "stderr", stderr, Arc::clone(&stderr_buf));
}
}
tracing::info!(
capsule_id = %capsule_id,
process_id = process_id,
command = %managed.command,
"Spawned background process"
);
state.background_processes.insert(process_id, managed);
drop(state);
let result = SpawnBackgroundResult { id: process_id };
let result_bytes = serde_json::to_vec(&result)?;
let mem = plugin.memory_new(&result_bytes)?;
outputs[0] = plugin.memory_to_val(mem);
Ok(())
}
#[expect(clippy::needless_pass_by_value)]
pub(crate) fn astrid_read_process_logs_host_impl(
plugin: &mut CurrentPlugin,
inputs: &[Val],
outputs: &mut [Val],
user_data: UserData<HostState>,
) -> Result<(), Error> {
let req_bytes: Vec<u8> = util::get_safe_bytes(plugin, &inputs[0], 256)?;
let req: BackgroundProcessRequest = serde_json::from_slice(&req_bytes)
.map_err(|e| Error::msg(format!("failed to parse read logs request: {e}")))?;
let ud = user_data.get()?;
let mut state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
let proc = state
.background_processes
.get_mut(&req.id)
.ok_or_else(|| Error::msg(format!("no background process with id {}", req.id)))?;
let (running, exit_code) = if let Some(child) = proc.child.as_mut() {
match child.try_wait() {
Ok(Some(status)) => {
proc.child.take();
(false, status.code())
},
Ok(None) => (true, None),
Err(_) => {
proc.child.take();
(false, Some(-1))
},
}
} else {
(false, None)
};
let stdout = drain_buffer(&proc.stdout_buf);
let stderr = drain_buffer(&proc.stderr_buf);
drop(state);
let result = ReadLogsResult {
stdout,
stderr,
running,
exit_code,
};
let result_bytes = serde_json::to_vec(&result)?;
let mem = plugin.memory_new(&result_bytes)?;
outputs[0] = plugin.memory_to_val(mem);
Ok(())
}
#[expect(clippy::needless_pass_by_value)]
pub(crate) fn astrid_kill_process_host_impl(
plugin: &mut CurrentPlugin,
inputs: &[Val],
outputs: &mut [Val],
user_data: UserData<HostState>,
) -> Result<(), Error> {
let req_bytes: Vec<u8> = util::get_safe_bytes(plugin, &inputs[0], 256)?;
let req: BackgroundProcessRequest = serde_json::from_slice(&req_bytes)
.map_err(|e| Error::msg(format!("failed to parse kill request: {e}")))?;
let ud = user_data.get()?;
let mut state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
let mut proc = state
.background_processes
.remove(&req.id)
.ok_or_else(|| Error::msg(format!("no background process with id {}", req.id)))?;
let capsule_id = state.capsule_id.as_str().to_owned();
drop(state);
let stdout = drain_buffer(&proc.stdout_buf);
let stderr = drain_buffer(&proc.stderr_buf);
let exit_code = if let Some(mut child) = proc.child.take() {
kill_and_reap(&mut child)
} else {
None
};
tracing::info!(
capsule_id = %capsule_id,
process_id = req.id,
command = %proc.command,
exit_code = ?exit_code,
"Killed background process"
);
let result = KillProcessResult {
killed: true,
exit_code,
stdout,
stderr,
};
let result_bytes = serde_json::to_vec(&result)?;
let mem = plugin.memory_new(&result_bytes)?;
outputs[0] = plugin.memory_to_val(mem);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn buffer_cap_enforced() {
let buf: Arc<Mutex<VecDeque<u8>>> = Arc::new(Mutex::new(VecDeque::new()));
let mut data = vec![b'A'; MAX_BUFFER_BYTES + 500];
data[0] = b'X';
data[500] = b'Y';
{
let mut locked = buf.lock().unwrap_or_else(|e| e.into_inner());
locked.extend(&data);
let excess = locked.len().saturating_sub(MAX_BUFFER_BYTES);
if excess > 0 {
locked.drain(..excess);
}
}
let locked = buf.lock().unwrap_or_else(|e| e.into_inner());
assert_eq!(locked.len(), MAX_BUFFER_BYTES);
assert_eq!(locked[0], b'Y');
assert!(!locked.contains(&b'X'));
}
#[test]
fn drain_buffer_clears_and_returns() {
let buf: Arc<Mutex<VecDeque<u8>>> = Arc::new(Mutex::new(VecDeque::new()));
{
let mut locked = buf.lock().unwrap_or_else(|e| e.into_inner());
locked.extend(b"hello world");
}
let result = drain_buffer(&buf);
assert_eq!(result, "hello world");
let locked = buf.lock().unwrap_or_else(|e| e.into_inner());
assert!(locked.is_empty());
}
#[test]
fn drain_buffer_handles_empty() {
let buf: Arc<Mutex<VecDeque<u8>>> = Arc::new(Mutex::new(VecDeque::new()));
let result = drain_buffer(&buf);
assert_eq!(result, "");
}
#[test]
fn managed_process_drop_kills_child() {
let child = Command::new("sleep")
.arg("60")
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.expect("failed to spawn sleep");
let raw_pid = child.id();
let managed = ManagedProcess {
child: Some(child),
stdout_buf: Arc::new(Mutex::new(VecDeque::new())),
stderr_buf: Arc::new(Mutex::new(VecDeque::new())),
command: "sleep 60".to_string(),
};
drop(managed);
#[cfg(unix)]
{
let pid = nix::unistd::Pid::from_raw(i32::try_from(raw_pid).unwrap_or(i32::MAX));
let result = nix::sys::signal::kill(pid, None);
assert!(
result.is_err(),
"process should be dead after ManagedProcess drop"
);
}
}
#[test]
fn spawn_respects_limit() {
use std::collections::HashMap;
let mut processes: HashMap<u64, ManagedProcess> = HashMap::new();
for i in 0..MAX_BACKGROUND_PROCESSES {
let child = Command::new("sleep")
.arg("60")
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.expect("failed to spawn");
processes.insert(
i as u64,
ManagedProcess {
child: Some(child),
stdout_buf: Arc::new(Mutex::new(VecDeque::new())),
stderr_buf: Arc::new(Mutex::new(VecDeque::new())),
command: "sleep 60".to_string(),
},
);
}
assert!(
processes.len() >= MAX_BACKGROUND_PROCESSES,
"at limit: should reject new spawns"
);
processes.remove(&0); assert!(
processes.len() < MAX_BACKGROUND_PROCESSES,
"below limit: should allow new spawns"
);
}
#[test]
fn kill_nonexistent_returns_error() {
let processes: std::collections::HashMap<u64, ManagedProcess> =
std::collections::HashMap::new();
assert!(processes.get(&999).is_none());
}
#[test]
fn read_logs_after_natural_exit() {
let mut child = Command::new("echo")
.arg("hello from echo")
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.expect("failed to spawn echo");
let stdout_buf: Arc<Mutex<VecDeque<u8>>> = Arc::new(Mutex::new(VecDeque::new()));
let stderr_buf: Arc<Mutex<VecDeque<u8>>> = Arc::new(Mutex::new(VecDeque::new()));
if let Some(stdout) = child.stdout.take() {
spawn_reader_thread(1, "stdout", stdout, Arc::clone(&stdout_buf));
}
let status = child.wait().expect("failed to wait");
assert!(status.success());
std::thread::sleep(std::time::Duration::from_millis(50));
let stdout = drain_buffer(&stdout_buf);
let stderr = drain_buffer(&stderr_buf);
assert!(
stdout.contains("hello from echo"),
"expected output after natural exit, got: {stdout}"
);
assert!(stderr.is_empty());
}
#[test]
fn kill_returns_final_output() {
let mut child = Command::new("echo")
.arg("final output")
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.expect("failed to spawn echo");
let stdout_buf: Arc<Mutex<VecDeque<u8>>> = Arc::new(Mutex::new(VecDeque::new()));
if let Some(stdout) = child.stdout.take() {
spawn_reader_thread(1, "test-stdout", stdout, Arc::clone(&stdout_buf));
}
let _ = child.wait().expect("failed to wait for child");
std::thread::sleep(std::time::Duration::from_millis(50));
let stdout = drain_buffer(&stdout_buf);
assert!(
stdout.contains("final output"),
"expected 'final output' in stdout, got: {stdout}"
);
}
#[test]
fn tracker_register_unregister() {
let tracker = ProcessTracker::new();
tracker.register(1234, None);
tracker.register(5678, Some("call-a".into()));
assert_eq!(tracker.active_pids.lock().unwrap().len(), 2);
tracker.unregister(1234);
assert_eq!(tracker.active_pids.lock().unwrap().len(), 1);
assert!(tracker.active_pids.lock().unwrap().contains_key(&5678));
}
#[test]
fn tracker_ignores_pid_zero() {
let tracker = ProcessTracker::new();
tracker.register(0, None);
assert!(tracker.active_pids.lock().unwrap().is_empty());
}
#[test]
fn tracker_cancel_all_empty_is_noop() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let tracker = ProcessTracker::new();
tracker.cancel_all(rt.handle());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn tracker_cancel_all_kills_real_process() {
let tracker = Arc::new(ProcessTracker::new());
let child = Command::new("sleep")
.arg("60")
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.expect("failed to spawn sleep");
let pid = child.id();
tracker.register(pid, None);
tracker.cancel_all(&tokio::runtime::Handle::current());
let output = tokio::task::spawn_blocking(move || child.wait_with_output())
.await
.expect("join failed")
.expect("wait failed");
tracker.unregister(pid);
assert!(!output.status.success());
assert!(tracker.active_pids.lock().unwrap().is_empty());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn tracker_sigkill_fires_for_sigint_ignoring_process() {
let tracker = Arc::new(ProcessTracker::new());
let child = Command::new("sh")
.args(["-c", "trap '' INT; sleep 60"])
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.expect("failed to spawn sh");
let pid = child.id();
tracker.register(pid, None);
tracker.cancel_all(&tokio::runtime::Handle::current());
let output = tokio::time::timeout(
std::time::Duration::from_secs(5),
tokio::task::spawn_blocking(move || child.wait_with_output()),
)
.await
.expect("process was not killed within 5s")
.expect("join failed")
.expect("wait failed");
tracker.unregister(pid);
assert!(!output.status.success());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn tracker_cancel_all_multiple_processes() {
let tracker = Arc::new(ProcessTracker::new());
let child1 = Command::new("sleep")
.arg("60")
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.expect("failed to spawn sleep 1");
let child2 = Command::new("sleep")
.arg("60")
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.expect("failed to spawn sleep 2");
let pid1 = child1.id();
let pid2 = child2.id();
tracker.register(pid1, None);
tracker.register(pid2, None);
assert_eq!(tracker.active_pids.lock().unwrap().len(), 2);
tracker.cancel_all(&tokio::runtime::Handle::current());
let out1 = tokio::task::spawn_blocking(move || child1.wait_with_output())
.await
.expect("join 1 failed")
.expect("wait 1 failed");
let out2 = tokio::task::spawn_blocking(move || child2.wait_with_output())
.await
.expect("join 2 failed")
.expect("wait 2 failed");
tracker.unregister(pid1);
tracker.unregister(pid2);
assert!(!out1.status.success());
assert!(!out2.status.success());
assert!(tracker.active_pids.lock().unwrap().is_empty());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn tracker_cancel_by_call_ids_scoped() {
let tracker = Arc::new(ProcessTracker::new());
let child_a = Command::new("sleep")
.arg("60")
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.expect("failed to spawn sleep a");
let child_b = Command::new("sleep")
.arg("60")
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.expect("failed to spawn sleep b");
let pid_a = child_a.id();
let pid_b = child_b.id();
tracker.register(pid_a, Some("call-a".into()));
tracker.register(pid_b, Some("call-b".into()));
tracker.cancel_by_call_ids(&["call-a".into()], &tokio::runtime::Handle::current());
let out_a = tokio::task::spawn_blocking(move || child_a.wait_with_output())
.await
.expect("join a failed")
.expect("wait a failed");
assert!(!out_a.status.success());
assert!(tracker.active_pids.lock().unwrap().contains_key(&pid_b));
if let Some(raw) = i32::try_from(pid_b).ok() {
let _ = nix::sys::signal::kill(
nix::unistd::Pid::from_raw(raw),
nix::sys::signal::Signal::SIGKILL,
);
}
let _ = tokio::task::spawn_blocking(move || child_b.wait_with_output()).await;
tracker.unregister(pid_a);
tracker.unregister(pid_b);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn tracker_cancel_by_call_ids_includes_none() {
let tracker = Arc::new(ProcessTracker::new());
let child = Command::new("sleep")
.arg("60")
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.expect("failed to spawn sleep");
let pid = child.id();
tracker.register(pid, None);
tracker.cancel_by_call_ids(&["any-id".into()], &tokio::runtime::Handle::current());
let output = tokio::task::spawn_blocking(move || child.wait_with_output())
.await
.expect("join failed")
.expect("wait failed");
tracker.unregister(pid);
assert!(!output.status.success());
}
}