1use std::{collections::HashMap, usize};
44
45use libc::{EPOLLIN, EPOLL_CTL_ADD, EPOLL_CTL_DEL};
46
47type FD = i32;
48type PID = u32;
49
50type FDPidsMap = HashMap<PID, FD>;
52
53pub struct PidSet {
55 fd_pids: FDPidsMap,
56 epoll_fd: Option<FD>,
57}
58
59#[derive(Debug, thiserror::Error)]
61pub enum PidSetError {
62 #[error("Error while creating epoll file instance:`{0}`")]
63 EpollCreate(std::io::Error),
64
65 #[error("Error on pidfd_open syscall for pid `{0}`: `{1}")]
66 PidFdOpenSyscall(u32, std::io::Error),
67
68 #[error("Error on epoll_ctl: `{0}")]
69 EpollCtl(std::io::Error),
70
71 #[error("Error on epoll_wait: `{0}")]
72 EpollWait(std::io::Error),
73
74 #[error("PID not found: `{0}")]
75 PidNotFound(u32),
76
77 #[error("Error while closing epoll file descriptor: `{0}")]
78 EpollClose(std::io::Error),
79}
80
81impl PidSet {
82 pub fn new<P: IntoIterator<Item = PID>>(pids: P) -> Self {
88 let fd_pids: FDPidsMap = pids.into_iter().map(|pid| (pid, 0)).collect();
89 Self {
90 fd_pids,
91 epoll_fd: None,
92 }
93 }
94
95 fn register_pid(epoll_fd: i32, pid: u32, token: u32) -> Result<FD, PidSetError> {
96 let cfd = unsafe { syscallerr(libc::syscall(libc::SYS_pidfd_open, pid, 0)) }
97 .map_err(|err| PidSetError::PidFdOpenSyscall(pid, err))?;
98 unsafe {
100 syserr(libc::epoll_ctl(
101 epoll_fd,
102 EPOLL_CTL_ADD,
103 cfd as i32,
104 &mut libc::epoll_event {
105 events: EPOLLIN as u32,
106 u64: token as u64,
107 } as *mut _ as *mut libc::epoll_event,
108 ))
109 }
110 .map_err(PidSetError::EpollCtl)?;
111 Ok(cfd as i32)
112 }
113
114 fn deregister_pid(epoll_fd: i32, fd: i32) -> Result<(), PidSetError> {
115 let _ = unsafe {
116 syserr(libc::epoll_ctl(
117 epoll_fd,
118 EPOLL_CTL_DEL,
119 fd,
120 std::ptr::null_mut(),
121 ))
122 }
123 .map_err(PidSetError::EpollWait)?;
124 Ok(())
125 }
126
127 fn init_epoll(&mut self) -> Result<FD, PidSetError> {
128 let epoll_fd =
130 unsafe { syserr(libc::epoll_create1(0)) }.map_err(PidSetError::EpollCreate)?;
131 for (pid, fd) in &mut self.fd_pids {
132 *fd = PidSet::register_pid(epoll_fd, *pid, *pid)?;
133 }
134
135 self.epoll_fd = Some(epoll_fd);
136 Ok(epoll_fd)
137 }
138}
139
140fn syserr(status_code: libc::c_int) -> std::io::Result<libc::c_int> {
141 if status_code < 0 {
142 return Err(std::io::Error::from_raw_os_error(status_code));
143 }
144 Ok(status_code)
145}
146
147fn syscallerr(status_code: libc::c_long) -> std::io::Result<libc::c_long> {
148 if status_code < 0 {
149 return Err(std::io::Error::last_os_error());
150 }
151 Ok(status_code)
152}
153
154impl PidSet {
155 fn wait(&mut self, n: usize) -> Result<usize, PidSetError> {
165 let max_events = self.fd_pids.len();
166 let mut total_events: usize = 0;
167 let epoll_fd = self.epoll_fd.unwrap_or(self.init_epoll()?);
168 while total_events < n {
169 let mut events: Vec<libc::epoll_event> = Vec::with_capacity(max_events);
170 let event_count = syserr(unsafe {
171 libc::epoll_wait(epoll_fd, events.as_mut_ptr(), max_events as i32, -1)
172 })
173 .map_err(PidSetError::EpollWait)? as usize;
174 unsafe { events.set_len(event_count as usize) };
175 total_events += event_count;
176
177 for event in events {
178 let cdata = event.u64 as u32;
179 let fd = self
181 .fd_pids
182 .get(&cdata)
183 .ok_or(PidSetError::PidNotFound(cdata))?;
184 PidSet::deregister_pid(epoll_fd, *fd)?;
185
186 self.fd_pids.remove(&cdata);
188 }
189 }
190 Ok(total_events)
191 }
192
193 pub fn wait_all(&mut self) -> Result<(), PidSetError> {
199 self.wait(self.fd_pids.len())?;
200 Ok(())
201 }
202
203 pub fn wait_any(&mut self) -> Result<(), PidSetError> {
209 self.wait(1)?;
210 Ok(())
211 }
212
213 pub fn close(mut self) -> Result<(), PidSetError> {
219 let epoll_fd = self.epoll_fd.unwrap_or(self.init_epoll()?);
220 unsafe { syserr(libc::close(epoll_fd)) }.map_err(PidSetError::EpollClose)?;
221 Ok(())
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use std::time::{Duration, Instant};
229
230 fn sleep_cmd(duration: &str) -> std::process::Command {
231 let mut cmd1 = std::process::Command::new("sleep");
232 cmd1.arg(duration);
233 cmd1
234 }
235
236 #[test]
237 fn wait_all() {
238 let mut pid_set = PidSet::new([
239 sleep_cmd("0.1").spawn().unwrap().id(),
240 sleep_cmd("0.2").spawn().unwrap().id(),
241 sleep_cmd("0.3").spawn().unwrap().id(),
242 sleep_cmd("0.4").spawn().unwrap().id(),
243 sleep_cmd("0.5").spawn().unwrap().id(),
244 ]);
245
246 assert!(pid_set.wait_all().is_ok());
247 }
248
249 #[test]
250 fn wait_any() {
251 let start_time = Instant::now(); let mut pid_set = PidSet::new([
254 sleep_cmd("0.2").spawn().unwrap().id(),
255 sleep_cmd("3").spawn().unwrap().id(),
256 sleep_cmd("3").spawn().unwrap().id(),
257 sleep_cmd("3").spawn().unwrap().id(),
258 sleep_cmd("3").spawn().unwrap().id(),
259 ]);
260
261 assert!(pid_set.wait_any().is_ok());
262 assert!(
263 start_time.elapsed() < Duration::from_secs(3),
264 "Expected wait_any() to return in less than 3 seconds, but it took {:?}",
265 start_time.elapsed()
266 );
267 }
268}