nexo_driver_loop/acceptance/
shell.rs1use std::path::{Path, PathBuf};
5use std::process::Stdio;
6use std::time::{Duration, Instant};
7
8use tokio::process::Command;
9
10use crate::error::DriverError;
11
12#[derive(Clone, Debug)]
13pub struct ShellRunner {
14 shell: PathBuf,
15 output_byte_limit: usize,
16 forced_kill_after: Duration,
17}
18
19#[derive(Clone, Debug)]
20pub struct ShellResult {
21 pub exit_code: Option<i32>,
22 pub stdout: String,
23 pub stderr: String,
24 pub timed_out: bool,
25 pub duration: Duration,
26}
27
28impl Default for ShellRunner {
29 fn default() -> Self {
30 Self {
31 shell: PathBuf::from("/bin/sh"),
32 output_byte_limit: 1024 * 1024,
33 forced_kill_after: Duration::from_secs(1),
34 }
35 }
36}
37
38impl ShellRunner {
39 pub fn new() -> Self {
40 Self::default()
41 }
42 pub fn with_shell(mut self, p: impl Into<PathBuf>) -> Self {
43 self.shell = p.into();
44 self
45 }
46 pub fn with_output_byte_limit(mut self, n: usize) -> Self {
47 self.output_byte_limit = n;
48 self
49 }
50 pub fn with_forced_kill_after(mut self, d: Duration) -> Self {
51 self.forced_kill_after = d;
52 self
53 }
54
55 pub async fn run(
56 &self,
57 cmd: &str,
58 cwd: &Path,
59 timeout: Duration,
60 ) -> Result<ShellResult, DriverError> {
61 let started = Instant::now();
62 let mut child = Command::new(&self.shell)
63 .arg("-c")
64 .arg(cmd)
65 .current_dir(cwd)
66 .stdout(Stdio::piped())
67 .stderr(Stdio::piped())
68 .kill_on_drop(true)
69 .spawn()
70 .map_err(|e| DriverError::Acceptance(format!("spawn shell: {e}")))?;
71
72 let stdout = child.stdout.take();
73 let stderr = child.stderr.take();
74 let read_stdout = read_capped(stdout, self.output_byte_limit);
75 let read_stderr = read_capped(stderr, self.output_byte_limit);
76
77 let wait = child.wait();
78 let race = async {
79 tokio::select! {
80 w = wait => w,
81 }
82 };
83 let res = tokio::time::timeout(timeout, race).await;
84
85 match res {
86 Ok(Ok(status)) => {
87 let stdout = read_stdout.await;
88 let stderr = read_stderr.await;
89 Ok(ShellResult {
90 exit_code: status.code(),
91 stdout,
92 stderr,
93 timed_out: false,
94 duration: started.elapsed(),
95 })
96 }
97 Ok(Err(e)) => Err(DriverError::Acceptance(format!("shell wait: {e}"))),
98 Err(_) => {
99 let _ = child.start_kill();
100 let _ = tokio::time::timeout(self.forced_kill_after, child.wait()).await;
101 let stdout = read_stdout.await;
102 let stderr = read_stderr.await;
103 Ok(ShellResult {
104 exit_code: None,
105 stdout,
106 stderr,
107 timed_out: true,
108 duration: started.elapsed(),
109 })
110 }
111 }
112 }
113}
114
115async fn read_capped<R>(reader: Option<R>, limit: usize) -> String
116where
117 R: tokio::io::AsyncRead + Unpin,
118{
119 use tokio::io::AsyncReadExt;
120 let Some(mut r) = reader else {
121 return String::new();
122 };
123 let mut buf = Vec::with_capacity(limit.min(8192));
124 let mut chunk = [0u8; 8192];
125 loop {
126 match r.read(&mut chunk).await {
127 Ok(0) => break,
128 Ok(n) => {
129 let take = n.min(limit.saturating_sub(buf.len()));
130 buf.extend_from_slice(&chunk[..take]);
131 if buf.len() >= limit {
132 let mut sink = [0u8; 8192];
134 while r.read(&mut sink).await.unwrap_or(0) > 0 {}
135 break;
136 }
137 }
138 Err(_) => break,
139 }
140 }
141 String::from_utf8_lossy(&buf).into_owned()
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[tokio::test]
151 async fn echo_exit_zero() {
152 let r = ShellRunner::default()
153 .run("echo hello", &std::env::temp_dir(), Duration::from_secs(5))
154 .await
155 .unwrap();
156 assert_eq!(r.exit_code, Some(0));
157 assert!(r.stdout.contains("hello"));
158 assert!(!r.timed_out);
159 }
160
161 #[tokio::test]
162 async fn false_exits_one() {
163 let r = ShellRunner::default()
164 .run("false", &std::env::temp_dir(), Duration::from_secs(5))
165 .await
166 .unwrap();
167 assert_eq!(r.exit_code, Some(1));
168 }
169
170 #[tokio::test]
171 async fn timeout_marks_timed_out() {
172 let r = ShellRunner::default()
173 .run("sleep 5", &std::env::temp_dir(), Duration::from_millis(100))
174 .await
175 .unwrap();
176 assert!(r.timed_out, "expected timed_out=true");
177 assert_eq!(r.exit_code, None);
178 }
179
180 #[tokio::test]
181 async fn cwd_is_respected() {
182 let dir = tempfile::tempdir().unwrap();
183 let r = ShellRunner::default()
184 .run("pwd", dir.path(), Duration::from_secs(5))
185 .await
186 .unwrap();
187 let canonical = std::fs::canonicalize(dir.path()).unwrap();
188 assert!(r.stdout.contains(&canonical.display().to_string()));
189 }
190}