1use anyhow::Result;
2
3pub fn run_with_timeout(
15 mut cmd: std::process::Command,
16 timeout: std::time::Duration,
17) -> Option<std::process::Output> {
18 use std::process::Stdio;
19 use std::time::Instant;
20
21 let mut child = cmd
22 .stdin(Stdio::null())
23 .stdout(Stdio::piped())
24 .stderr(Stdio::piped())
25 .spawn()
26 .ok()?;
27
28 let start = Instant::now();
29 loop {
30 match child.try_wait() {
31 Ok(Some(_)) => return child.wait_with_output().ok(),
33 Ok(None) => {
34 if start.elapsed() >= timeout {
35 let _ = child.kill();
36 let _ = child.wait();
37 return None;
38 }
39 std::thread::sleep(std::time::Duration::from_millis(50));
40 }
41 Err(_) => return None,
42 }
43 }
44}
45
46pub fn is_alive(pid: u32) -> bool {
48 #[cfg(unix)]
49 {
50 unsafe { libc::kill(pid as libc::pid_t, 0) == 0 }
54 }
55 #[cfg(windows)]
56 {
57 use windows_sys::Win32::Foundation::{CloseHandle, STILL_ACTIVE, WAIT_TIMEOUT};
58 use windows_sys::Win32::System::Threading::{
59 GetExitCodeProcess, OpenProcess, WaitForSingleObject, PROCESS_QUERY_LIMITED_INFORMATION,
60 };
61
62 unsafe {
66 let handle = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, 0, pid);
67 if handle.is_null() {
68 return false;
69 }
70 let wait = WaitForSingleObject(handle, 0);
71 if wait == WAIT_TIMEOUT {
72 CloseHandle(handle);
73 return true;
74 }
75 let mut exit_code: u32 = 0;
76 GetExitCodeProcess(handle, &mut exit_code);
77 CloseHandle(handle);
78 exit_code == STILL_ACTIVE as u32
79 }
80 }
81}
82
83pub fn terminate_gracefully(pid: u32) -> Result<()> {
86 #[cfg(unix)]
87 {
88 let ret = unsafe { libc::kill(pid as libc::pid_t, libc::SIGTERM) };
91 if ret != 0 {
92 anyhow::bail!(
93 "Failed to send SIGTERM to PID {pid}: {}",
94 std::io::Error::last_os_error()
95 );
96 }
97 Ok(())
98 }
99 #[cfg(windows)]
100 {
101 force_kill(pid)
102 }
103}
104
105pub fn force_kill(pid: u32) -> Result<()> {
107 #[cfg(unix)]
108 {
109 let ret = unsafe { libc::kill(pid as libc::pid_t, libc::SIGKILL) };
112 if ret != 0 {
113 anyhow::bail!(
114 "Failed to send SIGKILL to PID {pid}: {}",
115 std::io::Error::last_os_error()
116 );
117 }
118 Ok(())
119 }
120 #[cfg(windows)]
121 {
122 use windows_sys::Win32::Foundation::CloseHandle;
123 use windows_sys::Win32::System::Threading::{
124 OpenProcess, TerminateProcess, PROCESS_TERMINATE,
125 };
126
127 unsafe {
130 let handle = OpenProcess(PROCESS_TERMINATE, 0, pid);
131 if handle.is_null() {
132 anyhow::bail!(
133 "Failed to open PID {pid} for termination: {}",
134 std::io::Error::last_os_error()
135 );
136 }
137 let ok = TerminateProcess(handle, 1);
138 CloseHandle(handle);
139 if ok == 0 {
140 anyhow::bail!(
141 "Failed to terminate PID {pid}: {}",
142 std::io::Error::last_os_error()
143 );
144 }
145 Ok(())
146 }
147 }
148}
149
150pub fn find_pids_by_name(name: &str) -> Vec<u32> {
153 let my_pid = std::process::id();
154 let mut pids = Vec::new();
155
156 #[cfg(unix)]
157 {
158 if let Ok(output) = std::process::Command::new("pgrep")
160 .arg("-x")
161 .arg(name)
162 .output()
163 {
164 collect_pids(&output.stdout, my_pid, &mut pids);
165 }
166
167 if let Ok(output) = std::process::Command::new("pgrep")
170 .arg("-f")
171 .arg(format!("/{name}(\\s|$)"))
172 .output()
173 {
174 collect_pids(&output.stdout, my_pid, &mut pids);
175 }
176
177 pids.sort_unstable();
178 pids.dedup();
179 }
180
181 #[cfg(windows)]
182 {
183 if let Ok(output) = std::process::Command::new("tasklist")
184 .args([
185 "/FI",
186 &format!("IMAGENAME eq {name}.exe"),
187 "/FO",
188 "CSV",
189 "/NH",
190 ])
191 .output()
192 {
193 let stdout = String::from_utf8_lossy(&output.stdout);
194 for line in stdout.lines() {
195 let parts: Vec<&str> = line.split(',').collect();
196 if parts.len() >= 2 {
197 let pid_str = parts[1].trim().trim_matches('"');
198 if let Ok(pid) = pid_str.parse::<u32>() {
199 if pid != my_pid {
200 pids.push(pid);
201 }
202 }
203 }
204 }
205 }
206 }
207
208 pids
209}
210
211#[cfg(unix)]
212fn collect_pids(stdout: &[u8], exclude_pid: u32, out: &mut Vec<u32>) {
213 let text = String::from_utf8_lossy(stdout);
214 for line in text.lines() {
215 if let Ok(pid) = line.trim().parse::<u32>() {
216 if pid != exclude_pid {
217 out.push(pid);
218 }
219 }
220 }
221}
222
223pub fn find_killable_pids(name: &str) -> Vec<u32> {
227 let all = find_pids_by_name(name);
228 let mcp_pids = find_mcp_server_pids(name);
229 all.into_iter().filter(|p| !mcp_pids.contains(p)).collect()
230}
231
232#[cfg(unix)]
233fn find_mcp_server_pids(name: &str) -> Vec<u32> {
234 find_pids_by_name(name)
235 .into_iter()
236 .filter(|&pid| is_mcp_stdio_process(pid))
237 .collect()
238}
239
240#[cfg(not(unix))]
241fn find_mcp_server_pids(_name: &str) -> Vec<u32> {
242 Vec::new()
243}
244
245#[cfg(unix)]
246fn is_mcp_stdio_process(pid: u32) -> bool {
247 if let Ok(output) = std::process::Command::new("ps")
248 .args(["-o", "ppid=,command=", "-p", &pid.to_string()])
249 .output()
250 {
251 let text = String::from_utf8_lossy(&output.stdout);
252 let t = text.trim();
253 if t.contains("Cursor") || t.contains("cursor") || t.contains("code") {
254 return true;
255 }
256 let parts: Vec<&str> = t.split_whitespace().collect();
257 if let Some(ppid_str) = parts.first() {
258 if let Ok(ppid) = ppid_str.parse::<u32>() {
259 if let Ok(pp_out) = std::process::Command::new("ps")
260 .args(["-o", "command=", "-p", &ppid.to_string()])
261 .output()
262 {
263 let pp_cmd = String::from_utf8_lossy(&pp_out.stdout);
264 if pp_cmd.contains("Cursor")
265 || pp_cmd.contains("cursor")
266 || pp_cmd.contains("code")
267 {
268 return true;
269 }
270 }
271 }
272 }
273 let cmd_part = parts.get(1..).map(|p| p.join(" ")).unwrap_or_default();
274 if (cmd_part.ends_with("/lean-ctx") || cmd_part == "lean-ctx")
276 && !cmd_part.contains("proxy")
277 && !cmd_part.contains("dashboard")
278 && !cmd_part.contains("daemon")
279 && !cmd_part.contains("stop")
280 && !cmd_part.contains("hook")
281 {
282 return true;
283 }
284 if cmd_part.contains("hook observe")
286 || cmd_part.contains("hook rewrite")
287 || cmd_part.contains("hook redirect")
288 {
289 return true;
290 }
291 }
292 false
293}
294
295pub fn kill_all_by_name(name: &str) -> usize {
298 let pids = find_killable_pids(name);
299 if pids.is_empty() {
300 return 0;
301 }
302
303 for &pid in &pids {
304 let _ = terminate_gracefully(pid);
305 }
306
307 std::thread::sleep(std::time::Duration::from_millis(500));
308
309 let mut killed = 0;
310 for &pid in &pids {
311 if is_alive(pid) {
312 let _ = force_kill(pid);
313 }
314 killed += 1;
315 }
316
317 std::thread::sleep(std::time::Duration::from_millis(200));
318
319 killed
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn current_process_is_alive() {
328 assert!(is_alive(std::process::id()));
329 }
330
331 #[test]
332 fn bogus_pid_is_not_alive() {
333 assert!(!is_alive(u32::MAX - 42));
334 }
335
336 #[cfg(unix)]
337 #[test]
338 fn run_with_timeout_returns_output_for_fast_command() {
339 let mut cmd = std::process::Command::new("echo");
340 cmd.arg("hello");
341 let out = run_with_timeout(cmd, std::time::Duration::from_secs(5))
342 .expect("fast command should complete");
343 assert!(out.status.success());
344 assert_eq!(String::from_utf8_lossy(&out.stdout).trim(), "hello");
345 }
346
347 #[cfg(unix)]
348 #[test]
349 fn run_with_timeout_kills_slow_command() {
350 let mut cmd = std::process::Command::new("sleep");
351 cmd.arg("30");
352 let start = std::time::Instant::now();
353 let result = run_with_timeout(cmd, std::time::Duration::from_millis(300));
354 assert!(result.is_none(), "slow command must time out");
355 assert!(
356 start.elapsed() < std::time::Duration::from_secs(5),
357 "timeout must not wait for the full command"
358 );
359 }
360}