1use std::io;
24use std::os::fd::{AsRawFd, OwnedFd, RawFd};
25use std::time::{Duration, Instant};
26
27use rustix::process::{pidfd_send_signal, Signal};
28
29use crate::plan::Plan;
30use crate::workspace::Workspace;
31
32#[derive(Debug, Clone)]
34pub struct Output {
35 pub stdout: Vec<u8>,
36 pub stderr: Vec<u8>,
37 pub status: Status,
38 pub duration: Duration,
39 pub exit_code: Option<i32>,
40 pub signal: Option<i32>,
41}
42
43impl Output {
44 #[inline]
45 pub fn success(&self) -> bool {
46 self.status == Status::Exited && self.exit_code == Some(0)
47 }
48
49 #[inline]
50 pub fn stdout_str(&self) -> String {
51 String::from_utf8_lossy(&self.stdout).into_owned()
52 }
53
54 #[inline]
55 pub fn stderr_str(&self) -> String {
56 String::from_utf8_lossy(&self.stderr).into_owned()
57 }
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum Status {
63 Exited,
64 Signaled,
65 Timeout,
66 OutputLimitExceeded,
67}
68
69pub fn monitor(pidfd: OwnedFd, workspace: &Workspace, plan: &Plan) -> io::Result<Output> {
71 let start = Instant::now();
72 let deadline = start + plan.timeout;
73
74 let mut stdout_buf = Vec::new();
75 let mut stderr_buf = Vec::new();
76
77 let stdout_fd = workspace.pipes.stdout.read.as_raw_fd();
78 let stderr_fd = workspace.pipes.stderr.read.as_raw_fd();
79 let pidfd_raw = pidfd.as_raw_fd();
80
81 set_nonblocking(stdout_fd)?;
82 set_nonblocking(stderr_fd)?;
83
84 let mut status = Status::Exited;
85 let mut exit_code = None;
86 let mut signal = None;
87 let mut buf = [0u8; 4096];
88
89 loop {
90 let timeout_remaining = deadline.saturating_duration_since(Instant::now());
91 if timeout_remaining.is_zero() {
92 pidfd_send_signal(&pidfd, Signal::KILL).ok();
93 status = Status::Timeout;
94 wait_for_exit(pidfd_raw)?;
95 break;
96 }
97
98 let poll_timeout = timeout_remaining.as_millis().min(100) as i32;
100 let mut fds = [
101 libc::pollfd { fd: stdout_fd, events: libc::POLLIN, revents: 0 },
102 libc::pollfd { fd: stderr_fd, events: libc::POLLIN, revents: 0 },
103 libc::pollfd { fd: pidfd_raw, events: libc::POLLIN, revents: 0 },
104 ];
105
106 let ret = unsafe { libc::poll(fds.as_mut_ptr(), 3, poll_timeout) };
107 if ret < 0 {
108 let err = io::Error::last_os_error();
109 if err.kind() == io::ErrorKind::Interrupted {
110 continue;
111 }
112 return Err(err);
113 }
114
115 if fds[0].revents & libc::POLLIN != 0 {
116 if let Ok(n) = read_nonblocking(stdout_fd, &mut buf) {
117 if n > 0 {
118 if stdout_buf.len() + n > plan.max_output as usize {
119 status = Status::OutputLimitExceeded;
120 pidfd_send_signal(&pidfd, Signal::KILL).ok();
121 wait_for_exit(pidfd_raw)?;
122 break;
123 }
124 stdout_buf.extend_from_slice(&buf[..n]);
125 }
126 }
127 }
128
129 if fds[1].revents & libc::POLLIN != 0 {
130 if let Ok(n) = read_nonblocking(stderr_fd, &mut buf) {
131 if n > 0 {
132 if stderr_buf.len() + n > plan.max_output as usize {
133 status = Status::OutputLimitExceeded;
134 pidfd_send_signal(&pidfd, Signal::KILL).ok();
135 wait_for_exit(pidfd_raw)?;
136 break;
137 }
138 stderr_buf.extend_from_slice(&buf[..n]);
139 }
140 }
141 }
142
143 if fds[2].revents & libc::POLLIN != 0 {
144 let (ec, sig) = wait_for_exit(pidfd_raw)?;
145 exit_code = ec;
146 signal = sig;
147 if sig.is_some() {
148 status = Status::Signaled;
149 }
150 break;
151 }
152
153 if (fds[0].revents & libc::POLLHUP != 0) && (fds[1].revents & libc::POLLHUP != 0) {
154 let (ec, sig) = wait_for_exit(pidfd_raw)?;
155 exit_code = ec;
156 signal = sig;
157 if sig.is_some() {
158 status = Status::Signaled;
159 }
160 break;
161 }
162 }
163
164 drain_remaining(stdout_fd, &mut stdout_buf, &mut buf, plan.max_output);
165 drain_remaining(stderr_fd, &mut stderr_buf, &mut buf, plan.max_output);
166
167 Ok(Output {
168 stdout: stdout_buf,
169 stderr: stderr_buf,
170 status,
171 duration: start.elapsed(),
172 exit_code,
173 signal,
174 })
175}
176
177pub fn write_stdin(workspace: &Workspace, data: &[u8]) -> io::Result<()> {
179 let fd = workspace.pipes.stdin.write.as_raw_fd();
180 let mut written = 0;
181 while written < data.len() {
182 let ret = unsafe {
183 libc::write(fd, data[written..].as_ptr().cast::<libc::c_void>(), data.len() - written)
184 };
185 if ret < 0 {
186 return Err(io::Error::last_os_error());
187 }
188 written += ret as usize;
189 }
190 Ok(())
191}
192
193#[inline]
194pub(crate) fn set_nonblocking(fd: RawFd) -> io::Result<()> {
195 let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
196 if flags < 0 {
197 return Err(io::Error::last_os_error());
198 }
199 let ret = unsafe { libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) };
200 if ret < 0 { Err(io::Error::last_os_error()) } else { Ok(()) }
201}
202
203#[inline]
204fn read_nonblocking(fd: RawFd, buf: &mut [u8]) -> io::Result<usize> {
205 let ret = unsafe { libc::read(fd, buf.as_mut_ptr().cast::<libc::c_void>(), buf.len()) };
206 if ret < 0 { Err(io::Error::last_os_error()) } else { Ok(ret as usize) }
207}
208
209fn drain_remaining(fd: RawFd, output: &mut Vec<u8>, buf: &mut [u8], max_output: u64) {
210 let max = max_output as usize;
211 loop {
212 if output.len() >= max {
213 break;
215 }
216 match read_nonblocking(fd, buf) {
217 Ok(0) | Err(_) => break,
218 Ok(n) => {
219 let remaining = max.saturating_sub(output.len());
221 let to_add = n.min(remaining);
222 output.extend_from_slice(&buf[..to_add]);
223 }
224 }
225 }
226}
227
228pub(crate) fn wait_for_exit(pidfd: RawFd) -> io::Result<(Option<i32>, Option<i32>)> {
229 let mut siginfo: libc::siginfo_t = unsafe { std::mem::zeroed() };
230 let ret = unsafe {
231 libc::waitid(libc::P_PIDFD, pidfd as libc::id_t, &mut siginfo, libc::WEXITED)
232 };
233 if ret < 0 {
234 return Err(io::Error::last_os_error());
235 }
236
237 let code = siginfo.si_code;
238 let status = unsafe { siginfo.si_status() };
239
240 match code {
241 libc::CLD_EXITED => Ok((Some(status), None)),
242 libc::CLD_KILLED | libc::CLD_DUMPED => Ok((None, Some(status))),
243 _ => Ok((None, None)),
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn output_success() {
253 let output = Output {
254 stdout: vec![],
255 stderr: vec![],
256 status: Status::Exited,
257 duration: Duration::from_millis(100),
258 exit_code: Some(0),
259 signal: None,
260 };
261 assert!(output.success());
262 }
263
264 #[test]
265 fn output_failure() {
266 let output = Output {
267 stdout: vec![],
268 stderr: vec![],
269 status: Status::Exited,
270 duration: Duration::from_millis(100),
271 exit_code: Some(1),
272 signal: None,
273 };
274 assert!(!output.success());
275 }
276}