use std::collections::HashSet;
use std::io;
use std::process::{Child, Command};
use std::sync::{Arc, Mutex};
#[derive(Clone, Default)]
pub struct LspChildRegistry {
inner: Arc<Mutex<HashSet<u32>>>,
}
impl LspChildRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn track(&self, pid: u32) {
if let Ok(mut set) = self.inner.lock() {
set.insert(pid);
}
}
pub fn spawn_tracked(&self, command: &mut Command) -> io::Result<Child> {
let mut set = self
.inner
.lock()
.map_err(|_| io::Error::other("LSP child registry mutex poisoned"))?;
let child = command.spawn()?;
set.insert(child.id());
Ok(child)
}
pub fn untrack(&self, pid: u32) {
if let Ok(mut set) = self.inner.lock() {
set.remove(&pid);
}
}
pub fn pids(&self) -> Vec<u32> {
self.inner
.lock()
.map(|set| set.iter().copied().collect())
.unwrap_or_default()
}
#[cfg(unix)]
pub fn kill_all(&self) -> usize {
use std::os::raw::c_int;
let pids = self.pids();
let mut killed = 0;
for pid in pids {
unsafe {
let pgid = pid as libc::pid_t;
let rc = libc::killpg(pgid, 9 as c_int);
if rc == 0 {
killed += 1;
}
}
}
killed
}
#[cfg(not(unix))]
pub fn kill_all(&self) -> usize {
let pids = self.pids();
let mut killed = 0;
for pid in pids {
if std::process::Command::new("taskkill")
.args(["/F", "/T", "/PID", &pid.to_string()])
.status()
.is_ok()
{
killed += 1;
}
}
killed
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn track_untrack_pids_round_trip() {
let reg = LspChildRegistry::new();
reg.track(100);
reg.track(200);
let mut pids = reg.pids();
pids.sort();
assert_eq!(pids, vec![100, 200]);
reg.untrack(100);
assert_eq!(reg.pids(), vec![200]);
}
#[test]
fn clones_share_state() {
let a = LspChildRegistry::new();
let b = a.clone();
a.track(42);
assert_eq!(b.pids(), vec![42]);
b.untrack(42);
assert!(a.pids().is_empty());
}
#[test]
fn untracking_unknown_pid_is_safe() {
let reg = LspChildRegistry::new();
reg.untrack(999); assert!(reg.pids().is_empty());
}
#[test]
fn kill_all_with_no_pids_returns_zero() {
let reg = LspChildRegistry::new();
assert_eq!(reg.kill_all(), 0);
}
#[test]
fn spawn_tracked_records_pid_before_returning() {
let reg = LspChildRegistry::new();
let mut command = if cfg!(windows) {
let mut command = std::process::Command::new("cmd");
command.args(["/C", "exit", "0"]);
command
} else {
let mut command = std::process::Command::new("sh");
command.args(["-c", "exit 0"]);
command
};
let mut child = reg.spawn_tracked(&mut command).expect("spawn tracked");
let pid = child.id();
assert!(reg.pids().contains(&pid));
let _ = child.wait();
reg.untrack(pid);
}
#[cfg(unix)]
#[test]
fn kill_all_kills_process_group_not_just_wrapper_pid() {
use std::os::unix::process::CommandExt;
use std::process::Command;
use std::thread;
use std::time::{Duration, Instant};
fn process_running(pid: u32) -> bool {
let Ok(pid_i) = i32::try_from(pid) else {
return false;
};
let output = Command::new("ps")
.args(["-o", "stat=", "-p", &pid_i.to_string()])
.output()
.expect("ps");
if !output.status.success() {
return false;
}
let stat = String::from_utf8_lossy(&output.stdout);
!stat.is_empty() && !stat.contains('Z')
}
fn wait_until_not_running(pid: u32, timeout: Duration) -> bool {
let started = Instant::now();
while started.elapsed() < timeout {
if !process_running(pid) {
return true;
}
thread::sleep(Duration::from_millis(50));
}
false
}
let dir = tempfile::tempdir().expect("tempdir");
let pid_file = dir.path().join("grandchild.pid");
const PID_FILE_ENV: &str = "AFT_LSP_KILLALL_TEST_PID_FILE";
let mut child = unsafe {
let mut cmd = Command::new("sh");
cmd.arg("-c")
.arg("sleep 60 & echo $! > \"$AFT_LSP_KILLALL_TEST_PID_FILE\"; wait")
.env(PID_FILE_ENV, &pid_file);
cmd.pre_exec(|| {
if libc::setsid() == -1 {
return Err(std::io::Error::last_os_error());
}
Ok(())
});
cmd.spawn().expect("spawn wrapper")
};
let wrapper_pid = child.id();
let started = Instant::now();
while !pid_file.exists() {
assert!(
started.elapsed() < Duration::from_secs(2),
"timed out waiting for grandchild pid file"
);
thread::sleep(Duration::from_millis(50));
}
let grandchild_pid: u32 = std::fs::read_to_string(&pid_file)
.expect("read grandchild pid")
.trim()
.parse()
.expect("parse grandchild PID");
assert!(process_running(wrapper_pid), "wrapper should be running");
assert!(
process_running(grandchild_pid),
"grandchild should be running"
);
let reg = LspChildRegistry::new();
reg.track(wrapper_pid);
let killed = reg.kill_all();
assert_eq!(killed, 1, "should report 1 group killed");
let _ = child.wait();
assert!(
wait_until_not_running(wrapper_pid, Duration::from_secs(5)),
"wrapper must stop after killpg"
);
assert!(
wait_until_not_running(grandchild_pid, Duration::from_secs(5)),
"grandchild must stop after killpg (this was the npm-wrapper orphan bug)"
);
}
}