use std::collections::HashSet;
use std::sync::{Arc, Mutex};
#[derive(Clone, Default)]
pub struct LspChildRegistry {
inner: Arc<Mutex<HashSet<u32>>>,
}
impl LspChildRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn track(&self, pid: u32) {
if let Ok(mut set) = self.inner.lock() {
set.insert(pid);
}
}
pub fn untrack(&self, pid: u32) {
if let Ok(mut set) = self.inner.lock() {
set.remove(&pid);
}
}
pub fn pids(&self) -> Vec<u32> {
self.inner
.lock()
.map(|set| set.iter().copied().collect())
.unwrap_or_default()
}
#[cfg(unix)]
pub fn kill_all(&self) -> usize {
use std::os::raw::c_int;
let pids = self.pids();
let mut killed = 0;
for pid in pids {
unsafe {
let pgid = pid as libc::pid_t;
let rc = libc::killpg(pgid, 9 as c_int);
if rc == 0 {
killed += 1;
}
}
}
killed
}
#[cfg(not(unix))]
pub fn kill_all(&self) -> usize {
let pids = self.pids();
let mut killed = 0;
for pid in pids {
if std::process::Command::new("taskkill")
.args(["/F", "/T", "/PID", &pid.to_string()])
.status()
.is_ok()
{
killed += 1;
}
}
killed
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn track_untrack_pids_round_trip() {
let reg = LspChildRegistry::new();
reg.track(100);
reg.track(200);
let mut pids = reg.pids();
pids.sort();
assert_eq!(pids, vec![100, 200]);
reg.untrack(100);
assert_eq!(reg.pids(), vec![200]);
}
#[test]
fn clones_share_state() {
let a = LspChildRegistry::new();
let b = a.clone();
a.track(42);
assert_eq!(b.pids(), vec![42]);
b.untrack(42);
assert!(a.pids().is_empty());
}
#[test]
fn untracking_unknown_pid_is_safe() {
let reg = LspChildRegistry::new();
reg.untrack(999); assert!(reg.pids().is_empty());
}
#[test]
fn kill_all_with_no_pids_returns_zero() {
let reg = LspChildRegistry::new();
assert_eq!(reg.kill_all(), 0);
}
#[cfg(unix)]
#[test]
fn kill_all_kills_process_group_not_just_wrapper_pid() {
use std::os::unix::process::CommandExt;
use std::process::{Command, Stdio};
use std::thread;
use std::time::Duration;
let mut child = unsafe {
let mut cmd = Command::new("sh");
cmd.arg("-c")
.arg("sleep 60 & echo $! ; wait")
.stdout(Stdio::piped())
.stderr(Stdio::null());
cmd.pre_exec(|| {
if libc::setsid() == -1 {
return Err(std::io::Error::last_os_error());
}
Ok(())
});
cmd.spawn().expect("spawn wrapper")
};
let mut stdout = child.stdout.take().expect("stdout pipe");
let mut buf = String::new();
use std::io::Read;
let mut byte = [0u8; 1];
let deadline = std::time::Instant::now() + Duration::from_secs(2);
while std::time::Instant::now() < deadline {
match stdout.read(&mut byte) {
Ok(0) => break,
Ok(_) => {
if byte[0] == b'\n' {
break;
}
buf.push(byte[0] as char);
}
Err(_) => break,
}
}
let grandchild_pid: u32 = buf.trim().parse().expect("parse grandchild PID");
let wrapper_pid = child.id();
assert!(
crate::bash_background::process::is_process_alive(wrapper_pid),
"wrapper should be alive"
);
assert!(
crate::bash_background::process::is_process_alive(grandchild_pid),
"grandchild should be alive"
);
let reg = LspChildRegistry::new();
reg.track(wrapper_pid);
let killed = reg.kill_all();
assert_eq!(killed, 1, "should report 1 group killed");
let _ = child.wait();
thread::sleep(Duration::from_millis(100));
assert!(
!crate::bash_background::process::is_process_alive(wrapper_pid),
"wrapper must be dead after killpg"
);
assert!(
!crate::bash_background::process::is_process_alive(grandchild_pid),
"grandchild must be dead after killpg (this was the npm-wrapper orphan bug)"
);
}
}