use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use async_trait::async_trait;
use caliban_agent_core::{Tool, ToolContext, ToolError};
use caliban_provider::{ContentBlock, TextBlock};
use caliban_sandbox::SandboxedShim;
use serde::Deserialize;
use serde_json::{Value, json};
use tokio::io::AsyncReadExt;
use super::bash_bg::{BashBgRegistry, global_registry, spawn_background};
use crate::workspace::WorkspaceRoot;
const STDOUT_CAP: usize = 30 * 1024;
const STDERR_CAP: usize = 30 * 1024;
#[derive(Debug)]
pub struct BashTool {
root: Arc<WorkspaceRoot>,
schema: OnceLock<Value>,
sandbox: Option<Arc<SandboxedShim>>,
bg_registry: Arc<BashBgRegistry>,
}
impl BashTool {
#[must_use]
pub fn new(root: WorkspaceRoot) -> Self {
Self {
root: Arc::new(root),
schema: OnceLock::new(),
sandbox: None,
bg_registry: global_registry(),
}
}
#[must_use]
pub fn with_sandbox(root: WorkspaceRoot, sandbox: Option<Arc<SandboxedShim>>) -> Self {
Self {
root: Arc::new(root),
schema: OnceLock::new(),
sandbox,
bg_registry: global_registry(),
}
}
#[must_use]
pub fn with_bg_registry(mut self, registry: Arc<BashBgRegistry>) -> Self {
self.bg_registry = registry;
self
}
#[must_use]
pub fn is_sandboxed(&self) -> bool {
self.sandbox.as_ref().is_some_and(|s| s.is_active())
}
}
#[derive(Debug, Deserialize)]
struct BashInput {
command: String,
#[serde(default)]
timeout_seconds: Option<u64>,
#[serde(default)]
cwd: Option<String>,
#[serde(default)]
background: bool,
}
async fn read_capped<R: AsyncReadExt + Unpin>(mut reader: R, cap: usize) -> String {
let mut buf = Vec::with_capacity(cap);
let mut chunk = [0u8; 4096];
while buf.len() < cap {
match reader.read(&mut chunk).await {
Ok(0) | Err(_) => break,
Ok(n) => {
let take = n.min(cap - buf.len());
buf.extend_from_slice(&chunk[..take]);
}
}
}
String::from_utf8_lossy(&buf).into_owned()
}
enum RawOutcome {
Done(std::io::Result<std::process::ExitStatus>),
Timeout,
Cancelled,
}
#[allow(unused_variables)]
#[allow(unsafe_code)] async fn kill_process_tree(child_pid: Option<u32>, child: &mut tokio::process::Child) {
#[cfg(unix)]
if let Some(pid) = child_pid {
super::signal_process_group(pid, libc::SIGKILL);
}
let _ = child.start_kill();
let _ = child.wait().await;
}
#[async_trait]
impl Tool for BashTool {
fn name(&self) -> &'static str {
"Bash"
}
fn description(&self) -> &'static str {
"Run a shell command via /bin/sh -c. Captures stdout and stderr. Enforces a timeout. Returns exit code, stdout, and stderr."
}
fn input_schema(&self) -> &Value {
self.schema.get_or_init(|| json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "Shell command to run"
},
"timeout_seconds": {
"type": "integer",
"description": "Maximum seconds to wait before killing the process (default 60, min 1, max 600)",
"minimum": 1,
"maximum": 600
},
"cwd": {
"type": "string",
"description": "Working directory for the command, relative to workspace root (default: workspace root)"
},
"background": {
"type": "boolean",
"description": "When true, runs the command in the background; the call returns immediately with a shell_id usable with BashOutput and KillShell."
}
},
"required": ["command"]
}))
}
async fn invoke(&self, input: Value, cx: ToolContext) -> Result<Vec<ContentBlock>, ToolError> {
let parsed: BashInput = crate::parse_input(input)?;
let cwd = match parsed.cwd {
Some(ref c) => self.root.resolve(c)?,
None => self.root.root().to_path_buf(),
};
if parsed.background {
let id = spawn_background(
&self.bg_registry,
parsed.command.clone(),
&cwd,
self.sandbox.as_ref(),
)?;
return Ok(vec![ContentBlock::Text(TextBlock {
text: format!(
"Started background shell {id}: {} (use BashOutput/KillShell with shell_id={id})",
parsed.command,
),
cache_control: None,
})]);
}
let timeout_secs = parsed.timeout_seconds.unwrap_or(60).clamp(1, 600);
let timeout = Duration::from_secs(timeout_secs);
let mut shell = super::bash_bg::build_shell(&parsed.command, &cwd, self.sandbox.as_ref())?;
let mut child = shell.spawn().map_err(ToolError::execution)?;
let child_pid = child.id();
let stdout_pipe = child.stdout.take().expect("piped");
let stderr_pipe = child.stderr.take().expect("piped");
let read_out = tokio::spawn(read_capped(stdout_pipe, STDOUT_CAP));
let read_err = tokio::spawn(read_capped(stderr_pipe, STDERR_CAP));
let outcome = {
let wait_fut = child.wait();
tokio::pin!(wait_fut);
tokio::select! {
result = &mut wait_fut => RawOutcome::Done(result),
() = tokio::time::sleep(timeout) => RawOutcome::Timeout,
() = cx.cancel.cancelled() => RawOutcome::Cancelled,
}
};
match outcome {
RawOutcome::Done(status_result) => {
let status = status_result.map_err(ToolError::execution)?;
let stdout_str = read_out.await.unwrap_or_default();
let stderr_str = read_err.await.unwrap_or_default();
let exit_code_num = status.code();
let exit_code_display = exit_code_num
.map_or_else(|| "(killed by signal)".to_string(), |c| c.to_string());
let stdout_section = if stdout_str.is_empty() {
"(empty)".to_string()
} else {
stdout_str
};
let stderr_section = if stderr_str.is_empty() {
"(empty)".to_string()
} else {
stderr_str
};
let text = format!(
"→ Bash command: {}\n→ Exit code: {}\n→ Stdout:\n{}\n→ Stderr:\n{}",
parsed.command, exit_code_display, stdout_section, stderr_section,
);
if exit_code_num != Some(0) {
return Err(ToolError::execution(std::io::Error::other(text)));
}
Ok(vec![ContentBlock::Text(TextBlock {
text,
cache_control: None,
})])
}
RawOutcome::Timeout => {
kill_process_tree(child_pid, &mut child).await;
read_out.abort();
read_err.abort();
Err(ToolError::execution(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!("command timed out after {timeout_secs}s"),
)))
}
RawOutcome::Cancelled => {
kill_process_tree(child_pid, &mut child).await;
read_out.abort();
read_err.abort();
Err(ToolError::Cancelled)
}
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use serde_json::json;
use tempfile::TempDir;
use tokio_util::sync::CancellationToken;
use super::*;
fn ctx() -> ToolContext {
ToolContext {
tool_use_id: "t1".into(),
cancel: CancellationToken::new(),
hooks: None,
turn_index: 0,
}
}
#[tokio::test]
async fn echo_succeeds() {
let tmp = TempDir::new().unwrap();
let tool = BashTool::new(WorkspaceRoot::new(tmp.path()));
let out = tool
.invoke(json!({"command": "echo hi"}), ctx())
.await
.unwrap();
let ContentBlock::Text(t) = &out[0] else {
panic!("expected Text block")
};
assert!(t.text.contains("hi"), "output: {}", t.text);
assert!(t.text.contains("Exit code: 0"), "output: {}", t.text);
}
#[tokio::test]
async fn nonzero_exit_returns_tool_error() {
let tmp = TempDir::new().unwrap();
let tool = BashTool::new(WorkspaceRoot::new(tmp.path()));
let err = tool
.invoke(json!({"command": "exit 7"}), ctx())
.await
.unwrap_err();
let msg = format!("{err}");
assert!(
matches!(err, ToolError::Execution(_)),
"wrong variant: {err:?}"
);
assert!(msg.contains("Exit code: 7"), "msg: {msg}");
}
#[tokio::test]
async fn command_not_found_returns_tool_error() {
let tmp = TempDir::new().unwrap();
let tool = BashTool::new(WorkspaceRoot::new(tmp.path()));
let err = tool
.invoke(
json!({"command": "this-command-definitely-does-not-exist-zzz"}),
ctx(),
)
.await
.unwrap_err();
let msg = format!("{err}");
assert!(
matches!(err, ToolError::Execution(_)),
"wrong variant: {err:?}"
);
assert!(msg.contains("Exit code: 127"), "msg: {msg}");
}
#[tokio::test]
async fn timeout_fires() {
let tmp = TempDir::new().unwrap();
let tool = BashTool::new(WorkspaceRoot::new(tmp.path()));
let start = std::time::Instant::now();
let err = tool
.invoke(json!({"command": "sleep 5", "timeout_seconds": 1}), ctx())
.await
.unwrap_err();
assert!(
start.elapsed().as_secs() < 3,
"timeout took too long: {:?}",
start.elapsed()
);
let s = format!("{err}");
assert!(
s.to_lowercase().contains("timed out") || s.to_lowercase().contains("timeout"),
"error message: {s}"
);
}
#[tokio::test]
async fn cancellation_kills_process() {
let tmp = TempDir::new().unwrap();
let tool = BashTool::new(WorkspaceRoot::new(tmp.path()));
let cancel = CancellationToken::new();
let cancel_clone = cancel.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
cancel_clone.cancel();
});
let cx = ToolContext {
tool_use_id: "t1".into(),
cancel,
hooks: None,
turn_index: 0,
};
let start = std::time::Instant::now();
let err = tool
.invoke(json!({"command": "sleep 30"}), cx)
.await
.unwrap_err();
assert!(
start.elapsed().as_millis() < 1000,
"cancellation took too long: {:?}",
start.elapsed()
);
assert!(
matches!(err, ToolError::Cancelled),
"expected Cancelled, got: {err}"
);
}
#[tokio::test]
async fn background_true_returns_immediately_with_shell_id() {
let tmp = TempDir::new().unwrap();
let reg = crate::shell::bash_bg::BashBgRegistry::new_for_test(1024 * 1024);
let tool = BashTool::new(WorkspaceRoot::new(tmp.path())).with_bg_registry(reg.clone());
let start = std::time::Instant::now();
let out = tool
.invoke(json!({"command": "sleep 30", "background": true}), ctx())
.await
.unwrap();
assert!(
start.elapsed() < Duration::from_secs(1),
"background invoke blocked: {:?}",
start.elapsed()
);
let ContentBlock::Text(t) = &out[0] else {
panic!("expected Text")
};
assert!(
t.text.contains("Started background shell"),
"out: {}",
t.text
);
assert_eq!(reg.running_count(), 1);
reg.kill_all();
}
#[cfg(unix)]
#[tokio::test]
async fn cancellation_kills_subprocess_tree() {
use std::process::Command;
let tmp = TempDir::new().unwrap();
let tool = BashTool::new(WorkspaceRoot::new(tmp.path()));
let cancel = CancellationToken::new();
let marker_seconds: u32 = 30000 + (std::process::id() % 1000);
let cmd = format!("/bin/sleep {marker_seconds} & wait");
let cancel_clone = cancel.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
cancel_clone.cancel();
});
let cx = ToolContext {
tool_use_id: "t1".into(),
cancel,
hooks: None,
turn_index: 0,
};
let err = tool.invoke(json!({"command": cmd}), cx).await.unwrap_err();
assert!(matches!(err, ToolError::Cancelled));
tokio::time::sleep(Duration::from_millis(300)).await;
let needle = format!("sleep {marker_seconds}");
let output = Command::new("ps")
.arg("-eo")
.arg("pid,command")
.output()
.expect("ps should run");
let ps_text = String::from_utf8_lossy(&output.stdout);
let surviving: Vec<&str> = ps_text
.lines()
.filter(|l| l.contains(&needle) && !l.contains("grep") && !l.contains("ps -eo"))
.collect();
assert!(
surviving.is_empty(),
"subprocess(es) survived cancellation:\n{}",
surviving.join("\n"),
);
}
}