use std::collections::HashMap;
use std::collections::VecDeque;
use std::sync::OnceLock;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use async_trait::async_trait;
use caliban_agent_core::{Tool, ToolContext, ToolError};
use caliban_provider::{ContentBlock, TextBlock};
use caliban_sandbox::SandboxedShim;
use serde::Deserialize;
use serde_json::{Value, json};
pub const DEFAULT_RING_CAP_BYTES: usize = 5 * 1024 * 1024 * 1024;
pub const KILL_GRACE: Duration = Duration::from_secs(5);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BashStatus {
Running,
Exited(i32),
Killed,
}
impl BashStatus {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::Running => "running",
Self::Exited(_) => "exited",
Self::Killed => "killed",
}
}
}
#[derive(Debug)]
pub struct RingBuffer {
cap: usize,
written: u64,
dropped: u64,
buf: VecDeque<u8>,
}
impl RingBuffer {
#[must_use]
pub fn with_cap(cap: usize) -> Self {
Self {
cap,
written: 0,
dropped: 0,
buf: VecDeque::with_capacity(std::cmp::min(cap, 64 * 1024)),
}
}
pub fn push(&mut self, bytes: &[u8]) -> u64 {
let to_take = if bytes.len() > self.cap {
let start = bytes.len() - self.cap;
self.buf.clear();
self.dropped = self.written + (start as u64);
&bytes[start..]
} else {
bytes
};
let new_total = self.buf.len() + to_take.len();
if new_total > self.cap {
let drop_n = new_total - self.cap;
for _ in 0..drop_n {
self.buf.pop_front();
}
self.dropped += drop_n as u64;
}
self.buf.extend(to_take);
self.written += bytes.len() as u64;
self.written
}
#[must_use]
pub fn written(&self) -> u64 {
self.written
}
#[must_use]
pub fn dropped(&self) -> u64 {
self.dropped
}
#[must_use]
pub fn len(&self) -> usize {
self.buf.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.buf.is_empty()
}
#[must_use]
pub fn snapshot(&self) -> (String, u64) {
let bytes: Vec<u8> = self.buf.iter().copied().collect();
let text = String::from_utf8_lossy(&bytes).into_owned();
(text, self.written)
}
#[must_use]
pub fn read_since(&self, since: u64) -> (String, u64, u64) {
let start_offset = std::cmp::max(since, self.dropped);
let skip = usize::try_from(start_offset - self.dropped).unwrap_or(usize::MAX);
let bytes: Vec<u8> = self.buf.iter().copied().skip(skip).collect();
let text = String::from_utf8_lossy(&bytes).into_owned();
(text, start_offset, self.written)
}
}
pub struct BashJob {
pub id: String,
pub command: String,
pub started_at: Instant,
pub status: Mutex<BashStatus>,
pub stdout: Mutex<RingBuffer>,
pub stderr: Mutex<RingBuffer>,
pub pid: Mutex<Option<u32>>,
pub cancel: tokio_util::sync::CancellationToken,
}
impl std::fmt::Debug for BashJob {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BashJob")
.field("id", &self.id)
.field("command", &self.command)
.field("started_at", &self.started_at)
.field("status", &*self.status.lock().unwrap())
.finish_non_exhaustive()
}
}
impl BashJob {
#[must_use]
pub fn snapshot_status(&self) -> BashStatus {
*self.status.lock().unwrap()
}
}
pub struct BashBgRegistry {
jobs: Mutex<HashMap<String, Arc<BashJob>>>,
cap_bytes: usize,
}
impl std::fmt::Debug for BashBgRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BashBgRegistry")
.field("jobs", &self.jobs.lock().unwrap().len())
.field("cap_bytes", &self.cap_bytes)
.finish()
}
}
impl BashBgRegistry {
#[must_use]
pub fn new() -> Self {
Self::with_cap(DEFAULT_RING_CAP_BYTES)
}
#[must_use]
pub fn with_cap(cap_bytes: usize) -> Self {
Self {
jobs: Mutex::new(HashMap::new()),
cap_bytes,
}
}
#[must_use]
pub fn new_for_test(cap_bytes: usize) -> Arc<Self> {
Arc::new(Self::with_cap(cap_bytes))
}
#[must_use]
pub fn cap_bytes(&self) -> usize {
self.cap_bytes
}
#[must_use]
pub fn job_count(&self) -> usize {
self.jobs.lock().unwrap().len()
}
#[must_use]
pub fn running_count(&self) -> usize {
self.jobs
.lock()
.unwrap()
.values()
.filter(|j| j.snapshot_status() == BashStatus::Running)
.count()
}
pub fn insert(&self, job: Arc<BashJob>) {
self.jobs.lock().unwrap().insert(job.id.clone(), job);
}
#[must_use]
pub fn get(&self, id: &str) -> Option<Arc<BashJob>> {
self.jobs.lock().unwrap().get(id).cloned()
}
pub fn remove(&self, id: &str) -> Option<Arc<BashJob>> {
self.jobs.lock().unwrap().remove(id)
}
#[must_use]
pub fn list(&self) -> Vec<(String, BashStatus, String)> {
self.jobs
.lock()
.unwrap()
.values()
.map(|j| (j.id.clone(), j.snapshot_status(), j.command.clone()))
.collect()
}
pub fn kill_all(&self) {
let ids: Vec<String> = self.jobs.lock().unwrap().keys().cloned().collect();
for id in ids {
if let Some(job) = self.get(&id)
&& job.snapshot_status() == BashStatus::Running
{
kill_job_now(&job, true);
}
}
}
}
impl Default for BashBgRegistry {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_REGISTRY: OnceLock<Arc<BashBgRegistry>> = OnceLock::new();
#[must_use]
pub fn global_registry() -> Arc<BashBgRegistry> {
GLOBAL_REGISTRY
.get_or_init(|| Arc::new(BashBgRegistry::new()))
.clone()
}
#[allow(unsafe_code)] fn kill_job_now(job: &BashJob, force_kill: bool) {
let pid = *job.pid.lock().unwrap();
#[cfg(unix)]
if let Some(p) = pid {
super::signal_process_group(
p,
if force_kill {
libc::SIGKILL
} else {
libc::SIGTERM
},
);
}
#[cfg(not(unix))]
let _ = pid;
job.cancel.cancel();
}
#[must_use]
pub fn new_shell_id() -> String {
let id = uuid::Uuid::new_v4().simple().to_string();
id.chars().take(12).collect()
}
pub(super) fn build_shell(
command: &str,
cwd: &std::path::Path,
sandbox: Option<&Arc<SandboxedShim>>,
) -> Result<tokio::process::Command, ToolError> {
use std::process::Stdio;
let mut shell = tokio::process::Command::new("/bin/sh");
shell
.arg("-c")
.arg(command)
.current_dir(cwd)
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true);
#[cfg(unix)]
shell.process_group(0);
if let Some(shim) = sandbox {
shim.wrap_command(&mut shell, command)
.map_err(|e| ToolError::execution(std::io::Error::other(format!("sandbox: {e}"))))?;
}
Ok(shell)
}
pub fn spawn_background(
registry: &Arc<BashBgRegistry>,
command: String,
cwd: &std::path::Path,
sandbox: Option<&Arc<SandboxedShim>>,
) -> Result<String, ToolError> {
let id = new_shell_id();
let cap = registry.cap_bytes();
let mut shell = build_shell(&command, cwd, sandbox)?;
let mut child = shell.spawn().map_err(ToolError::execution)?;
let pid = child.id();
let stdout = child.stdout.take().expect("piped");
let stderr = child.stderr.take().expect("piped");
let job = Arc::new(BashJob {
id: id.clone(),
command,
started_at: Instant::now(),
status: Mutex::new(BashStatus::Running),
stdout: Mutex::new(RingBuffer::with_cap(cap)),
stderr: Mutex::new(RingBuffer::with_cap(cap)),
pid: Mutex::new(pid),
cancel: tokio_util::sync::CancellationToken::new(),
});
registry.insert(job.clone());
let stdout_job = job.clone();
tokio::spawn(async move {
drain_to_ring(stdout, &stdout_job, true).await;
});
let stderr_job = job.clone();
tokio::spawn(async move {
drain_to_ring(stderr, &stderr_job, false).await;
});
let watch_job = job.clone();
tokio::spawn(async move {
let exit = tokio::select! {
r = child.wait() => Some(r),
() = watch_job.cancel.cancelled() => None,
};
{
let mut status_lock = watch_job.status.lock().unwrap();
if let Some(Ok(s)) = exit {
if let Some(code) = s.code() {
*status_lock = BashStatus::Exited(code);
} else {
*status_lock = BashStatus::Killed;
}
} else {
*status_lock = BashStatus::Killed;
}
}
if !matches!(exit, Some(Ok(_))) {
let _ = child.start_kill();
drop(child);
}
*watch_job.pid.lock().unwrap() = None;
});
Ok(id)
}
async fn drain_to_ring<R>(reader: R, job: &BashJob, to_stdout: bool)
where
R: tokio::io::AsyncRead + Unpin,
{
use tokio::io::AsyncReadExt;
let mut reader = reader;
let mut buf = [0u8; 8192];
loop {
match reader.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => {
if to_stdout {
job.stdout.lock().unwrap().push(&buf[..n]);
} else {
job.stderr.lock().unwrap().push(&buf[..n]);
}
}
}
}
}
#[derive(Debug)]
pub struct BashOutputTool {
registry: Arc<BashBgRegistry>,
schema: OnceLock<Value>,
}
impl BashOutputTool {
#[must_use]
pub fn new(registry: Arc<BashBgRegistry>) -> Self {
Self {
registry,
schema: OnceLock::new(),
}
}
#[must_use]
pub fn with_global_registry() -> Self {
Self::new(global_registry())
}
}
#[derive(Debug, Deserialize)]
struct BashOutputInput {
shell_id: String,
#[serde(default)]
since_offset: Option<u64>,
}
#[async_trait]
impl Tool for BashOutputTool {
fn name(&self) -> &'static str {
"BashOutput"
}
fn description(&self) -> &'static str {
"Read the latest stdout/stderr from a background shell launched via Bash with background:true. Optional since_offset returns only the slice past that absolute byte offset (for incremental polling)."
}
fn input_schema(&self) -> &Value {
self.schema.get_or_init(|| {
json!({
"type": "object",
"properties": {
"shell_id": { "type": "string", "description": "Shell id returned by Bash(background=true)." },
"since_offset": { "type": "integer", "minimum": 0, "description": "Return only bytes after this absolute byte offset (for incremental polling)." }
},
"required": ["shell_id"]
})
})
}
async fn invoke(&self, input: Value, _cx: ToolContext) -> Result<Vec<ContentBlock>, ToolError> {
let parsed: BashOutputInput = crate::parse_input(input)?;
let job = self.registry.get(&parsed.shell_id).ok_or_else(|| {
ToolError::execution(std::io::Error::other(format!(
"no background shell with id {}",
parsed.shell_id
)))
})?;
let since = parsed.since_offset.unwrap_or(0);
let (stdout_text, stdout_start, stdout_end) = job.stdout.lock().unwrap().read_since(since);
let (stderr_text, stderr_start, stderr_end) = job.stderr.lock().unwrap().read_since(since);
let status = job.snapshot_status();
let age = job.started_at.elapsed();
let header = format!(
"shell_id: {} status: {} age: {:.1}s\nstdout (bytes {}..{}):\n",
job.id,
status.as_str(),
age.as_secs_f32(),
stdout_start,
stdout_end,
);
let mid = format!("\nstderr (bytes {stderr_start}..{stderr_end}):\n");
let text = format!("{header}{stdout_text}{mid}{stderr_text}");
Ok(vec![ContentBlock::Text(TextBlock {
text,
cache_control: None,
})])
}
}
#[derive(Debug)]
pub struct KillShellTool {
registry: Arc<BashBgRegistry>,
grace: Duration,
schema: OnceLock<Value>,
}
impl KillShellTool {
#[must_use]
pub fn new(registry: Arc<BashBgRegistry>) -> Self {
Self::with_grace(registry, KILL_GRACE)
}
#[must_use]
pub fn with_grace(registry: Arc<BashBgRegistry>, grace: Duration) -> Self {
Self {
registry,
grace,
schema: OnceLock::new(),
}
}
#[must_use]
pub fn with_global_registry() -> Self {
Self::new(global_registry())
}
}
#[derive(Debug, Deserialize)]
struct KillShellInput {
shell_id: String,
}
#[async_trait]
impl Tool for KillShellTool {
fn name(&self) -> &'static str {
"KillShell"
}
fn description(&self) -> &'static str {
"Terminate a background shell launched via Bash with background:true. Sends SIGTERM, waits ~5s, then SIGKILL. Reaps the child."
}
fn input_schema(&self) -> &Value {
self.schema.get_or_init(|| {
json!({
"type": "object",
"properties": {
"shell_id": { "type": "string", "description": "Shell id returned by Bash(background=true)." }
},
"required": ["shell_id"]
})
})
}
async fn invoke(&self, input: Value, _cx: ToolContext) -> Result<Vec<ContentBlock>, ToolError> {
let parsed: KillShellInput = crate::parse_input(input)?;
let job = self.registry.get(&parsed.shell_id).ok_or_else(|| {
ToolError::execution(std::io::Error::other(format!(
"no background shell with id {}",
parsed.shell_id
)))
})?;
if job.snapshot_status() != BashStatus::Running {
return Ok(vec![ContentBlock::Text(TextBlock {
text: format!(
"Shell {} is already in status {}; no action taken.",
job.id,
job.snapshot_status().as_str()
),
cache_control: None,
})]);
}
kill_job_now(&job, false);
let deadline = Instant::now() + self.grace;
while Instant::now() < deadline {
if job.snapshot_status() != BashStatus::Running {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
if job.snapshot_status() == BashStatus::Running {
kill_job_now(&job, true);
tokio::time::sleep(Duration::from_millis(200)).await;
}
let status = job.snapshot_status();
let consumed_stdout = job.stdout.lock().unwrap().written();
let consumed_stderr = job.stderr.lock().unwrap().written();
Ok(vec![ContentBlock::Text(TextBlock {
text: format!(
"Killed shell {}; status={}; consumed_stdout={} bytes, consumed_stderr={} bytes",
job.id,
status.as_str(),
consumed_stdout,
consumed_stderr,
),
cache_control: None,
})])
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use tokio_util::sync::CancellationToken;
fn ctx() -> ToolContext {
ToolContext {
tool_use_id: "t1".into(),
cancel: CancellationToken::new(),
hooks: None,
turn_index: 0,
}
}
#[test]
fn ring_buffer_drops_oldest_at_cap() {
let mut rb = RingBuffer::with_cap(16);
rb.push(b"0123456789ABCDEF");
assert_eq!(rb.len(), 16);
assert_eq!(rb.written(), 16);
rb.push(b"GHIJKLMN");
assert_eq!(rb.len(), 16);
assert_eq!(rb.dropped(), 8);
let (text, end) = rb.snapshot();
assert_eq!(text, "89ABCDEFGHIJKLMN");
assert_eq!(end, 24);
}
#[test]
fn ring_buffer_handles_chunk_bigger_than_cap() {
let mut rb = RingBuffer::with_cap(4);
rb.push(b"0123456789");
let (text, end) = rb.snapshot();
assert_eq!(text, "6789");
assert_eq!(end, 10);
assert_eq!(rb.dropped(), 6);
}
#[test]
fn ring_buffer_read_since_returns_tail() {
let mut rb = RingBuffer::with_cap(32);
rb.push(b"hello world");
let (text, start, end) = rb.read_since(6);
assert_eq!(text, "world");
assert_eq!(start, 6);
assert_eq!(end, 11);
}
#[test]
fn build_shell_without_sandbox_invokes_bin_sh_directly() {
let cmd = build_shell("echo hi", &std::env::current_dir().unwrap(), None).unwrap();
let std_cmd = cmd.as_std();
assert_eq!(std_cmd.get_program(), "/bin/sh");
let args: Vec<String> = std_cmd
.get_args()
.map(|a| a.to_string_lossy().into_owned())
.collect();
assert_eq!(args, ["-c", "echo hi"]);
}
#[test]
fn build_shell_routes_through_the_sandbox_wrap() {
let policy = caliban_sandbox::Policy {
enabled: true,
..Default::default()
};
let shim = Arc::new(caliban_sandbox::SandboxedShim::new(policy).unwrap());
let cmd = build_shell("echo hi", &std::env::current_dir().unwrap(), Some(&shim)).unwrap();
let program = cmd.as_std().get_program().to_string_lossy().into_owned();
if shim.is_active() {
assert_ne!(
program, "/bin/sh",
"an active sandbox must wrap the shell program",
);
} else {
assert_eq!(program, "/bin/sh");
}
}
#[tokio::test]
async fn spawn_background_returns_shell_id_immediately() {
let reg = BashBgRegistry::new_for_test(1024 * 1024);
let start = Instant::now();
let id = spawn_background(
®,
"sleep 5".into(),
&std::env::current_dir().unwrap(),
None,
)
.unwrap();
assert!(start.elapsed() < Duration::from_millis(500));
assert_eq!(id.len(), 12);
assert!(reg.get(&id).is_some());
assert_eq!(reg.running_count(), 1);
if let Some(job) = reg.get(&id) {
kill_job_now(&job, true);
}
}
#[tokio::test]
async fn bash_output_returns_streaming_stdout() {
let reg = BashBgRegistry::new_for_test(1024 * 1024);
let id = spawn_background(
®,
"printf 'hello'; sleep 30".into(),
&std::env::current_dir().unwrap(),
None,
)
.unwrap();
for _ in 0..50 {
tokio::time::sleep(Duration::from_millis(50)).await;
let job = reg.get(&id).unwrap();
let (text, _e) = job.stdout.lock().unwrap().snapshot();
if text.contains("hello") {
break;
}
}
let tool = BashOutputTool::new(reg.clone());
let out = tool.invoke(json!({"shell_id": id}), ctx()).await.unwrap();
let ContentBlock::Text(t) = &out[0] else {
panic!("expected Text")
};
assert!(t.text.contains("hello"), "out: {}", t.text);
assert!(t.text.contains("status: running"), "out: {}", t.text);
if let Some(job) = reg.get(&id) {
kill_job_now(&job, true);
}
}
#[tokio::test]
async fn bash_output_supports_since_offset() {
let reg = BashBgRegistry::new_for_test(1024 * 1024);
let id = spawn_background(
®,
"printf 'aaaaa'; sleep 30".into(),
&std::env::current_dir().unwrap(),
None,
)
.unwrap();
for _ in 0..50 {
tokio::time::sleep(Duration::from_millis(50)).await;
let job = reg.get(&id).unwrap();
if job.stdout.lock().unwrap().written() >= 5 {
break;
}
}
let tool = BashOutputTool::new(reg.clone());
let out = tool
.invoke(json!({"shell_id": id, "since_offset": 3}), ctx())
.await
.unwrap();
let ContentBlock::Text(t) = &out[0] else {
panic!("expected Text")
};
assert!(t.text.contains("bytes 3..5"), "out: {}", t.text);
if let Some(job) = reg.get(&id) {
kill_job_now(&job, true);
}
}
#[tokio::test]
async fn kill_shell_terminates_running_job() {
let reg = BashBgRegistry::new_for_test(1024 * 1024);
let id = spawn_background(
®,
"sleep 60".into(),
&std::env::current_dir().unwrap(),
None,
)
.unwrap();
assert_eq!(reg.running_count(), 1);
let tool = KillShellTool::with_grace(reg.clone(), Duration::from_millis(500));
let out = tool
.invoke(json!({"shell_id": id.clone()}), ctx())
.await
.unwrap();
let ContentBlock::Text(t) = &out[0] else {
panic!("expected Text")
};
assert!(t.text.contains("Killed shell"), "out: {}", t.text);
for _ in 0..20 {
if reg.running_count() == 0 {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
assert_eq!(reg.running_count(), 0);
}
#[tokio::test]
async fn kill_all_terminates_every_running_job() {
let reg = BashBgRegistry::new_for_test(1024 * 1024);
let ids: Vec<String> = (0..3)
.map(|_| {
spawn_background(
®,
"sleep 60".into(),
&std::env::current_dir().unwrap(),
None,
)
.unwrap()
})
.collect();
assert_eq!(reg.running_count(), 3);
reg.kill_all();
for _ in 0..40 {
if reg.running_count() == 0 {
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
assert_eq!(reg.running_count(), 0);
for id in ids {
let job = reg.get(&id).unwrap();
assert_ne!(job.snapshot_status(), BashStatus::Running);
}
}
#[tokio::test]
async fn bash_output_unknown_id_returns_error() {
let reg = BashBgRegistry::new_for_test(1024);
let tool = BashOutputTool::new(reg);
let err = tool
.invoke(json!({"shell_id": "doesnotexist"}), ctx())
.await
.unwrap_err();
assert!(matches!(err, ToolError::Execution(_)));
let msg = format!("{err}");
assert!(msg.contains("no background shell"), "msg: {msg}");
}
#[test]
fn new_shell_id_is_12_chars() {
let id = new_shell_id();
assert_eq!(id.len(), 12);
assert!(id.chars().all(|c| c.is_ascii_alphanumeric()));
}
}