1use std::io;
24use std::os::fd::{AsRawFd, OwnedFd, RawFd};
25use std::time::{Duration, Instant};
26
27use rustix::process::{Signal, pidfd_send_signal};
28
29use crate::plan::Plan;
30use crate::workspace::Workspace;
31
32#[derive(Debug, Clone)]
34pub struct Output {
35 pub stdout: Vec<u8>,
36 pub stderr: Vec<u8>,
37 pub status: Status,
38 pub duration: Duration,
39 pub exit_code: Option<i32>,
40 pub signal: Option<i32>,
41}
42
43impl Output {
44 #[inline]
45 pub fn success(&self) -> bool {
46 self.status == Status::Exited && self.exit_code == Some(0)
47 }
48
49 #[inline]
50 pub fn stdout_str(&self) -> String {
51 String::from_utf8_lossy(&self.stdout).into_owned()
52 }
53
54 #[inline]
55 pub fn stderr_str(&self) -> String {
56 String::from_utf8_lossy(&self.stderr).into_owned()
57 }
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum Status {
63 Exited,
64 Signaled,
65 Timeout,
66 OutputLimitExceeded,
67}
68
69pub fn monitor(pidfd: OwnedFd, workspace: &Workspace, plan: &Plan) -> io::Result<Output> {
71 let start = Instant::now();
72 let deadline = start + plan.timeout;
73
74 let mut stdout_buf = Vec::new();
75 let mut stderr_buf = Vec::new();
76
77 let stdout_fd = workspace.pipes.stdout.read.as_raw_fd();
78 let stderr_fd = workspace.pipes.stderr.read.as_raw_fd();
79 let pidfd_raw = pidfd.as_raw_fd();
80
81 set_nonblocking(stdout_fd)?;
82 set_nonblocking(stderr_fd)?;
83
84 let mut status = Status::Exited;
85 let mut exit_code = None;
86 let mut signal = None;
87 let mut buf = [0u8; 4096];
88
89 loop {
90 let timeout_remaining = deadline.saturating_duration_since(Instant::now());
91 if timeout_remaining.is_zero() {
92 pidfd_send_signal(&pidfd, Signal::KILL).ok();
93 status = Status::Timeout;
94 wait_for_exit(pidfd_raw)?;
95 break;
96 }
97
98 let poll_timeout = timeout_remaining.as_millis().min(100) as i32;
100 let mut fds = [
101 libc::pollfd {
102 fd: stdout_fd,
103 events: libc::POLLIN,
104 revents: 0,
105 },
106 libc::pollfd {
107 fd: stderr_fd,
108 events: libc::POLLIN,
109 revents: 0,
110 },
111 libc::pollfd {
112 fd: pidfd_raw,
113 events: libc::POLLIN,
114 revents: 0,
115 },
116 ];
117
118 let ret = unsafe { libc::poll(fds.as_mut_ptr(), 3, poll_timeout) };
119 if ret < 0 {
120 let err = io::Error::last_os_error();
121 if err.kind() == io::ErrorKind::Interrupted {
122 continue;
123 }
124 return Err(err);
125 }
126
127 if fds[0].revents & libc::POLLIN != 0 {
128 if let Ok(n) = read_nonblocking(stdout_fd, &mut buf) {
129 if n > 0 {
130 if stdout_buf.len() + n > plan.max_output as usize {
131 status = Status::OutputLimitExceeded;
132 pidfd_send_signal(&pidfd, Signal::KILL).ok();
133 wait_for_exit(pidfd_raw)?;
134 break;
135 }
136 stdout_buf.extend_from_slice(&buf[..n]);
137 }
138 }
139 }
140
141 if fds[1].revents & libc::POLLIN != 0 {
142 if let Ok(n) = read_nonblocking(stderr_fd, &mut buf) {
143 if n > 0 {
144 if stderr_buf.len() + n > plan.max_output as usize {
145 status = Status::OutputLimitExceeded;
146 pidfd_send_signal(&pidfd, Signal::KILL).ok();
147 wait_for_exit(pidfd_raw)?;
148 break;
149 }
150 stderr_buf.extend_from_slice(&buf[..n]);
151 }
152 }
153 }
154
155 if fds[2].revents & libc::POLLIN != 0 {
156 let (ec, sig) = wait_for_exit(pidfd_raw)?;
157 exit_code = ec;
158 signal = sig;
159 if sig.is_some() {
160 status = Status::Signaled;
161 }
162 break;
163 }
164
165 if (fds[0].revents & libc::POLLHUP != 0) && (fds[1].revents & libc::POLLHUP != 0) {
166 let (ec, sig) = wait_for_exit(pidfd_raw)?;
167 exit_code = ec;
168 signal = sig;
169 if sig.is_some() {
170 status = Status::Signaled;
171 }
172 break;
173 }
174 }
175
176 drain_remaining(stdout_fd, &mut stdout_buf, &mut buf, plan.max_output);
177 drain_remaining(stderr_fd, &mut stderr_buf, &mut buf, plan.max_output);
178
179 Ok(Output {
180 stdout: stdout_buf,
181 stderr: stderr_buf,
182 status,
183 duration: start.elapsed(),
184 exit_code,
185 signal,
186 })
187}
188
189pub fn write_stdin(workspace: &Workspace, data: &[u8]) -> io::Result<()> {
191 let fd = workspace.pipes.stdin.write.as_raw_fd();
192 let mut written = 0;
193 while written < data.len() {
194 let ret = unsafe {
195 libc::write(
196 fd,
197 data[written..].as_ptr().cast::<libc::c_void>(),
198 data.len() - written,
199 )
200 };
201 if ret < 0 {
202 return Err(io::Error::last_os_error());
203 }
204 written += ret as usize;
205 }
206 Ok(())
207}
208
209#[inline]
210pub(crate) fn set_nonblocking(fd: RawFd) -> io::Result<()> {
211 let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
212 if flags < 0 {
213 return Err(io::Error::last_os_error());
214 }
215 let ret = unsafe { libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) };
216 if ret < 0 {
217 Err(io::Error::last_os_error())
218 } else {
219 Ok(())
220 }
221}
222
223#[inline]
224fn read_nonblocking(fd: RawFd, buf: &mut [u8]) -> io::Result<usize> {
225 let ret = unsafe { libc::read(fd, buf.as_mut_ptr().cast::<libc::c_void>(), buf.len()) };
226 if ret < 0 {
227 Err(io::Error::last_os_error())
228 } else {
229 Ok(ret as usize)
230 }
231}
232
233fn drain_remaining(fd: RawFd, output: &mut Vec<u8>, buf: &mut [u8], max_output: u64) {
234 let max = max_output as usize;
235 loop {
236 if output.len() >= max {
237 break;
239 }
240 match read_nonblocking(fd, buf) {
241 Ok(0) | Err(_) => break,
242 Ok(n) => {
243 let remaining = max.saturating_sub(output.len());
245 let to_add = n.min(remaining);
246 output.extend_from_slice(&buf[..to_add]);
247 }
248 }
249 }
250}
251
252pub(crate) fn wait_for_exit(pidfd: RawFd) -> io::Result<(Option<i32>, Option<i32>)> {
253 let mut siginfo: libc::siginfo_t = unsafe { std::mem::zeroed() };
254 let ret = unsafe {
255 libc::waitid(
256 libc::P_PIDFD,
257 pidfd as libc::id_t,
258 &mut siginfo,
259 libc::WEXITED,
260 )
261 };
262 if ret < 0 {
263 return Err(io::Error::last_os_error());
264 }
265
266 let code = siginfo.si_code;
267 let status = unsafe { siginfo.si_status() };
268
269 match code {
270 libc::CLD_EXITED => Ok((Some(status), None)),
271 libc::CLD_KILLED | libc::CLD_DUMPED => Ok((None, Some(status))),
272 _ => Ok((None, None)),
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn output_success() {
282 let output = Output {
283 stdout: vec![],
284 stderr: vec![],
285 status: Status::Exited,
286 duration: Duration::from_millis(100),
287 exit_code: Some(0),
288 signal: None,
289 };
290 assert!(output.success());
291 }
292
293 #[test]
294 fn output_failure() {
295 let output = Output {
296 stdout: vec![],
297 stderr: vec![],
298 status: Status::Exited,
299 duration: Duration::from_millis(100),
300 exit_code: Some(1),
301 signal: None,
302 };
303 assert!(!output.success());
304 }
305}