#[allow(unused_imports)]
use crate::sync_util::LockExt;
use indexmap::IndexMap;
use rig::completion::ToolDefinition;
use rig::tool::Tool;
use serde::Deserialize;
use std::sync::{Arc, Mutex};
use tokio::task::JoinHandle;
use crate::agent::tools::ToolError;
static GLOBAL: std::sync::LazyLock<BackgroundShellStore> =
std::sync::LazyLock::new(BackgroundShellStore::new);
pub fn global() -> BackgroundShellStore {
GLOBAL.clone()
}
const MAX_UNREAD_BYTES: usize = 1024 * 1024;
const STORE_CAPACITY: usize = 32;
const MAX_CONCURRENT_SHELLS: usize = 8;
#[derive(Debug, Clone, PartialEq)]
pub enum ShellStatus {
Running,
Exited(i32),
Killed,
Failed(String),
}
impl ShellStatus {
pub fn is_running(&self) -> bool {
matches!(self, ShellStatus::Running)
}
pub fn label(&self) -> String {
match self {
ShellStatus::Running => "running".to_string(),
ShellStatus::Exited(code) => format!("exited({code})"),
ShellStatus::Killed => "killed".to_string(),
ShellStatus::Failed(e) => format!("failed: {e}"),
}
}
}
struct ShellEntry {
command: String,
unread: String,
truncated: bool,
status: ShellStatus,
handle: Option<JoinHandle<()>>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ShellInfo {
pub id: String,
pub command: String,
pub status: ShellStatus,
}
#[derive(Debug, Clone, Default)]
pub struct BackgroundShellStore {
inner: Arc<Mutex<IndexMap<String, ShellEntry>>>,
}
impl std::fmt::Debug for ShellEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShellEntry")
.field("command", &self.command)
.field("unread_len", &self.unread.len())
.field("truncated", &self.truncated)
.field("status", &self.status)
.field("has_handle", &self.handle.is_some())
.finish()
}
}
impl BackgroundShellStore {
pub fn new() -> Self {
Self::default()
}
pub fn max_concurrent() -> usize {
MAX_CONCURRENT_SHELLS
}
fn lock(&self) -> std::sync::MutexGuard<'_, IndexMap<String, ShellEntry>> {
self.inner.lock_ignore_poison()
}
#[cfg(test)]
pub fn register(&self, id: String, command: String) {
let mut map = self.lock();
Self::insert_locked(&mut map, id, command);
}
pub fn try_register(&self, id: String, command: String) -> bool {
let mut map = self.lock();
let running = map.values().filter(|e| e.status.is_running()).count();
if running >= MAX_CONCURRENT_SHELLS {
return false;
}
Self::insert_locked(&mut map, id, command);
true
}
fn insert_locked(map: &mut IndexMap<String, ShellEntry>, id: String, command: String) {
if !map.contains_key(&id) && map.len() >= STORE_CAPACITY {
let victim = map
.iter()
.find(|(_, e)| !e.status.is_running())
.map(|(id, _)| id.clone())
.or_else(|| map.keys().next().cloned());
if let Some(victim) = victim
&& let Some(mut e) = map.shift_remove(&victim)
&& let Some(h) = e.handle.take()
{
h.abort();
}
}
map.insert(
id,
ShellEntry {
command,
unread: String::new(),
truncated: false,
status: ShellStatus::Running,
handle: None,
},
);
}
pub fn attach_handle(&self, id: &str, handle: JoinHandle<()>) {
let mut map = self.lock();
match map.get_mut(id) {
Some(e) if e.status.is_running() => e.handle = Some(handle),
_ => handle.abort(),
}
}
pub fn append(&self, id: &str, chunk: &str) {
let mut map = self.lock();
let Some(e) = map.get_mut(id) else {
return;
};
if e.unread.len() + chunk.len() <= MAX_UNREAD_BYTES {
e.unread.push_str(chunk);
} else if !e.truncated {
e.truncated = true;
let room = MAX_UNREAD_BYTES.saturating_sub(e.unread.len());
let mut take = room.min(chunk.len());
while take > 0 && !chunk.is_char_boundary(take) {
take -= 1;
}
e.unread.push_str(&chunk[..take]);
e.unread.push_str(
"\n…[background shell output exceeded the unread-buffer cap; call bash_output more often to drain it]",
);
}
}
pub fn finish(&self, id: &str, status: ShellStatus) {
let mut map = self.lock();
if let Some(e) = map.get_mut(id)
&& e.status.is_running()
{
e.status = status;
e.handle = None;
}
}
pub fn read_new(&self, id: &str) -> Option<(String, ShellStatus)> {
let mut map = self.lock();
let e = map.get_mut(id)?;
let out = std::mem::take(&mut e.unread);
e.truncated = false;
Some((out, e.status.clone()))
}
pub fn kill(&self, id: &str) -> bool {
let mut map = self.lock();
let Some(e) = map.get_mut(id) else {
return false;
};
if !e.status.is_running() {
return false;
}
if let Some(h) = e.handle.take() {
h.abort();
}
e.status = ShellStatus::Killed;
true
}
pub fn running_count(&self) -> usize {
self.lock()
.values()
.filter(|e| e.status.is_running())
.count()
}
pub fn list(&self) -> Vec<ShellInfo> {
self.lock()
.iter()
.map(|(id, e)| ShellInfo {
id: id.clone(),
command: e.command.clone(),
status: e.status.clone(),
})
.collect()
}
pub fn kill_all(&self) {
let mut map = self.lock();
for e in map.values_mut() {
if e.status.is_running() {
if let Some(h) = e.handle.take() {
h.abort();
}
e.status = ShellStatus::Killed;
}
}
}
}
#[derive(Deserialize)]
pub struct BashOutputArgs {
pub id: String,
}
pub struct BashOutputTool {
store: BackgroundShellStore,
}
impl BashOutputTool {
pub fn new(store: BackgroundShellStore) -> Self {
Self { store }
}
}
impl Tool for BashOutputTool {
const NAME: &'static str = "bash_output";
type Error = ToolError;
type Args = BashOutputArgs;
type Output = String;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "bash_output".to_string(),
description: "Read new output from a background shell (one started with bash(background=true)). Returns the output produced since your last call plus the shell's status (running / exited(code) / killed / failed). Poll this to follow a long-running command; call kill_shell to stop it.".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"id": { "type": "string", "description": "The background shell id returned by bash(background=true)." }
},
"required": ["id"]
}),
}
}
async fn call(&self, args: BashOutputArgs) -> Result<String, ToolError> {
match self.store.read_new(&args.id) {
Some((out, status)) => {
let body = if out.is_empty() {
"(no new output)".to_string()
} else {
out
};
Ok(format!(
"[shell {} — {}]\n{}",
args.id,
status.label(),
body
))
}
None => Err(ToolError::Msg(format!(
"no background shell with id {:?} (it may have been evicted)",
args.id
))),
}
}
}
#[derive(Deserialize)]
pub struct KillShellArgs {
pub id: String,
}
pub struct KillShellTool {
store: BackgroundShellStore,
}
impl KillShellTool {
pub fn new(store: BackgroundShellStore) -> Self {
Self { store }
}
}
impl Tool for KillShellTool {
const NAME: &'static str = "kill_shell";
type Error = ToolError;
type Args = KillShellArgs;
type Output = String;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "kill_shell".to_string(),
description: "Stop a running background shell (one started with bash(background=true)) by id. Kills the whole process group. No-op if it already exited.".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"id": { "type": "string", "description": "The background shell id to kill." }
},
"required": ["id"]
}),
}
}
async fn call(&self, args: KillShellArgs) -> Result<String, ToolError> {
if self.store.kill(&args.id) {
Ok(format!("killed background shell {}", args.id))
} else {
Ok(format!(
"no running background shell with id {:?} (already exited or unknown)",
args.id
))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn read_new_drains_unread_and_reports_status() {
let s = BackgroundShellStore::new();
s.register("a".into(), "sleep 1".into());
s.append("a", "line1\n");
s.append("a", "line2\n");
let (out, st) = s.read_new("a").unwrap();
assert_eq!(out, "line1\nline2\n");
assert_eq!(st, ShellStatus::Running);
s.append("a", "line3\n");
let (out2, _) = s.read_new("a").unwrap();
assert_eq!(out2, "line3\n");
let (out3, _) = s.read_new("a").unwrap();
assert_eq!(out3, "");
}
#[test]
fn read_new_unknown_id_is_none() {
let s = BackgroundShellStore::new();
assert!(s.read_new("nope").is_none());
}
#[test]
fn try_register_enforces_running_cap() {
let s = BackgroundShellStore::new();
for n in 0..MAX_CONCURRENT_SHELLS {
assert!(
s.try_register(format!("s{n}"), "x".into()),
"registration {n} within cap must succeed"
);
}
assert_eq!(s.running_count(), MAX_CONCURRENT_SHELLS);
assert!(!s.try_register("over".into(), "x".into()));
assert!(s.read_new("over").is_none());
s.finish("s0", ShellStatus::Exited(0));
assert!(s.try_register("after".into(), "x".into()));
}
#[test]
fn finish_sets_terminal_and_first_terminal_wins() {
let s = BackgroundShellStore::new();
s.register("a".into(), "x".into());
assert_eq!(s.running_count(), 1);
s.finish("a", ShellStatus::Exited(0));
let (_, st) = s.read_new("a").unwrap();
assert_eq!(st, ShellStatus::Exited(0));
assert_eq!(s.running_count(), 0);
s.finish("a", ShellStatus::Killed);
assert_eq!(s.read_new("a").unwrap().1, ShellStatus::Exited(0));
}
#[test]
fn kill_marks_killed_only_when_running() {
let s = BackgroundShellStore::new();
s.register("a".into(), "x".into());
assert!(s.kill("a"));
assert_eq!(s.read_new("a").unwrap().1, ShellStatus::Killed);
assert!(!s.kill("a"));
assert!(!s.kill("unknown"));
}
#[test]
fn unread_buffer_is_capped() {
let s = BackgroundShellStore::new();
s.register("a".into(), "flood".into());
let chunk = "x".repeat(100_000);
for _ in 0..20 {
s.append("a", &chunk);
}
let (out, _) = s.read_new("a").unwrap();
assert!(out.len() <= MAX_UNREAD_BYTES + 200, "len was {}", out.len());
assert!(out.contains("exceeded the unread-buffer cap"));
}
#[tokio::test]
async fn eviction_aborts_an_evicted_running_shell() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let store = BackgroundShellStore::new();
let dropped = Arc::new(AtomicBool::new(false));
for n in 0..STORE_CAPACITY {
let id = format!("s{n}");
store.register(id.clone(), "x".to_string());
if n == 0 {
let flag = dropped.clone();
let h = tokio::spawn(async move {
struct G(Arc<AtomicBool>);
impl Drop for G {
fn drop(&mut self) {
self.0.store(true, Ordering::SeqCst);
}
}
let _g = G(flag);
std::future::pending::<()>().await;
});
store.attach_handle(&id, h);
}
}
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
store.register("overflow".to_string(), "y".to_string());
for _ in 0..100 {
if dropped.load(Ordering::SeqCst) {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
}
assert!(
dropped.load(Ordering::SeqCst),
"evicted running shell's drain task must be aborted, not detached"
);
assert!(
store.list().iter().all(|s| s.id != "s0"),
"evicted shell must be gone from the registry"
);
}
#[tokio::test]
async fn eviction_prefers_terminal_over_running() {
let store = BackgroundShellStore::new();
store.register("s0".to_string(), "x".to_string());
store.register("term".to_string(), "x".to_string());
store.finish("term", ShellStatus::Exited(0));
for n in 1..(STORE_CAPACITY - 1) {
store.register(format!("s{n}"), "x".to_string());
}
store.register("overflow".to_string(), "x".to_string());
let ids: Vec<_> = store.list().into_iter().map(|s| s.id).collect();
assert!(ids.contains(&"s0".to_string()), "running s0 must survive");
assert!(!ids.contains(&"term".to_string()), "terminal entry evicted");
}
#[test]
fn list_reports_all_shells() {
let s = BackgroundShellStore::new();
s.register("a".into(), "cmd-a".into());
s.register("b".into(), "cmd-b".into());
s.finish("b", ShellStatus::Exited(1));
let rows = s.list();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].command, "cmd-a");
assert_eq!(rows[1].status, ShellStatus::Exited(1));
}
}