pid_set/
lib.rs

1//! # PID Set Library
2//!
3//! `pid_set` is a library for managing and monitoring process identifiers (PIDs) using epoll on Linux.
4//! It allows for asynchronous notification when a process exits by leveraging epoll and pidfd (process file descriptors).
5//!
6//! ## Features
7//! - Create a `PidSet` to manage multiple PIDs.
8//! - Monitor process exits using epoll.
9//! - Handle system call errors gracefully with custom errors.
10//!
11//! ## Usage
12//! Add this to your `Cargo.toml`:
13//!
14//! ```toml
15//! [dependencies]
16//! pid_set = "0.1.0"
17//! ```
18//!
19//! ## Examples
20//! Here's how you can use `PidSet` to monitor a list of PIDs:
21//!
22//! ```rust
23//! use pid_set::{PidSet, PidSetError};
24//!
25//! fn main() -> Result<(), PidSetError> {
26//!    let mut cmd1 = std::process::Command::new("sleep");
27//!    cmd1.arg("0.1");
28//!    let mut cmd2 = std::process::Command::new("sleep");
29//!    cmd2.arg("0.2");
30//!
31//!    let pids = vec![cmd1.spawn().unwrap().id(), cmd2.spawn().unwrap().id()]; // Example PIDs
32//!     let mut pid_set = PidSet::new(pids);
33//!
34//!     // Wait for any PID to exit
35//!     pid_set.wait_any()?;
36//!
37//!     // Clean up
38//!     pid_set.close()?;
39//!     Ok(())
40//! }
41//! ```
42
43use std::{collections::HashMap, usize};
44
45use libc::{EPOLLIN, EPOLL_CTL_ADD, EPOLL_CTL_DEL};
46
47type FD = i32;
48type PID = u32;
49
50/// A map of process IDs (PIDs) to their associated file descriptors.
51type FDPidsMap = HashMap<PID, FD>;
52
53/// Manages a set of PIDs and their corresponding epoll file descriptors.
54pub struct PidSet {
55    fd_pids: FDPidsMap,
56    epoll_fd: Option<FD>,
57}
58
59/// Errors that can occur in the `PidSet`.
60#[derive(Debug, thiserror::Error)]
61pub enum PidSetError {
62    #[error("Error while creating epoll file instance:`{0}`")]
63    EpollCreate(std::io::Error),
64
65    #[error("Error on pidfd_open syscall for pid `{0}`: `{1}")]
66    PidFdOpenSyscall(u32, std::io::Error),
67
68    #[error("Error on epoll_ctl: `{0}")]
69    EpollCtl(std::io::Error),
70
71    #[error("Error on epoll_wait: `{0}")]
72    EpollWait(std::io::Error),
73
74    #[error("PID not found: `{0}")]
75    PidNotFound(u32),
76
77    #[error("Error while closing epoll file descriptor: `{0}")]
78    EpollClose(std::io::Error),
79}
80
81impl PidSet {
82    /// Creates a new `PidSet` with the specified PIDs.
83    ///
84    /// # Arguments
85    ///
86    /// * `pids` - An iterator over the PIDs to monitor.
87    pub fn new<P: IntoIterator<Item = PID>>(pids: P) -> Self {
88        let fd_pids: FDPidsMap = pids.into_iter().map(|pid| (pid, 0)).collect();
89        Self {
90            fd_pids,
91            epoll_fd: None,
92        }
93    }
94
95    fn register_pid(epoll_fd: i32, pid: u32, token: u32) -> Result<FD, PidSetError> {
96        let cfd = unsafe { syscallerr(libc::syscall(libc::SYS_pidfd_open, pid, 0)) }
97            .map_err(|err| PidSetError::PidFdOpenSyscall(pid, err))?;
98        // use pid as token
99        unsafe {
100            syserr(libc::epoll_ctl(
101                epoll_fd,
102                EPOLL_CTL_ADD,
103                cfd as i32,
104                &mut libc::epoll_event {
105                    events: EPOLLIN as u32,
106                    u64: token as u64,
107                } as *mut _ as *mut libc::epoll_event,
108            ))
109        }
110        .map_err(PidSetError::EpollCtl)?;
111        Ok(cfd as i32)
112    }
113
114    fn deregister_pid(epoll_fd: i32, fd: i32) -> Result<(), PidSetError> {
115        let _ = unsafe {
116            syserr(libc::epoll_ctl(
117                epoll_fd,
118                EPOLL_CTL_DEL,
119                fd,
120                std::ptr::null_mut(),
121            ))
122        }
123        .map_err(PidSetError::EpollWait)?;
124        Ok(())
125    }
126
127    fn init_epoll(&mut self) -> Result<FD, PidSetError> {
128        // EPOLL_CLOEXEC flag disabled
129        let epoll_fd =
130            unsafe { syserr(libc::epoll_create1(0)) }.map_err(PidSetError::EpollCreate)?;
131        for (pid, fd) in &mut self.fd_pids {
132            *fd = PidSet::register_pid(epoll_fd, *pid, *pid)?;
133        }
134
135        self.epoll_fd = Some(epoll_fd);
136        Ok(epoll_fd)
137    }
138}
139
140fn syserr(status_code: libc::c_int) -> std::io::Result<libc::c_int> {
141    if status_code < 0 {
142        return Err(std::io::Error::from_raw_os_error(status_code));
143    }
144    Ok(status_code)
145}
146
147fn syscallerr(status_code: libc::c_long) -> std::io::Result<libc::c_long> {
148    if status_code < 0 {
149        return Err(std::io::Error::last_os_error());
150    }
151    Ok(status_code)
152}
153
154impl PidSet {
155    /// Waits for a specified number of PIDs to exit, up to the total number monitored.
156    ///
157    /// # Arguments
158    ///
159    /// * `n` - The number of PID events to wait for.
160    ///
161    /// # Errors
162    ///
163    /// Returns `PidSetError` if an error occurs during epoll wait or if a PID is not found.
164    fn wait(&mut self, n: usize) -> Result<usize, PidSetError> {
165        let max_events = self.fd_pids.len();
166        let mut total_events: usize = 0;
167        let epoll_fd = self.epoll_fd.unwrap_or(self.init_epoll()?);
168        while total_events < n {
169            let mut events: Vec<libc::epoll_event> = Vec::with_capacity(max_events);
170            let event_count = syserr(unsafe {
171                libc::epoll_wait(epoll_fd, events.as_mut_ptr(), max_events as i32, -1)
172            })
173            .map_err(PidSetError::EpollWait)? as usize;
174            unsafe { events.set_len(event_count as usize) };
175            total_events += event_count;
176
177            for event in events {
178                let cdata = event.u64 as u32;
179                // TODO: return Error if event_count is -1
180                let fd = self
181                    .fd_pids
182                    .get(&cdata)
183                    .ok_or(PidSetError::PidNotFound(cdata))?;
184                PidSet::deregister_pid(epoll_fd, *fd)?;
185
186                // remove from hashmap
187                self.fd_pids.remove(&cdata);
188            }
189        }
190        Ok(total_events)
191    }
192
193    /// Waits for all PIDs to exit.
194    ///
195    /// # Errors
196    ///
197    /// Returns `PidSetError` if an error occurs during the wait.
198    pub fn wait_all(&mut self) -> Result<(), PidSetError> {
199        self.wait(self.fd_pids.len())?;
200        Ok(())
201    }
202
203    /// Waits for any one PID to exit.
204    ///
205    /// # Errors
206    ///
207    /// Returns `PidSetError` if an error occurs during the wait.
208    pub fn wait_any(&mut self) -> Result<(), PidSetError> {
209        self.wait(1)?;
210        Ok(())
211    }
212
213    /// Closes the epoll file descriptor and cleans up the `PidSet`.
214    ///
215    /// # Errors
216    ///
217    /// Returns `PidSetError` if an error occurs while closing the epoll file descriptor.
218    pub fn close(mut self) -> Result<(), PidSetError> {
219        let epoll_fd = self.epoll_fd.unwrap_or(self.init_epoll()?);
220        unsafe { syserr(libc::close(epoll_fd)) }.map_err(PidSetError::EpollClose)?;
221        Ok(())
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use std::time::{Duration, Instant};
229
230    fn sleep_cmd(duration: &str) -> std::process::Command {
231        let mut cmd1 = std::process::Command::new("sleep");
232        cmd1.arg(duration);
233        cmd1
234    }
235
236    #[test]
237    fn wait_all() {
238        let mut pid_set = PidSet::new([
239            sleep_cmd("0.1").spawn().unwrap().id(),
240            sleep_cmd("0.2").spawn().unwrap().id(),
241            sleep_cmd("0.3").spawn().unwrap().id(),
242            sleep_cmd("0.4").spawn().unwrap().id(),
243            sleep_cmd("0.5").spawn().unwrap().id(),
244        ]);
245
246        assert!(pid_set.wait_all().is_ok());
247    }
248
249    #[test]
250    fn wait_any() {
251        let start_time = Instant::now(); // Start the timer
252
253        let mut pid_set = PidSet::new([
254            sleep_cmd("0.2").spawn().unwrap().id(),
255            sleep_cmd("3").spawn().unwrap().id(),
256            sleep_cmd("3").spawn().unwrap().id(),
257            sleep_cmd("3").spawn().unwrap().id(),
258            sleep_cmd("3").spawn().unwrap().id(),
259        ]);
260
261        assert!(pid_set.wait_any().is_ok());
262        assert!(
263            start_time.elapsed() < Duration::from_secs(3),
264            "Expected wait_any() to return in less than 3 seconds, but it took {:?}",
265            start_time.elapsed()
266        );
267    }
268}