use crate::agent::extension::{AgentTool, Cancel, Extension, ToolOutput};
use crate::agent::extension::{ToolRenderContext, ToolRenderer};
use crate::tui::Theme;
use crate::tui::visual_truncate::truncate_to_visual_lines;
use anyhow::Context;
use async_trait::async_trait;
use std::borrow::Cow;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Instant;
use tokio::sync::{Mutex as TokioMutex, mpsc::UnboundedSender};
pub struct BashExtension {
cwd: std::path::PathBuf,
}
impl BashExtension {
pub fn new(cwd: std::path::PathBuf) -> Self {
Self { cwd }
}
}
impl Extension for BashExtension {
fn name(&self) -> Cow<'static, str> {
"bash".into()
}
fn tools(&self) -> Vec<Box<dyn AgentTool>> {
vec![Box::new(BashTool {
cwd: self.cwd.clone(),
})]
}
}
struct BashTool {
cwd: std::path::PathBuf,
}
const DEFAULT_MAX_LINES: usize = 2000;
const DEFAULT_MAX_BYTES: usize = 50 * 1024; const DEFAULT_TIMEOUT_SECS: u64 = 300;
#[cfg(unix)]
fn kill_process_group(pid: u32) {
if pid > 0 {
let _ = std::process::Command::new("kill")
.arg("--")
.arg(format!("-{}", pid))
.spawn();
}
}
#[cfg(not(unix))]
fn kill_process_group(pid: u32) {
let _ = pid;
}
fn spawn_bash_command(
command: &str,
cwd: &std::path::Path,
) -> std::io::Result<tokio::process::Child> {
#[cfg(unix)]
{
use std::os::unix::process::CommandExt;
let mut std_cmd = std::process::Command::new("sh");
std_cmd.arg("-c").arg(command).current_dir(cwd);
unsafe {
std_cmd.pre_exec(|| {
libc::setpgid(0, 0);
Ok(())
});
}
let mut tokio_cmd = tokio::process::Command::from(std_cmd);
tokio_cmd
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
}
#[cfg(not(unix))]
{
tokio::process::Command::new("sh")
.arg("-c")
.arg(command)
.current_dir(cwd)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
}
}
fn finish_bash_execution(
_command: &str,
combined: &str,
exit_code: i32,
cancelled: bool,
_started_at: Instant,
on_update: Option<UnboundedSender<ToolOutput>>,
) -> Result<ToolOutput, anyhow::Error> {
let trunc = truncate_tail(combined, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES);
let mut result_text = if trunc.content.is_empty() {
"(no output)".to_string()
} else {
trunc.content.clone()
};
if trunc.truncated {
let tmp_dir = std::env::temp_dir().join("rab-bash");
let _ = std::fs::create_dir_all(&tmp_dir);
let tmp_path = tmp_dir.join(format!("{}.txt", uuid::Uuid::new_v4()));
let saved = if std::fs::write(&tmp_path, combined).is_ok() {
Some(tmp_path)
} else {
None
};
let start_line = trunc.total_lines - trunc.output_lines + 1;
let end_line = trunc.total_lines;
let notice = if trunc.truncated_by == "lines" {
format!(
"\n\n[Showing lines {}-{} of {}. Full output: {}]",
start_line,
end_line,
trunc.total_lines,
saved
.as_ref()
.map(|p| p.display().to_string())
.unwrap_or_default()
)
} else {
format!(
"\n\n[Showing lines {}-{} of {} ({} limit). Full output: {}]",
start_line,
end_line,
trunc.total_lines,
format_size(DEFAULT_MAX_BYTES),
saved
.as_ref()
.map(|p| p.display().to_string())
.unwrap_or_default()
)
};
result_text.push_str(¬ice);
}
if let Some(ref tx) = on_update {
let _ = tx.send(ToolOutput::ok(result_text.clone()));
}
if cancelled {
let err_msg = if result_text.is_empty() || result_text == "(no output)" {
"Command aborted".to_string()
} else {
format!("{}\n\nCommand aborted", result_text)
};
return Err(anyhow::anyhow!("{}", err_msg));
}
if exit_code != 0 {
let err_msg = if result_text.is_empty() || result_text == "(no output)" {
format!("Command exited with code {}", exit_code)
} else {
format!("{}\n\nCommand exited with code {}", result_text, exit_code)
};
return Err(anyhow::anyhow!("{}", err_msg));
}
Ok(ToolOutput::ok(result_text))
}
fn format_size(bytes: usize) -> String {
if bytes < 1024 {
format!("{}B", bytes)
} else if bytes < 1024 * 1024 {
format!("{:.1}KB", bytes as f64 / 1024.0)
} else {
format!("{:.1}MB", bytes as f64 / (1024.0 * 1024.0))
}
}
struct TailTruncation {
content: String,
truncated: bool,
#[allow(dead_code)]
total_lines: usize,
#[allow(dead_code)]
output_lines: usize,
#[allow(dead_code)]
output_bytes: usize,
#[allow(dead_code)]
truncated_by: &'static str, #[allow(dead_code)]
last_line_partial: bool,
}
fn truncate_tail(content: &str, max_lines: usize, max_bytes: usize) -> TailTruncation {
let total_bytes = content.len();
let lines: Vec<&str> = content.lines().collect();
let total_lines = lines.len();
if total_lines <= max_lines && total_bytes <= max_bytes {
return TailTruncation {
content: content.to_string(),
truncated: false,
total_lines,
output_lines: total_lines,
output_bytes: total_bytes,
truncated_by: "",
last_line_partial: false,
};
}
let mut output: Vec<&str> = Vec::new();
let mut byte_count: usize = 0;
let mut truncated_by = "lines";
let mut last_line_partial = false;
for line in lines.iter().rev().take(max_lines) {
let line_bytes = line.len();
let with_newline = if output.is_empty() {
line_bytes
} else {
line_bytes + 1 };
if byte_count + with_newline > max_bytes {
truncated_by = "bytes";
if output.is_empty() {
let end_start = line.len().saturating_sub(max_bytes);
let truncated_line = &line[end_start..];
output.push(truncated_line);
byte_count = truncated_line.len();
last_line_partial = true;
}
break;
}
output.push(line);
byte_count += with_newline;
}
if output.len() >= max_lines && byte_count <= max_bytes {
truncated_by = "lines";
}
output.reverse();
TailTruncation {
content: output.join("\n"),
truncated: true,
total_lines,
output_lines: output.len(),
output_bytes: byte_count,
truncated_by,
last_line_partial,
}
}
#[async_trait]
impl AgentTool for BashTool {
fn name(&self) -> &str {
"bash"
}
fn description(&self) -> &str {
"Execute a bash command in the current working directory. Returns stdout and stderr. \
Output is truncated to last 2000 lines or 50KB (whichever is hit first). If truncated, \
full output is saved to a temp file. Optionally provide a timeout in seconds."
}
fn parameters(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"required": ["command"],
"properties": {
"command": {
"type": "string",
"description": "Bash command to execute"
},
"timeout": {
"type": "number",
"description": "Timeout in seconds (optional, no default timeout)"
}
}
})
}
fn label(&self) -> &str {
"Execute bash commands (ls, grep, find, etc.)"
}
fn renderer(&self) -> Option<Box<dyn ToolRenderer>> {
Some(Box::new(BashRenderer))
}
async fn execute(
&self,
tool_call_id: String,
args: serde_json::Value,
cancel: Cancel,
on_update: Option<UnboundedSender<ToolOutput>>,
) -> anyhow::Result<ToolOutput> {
let _ = tool_call_id;
let command = args["command"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing 'command' argument"))?;
let timeout = args["timeout"].as_u64().or(Some(DEFAULT_TIMEOUT_SECS));
let started_at = Instant::now();
cancel.check()?;
let mut child = spawn_bash_command(command, &self.cwd)
.with_context(|| format!("Failed to spawn command: {}", command))?;
let pid = child.id().unwrap_or(0);
let combined = Arc::new(TokioMutex::new(String::new()));
let combined_clone = combined.clone();
let stdout_pipe = child
.stdout
.take()
.ok_or_else(|| anyhow::anyhow!("Failed to capture stdout"))?;
let stderr_pipe = child
.stderr
.take()
.ok_or_else(|| anyhow::anyhow!("Failed to capture stderr"))?;
use tokio::io::AsyncReadExt;
let read_task = tokio::spawn(async move {
let mut stdout_buf = vec![0u8; 4096];
let mut stderr_buf = vec![0u8; 4096];
let mut stdout_reader = stdout_pipe;
let mut stderr_reader = stderr_pipe;
let mut stdout_done = false;
let mut stderr_done = false;
loop {
tokio::select! {
result = stdout_reader.read(&mut stdout_buf), if !stdout_done => {
match result {
Ok(0) => stdout_done = true,
Ok(n) => {
let mut out = combined_clone.lock().await;
out.push_str(&String::from_utf8_lossy(&stdout_buf[..n]));
}
Err(_) => stdout_done = true,
}
}
result = stderr_reader.read(&mut stderr_buf), if !stderr_done => {
match result {
Ok(0) => stderr_done = true,
Ok(n) => {
let mut out = combined_clone.lock().await;
out.push_str(&String::from_utf8_lossy(&stderr_buf[..n]));
}
Err(_) => stderr_done = true,
}
}
}
if stdout_done && stderr_done {
break;
}
}
});
let cancelled = Arc::new(AtomicBool::new(false));
let cancel_clone = cancelled.clone();
let _cancel_monitor: tokio::task::JoinHandle<()> = tokio::spawn(async move {
while !cancel.is_cancelled() {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
cancel_clone.store(true, Ordering::SeqCst);
kill_process_group(pid);
});
let timeout_dur = timeout.map(std::time::Duration::from_secs);
loop {
if cancelled.load(Ordering::SeqCst) {
kill_process_group(pid);
read_task.abort();
return Err(anyhow::anyhow!("Command aborted"));
}
if let Some(dur) = timeout_dur
&& started_at.elapsed() > dur
{
kill_process_group(pid);
read_task.abort();
return Err(anyhow::anyhow!(
"Command timed out after {} seconds",
timeout.unwrap_or(0)
));
}
if let Some(ref tx) = on_update {
let out = combined.lock().await;
if !out.is_empty() {
let elapsed = started_at.elapsed();
let display = format!(
"{}\n\n[Elapsed {:.1}s]",
out.trim_end(),
elapsed.as_secs_f64()
);
let _ = tx.send(ToolOutput::ok(display));
}
}
match child.try_wait() {
Ok(Some(status)) => {
read_task.await.ok();
let combined_str = combined.lock().await.clone();
let exit_code = status.code().unwrap_or(-1);
return finish_bash_execution(
command,
&combined_str,
exit_code,
false,
started_at,
on_update,
);
}
Ok(None) => {
tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
}
Err(_) => {
read_task.await.ok();
let combined_str = combined.lock().await.clone();
let exit_code = -1;
return finish_bash_execution(
command,
&combined_str,
exit_code,
false,
started_at,
on_update,
);
}
}
}
}
}
struct BashRenderer;
fn parse_command(cmd: &str) -> Option<(&'static str, Option<String>)> {
let trimmed = cmd.trim();
let effective = {
let mut rest = trimmed;
loop {
if let Some(eq_pos) = rest.find('=') {
let var_name = &rest[..eq_pos];
if !var_name.is_empty() && var_name.chars().all(|c| c.is_alphanumeric() || c == '_')
{
let after_eq = &rest[eq_pos + 1..];
if let Some(space_pos) = after_eq.find(' ') {
rest = after_eq[space_pos + 1..].trim_start();
continue;
} else {
rest = "";
break;
}
}
}
break;
}
rest
};
if effective.starts_with("ls ") || effective == "ls" {
let path = extract_ls_path(effective);
return Some(("ls", path));
}
if effective.starts_with("grep ") || effective.starts_with("rg ") {
let info = extract_grep_info(effective);
return Some(("grep", info));
}
if effective.starts_with("find ") {
let info = extract_find_info(effective);
return Some(("find", info));
}
if effective.starts_with("cat ") || effective == "cat" {
let path = effective.strip_prefix("cat ").map(|s| s.trim().to_string());
return Some(("cat", path));
}
if effective.starts_with("head ") || effective.starts_with("tail ") {
let (cmd_name, rest) = if effective.starts_with("head") {
("head", effective.strip_prefix("head").unwrap_or(""))
} else {
("tail", effective.strip_prefix("tail").unwrap_or(""))
};
let path = rest.trim();
let path_opt = if path.is_empty() {
None
} else {
Some(path.to_string())
};
return Some((cmd_name, path_opt));
}
if effective.starts_with("wc ") || effective == "wc" {
let path = effective.strip_prefix("wc ").map(|s| s.trim().to_string());
return Some(("wc", path));
}
None
}
fn extract_ls_path(cmd: &str) -> Option<String> {
let args = cmd.strip_prefix("ls").unwrap_or("").trim();
if args.is_empty() {
Some(".".to_string())
} else {
args.split_whitespace()
.rfind(|a| !a.starts_with('-'))
.map(|s| s.to_string())
}
}
fn extract_grep_info(cmd: &str) -> Option<String> {
let args = cmd
.strip_prefix("grep")
.or_else(|| cmd.strip_prefix("rg"))
.unwrap_or("")
.trim();
if args.is_empty() {
return None;
}
let mut pattern = None;
let mut files = Vec::new();
let mut skip_next = false;
for arg in args.split_whitespace() {
if skip_next {
skip_next = false;
continue;
}
if arg.starts_with('-') {
if arg == "-n" || arg == "-C" || arg == "-A" || arg == "-B" || arg == "--max-count" {
skip_next = true;
}
continue;
}
if pattern.is_none() {
pattern = Some(arg);
} else {
files.push(arg);
}
}
let mut desc = String::new();
if let Some(p) = pattern {
desc.push_str(p);
}
if !files.is_empty() {
desc.push_str(" in ");
desc.push_str(&files.join(", "));
}
if desc.is_empty() { None } else { Some(desc) }
}
fn extract_find_info(cmd: &str) -> Option<String> {
let args = cmd.strip_prefix("find").unwrap_or("").trim();
if args.is_empty() {
return Some(".".to_string());
}
let mut path = None;
let mut name = None;
let mut skip_next = false;
for arg in args.split_whitespace() {
if skip_next {
skip_next = false;
continue;
}
if arg == "-name" || arg == "-path" || arg == "-type" {
skip_next = true;
if arg == "-name" {
continue;
}
}
if arg.starts_with('-') {
continue;
}
if path.is_none() {
path = Some(arg);
}
}
let mut it = args.split_whitespace();
while let Some(arg) = it.next() {
if arg == "-name" {
name = it.next();
}
}
let mut desc = path.unwrap_or(".").to_string();
if let Some(n) = name {
desc.push_str(&format!(" (name={})", n));
}
Some(desc)
}
fn format_command_header(cmd: &str, theme: &dyn Theme) -> Option<String> {
let (name, desc) = parse_command(cmd)?;
let title = theme.fg("toolTitle", &theme.bold(name));
let detail = desc
.map(|d| format!(" {}", theme.fg("accent", &d)))
.unwrap_or_default();
Some(format!("{}{}", title, detail))
}
impl ToolRenderer for BashRenderer {
fn render_call(
&self,
args: &serde_json::Value,
_width: usize,
theme: &dyn Theme,
_ctx: &ToolRenderContext,
) -> Vec<String> {
let cmd = args
.get("command")
.and_then(|v| v.as_str())
.unwrap_or("...");
let timeout = args.get("timeout").and_then(|v| v.as_i64());
let timeout_suffix = timeout
.map(|t| theme.fg("muted", &format!(" (timeout {}s)", t)))
.unwrap_or_default();
if let Some(header) = format_command_header(cmd, theme) {
vec![format!("{}{}", header, timeout_suffix)]
} else {
vec![format!(
"{}{}",
theme.fg("toolTitle", &theme.bold(&format!("$ {}", cmd))),
timeout_suffix
)]
}
}
fn render_result(
&self,
content: &str,
width: usize,
theme: &dyn Theme,
ctx: &ToolRenderContext,
) -> Vec<String> {
let mut lines: Vec<String> = Vec::new();
let clean = strip_context_truncation_footer(content);
let all_lines: Vec<&str> = clean.split('\n').collect();
if all_lines.is_empty() || (all_lines.len() == 1 && all_lines[0].is_empty()) {
return lines;
}
let preview_count = 5;
let (preview_lines, hidden_line_count) = if ctx.expanded {
(all_lines.clone(), 0)
} else {
truncate_to_visual_lines(&all_lines, width, preview_count)
};
if !ctx.expanded && hidden_line_count > 0 {
let hint = if ctx.expand_key.is_empty() {
theme.fg("muted", &format!("... {} earlier lines", hidden_line_count))
} else {
theme.fg(
"muted",
&format!(
"... ({} earlier lines, {} to expand)",
hidden_line_count, ctx.expand_key
),
)
};
lines.push(hint);
}
let fg_key = if ctx.is_error { "error" } else { "toolOutput" };
for line in &preview_lines {
if line.is_empty() {
lines.push(String::new());
} else {
lines.push(theme.fg(fg_key, line));
}
}
if let Some(secs) = ctx.duration_secs {
let is_complete = ctx.exit_code.is_some() || ctx.cancelled;
let label = if is_complete { "Took" } else { "Elapsed" };
lines.push(theme.fg("muted", &format!("{} {:.1}s", label, secs)));
}
if ctx.cancelled {
lines.push(theme.fg("warning", "(cancelled)"));
} else if let Some(code) = ctx.exit_code
&& code != 0
{
lines.push(theme.fg("warning", &format!("(exit {})", code)));
}
if ctx.was_truncated {
if let Some(ref path) = ctx.full_output_path {
lines.push(theme.fg(
"warning",
&format!("Output truncated. Full output: {}", path),
));
} else {
lines.push(theme.fg("warning", "Output truncated."));
}
}
lines
}
}
fn strip_context_truncation_footer(output: &str) -> String {
let lines: Vec<&str> = output.lines().collect();
if lines.len() < 3 {
return output.to_string();
}
let last = lines.last().map_or("", |v| v).trim();
if last.starts_with('[')
&& (last.contains("Showing lines") || last.contains("Showing last"))
&& last.contains("Full output:")
{
let before: Vec<&str> = lines[..lines.len() - 1].to_vec();
if !before.is_empty() && before[before.len() - 1].is_empty() {
before[..before.len() - 1].join("\n")
} else {
before.join("\n")
}
} else {
output.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tool() -> BashTool {
BashTool {
cwd: std::env::temp_dir(),
}
}
#[tokio::test]
async fn runs_simple_command() {
let tool = make_tool();
let output = tool
.execute(
"id".into(),
serde_json::json!({"command": "echo hello"}),
Cancel::new(),
None,
)
.await
.unwrap();
assert!(output.content.contains("hello"));
}
#[tokio::test]
async fn captures_stderr() {
let tool = make_tool();
let output = tool
.execute(
"id".into(),
serde_json::json!({"command": "echo err >&2"}),
Cancel::new(),
None,
)
.await
.unwrap();
assert!(output.content.contains("err"));
}
#[tokio::test]
async fn cancel_aborts() {
let tool = make_tool();
let cancel = Cancel::new();
cancel.cancel();
let result = tool
.execute(
"id".into(),
serde_json::json!({"command": "sleep 10"}),
cancel,
None,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("cancelled") || err.contains("aborted"),
"expected cancellation error, got: {}",
err
);
}
#[tokio::test]
async fn timeout_works() {
let tool = make_tool();
let result = tool
.execute(
"id".into(),
serde_json::json!({"command": "sleep 10", "timeout": 1}),
Cancel::new(),
None,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("timed out"));
}
#[test]
fn test_truncate_tail_no_truncation() {
let result = truncate_tail("hello\nworld\n", 2000, 50000);
assert!(!result.truncated);
assert_eq!(result.content, "hello\nworld\n");
}
#[test]
fn test_truncate_tail_by_lines() {
let content: String = (1..=5000).map(|i| format!("line {}\n", i)).collect();
let result = truncate_tail(&content, 2000, 50000);
assert!(result.truncated);
assert!(result.content.starts_with("line 3001"));
assert_eq!(result.content.lines().count(), 2000);
}
#[test]
fn test_truncate_tail_by_bytes() {
let content: String = (1..=100)
.map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
.collect();
let result = truncate_tail(&content, 2000, 50000);
assert!(result.truncated);
assert!(result.content.len() <= 50000);
assert!(result.content.lines().count() < 100);
}
#[test]
fn test_truncate_tail_partial_last_line() {
let content = format!("short\n{}\n", "x".repeat(60000));
let result = truncate_tail(&content, 2000, 50000);
assert!(result.truncated);
assert!(!result.content.starts_with("short"));
assert!(result.content.len() <= 50000);
}
#[test]
fn test_truncate_tail_empty() {
let result = truncate_tail("", 2000, 50000);
assert!(!result.truncated);
assert_eq!(result.content, "");
}
#[tokio::test]
async fn exit_code_nonzero() {
let tool = make_tool();
let result = tool
.execute(
"id".into(),
serde_json::json!({"command": "exit 42"}),
Cancel::new(),
None,
)
.await;
assert!(result.is_err(), "non-zero exit should return error");
let err = result.unwrap_err().to_string();
assert!(err.contains("exited with code 42"), "got: {}", err);
}
#[tokio::test]
async fn exit_code_with_output() {
let tool = make_tool();
let result = tool
.execute(
"id".into(),
serde_json::json!({"command": "echo before && exit 1"}),
Cancel::new(),
None,
)
.await;
assert!(result.is_err(), "non-zero exit should return error");
let err = result.unwrap_err().to_string();
assert!(err.contains("before"), "got: {}", err);
assert!(err.contains("exited with code 1"), "got: {}", err);
}
#[tokio::test]
async fn no_output() {
let tool = make_tool();
let output = tool
.execute(
"id".into(),
serde_json::json!({"command": "true"}),
Cancel::new(),
None,
)
.await
.unwrap();
assert!(
output.content.contains("(no output)"),
"got: {}",
output.content
);
}
#[tokio::test]
async fn combined_stdout_stderr() {
let tool = make_tool();
let output = tool
.execute(
"id".into(),
serde_json::json!({"command": "echo out; echo err >&2"}),
Cancel::new(),
None,
)
.await
.unwrap();
assert!(output.content.contains("out"), "got: {}", output.content);
assert!(output.content.contains("err"), "got: {}", output.content);
}
#[tokio::test]
async fn runs_in_cwd() {
let tmp = std::env::temp_dir().join(format!("rab-bash-cwd-{}", uuid::Uuid::new_v4()));
std::fs::create_dir_all(&tmp).unwrap();
std::fs::write(tmp.join("marker.txt"), "hello").unwrap();
let tool = BashTool { cwd: tmp.clone() };
let output = tool
.execute(
"id".into(),
serde_json::json!({"command": "cat marker.txt"}),
Cancel::new(),
None,
)
.await
.unwrap();
assert!(output.content.contains("hello"), "got: {}", output.content);
}
#[tokio::test]
async fn missing_command_errors() {
let tool = make_tool();
let result = tool
.execute("id".into(), serde_json::json!({}), Cancel::new(), None)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("command"), "got: {}", err);
}
#[tokio::test]
async fn timeout_with_partial_output() {
let tool = make_tool();
let result = tool
.execute(
"id".into(),
serde_json::json!({"command": "echo start && sleep 10 && echo end", "timeout": 1}),
Cancel::new(),
None,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("timed out"), "got: {}", err);
}
#[tokio::test]
async fn cancel_during_long_command() {
let tool = make_tool();
let cancel = Cancel::new();
let cancel_clone = cancel.clone();
let handle = tokio::spawn(async move {
tool.execute(
"id".into(),
serde_json::json!({"command": "sleep 30"}),
cancel_clone,
None,
)
.await
});
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
cancel.cancel();
let result = handle.await.unwrap();
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("aborted") || err.contains("cancelled"),
"expected cancellation error, got: {}",
err
);
}
#[test]
fn test_truncate_tail_exact_line_fit() {
let lines: String = (1..=2000).map(|i| format!("line {}\n", i)).collect();
let result = truncate_tail(&lines, 2000, 50000);
assert!(
!result.truncated,
"should not truncate when exactly at line limit"
);
assert!(result.content.lines().count() == 2000);
}
#[test]
fn test_truncate_tail_one_over_line_limit() {
let lines: String = (1..=2001).map(|i| format!("line {}\n", i)).collect();
let result = truncate_tail(&lines, 2000, 50000);
assert!(result.truncated);
assert_eq!(result.content.lines().count(), 2000);
assert!(result.content.starts_with("line 2"));
}
#[test]
fn test_truncate_tail_exact_byte_fit() {
let line = "a".repeat(50000);
let result = truncate_tail(&line, 2000, 50000);
assert!(!result.truncated);
}
#[test]
fn test_truncate_tail_one_byte_over() {
let line = "a".repeat(50001);
let result = truncate_tail(&line, 2000, 50000);
assert!(result.truncated);
assert!(result.content.len() <= 50000);
}
#[test]
fn test_truncate_tail_single_line_under_limit() {
let result = truncate_tail("hello world", 2000, 50000);
assert!(!result.truncated);
assert_eq!(result.content, "hello world");
}
#[test]
fn test_truncate_tail_trailing_newline() {
let result = truncate_tail("a\nb\n", 2000, 50000);
assert!(!result.truncated);
assert_eq!(result.content, "a\nb\n");
}
#[test]
fn test_truncate_tail_no_trailing_newline() {
let result = truncate_tail("a\nb", 2000, 50000);
assert!(!result.truncated);
assert_eq!(result.content, "a\nb");
}
#[test]
fn test_truncate_tail_single_line_exceeds_limit() {
let content = "x".repeat(60000);
let result = truncate_tail(&content, 2000, 50000);
assert!(result.truncated);
assert!(result.last_line_partial);
assert_eq!(result.content.len(), 50000);
assert!(result.content.ends_with("x".repeat(50000).as_str()));
}
#[test]
fn test_truncate_tail_byte_count_respects_newlines() {
let content: String = (1..=100)
.map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
.collect();
let result = truncate_tail(&content, 2000, 50000);
assert!(result.truncated);
assert!(
result.output_bytes <= 50000,
"output_bytes {} exceeds limit 50000",
result.output_bytes
);
}
#[tokio::test]
async fn truncated_by_lines_shows_footer() {
let tool = make_tool();
let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
let output = tool
.execute(
"id".into(),
serde_json::json!({"command": cmd}),
Cancel::new(),
None,
)
.await
.unwrap();
assert!(
output.content.contains("Showing lines"),
"got: {}",
output.content
);
assert!(
output.content.contains("Full output:"),
"got: {}",
output.content
);
}
#[tokio::test]
async fn small_output_no_footer() {
let tool = make_tool();
let output = tool
.execute(
"id".into(),
serde_json::json!({"command": "echo hello"}),
Cancel::new(),
None,
)
.await
.unwrap();
assert!(
!output.content.contains("Output truncated"),
"got: {}",
output.content
);
assert!(
!output.content.contains("Full output:"),
"got: {}",
output.content
);
}
#[tokio::test]
async fn truncated_saves_temp_file() {
let tool = make_tool();
let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
let output = tool
.execute(
"id".into(),
serde_json::json!({"command": cmd}),
Cancel::new(),
None,
)
.await
.unwrap();
assert!(
output.content.contains("/rab-bash/"),
"expected temp file path, got: {}",
output.content
);
}
#[test]
fn test_truncate_tail_many_short_lines() {
let content: String = (1..=10000).map(|i| format!("{}\n", i)).collect();
let result = truncate_tail(&content, 2000, 50000);
assert!(result.truncated);
assert_eq!(result.truncated_by, "lines");
assert_eq!(result.output_lines, 2000);
assert!(
result.content.starts_with("8001"),
"starts with: {:?}",
&result.content[..10]
);
}
#[test]
fn test_truncate_tail_lines_and_bytes_both_exceeded() {
let content: String = (1..=5000)
.map(|i| format!("line {} {}\n", i, "x".repeat(100)))
.collect();
let result = truncate_tail(&content, 2000, 30000);
assert!(result.truncated);
assert_eq!(result.truncated_by, "bytes");
assert!(result.output_lines < 2000);
}
}
#[cfg(test)]
mod command_tests {
use super::*;
#[test]
fn test_parse_ls() {
let result = parse_command("ls -la src/");
assert!(result.is_some());
let (name, desc) = result.unwrap();
assert_eq!(name, "ls");
assert_eq!(desc, Some("src/".to_string()));
}
#[test]
fn test_parse_ls_default() {
let result = parse_command("ls");
assert!(result.is_some());
let (name, desc) = result.unwrap();
assert_eq!(name, "ls");
assert_eq!(desc, Some(".".to_string()));
}
#[test]
fn test_parse_grep() {
let result = parse_command("grep -r \"pattern\" src/");
assert!(result.is_some());
let (name, desc) = result.unwrap();
assert_eq!(name, "grep");
assert!(desc.is_some());
let desc = desc.unwrap();
assert!(desc.contains("pattern"));
assert!(desc.contains("src/"));
}
#[test]
fn test_parse_rg() {
let result = parse_command("rg pattern src/");
assert!(result.is_some());
let (name, _) = result.unwrap();
assert_eq!(name, "grep");
}
#[test]
fn test_parse_find() {
let result = parse_command("find . -name \"*.rs\"");
assert!(result.is_some());
let (name, desc) = result.unwrap();
assert_eq!(name, "find");
assert!(desc.is_some());
let desc = desc.unwrap();
assert!(desc.contains("."));
assert!(desc.contains("*.rs"));
}
#[test]
fn test_parse_cat() {
let result = parse_command("cat README.md");
assert!(result.is_some());
let (name, desc) = result.unwrap();
assert_eq!(name, "cat");
assert_eq!(desc, Some("README.md".to_string()));
}
#[test]
fn test_parse_head() {
let result = parse_command("head -20 file.txt");
assert!(result.is_some());
let (name, desc) = result.unwrap();
assert_eq!(name, "head");
assert_eq!(desc, Some("-20 file.txt".to_string()));
}
#[test]
fn test_parse_tail() {
let result = parse_command("tail -f log.txt");
assert!(result.is_some());
let (name, desc) = result.unwrap();
assert_eq!(name, "tail");
assert_eq!(desc, Some("-f log.txt".to_string()));
}
#[test]
fn test_parse_wc() {
let result = parse_command("wc -l file.txt");
assert!(result.is_some());
let (name, desc) = result.unwrap();
assert_eq!(name, "wc");
assert_eq!(desc, Some("-l file.txt".to_string()));
}
#[test]
fn test_parse_unknown() {
let result = parse_command("echo hello");
assert!(result.is_none());
}
#[test]
fn test_parse_with_env() {
let result = parse_command("FOO=bar ls src/");
assert!(result.is_some());
let (name, desc) = result.unwrap();
assert_eq!(name, "ls");
assert_eq!(desc, Some("src/".to_string()));
}
}