use std::io;
#[cfg(unix)]
use nix::errno::Errno;
#[cfg(target_os = "linux")]
use nix::sys::prctl;
#[cfg(unix)]
use nix::sys::signal::{self, Signal};
#[cfg(unix)]
use nix::unistd::{self, Pid};
#[cfg(unix)]
use tokio::process::Child;
pub const DEFAULT_GRACEFUL_TIMEOUT_MS: u64 = 500;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum KillSignal {
Int,
Term,
#[default]
Kill,
}
#[cfg(unix)]
impl KillSignal {
fn as_nix_signal(self) -> Signal {
match self {
KillSignal::Int => Signal::SIGINT,
KillSignal::Term => Signal::SIGTERM,
KillSignal::Kill => Signal::SIGKILL,
}
}
}
#[cfg(unix)]
fn nix_err_to_io(err: Errno) -> io::Error {
io::Error::from_raw_os_error(err as i32)
}
#[cfg(target_os = "linux")]
pub fn set_parent_death_signal(parent_pid: libc::pid_t) -> io::Result<()> {
prctl::set_pdeathsig(Some(Signal::SIGTERM)).map_err(nix_err_to_io)?;
if unistd::getppid() != Pid::from_raw(parent_pid) {
signal::kill(unistd::getpid(), Signal::SIGTERM).map_err(nix_err_to_io)?;
}
Ok(())
}
#[cfg(not(target_os = "linux"))]
pub fn set_parent_death_signal(_parent_pid: i32) -> io::Result<()> {
Ok(())
}
#[cfg(unix)]
pub fn detach_from_tty() -> io::Result<()> {
match unistd::setsid() {
Ok(_) => Ok(()),
Err(Errno::EPERM) => set_process_group(),
Err(err) => Err(nix_err_to_io(err)),
}
}
#[cfg(not(unix))]
pub fn detach_from_tty() -> io::Result<()> {
Ok(())
}
#[cfg(unix)]
pub fn set_process_group() -> io::Result<()> {
unistd::setpgid(Pid::from_raw(0), Pid::from_raw(0)).map_err(nix_err_to_io)
}
#[cfg(not(unix))]
pub fn set_process_group() -> io::Result<()> {
Ok(())
}
#[cfg(unix)]
pub fn kill_process_group_by_pid(pid: u32) -> io::Result<()> {
kill_process_group_by_pid_with_signal(pid, KillSignal::Kill)
}
#[cfg(unix)]
pub fn kill_process_group_by_pid_with_signal(pid: u32, signal: KillSignal) -> io::Result<()> {
use std::io::ErrorKind;
let target_pid = Pid::from_raw(pid as libc::pid_t);
let pgid = unistd::getpgid(Some(target_pid));
let mut pgid_err = None;
match pgid {
Ok(group) => {
if let Err(err) = signal::killpg(group, signal.as_nix_signal()) {
let io_err = nix_err_to_io(err);
if io_err.kind() != ErrorKind::NotFound {
pgid_err = Some(io_err);
}
}
}
Err(err) => pgid_err = Some(nix_err_to_io(err)),
}
if let Err(err) = signal::kill(target_pid, signal.as_nix_signal()) {
let io_err = nix_err_to_io(err);
if io_err.kind() == ErrorKind::NotFound {
return Ok(());
}
if let Some(pgid_error) = pgid_err {
return Err(pgid_error);
}
return Err(io_err);
}
Ok(())
}
#[cfg(not(unix))]
pub fn kill_process_group_by_pid(_pid: u32) -> io::Result<()> {
Ok(())
}
#[cfg(not(unix))]
pub fn kill_process_group_by_pid_with_signal(_pid: u32, _signal: KillSignal) -> io::Result<()> {
Ok(())
}
#[cfg(unix)]
pub fn kill_process_group(process_group_id: u32) -> io::Result<()> {
kill_process_group_with_signal(process_group_id, KillSignal::Kill)
}
#[cfg(unix)]
pub fn kill_process_group_with_signal(process_group_id: u32, signal: KillSignal) -> io::Result<()> {
use std::io::ErrorKind;
let pgid = Pid::from_raw(process_group_id as libc::pid_t);
if let Err(err) = signal::killpg(pgid, signal.as_nix_signal()) {
let io_err = nix_err_to_io(err);
if io_err.kind() != ErrorKind::NotFound {
return Err(io_err);
}
}
Ok(())
}
#[cfg(not(unix))]
pub fn kill_process_group(_process_group_id: u32) -> io::Result<()> {
Ok(())
}
#[cfg(not(unix))]
pub fn kill_process_group_with_signal(
_process_group_id: u32,
_signal: KillSignal,
) -> io::Result<()> {
Ok(())
}
#[cfg(unix)]
pub fn kill_child_process_group(child: &mut Child) -> io::Result<()> {
kill_child_process_group_with_signal(child, KillSignal::Kill)
}
#[cfg(unix)]
pub fn kill_child_process_group_with_signal(
child: &mut Child,
signal: KillSignal,
) -> io::Result<()> {
if let Some(pid) = child.id() {
return kill_process_group_by_pid_with_signal(pid, signal);
}
Ok(())
}
#[cfg(not(unix))]
pub fn kill_child_process_group(_child: &mut tokio::process::Child) -> io::Result<()> {
Ok(())
}
#[cfg(not(unix))]
pub fn kill_child_process_group_with_signal(
_child: &mut tokio::process::Child,
_signal: KillSignal,
) -> io::Result<()> {
Ok(())
}
#[cfg(windows)]
pub fn kill_process(pid: u32) -> io::Result<()> {
let status = std::process::Command::new("taskkill")
.args(["/PID", &pid.to_string(), "/T", "/F"])
.status()?;
if status.success() {
Ok(())
} else {
Err(io::Error::other("taskkill failed"))
}
}
#[cfg(not(windows))]
pub fn kill_process(_pid: u32) -> io::Result<()> {
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GracefulTerminationResult {
GracefulExit,
ForcefulKill,
AlreadyExited,
Error,
}
#[cfg(unix)]
fn is_process_running(pid: u32) -> bool {
let target_pid = Pid::from_raw(pid as libc::pid_t);
match signal::kill(target_pid, None::<Signal>) {
Ok(()) => true,
Err(Errno::EPERM) => true,
Err(_) => false,
}
}
#[cfg(not(unix))]
fn is_process_running(_pid: u32) -> bool {
true
}
#[cfg(unix)]
pub fn graceful_kill_process_group(
pid: u32,
initial_signal: KillSignal,
grace_period: std::time::Duration,
) -> GracefulTerminationResult {
if !is_process_running(pid) {
return GracefulTerminationResult::AlreadyExited;
}
let target_pid = Pid::from_raw(pid as libc::pid_t);
let Ok(pgid) = unistd::getpgid(Some(target_pid)) else {
return GracefulTerminationResult::AlreadyExited;
};
let signal = match initial_signal {
KillSignal::Kill => Signal::SIGTERM, other => other.as_nix_signal(),
};
if let Err(err) = signal::killpg(pgid, signal) {
if err != Errno::ESRCH {
return GracefulTerminationResult::Error;
}
return GracefulTerminationResult::AlreadyExited;
}
let deadline = std::time::Instant::now() + grace_period;
let poll_interval = std::time::Duration::from_millis(10);
while std::time::Instant::now() < deadline {
if !is_process_running(pid) {
return GracefulTerminationResult::GracefulExit;
}
std::thread::sleep(poll_interval);
}
let _ = signal::killpg(pgid, Signal::SIGKILL);
if let Err(err) = signal::kill(target_pid, Signal::SIGKILL) {
if err == Errno::ESRCH {
return GracefulTerminationResult::GracefulExit;
}
return GracefulTerminationResult::Error;
}
GracefulTerminationResult::ForcefulKill
}
#[cfg(not(unix))]
pub fn graceful_kill_process_group(
pid: u32,
initial_signal: KillSignal,
grace_period: std::time::Duration,
) -> GracefulTerminationResult {
#[cfg(windows)]
{
let _ = initial_signal;
let pid_arg = pid.to_string();
match std::process::Command::new("taskkill")
.args(["/PID", &pid_arg, "/T"])
.status()
{
Ok(status) if status.success() => {
std::thread::sleep(grace_period);
GracefulTerminationResult::GracefulExit
}
Ok(_) => match kill_process(pid) {
Ok(()) => GracefulTerminationResult::ForcefulKill,
Err(_) => GracefulTerminationResult::AlreadyExited,
},
Err(_) => GracefulTerminationResult::Error,
}
}
#[cfg(not(windows))]
{
let _ = (pid, initial_signal, grace_period);
GracefulTerminationResult::Error
}
}
pub fn graceful_kill_process_group_default(pid: u32) -> GracefulTerminationResult {
graceful_kill_process_group(
pid,
KillSignal::Term,
std::time::Duration::from_millis(DEFAULT_GRACEFUL_TIMEOUT_MS),
)
}
pub async fn graceful_kill_process_group_default_async(pid: u32) -> GracefulTerminationResult {
tokio::task::spawn_blocking(move || graceful_kill_process_group_default(pid))
.await
.unwrap_or(GracefulTerminationResult::Error)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_set_parent_death_signal_no_panic() {
#[cfg(target_os = "linux")]
{
let parent_pid = unistd::getpid().as_raw();
let _ = set_parent_death_signal(parent_pid);
}
#[cfg(not(target_os = "linux"))]
{
assert!(set_parent_death_signal(0).is_ok());
}
}
#[test]
fn test_kill_nonexistent_process_group() {
#[cfg(unix)]
{
let result = kill_process_group(2_000_000_000);
let _ = result;
}
#[cfg(not(unix))]
{
let result = kill_process_group(999_999);
assert!(result.is_ok());
}
}
#[test]
fn test_kill_signal_values() {
assert_ne!(KillSignal::Int, KillSignal::Term);
assert_ne!(KillSignal::Term, KillSignal::Kill);
assert_ne!(KillSignal::Int, KillSignal::Kill);
assert_eq!(KillSignal::default(), KillSignal::Kill);
}
#[test]
fn test_graceful_termination_result_debug() {
let results = [
GracefulTerminationResult::GracefulExit,
GracefulTerminationResult::ForcefulKill,
GracefulTerminationResult::AlreadyExited,
GracefulTerminationResult::Error,
];
for result in &results {
let _ = format!("{result:?}");
}
}
#[test]
fn test_graceful_kill_nonexistent_process() {
let result = graceful_kill_process_group_default(2_000_000_000);
#[cfg(unix)]
{
assert_eq!(result, GracefulTerminationResult::AlreadyExited);
}
#[cfg(not(unix))]
{
let _ = result;
}
}
#[tokio::test]
async fn test_graceful_kill_nonexistent_process_async() {
let result = graceful_kill_process_group_default_async(2_000_000_000).await;
#[cfg(unix)]
{
assert_eq!(result, GracefulTerminationResult::AlreadyExited);
}
#[cfg(not(unix))]
{
let _ = result;
}
}
#[cfg(unix)]
#[test]
fn test_is_process_running_self() {
let pid = std::process::id();
assert!(is_process_running(pid));
}
#[cfg(unix)]
#[test]
fn test_is_process_running_nonexistent() {
assert!(!is_process_running(2_000_000_000));
}
}