use crate::context::ToolContext;
use crate::toolset::{ToolOutcome, ToolSet};
use async_trait::async_trait;
use oharness_core::message::{Content, ToolOutput};
use oharness_core::ToolSpec;
use serde::Deserialize;
use serde_json::{json, Value};
use std::process::Stdio;
use std::sync::OnceLock;
use std::time::Duration;
use tokio::process::Command;
const DEFAULT_TIMEOUT_SECS: u64 = 60;
const MAX_OUTPUT_BYTES: usize = 64 * 1024;
pub struct BashTool {
name: String,
timeout: Duration,
env_allowlist: Option<Vec<String>>,
specs: Vec<ToolSpec>,
}
impl Default for BashTool {
fn default() -> Self {
Self::new("bash")
}
}
impl BashTool {
pub fn new(name: impl Into<String>) -> Self {
let name = name.into();
let specs = vec![ToolSpec {
name: name.clone(),
description: "Execute a shell command via `/bin/bash -c <command>`. Returns \
combined stdout/stderr. Commands run in the configured \
workspace directory, or the current directory if no workspace \
is set. Output is truncated at 64KiB."
.to_string(),
input_schema: default_schema(),
}];
Self {
name,
timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
env_allowlist: None,
specs,
}
}
pub fn with_timeout(mut self, d: Duration) -> Self {
self.timeout = d;
self
}
pub fn with_env_allowlist<I, S>(mut self, names: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.env_allowlist = Some(names.into_iter().map(Into::into).collect());
self
}
}
#[async_trait]
impl ToolSet for BashTool {
fn specs(&self) -> &[ToolSpec] {
&self.specs
}
async fn execute(&self, name: &str, input: Value, ctx: &ToolContext) -> ToolOutcome {
if name != self.name {
return ToolOutcome::error(format!("tool `{name}` not handled by BashTool"), false);
}
if ctx.cancellation.is_cancelled() {
return ToolOutcome::Cancelled;
}
let parsed: BashInput = match serde_json::from_value(input) {
Ok(v) => v,
Err(e) => return ToolOutcome::error(format!("invalid bash input: {e}"), false),
};
let mut cmd = Command::new("/bin/bash");
cmd.arg("-c").arg(&parsed.command);
if let Some(ws) = ctx.workspace_path() {
cmd.current_dir(ws);
}
if let Some(names) = &self.env_allowlist {
cmd.env_clear();
for name in names {
if let Ok(val) = std::env::var(name) {
cmd.env(name, val);
}
}
}
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
cmd.stdin(Stdio::null());
cmd.kill_on_drop(true);
let timeout_dur = parsed
.timeout_secs
.map(Duration::from_secs)
.unwrap_or(self.timeout);
let cancellation = ctx.cancellation.clone();
let output = {
let child = match cmd.spawn() {
Ok(c) => c,
Err(e) => return ToolOutcome::error(format!("bash spawn: {e}"), true),
};
tokio::select! {
res = child.wait_with_output() => match res {
Ok(o) => o,
Err(e) => return ToolOutcome::error(format!("bash: {e}"), true),
},
_ = tokio::time::sleep(timeout_dur) => {
return ToolOutcome::error(
format!("bash: timed out after {}s", timeout_dur.as_secs()),
true,
);
}
_ = cancellation.cancelled() => {
return ToolOutcome::Cancelled;
}
}
};
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
let code = output.status.code();
let mut combined = String::new();
if !stdout.is_empty() {
combined.push_str("STDOUT:\n");
combined.push_str(&stdout);
}
if !stderr.is_empty() {
if !combined.is_empty() {
combined.push_str("\n\n");
}
combined.push_str("STDERR:\n");
combined.push_str(&stderr);
}
let (combined, truncated) = if combined.len() > MAX_OUTPUT_BYTES {
(
format!(
"{}\n\n[truncated at {MAX_OUTPUT_BYTES} bytes]",
&combined[..MAX_OUTPUT_BYTES]
),
true,
)
} else {
(combined, false)
};
let tail = match code {
Some(0) => String::new(),
Some(c) => format!("\n\n[exit code: {c}]"),
None => "\n\n[exit: killed by signal]".to_string(),
};
ToolOutcome::Success(ToolOutput {
content: vec![Content::text(format!("{combined}{tail}"))],
truncated,
})
}
}
#[derive(Debug, Deserialize)]
struct BashInput {
command: String,
#[serde(default)]
timeout_secs: Option<u64>,
}
fn default_schema() -> Value {
static SCHEMA: OnceLock<Value> = OnceLock::new();
SCHEMA
.get_or_init(|| {
json!({
"type": "object",
"required": ["command"],
"properties": {
"command": {
"type": "string",
"description": "The shell command to execute."
},
"timeout_secs": {
"type": "integer",
"description": "Optional per-call timeout in seconds.",
"minimum": 1
}
},
"additionalProperties": false
})
})
.clone()
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Instant;
fn context() -> ToolContext {
ToolContext::null()
}
fn outcome_text(outcome: &ToolOutcome) -> String {
match outcome {
ToolOutcome::Success(output) => output
.content
.iter()
.filter_map(|c| match c {
Content::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n"),
ToolOutcome::ExecutionError { message, .. } => message.clone(),
ToolOutcome::Denied { reason } => reason.clone(),
ToolOutcome::Cancelled => String::from("<cancelled>"),
}
}
#[tokio::test]
async fn happy_path_captures_stdout() {
let tool = BashTool::default();
let outcome = tool
.execute("bash", json!({"command": "echo hello world"}), &context())
.await;
assert!(matches!(outcome, ToolOutcome::Success(_)));
let text = outcome_text(&outcome);
assert!(text.contains("hello world"), "missing stdout: {text}");
}
#[tokio::test]
async fn timeout_kills_subprocess_not_leaks_it() {
let tool = BashTool::default().with_timeout(Duration::from_secs(1));
let start = Instant::now();
let outcome = tool
.execute("bash", json!({"command": "sleep 30"}), &context())
.await;
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_secs(3),
"bash did not return promptly on timeout: took {elapsed:?}"
);
match outcome {
ToolOutcome::ExecutionError { message, .. } => {
assert!(message.contains("timed out"), "{message}");
}
other => panic!("expected ExecutionError, got {other:?}"),
}
}
#[tokio::test]
async fn cancellation_interrupts_running_command() {
let tool = BashTool::default().with_timeout(Duration::from_secs(30));
let mut ctx = ToolContext::null();
let token = ctx.cancellation.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(200)).await;
token.cancel();
});
ctx.cancellation = ctx.cancellation.clone();
let start = Instant::now();
let outcome = tool
.execute("bash", json!({"command": "sleep 30"}), &ctx)
.await;
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_secs(3),
"cancellation was not prompt: took {elapsed:?}"
);
assert!(matches!(outcome, ToolOutcome::Cancelled), "got {outcome:?}");
}
#[tokio::test]
async fn env_allowlist_hides_unlisted_vars() {
std::env::set_var("OHARNESS_BASH_TEST_SECRET", "should-not-leak");
let tool = BashTool::default().with_env_allowlist(["PATH", "HOME"]);
let outcome = tool
.execute("bash", json!({"command": "env"}), &context())
.await;
let text = outcome_text(&outcome);
assert!(
!text.contains("OHARNESS_BASH_TEST_SECRET"),
"secret env var leaked through allowlist: {text}"
);
std::env::remove_var("OHARNESS_BASH_TEST_SECRET");
}
#[tokio::test]
async fn no_allowlist_inherits_env() {
std::env::set_var("OHARNESS_BASH_PASSTHROUGH", "visible");
let tool = BashTool::default();
let outcome = tool
.execute("bash", json!({"command": "env"}), &context())
.await;
let text = outcome_text(&outcome);
assert!(
text.contains("OHARNESS_BASH_PASSTHROUGH"),
"expected env var to passthrough without allowlist: {text}"
);
std::env::remove_var("OHARNESS_BASH_PASSTHROUGH");
}
#[tokio::test]
async fn large_output_is_truncated_flagged() {
let tool = BashTool::default();
let outcome = tool
.execute(
"bash",
json!({"command": "yes foo | head -c 200000"}),
&context(),
)
.await;
match outcome {
ToolOutcome::Success(output) => {
assert!(output.truncated, "truncated flag not set");
let text = output
.content
.iter()
.filter_map(|c| match c {
Content::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("");
assert!(text.contains("truncated at"), "missing truncation marker");
}
other => panic!("expected Success, got {other:?}"),
}
}
}