1use crate::agent::extension::{AgentTool, Cancel, Extension, ToolOutput};
2use crate::agent::extension::{ToolRenderContext, ToolRenderer};
3use crate::tui::Theme;
4use crate::tui::visual_truncate::truncate_to_visual_lines;
5use anyhow::Context;
6use async_trait::async_trait;
7use std::borrow::Cow;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::time::Instant;
11use tokio::sync::{Mutex as TokioMutex, mpsc::UnboundedSender};
12
13pub struct BashExtension {
14 cwd: std::path::PathBuf,
15}
16
17impl BashExtension {
18 pub fn new(cwd: std::path::PathBuf) -> Self {
19 Self { cwd }
20 }
21}
22
23impl Extension for BashExtension {
24 fn name(&self) -> Cow<'static, str> {
25 "bash".into()
26 }
27
28 fn tools(&self) -> Vec<Box<dyn AgentTool>> {
29 vec![Box::new(BashTool {
30 cwd: self.cwd.clone(),
31 })]
32 }
33}
34
35struct BashTool {
36 cwd: std::path::PathBuf,
37}
38
39const DEFAULT_MAX_LINES: usize = 2000;
42const DEFAULT_MAX_BYTES: usize = 50 * 1024; const DEFAULT_TIMEOUT_SECS: u64 = 300; #[cfg(unix)]
49fn kill_process_group(pid: u32) {
50 if pid > 0 {
51 let _ = std::process::Command::new("kill")
52 .arg("--")
53 .arg(format!("-{}", pid))
54 .spawn();
55 }
56}
57
58#[cfg(not(unix))]
59fn kill_process_group(pid: u32) {
60 let _ = pid;
61}
62
63fn spawn_bash_command(
65 command: &str,
66 cwd: &std::path::Path,
67) -> std::io::Result<tokio::process::Child> {
68 #[cfg(unix)]
69 {
70 use std::os::unix::process::CommandExt;
71 let mut std_cmd = std::process::Command::new("sh");
72 std_cmd.arg("-c").arg(command).current_dir(cwd);
73 unsafe {
74 std_cmd.pre_exec(|| {
75 libc::setpgid(0, 0);
76 Ok(())
77 });
78 }
79 let mut tokio_cmd = tokio::process::Command::from(std_cmd);
80 tokio_cmd
81 .stdout(std::process::Stdio::piped())
82 .stderr(std::process::Stdio::piped())
83 .spawn()
84 }
85 #[cfg(not(unix))]
86 {
87 tokio::process::Command::new("sh")
88 .arg("-c")
89 .arg(command)
90 .current_dir(cwd)
91 .stdout(std::process::Stdio::piped())
92 .stderr(std::process::Stdio::piped())
93 .spawn()
94 }
95}
96
97fn finish_bash_execution(
106 _command: &str,
107 combined: &str,
108 exit_code: i32,
109 cancelled: bool,
110 _started_at: Instant,
111 on_update: Option<UnboundedSender<ToolOutput>>,
112) -> Result<ToolOutput, anyhow::Error> {
113 let trunc = truncate_tail(combined, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES);
115
116 let mut result_text = if trunc.content.is_empty() {
118 "(no output)".to_string()
119 } else {
120 trunc.content.clone()
121 };
122
123 if trunc.truncated {
125 let tmp_dir = std::env::temp_dir().join("rab-bash");
126 let _ = std::fs::create_dir_all(&tmp_dir);
127 let tmp_path = tmp_dir.join(format!("{}.txt", uuid::Uuid::new_v4()));
128 let saved = if std::fs::write(&tmp_path, combined).is_ok() {
129 Some(tmp_path)
130 } else {
131 None
132 };
133
134 let start_line = trunc.total_lines - trunc.output_lines + 1;
135 let end_line = trunc.total_lines;
136
137 let notice = if trunc.truncated_by == "lines" {
138 format!(
139 "\n\n[Showing lines {}-{} of {}. Full output: {}]",
140 start_line,
141 end_line,
142 trunc.total_lines,
143 saved
144 .as_ref()
145 .map(|p| p.display().to_string())
146 .unwrap_or_default()
147 )
148 } else {
149 format!(
150 "\n\n[Showing lines {}-{} of {} ({} limit). Full output: {}]",
151 start_line,
152 end_line,
153 trunc.total_lines,
154 format_size(DEFAULT_MAX_BYTES),
155 saved
156 .as_ref()
157 .map(|p| p.display().to_string())
158 .unwrap_or_default()
159 )
160 };
161 result_text.push_str(¬ice);
162 }
163
164 if let Some(ref tx) = on_update {
166 let _ = tx.send(ToolOutput::ok(result_text.clone()));
167 }
168
169 if cancelled {
171 let err_msg = if result_text.is_empty() || result_text == "(no output)" {
172 "Command aborted".to_string()
173 } else {
174 format!("{}\n\nCommand aborted", result_text)
175 };
176 return Err(anyhow::anyhow!("{}", err_msg));
177 }
178
179 if exit_code != 0 {
180 let err_msg = if result_text.is_empty() || result_text == "(no output)" {
181 format!("Command exited with code {}", exit_code)
182 } else {
183 format!("{}\n\nCommand exited with code {}", result_text, exit_code)
184 };
185 return Err(anyhow::anyhow!("{}", err_msg));
186 }
187
188 Ok(ToolOutput::ok(result_text))
189}
190
191fn format_size(bytes: usize) -> String {
193 if bytes < 1024 {
194 format!("{}B", bytes)
195 } else if bytes < 1024 * 1024 {
196 format!("{:.1}KB", bytes as f64 / 1024.0)
197 } else {
198 format!("{:.1}MB", bytes as f64 / (1024.0 * 1024.0))
199 }
200}
201
202struct TailTruncation {
204 content: String,
206 truncated: bool,
208 #[allow(dead_code)]
210 total_lines: usize,
211 #[allow(dead_code)]
212 output_lines: usize,
213 #[allow(dead_code)]
214 output_bytes: usize,
215 #[allow(dead_code)]
216 truncated_by: &'static str, #[allow(dead_code)]
218 last_line_partial: bool,
219}
220
221fn truncate_tail(content: &str, max_lines: usize, max_bytes: usize) -> TailTruncation {
225 let total_bytes = content.len();
226 let lines: Vec<&str> = content.lines().collect();
227 let total_lines = lines.len();
228
229 if total_lines <= max_lines && total_bytes <= max_bytes {
231 return TailTruncation {
232 content: content.to_string(),
233 truncated: false,
234 total_lines,
235 output_lines: total_lines,
236 output_bytes: total_bytes,
237 truncated_by: "",
238 last_line_partial: false,
239 };
240 }
241
242 let mut output: Vec<&str> = Vec::new();
244 let mut byte_count: usize = 0;
245 let mut truncated_by = "lines";
246 let mut last_line_partial = false;
247
248 for line in lines.iter().rev().take(max_lines) {
249 let line_bytes = line.len();
250 let with_newline = if output.is_empty() {
251 line_bytes
252 } else {
253 line_bytes + 1 };
255
256 if byte_count + with_newline > max_bytes {
257 truncated_by = "bytes";
258 if output.is_empty() {
261 let end_start = line.len().saturating_sub(max_bytes);
262 let truncated_line = &line[end_start..];
263 output.push(truncated_line);
264 byte_count = truncated_line.len();
265 last_line_partial = true;
266 }
267 break;
268 }
269
270 output.push(line);
271 byte_count += with_newline;
272 }
273
274 if output.len() >= max_lines && byte_count <= max_bytes {
275 truncated_by = "lines";
276 }
277
278 output.reverse();
279 TailTruncation {
280 content: output.join("\n"),
281 truncated: true,
282 total_lines,
283 output_lines: output.len(),
284 output_bytes: byte_count,
285 truncated_by,
286 last_line_partial,
287 }
288}
289
290#[async_trait]
293impl AgentTool for BashTool {
294 fn name(&self) -> &str {
295 "bash"
296 }
297
298 fn description(&self) -> &str {
299 "Execute a bash command in the current working directory. Returns stdout and stderr. \
300 Output is truncated to last 2000 lines or 50KB (whichever is hit first). If truncated, \
301 full output is saved to a temp file. Optionally provide a timeout in seconds."
302 }
303
304 fn parameters(&self) -> serde_json::Value {
305 serde_json::json!({
306 "type": "object",
307 "required": ["command"],
308 "properties": {
309 "command": {
310 "type": "string",
311 "description": "Bash command to execute"
312 },
313 "timeout": {
314 "type": "number",
315 "description": "Timeout in seconds (optional, no default timeout)"
316 }
317 }
318 })
319 }
320
321 fn label(&self) -> &str {
322 "Execute bash commands (ls, grep, find, etc.)"
323 }
324
325 fn renderer(&self) -> Option<Box<dyn ToolRenderer>> {
326 Some(Box::new(BashRenderer))
327 }
328
329 async fn execute(
330 &self,
331 tool_call_id: String,
332 args: serde_json::Value,
333 cancel: Cancel,
334 on_update: Option<UnboundedSender<ToolOutput>>,
335 ) -> anyhow::Result<ToolOutput> {
336 let _ = tool_call_id;
337 let command = args["command"]
338 .as_str()
339 .ok_or_else(|| anyhow::anyhow!("Missing 'command' argument"))?;
340 let timeout = args["timeout"].as_u64().or(Some(DEFAULT_TIMEOUT_SECS));
341 let started_at = Instant::now();
342
343 cancel.check()?;
344
345 let mut child = spawn_bash_command(command, &self.cwd)
347 .with_context(|| format!("Failed to spawn command: {}", command))?;
348
349 let pid = child.id().unwrap_or(0);
350
351 let combined = Arc::new(TokioMutex::new(String::new()));
353 let combined_clone = combined.clone();
354
355 let stdout_pipe = child
357 .stdout
358 .take()
359 .ok_or_else(|| anyhow::anyhow!("Failed to capture stdout"))?;
360 let stderr_pipe = child
361 .stderr
362 .take()
363 .ok_or_else(|| anyhow::anyhow!("Failed to capture stderr"))?;
364
365 use tokio::io::AsyncReadExt;
366 let read_task = tokio::spawn(async move {
367 let mut stdout_buf = vec![0u8; 4096];
368 let mut stderr_buf = vec![0u8; 4096];
369 let mut stdout_reader = stdout_pipe;
370 let mut stderr_reader = stderr_pipe;
371 let mut stdout_done = false;
372 let mut stderr_done = false;
373 loop {
374 tokio::select! {
375 result = stdout_reader.read(&mut stdout_buf), if !stdout_done => {
376 match result {
377 Ok(0) => stdout_done = true,
378 Ok(n) => {
379 let mut out = combined_clone.lock().await;
380 out.push_str(&String::from_utf8_lossy(&stdout_buf[..n]));
381 }
382 Err(_) => stdout_done = true,
383 }
384 }
385 result = stderr_reader.read(&mut stderr_buf), if !stderr_done => {
386 match result {
387 Ok(0) => stderr_done = true,
388 Ok(n) => {
389 let mut out = combined_clone.lock().await;
390 out.push_str(&String::from_utf8_lossy(&stderr_buf[..n]));
391 }
392 Err(_) => stderr_done = true,
393 }
394 }
395 }
396 if stdout_done && stderr_done {
397 break;
398 }
399 }
400 });
401
402 let cancelled = Arc::new(AtomicBool::new(false));
404 let cancel_clone = cancelled.clone();
405 let _cancel_monitor: tokio::task::JoinHandle<()> = tokio::spawn(async move {
406 while !cancel.is_cancelled() {
407 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
408 }
409 cancel_clone.store(true, Ordering::SeqCst);
410 kill_process_group(pid);
411 });
412
413 let timeout_dur = timeout.map(std::time::Duration::from_secs);
415 loop {
416 if cancelled.load(Ordering::SeqCst) {
418 kill_process_group(pid);
419 read_task.abort();
420 return Err(anyhow::anyhow!("Command aborted"));
421 }
422
423 if let Some(dur) = timeout_dur
425 && started_at.elapsed() > dur
426 {
427 kill_process_group(pid);
428 read_task.abort();
429 return Err(anyhow::anyhow!(
430 "Command timed out after {} seconds",
431 timeout.unwrap_or(0)
432 ));
433 }
434
435 if let Some(ref tx) = on_update {
437 let out = combined.lock().await;
438 if !out.is_empty() {
439 let elapsed = started_at.elapsed();
440 let display = format!(
441 "{}\n\n[Elapsed {:.1}s]",
442 out.trim_end(),
443 elapsed.as_secs_f64()
444 );
445 let _ = tx.send(ToolOutput::ok(display));
446 }
447 }
448
449 match child.try_wait() {
451 Ok(Some(status)) => {
452 read_task.await.ok();
453 let combined_str = combined.lock().await.clone();
454 let exit_code = status.code().unwrap_or(-1);
455
456 return finish_bash_execution(
457 command,
458 &combined_str,
459 exit_code,
460 false,
461 started_at,
462 on_update,
463 );
464 }
465 Ok(None) => {
466 tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
468 }
469 Err(_) => {
470 read_task.await.ok();
471 let combined_str = combined.lock().await.clone();
472 let exit_code = -1;
473 return finish_bash_execution(
474 command,
475 &combined_str,
476 exit_code,
477 false,
478 started_at,
479 on_update,
480 );
481 }
482 }
483 }
484 }
485}
486
487struct BashRenderer;
490
491fn parse_command(cmd: &str) -> Option<(&'static str, Option<String>)> {
496 let trimmed = cmd.trim();
497
498 let effective = {
500 let mut rest = trimmed;
501 loop {
502 if let Some(eq_pos) = rest.find('=') {
504 let var_name = &rest[..eq_pos];
505 if !var_name.is_empty() && var_name.chars().all(|c| c.is_alphanumeric() || c == '_')
507 {
508 let after_eq = &rest[eq_pos + 1..];
510 if let Some(space_pos) = after_eq.find(' ') {
511 rest = after_eq[space_pos + 1..].trim_start();
512 continue;
513 } else {
514 rest = "";
516 break;
517 }
518 }
519 }
520 break;
521 }
522 rest
523 };
524
525 if effective.starts_with("ls ") || effective == "ls" {
527 let path = extract_ls_path(effective);
528 return Some(("ls", path));
529 }
530
531 if effective.starts_with("grep ") || effective.starts_with("rg ") {
533 let info = extract_grep_info(effective);
534 return Some(("grep", info));
535 }
536
537 if effective.starts_with("find ") {
539 let info = extract_find_info(effective);
540 return Some(("find", info));
541 }
542
543 if effective.starts_with("cat ") || effective == "cat" {
545 let path = effective.strip_prefix("cat ").map(|s| s.trim().to_string());
546 return Some(("cat", path));
547 }
548
549 if effective.starts_with("head ") || effective.starts_with("tail ") {
551 let (cmd_name, rest) = if effective.starts_with("head") {
552 ("head", effective.strip_prefix("head").unwrap_or(""))
553 } else {
554 ("tail", effective.strip_prefix("tail").unwrap_or(""))
555 };
556 let path = rest.trim();
557 let path_opt = if path.is_empty() {
558 None
559 } else {
560 Some(path.to_string())
561 };
562 return Some((cmd_name, path_opt));
563 }
564
565 if effective.starts_with("wc ") || effective == "wc" {
567 let path = effective.strip_prefix("wc ").map(|s| s.trim().to_string());
568 return Some(("wc", path));
569 }
570
571 None
572}
573
574fn extract_ls_path(cmd: &str) -> Option<String> {
576 let args = cmd.strip_prefix("ls").unwrap_or("").trim();
578 if args.is_empty() {
579 Some(".".to_string())
580 } else {
581 args.split_whitespace()
583 .rfind(|a| !a.starts_with('-'))
584 .map(|s| s.to_string())
585 }
586}
587
588fn extract_grep_info(cmd: &str) -> Option<String> {
590 let args = cmd
591 .strip_prefix("grep")
592 .or_else(|| cmd.strip_prefix("rg"))
593 .unwrap_or("")
594 .trim();
595 if args.is_empty() {
596 return None;
597 }
598 let mut pattern = None;
600 let mut files = Vec::new();
601 let mut skip_next = false;
602 for arg in args.split_whitespace() {
603 if skip_next {
604 skip_next = false;
605 continue;
606 }
607 if arg.starts_with('-') {
608 if arg == "-n" || arg == "-C" || arg == "-A" || arg == "-B" || arg == "--max-count" {
610 skip_next = true;
611 }
612 continue;
613 }
614 if pattern.is_none() {
615 pattern = Some(arg);
616 } else {
617 files.push(arg);
618 }
619 }
620 let mut desc = String::new();
621 if let Some(p) = pattern {
622 desc.push_str(p);
623 }
624 if !files.is_empty() {
625 desc.push_str(" in ");
626 desc.push_str(&files.join(", "));
627 }
628 if desc.is_empty() { None } else { Some(desc) }
629}
630
631fn extract_find_info(cmd: &str) -> Option<String> {
633 let args = cmd.strip_prefix("find").unwrap_or("").trim();
634 if args.is_empty() {
635 return Some(".".to_string());
636 }
637 let mut path = None;
639 let mut name = None;
640 let mut skip_next = false;
641 for arg in args.split_whitespace() {
642 if skip_next {
643 skip_next = false;
644 continue;
645 }
646 if arg == "-name" || arg == "-path" || arg == "-type" {
647 skip_next = true;
648 if arg == "-name" {
649 continue;
651 }
652 }
653 if arg.starts_with('-') {
654 continue;
655 }
656 if path.is_none() {
657 path = Some(arg);
658 }
659 }
660 let mut it = args.split_whitespace();
662 while let Some(arg) = it.next() {
663 if arg == "-name" {
664 name = it.next();
665 }
666 }
667 let mut desc = path.unwrap_or(".").to_string();
668 if let Some(n) = name {
669 desc.push_str(&format!(" (name={})", n));
670 }
671 Some(desc)
672}
673
674fn format_command_header(cmd: &str, theme: &dyn Theme) -> Option<String> {
676 let (name, desc) = parse_command(cmd)?;
677 let title = theme.fg("toolTitle", &theme.bold(name));
678 let detail = desc
679 .map(|d| format!(" {}", theme.fg("accent", &d)))
680 .unwrap_or_default();
681 Some(format!("{}{}", title, detail))
682}
683
684impl ToolRenderer for BashRenderer {
687 fn render_call(
688 &self,
689 args: &serde_json::Value,
690 _width: usize,
691 theme: &dyn Theme,
692 _ctx: &ToolRenderContext,
693 ) -> Vec<String> {
694 let cmd = args
695 .get("command")
696 .and_then(|v| v.as_str())
697 .unwrap_or("...");
698 let timeout = args.get("timeout").and_then(|v| v.as_i64());
699 let timeout_suffix = timeout
700 .map(|t| theme.fg("muted", &format!(" (timeout {}s)", t)))
701 .unwrap_or_default();
702
703 if let Some(header) = format_command_header(cmd, theme) {
705 vec![format!("{}{}", header, timeout_suffix)]
706 } else {
707 vec![format!(
708 "{}{}",
709 theme.fg("toolTitle", &theme.bold(&format!("$ {}", cmd))),
710 timeout_suffix
711 )]
712 }
713 }
714
715 fn render_result(
716 &self,
717 content: &str,
718 width: usize,
719 theme: &dyn Theme,
720 ctx: &ToolRenderContext,
721 ) -> Vec<String> {
722 let mut lines: Vec<String> = Vec::new();
723
724 let clean = strip_context_truncation_footer(content);
726 let all_lines: Vec<&str> = clean.split('\n').collect();
727
728 if all_lines.is_empty() || (all_lines.len() == 1 && all_lines[0].is_empty()) {
729 return lines;
730 }
731
732 let preview_count = 5;
734 let (preview_lines, hidden_line_count) = if ctx.expanded {
735 (all_lines.clone(), 0)
736 } else {
737 truncate_to_visual_lines(&all_lines, width, preview_count)
738 };
739
740 if !ctx.expanded && hidden_line_count > 0 {
741 let hint = if ctx.expand_key.is_empty() {
742 theme.fg("muted", &format!("... {} earlier lines", hidden_line_count))
743 } else {
744 theme.fg(
745 "muted",
746 &format!(
747 "... ({} earlier lines, {} to expand)",
748 hidden_line_count, ctx.expand_key
749 ),
750 )
751 };
752 lines.push(hint);
753 }
754
755 let fg_key = if ctx.is_error { "error" } else { "toolOutput" };
756 for line in &preview_lines {
757 if line.is_empty() {
758 lines.push(String::new());
759 } else {
760 lines.push(theme.fg(fg_key, line));
761 }
762 }
763
764 if let Some(secs) = ctx.duration_secs {
766 let is_complete = ctx.exit_code.is_some() || ctx.cancelled;
767 let label = if is_complete { "Took" } else { "Elapsed" };
768 lines.push(theme.fg("muted", &format!("{} {:.1}s", label, secs)));
769 }
770
771 if ctx.cancelled {
773 lines.push(theme.fg("warning", "(cancelled)"));
774 } else if let Some(code) = ctx.exit_code
775 && code != 0
776 {
777 lines.push(theme.fg("warning", &format!("(exit {})", code)));
778 }
779
780 if ctx.was_truncated {
782 if let Some(ref path) = ctx.full_output_path {
783 lines.push(theme.fg(
784 "warning",
785 &format!("Output truncated. Full output: {}", path),
786 ));
787 } else {
788 lines.push(theme.fg("warning", "Output truncated."));
789 }
790 }
791
792 lines
793 }
794}
795
796fn strip_context_truncation_footer(output: &str) -> String {
798 let lines: Vec<&str> = output.lines().collect();
799 if lines.len() < 3 {
800 return output.to_string();
801 }
802 let last = lines.last().map_or("", |v| v).trim();
803 if last.starts_with('[')
804 && (last.contains("Showing lines") || last.contains("Showing last"))
805 && last.contains("Full output:")
806 {
807 let before: Vec<&str> = lines[..lines.len() - 1].to_vec();
808 if !before.is_empty() && before[before.len() - 1].is_empty() {
809 before[..before.len() - 1].join("\n")
810 } else {
811 before.join("\n")
812 }
813 } else {
814 output.to_string()
815 }
816}
817
818#[cfg(test)]
819mod tests {
820 use super::*;
821
822 fn make_tool() -> BashTool {
823 BashTool {
824 cwd: std::env::temp_dir(),
825 }
826 }
827
828 #[tokio::test]
829 async fn runs_simple_command() {
830 let tool = make_tool();
831 let output = tool
832 .execute(
833 "id".into(),
834 serde_json::json!({"command": "echo hello"}),
835 Cancel::new(),
836 None,
837 )
838 .await
839 .unwrap();
840 assert!(output.content.contains("hello"));
841 }
842
843 #[tokio::test]
844 async fn captures_stderr() {
845 let tool = make_tool();
846 let output = tool
847 .execute(
848 "id".into(),
849 serde_json::json!({"command": "echo err >&2"}),
850 Cancel::new(),
851 None,
852 )
853 .await
854 .unwrap();
855 assert!(output.content.contains("err"));
856 }
857
858 #[tokio::test]
859 async fn cancel_aborts() {
860 let tool = make_tool();
861 let cancel = Cancel::new();
862 cancel.cancel();
863 let result = tool
864 .execute(
865 "id".into(),
866 serde_json::json!({"command": "sleep 10"}),
867 cancel,
868 None,
869 )
870 .await;
871 assert!(result.is_err());
872 let err = result.unwrap_err().to_string();
873 assert!(
874 err.contains("cancelled") || err.contains("aborted"),
875 "expected cancellation error, got: {}",
876 err
877 );
878 }
879
880 #[tokio::test]
881 async fn timeout_works() {
882 let tool = make_tool();
883 let result = tool
884 .execute(
885 "id".into(),
886 serde_json::json!({"command": "sleep 10", "timeout": 1}),
887 Cancel::new(),
888 None,
889 )
890 .await;
891 assert!(result.is_err());
892 let err = result.unwrap_err().to_string();
893 assert!(err.contains("timed out"));
894 }
895
896 #[test]
897 fn test_truncate_tail_no_truncation() {
898 let result = truncate_tail("hello\nworld\n", 2000, 50000);
899 assert!(!result.truncated);
900 assert_eq!(result.content, "hello\nworld\n");
901 }
902
903 #[test]
904 fn test_truncate_tail_by_lines() {
905 let content: String = (1..=5000).map(|i| format!("line {}\n", i)).collect();
906 let result = truncate_tail(&content, 2000, 50000);
907 assert!(result.truncated);
908 assert!(result.content.starts_with("line 3001"));
909 assert_eq!(result.content.lines().count(), 2000);
910 }
911
912 #[test]
913 fn test_truncate_tail_by_bytes() {
914 let content: String = (1..=100)
915 .map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
916 .collect();
917 let result = truncate_tail(&content, 2000, 50000);
918 assert!(result.truncated);
919 assert!(result.content.len() <= 50000);
920 assert!(result.content.lines().count() < 100);
921 }
922
923 #[test]
924 fn test_truncate_tail_partial_last_line() {
925 let content = format!("short\n{}\n", "x".repeat(60000));
927 let result = truncate_tail(&content, 2000, 50000);
928 assert!(result.truncated);
929 assert!(!result.content.starts_with("short"));
930 assert!(result.content.len() <= 50000);
931 }
932
933 #[test]
934 fn test_truncate_tail_empty() {
935 let result = truncate_tail("", 2000, 50000);
936 assert!(!result.truncated);
937 assert_eq!(result.content, "");
938 }
939
940 #[tokio::test]
943 async fn exit_code_nonzero() {
944 let tool = make_tool();
945 let result = tool
946 .execute(
947 "id".into(),
948 serde_json::json!({"command": "exit 42"}),
949 Cancel::new(),
950 None,
951 )
952 .await;
953 assert!(result.is_err(), "non-zero exit should return error");
954 let err = result.unwrap_err().to_string();
955 assert!(err.contains("exited with code 42"), "got: {}", err);
956 }
957
958 #[tokio::test]
959 async fn exit_code_with_output() {
960 let tool = make_tool();
961 let result = tool
962 .execute(
963 "id".into(),
964 serde_json::json!({"command": "echo before && exit 1"}),
965 Cancel::new(),
966 None,
967 )
968 .await;
969 assert!(result.is_err(), "non-zero exit should return error");
970 let err = result.unwrap_err().to_string();
971 assert!(err.contains("before"), "got: {}", err);
972 assert!(err.contains("exited with code 1"), "got: {}", err);
973 }
974
975 #[tokio::test]
976 async fn no_output() {
977 let tool = make_tool();
978 let output = tool
979 .execute(
980 "id".into(),
981 serde_json::json!({"command": "true"}),
982 Cancel::new(),
983 None,
984 )
985 .await
986 .unwrap();
987 assert!(
988 output.content.contains("(no output)"),
989 "got: {}",
990 output.content
991 );
992 }
993
994 #[tokio::test]
995 async fn combined_stdout_stderr() {
996 let tool = make_tool();
997 let output = tool
998 .execute(
999 "id".into(),
1000 serde_json::json!({"command": "echo out; echo err >&2"}),
1001 Cancel::new(),
1002 None,
1003 )
1004 .await
1005 .unwrap();
1006 assert!(output.content.contains("out"), "got: {}", output.content);
1007 assert!(output.content.contains("err"), "got: {}", output.content);
1008 }
1009
1010 #[tokio::test]
1011 async fn runs_in_cwd() {
1012 let tmp = std::env::temp_dir().join(format!("rab-bash-cwd-{}", uuid::Uuid::new_v4()));
1013 std::fs::create_dir_all(&tmp).unwrap();
1014 std::fs::write(tmp.join("marker.txt"), "hello").unwrap();
1015
1016 let tool = BashTool { cwd: tmp.clone() };
1017 let output = tool
1018 .execute(
1019 "id".into(),
1020 serde_json::json!({"command": "cat marker.txt"}),
1021 Cancel::new(),
1022 None,
1023 )
1024 .await
1025 .unwrap();
1026 assert!(output.content.contains("hello"), "got: {}", output.content);
1027 }
1028
1029 #[tokio::test]
1030 async fn missing_command_errors() {
1031 let tool = make_tool();
1032 let result = tool
1033 .execute("id".into(), serde_json::json!({}), Cancel::new(), None)
1034 .await;
1035 assert!(result.is_err());
1036 let err = result.unwrap_err().to_string();
1037 assert!(err.contains("command"), "got: {}", err);
1038 }
1039
1040 #[tokio::test]
1041 async fn timeout_with_partial_output() {
1042 let tool = make_tool();
1043 let result = tool
1045 .execute(
1046 "id".into(),
1047 serde_json::json!({"command": "echo start && sleep 10 && echo end", "timeout": 1}),
1048 Cancel::new(),
1049 None,
1050 )
1051 .await;
1052 assert!(result.is_err());
1055 let err = result.unwrap_err().to_string();
1056 assert!(err.contains("timed out"), "got: {}", err);
1057 }
1058
1059 #[tokio::test]
1060 async fn cancel_during_long_command() {
1061 let tool = make_tool();
1062 let cancel = Cancel::new();
1063 let cancel_clone = cancel.clone();
1064
1065 let handle = tokio::spawn(async move {
1066 tool.execute(
1067 "id".into(),
1068 serde_json::json!({"command": "sleep 30"}),
1069 cancel_clone,
1070 None,
1071 )
1072 .await
1073 });
1074
1075 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1077 cancel.cancel();
1078
1079 let result = handle.await.unwrap();
1080 assert!(result.is_err());
1081 let err = result.unwrap_err().to_string();
1082 assert!(
1083 err.contains("aborted") || err.contains("cancelled"),
1084 "expected cancellation error, got: {}",
1085 err
1086 );
1087 }
1088
1089 #[test]
1092 fn test_truncate_tail_exact_line_fit() {
1093 let lines: String = (1..=2000).map(|i| format!("line {}\n", i)).collect();
1095 let result = truncate_tail(&lines, 2000, 50000);
1096 assert!(
1097 !result.truncated,
1098 "should not truncate when exactly at line limit"
1099 );
1100 assert!(result.content.lines().count() == 2000);
1101 }
1102
1103 #[test]
1104 fn test_truncate_tail_one_over_line_limit() {
1105 let lines: String = (1..=2001).map(|i| format!("line {}\n", i)).collect();
1106 let result = truncate_tail(&lines, 2000, 50000);
1107 assert!(result.truncated);
1108 assert_eq!(result.content.lines().count(), 2000);
1109 assert!(result.content.starts_with("line 2"));
1111 }
1112
1113 #[test]
1114 fn test_truncate_tail_exact_byte_fit() {
1115 let line = "a".repeat(50000);
1117 let result = truncate_tail(&line, 2000, 50000);
1118 assert!(!result.truncated);
1119 }
1120
1121 #[test]
1122 fn test_truncate_tail_one_byte_over() {
1123 let line = "a".repeat(50001);
1125 let result = truncate_tail(&line, 2000, 50000);
1126 assert!(result.truncated);
1127 assert!(result.content.len() <= 50000);
1128 }
1129
1130 #[test]
1131 fn test_truncate_tail_single_line_under_limit() {
1132 let result = truncate_tail("hello world", 2000, 50000);
1133 assert!(!result.truncated);
1134 assert_eq!(result.content, "hello world");
1135 }
1136
1137 #[test]
1138 fn test_truncate_tail_trailing_newline() {
1139 let result = truncate_tail("a\nb\n", 2000, 50000);
1140 assert!(!result.truncated);
1141 assert_eq!(result.content, "a\nb\n");
1142 }
1143
1144 #[test]
1145 fn test_truncate_tail_no_trailing_newline() {
1146 let result = truncate_tail("a\nb", 2000, 50000);
1147 assert!(!result.truncated);
1148 assert_eq!(result.content, "a\nb");
1149 }
1150
1151 #[test]
1152 fn test_truncate_tail_single_line_exceeds_limit() {
1153 let content = "x".repeat(60000);
1154 let result = truncate_tail(&content, 2000, 50000);
1155 assert!(result.truncated);
1156 assert!(result.last_line_partial);
1157 assert_eq!(result.content.len(), 50000);
1159 assert!(result.content.ends_with("x".repeat(50000).as_str()));
1160 }
1161
1162 #[test]
1163 fn test_truncate_tail_byte_count_respects_newlines() {
1164 let content: String = (1..=100)
1167 .map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
1168 .collect();
1169 let result = truncate_tail(&content, 2000, 50000);
1170 assert!(result.truncated);
1171 assert!(
1173 result.output_bytes <= 50000,
1174 "output_bytes {} exceeds limit 50000",
1175 result.output_bytes
1176 );
1177 }
1178
1179 #[tokio::test]
1182 async fn truncated_by_lines_shows_footer() {
1183 let tool = make_tool();
1184 let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
1186 let output = tool
1187 .execute(
1188 "id".into(),
1189 serde_json::json!({"command": cmd}),
1190 Cancel::new(),
1191 None,
1192 )
1193 .await
1194 .unwrap();
1195 assert!(
1196 output.content.contains("Showing lines"),
1197 "got: {}",
1198 output.content
1199 );
1200 assert!(
1201 output.content.contains("Full output:"),
1202 "got: {}",
1203 output.content
1204 );
1205 }
1206
1207 #[tokio::test]
1208 async fn small_output_no_footer() {
1209 let tool = make_tool();
1210 let output = tool
1211 .execute(
1212 "id".into(),
1213 serde_json::json!({"command": "echo hello"}),
1214 Cancel::new(),
1215 None,
1216 )
1217 .await
1218 .unwrap();
1219 assert!(
1221 !output.content.contains("Output truncated"),
1222 "got: {}",
1223 output.content
1224 );
1225 assert!(
1226 !output.content.contains("Full output:"),
1227 "got: {}",
1228 output.content
1229 );
1230 }
1231
1232 #[tokio::test]
1233 async fn truncated_saves_temp_file() {
1234 let tool = make_tool();
1235 let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
1237 let output = tool
1238 .execute(
1239 "id".into(),
1240 serde_json::json!({"command": cmd}),
1241 Cancel::new(),
1242 None,
1243 )
1244 .await
1245 .unwrap();
1246 assert!(
1248 output.content.contains("/rab-bash/"),
1249 "expected temp file path, got: {}",
1250 output.content
1251 );
1252 }
1253
1254 #[test]
1257 fn test_truncate_tail_many_short_lines() {
1258 let content: String = (1..=10000).map(|i| format!("{}\n", i)).collect();
1260 let result = truncate_tail(&content, 2000, 50000);
1261 assert!(result.truncated);
1262 assert_eq!(result.truncated_by, "lines");
1263 assert_eq!(result.output_lines, 2000);
1264 assert!(
1266 result.content.starts_with("8001"),
1267 "starts with: {:?}",
1268 &result.content[..10]
1269 );
1270 }
1271
1272 #[test]
1273 fn test_truncate_tail_lines_and_bytes_both_exceeded() {
1274 let content: String = (1..=5000)
1276 .map(|i| format!("line {} {}\n", i, "x".repeat(100)))
1277 .collect();
1278 let result = truncate_tail(&content, 2000, 30000);
1279 assert!(result.truncated);
1280 assert_eq!(result.truncated_by, "bytes");
1283 assert!(result.output_lines < 2000);
1284 }
1285}
1286
1287#[cfg(test)]
1288mod command_tests {
1289 use super::*;
1290
1291 #[test]
1292 fn test_parse_ls() {
1293 let result = parse_command("ls -la src/");
1294 assert!(result.is_some());
1295 let (name, desc) = result.unwrap();
1296 assert_eq!(name, "ls");
1297 assert_eq!(desc, Some("src/".to_string()));
1298 }
1299
1300 #[test]
1301 fn test_parse_ls_default() {
1302 let result = parse_command("ls");
1303 assert!(result.is_some());
1304 let (name, desc) = result.unwrap();
1305 assert_eq!(name, "ls");
1306 assert_eq!(desc, Some(".".to_string()));
1307 }
1308
1309 #[test]
1310 fn test_parse_grep() {
1311 let result = parse_command("grep -r \"pattern\" src/");
1312 assert!(result.is_some());
1313 let (name, desc) = result.unwrap();
1314 assert_eq!(name, "grep");
1315 assert!(desc.is_some());
1316 let desc = desc.unwrap();
1317 assert!(desc.contains("pattern"));
1318 assert!(desc.contains("src/"));
1319 }
1320
1321 #[test]
1322 fn test_parse_rg() {
1323 let result = parse_command("rg pattern src/");
1324 assert!(result.is_some());
1325 let (name, _) = result.unwrap();
1326 assert_eq!(name, "grep");
1327 }
1328
1329 #[test]
1330 fn test_parse_find() {
1331 let result = parse_command("find . -name \"*.rs\"");
1332 assert!(result.is_some());
1333 let (name, desc) = result.unwrap();
1334 assert_eq!(name, "find");
1335 assert!(desc.is_some());
1336 let desc = desc.unwrap();
1337 assert!(desc.contains("."));
1338 assert!(desc.contains("*.rs"));
1339 }
1340
1341 #[test]
1342 fn test_parse_cat() {
1343 let result = parse_command("cat README.md");
1344 assert!(result.is_some());
1345 let (name, desc) = result.unwrap();
1346 assert_eq!(name, "cat");
1347 assert_eq!(desc, Some("README.md".to_string()));
1348 }
1349
1350 #[test]
1351 fn test_parse_head() {
1352 let result = parse_command("head -20 file.txt");
1353 assert!(result.is_some());
1354 let (name, desc) = result.unwrap();
1355 assert_eq!(name, "head");
1356 assert_eq!(desc, Some("-20 file.txt".to_string()));
1357 }
1358
1359 #[test]
1360 fn test_parse_tail() {
1361 let result = parse_command("tail -f log.txt");
1362 assert!(result.is_some());
1363 let (name, desc) = result.unwrap();
1364 assert_eq!(name, "tail");
1365 assert_eq!(desc, Some("-f log.txt".to_string()));
1366 }
1367
1368 #[test]
1369 fn test_parse_wc() {
1370 let result = parse_command("wc -l file.txt");
1371 assert!(result.is_some());
1372 let (name, desc) = result.unwrap();
1373 assert_eq!(name, "wc");
1374 assert_eq!(desc, Some("-l file.txt".to_string()));
1375 }
1376
1377 #[test]
1378 fn test_parse_unknown() {
1379 let result = parse_command("echo hello");
1380 assert!(result.is_none());
1381 }
1382
1383 #[test]
1384 fn test_parse_with_env() {
1385 let result = parse_command("FOO=bar ls src/");
1386 assert!(result.is_some());
1387 let (name, desc) = result.unwrap();
1388 assert_eq!(name, "ls");
1389 assert_eq!(desc, Some("src/".to_string()));
1390 }
1391}