Skip to main content

atd_tools_shell/
shared.rs

1//! Shared subprocess handler for shell tools.
2//!
3//! Responsibilities:
4//! - Spawn the given program + args with piped stdout/stderr, null stdin.
5//! - Concurrently drain stdout and stderr into byte buffers, each capped at
6//!   its respective byte budget. Continue reading past the cap until EOF so
7//!   the child doesn't block on full pipe buffers; bytes past the cap are
8//!   discarded.
9//! - Wait for the child with an optional absolute deadline. On timeout:
10//!   Unix: SIGTERM, then sleep grace_ms, then Child::kill (SIGKILL).
11//!   Windows: Child::kill directly.
12//! - UTF-8-lossy decode both streams; return exit code + truncation flags +
13//!   duration.
14
15use std::path::Path;
16use std::process::Stdio;
17use std::time::{Duration, Instant};
18
19use thiserror::Error;
20use tokio::io::{AsyncRead, AsyncReadExt};
21
22pub struct RunRequest<'a> {
23    pub program: &'a str,
24    pub args: &'a [&'a str],
25    pub cwd: &'a Path,
26    pub deadline: Option<Instant>,
27    pub grace_ms: u64,
28    pub max_stdout_bytes: usize,
29    pub max_stderr_bytes: usize,
30}
31
32#[derive(Debug)]
33pub struct RunOutput {
34    /// Exit code, or `None` if process was killed by signal.
35    pub exit_code: Option<i32>,
36    pub stdout: String,
37    pub stdout_truncated: bool,
38    pub stderr: String,
39    pub stderr_truncated: bool,
40    pub duration_ms: u64,
41}
42
43#[derive(Debug, Error)]
44pub enum RunError {
45    #[error("program not found: {program}")]
46    NotFound { program: String },
47
48    #[error("spawn failed: {0}")]
49    SpawnFailed(#[source] std::io::Error),
50
51    #[error("command timed out after {after_ms}ms")]
52    TimedOut { after_ms: u64 },
53
54    #[error("io error: {0}")]
55    Io(#[from] std::io::Error),
56}
57
58pub async fn run(req: RunRequest<'_>) -> Result<RunOutput, RunError> {
59    let start = Instant::now();
60
61    let mut cmd = tokio::process::Command::new(req.program);
62    cmd.args(req.args)
63        .current_dir(req.cwd)
64        .stdin(Stdio::null())
65        .stdout(Stdio::piped())
66        .stderr(Stdio::piped());
67
68    let mut child = match cmd.spawn() {
69        Ok(c) => c,
70        Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
71            return Err(RunError::NotFound {
72                program: req.program.to_string(),
73            });
74        }
75        Err(e) => return Err(RunError::SpawnFailed(e)),
76    };
77
78    let stdout = child.stdout.take().expect("stdout was set to piped");
79    let stderr = child.stderr.take().expect("stderr was set to piped");
80
81    let max_stdout = req.max_stdout_bytes;
82    let max_stderr = req.max_stderr_bytes;
83    let stdout_task = tokio::spawn(read_capped(stdout, max_stdout));
84    let stderr_task = tokio::spawn(read_capped(stderr, max_stderr));
85
86    // Wait for the child, honoring the optional deadline.
87    let status_result = match req.deadline {
88        Some(deadline) => {
89            let remaining = deadline.saturating_duration_since(Instant::now());
90            tokio::time::timeout(remaining, child.wait()).await
91        }
92        None => Ok(child.wait().await),
93    };
94
95    let status = match status_result {
96        Ok(Ok(s)) => s,
97        Ok(Err(e)) => {
98            let _ = child.start_kill();
99            let _ = child.wait().await;
100            let _ = stdout_task.await;
101            let _ = stderr_task.await;
102            return Err(RunError::Io(e));
103        }
104        Err(_elapsed) => {
105            // Deadline hit. SIGTERM → grace → SIGKILL.
106            #[cfg(unix)]
107            {
108                // child.id() returns None only if the child has already been reaped
109                // — not possible here since child.wait() is awaited below. But guard
110                // anyway: if it ever does return None, fall through to start_kill().
111                if let Some(pid) = child.id() {
112                    // SAFETY: child.wait() has not been called yet, so the kernel still
113                    // holds this PID slot for our child. SIGTERM to a PID we own is safe;
114                    // if the child races to exit before the signal lands, kill(2) returns
115                    // ESRCH and has no effect.
116                    unsafe {
117                        libc::kill(pid as libc::pid_t, libc::SIGTERM);
118                    }
119                    tokio::time::sleep(Duration::from_millis(req.grace_ms)).await;
120                }
121            }
122            // Either the grace expired, or we're on Windows — force kill.
123            let _ = child.start_kill();
124            let _ = child.wait().await;
125            // Drain the readers so their tasks complete.
126            let _ = stdout_task.await;
127            let _ = stderr_task.await;
128            return Err(RunError::TimedOut {
129                after_ms: start.elapsed().as_millis() as u64,
130            });
131        }
132    };
133
134    // Process finished; harvest the readers.
135    let (stdout_bytes, stdout_truncated) =
136        stdout_task.await.unwrap_or_else(|_| (Vec::new(), false));
137    let (stderr_bytes, stderr_truncated) =
138        stderr_task.await.unwrap_or_else(|_| (Vec::new(), false));
139
140    Ok(RunOutput {
141        exit_code: status.code(),
142        stdout: String::from_utf8_lossy(&stdout_bytes).into_owned(),
143        stdout_truncated,
144        stderr: String::from_utf8_lossy(&stderr_bytes).into_owned(),
145        stderr_truncated,
146        duration_ms: start.elapsed().as_millis() as u64,
147    })
148}
149
150/// Drain `reader` into a Vec<u8>. Stop storing once `max` bytes are
151/// captured, but keep reading to EOF (so the writer doesn't block on a
152/// full pipe buffer) — discard the excess.
153async fn read_capped<R>(mut reader: R, max: usize) -> (Vec<u8>, bool)
154where
155    R: AsyncRead + Unpin,
156{
157    let mut buf = Vec::new();
158    let mut truncated = false;
159    let mut chunk = [0u8; 4096];
160    loop {
161        match reader.read(&mut chunk).await {
162            Ok(0) => break,
163            Ok(n) => {
164                if buf.len() >= max {
165                    truncated = true;
166                    continue;
167                }
168                let room = max - buf.len();
169                if n <= room {
170                    buf.extend_from_slice(&chunk[..n]);
171                } else {
172                    buf.extend_from_slice(&chunk[..room]);
173                    truncated = true;
174                }
175            }
176            Err(_) => break,
177        }
178    }
179    (buf, truncated)
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    fn cwd() -> std::path::PathBuf {
187        std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("."))
188    }
189
190    #[tokio::test]
191    async fn happy_run_returns_stdout_and_exit_zero() {
192        let out = run(RunRequest {
193            program: "bash",
194            args: &["-c", "echo hello"],
195            cwd: &cwd(),
196            deadline: None,
197            grace_ms: 1000,
198            max_stdout_bytes: 4096,
199            max_stderr_bytes: 4096,
200        })
201        .await
202        .unwrap();
203        assert_eq!(out.exit_code, Some(0));
204        assert_eq!(out.stdout, "hello\n");
205        assert_eq!(out.stderr, "");
206        assert!(!out.stdout_truncated);
207    }
208
209    #[tokio::test]
210    async fn nonzero_exit_code_returned_as_ok() {
211        let out = run(RunRequest {
212            program: "bash",
213            args: &["-c", "exit 3"],
214            cwd: &cwd(),
215            deadline: None,
216            grace_ms: 1000,
217            max_stdout_bytes: 4096,
218            max_stderr_bytes: 4096,
219        })
220        .await
221        .unwrap();
222        assert_eq!(out.exit_code, Some(3));
223    }
224
225    #[tokio::test]
226    async fn stderr_captured_separately() {
227        let out = run(RunRequest {
228            program: "bash",
229            args: &["-c", ">&2 echo oops; echo good"],
230            cwd: &cwd(),
231            deadline: None,
232            grace_ms: 1000,
233            max_stdout_bytes: 4096,
234            max_stderr_bytes: 4096,
235        })
236        .await
237        .unwrap();
238        assert_eq!(out.stdout, "good\n");
239        assert_eq!(out.stderr, "oops\n");
240    }
241
242    #[tokio::test]
243    async fn stdout_truncated_at_budget() {
244        // Print 10 KB to stdout with 1 KB budget.
245        let out = run(RunRequest {
246            program: "bash",
247            args: &["-c", "yes x | head -c 10240"],
248            cwd: &cwd(),
249            deadline: None,
250            grace_ms: 1000,
251            max_stdout_bytes: 1024,
252            max_stderr_bytes: 1024,
253        })
254        .await
255        .unwrap();
256        assert!(out.stdout_truncated);
257        assert!(out.stdout.len() <= 1024);
258    }
259
260    #[tokio::test]
261    async fn stderr_truncated_at_budget() {
262        let out = run(RunRequest {
263            program: "bash",
264            args: &["-c", "yes x | head -c 10240 >&2"],
265            cwd: &cwd(),
266            deadline: None,
267            grace_ms: 1000,
268            max_stdout_bytes: 1024,
269            max_stderr_bytes: 1024,
270        })
271        .await
272        .unwrap();
273        assert!(out.stderr_truncated);
274        assert!(out.stderr.len() <= 1024);
275    }
276
277    #[tokio::test]
278    async fn timeout_triggers_sigterm_then_sigkill() {
279        let start = Instant::now();
280        let err = run(RunRequest {
281            program: "bash",
282            args: &["-c", "sleep 10"],
283            cwd: &cwd(),
284            deadline: Some(Instant::now() + Duration::from_millis(200)),
285            grace_ms: 100,
286            max_stdout_bytes: 1024,
287            max_stderr_bytes: 1024,
288        })
289        .await
290        .unwrap_err();
291        let elapsed = start.elapsed();
292        match err {
293            RunError::TimedOut { .. } => {}
294            _ => panic!("expected TimedOut, got {err:?}"),
295        }
296        // Should have killed within ~deadline + grace, certainly less than the 10s sleep.
297        assert!(
298            elapsed < Duration::from_secs(2),
299            "took too long: {elapsed:?}"
300        );
301    }
302
303    #[tokio::test]
304    async fn bogus_program_returns_not_found() {
305        let err = run(RunRequest {
306            program: "this-program-definitely-does-not-exist-xyzzy",
307            args: &[],
308            cwd: &cwd(),
309            deadline: None,
310            grace_ms: 1000,
311            max_stdout_bytes: 1024,
312            max_stderr_bytes: 1024,
313        })
314        .await
315        .unwrap_err();
316        match err {
317            RunError::NotFound { program } => {
318                assert!(program.contains("xyzzy"));
319            }
320            _ => panic!("expected NotFound, got {err:?}"),
321        }
322    }
323
324    #[tokio::test]
325    async fn cwd_is_honored() {
326        // Run `pwd` in a tempdir; output should be that tempdir's canonical path.
327        let dir = tempfile::tempdir().unwrap();
328        let canonical = tokio::fs::canonicalize(dir.path()).await.unwrap();
329        let out = run(RunRequest {
330            program: "bash",
331            args: &["-c", "pwd"],
332            cwd: &canonical,
333            deadline: None,
334            grace_ms: 1000,
335            max_stdout_bytes: 4096,
336            max_stderr_bytes: 4096,
337        })
338        .await
339        .unwrap();
340        let printed = out.stdout.trim_end();
341        assert_eq!(printed, canonical.to_string_lossy());
342    }
343}