Skip to main content

pitchfork_cli/
procs.rs

1use crate::Result;
2use crate::settings::settings;
3use miette::IntoDiagnostic;
4use once_cell::sync::Lazy;
5use std::sync::Mutex;
6use sysinfo::ProcessesToUpdate;
7
8pub struct Procs {
9    system: Mutex<sysinfo::System>,
10}
11
12pub static PROCS: Lazy<Procs> = Lazy::new(Procs::new);
13
14impl Default for Procs {
15    fn default() -> Self {
16        Self::new()
17    }
18}
19
20impl Procs {
21    pub fn new() -> Self {
22        let procs = Self {
23            system: Mutex::new(sysinfo::System::new()),
24        };
25        procs.refresh_processes();
26        procs
27    }
28
29    fn lock_system(&self) -> std::sync::MutexGuard<'_, sysinfo::System> {
30        self.system.lock().unwrap_or_else(|poisoned| {
31            warn!("System mutex was poisoned, recovering");
32            poisoned.into_inner()
33        })
34    }
35
36    pub fn title(&self, pid: u32) -> Option<String> {
37        self.lock_system()
38            .process(sysinfo::Pid::from_u32(pid))
39            .map(|p| p.name().to_string_lossy().to_string())
40    }
41
42    pub fn is_running(&self, pid: u32) -> bool {
43        self.lock_system()
44            .process(sysinfo::Pid::from_u32(pid))
45            .is_some()
46    }
47
48    /// Walk the /proc tree to find all descendant PIDs.
49    /// Kept for diagnostics/status display; no longer used in the kill path.
50    #[allow(dead_code)]
51    pub fn all_children(&self, pid: u32) -> Vec<u32> {
52        let system = self.lock_system();
53        let all = system.processes();
54        let mut children = vec![];
55        for (child_pid, process) in all {
56            let mut process = process;
57            while let Some(parent) = process.parent() {
58                if parent == sysinfo::Pid::from_u32(pid) {
59                    children.push(child_pid.as_u32());
60                    break;
61                }
62                match system.process(parent) {
63                    Some(p) => process = p,
64                    None => break,
65                }
66            }
67        }
68        children
69    }
70
71    pub async fn kill_process_group_async(
72        &self,
73        pid: u32,
74        stop_signal: i32,
75        stop_timeout: Option<std::time::Duration>,
76    ) -> Result<bool> {
77        tokio::task::spawn_blocking(move || {
78            PROCS.kill_process_group(pid, stop_signal, stop_timeout)
79        })
80        .await
81        .into_diagnostic()?
82    }
83
84    /// Kill an entire process group with graceful shutdown strategy:
85    /// 1. Send the configured stop signal to the process group (-pgid) and wait up to ~3s
86    /// 2. If any processes remain, send SIGKILL to the group
87    ///
88    /// Since daemons are spawned with setsid(), the daemon PID == PGID,
89    /// so this atomically signals all descendant processes.
90    ///
91    /// Returns `Err` if the signal could not be sent (e.g. permission denied).
92    #[cfg(unix)]
93    fn kill_process_group(
94        &self,
95        pid: u32,
96        stop_signal: i32,
97        stop_timeout: Option<std::time::Duration>,
98    ) -> Result<bool> {
99        let pgid = pid as i32;
100        let signal_name = signal_name(stop_signal);
101
102        debug!("killing process group {pgid} with {signal_name}");
103
104        // Send the stop signal to the entire process group.
105        // killpg sends to all processes in the group atomically.
106        // We intentionally skip the zombie check here because the leader may be
107        // a zombie while children in the group are still running.
108        let ret = unsafe { libc::killpg(pgid, stop_signal) };
109        if ret == -1 {
110            let err = std::io::Error::last_os_error();
111            if err.raw_os_error() == Some(libc::ESRCH) {
112                debug!("process group {pgid} no longer exists");
113                return Ok(false);
114            }
115            if err.raw_os_error() == Some(libc::EPERM) {
116                return Err(miette::miette!(
117                    "failed to send {signal_name} to process group {pgid}: permission denied"
118                ));
119            }
120            warn!("failed to send {signal_name} to process group {pgid}: {err}");
121        }
122
123        // Wait for graceful shutdown: fast initial check then slower polling.
124        // Per-daemon timeout overrides the global setting.
125        let stop_timeout = stop_timeout.unwrap_or_else(|| settings().supervisor_stop_timeout());
126        let fast_ms = 10u64;
127        let slow_ms = 50u64;
128        let total_ms = stop_timeout.as_millis().max(1) as u64;
129        let fast_count = ((total_ms / fast_ms) as usize).min(10);
130        let fast_total_ms = fast_ms * fast_count as u64;
131        let remaining_ms = total_ms.saturating_sub(fast_total_ms);
132        let slow_count = (remaining_ms / slow_ms) as usize;
133
134        let fast_checks =
135            std::iter::repeat_n(std::time::Duration::from_millis(fast_ms), fast_count);
136        let slow_checks =
137            std::iter::repeat_n(std::time::Duration::from_millis(slow_ms), slow_count);
138        let mut elapsed_ms = 0u64;
139
140        for sleep_duration in fast_checks.chain(slow_checks) {
141            std::thread::sleep(sleep_duration);
142            self.refresh_pids(&[pid]);
143            elapsed_ms += sleep_duration.as_millis() as u64;
144            if self.is_terminated_or_zombie(sysinfo::Pid::from_u32(pid)) {
145                debug!("process group {pgid} terminated after {signal_name} ({elapsed_ms} ms)",);
146                return Ok(true);
147            }
148        }
149
150        // SIGKILL the entire process group as last resort
151        warn!(
152            "process group {pgid} did not respond to {signal_name} after {}ms, sending SIGKILL",
153            stop_timeout.as_millis()
154        );
155        let ret = unsafe { libc::killpg(pgid, libc::SIGKILL) };
156        if ret == -1 {
157            let err = std::io::Error::last_os_error();
158            if err.raw_os_error() != Some(libc::ESRCH) {
159                warn!("failed to send SIGKILL to process group {pgid}: {err}");
160            }
161        }
162
163        // Brief wait for SIGKILL to take effect
164        std::thread::sleep(std::time::Duration::from_millis(100));
165        Ok(true)
166    }
167
168    #[cfg(not(unix))]
169    fn kill_process_group(
170        &self,
171        pid: u32,
172        _stop_signal: i32,
173        _stop_timeout: Option<std::time::Duration>,
174    ) -> Result<bool> {
175        self.kill(pid, 0, None)
176    }
177
178    pub async fn kill_async(
179        &self,
180        pid: u32,
181        stop_signal: i32,
182        stop_timeout: Option<std::time::Duration>,
183    ) -> Result<bool> {
184        tokio::task::spawn_blocking(move || PROCS.kill(pid, stop_signal, stop_timeout))
185            .await
186            .into_diagnostic()?
187    }
188
189    /// Kill a process with graceful shutdown strategy:
190    /// 1. Send the configured stop signal and wait up to ~3s (10ms intervals for first 100ms, then 50ms intervals)
191    /// 2. If still running, send SIGKILL to force termination
192    ///
193    /// This ensures fast-exiting processes don't wait unnecessarily,
194    /// while stubborn processes eventually get forcefully terminated.
195    ///
196    /// Returns `Err` if the signal could not be sent (e.g. permission denied
197    /// when targeting a process owned by another user/root).
198    fn kill(
199        &self,
200        pid: u32,
201        stop_signal: i32,
202        stop_timeout: Option<std::time::Duration>,
203    ) -> Result<bool> {
204        let sysinfo_pid = sysinfo::Pid::from_u32(pid);
205
206        // Check if process exists or is a zombie (already terminated but not reaped)
207        if self.is_terminated_or_zombie(sysinfo_pid) {
208            return Ok(false);
209        }
210
211        debug!("killing process {pid}");
212
213        #[cfg(windows)]
214        {
215            if let Some(process) = self.lock_system().process(sysinfo_pid) {
216                process.kill();
217                process.wait();
218            }
219            return Ok(true);
220        }
221
222        #[cfg(unix)]
223        {
224            let signal_name = signal_name(stop_signal);
225            // Send stop signal for graceful shutdown using libc::kill directly
226            // so we can distinguish EPERM (permission denied) from ESRCH
227            // (process already gone — possible in a narrow race window).
228            debug!("sending {signal_name} to process {pid}");
229            let ret = unsafe { libc::kill(pid as i32, stop_signal) };
230            if ret == -1 {
231                let err = std::io::Error::last_os_error();
232                if err.raw_os_error() == Some(libc::ESRCH) {
233                    debug!("process {pid} no longer exists");
234                    return Ok(false);
235                }
236                if err.raw_os_error() == Some(libc::EPERM) {
237                    return Err(miette::miette!(
238                        "failed to send {signal_name} to process {pid}: permission denied"
239                    ));
240                }
241                return Err(miette::miette!(
242                    "failed to send {signal_name} to process {pid}: {err}"
243                ));
244            }
245
246            // Fast check: 10ms intervals, then slower 50ms polling for stop_timeout.
247            // Per-daemon timeout overrides the global setting.
248            let stop_timeout = stop_timeout.unwrap_or_else(|| settings().supervisor_stop_timeout());
249            let fast_ms = 10u64;
250            let slow_ms = 50u64;
251            let total_ms = stop_timeout.as_millis().max(1) as u64;
252            let fast_count = ((total_ms / fast_ms) as usize).min(10);
253            let fast_total_ms = fast_ms * fast_count as u64;
254            let remaining_ms = total_ms.saturating_sub(fast_total_ms);
255            let slow_count = (remaining_ms / slow_ms) as usize;
256
257            for i in 0..fast_count {
258                std::thread::sleep(std::time::Duration::from_millis(fast_ms));
259                self.refresh_pids(&[pid]);
260                if self.is_terminated_or_zombie(sysinfo_pid) {
261                    debug!(
262                        "process {pid} terminated after {signal_name} ({} ms)",
263                        (i + 1) * fast_ms as usize
264                    );
265                    return Ok(true);
266                }
267            }
268
269            // Slower check: 50ms intervals for the remainder of stop_timeout
270            for i in 0..slow_count {
271                std::thread::sleep(std::time::Duration::from_millis(slow_ms));
272                self.refresh_pids(&[pid]);
273                if self.is_terminated_or_zombie(sysinfo_pid) {
274                    debug!(
275                        "process {pid} terminated after {signal_name} ({} ms)",
276                        fast_total_ms + (i + 1) as u64 * slow_ms
277                    );
278                    return Ok(true);
279                }
280            }
281
282            // SIGKILL as last resort after stop_timeout
283            warn!(
284                "process {pid} did not respond to {signal_name} after {}ms, sending SIGKILL",
285                stop_timeout.as_millis()
286            );
287            let ret = unsafe { libc::kill(pid as i32, libc::SIGKILL) };
288            if ret == -1 {
289                let err = std::io::Error::last_os_error();
290                if err.raw_os_error() != Some(libc::ESRCH) {
291                    warn!("failed to send SIGKILL to process {pid}: {err}");
292                }
293            }
294
295            // Brief wait for SIGKILL to take effect
296            std::thread::sleep(std::time::Duration::from_millis(100));
297            Ok(true)
298        }
299    }
300
301    /// Check if a process is terminated or is a zombie.
302    /// On Linux, zombie processes still have /proc/[pid] entries but are effectively dead.
303    /// This prevents unnecessary signal escalation for processes that have already exited.
304    fn is_terminated_or_zombie(&self, sysinfo_pid: sysinfo::Pid) -> bool {
305        let system = self.lock_system();
306        match system.process(sysinfo_pid) {
307            None => true,
308            Some(process) => {
309                #[cfg(unix)]
310                {
311                    matches!(process.status(), sysinfo::ProcessStatus::Zombie)
312                }
313                #[cfg(not(unix))]
314                {
315                    let _ = process;
316                    false
317                }
318            }
319        }
320    }
321
322    pub(crate) fn refresh_processes(&self) {
323        self.lock_system()
324            .refresh_processes(ProcessesToUpdate::All, true);
325    }
326
327    /// Refresh only specific PIDs instead of all processes.
328    /// More efficient when you only need to check a small set of known PIDs.
329    pub(crate) fn refresh_pids(&self, pids: &[u32]) {
330        let sysinfo_pids: Vec<sysinfo::Pid> =
331            pids.iter().map(|p| sysinfo::Pid::from_u32(*p)).collect();
332        self.lock_system()
333            .refresh_processes(ProcessesToUpdate::Some(&sysinfo_pids), true);
334    }
335
336    /// Get aggregated stats for multiple process groups in a single pass.
337    ///
338    /// Builds the parent→children map once (O(N)) and then BFS-es from each
339    /// root PID (O(D_i) per daemon). Total cost is O(N + ΣD_i) instead of
340    /// O(D × N) when calling `get_group_stats` in a loop.
341    pub fn get_batch_group_stats(&self, pids: &[u32]) -> Vec<(u32, Option<ProcessStats>)> {
342        let system = self.lock_system();
343        let processes = system.processes();
344
345        let now = std::time::SystemTime::now()
346            .duration_since(std::time::UNIX_EPOCH)
347            .map(|d| d.as_secs())
348            .unwrap_or(0);
349
350        // Build parent → children map once for all daemons
351        let mut children_map: std::collections::HashMap<sysinfo::Pid, Vec<sysinfo::Pid>> =
352            std::collections::HashMap::new();
353        for (child_pid, child) in processes {
354            if let Some(ppid) = child.parent() {
355                children_map.entry(ppid).or_default().push(*child_pid);
356            }
357        }
358
359        pids.iter()
360            .map(|&pid| {
361                let root_pid = sysinfo::Pid::from_u32(pid);
362                let Some(root) = processes.get(&root_pid) else {
363                    return (pid, None);
364                };
365
366                let root_disk = root.disk_usage();
367                let mut stats = ProcessStats {
368                    cpu_percent: root.cpu_usage(),
369                    memory_bytes: root.memory(),
370                    uptime_secs: now.saturating_sub(root.start_time()),
371                    disk_read_bytes: root_disk.read_bytes,
372                    disk_write_bytes: root_disk.written_bytes,
373                };
374
375                // BFS from root_pid to find all descendants
376                let mut queue = std::collections::VecDeque::new();
377                if let Some(direct_children) = children_map.get(&root_pid) {
378                    queue.extend(direct_children);
379                }
380                while let Some(child_pid) = queue.pop_front() {
381                    if let Some(child) = processes.get(&child_pid) {
382                        let disk = child.disk_usage();
383                        stats.cpu_percent += child.cpu_usage();
384                        stats.memory_bytes += child.memory();
385                        stats.disk_read_bytes += disk.read_bytes;
386                        stats.disk_write_bytes += disk.written_bytes;
387                    }
388                    if let Some(grandchildren) = children_map.get(&child_pid) {
389                        queue.extend(grandchildren);
390                    }
391                }
392
393                (pid, Some(stats))
394            })
395            .collect()
396    }
397
398    /// Get process stats (cpu%, memory bytes, uptime secs, disk I/O) for a given PID
399    pub fn get_stats(&self, pid: u32) -> Option<ProcessStats> {
400        let system = self.lock_system();
401        system.process(sysinfo::Pid::from_u32(pid)).map(|p| {
402            let now = std::time::SystemTime::now()
403                .duration_since(std::time::UNIX_EPOCH)
404                .map(|d| d.as_secs())
405                .unwrap_or(0);
406            let disk = p.disk_usage();
407            ProcessStats {
408                cpu_percent: p.cpu_usage(),
409                memory_bytes: p.memory(),
410                uptime_secs: now.saturating_sub(p.start_time()),
411                disk_read_bytes: disk.read_bytes,
412                disk_write_bytes: disk.written_bytes,
413            }
414        })
415    }
416
417    /// Get extended process information for a given PID
418    pub fn get_extended_stats(&self, pid: u32) -> Option<ExtendedProcessStats> {
419        let system = self.lock_system();
420        system.process(sysinfo::Pid::from_u32(pid)).map(|p| {
421            let now = std::time::SystemTime::now()
422                .duration_since(std::time::UNIX_EPOCH)
423                .map(|d| d.as_secs())
424                .unwrap_or(0);
425            let disk = p.disk_usage();
426
427            ExtendedProcessStats {
428                name: p.name().to_string_lossy().to_string(),
429                exe_path: p.exe().map(|e| e.to_string_lossy().to_string()),
430                cwd: p.cwd().map(|c| c.to_string_lossy().to_string()),
431                environ: p
432                    .environ()
433                    .iter()
434                    .take(20)
435                    .map(|s| s.to_string_lossy().to_string())
436                    .collect(),
437                status: format!("{:?}", p.status()),
438                cpu_percent: p.cpu_usage(),
439                memory_bytes: p.memory(),
440                virtual_memory_bytes: p.virtual_memory(),
441                uptime_secs: now.saturating_sub(p.start_time()),
442                start_time: p.start_time(),
443                disk_read_bytes: disk.read_bytes,
444                disk_write_bytes: disk.written_bytes,
445                parent_pid: p.parent().map(|pp| pp.as_u32()),
446                thread_count: p.tasks().map(|t| t.len()).unwrap_or(0),
447                user_id: p.user_id().map(|u| u.to_string()),
448            }
449        })
450    }
451}
452
453#[derive(Debug, Clone, Copy)]
454pub struct ProcessStats {
455    pub cpu_percent: f32,
456    pub memory_bytes: u64,
457    pub uptime_secs: u64,
458    pub disk_read_bytes: u64,
459    pub disk_write_bytes: u64,
460}
461
462impl ProcessStats {
463    pub fn memory_display(&self) -> String {
464        format_bytes(self.memory_bytes)
465    }
466
467    pub fn cpu_display(&self) -> String {
468        format!("{:.1}%", self.cpu_percent)
469    }
470
471    pub fn uptime_display(&self) -> String {
472        format_duration(self.uptime_secs)
473    }
474
475    pub fn disk_read_display(&self) -> String {
476        format_bytes_per_sec(self.disk_read_bytes)
477    }
478
479    pub fn disk_write_display(&self) -> String {
480        format_bytes_per_sec(self.disk_write_bytes)
481    }
482}
483
484/// Extended process stats with more detailed information
485#[derive(Debug, Clone)]
486pub struct ExtendedProcessStats {
487    pub name: String,
488    pub exe_path: Option<String>,
489    pub cwd: Option<String>,
490    pub environ: Vec<String>,
491    pub status: String,
492    pub cpu_percent: f32,
493    pub memory_bytes: u64,
494    pub virtual_memory_bytes: u64,
495    pub uptime_secs: u64,
496    pub start_time: u64,
497    pub disk_read_bytes: u64,
498    pub disk_write_bytes: u64,
499    pub parent_pid: Option<u32>,
500    pub thread_count: usize,
501    pub user_id: Option<String>,
502}
503
504impl ExtendedProcessStats {
505    pub fn memory_display(&self) -> String {
506        format_bytes(self.memory_bytes)
507    }
508
509    pub fn virtual_memory_display(&self) -> String {
510        format_bytes(self.virtual_memory_bytes)
511    }
512
513    pub fn cpu_display(&self) -> String {
514        format!("{:.1}%", self.cpu_percent)
515    }
516
517    pub fn uptime_display(&self) -> String {
518        format_duration(self.uptime_secs)
519    }
520
521    pub fn start_time_display(&self) -> String {
522        use std::time::{Duration, UNIX_EPOCH};
523        let datetime = UNIX_EPOCH + Duration::from_secs(self.start_time);
524        chrono::DateTime::<chrono::Local>::from(datetime)
525            .format("%Y-%m-%d %H:%M:%S")
526            .to_string()
527    }
528
529    pub fn disk_read_display(&self) -> String {
530        format_bytes_per_sec(self.disk_read_bytes)
531    }
532
533    pub fn disk_write_display(&self) -> String {
534        format_bytes_per_sec(self.disk_write_bytes)
535    }
536}
537
538fn format_bytes(bytes: u64) -> String {
539    if bytes < 1024 {
540        format!("{bytes}B")
541    } else if bytes < 1024 * 1024 {
542        format!("{:.1}KB", bytes as f64 / 1024.0)
543    } else if bytes < 1024 * 1024 * 1024 {
544        format!("{:.1}MB", bytes as f64 / (1024.0 * 1024.0))
545    } else {
546        format!("{:.1}GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
547    }
548}
549
550fn format_duration(secs: u64) -> String {
551    if secs < 60 {
552        format!("{secs}s")
553    } else if secs < 3600 {
554        format!("{}m {}s", secs / 60, secs % 60)
555    } else if secs < 86400 {
556        let hours = secs / 3600;
557        let mins = (secs % 3600) / 60;
558        format!("{hours}h {mins}m")
559    } else {
560        let days = secs / 86400;
561        let hours = (secs % 86400) / 3600;
562        format!("{days}d {hours}h")
563    }
564}
565
566fn format_bytes_per_sec(bytes: u64) -> String {
567    if bytes < 1024 {
568        format!("{bytes}B/s")
569    } else if bytes < 1024 * 1024 {
570        format!("{:.1}KB/s", bytes as f64 / 1024.0)
571    } else if bytes < 1024 * 1024 * 1024 {
572        format!("{:.1}MB/s", bytes as f64 / (1024.0 * 1024.0))
573    } else {
574        format!("{:.1}GB/s", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
575    }
576}
577
578#[cfg(unix)]
579fn signal_name(sig: i32) -> &'static str {
580    match sig {
581        libc::SIGHUP => "SIGHUP",
582        libc::SIGINT => "SIGINT",
583        libc::SIGQUIT => "SIGQUIT",
584        libc::SIGTERM => "SIGTERM",
585        libc::SIGUSR1 => "SIGUSR1",
586        libc::SIGUSR2 => "SIGUSR2",
587        libc::SIGKILL => "SIGKILL",
588        _ => "UNKNOWN",
589    }
590}