1use std::process::{ExitStatus, Stdio};
6use std::time::Duration;
7
8use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
9use tokio::process::{Child, Command};
10use tokio::time::{timeout_at, Instant};
11
12const READ_CHUNK_SIZE: usize = 8 * 1024;
13
14#[derive(Clone, Copy, Debug, PartialEq, Eq)]
16pub struct StreamLimit {
17 pub max_stdout_bytes: usize,
19 pub max_stderr_bytes: usize,
21}
22
23#[derive(Clone, Debug, PartialEq, Eq)]
25pub struct CollectedStream {
26 pub bytes: Vec<u8>,
28 pub total_bytes: usize,
30 pub overflowed: bool,
32}
33
34#[derive(Clone, Copy, Debug, PartialEq, Eq)]
36pub struct StdinWriteError {
37 pub kind: std::io::ErrorKind,
39}
40
41#[derive(Clone, Debug, PartialEq, Eq)]
43pub struct ProcessRunResult {
44 pub status: ExitStatus,
46 pub stdout: CollectedStream,
48 pub stderr: CollectedStream,
50 pub stdin_write_error: Option<StdinWriteError>,
52 pub stdin_close_error: Option<StdinWriteError>,
54}
55
56#[derive(Clone, Copy, Debug, PartialEq, Eq)]
58pub enum ProcessRunError {
59 SpawnFailed,
61 Timeout,
63 WaitFailed,
65 StdoutReadFailed,
67 StderrReadFailed,
69}
70
71pub async fn run_command(
77 mut cmd: Command,
78 stdin_bytes: Option<Vec<u8>>,
79 timeout: Duration,
80 limits: StreamLimit,
81) -> Result<ProcessRunResult, ProcessRunError> {
82 cmd.kill_on_drop(true);
83 cmd.stdin(Stdio::piped());
84 cmd.stdout(Stdio::piped());
85 cmd.stderr(Stdio::piped());
86
87 let mut child = cmd.spawn().map_err(|_| ProcessRunError::SpawnFailed)?;
88 let deadline = Instant::now() + timeout;
89
90 let stdout_handle = child
91 .stdout
92 .take()
93 .ok_or(ProcessRunError::StdoutReadFailed)?;
94 let stderr_handle = child
95 .stderr
96 .take()
97 .ok_or(ProcessRunError::StderrReadFailed)?;
98
99 let stdout_task = tokio::spawn(read_stream(stdout_handle, limits.max_stdout_bytes));
100 let stderr_task = tokio::spawn(read_stream(stderr_handle, limits.max_stderr_bytes));
101
102 let (stdin_write_error, stdin_close_error) =
103 write_and_close_stdin(&mut child, stdin_bytes, deadline).await?;
104
105 let status = wait_for_exit_or_timeout(&mut child, deadline).await?;
106
107 let stdout = join_stream_task(stdout_task, ProcessRunError::StdoutReadFailed).await?;
108 let stderr = join_stream_task(stderr_task, ProcessRunError::StderrReadFailed).await?;
109
110 Ok(ProcessRunResult {
111 status,
112 stdout,
113 stderr,
114 stdin_write_error,
115 stdin_close_error,
116 })
117}
118
119async fn write_and_close_stdin(
120 child: &mut Child,
121 stdin_bytes: Option<Vec<u8>>,
122 deadline: Instant,
123) -> Result<(Option<StdinWriteError>, Option<StdinWriteError>), ProcessRunError> {
124 let Some(mut stdin) = child.stdin.take() else {
125 return Ok((None, None));
126 };
127
128 let mut write_error = None;
129 let mut close_error = None;
130
131 if let Some(bytes) = stdin_bytes {
132 if !bytes.is_empty() {
133 match timeout_at(deadline, stdin.write_all(&bytes)).await {
134 Ok(Ok(())) => {}
135 Ok(Err(err)) => {
136 write_error = Some(StdinWriteError { kind: err.kind() });
137 }
138 Err(_) => {
139 kill_and_reap(child).await;
140 return Err(ProcessRunError::Timeout);
141 }
142 }
143 }
144 }
145
146 match timeout_at(deadline, stdin.shutdown()).await {
147 Ok(Ok(())) => {}
148 Ok(Err(err)) => {
149 close_error = Some(StdinWriteError { kind: err.kind() });
150 }
151 Err(_) => {
152 kill_and_reap(child).await;
153 return Err(ProcessRunError::Timeout);
154 }
155 }
156
157 Ok((write_error, close_error))
158}
159
160async fn wait_for_exit_or_timeout(
161 child: &mut Child,
162 deadline: Instant,
163) -> Result<ExitStatus, ProcessRunError> {
164 match timeout_at(deadline, child.wait()).await {
165 Ok(Ok(status)) => Ok(status),
166 Ok(Err(_)) => Err(ProcessRunError::WaitFailed),
167 Err(_) => {
168 kill_and_reap(child).await;
169 Err(ProcessRunError::Timeout)
170 }
171 }
172}
173
174async fn kill_and_reap(child: &mut Child) {
175 let _ = child.kill().await;
176 let _ = child.wait().await;
177}
178
179async fn join_stream_task(
180 handle: tokio::task::JoinHandle<Result<CollectedStream, std::io::Error>>,
181 map_err: ProcessRunError,
182) -> Result<CollectedStream, ProcessRunError> {
183 let joined = handle.await.map_err(|_| map_err)?;
184 joined.map_err(|_| map_err)
185}
186
187async fn read_stream<R>(mut reader: R, max_bytes: usize) -> Result<CollectedStream, std::io::Error>
188where
189 R: AsyncRead + Unpin,
190{
191 let mut collected = Vec::new();
192 let mut total_bytes = 0usize;
193 let mut overflowed = false;
194 let mut chunk = [0u8; READ_CHUNK_SIZE];
195
196 loop {
197 let n = reader.read(&mut chunk).await?;
198 if n == 0 {
199 break;
200 }
201
202 total_bytes = total_bytes.saturating_add(n);
203 let remaining = max_bytes.saturating_sub(collected.len());
204 if remaining > 0 {
205 let take = remaining.min(n);
206 collected.extend_from_slice(&chunk[..take]);
207 if take < n {
208 overflowed = true;
209 }
210 } else {
211 overflowed = true;
212 }
213 }
214
215 Ok(CollectedStream {
216 bytes: collected,
217 total_bytes,
218 overflowed,
219 })
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[cfg(unix)]
227 use std::os::unix::fs::PermissionsExt;
228 use std::path::PathBuf;
229
230 fn write_test_program(script_body: &str) -> PathBuf {
231 let root =
232 std::env::temp_dir().join(format!("mfm-process-exec-test-{}", uuid::Uuid::new_v4()));
233 std::fs::create_dir_all(&root).expect("create temp dir");
234 let program = root.join("app.sh");
235 std::fs::write(&program, format!("#!/bin/sh\n{script_body}\n")).expect("write script");
236 #[cfg(unix)]
237 {
238 let mut perms = std::fs::metadata(&program).expect("metadata").permissions();
239 perms.set_mode(0o755);
240 std::fs::set_permissions(&program, perms).expect("chmod");
241 }
242 program
243 }
244
245 #[tokio::test]
246 async fn timeout_covers_stdin_write() {
247 let program = write_test_program("while :; do :; done");
248
249 let cmd = Command::new(&program);
250 let stdin = vec![b'a'; 4 * 1024 * 1024];
251 let err = run_command(
252 cmd,
253 Some(stdin),
254 Duration::from_millis(20),
255 StreamLimit {
256 max_stdout_bytes: 256,
257 max_stderr_bytes: 256,
258 },
259 )
260 .await
261 .expect_err("expected timeout");
262
263 assert_eq!(err, ProcessRunError::Timeout);
264
265 std::fs::remove_file(&program).expect("cleanup program");
266 std::fs::remove_dir_all(program.parent().expect("parent")).expect("cleanup dir");
267 }
268
269 #[tokio::test]
270 async fn stream_overflow_is_detected_without_unbounded_growth() {
271 let program = write_test_program("head -c 16384 /dev/zero");
272 let cmd = Command::new(&program);
273 let out = run_command(
274 cmd,
275 None,
276 Duration::from_secs(2),
277 StreamLimit {
278 max_stdout_bytes: 1024,
279 max_stderr_bytes: 1024,
280 },
281 )
282 .await
283 .expect("run should succeed");
284
285 assert!(out.status.success());
286 assert!(out.stdout.overflowed);
287 assert_eq!(out.stdout.bytes.len(), 1024);
288 assert!(out.stdout.total_bytes >= 1024);
289
290 std::fs::remove_file(&program).expect("cleanup program");
291 std::fs::remove_dir_all(program.parent().expect("parent")).expect("cleanup dir");
292 }
293}