Skip to main content

pitchfork_cli/
procs.rs

1use crate::Result;
2use crate::settings::settings;
3use miette::IntoDiagnostic;
4use once_cell::sync::Lazy;
5use std::sync::Mutex;
6use sysinfo::ProcessesToUpdate;
7
8pub struct Procs {
9    system: Mutex<sysinfo::System>,
10}
11
12pub static PROCS: Lazy<Procs> = Lazy::new(Procs::new);
13
14impl Default for Procs {
15    fn default() -> Self {
16        Self::new()
17    }
18}
19
20impl Procs {
21    pub fn new() -> Self {
22        let procs = Self {
23            system: Mutex::new(sysinfo::System::new()),
24        };
25        procs.refresh_processes();
26        procs
27    }
28
29    fn lock_system(&self) -> std::sync::MutexGuard<'_, sysinfo::System> {
30        self.system.lock().unwrap_or_else(|poisoned| {
31            warn!("System mutex was poisoned, recovering");
32            poisoned.into_inner()
33        })
34    }
35
36    pub fn title(&self, pid: u32) -> Option<String> {
37        self.lock_system()
38            .process(sysinfo::Pid::from_u32(pid))
39            .map(|p| p.name().to_string_lossy().to_string())
40    }
41
42    pub fn is_running(&self, pid: u32) -> bool {
43        self.lock_system()
44            .process(sysinfo::Pid::from_u32(pid))
45            .is_some()
46    }
47
48    /// Walk the /proc tree to find all descendant PIDs.
49    /// Kept for diagnostics/status display; no longer used in the kill path.
50    #[allow(dead_code)]
51    pub fn all_children(&self, pid: u32) -> Vec<u32> {
52        let system = self.lock_system();
53        let all = system.processes();
54        let mut children = vec![];
55        for (child_pid, process) in all {
56            let mut process = process;
57            while let Some(parent) = process.parent() {
58                if parent == sysinfo::Pid::from_u32(pid) {
59                    children.push(child_pid.as_u32());
60                    break;
61                }
62                match system.process(parent) {
63                    Some(p) => process = p,
64                    None => break,
65                }
66            }
67        }
68        children
69    }
70
71    pub async fn kill_process_group_async(&self, pid: u32) -> Result<bool> {
72        tokio::task::spawn_blocking(move || PROCS.kill_process_group(pid))
73            .await
74            .into_diagnostic()?
75    }
76
77    /// Kill an entire process group with graceful shutdown strategy:
78    /// 1. Send SIGTERM to the process group (-pgid) and wait up to ~3s
79    /// 2. If any processes remain, send SIGKILL to the group
80    ///
81    /// Since daemons are spawned with setsid(), the daemon PID == PGID,
82    /// so this atomically signals all descendant processes.
83    ///
84    /// Returns `Err` if the signal could not be sent (e.g. permission denied).
85    #[cfg(unix)]
86    fn kill_process_group(&self, pid: u32) -> Result<bool> {
87        let pgid = pid as i32;
88
89        debug!("killing process group {pgid}");
90
91        // Send SIGTERM to the entire process group.
92        // killpg sends to all processes in the group atomically.
93        // We intentionally skip the zombie check here because the leader may be
94        // a zombie while children in the group are still running.
95        let ret = unsafe { libc::killpg(pgid, libc::SIGTERM) };
96        if ret == -1 {
97            let err = std::io::Error::last_os_error();
98            if err.raw_os_error() == Some(libc::ESRCH) {
99                debug!("process group {pgid} no longer exists");
100                return Ok(false);
101            }
102            if err.raw_os_error() == Some(libc::EPERM) {
103                return Err(miette::miette!(
104                    "failed to send SIGTERM to process group {pgid}: permission denied"
105                ));
106            }
107            warn!("failed to send SIGTERM to process group {pgid}: {err}");
108        }
109
110        // Wait for graceful shutdown: fast initial check then slower polling.
111        // Use the configurable stop_timeout setting to govern total wait time.
112        let stop_timeout = settings().supervisor_stop_timeout();
113        let fast_ms = 10u64;
114        let slow_ms = 50u64;
115        let total_ms = stop_timeout.as_millis().max(1) as u64;
116        let fast_count = ((total_ms / fast_ms) as usize).min(10);
117        let fast_total_ms = fast_ms * fast_count as u64;
118        let remaining_ms = total_ms.saturating_sub(fast_total_ms);
119        let slow_count = (remaining_ms / slow_ms) as usize;
120
121        let fast_checks =
122            std::iter::repeat_n(std::time::Duration::from_millis(fast_ms), fast_count);
123        let slow_checks =
124            std::iter::repeat_n(std::time::Duration::from_millis(slow_ms), slow_count);
125        let mut elapsed_ms = 0u64;
126
127        for sleep_duration in fast_checks.chain(slow_checks) {
128            std::thread::sleep(sleep_duration);
129            self.refresh_pids(&[pid]);
130            elapsed_ms += sleep_duration.as_millis() as u64;
131            if self.is_terminated_or_zombie(sysinfo::Pid::from_u32(pid)) {
132                debug!("process group {pgid} terminated after SIGTERM ({elapsed_ms} ms)",);
133                return Ok(true);
134            }
135        }
136
137        // SIGKILL the entire process group as last resort
138        warn!(
139            "process group {pgid} did not respond to SIGTERM after {}ms, sending SIGKILL",
140            stop_timeout.as_millis()
141        );
142        let ret = unsafe { libc::killpg(pgid, libc::SIGKILL) };
143        if ret == -1 {
144            let err = std::io::Error::last_os_error();
145            if err.raw_os_error() != Some(libc::ESRCH) {
146                warn!("failed to send SIGKILL to process group {pgid}: {err}");
147            }
148        }
149
150        // Brief wait for SIGKILL to take effect
151        std::thread::sleep(std::time::Duration::from_millis(100));
152        Ok(true)
153    }
154
155    #[cfg(not(unix))]
156    fn kill_process_group(&self, pid: u32) -> Result<bool> {
157        // On non-unix platforms, fall back to single-process kill
158        self.kill(pid)
159    }
160
161    pub async fn kill_async(&self, pid: u32) -> Result<bool> {
162        tokio::task::spawn_blocking(move || PROCS.kill(pid))
163            .await
164            .into_diagnostic()?
165    }
166
167    /// Kill a process with graceful shutdown strategy:
168    /// 1. Send SIGTERM and wait up to ~3s (10ms intervals for first 100ms, then 50ms intervals)
169    /// 2. If still running, send SIGKILL to force termination
170    ///
171    /// This ensures fast-exiting processes don't wait unnecessarily,
172    /// while stubborn processes eventually get forcefully terminated.
173    ///
174    /// Returns `Err` if the signal could not be sent (e.g. permission denied
175    /// when targeting a process owned by another user/root).
176    fn kill(&self, pid: u32) -> Result<bool> {
177        let sysinfo_pid = sysinfo::Pid::from_u32(pid);
178
179        // Check if process exists or is a zombie (already terminated but not reaped)
180        if self.is_terminated_or_zombie(sysinfo_pid) {
181            return Ok(false);
182        }
183
184        debug!("killing process {pid}");
185
186        #[cfg(windows)]
187        {
188            if let Some(process) = self.lock_system().process(sysinfo_pid) {
189                process.kill();
190                process.wait();
191            }
192            return Ok(true);
193        }
194
195        #[cfg(unix)]
196        {
197            // Send SIGTERM for graceful shutdown using libc::kill directly
198            // so we can distinguish EPERM (permission denied) from ESRCH
199            // (process already gone — possible in a narrow race window).
200            debug!("sending SIGTERM to process {pid}");
201            let ret = unsafe { libc::kill(pid as i32, libc::SIGTERM) };
202            if ret == -1 {
203                let err = std::io::Error::last_os_error();
204                if err.raw_os_error() == Some(libc::ESRCH) {
205                    debug!("process {pid} no longer exists");
206                    return Ok(false);
207                }
208                if err.raw_os_error() == Some(libc::EPERM) {
209                    return Err(miette::miette!(
210                        "failed to send SIGTERM to process {pid}: permission denied"
211                    ));
212                }
213                return Err(miette::miette!(
214                    "failed to send SIGTERM to process {pid}: {err}"
215                ));
216            }
217
218            // Fast check: 10ms intervals, then slower 50ms polling for stop_timeout.
219            // Cap fast_count so short timeouts (e.g. 50ms) are honoured correctly.
220            let stop_timeout = settings().supervisor_stop_timeout();
221            let fast_ms = 10u64;
222            let slow_ms = 50u64;
223            let total_ms = stop_timeout.as_millis().max(1) as u64;
224            let fast_count = ((total_ms / fast_ms) as usize).min(10);
225            let fast_total_ms = fast_ms * fast_count as u64;
226            let remaining_ms = total_ms.saturating_sub(fast_total_ms);
227            let slow_count = (remaining_ms / slow_ms) as usize;
228
229            for i in 0..fast_count {
230                std::thread::sleep(std::time::Duration::from_millis(fast_ms));
231                self.refresh_pids(&[pid]);
232                if self.is_terminated_or_zombie(sysinfo_pid) {
233                    debug!(
234                        "process {pid} terminated after SIGTERM ({} ms)",
235                        (i + 1) * fast_ms as usize
236                    );
237                    return Ok(true);
238                }
239            }
240
241            // Slower check: 50ms intervals for the remainder of stop_timeout
242            for i in 0..slow_count {
243                std::thread::sleep(std::time::Duration::from_millis(slow_ms));
244                self.refresh_pids(&[pid]);
245                if self.is_terminated_or_zombie(sysinfo_pid) {
246                    debug!(
247                        "process {pid} terminated after SIGTERM ({} ms)",
248                        fast_total_ms + (i + 1) as u64 * slow_ms
249                    );
250                    return Ok(true);
251                }
252            }
253
254            // SIGKILL as last resort after stop_timeout
255            warn!(
256                "process {pid} did not respond to SIGTERM after {}ms, sending SIGKILL",
257                stop_timeout.as_millis()
258            );
259            let ret = unsafe { libc::kill(pid as i32, libc::SIGKILL) };
260            if ret == -1 {
261                let err = std::io::Error::last_os_error();
262                if err.raw_os_error() != Some(libc::ESRCH) {
263                    warn!("failed to send SIGKILL to process {pid}: {err}");
264                }
265            }
266
267            // Brief wait for SIGKILL to take effect
268            std::thread::sleep(std::time::Duration::from_millis(100));
269            Ok(true)
270        }
271    }
272
273    /// Check if a process is terminated or is a zombie.
274    /// On Linux, zombie processes still have /proc/[pid] entries but are effectively dead.
275    /// This prevents unnecessary signal escalation for processes that have already exited.
276    fn is_terminated_or_zombie(&self, sysinfo_pid: sysinfo::Pid) -> bool {
277        let system = self.lock_system();
278        match system.process(sysinfo_pid) {
279            None => true,
280            Some(process) => {
281                #[cfg(unix)]
282                {
283                    matches!(process.status(), sysinfo::ProcessStatus::Zombie)
284                }
285                #[cfg(not(unix))]
286                {
287                    let _ = process;
288                    false
289                }
290            }
291        }
292    }
293
294    pub(crate) fn refresh_processes(&self) {
295        self.lock_system()
296            .refresh_processes(ProcessesToUpdate::All, true);
297    }
298
299    /// Refresh only specific PIDs instead of all processes.
300    /// More efficient when you only need to check a small set of known PIDs.
301    pub(crate) fn refresh_pids(&self, pids: &[u32]) {
302        let sysinfo_pids: Vec<sysinfo::Pid> =
303            pids.iter().map(|p| sysinfo::Pid::from_u32(*p)).collect();
304        self.lock_system()
305            .refresh_processes(ProcessesToUpdate::Some(&sysinfo_pids), true);
306    }
307
308    /// Get aggregated stats for multiple process groups in a single pass.
309    ///
310    /// Builds the parent→children map once (O(N)) and then BFS-es from each
311    /// root PID (O(D_i) per daemon). Total cost is O(N + ΣD_i) instead of
312    /// O(D × N) when calling `get_group_stats` in a loop.
313    pub fn get_batch_group_stats(&self, pids: &[u32]) -> Vec<(u32, Option<ProcessStats>)> {
314        let system = self.lock_system();
315        let processes = system.processes();
316
317        let now = std::time::SystemTime::now()
318            .duration_since(std::time::UNIX_EPOCH)
319            .map(|d| d.as_secs())
320            .unwrap_or(0);
321
322        // Build parent → children map once for all daemons
323        let mut children_map: std::collections::HashMap<sysinfo::Pid, Vec<sysinfo::Pid>> =
324            std::collections::HashMap::new();
325        for (child_pid, child) in processes {
326            if let Some(ppid) = child.parent() {
327                children_map.entry(ppid).or_default().push(*child_pid);
328            }
329        }
330
331        pids.iter()
332            .map(|&pid| {
333                let root_pid = sysinfo::Pid::from_u32(pid);
334                let Some(root) = processes.get(&root_pid) else {
335                    return (pid, None);
336                };
337
338                let root_disk = root.disk_usage();
339                let mut stats = ProcessStats {
340                    cpu_percent: root.cpu_usage(),
341                    memory_bytes: root.memory(),
342                    uptime_secs: now.saturating_sub(root.start_time()),
343                    disk_read_bytes: root_disk.read_bytes,
344                    disk_write_bytes: root_disk.written_bytes,
345                };
346
347                // BFS from root_pid to find all descendants
348                let mut queue = std::collections::VecDeque::new();
349                if let Some(direct_children) = children_map.get(&root_pid) {
350                    queue.extend(direct_children);
351                }
352                while let Some(child_pid) = queue.pop_front() {
353                    if let Some(child) = processes.get(&child_pid) {
354                        let disk = child.disk_usage();
355                        stats.cpu_percent += child.cpu_usage();
356                        stats.memory_bytes += child.memory();
357                        stats.disk_read_bytes += disk.read_bytes;
358                        stats.disk_write_bytes += disk.written_bytes;
359                    }
360                    if let Some(grandchildren) = children_map.get(&child_pid) {
361                        queue.extend(grandchildren);
362                    }
363                }
364
365                (pid, Some(stats))
366            })
367            .collect()
368    }
369
370    /// Get process stats (cpu%, memory bytes, uptime secs, disk I/O) for a given PID
371    pub fn get_stats(&self, pid: u32) -> Option<ProcessStats> {
372        let system = self.lock_system();
373        system.process(sysinfo::Pid::from_u32(pid)).map(|p| {
374            let now = std::time::SystemTime::now()
375                .duration_since(std::time::UNIX_EPOCH)
376                .map(|d| d.as_secs())
377                .unwrap_or(0);
378            let disk = p.disk_usage();
379            ProcessStats {
380                cpu_percent: p.cpu_usage(),
381                memory_bytes: p.memory(),
382                uptime_secs: now.saturating_sub(p.start_time()),
383                disk_read_bytes: disk.read_bytes,
384                disk_write_bytes: disk.written_bytes,
385            }
386        })
387    }
388
389    /// Get extended process information for a given PID
390    pub fn get_extended_stats(&self, pid: u32) -> Option<ExtendedProcessStats> {
391        let system = self.lock_system();
392        system.process(sysinfo::Pid::from_u32(pid)).map(|p| {
393            let now = std::time::SystemTime::now()
394                .duration_since(std::time::UNIX_EPOCH)
395                .map(|d| d.as_secs())
396                .unwrap_or(0);
397            let disk = p.disk_usage();
398
399            ExtendedProcessStats {
400                name: p.name().to_string_lossy().to_string(),
401                exe_path: p.exe().map(|e| e.to_string_lossy().to_string()),
402                cwd: p.cwd().map(|c| c.to_string_lossy().to_string()),
403                environ: p
404                    .environ()
405                    .iter()
406                    .take(20)
407                    .map(|s| s.to_string_lossy().to_string())
408                    .collect(),
409                status: format!("{:?}", p.status()),
410                cpu_percent: p.cpu_usage(),
411                memory_bytes: p.memory(),
412                virtual_memory_bytes: p.virtual_memory(),
413                uptime_secs: now.saturating_sub(p.start_time()),
414                start_time: p.start_time(),
415                disk_read_bytes: disk.read_bytes,
416                disk_write_bytes: disk.written_bytes,
417                parent_pid: p.parent().map(|pp| pp.as_u32()),
418                thread_count: p.tasks().map(|t| t.len()).unwrap_or(0),
419                user_id: p.user_id().map(|u| u.to_string()),
420            }
421        })
422    }
423}
424
425#[derive(Debug, Clone, Copy)]
426pub struct ProcessStats {
427    pub cpu_percent: f32,
428    pub memory_bytes: u64,
429    pub uptime_secs: u64,
430    pub disk_read_bytes: u64,
431    pub disk_write_bytes: u64,
432}
433
434impl ProcessStats {
435    pub fn memory_display(&self) -> String {
436        format_bytes(self.memory_bytes)
437    }
438
439    pub fn cpu_display(&self) -> String {
440        format!("{:.1}%", self.cpu_percent)
441    }
442
443    pub fn uptime_display(&self) -> String {
444        format_duration(self.uptime_secs)
445    }
446
447    pub fn disk_read_display(&self) -> String {
448        format_bytes_per_sec(self.disk_read_bytes)
449    }
450
451    pub fn disk_write_display(&self) -> String {
452        format_bytes_per_sec(self.disk_write_bytes)
453    }
454}
455
456/// Extended process stats with more detailed information
457#[derive(Debug, Clone)]
458pub struct ExtendedProcessStats {
459    pub name: String,
460    pub exe_path: Option<String>,
461    pub cwd: Option<String>,
462    pub environ: Vec<String>,
463    pub status: String,
464    pub cpu_percent: f32,
465    pub memory_bytes: u64,
466    pub virtual_memory_bytes: u64,
467    pub uptime_secs: u64,
468    pub start_time: u64,
469    pub disk_read_bytes: u64,
470    pub disk_write_bytes: u64,
471    pub parent_pid: Option<u32>,
472    pub thread_count: usize,
473    pub user_id: Option<String>,
474}
475
476impl ExtendedProcessStats {
477    pub fn memory_display(&self) -> String {
478        format_bytes(self.memory_bytes)
479    }
480
481    pub fn virtual_memory_display(&self) -> String {
482        format_bytes(self.virtual_memory_bytes)
483    }
484
485    pub fn cpu_display(&self) -> String {
486        format!("{:.1}%", self.cpu_percent)
487    }
488
489    pub fn uptime_display(&self) -> String {
490        format_duration(self.uptime_secs)
491    }
492
493    pub fn start_time_display(&self) -> String {
494        use std::time::{Duration, UNIX_EPOCH};
495        let datetime = UNIX_EPOCH + Duration::from_secs(self.start_time);
496        chrono::DateTime::<chrono::Local>::from(datetime)
497            .format("%Y-%m-%d %H:%M:%S")
498            .to_string()
499    }
500
501    pub fn disk_read_display(&self) -> String {
502        format_bytes_per_sec(self.disk_read_bytes)
503    }
504
505    pub fn disk_write_display(&self) -> String {
506        format_bytes_per_sec(self.disk_write_bytes)
507    }
508}
509
510fn format_bytes(bytes: u64) -> String {
511    if bytes < 1024 {
512        format!("{bytes}B")
513    } else if bytes < 1024 * 1024 {
514        format!("{:.1}KB", bytes as f64 / 1024.0)
515    } else if bytes < 1024 * 1024 * 1024 {
516        format!("{:.1}MB", bytes as f64 / (1024.0 * 1024.0))
517    } else {
518        format!("{:.1}GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
519    }
520}
521
522fn format_duration(secs: u64) -> String {
523    if secs < 60 {
524        format!("{secs}s")
525    } else if secs < 3600 {
526        format!("{}m {}s", secs / 60, secs % 60)
527    } else if secs < 86400 {
528        let hours = secs / 3600;
529        let mins = (secs % 3600) / 60;
530        format!("{hours}h {mins}m")
531    } else {
532        let days = secs / 86400;
533        let hours = (secs % 86400) / 3600;
534        format!("{days}d {hours}h")
535    }
536}
537
538fn format_bytes_per_sec(bytes: u64) -> String {
539    if bytes < 1024 {
540        format!("{bytes}B/s")
541    } else if bytes < 1024 * 1024 {
542        format!("{:.1}KB/s", bytes as f64 / 1024.0)
543    } else if bytes < 1024 * 1024 * 1024 {
544        format!("{:.1}MB/s", bytes as f64 / (1024.0 * 1024.0))
545    } else {
546        format!("{:.1}GB/s", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
547    }
548}