1use crate::context::ToolContext;
10use crate::toolset::{ToolOutcome, ToolSet};
11use async_trait::async_trait;
12use oharness_core::message::{Content, ToolOutput};
13use oharness_core::ToolSpec;
14use serde::Deserialize;
15use serde_json::{json, Value};
16use std::process::Stdio;
17use std::sync::OnceLock;
18use std::time::Duration;
19use tokio::process::Command;
20
21const DEFAULT_TIMEOUT_SECS: u64 = 60;
22const MAX_OUTPUT_BYTES: usize = 64 * 1024;
23
24pub struct BashTool {
26 name: String,
27 timeout: Duration,
28 env_allowlist: Option<Vec<String>>,
35 specs: Vec<ToolSpec>,
36}
37
38impl Default for BashTool {
39 fn default() -> Self {
40 Self::new("bash")
41 }
42}
43
44impl BashTool {
45 pub fn new(name: impl Into<String>) -> Self {
46 let name = name.into();
47 let specs = vec![ToolSpec {
48 name: name.clone(),
49 description: "Execute a shell command via `/bin/bash -c <command>`. Returns \
50 combined stdout/stderr. Commands run in the configured \
51 workspace directory, or the current directory if no workspace \
52 is set. Output is truncated at 64KiB."
53 .to_string(),
54 input_schema: default_schema(),
55 }];
56 Self {
57 name,
58 timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
59 env_allowlist: None,
60 specs,
61 }
62 }
63
64 pub fn with_timeout(mut self, d: Duration) -> Self {
65 self.timeout = d;
66 self
67 }
68
69 pub fn with_env_allowlist<I, S>(mut self, names: I) -> Self
83 where
84 I: IntoIterator<Item = S>,
85 S: Into<String>,
86 {
87 self.env_allowlist = Some(names.into_iter().map(Into::into).collect());
88 self
89 }
90}
91
92#[async_trait]
93impl ToolSet for BashTool {
94 fn specs(&self) -> &[ToolSpec] {
95 &self.specs
96 }
97
98 async fn execute(&self, name: &str, input: Value, ctx: &ToolContext) -> ToolOutcome {
99 if name != self.name {
100 return ToolOutcome::error(format!("tool `{name}` not handled by BashTool"), false);
101 }
102 if ctx.cancellation.is_cancelled() {
103 return ToolOutcome::Cancelled;
104 }
105
106 let parsed: BashInput = match serde_json::from_value(input) {
107 Ok(v) => v,
108 Err(e) => return ToolOutcome::error(format!("invalid bash input: {e}"), false),
109 };
110
111 let mut cmd = Command::new("/bin/bash");
112 cmd.arg("-c").arg(&parsed.command);
113 if let Some(ws) = ctx.workspace_path() {
114 cmd.current_dir(ws);
115 }
116
117 if let Some(names) = &self.env_allowlist {
121 cmd.env_clear();
122 for name in names {
123 if let Ok(val) = std::env::var(name) {
124 cmd.env(name, val);
125 }
126 }
127 }
128
129 cmd.stdout(Stdio::piped());
133 cmd.stderr(Stdio::piped());
134 cmd.stdin(Stdio::null());
138
139 cmd.kill_on_drop(true);
145
146 let timeout_dur = parsed
147 .timeout_secs
148 .map(Duration::from_secs)
149 .unwrap_or(self.timeout);
150
151 let cancellation = ctx.cancellation.clone();
158 let output = {
159 let child = match cmd.spawn() {
160 Ok(c) => c,
161 Err(e) => return ToolOutcome::error(format!("bash spawn: {e}"), true),
162 };
163 tokio::select! {
164 res = child.wait_with_output() => match res {
165 Ok(o) => o,
166 Err(e) => return ToolOutcome::error(format!("bash: {e}"), true),
167 },
168 _ = tokio::time::sleep(timeout_dur) => {
169 return ToolOutcome::error(
171 format!("bash: timed out after {}s", timeout_dur.as_secs()),
172 true,
173 );
174 }
175 _ = cancellation.cancelled() => {
176 return ToolOutcome::Cancelled;
177 }
178 }
179 };
180
181 let stdout = String::from_utf8_lossy(&output.stdout);
182 let stderr = String::from_utf8_lossy(&output.stderr);
183 let code = output.status.code();
184
185 let mut combined = String::new();
186 if !stdout.is_empty() {
187 combined.push_str("STDOUT:\n");
188 combined.push_str(&stdout);
189 }
190 if !stderr.is_empty() {
191 if !combined.is_empty() {
192 combined.push_str("\n\n");
193 }
194 combined.push_str("STDERR:\n");
195 combined.push_str(&stderr);
196 }
197 let (combined, truncated) = if combined.len() > MAX_OUTPUT_BYTES {
198 (
199 format!(
200 "{}\n\n[truncated at {MAX_OUTPUT_BYTES} bytes]",
201 &combined[..MAX_OUTPUT_BYTES]
202 ),
203 true,
204 )
205 } else {
206 (combined, false)
207 };
208
209 let tail = match code {
210 Some(0) => String::new(),
211 Some(c) => format!("\n\n[exit code: {c}]"),
212 None => "\n\n[exit: killed by signal]".to_string(),
213 };
214
215 ToolOutcome::Success(ToolOutput {
216 content: vec![Content::text(format!("{combined}{tail}"))],
217 truncated,
218 })
219 }
220}
221
222#[derive(Debug, Deserialize)]
223struct BashInput {
224 command: String,
225 #[serde(default)]
226 timeout_secs: Option<u64>,
227}
228
229fn default_schema() -> Value {
230 static SCHEMA: OnceLock<Value> = OnceLock::new();
231 SCHEMA
232 .get_or_init(|| {
233 json!({
234 "type": "object",
235 "required": ["command"],
236 "properties": {
237 "command": {
238 "type": "string",
239 "description": "The shell command to execute."
240 },
241 "timeout_secs": {
242 "type": "integer",
243 "description": "Optional per-call timeout in seconds.",
244 "minimum": 1
245 }
246 },
247 "additionalProperties": false
248 })
249 })
250 .clone()
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use std::time::Instant;
257
258 fn context() -> ToolContext {
259 ToolContext::null()
260 }
261
262 fn outcome_text(outcome: &ToolOutcome) -> String {
263 match outcome {
264 ToolOutcome::Success(output) => output
265 .content
266 .iter()
267 .filter_map(|c| match c {
268 Content::Text { text } => Some(text.as_str()),
269 _ => None,
270 })
271 .collect::<Vec<_>>()
272 .join("\n"),
273 ToolOutcome::ExecutionError { message, .. } => message.clone(),
274 ToolOutcome::Denied { reason } => reason.clone(),
275 ToolOutcome::Cancelled => String::from("<cancelled>"),
276 }
277 }
278
279 #[tokio::test]
280 async fn happy_path_captures_stdout() {
281 let tool = BashTool::default();
282 let outcome = tool
283 .execute("bash", json!({"command": "echo hello world"}), &context())
284 .await;
285 assert!(matches!(outcome, ToolOutcome::Success(_)));
286 let text = outcome_text(&outcome);
287 assert!(text.contains("hello world"), "missing stdout: {text}");
288 }
289
290 #[tokio::test]
296 async fn timeout_kills_subprocess_not_leaks_it() {
297 let tool = BashTool::default().with_timeout(Duration::from_secs(1));
298 let start = Instant::now();
299 let outcome = tool
300 .execute("bash", json!({"command": "sleep 30"}), &context())
301 .await;
302 let elapsed = start.elapsed();
303 assert!(
304 elapsed < Duration::from_secs(3),
305 "bash did not return promptly on timeout: took {elapsed:?}"
306 );
307 match outcome {
308 ToolOutcome::ExecutionError { message, .. } => {
309 assert!(message.contains("timed out"), "{message}");
310 }
311 other => panic!("expected ExecutionError, got {other:?}"),
312 }
313 }
314
315 #[tokio::test]
319 async fn cancellation_interrupts_running_command() {
320 let tool = BashTool::default().with_timeout(Duration::from_secs(30));
321 let mut ctx = ToolContext::null();
322 let token = ctx.cancellation.clone();
323 tokio::spawn(async move {
325 tokio::time::sleep(Duration::from_millis(200)).await;
326 token.cancel();
327 });
328 ctx.cancellation = ctx.cancellation.clone();
332
333 let start = Instant::now();
334 let outcome = tool
335 .execute("bash", json!({"command": "sleep 30"}), &ctx)
336 .await;
337 let elapsed = start.elapsed();
338 assert!(
339 elapsed < Duration::from_secs(3),
340 "cancellation was not prompt: took {elapsed:?}"
341 );
342 assert!(matches!(outcome, ToolOutcome::Cancelled), "got {outcome:?}");
343 }
344
345 #[tokio::test]
350 async fn env_allowlist_hides_unlisted_vars() {
351 std::env::set_var("OHARNESS_BASH_TEST_SECRET", "should-not-leak");
358
359 let tool = BashTool::default().with_env_allowlist(["PATH", "HOME"]);
360 let outcome = tool
361 .execute("bash", json!({"command": "env"}), &context())
362 .await;
363 let text = outcome_text(&outcome);
364 assert!(
365 !text.contains("OHARNESS_BASH_TEST_SECRET"),
366 "secret env var leaked through allowlist: {text}"
367 );
368
369 std::env::remove_var("OHARNESS_BASH_TEST_SECRET");
371 }
372
373 #[tokio::test]
376 async fn no_allowlist_inherits_env() {
377 std::env::set_var("OHARNESS_BASH_PASSTHROUGH", "visible");
378
379 let tool = BashTool::default();
380 let outcome = tool
381 .execute("bash", json!({"command": "env"}), &context())
382 .await;
383 let text = outcome_text(&outcome);
384 assert!(
385 text.contains("OHARNESS_BASH_PASSTHROUGH"),
386 "expected env var to passthrough without allowlist: {text}"
387 );
388
389 std::env::remove_var("OHARNESS_BASH_PASSTHROUGH");
390 }
391
392 #[tokio::test]
395 async fn large_output_is_truncated_flagged() {
396 let tool = BashTool::default();
397 let outcome = tool
399 .execute(
400 "bash",
401 json!({"command": "yes foo | head -c 200000"}),
402 &context(),
403 )
404 .await;
405 match outcome {
406 ToolOutcome::Success(output) => {
407 assert!(output.truncated, "truncated flag not set");
408 let text = output
409 .content
410 .iter()
411 .filter_map(|c| match c {
412 Content::Text { text } => Some(text.as_str()),
413 _ => None,
414 })
415 .collect::<Vec<_>>()
416 .join("");
417 assert!(text.contains("truncated at"), "missing truncation marker");
418 }
419 other => panic!("expected Success, got {other:?}"),
420 }
421 }
422}