Skip to main content

nexo_driver_loop/acceptance/
shell.rs

1//! Cross-platform-ish (Unix-default) shell runner used by the
2//! shell-command criterion and the built-in custom verifiers.
3
4use std::path::{Path, PathBuf};
5use std::process::Stdio;
6use std::time::{Duration, Instant};
7
8use tokio::process::Command;
9
10use crate::error::DriverError;
11
12#[derive(Clone, Debug)]
13pub struct ShellRunner {
14    shell: PathBuf,
15    output_byte_limit: usize,
16    forced_kill_after: Duration,
17}
18
19#[derive(Clone, Debug)]
20pub struct ShellResult {
21    pub exit_code: Option<i32>,
22    pub stdout: String,
23    pub stderr: String,
24    pub timed_out: bool,
25    pub duration: Duration,
26}
27
28impl Default for ShellRunner {
29    fn default() -> Self {
30        Self {
31            shell: PathBuf::from("/bin/sh"),
32            output_byte_limit: 1024 * 1024,
33            forced_kill_after: Duration::from_secs(1),
34        }
35    }
36}
37
38impl ShellRunner {
39    pub fn new() -> Self {
40        Self::default()
41    }
42    pub fn with_shell(mut self, p: impl Into<PathBuf>) -> Self {
43        self.shell = p.into();
44        self
45    }
46    pub fn with_output_byte_limit(mut self, n: usize) -> Self {
47        self.output_byte_limit = n;
48        self
49    }
50    pub fn with_forced_kill_after(mut self, d: Duration) -> Self {
51        self.forced_kill_after = d;
52        self
53    }
54
55    pub async fn run(
56        &self,
57        cmd: &str,
58        cwd: &Path,
59        timeout: Duration,
60    ) -> Result<ShellResult, DriverError> {
61        let started = Instant::now();
62        let mut child = Command::new(&self.shell)
63            .arg("-c")
64            .arg(cmd)
65            .current_dir(cwd)
66            .stdout(Stdio::piped())
67            .stderr(Stdio::piped())
68            .kill_on_drop(true)
69            .spawn()
70            .map_err(|e| DriverError::Acceptance(format!("spawn shell: {e}")))?;
71
72        let stdout = child.stdout.take();
73        let stderr = child.stderr.take();
74        let read_stdout = read_capped(stdout, self.output_byte_limit);
75        let read_stderr = read_capped(stderr, self.output_byte_limit);
76
77        let wait = child.wait();
78        let race = async {
79            tokio::select! {
80                w = wait => w,
81            }
82        };
83        let res = tokio::time::timeout(timeout, race).await;
84
85        match res {
86            Ok(Ok(status)) => {
87                let stdout = read_stdout.await;
88                let stderr = read_stderr.await;
89                Ok(ShellResult {
90                    exit_code: status.code(),
91                    stdout,
92                    stderr,
93                    timed_out: false,
94                    duration: started.elapsed(),
95                })
96            }
97            Ok(Err(e)) => Err(DriverError::Acceptance(format!("shell wait: {e}"))),
98            Err(_) => {
99                let _ = child.start_kill();
100                let _ = tokio::time::timeout(self.forced_kill_after, child.wait()).await;
101                let stdout = read_stdout.await;
102                let stderr = read_stderr.await;
103                Ok(ShellResult {
104                    exit_code: None,
105                    stdout,
106                    stderr,
107                    timed_out: true,
108                    duration: started.elapsed(),
109                })
110            }
111        }
112    }
113}
114
115async fn read_capped<R>(reader: Option<R>, limit: usize) -> String
116where
117    R: tokio::io::AsyncRead + Unpin,
118{
119    use tokio::io::AsyncReadExt;
120    let Some(mut r) = reader else {
121        return String::new();
122    };
123    let mut buf = Vec::with_capacity(limit.min(8192));
124    let mut chunk = [0u8; 8192];
125    loop {
126        match r.read(&mut chunk).await {
127            Ok(0) => break,
128            Ok(n) => {
129                let take = n.min(limit.saturating_sub(buf.len()));
130                buf.extend_from_slice(&chunk[..take]);
131                if buf.len() >= limit {
132                    // Drain rest to let the child finish writing.
133                    let mut sink = [0u8; 8192];
134                    while r.read(&mut sink).await.unwrap_or(0) > 0 {}
135                    break;
136                }
137            }
138            Err(_) => break,
139        }
140    }
141    // UTF-8 lossy, then truncate at last valid char if cap landed
142    // mid-codepoint.
143    String::from_utf8_lossy(&buf).into_owned()
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[tokio::test]
151    async fn echo_exit_zero() {
152        let r = ShellRunner::default()
153            .run("echo hello", &std::env::temp_dir(), Duration::from_secs(5))
154            .await
155            .unwrap();
156        assert_eq!(r.exit_code, Some(0));
157        assert!(r.stdout.contains("hello"));
158        assert!(!r.timed_out);
159    }
160
161    #[tokio::test]
162    async fn false_exits_one() {
163        let r = ShellRunner::default()
164            .run("false", &std::env::temp_dir(), Duration::from_secs(5))
165            .await
166            .unwrap();
167        assert_eq!(r.exit_code, Some(1));
168    }
169
170    #[tokio::test]
171    async fn timeout_marks_timed_out() {
172        let r = ShellRunner::default()
173            .run("sleep 5", &std::env::temp_dir(), Duration::from_millis(100))
174            .await
175            .unwrap();
176        assert!(r.timed_out, "expected timed_out=true");
177        assert_eq!(r.exit_code, None);
178    }
179
180    #[tokio::test]
181    async fn cwd_is_respected() {
182        let dir = tempfile::tempdir().unwrap();
183        let r = ShellRunner::default()
184            .run("pwd", dir.path(), Duration::from_secs(5))
185            .await
186            .unwrap();
187        let canonical = std::fs::canonicalize(dir.path()).unwrap();
188        assert!(r.stdout.contains(&canonical.display().to_string()));
189    }
190}