Skip to main content

pitchfork_cli/
procs.rs

1use crate::Result;
2use crate::settings::settings;
3use miette::IntoDiagnostic;
4use once_cell::sync::Lazy;
5use std::sync::Mutex;
6use sysinfo::ProcessesToUpdate;
7#[cfg(unix)]
8use sysinfo::Signal;
9
10pub struct Procs {
11    system: Mutex<sysinfo::System>,
12}
13
14pub static PROCS: Lazy<Procs> = Lazy::new(Procs::new);
15
16impl Default for Procs {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl Procs {
23    pub fn new() -> Self {
24        let procs = Self {
25            system: Mutex::new(sysinfo::System::new()),
26        };
27        procs.refresh_processes();
28        procs
29    }
30
31    fn lock_system(&self) -> std::sync::MutexGuard<'_, sysinfo::System> {
32        self.system.lock().unwrap_or_else(|poisoned| {
33            warn!("System mutex was poisoned, recovering");
34            poisoned.into_inner()
35        })
36    }
37
38    pub fn title(&self, pid: u32) -> Option<String> {
39        self.lock_system()
40            .process(sysinfo::Pid::from_u32(pid))
41            .map(|p| p.name().to_string_lossy().to_string())
42    }
43
44    pub fn is_running(&self, pid: u32) -> bool {
45        self.lock_system()
46            .process(sysinfo::Pid::from_u32(pid))
47            .is_some()
48    }
49
50    /// Walk the /proc tree to find all descendant PIDs.
51    /// Kept for diagnostics/status display; no longer used in the kill path.
52    #[allow(dead_code)]
53    pub fn all_children(&self, pid: u32) -> Vec<u32> {
54        let system = self.lock_system();
55        let all = system.processes();
56        let mut children = vec![];
57        for (child_pid, process) in all {
58            let mut process = process;
59            while let Some(parent) = process.parent() {
60                if parent == sysinfo::Pid::from_u32(pid) {
61                    children.push(child_pid.as_u32());
62                    break;
63                }
64                match system.process(parent) {
65                    Some(p) => process = p,
66                    None => break,
67                }
68            }
69        }
70        children
71    }
72
73    pub async fn kill_process_group_async(&self, pid: u32) -> Result<bool> {
74        let result = tokio::task::spawn_blocking(move || PROCS.kill_process_group(pid))
75            .await
76            .into_diagnostic()?;
77        Ok(result)
78    }
79
80    /// Kill an entire process group with graceful shutdown strategy:
81    /// 1. Send SIGTERM to the process group (-pgid) and wait up to ~3s
82    /// 2. If any processes remain, send SIGKILL to the group
83    ///
84    /// Since daemons are spawned with setsid(), the daemon PID == PGID,
85    /// so this atomically signals all descendant processes.
86    #[cfg(unix)]
87    fn kill_process_group(&self, pid: u32) -> bool {
88        let pgid = pid as i32;
89
90        debug!("killing process group {pgid}");
91
92        // Send SIGTERM to the entire process group.
93        // killpg sends to all processes in the group atomically.
94        // We intentionally skip the zombie check here because the leader may be
95        // a zombie while children in the group are still running.
96        let ret = unsafe { libc::killpg(pgid, libc::SIGTERM) };
97        if ret == -1 {
98            let err = std::io::Error::last_os_error();
99            if err.raw_os_error() == Some(libc::ESRCH) {
100                debug!("process group {pgid} no longer exists");
101                return false;
102            }
103            warn!("failed to send SIGTERM to process group {pgid}: {err}");
104        }
105
106        // Wait for graceful shutdown: fast initial check then slower polling.
107        // Use the configurable stop_timeout setting to govern total wait time.
108        let stop_timeout = settings().supervisor_stop_timeout();
109        let fast_ms = 10u64;
110        let slow_ms = 50u64;
111        let total_ms = stop_timeout.as_millis().max(1) as u64;
112        let fast_count = ((total_ms / fast_ms) as usize).min(10);
113        let fast_total_ms = fast_ms * fast_count as u64;
114        let remaining_ms = total_ms.saturating_sub(fast_total_ms);
115        let slow_count = (remaining_ms / slow_ms) as usize;
116
117        let fast_checks =
118            std::iter::repeat_n(std::time::Duration::from_millis(fast_ms), fast_count);
119        let slow_checks =
120            std::iter::repeat_n(std::time::Duration::from_millis(slow_ms), slow_count);
121        let mut elapsed_ms = 0u64;
122
123        for sleep_duration in fast_checks.chain(slow_checks) {
124            std::thread::sleep(sleep_duration);
125            self.refresh_pids(&[pid]);
126            elapsed_ms += sleep_duration.as_millis() as u64;
127            if self.is_terminated_or_zombie(sysinfo::Pid::from_u32(pid)) {
128                debug!("process group {pgid} terminated after SIGTERM ({elapsed_ms} ms)",);
129                return true;
130            }
131        }
132
133        // SIGKILL the entire process group as last resort
134        warn!(
135            "process group {pgid} did not respond to SIGTERM after {}ms, sending SIGKILL",
136            stop_timeout.as_millis()
137        );
138        unsafe {
139            libc::killpg(pgid, libc::SIGKILL);
140        }
141
142        // Brief wait for SIGKILL to take effect
143        std::thread::sleep(std::time::Duration::from_millis(100));
144        true
145    }
146
147    #[cfg(not(unix))]
148    fn kill_process_group(&self, pid: u32) -> bool {
149        // On non-unix platforms, fall back to single-process kill
150        self.kill(pid)
151    }
152
153    pub async fn kill_async(&self, pid: u32) -> Result<bool> {
154        let result = tokio::task::spawn_blocking(move || PROCS.kill(pid))
155            .await
156            .into_diagnostic()?;
157        Ok(result)
158    }
159
160    /// Kill a process with graceful shutdown strategy:
161    /// 1. Send SIGTERM and wait up to ~3s (10ms intervals for first 100ms, then 50ms intervals)
162    /// 2. If still running, send SIGKILL to force termination
163    ///
164    /// This ensures fast-exiting processes don't wait unnecessarily,
165    /// while stubborn processes eventually get forcefully terminated.
166    fn kill(&self, pid: u32) -> bool {
167        let sysinfo_pid = sysinfo::Pid::from_u32(pid);
168
169        // Check if process exists or is a zombie (already terminated but not reaped)
170        if self.is_terminated_or_zombie(sysinfo_pid) {
171            return false;
172        }
173
174        debug!("killing process {pid}");
175
176        #[cfg(windows)]
177        {
178            if let Some(process) = self.lock_system().process(sysinfo_pid) {
179                process.kill();
180                process.wait();
181            }
182            return true;
183        }
184
185        #[cfg(unix)]
186        {
187            // Send SIGTERM for graceful shutdown
188            if let Some(process) = self.lock_system().process(sysinfo_pid) {
189                debug!("sending SIGTERM to process {pid}");
190                process.kill_with(Signal::Term);
191            }
192
193            // Fast check: 10ms intervals, then slower 50ms polling for stop_timeout.
194            // Cap fast_count so short timeouts (e.g. 50ms) are honoured correctly.
195            let stop_timeout = settings().supervisor_stop_timeout();
196            let fast_ms = 10u64;
197            let slow_ms = 50u64;
198            let total_ms = stop_timeout.as_millis().max(1) as u64;
199            let fast_count = ((total_ms / fast_ms) as usize).min(10);
200            let fast_total_ms = fast_ms * fast_count as u64;
201            let remaining_ms = total_ms.saturating_sub(fast_total_ms);
202            let slow_count = (remaining_ms / slow_ms) as usize;
203
204            for i in 0..fast_count {
205                std::thread::sleep(std::time::Duration::from_millis(fast_ms));
206                self.refresh_pids(&[pid]);
207                if self.is_terminated_or_zombie(sysinfo_pid) {
208                    debug!(
209                        "process {pid} terminated after SIGTERM ({} ms)",
210                        (i + 1) * fast_ms as usize
211                    );
212                    return true;
213                }
214            }
215
216            // Slower check: 50ms intervals for the remainder of stop_timeout
217            for i in 0..slow_count {
218                std::thread::sleep(std::time::Duration::from_millis(slow_ms));
219                self.refresh_pids(&[pid]);
220                if self.is_terminated_or_zombie(sysinfo_pid) {
221                    debug!(
222                        "process {pid} terminated after SIGTERM ({} ms)",
223                        fast_total_ms + (i + 1) as u64 * slow_ms
224                    );
225                    return true;
226                }
227            }
228
229            // SIGKILL as last resort after stop_timeout
230            if let Some(process) = self.lock_system().process(sysinfo_pid) {
231                warn!(
232                    "process {pid} did not respond to SIGTERM after {}ms, sending SIGKILL",
233                    stop_timeout.as_millis()
234                );
235                process.kill_with(Signal::Kill);
236                process.wait();
237            }
238
239            true
240        }
241    }
242
243    /// Check if a process is terminated or is a zombie.
244    /// On Linux, zombie processes still have /proc/[pid] entries but are effectively dead.
245    /// This prevents unnecessary signal escalation for processes that have already exited.
246    fn is_terminated_or_zombie(&self, sysinfo_pid: sysinfo::Pid) -> bool {
247        let system = self.lock_system();
248        match system.process(sysinfo_pid) {
249            None => true,
250            Some(process) => {
251                #[cfg(unix)]
252                {
253                    matches!(process.status(), sysinfo::ProcessStatus::Zombie)
254                }
255                #[cfg(not(unix))]
256                {
257                    let _ = process;
258                    false
259                }
260            }
261        }
262    }
263
264    pub(crate) fn refresh_processes(&self) {
265        self.lock_system()
266            .refresh_processes(ProcessesToUpdate::All, true);
267    }
268
269    /// Refresh only specific PIDs instead of all processes.
270    /// More efficient when you only need to check a small set of known PIDs.
271    pub(crate) fn refresh_pids(&self, pids: &[u32]) {
272        let sysinfo_pids: Vec<sysinfo::Pid> =
273            pids.iter().map(|p| sysinfo::Pid::from_u32(*p)).collect();
274        self.lock_system()
275            .refresh_processes(ProcessesToUpdate::Some(&sysinfo_pids), true);
276    }
277
278    /// Get aggregated stats for multiple process groups in a single pass.
279    ///
280    /// Builds the parent→children map once (O(N)) and then BFS-es from each
281    /// root PID (O(D_i) per daemon). Total cost is O(N + ΣD_i) instead of
282    /// O(D × N) when calling `get_group_stats` in a loop.
283    pub fn get_batch_group_stats(&self, pids: &[u32]) -> Vec<(u32, Option<ProcessStats>)> {
284        let system = self.lock_system();
285        let processes = system.processes();
286
287        let now = std::time::SystemTime::now()
288            .duration_since(std::time::UNIX_EPOCH)
289            .map(|d| d.as_secs())
290            .unwrap_or(0);
291
292        // Build parent → children map once for all daemons
293        let mut children_map: std::collections::HashMap<sysinfo::Pid, Vec<sysinfo::Pid>> =
294            std::collections::HashMap::new();
295        for (child_pid, child) in processes {
296            if let Some(ppid) = child.parent() {
297                children_map.entry(ppid).or_default().push(*child_pid);
298            }
299        }
300
301        pids.iter()
302            .map(|&pid| {
303                let root_pid = sysinfo::Pid::from_u32(pid);
304                let Some(root) = processes.get(&root_pid) else {
305                    return (pid, None);
306                };
307
308                let root_disk = root.disk_usage();
309                let mut stats = ProcessStats {
310                    cpu_percent: root.cpu_usage(),
311                    memory_bytes: root.memory(),
312                    uptime_secs: now.saturating_sub(root.start_time()),
313                    disk_read_bytes: root_disk.read_bytes,
314                    disk_write_bytes: root_disk.written_bytes,
315                };
316
317                // BFS from root_pid to find all descendants
318                let mut queue = std::collections::VecDeque::new();
319                if let Some(direct_children) = children_map.get(&root_pid) {
320                    queue.extend(direct_children);
321                }
322                while let Some(child_pid) = queue.pop_front() {
323                    if let Some(child) = processes.get(&child_pid) {
324                        let disk = child.disk_usage();
325                        stats.cpu_percent += child.cpu_usage();
326                        stats.memory_bytes += child.memory();
327                        stats.disk_read_bytes += disk.read_bytes;
328                        stats.disk_write_bytes += disk.written_bytes;
329                    }
330                    if let Some(grandchildren) = children_map.get(&child_pid) {
331                        queue.extend(grandchildren);
332                    }
333                }
334
335                (pid, Some(stats))
336            })
337            .collect()
338    }
339
340    /// Get process stats (cpu%, memory bytes, uptime secs, disk I/O) for a given PID
341    pub fn get_stats(&self, pid: u32) -> Option<ProcessStats> {
342        let system = self.lock_system();
343        system.process(sysinfo::Pid::from_u32(pid)).map(|p| {
344            let now = std::time::SystemTime::now()
345                .duration_since(std::time::UNIX_EPOCH)
346                .map(|d| d.as_secs())
347                .unwrap_or(0);
348            let disk = p.disk_usage();
349            ProcessStats {
350                cpu_percent: p.cpu_usage(),
351                memory_bytes: p.memory(),
352                uptime_secs: now.saturating_sub(p.start_time()),
353                disk_read_bytes: disk.read_bytes,
354                disk_write_bytes: disk.written_bytes,
355            }
356        })
357    }
358
359    /// Get extended process information for a given PID
360    pub fn get_extended_stats(&self, pid: u32) -> Option<ExtendedProcessStats> {
361        let system = self.lock_system();
362        system.process(sysinfo::Pid::from_u32(pid)).map(|p| {
363            let now = std::time::SystemTime::now()
364                .duration_since(std::time::UNIX_EPOCH)
365                .map(|d| d.as_secs())
366                .unwrap_or(0);
367            let disk = p.disk_usage();
368
369            ExtendedProcessStats {
370                name: p.name().to_string_lossy().to_string(),
371                exe_path: p.exe().map(|e| e.to_string_lossy().to_string()),
372                cwd: p.cwd().map(|c| c.to_string_lossy().to_string()),
373                environ: p
374                    .environ()
375                    .iter()
376                    .take(20)
377                    .map(|s| s.to_string_lossy().to_string())
378                    .collect(),
379                status: format!("{:?}", p.status()),
380                cpu_percent: p.cpu_usage(),
381                memory_bytes: p.memory(),
382                virtual_memory_bytes: p.virtual_memory(),
383                uptime_secs: now.saturating_sub(p.start_time()),
384                start_time: p.start_time(),
385                disk_read_bytes: disk.read_bytes,
386                disk_write_bytes: disk.written_bytes,
387                parent_pid: p.parent().map(|pp| pp.as_u32()),
388                thread_count: p.tasks().map(|t| t.len()).unwrap_or(0),
389                user_id: p.user_id().map(|u| u.to_string()),
390            }
391        })
392    }
393}
394
395#[derive(Debug, Clone, Copy)]
396pub struct ProcessStats {
397    pub cpu_percent: f32,
398    pub memory_bytes: u64,
399    pub uptime_secs: u64,
400    pub disk_read_bytes: u64,
401    pub disk_write_bytes: u64,
402}
403
404impl ProcessStats {
405    pub fn memory_display(&self) -> String {
406        format_bytes(self.memory_bytes)
407    }
408
409    pub fn cpu_display(&self) -> String {
410        format!("{:.1}%", self.cpu_percent)
411    }
412
413    pub fn uptime_display(&self) -> String {
414        format_duration(self.uptime_secs)
415    }
416
417    pub fn disk_read_display(&self) -> String {
418        format_bytes_per_sec(self.disk_read_bytes)
419    }
420
421    pub fn disk_write_display(&self) -> String {
422        format_bytes_per_sec(self.disk_write_bytes)
423    }
424}
425
426/// Extended process stats with more detailed information
427#[derive(Debug, Clone)]
428pub struct ExtendedProcessStats {
429    pub name: String,
430    pub exe_path: Option<String>,
431    pub cwd: Option<String>,
432    pub environ: Vec<String>,
433    pub status: String,
434    pub cpu_percent: f32,
435    pub memory_bytes: u64,
436    pub virtual_memory_bytes: u64,
437    pub uptime_secs: u64,
438    pub start_time: u64,
439    pub disk_read_bytes: u64,
440    pub disk_write_bytes: u64,
441    pub parent_pid: Option<u32>,
442    pub thread_count: usize,
443    pub user_id: Option<String>,
444}
445
446impl ExtendedProcessStats {
447    pub fn memory_display(&self) -> String {
448        format_bytes(self.memory_bytes)
449    }
450
451    pub fn virtual_memory_display(&self) -> String {
452        format_bytes(self.virtual_memory_bytes)
453    }
454
455    pub fn cpu_display(&self) -> String {
456        format!("{:.1}%", self.cpu_percent)
457    }
458
459    pub fn uptime_display(&self) -> String {
460        format_duration(self.uptime_secs)
461    }
462
463    pub fn start_time_display(&self) -> String {
464        use std::time::{Duration, UNIX_EPOCH};
465        let datetime = UNIX_EPOCH + Duration::from_secs(self.start_time);
466        chrono::DateTime::<chrono::Local>::from(datetime)
467            .format("%Y-%m-%d %H:%M:%S")
468            .to_string()
469    }
470
471    pub fn disk_read_display(&self) -> String {
472        format_bytes_per_sec(self.disk_read_bytes)
473    }
474
475    pub fn disk_write_display(&self) -> String {
476        format_bytes_per_sec(self.disk_write_bytes)
477    }
478}
479
480fn format_bytes(bytes: u64) -> String {
481    if bytes < 1024 {
482        format!("{bytes}B")
483    } else if bytes < 1024 * 1024 {
484        format!("{:.1}KB", bytes as f64 / 1024.0)
485    } else if bytes < 1024 * 1024 * 1024 {
486        format!("{:.1}MB", bytes as f64 / (1024.0 * 1024.0))
487    } else {
488        format!("{:.1}GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
489    }
490}
491
492fn format_duration(secs: u64) -> String {
493    if secs < 60 {
494        format!("{secs}s")
495    } else if secs < 3600 {
496        format!("{}m {}s", secs / 60, secs % 60)
497    } else if secs < 86400 {
498        let hours = secs / 3600;
499        let mins = (secs % 3600) / 60;
500        format!("{hours}h {mins}m")
501    } else {
502        let days = secs / 86400;
503        let hours = (secs % 86400) / 3600;
504        format!("{days}d {hours}h")
505    }
506}
507
508fn format_bytes_per_sec(bytes: u64) -> String {
509    if bytes < 1024 {
510        format!("{bytes}B/s")
511    } else if bytes < 1024 * 1024 {
512        format!("{:.1}KB/s", bytes as f64 / 1024.0)
513    } else if bytes < 1024 * 1024 * 1024 {
514        format!("{:.1}MB/s", bytes as f64 / (1024.0 * 1024.0))
515    } else {
516        format!("{:.1}GB/s", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
517    }
518}