Skip to main content

pitchfork_cli/
procs.rs

1use crate::Result;
2#[cfg(unix)]
3use crate::settings::settings;
4use miette::IntoDiagnostic;
5use once_cell::sync::Lazy;
6use std::sync::Mutex;
7use sysinfo::ProcessesToUpdate;
8
9pub struct Procs {
10    system: Mutex<sysinfo::System>,
11}
12
13pub static PROCS: Lazy<Procs> = Lazy::new(Procs::new);
14
15impl Default for Procs {
16    fn default() -> Self {
17        Self::new()
18    }
19}
20
21impl Procs {
22    pub fn new() -> Self {
23        let procs = Self {
24            system: Mutex::new(sysinfo::System::new()),
25        };
26        procs.refresh_processes();
27        procs
28    }
29
30    fn lock_system(&self) -> std::sync::MutexGuard<'_, sysinfo::System> {
31        self.system.lock().unwrap_or_else(|poisoned| {
32            warn!("System mutex was poisoned, recovering");
33            poisoned.into_inner()
34        })
35    }
36
37    pub fn title(&self, pid: u32) -> Option<String> {
38        self.lock_system()
39            .process(sysinfo::Pid::from_u32(pid))
40            .map(|p| p.name().to_string_lossy().to_string())
41    }
42
43    pub fn is_running(&self, pid: u32) -> bool {
44        self.lock_system()
45            .process(sysinfo::Pid::from_u32(pid))
46            .is_some()
47    }
48
49    /// Walk the /proc tree to find all descendant PIDs.
50    /// Kept for diagnostics/status display; no longer used in the kill path.
51    #[allow(dead_code)]
52    pub fn all_children(&self, pid: u32) -> Vec<u32> {
53        let system = self.lock_system();
54        let all = system.processes();
55        let mut children = vec![];
56        for (child_pid, process) in all {
57            let mut process = process;
58            while let Some(parent) = process.parent() {
59                if parent == sysinfo::Pid::from_u32(pid) {
60                    children.push(child_pid.as_u32());
61                    break;
62                }
63                match system.process(parent) {
64                    Some(p) => process = p,
65                    None => break,
66                }
67            }
68        }
69        children
70    }
71
72    pub async fn kill_process_group_async(
73        &self,
74        pid: u32,
75        stop_signal: i32,
76        stop_timeout: Option<std::time::Duration>,
77    ) -> Result<bool> {
78        tokio::task::spawn_blocking(move || {
79            PROCS.kill_process_group(pid, stop_signal, stop_timeout)
80        })
81        .await
82        .into_diagnostic()?
83    }
84
85    /// Kill an entire process group with graceful shutdown strategy:
86    /// 1. Send the configured stop signal to the process group (-pgid) and wait up to ~3s
87    /// 2. If any processes remain, send SIGKILL to the group
88    ///
89    /// Since daemons are spawned with setsid(), the daemon PID == PGID,
90    /// so this atomically signals all descendant processes.
91    ///
92    /// Returns `Err` if the signal could not be sent (e.g. permission denied).
93    #[cfg(unix)]
94    fn kill_process_group(
95        &self,
96        pid: u32,
97        stop_signal: i32,
98        stop_timeout: Option<std::time::Duration>,
99    ) -> Result<bool> {
100        let pgid = pid as i32;
101        let signal_name = signal_name(stop_signal);
102
103        debug!("killing process group {pgid} with {signal_name}");
104
105        // Send the stop signal to the entire process group.
106        // killpg sends to all processes in the group atomically.
107        // We intentionally skip the zombie check here because the leader may be
108        // a zombie while children in the group are still running.
109        let ret = unsafe { libc::killpg(pgid, stop_signal) };
110        if ret == -1 {
111            let err = std::io::Error::last_os_error();
112            if err.raw_os_error() == Some(libc::ESRCH) {
113                debug!("process group {pgid} no longer exists");
114                return Ok(false);
115            }
116            if err.raw_os_error() == Some(libc::EPERM) {
117                return Err(miette::miette!(
118                    "failed to send {signal_name} to process group {pgid}: permission denied"
119                ));
120            }
121            warn!("failed to send {signal_name} to process group {pgid}: {err}");
122        }
123
124        // Wait for graceful shutdown: fast initial check then slower polling.
125        // Per-daemon timeout overrides the global setting.
126        let stop_timeout = stop_timeout.unwrap_or_else(|| settings().supervisor_stop_timeout());
127        let fast_ms = 10u64;
128        let slow_ms = 50u64;
129        let total_ms = stop_timeout.as_millis().max(1) as u64;
130        let fast_count = ((total_ms / fast_ms) as usize).min(10);
131        let fast_total_ms = fast_ms * fast_count as u64;
132        let remaining_ms = total_ms.saturating_sub(fast_total_ms);
133        let slow_count = (remaining_ms / slow_ms) as usize;
134
135        let fast_checks =
136            std::iter::repeat_n(std::time::Duration::from_millis(fast_ms), fast_count);
137        let slow_checks =
138            std::iter::repeat_n(std::time::Duration::from_millis(slow_ms), slow_count);
139        let mut elapsed_ms = 0u64;
140
141        for sleep_duration in fast_checks.chain(slow_checks) {
142            std::thread::sleep(sleep_duration);
143            self.refresh_pids(&[pid]);
144            elapsed_ms += sleep_duration.as_millis() as u64;
145            if self.is_terminated_or_zombie(sysinfo::Pid::from_u32(pid)) {
146                debug!("process group {pgid} terminated after {signal_name} ({elapsed_ms} ms)",);
147                return Ok(true);
148            }
149        }
150
151        // SIGKILL the entire process group as last resort
152        warn!(
153            "process group {pgid} did not respond to {signal_name} after {}ms, sending SIGKILL",
154            stop_timeout.as_millis()
155        );
156        let ret = unsafe { libc::killpg(pgid, libc::SIGKILL) };
157        if ret == -1 {
158            let err = std::io::Error::last_os_error();
159            if err.raw_os_error() != Some(libc::ESRCH) {
160                warn!("failed to send SIGKILL to process group {pgid}: {err}");
161            }
162        }
163
164        // Brief wait for SIGKILL to take effect
165        std::thread::sleep(std::time::Duration::from_millis(100));
166        Ok(true)
167    }
168
169    #[cfg(not(unix))]
170    fn kill_process_group(
171        &self,
172        pid: u32,
173        _stop_signal: i32,
174        _stop_timeout: Option<std::time::Duration>,
175    ) -> Result<bool> {
176        self.kill(pid, 0, None)
177    }
178
179    pub async fn kill_async(
180        &self,
181        pid: u32,
182        stop_signal: i32,
183        stop_timeout: Option<std::time::Duration>,
184    ) -> Result<bool> {
185        tokio::task::spawn_blocking(move || PROCS.kill(pid, stop_signal, stop_timeout))
186            .await
187            .into_diagnostic()?
188    }
189
190    /// Kill a process with graceful shutdown strategy:
191    /// 1. Send the configured stop signal and wait up to ~3s (10ms intervals for first 100ms, then 50ms intervals)
192    /// 2. If still running, send SIGKILL to force termination
193    ///
194    /// This ensures fast-exiting processes don't wait unnecessarily,
195    /// while stubborn processes eventually get forcefully terminated.
196    ///
197    /// Returns `Err` if the signal could not be sent (e.g. permission denied
198    /// when targeting a process owned by another user/root).
199    fn kill(
200        &self,
201        pid: u32,
202        stop_signal: i32,
203        stop_timeout: Option<std::time::Duration>,
204    ) -> Result<bool> {
205        let sysinfo_pid = sysinfo::Pid::from_u32(pid);
206
207        // Check if process exists or is a zombie (already terminated but not reaped)
208        if self.is_terminated_or_zombie(sysinfo_pid) {
209            return Ok(false);
210        }
211
212        debug!("killing process {pid}");
213
214        #[cfg(windows)]
215        {
216            let _ = (stop_signal, stop_timeout);
217            if let Some(process) = self.lock_system().process(sysinfo_pid) {
218                process.kill();
219                process.wait();
220            }
221            Ok(true)
222        }
223
224        #[cfg(unix)]
225        {
226            let signal_name = signal_name(stop_signal);
227            // Send stop signal for graceful shutdown using libc::kill directly
228            // so we can distinguish EPERM (permission denied) from ESRCH
229            // (process already gone — possible in a narrow race window).
230            debug!("sending {signal_name} to process {pid}");
231            let ret = unsafe { libc::kill(pid as i32, stop_signal) };
232            if ret == -1 {
233                let err = std::io::Error::last_os_error();
234                if err.raw_os_error() == Some(libc::ESRCH) {
235                    debug!("process {pid} no longer exists");
236                    return Ok(false);
237                }
238                if err.raw_os_error() == Some(libc::EPERM) {
239                    return Err(miette::miette!(
240                        "failed to send {signal_name} to process {pid}: permission denied"
241                    ));
242                }
243                return Err(miette::miette!(
244                    "failed to send {signal_name} to process {pid}: {err}"
245                ));
246            }
247
248            // Fast check: 10ms intervals, then slower 50ms polling for stop_timeout.
249            // Per-daemon timeout overrides the global setting.
250            let stop_timeout = stop_timeout.unwrap_or_else(|| settings().supervisor_stop_timeout());
251            let fast_ms = 10u64;
252            let slow_ms = 50u64;
253            let total_ms = stop_timeout.as_millis().max(1) as u64;
254            let fast_count = ((total_ms / fast_ms) as usize).min(10);
255            let fast_total_ms = fast_ms * fast_count as u64;
256            let remaining_ms = total_ms.saturating_sub(fast_total_ms);
257            let slow_count = (remaining_ms / slow_ms) as usize;
258
259            for i in 0..fast_count {
260                std::thread::sleep(std::time::Duration::from_millis(fast_ms));
261                self.refresh_pids(&[pid]);
262                if self.is_terminated_or_zombie(sysinfo_pid) {
263                    debug!(
264                        "process {pid} terminated after {signal_name} ({} ms)",
265                        (i + 1) * fast_ms as usize
266                    );
267                    return Ok(true);
268                }
269            }
270
271            // Slower check: 50ms intervals for the remainder of stop_timeout
272            for i in 0..slow_count {
273                std::thread::sleep(std::time::Duration::from_millis(slow_ms));
274                self.refresh_pids(&[pid]);
275                if self.is_terminated_or_zombie(sysinfo_pid) {
276                    debug!(
277                        "process {pid} terminated after {signal_name} ({} ms)",
278                        fast_total_ms + (i + 1) as u64 * slow_ms
279                    );
280                    return Ok(true);
281                }
282            }
283
284            // SIGKILL as last resort after stop_timeout
285            warn!(
286                "process {pid} did not respond to {signal_name} after {}ms, sending SIGKILL",
287                stop_timeout.as_millis()
288            );
289            let ret = unsafe { libc::kill(pid as i32, libc::SIGKILL) };
290            if ret == -1 {
291                let err = std::io::Error::last_os_error();
292                if err.raw_os_error() != Some(libc::ESRCH) {
293                    warn!("failed to send SIGKILL to process {pid}: {err}");
294                }
295            }
296
297            // Brief wait for SIGKILL to take effect
298            std::thread::sleep(std::time::Duration::from_millis(100));
299            Ok(true)
300        }
301    }
302
303    /// Check if a process is terminated or is a zombie.
304    /// On Linux, zombie processes still have /proc/[pid] entries but are effectively dead.
305    /// This prevents unnecessary signal escalation for processes that have already exited.
306    fn is_terminated_or_zombie(&self, sysinfo_pid: sysinfo::Pid) -> bool {
307        let system = self.lock_system();
308        match system.process(sysinfo_pid) {
309            None => true,
310            Some(process) => {
311                #[cfg(unix)]
312                {
313                    matches!(process.status(), sysinfo::ProcessStatus::Zombie)
314                }
315                #[cfg(not(unix))]
316                {
317                    let _ = process;
318                    false
319                }
320            }
321        }
322    }
323
324    pub(crate) fn refresh_processes(&self) {
325        self.lock_system()
326            .refresh_processes(ProcessesToUpdate::All, true);
327    }
328
329    /// Refresh only specific PIDs instead of all processes.
330    /// More efficient when you only need to check a small set of known PIDs.
331    pub(crate) fn refresh_pids(&self, pids: &[u32]) {
332        let sysinfo_pids: Vec<sysinfo::Pid> =
333            pids.iter().map(|p| sysinfo::Pid::from_u32(*p)).collect();
334        self.lock_system()
335            .refresh_processes(ProcessesToUpdate::Some(&sysinfo_pids), true);
336    }
337
338    /// Get aggregated stats for multiple process groups in a single pass.
339    ///
340    /// Builds the parent→children map once (O(N)) and then BFS-es from each
341    /// root PID (O(D_i) per daemon). Total cost is O(N + ΣD_i) instead of
342    /// O(D × N) when calling `get_group_stats` in a loop.
343    pub fn get_batch_group_stats(&self, pids: &[u32]) -> Vec<(u32, Option<ProcessStats>)> {
344        let system = self.lock_system();
345        let processes = system.processes();
346
347        let now = std::time::SystemTime::now()
348            .duration_since(std::time::UNIX_EPOCH)
349            .map(|d| d.as_secs())
350            .unwrap_or(0);
351
352        // Build parent → children map once for all daemons
353        let mut children_map: std::collections::HashMap<sysinfo::Pid, Vec<sysinfo::Pid>> =
354            std::collections::HashMap::new();
355        for (child_pid, child) in processes {
356            if let Some(ppid) = child.parent() {
357                children_map.entry(ppid).or_default().push(*child_pid);
358            }
359        }
360
361        pids.iter()
362            .map(|&pid| {
363                let root_pid = sysinfo::Pid::from_u32(pid);
364                let Some(root) = processes.get(&root_pid) else {
365                    return (pid, None);
366                };
367
368                let root_disk = root.disk_usage();
369                let mut stats = ProcessStats {
370                    cpu_percent: root.cpu_usage(),
371                    memory_bytes: root.memory(),
372                    uptime_secs: now.saturating_sub(root.start_time()),
373                    disk_read_bytes: root_disk.read_bytes,
374                    disk_write_bytes: root_disk.written_bytes,
375                };
376
377                // BFS from root_pid to find all descendants
378                let mut queue = std::collections::VecDeque::new();
379                if let Some(direct_children) = children_map.get(&root_pid) {
380                    queue.extend(direct_children);
381                }
382                while let Some(child_pid) = queue.pop_front() {
383                    if let Some(child) = processes.get(&child_pid) {
384                        let disk = child.disk_usage();
385                        stats.cpu_percent += child.cpu_usage();
386                        stats.memory_bytes += child.memory();
387                        stats.disk_read_bytes += disk.read_bytes;
388                        stats.disk_write_bytes += disk.written_bytes;
389                    }
390                    if let Some(grandchildren) = children_map.get(&child_pid) {
391                        queue.extend(grandchildren);
392                    }
393                }
394
395                (pid, Some(stats))
396            })
397            .collect()
398    }
399
400    /// Get process stats (cpu%, memory bytes, uptime secs, disk I/O) for a given PID
401    pub fn get_stats(&self, pid: u32) -> Option<ProcessStats> {
402        let system = self.lock_system();
403        system.process(sysinfo::Pid::from_u32(pid)).map(|p| {
404            let now = std::time::SystemTime::now()
405                .duration_since(std::time::UNIX_EPOCH)
406                .map(|d| d.as_secs())
407                .unwrap_or(0);
408            let disk = p.disk_usage();
409            ProcessStats {
410                cpu_percent: p.cpu_usage(),
411                memory_bytes: p.memory(),
412                uptime_secs: now.saturating_sub(p.start_time()),
413                disk_read_bytes: disk.read_bytes,
414                disk_write_bytes: disk.written_bytes,
415            }
416        })
417    }
418
419    /// Get extended process information for a given PID
420    pub fn get_extended_stats(&self, pid: u32) -> Option<ExtendedProcessStats> {
421        let system = self.lock_system();
422        system.process(sysinfo::Pid::from_u32(pid)).map(|p| {
423            let now = std::time::SystemTime::now()
424                .duration_since(std::time::UNIX_EPOCH)
425                .map(|d| d.as_secs())
426                .unwrap_or(0);
427            let disk = p.disk_usage();
428
429            ExtendedProcessStats {
430                name: p.name().to_string_lossy().to_string(),
431                exe_path: p.exe().map(|e| e.to_string_lossy().to_string()),
432                cwd: p.cwd().map(|c| c.to_string_lossy().to_string()),
433                environ: p
434                    .environ()
435                    .iter()
436                    .take(20)
437                    .map(|s| s.to_string_lossy().to_string())
438                    .collect(),
439                status: format!("{:?}", p.status()),
440                cpu_percent: p.cpu_usage(),
441                memory_bytes: p.memory(),
442                virtual_memory_bytes: p.virtual_memory(),
443                uptime_secs: now.saturating_sub(p.start_time()),
444                start_time: p.start_time(),
445                disk_read_bytes: disk.read_bytes,
446                disk_write_bytes: disk.written_bytes,
447                parent_pid: p.parent().map(|pp| pp.as_u32()),
448                thread_count: p.tasks().map(|t| t.len()).unwrap_or(0),
449                user_id: p.user_id().map(|u| u.to_string()),
450            }
451        })
452    }
453}
454
455#[derive(Debug, Clone, Copy)]
456pub struct ProcessStats {
457    pub cpu_percent: f32,
458    pub memory_bytes: u64,
459    pub uptime_secs: u64,
460    pub disk_read_bytes: u64,
461    pub disk_write_bytes: u64,
462}
463
464impl ProcessStats {
465    pub fn memory_display(&self) -> String {
466        format_bytes(self.memory_bytes)
467    }
468
469    pub fn cpu_display(&self) -> String {
470        format!("{:.1}%", self.cpu_percent)
471    }
472
473    pub fn uptime_display(&self) -> String {
474        format_duration(self.uptime_secs)
475    }
476
477    pub fn disk_read_display(&self) -> String {
478        format_bytes_per_sec(self.disk_read_bytes)
479    }
480
481    pub fn disk_write_display(&self) -> String {
482        format_bytes_per_sec(self.disk_write_bytes)
483    }
484}
485
486/// Extended process stats with more detailed information
487#[derive(Debug, Clone)]
488pub struct ExtendedProcessStats {
489    pub name: String,
490    pub exe_path: Option<String>,
491    pub cwd: Option<String>,
492    pub environ: Vec<String>,
493    pub status: String,
494    pub cpu_percent: f32,
495    pub memory_bytes: u64,
496    pub virtual_memory_bytes: u64,
497    pub uptime_secs: u64,
498    pub start_time: u64,
499    pub disk_read_bytes: u64,
500    pub disk_write_bytes: u64,
501    pub parent_pid: Option<u32>,
502    pub thread_count: usize,
503    pub user_id: Option<String>,
504}
505
506impl ExtendedProcessStats {
507    pub fn memory_display(&self) -> String {
508        format_bytes(self.memory_bytes)
509    }
510
511    pub fn virtual_memory_display(&self) -> String {
512        format_bytes(self.virtual_memory_bytes)
513    }
514
515    pub fn cpu_display(&self) -> String {
516        format!("{:.1}%", self.cpu_percent)
517    }
518
519    pub fn uptime_display(&self) -> String {
520        format_duration(self.uptime_secs)
521    }
522
523    pub fn start_time_display(&self) -> String {
524        use std::time::{Duration, UNIX_EPOCH};
525        let datetime = UNIX_EPOCH + Duration::from_secs(self.start_time);
526        chrono::DateTime::<chrono::Local>::from(datetime)
527            .format("%Y-%m-%d %H:%M:%S")
528            .to_string()
529    }
530
531    pub fn disk_read_display(&self) -> String {
532        format_bytes_per_sec(self.disk_read_bytes)
533    }
534
535    pub fn disk_write_display(&self) -> String {
536        format_bytes_per_sec(self.disk_write_bytes)
537    }
538}
539
540fn format_bytes(bytes: u64) -> String {
541    if bytes < 1024 {
542        format!("{bytes}B")
543    } else if bytes < 1024 * 1024 {
544        format!("{:.1}KB", bytes as f64 / 1024.0)
545    } else if bytes < 1024 * 1024 * 1024 {
546        format!("{:.1}MB", bytes as f64 / (1024.0 * 1024.0))
547    } else {
548        format!("{:.1}GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
549    }
550}
551
552fn format_duration(secs: u64) -> String {
553    if secs < 60 {
554        format!("{secs}s")
555    } else if secs < 3600 {
556        format!("{}m {}s", secs / 60, secs % 60)
557    } else if secs < 86400 {
558        let hours = secs / 3600;
559        let mins = (secs % 3600) / 60;
560        format!("{hours}h {mins}m")
561    } else {
562        let days = secs / 86400;
563        let hours = (secs % 86400) / 3600;
564        format!("{days}d {hours}h")
565    }
566}
567
568fn format_bytes_per_sec(bytes: u64) -> String {
569    if bytes < 1024 {
570        format!("{bytes}B/s")
571    } else if bytes < 1024 * 1024 {
572        format!("{:.1}KB/s", bytes as f64 / 1024.0)
573    } else if bytes < 1024 * 1024 * 1024 {
574        format!("{:.1}MB/s", bytes as f64 / (1024.0 * 1024.0))
575    } else {
576        format!("{:.1}GB/s", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
577    }
578}
579
580#[cfg(unix)]
581fn signal_name(sig: i32) -> &'static str {
582    match sig {
583        libc::SIGHUP => "SIGHUP",
584        libc::SIGINT => "SIGINT",
585        libc::SIGQUIT => "SIGQUIT",
586        libc::SIGTERM => "SIGTERM",
587        libc::SIGUSR1 => "SIGUSR1",
588        libc::SIGUSR2 => "SIGUSR2",
589        libc::SIGKILL => "SIGKILL",
590        _ => "UNKNOWN",
591    }
592}