Skip to main content

cuenv_core/tasks/
process_registry.rs

1//! Global process registry for tracking and terminating spawned child processes.
2//!
3//! This module provides a centralized registry for tracking PIDs of child processes
4//! spawned during task execution. When the application receives a termination signal
5//! (e.g., Ctrl-C), the registry can terminate all tracked processes and their children.
6//!
7//! # Process Groups
8//!
9//! On Unix systems, spawned processes are placed in their own process groups using
10//! `setpgid(0, 0)`. This allows terminating the entire process tree (including any
11//! child processes spawned by the task) by sending signals to the process group.
12//!
13//! # Usage
14//!
15//! ```ignore
16//! use cuenv_core::tasks::process_registry::global_registry;
17//!
18//! // Register a process after spawning
19//! if let Some(pid) = child.id() {
20//!     global_registry().register(pid, "task_name".to_string()).await;
21//! }
22//!
23//! // Unregister when process completes
24//! global_registry().unregister(pid).await;
25//!
26//! // Terminate all on shutdown
27//! global_registry().terminate_all(Duration::from_secs(5)).await;
28//! ```
29
30use std::collections::HashMap;
31use std::sync::{Arc, OnceLock};
32use std::time::Duration;
33use tokio::sync::Mutex;
34use tracing::{debug, info, warn};
35
36/// Global process registry singleton.
37static GLOBAL_REGISTRY: OnceLock<Arc<ProcessRegistry>> = OnceLock::new();
38
39/// Returns the global process registry instance.
40///
41/// The registry is lazily initialized on first access and shared across
42/// the entire application.
43#[must_use]
44pub fn global_registry() -> Arc<ProcessRegistry> {
45    GLOBAL_REGISTRY
46        .get_or_init(|| Arc::new(ProcessRegistry::new()))
47        .clone()
48}
49
50/// Registry for tracking spawned child processes.
51///
52/// Maintains a mapping of PIDs to task names, allowing for graceful shutdown
53/// of all child processes when the application is terminated.
54pub struct ProcessRegistry {
55    /// Map of process IDs to task names for debugging/logging.
56    pids: Mutex<HashMap<u32, String>>,
57}
58
59impl ProcessRegistry {
60    /// Creates a new empty process registry.
61    #[must_use]
62    pub fn new() -> Self {
63        Self {
64            pids: Mutex::new(HashMap::new()),
65        }
66    }
67
68    /// Registers a process with the given PID and task name.
69    ///
70    /// Call this immediately after spawning a child process.
71    pub async fn register(&self, pid: u32, task_name: String) {
72        let mut pids = self.pids.lock().await;
73        debug!(pid, task = %task_name, "Registering process");
74        pids.insert(pid, task_name);
75    }
76
77    /// Unregisters a process after it has completed.
78    ///
79    /// Call this after successfully waiting for the process to exit.
80    pub async fn unregister(&self, pid: u32) {
81        let mut pids = self.pids.lock().await;
82        if let Some(task_name) = pids.remove(&pid) {
83            debug!(pid, task = %task_name, "Unregistering process");
84        }
85    }
86
87    /// Returns the number of currently tracked processes.
88    pub async fn count(&self) -> usize {
89        self.pids.lock().await.len()
90    }
91
92    /// Terminates all registered processes gracefully.
93    ///
94    /// This method:
95    /// 1. Sends SIGTERM to all process groups (allowing graceful shutdown)
96    /// 2. Waits up to `timeout` for processes to exit
97    /// 3. Sends SIGKILL to any remaining processes
98    ///
99    /// On Unix, signals are sent to the entire process group (-pid) to ensure
100    /// child processes spawned by tasks are also terminated.
101    pub async fn terminate_all(&self, timeout: Duration) {
102        let mut pids = self.pids.lock().await;
103
104        if pids.is_empty() {
105            return;
106        }
107
108        info!(count = pids.len(), "Terminating child processes");
109
110        // Phase 1: Send SIGTERM to all process groups
111        for (pid, task_name) in pids.iter() {
112            debug!(pid, task = %task_name, "Sending SIGTERM");
113            Self::send_term_signal(*pid);
114        }
115
116        // Phase 2: Wait for processes to exit (with timeout)
117        let deadline = std::time::Instant::now() + timeout;
118        while !pids.is_empty() && std::time::Instant::now() < deadline {
119            // Check which processes have exited
120            let mut exited = Vec::new();
121            for (pid, _) in pids.iter() {
122                if !Self::is_process_alive(*pid) {
123                    exited.push(*pid);
124                }
125            }
126
127            // Remove exited processes
128            for pid in exited {
129                if let Some(task_name) = pids.remove(&pid) {
130                    debug!(pid, task = %task_name, "Process exited gracefully");
131                }
132            }
133
134            if !pids.is_empty() {
135                // Short sleep before checking again
136                tokio::time::sleep(Duration::from_millis(100)).await;
137            }
138        }
139
140        // Phase 3: Force kill any remaining processes
141        for (pid, task_name) in pids.drain() {
142            warn!(pid, task = %task_name, "Force killing process after timeout");
143            Self::send_kill_signal(pid);
144        }
145    }
146
147    /// Sends SIGTERM to a process group (Unix) or terminates process (Windows).
148    #[cfg(unix)]
149    fn send_term_signal(pid: u32) {
150        // SAFETY: libc::kill with negative pid sends signal to entire process group.
151        // The pid was obtained from a spawned child process and is valid.
152        // SIGTERM is a safe signal that requests graceful termination.
153        #[expect(unsafe_code, reason = "Required for POSIX signal handling")]
154        unsafe {
155            // Use negative PID to send to entire process group
156            libc::kill(-(pid as i32), libc::SIGTERM);
157        }
158    }
159
160    /// Sends SIGKILL to a process group (Unix) or terminates process (Windows).
161    #[cfg(unix)]
162    fn send_kill_signal(pid: u32) {
163        // SAFETY: libc::kill with negative pid sends signal to entire process group.
164        // The pid was obtained from a spawned child process and is valid.
165        // SIGKILL forces immediate termination.
166        #[expect(unsafe_code, reason = "Required for POSIX signal handling")]
167        unsafe {
168            // Use negative PID to kill entire process group
169            libc::kill(-(pid as i32), libc::SIGKILL);
170        }
171    }
172
173    /// Checks if a process is still alive.
174    #[cfg(unix)]
175    fn is_process_alive(pid: u32) -> bool {
176        // SAFETY: libc::kill with signal 0 checks if process exists without sending a signal.
177        // This is a standard POSIX idiom for checking process existence.
178        #[expect(unsafe_code, reason = "Required for POSIX process existence check")]
179        unsafe {
180            libc::kill(pid as i32, 0) == 0
181        }
182    }
183
184    /// Windows implementation: terminate process using sysinfo crate.
185    #[cfg(windows)]
186    fn send_term_signal(pid: u32) {
187        use sysinfo::{Pid, ProcessRefreshKind, ProcessesToUpdate, Signal, System};
188
189        let mut system = System::new();
190        let process_pid = Pid::from(pid as usize);
191        system.refresh_processes_specifics(
192            ProcessesToUpdate::Some(&[process_pid]),
193            false,
194            ProcessRefreshKind::nothing(),
195        );
196
197        if let Some(process) = system.process(process_pid) {
198            let _ = process.kill_with(Signal::Term);
199        }
200    }
201
202    /// Windows implementation: force kill process.
203    #[cfg(windows)]
204    fn send_kill_signal(pid: u32) {
205        use sysinfo::{Pid, ProcessRefreshKind, ProcessesToUpdate, Signal, System};
206
207        let mut system = System::new();
208        let process_pid = Pid::from(pid as usize);
209        system.refresh_processes_specifics(
210            ProcessesToUpdate::Some(&[process_pid]),
211            false,
212            ProcessRefreshKind::nothing(),
213        );
214
215        if let Some(process) = system.process(process_pid) {
216            let _ = process.kill_with(Signal::Kill);
217        }
218    }
219
220    /// Windows implementation: check if process is alive.
221    #[cfg(windows)]
222    fn is_process_alive(pid: u32) -> bool {
223        use sysinfo::{Pid, ProcessRefreshKind, ProcessesToUpdate, System};
224
225        let mut system = System::new();
226        let process_pid = Pid::from(pid as usize);
227        system.refresh_processes_specifics(
228            ProcessesToUpdate::Some(&[process_pid]),
229            false,
230            ProcessRefreshKind::nothing(),
231        );
232
233        system.process(process_pid).is_some()
234    }
235}
236
237impl Default for ProcessRegistry {
238    fn default() -> Self {
239        Self::new()
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[tokio::test]
248    async fn test_registry_new() {
249        let registry = ProcessRegistry::new();
250        assert_eq!(registry.count().await, 0);
251    }
252
253    #[tokio::test]
254    async fn test_register_unregister() {
255        let registry = ProcessRegistry::new();
256
257        registry.register(1234, "test_task".to_string()).await;
258        assert_eq!(registry.count().await, 1);
259
260        registry.unregister(1234).await;
261        assert_eq!(registry.count().await, 0);
262    }
263
264    #[tokio::test]
265    async fn test_unregister_nonexistent() {
266        let registry = ProcessRegistry::new();
267
268        // Should not panic when unregistering non-existent PID
269        registry.unregister(9999).await;
270        assert_eq!(registry.count().await, 0);
271    }
272
273    #[tokio::test]
274    async fn test_terminate_empty() {
275        let registry = ProcessRegistry::new();
276
277        // Should return immediately when no processes are registered
278        registry.terminate_all(Duration::from_secs(1)).await;
279        assert_eq!(registry.count().await, 0);
280    }
281
282    #[tokio::test]
283    async fn test_global_registry_singleton() {
284        let r1 = global_registry();
285        let r2 = global_registry();
286
287        // Both should point to the same instance
288        assert!(Arc::ptr_eq(&r1, &r2));
289    }
290}