1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
//! # PID Set Library
//!
//! `pid_set` is a library for managing and monitoring process identifiers (PIDs) using epoll on Linux.
//! It allows for asynchronous notification when a process exits by leveraging epoll and pidfd (process file descriptors).
//!
//! ## Features
//! - Create a `PidSet` to manage multiple PIDs.
//! - Monitor process exits using epoll.
//! - Handle system call errors gracefully with custom errors.
//!
//! ## Usage
//! Add this to your `Cargo.toml`:
//!
//! ```toml
//! [dependencies]
//! pid_set = "0.1.0"
//! ```
//!
//! ## Examples
//! Here's how you can use `PidSet` to monitor a list of PIDs:
//!
//! ```rust
//! use pid_set::{PidSet, PidSetError};
//!
//! fn main() -> Result<(), PidSetError> {
//!     let pids = vec![1234, 5678, 431, 9871, 2123]; // Example PIDs
//!     let mut pid_set = PidSet::new(pids)?;
//!
//!     // Wait for any PID to exit
//!     pid_set.wait_any()?;
//!
//!     // Clean up
//!     pid_set.close()?;
//!     Ok(())
//! }
//! ```

use std::{collections::HashMap, usize};

use libc::{EPOLLIN, EPOLL_CTL_ADD, EPOLL_CTL_DEL};

type FD = i32;
type PID = u32;

/// A map of process IDs (PIDs) to their associated file descriptors.
type FDPidsMap = HashMap<PID, FD>;

/// Manages a set of PIDs and their corresponding epoll file descriptors.
pub struct PidSet {
    fd_pids: FDPidsMap,
    epoll_fd: FD,
}

/// Errors that can occur in the `PidSet`.
#[derive(Debug, thiserror::Error)]
pub enum PidSetError {
    #[error("Error while creating epoll file instance:`{0}`")]
    EpollCreate(std::io::Error),

    #[error("Error on pidfd_open syscall for pid `{0}`: `{1}")]
    PidFdOpenSyscall(u32, std::io::Error),

    #[error("Error on epoll_ctl: `{0}")]
    EpollCtl(std::io::Error),

    #[error("Error on epoll_wait: `{0}")]
    EpollWait(std::io::Error),

    #[error("PID not found: `{0}")]
    PidNotFound(u32),

    #[error("Error while closing epoll file descriptor: `{0}")]
    EpollClose(std::io::Error),
}

impl PidSet {
    /// Creates a new `PidSet` with the specified PIDs.
    ///
    /// # Arguments
    ///
    /// * `pids` - An iterator over the PIDs to monitor.
    ///
    /// # Errors
    ///
    /// Returns `PidSetError` if an error occurs while setting up epoll or pidfds.
    pub fn new<P: IntoIterator<Item = PID>>(pids: P) -> Result<Self, PidSetError> {
        // EPOLL_CLOEXEC flag disabled
        let epoll_fd =
            unsafe { syserr(libc::epoll_create1(0)) }.map_err(PidSetError::EpollCreate)?;
        let fd_pids: Result<FDPidsMap, PidSetError> = pids
            .into_iter()
            .map(|pid| {
                let cfd = unsafe { syscallerr(libc::syscall(libc::SYS_pidfd_open, pid, 0)) }
                    .map_err(|err| PidSetError::PidFdOpenSyscall(pid, err))?;
                // use pid as token
                unsafe {
                    syserr(libc::epoll_ctl(
                        epoll_fd,
                        EPOLL_CTL_ADD,
                        cfd as i32,
                        &mut libc::epoll_event {
                            events: EPOLLIN as u32,
                            u64: pid as u64,
                        } as *mut _ as *mut libc::epoll_event,
                    ))
                }
                .map_err(PidSetError::EpollCtl)?;
                Ok((pid, cfd as i32))
            })
            .collect();

        Ok(Self {
            fd_pids: fd_pids?,
            epoll_fd,
        })
    }
}

fn syserr(status_code: libc::c_int) -> std::io::Result<libc::c_int> {
    if status_code < 0 {
        return Err(std::io::Error::from_raw_os_error(status_code));
    }
    Ok(status_code)
}

fn syscallerr(status_code: libc::c_long) -> std::io::Result<libc::c_long> {
    if status_code < 0 {
        return Err(std::io::Error::last_os_error());
    }
    Ok(status_code)
}

impl PidSet {
    /// Waits for a specified number of PIDs to exit, up to the total number monitored.
    ///
    /// # Arguments
    ///
    /// * `n` - The number of PID events to wait for.
    ///
    /// # Errors
    ///
    /// Returns `PidSetError` if an error occurs during epoll wait or if a PID is not found.
    fn wait(&mut self, n: usize) -> Result<usize, PidSetError> {
        let max_events = self.fd_pids.len();
        let mut total_events: usize = 0;
        while total_events < n {
            let mut events: Vec<libc::epoll_event> = Vec::with_capacity(max_events);
            let event_count = syserr(unsafe {
                libc::epoll_wait(self.epoll_fd, events.as_mut_ptr(), max_events as i32, -1)
            })
            .map_err(PidSetError::EpollWait)? as usize;
            unsafe { events.set_len(event_count as usize) };
            total_events += event_count;

            for event in events {
                let cdata = event.u64 as u32;
                // TODO: return Error if event_count is -1
                let fd = self
                    .fd_pids
                    .get(&cdata)
                    .ok_or(PidSetError::PidNotFound(cdata))?;
                let _ = unsafe {
                    syserr(libc::epoll_ctl(
                        self.epoll_fd,
                        EPOLL_CTL_DEL,
                        *fd,
                        std::ptr::null_mut(),
                    ))
                }
                .map_err(PidSetError::EpollWait)?;

                // remove from hashmap
                self.fd_pids.remove(&cdata);
            }
        }
        Ok(total_events)
    }

    /// Waits for all PIDs to exit.
    ///
    /// # Errors
    ///
    /// Returns `PidSetError` if an error occurs during the wait.
    pub fn wait_all(&mut self) -> Result<(), PidSetError> {
        self.wait(self.fd_pids.len())?;
        Ok(())
    }

    /// Waits for any one PID to exit.
    ///
    /// # Errors
    ///
    /// Returns `PidSetError` if an error occurs during the wait.
    pub fn wait_any(&mut self) -> Result<(), PidSetError> {
        self.wait(1)?;
        Ok(())
    }

    /// Closes the epoll file descriptor and cleans up the `PidSet`.
    ///
    /// # Errors
    ///
    /// Returns `PidSetError` if an error occurs while closing the epoll file descriptor.
    pub fn close(self) -> Result<(), PidSetError> {
        unsafe { syserr(libc::close(self.epoll_fd)) }.map_err(PidSetError::EpollClose)?;
        Ok(())
    }
}