Skip to main content

pitchfork_cli/
procs.rs

1use crate::Result;
2use miette::IntoDiagnostic;
3use once_cell::sync::Lazy;
4use std::sync::Mutex;
5use sysinfo::ProcessesToUpdate;
6#[cfg(unix)]
7use sysinfo::Signal;
8
9pub struct Procs {
10    system: Mutex<sysinfo::System>,
11}
12
13pub static PROCS: Lazy<Procs> = Lazy::new(Procs::new);
14
15impl Default for Procs {
16    fn default() -> Self {
17        Self::new()
18    }
19}
20
21impl Procs {
22    pub fn new() -> Self {
23        let procs = Self {
24            system: Mutex::new(sysinfo::System::new()),
25        };
26        procs.refresh_processes();
27        procs
28    }
29
30    fn lock_system(&self) -> std::sync::MutexGuard<'_, sysinfo::System> {
31        self.system.lock().unwrap_or_else(|poisoned| {
32            warn!("System mutex was poisoned, recovering");
33            poisoned.into_inner()
34        })
35    }
36
37    pub fn title(&self, pid: u32) -> Option<String> {
38        self.lock_system()
39            .process(sysinfo::Pid::from_u32(pid))
40            .map(|p| p.name().to_string_lossy().to_string())
41    }
42
43    pub fn is_running(&self, pid: u32) -> bool {
44        self.lock_system()
45            .process(sysinfo::Pid::from_u32(pid))
46            .is_some()
47    }
48
49    /// Walk the /proc tree to find all descendant PIDs.
50    /// Kept for diagnostics/status display; no longer used in the kill path.
51    #[allow(dead_code)]
52    pub fn all_children(&self, pid: u32) -> Vec<u32> {
53        let system = self.lock_system();
54        let all = system.processes();
55        let mut children = vec![];
56        for (child_pid, process) in all {
57            let mut process = process;
58            while let Some(parent) = process.parent() {
59                if parent == sysinfo::Pid::from_u32(pid) {
60                    children.push(child_pid.as_u32());
61                    break;
62                }
63                match system.process(parent) {
64                    Some(p) => process = p,
65                    None => break,
66                }
67            }
68        }
69        children
70    }
71
72    pub async fn kill_process_group_async(&self, pid: u32) -> Result<bool> {
73        let result = tokio::task::spawn_blocking(move || PROCS.kill_process_group(pid))
74            .await
75            .into_diagnostic()?;
76        Ok(result)
77    }
78
79    /// Kill an entire process group with graceful shutdown strategy:
80    /// 1. Send SIGTERM to the process group (-pgid) and wait up to ~3s
81    /// 2. If any processes remain, send SIGKILL to the group
82    ///
83    /// Since daemons are spawned with setsid(), the daemon PID == PGID,
84    /// so this atomically signals all descendant processes.
85    #[cfg(unix)]
86    fn kill_process_group(&self, pid: u32) -> bool {
87        let pgid = pid as i32;
88
89        debug!("killing process group {pgid}");
90
91        // Send SIGTERM to the entire process group.
92        // killpg sends to all processes in the group atomically.
93        // We intentionally skip the zombie check here because the leader may be
94        // a zombie while children in the group are still running.
95        let ret = unsafe { libc::killpg(pgid, libc::SIGTERM) };
96        if ret == -1 {
97            let err = std::io::Error::last_os_error();
98            if err.raw_os_error() == Some(libc::ESRCH) {
99                debug!("process group {pgid} no longer exists");
100                return false;
101            }
102            warn!("failed to send SIGTERM to process group {pgid}: {err}");
103        }
104
105        // Wait for graceful shutdown: fast initial check then slower polling.
106        let fast_checks = std::iter::repeat_n(std::time::Duration::from_millis(10), 10);
107        let slow_checks = std::iter::repeat_n(std::time::Duration::from_millis(50), 58);
108        let mut elapsed_ms = 0u64;
109
110        for sleep_duration in fast_checks.chain(slow_checks) {
111            std::thread::sleep(sleep_duration);
112            self.refresh_pids(&[pid]);
113            elapsed_ms += sleep_duration.as_millis() as u64;
114            if self.is_terminated_or_zombie(sysinfo::Pid::from_u32(pid)) {
115                debug!("process group {pgid} terminated after SIGTERM ({elapsed_ms} ms)",);
116                return true;
117            }
118        }
119
120        // SIGKILL the entire process group as last resort
121        warn!("process group {pgid} did not respond to SIGTERM after ~3s, sending SIGKILL");
122        unsafe {
123            libc::killpg(pgid, libc::SIGKILL);
124        }
125
126        // Brief wait for SIGKILL to take effect
127        std::thread::sleep(std::time::Duration::from_millis(100));
128        true
129    }
130
131    #[cfg(not(unix))]
132    fn kill_process_group(&self, pid: u32) -> bool {
133        // On non-unix platforms, fall back to single-process kill
134        self.kill(pid)
135    }
136
137    pub async fn kill_async(&self, pid: u32) -> Result<bool> {
138        let result = tokio::task::spawn_blocking(move || PROCS.kill(pid))
139            .await
140            .into_diagnostic()?;
141        Ok(result)
142    }
143
144    /// Kill a process with graceful shutdown strategy:
145    /// 1. Send SIGTERM and wait up to ~3s (10ms intervals for first 100ms, then 50ms intervals)
146    /// 2. If still running, send SIGKILL to force termination
147    ///
148    /// This ensures fast-exiting processes don't wait unnecessarily,
149    /// while stubborn processes eventually get forcefully terminated.
150    fn kill(&self, pid: u32) -> bool {
151        let sysinfo_pid = sysinfo::Pid::from_u32(pid);
152
153        // Check if process exists or is a zombie (already terminated but not reaped)
154        if self.is_terminated_or_zombie(sysinfo_pid) {
155            return false;
156        }
157
158        debug!("killing process {pid}");
159
160        #[cfg(windows)]
161        {
162            if let Some(process) = self.lock_system().process(sysinfo_pid) {
163                process.kill();
164                process.wait();
165            }
166            return true;
167        }
168
169        #[cfg(unix)]
170        {
171            // Send SIGTERM for graceful shutdown
172            if let Some(process) = self.lock_system().process(sysinfo_pid) {
173                debug!("sending SIGTERM to process {pid}");
174                process.kill_with(Signal::Term);
175            }
176
177            // Fast check: 10ms intervals for first 100ms (for processes that exit immediately)
178            for i in 0..10 {
179                std::thread::sleep(std::time::Duration::from_millis(10));
180                self.refresh_pids(&[pid]);
181                if self.is_terminated_or_zombie(sysinfo_pid) {
182                    debug!(
183                        "process {pid} terminated after SIGTERM ({} ms)",
184                        (i + 1) * 10
185                    );
186                    return true;
187                }
188            }
189
190            // Slower check: 50ms intervals for up to ~3 more seconds (100ms + 2900ms = 3000ms total)
191            for i in 0..58 {
192                std::thread::sleep(std::time::Duration::from_millis(50));
193                self.refresh_pids(&[pid]);
194                if self.is_terminated_or_zombie(sysinfo_pid) {
195                    debug!(
196                        "process {pid} terminated after SIGTERM ({} ms)",
197                        100 + (i + 1) * 50
198                    );
199                    return true;
200                }
201            }
202
203            // SIGKILL as last resort after ~3s
204            if let Some(process) = self.lock_system().process(sysinfo_pid) {
205                warn!("process {pid} did not respond to SIGTERM after ~3s, sending SIGKILL");
206                process.kill_with(Signal::Kill);
207                process.wait();
208            }
209
210            true
211        }
212    }
213
214    /// Check if a process is terminated or is a zombie.
215    /// On Linux, zombie processes still have /proc/[pid] entries but are effectively dead.
216    /// This prevents unnecessary signal escalation for processes that have already exited.
217    fn is_terminated_or_zombie(&self, sysinfo_pid: sysinfo::Pid) -> bool {
218        let system = self.lock_system();
219        match system.process(sysinfo_pid) {
220            None => true,
221            Some(process) => {
222                #[cfg(unix)]
223                {
224                    matches!(process.status(), sysinfo::ProcessStatus::Zombie)
225                }
226                #[cfg(not(unix))]
227                {
228                    let _ = process;
229                    false
230                }
231            }
232        }
233    }
234
235    pub(crate) fn refresh_processes(&self) {
236        self.lock_system()
237            .refresh_processes(ProcessesToUpdate::All, true);
238    }
239
240    /// Refresh only specific PIDs instead of all processes.
241    /// More efficient when you only need to check a small set of known PIDs.
242    pub(crate) fn refresh_pids(&self, pids: &[u32]) {
243        let sysinfo_pids: Vec<sysinfo::Pid> =
244            pids.iter().map(|p| sysinfo::Pid::from_u32(*p)).collect();
245        self.lock_system()
246            .refresh_processes(ProcessesToUpdate::Some(&sysinfo_pids), true);
247    }
248
249    /// Get process stats (cpu%, memory bytes, uptime secs, disk I/O) for a given PID
250    pub fn get_stats(&self, pid: u32) -> Option<ProcessStats> {
251        let system = self.lock_system();
252        system.process(sysinfo::Pid::from_u32(pid)).map(|p| {
253            let now = std::time::SystemTime::now()
254                .duration_since(std::time::UNIX_EPOCH)
255                .map(|d| d.as_secs())
256                .unwrap_or(0);
257            let disk = p.disk_usage();
258            ProcessStats {
259                cpu_percent: p.cpu_usage(),
260                memory_bytes: p.memory(),
261                uptime_secs: now.saturating_sub(p.start_time()),
262                disk_read_bytes: disk.read_bytes,
263                disk_write_bytes: disk.written_bytes,
264            }
265        })
266    }
267
268    /// Get extended process information for a given PID
269    pub fn get_extended_stats(&self, pid: u32) -> Option<ExtendedProcessStats> {
270        let system = self.lock_system();
271        system.process(sysinfo::Pid::from_u32(pid)).map(|p| {
272            let now = std::time::SystemTime::now()
273                .duration_since(std::time::UNIX_EPOCH)
274                .map(|d| d.as_secs())
275                .unwrap_or(0);
276            let disk = p.disk_usage();
277
278            ExtendedProcessStats {
279                name: p.name().to_string_lossy().to_string(),
280                exe_path: p.exe().map(|e| e.to_string_lossy().to_string()),
281                cwd: p.cwd().map(|c| c.to_string_lossy().to_string()),
282                environ: p
283                    .environ()
284                    .iter()
285                    .take(20)
286                    .map(|s| s.to_string_lossy().to_string())
287                    .collect(),
288                status: format!("{:?}", p.status()),
289                cpu_percent: p.cpu_usage(),
290                memory_bytes: p.memory(),
291                virtual_memory_bytes: p.virtual_memory(),
292                uptime_secs: now.saturating_sub(p.start_time()),
293                start_time: p.start_time(),
294                disk_read_bytes: disk.read_bytes,
295                disk_write_bytes: disk.written_bytes,
296                parent_pid: p.parent().map(|pp| pp.as_u32()),
297                thread_count: p.tasks().map(|t| t.len()).unwrap_or(0),
298                user_id: p.user_id().map(|u| u.to_string()),
299            }
300        })
301    }
302}
303
304#[derive(Debug, Clone, Copy)]
305pub struct ProcessStats {
306    pub cpu_percent: f32,
307    pub memory_bytes: u64,
308    pub uptime_secs: u64,
309    pub disk_read_bytes: u64,
310    pub disk_write_bytes: u64,
311}
312
313impl ProcessStats {
314    pub fn memory_display(&self) -> String {
315        format_bytes(self.memory_bytes)
316    }
317
318    pub fn cpu_display(&self) -> String {
319        format!("{:.1}%", self.cpu_percent)
320    }
321
322    pub fn uptime_display(&self) -> String {
323        format_duration(self.uptime_secs)
324    }
325
326    pub fn disk_read_display(&self) -> String {
327        format_bytes_per_sec(self.disk_read_bytes)
328    }
329
330    pub fn disk_write_display(&self) -> String {
331        format_bytes_per_sec(self.disk_write_bytes)
332    }
333}
334
335/// Extended process stats with more detailed information
336#[derive(Debug, Clone)]
337pub struct ExtendedProcessStats {
338    pub name: String,
339    pub exe_path: Option<String>,
340    pub cwd: Option<String>,
341    pub environ: Vec<String>,
342    pub status: String,
343    pub cpu_percent: f32,
344    pub memory_bytes: u64,
345    pub virtual_memory_bytes: u64,
346    pub uptime_secs: u64,
347    pub start_time: u64,
348    pub disk_read_bytes: u64,
349    pub disk_write_bytes: u64,
350    pub parent_pid: Option<u32>,
351    pub thread_count: usize,
352    pub user_id: Option<String>,
353}
354
355impl ExtendedProcessStats {
356    pub fn memory_display(&self) -> String {
357        format_bytes(self.memory_bytes)
358    }
359
360    pub fn virtual_memory_display(&self) -> String {
361        format_bytes(self.virtual_memory_bytes)
362    }
363
364    pub fn cpu_display(&self) -> String {
365        format!("{:.1}%", self.cpu_percent)
366    }
367
368    pub fn uptime_display(&self) -> String {
369        format_duration(self.uptime_secs)
370    }
371
372    pub fn start_time_display(&self) -> String {
373        use std::time::{Duration, UNIX_EPOCH};
374        let datetime = UNIX_EPOCH + Duration::from_secs(self.start_time);
375        chrono::DateTime::<chrono::Local>::from(datetime)
376            .format("%Y-%m-%d %H:%M:%S")
377            .to_string()
378    }
379
380    pub fn disk_read_display(&self) -> String {
381        format_bytes_per_sec(self.disk_read_bytes)
382    }
383
384    pub fn disk_write_display(&self) -> String {
385        format_bytes_per_sec(self.disk_write_bytes)
386    }
387}
388
389fn format_bytes(bytes: u64) -> String {
390    if bytes < 1024 {
391        format!("{bytes}B")
392    } else if bytes < 1024 * 1024 {
393        format!("{:.1}KB", bytes as f64 / 1024.0)
394    } else if bytes < 1024 * 1024 * 1024 {
395        format!("{:.1}MB", bytes as f64 / (1024.0 * 1024.0))
396    } else {
397        format!("{:.1}GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
398    }
399}
400
401fn format_duration(secs: u64) -> String {
402    if secs < 60 {
403        format!("{secs}s")
404    } else if secs < 3600 {
405        format!("{}m {}s", secs / 60, secs % 60)
406    } else if secs < 86400 {
407        let hours = secs / 3600;
408        let mins = (secs % 3600) / 60;
409        format!("{hours}h {mins}m")
410    } else {
411        let days = secs / 86400;
412        let hours = (secs % 86400) / 3600;
413        format!("{days}d {hours}h")
414    }
415}
416
417fn format_bytes_per_sec(bytes: u64) -> String {
418    if bytes < 1024 {
419        format!("{bytes}B/s")
420    } else if bytes < 1024 * 1024 {
421        format!("{:.1}KB/s", bytes as f64 / 1024.0)
422    } else if bytes < 1024 * 1024 * 1024 {
423        format!("{:.1}MB/s", bytes as f64 / (1024.0 * 1024.0))
424    } else {
425        format!("{:.1}GB/s", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
426    }
427}