use std::cell::RefCell;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
pub const SUBPROCESS_TERM_GRACE: Duration = Duration::from_secs(2);
#[derive(Clone, Default)]
struct OpInterrupt {
cancel: Option<Arc<AtomicBool>>,
deadline: Option<Instant>,
}
thread_local! {
static CURRENT: RefCell<Option<OpInterrupt>> = const { RefCell::new(None) };
}
pub struct OpInterruptGuard {
#[allow(clippy::option_option)]
prev: Option<Option<OpInterrupt>>,
}
impl Drop for OpInterruptGuard {
fn drop(&mut self) {
if let Some(prev) = self.prev.take() {
CURRENT.with(|slot| *slot.borrow_mut() = prev);
}
}
}
pub fn install(cancel: Option<Arc<AtomicBool>>, deadline: Option<Instant>) -> OpInterruptGuard {
let prev = CURRENT.with(|slot| slot.borrow_mut().replace(OpInterrupt { cancel, deadline }));
OpInterruptGuard { prev: Some(prev) }
}
pub fn requested() -> bool {
CURRENT.with(|slot| {
let ctx = slot.borrow();
let Some(ctx) = ctx.as_ref() else {
return false;
};
if ctx
.cancel
.as_ref()
.is_some_and(|token| token.load(Ordering::SeqCst))
{
return true;
}
ctx.deadline
.is_some_and(|deadline| Instant::now() >= deadline)
})
}
pub fn configure_kill_group(command: &mut std::process::Command) {
#[cfg(unix)]
{
use std::os::unix::process::CommandExt;
command.process_group(0);
}
#[cfg(not(unix))]
{
let _ = command;
}
}
pub fn signal_pid_and_group(pid: u32, signal: i32) {
#[cfg(unix)]
{
extern "C" {
fn kill(pid: i32, sig: i32) -> i32;
}
unsafe {
kill(-(pid as i32), signal);
kill(pid as i32, signal);
}
}
#[cfg(not(unix))]
{
let _ = (pid, signal);
}
}
pub enum ChildWait {
Exited(std::process::ExitStatus),
TimedOut,
Interrupted(Option<std::process::ExitStatus>),
}
pub fn wait_child_interruptible(
child: &mut std::process::Child,
timeout: Option<Duration>,
) -> std::io::Result<ChildWait> {
let deadline = timeout.map(|limit| Instant::now() + limit);
loop {
if let Some(status) = child.try_wait()? {
return Ok(ChildWait::Exited(status));
}
if requested() {
let status = terminate_child_group(child);
return Ok(ChildWait::Interrupted(status));
}
if deadline.is_some_and(|deadline| Instant::now() >= deadline) {
if let Some(pid) = child_pid(child) {
signal_pid_and_group(pid, 9);
}
let _ = child.kill();
let _ = child.wait();
return Ok(ChildWait::TimedOut);
}
std::thread::sleep(Duration::from_millis(20));
}
}
pub fn terminate_child_group(child: &mut std::process::Child) -> Option<std::process::ExitStatus> {
#[cfg(unix)]
{
if let Some(pid) = child_pid(child) {
const SIGTERM: i32 = 15;
signal_pid_and_group(pid, SIGTERM);
let grace_deadline = Instant::now() + SUBPROCESS_TERM_GRACE;
loop {
match child.try_wait() {
Ok(Some(status)) => {
signal_pid_and_group(pid, 9);
return Some(status);
}
Ok(None) => {
if Instant::now() >= grace_deadline {
break;
}
std::thread::sleep(Duration::from_millis(20));
}
Err(_) => break,
}
}
signal_pid_and_group(pid, 9);
}
}
let _ = child.kill();
child.wait().ok()
}
fn child_pid(child: &std::process::Child) -> Option<u32> {
let pid = child.id();
(pid > 0).then_some(pid)
}
pub(crate) fn drain_captured_pipe(
rx: &std::sync::mpsc::Receiver<Vec<u8>>,
killed: bool,
child_pid: u32,
) -> Vec<u8> {
use std::sync::mpsc::RecvTimeoutError;
if killed {
return rx
.recv_timeout(Duration::from_millis(100))
.unwrap_or_default();
}
loop {
match rx.recv_timeout(Duration::from_millis(20)) {
Ok(buf) => return buf,
Err(RecvTimeoutError::Disconnected) => return Vec::new(),
Err(RecvTimeoutError::Timeout) => {
if requested() {
const SIGTERM: i32 = 15;
signal_pid_and_group(child_pid, SIGTERM);
if let Ok(buf) = rx.recv_timeout(SUBPROCESS_TERM_GRACE) {
signal_pid_and_group(child_pid, 9);
return buf;
}
signal_pid_and_group(child_pid, 9);
return rx
.recv_timeout(Duration::from_millis(100))
.unwrap_or_default();
}
}
}
}
}
pub(crate) fn spawn_pipe_drain<R: std::io::Read + Send + 'static>(
mut reader: R,
) -> std::sync::mpsc::Receiver<Vec<u8>> {
let (tx, rx) = std::sync::mpsc::channel::<Vec<u8>>();
std::thread::spawn(move || {
let mut buf = Vec::new();
let _ = reader.read_to_end(&mut buf);
let _ = tx.send(buf);
});
rx
}
pub fn capture_output_interruptible(
command: &mut std::process::Command,
) -> std::io::Result<std::process::Output> {
use std::process::Stdio;
command
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.stdin(Stdio::null());
configure_kill_group(command);
let mut child = command.spawn()?;
let pid = child.id();
let rx_out = child.stdout.take().map(spawn_pipe_drain);
let rx_err = child.stderr.take().map(spawn_pipe_drain);
let (status, killed) = match wait_child_interruptible(&mut child, None)? {
ChildWait::Exited(status) => (status, false),
ChildWait::TimedOut => (std::process::ExitStatus::default(), true),
ChildWait::Interrupted(status) => (status.unwrap_or_default(), true),
};
let stdout = rx_out
.map(|rx| drain_captured_pipe(&rx, killed, pid))
.unwrap_or_default();
let stderr = rx_err
.map(|rx| drain_captured_pipe(&rx, killed, pid))
.unwrap_or_default();
Ok(std::process::Output {
status,
stdout,
stderr,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn requested_is_false_without_context() {
assert!(!requested());
}
#[test]
fn cancel_token_trips_requested_and_guard_restores() {
let token = Arc::new(AtomicBool::new(false));
let guard = install(Some(token.clone()), None);
assert!(!requested());
token.store(true, Ordering::SeqCst);
assert!(requested());
drop(guard);
assert!(!requested());
}
#[test]
fn deadline_trips_requested() {
let expired = Instant::now()
.checked_sub(Duration::from_millis(1))
.expect("monotonic clock supports a 1ms test lookback");
let _guard = install(None, Some(expired));
assert!(requested());
}
#[test]
fn nested_installs_restore_in_order() {
let outer_token = Arc::new(AtomicBool::new(true));
let _outer = install(Some(outer_token), None);
assert!(requested());
{
let _inner = install(None, None);
assert!(!requested());
}
assert!(requested());
}
#[cfg(unix)]
#[test]
fn interrupted_wait_kills_process_group() {
let mut command = std::process::Command::new("sh");
command.args(["-c", "sleep 30 & wait"]);
configure_kill_group(&mut command);
let mut child = command.spawn().expect("spawn sh");
let pgid = child.id();
let cancel = Arc::new(AtomicBool::new(true));
let _guard = install(Some(cancel), None);
let started = Instant::now();
let outcome = wait_child_interruptible(&mut child, None).expect("wait");
assert!(matches!(outcome, ChildWait::Interrupted(_)));
assert!(started.elapsed() < Duration::from_secs(10));
extern "C" {
fn kill(pid: i32, sig: i32) -> i32;
}
let group_gone = || unsafe { kill(-(pgid as i32), 0) } != 0;
let deadline = Instant::now() + Duration::from_secs(5);
while !group_gone() && Instant::now() < deadline {
std::thread::sleep(Duration::from_millis(50));
}
assert!(group_gone(), "process group {pgid} survived interrupt");
}
}