1use crate::agent::extension::{Cancel, Extension, ToolDefinition};
2use crate::agent::extension::{ToolRenderContext, ToolRenderer};
3use crate::tui::Theme;
4use crate::tui::ThemeKey;
5use crate::tui::visual_truncate::truncate_to_visual_lines;
6use async_trait::async_trait;
7
8use std::borrow::Cow;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::time::Instant;
14use tokio::sync::{Mutex as TokioMutex, mpsc::UnboundedSender};
15
16#[async_trait]
21pub trait BashOperations: Send + Sync {
22 async fn exec(
25 &self,
26 command: &str,
27 cwd: &Path,
28 on_data: UnboundedSender<String>,
29 signal: Option<&Cancel>,
30 timeout: Option<u64>,
31 env: Option<HashMap<String, String>>,
32 ) -> Result<Option<i32>, anyhow::Error>;
33}
34
35#[derive(Clone, Default)]
36pub struct BashToolOptions {
37 pub operations: Option<Arc<dyn BashOperations>>,
39 pub command_prefix: Option<String>,
41 pub shell_path: Option<String>,
43}
44
45pub struct BashExtension {
46 cwd: PathBuf,
47 options: BashToolOptions,
48}
49
50impl BashExtension {
51 pub fn new(cwd: PathBuf) -> Self {
52 Self {
53 cwd,
54 options: BashToolOptions::default(),
55 }
56 }
57
58 pub fn with_options(cwd: PathBuf, options: BashToolOptions) -> Self {
59 Self { cwd, options }
60 }
61
62 pub fn with_shell_path(cwd: PathBuf, shell_path: String) -> Self {
63 Self {
64 cwd,
65 options: BashToolOptions {
66 shell_path: Some(shell_path),
67 ..BashToolOptions::default()
68 },
69 }
70 }
71}
72
73impl Extension for BashExtension {
74 fn name(&self) -> Cow<'static, str> {
75 "bash".into()
76 }
77
78 fn as_any(&self) -> &dyn std::any::Any {
79 self
80 }
81
82 fn tools(&self) -> Vec<ToolDefinition> {
83 vec![ToolDefinition {
84 tool: Box::new(BashTool {
85 cwd: self.cwd.clone(),
86 shell_path: self.options.shell_path.clone(),
87 command_prefix: self.options.command_prefix.clone(),
88 operations: self.options.operations.clone(),
89 }),
90 snippet: "Execute bash commands (ls, grep, find, etc.)",
91 guidelines: &[],
92 prepare_arguments: None,
93 before_tool_call: None,
94 after_tool_call: None,
95 renderer: Some(std::sync::Arc::new(BashRenderer)),
96 }]
97 }
98}
99
100struct BashTool {
101 cwd: PathBuf,
102 shell_path: Option<String>,
103 command_prefix: Option<String>,
104 operations: Option<Arc<dyn BashOperations>>,
105}
106
107const DEFAULT_MAX_LINES: usize = 2000;
110const DEFAULT_MAX_BYTES: usize = 50 * 1024; const BASH_TEMP_FILE_PREFIX: &str = "rab-bash";
112
113const TEMP_FILE_MAX_AGE_SECS: u64 = 24 * 60 * 60;
115
116const EXIT_STDIO_GRACE_MS: u64 = 100;
119
120struct ShellConfig {
124 shell: String,
125 args: Vec<String>,
126}
127
128fn resolve_shell(shell_path: Option<&str>) -> ShellConfig {
134 if let Some(path) = shell_path {
135 return ShellConfig {
136 shell: path.to_string(),
137 args: vec!["-c".to_string()],
138 };
139 }
140
141 if std::path::Path::new("/bin/bash").exists() {
143 return ShellConfig {
144 shell: "/bin/bash".to_string(),
145 args: vec!["-c".to_string()],
146 };
147 }
148
149 #[cfg(unix)]
151 {
152 if let Ok(output) = std::process::Command::new("which")
153 .arg("bash")
154 .stdout(std::process::Stdio::piped())
155 .stderr(std::process::Stdio::null())
156 .output()
157 && output.status.success()
158 {
159 let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
160 if !path.is_empty() && std::path::Path::new(&path).exists() {
161 return ShellConfig {
162 shell: path,
163 args: vec!["-c".to_string()],
164 };
165 }
166 }
167 }
168
169 ShellConfig {
171 shell: "sh".to_string(),
172 args: vec!["-c".to_string()],
173 }
174}
175
176#[cfg(unix)]
180fn kill_process_group(pid: u32) {
181 if pid > 0 {
182 let _ = std::process::Command::new("kill")
183 .arg("--")
184 .arg(format!("-{}", pid))
185 .status();
186 }
187}
188
189#[cfg(not(unix))]
190fn kill_process_group(pid: u32) {
191 let _ = pid;
192}
193
194fn spawn_bash_command(
196 command: &str,
197 cwd: &std::path::Path,
198 shell_path: Option<&str>,
199) -> std::io::Result<tokio::process::Child> {
200 let shell_cfg = resolve_shell(shell_path);
201
202 #[cfg(unix)]
203 {
204 use std::os::unix::process::CommandExt;
205 let mut std_cmd = std::process::Command::new(&shell_cfg.shell);
206 std_cmd.args(&shell_cfg.args).arg(command).current_dir(cwd);
207 unsafe {
208 std_cmd.pre_exec(|| {
209 libc::setpgid(0, 0);
210 Ok(())
211 });
212 }
213 let mut tokio_cmd = tokio::process::Command::from(std_cmd);
214 tokio_cmd
215 .stdin(std::process::Stdio::null())
216 .stdout(std::process::Stdio::piped())
217 .stderr(std::process::Stdio::piped())
218 .spawn()
219 }
220 #[cfg(not(unix))]
221 {
222 tokio::process::Command::new(&shell_cfg.shell)
223 .args(&shell_cfg.args)
224 .arg(command)
225 .current_dir(cwd)
226 .stdin(std::process::Stdio::null())
227 .stdout(std::process::Stdio::piped())
228 .stderr(std::process::Stdio::piped())
229 .spawn()
230 }
231}
232
233fn sanitize_output(text: &str) -> String {
235 let mut result = String::with_capacity(text.len());
236 let mut in_escape = false;
237 for c in text.chars() {
238 if in_escape {
239 if c == '\x1b' || c == '\u{9b}' {
240 continue;
241 }
242 if c.is_ascii_alphabetic() || c == '~' {
243 in_escape = false;
244 }
245 continue;
246 }
247 if c == '\x1b' || c == '\u{9b}' {
248 in_escape = true;
249 continue;
250 }
251 let code = c as u32;
252 if code <= 0x1f && code != 0x09 && code != 0x0a && code != 0x0d {
253 continue;
254 }
255 if (0xfff9..=0xfffb).contains(&code) {
256 continue;
257 }
258 result.push(c);
259 }
260 result
261}
262
263fn format_size(bytes: usize) -> String {
264 if bytes < 1024 {
265 format!("{}B", bytes)
266 } else if bytes < 1024 * 1024 {
267 format!("{:.1}KB", bytes as f64 / 1024.0)
268 } else {
269 format!("{:.1}MB", bytes as f64 / (1024.0 * 1024.0))
270 }
271}
272
273struct TailTruncation {
275 content: String,
276 truncated: bool,
277 total_lines: usize,
278 output_lines: usize,
279 output_bytes: usize,
280 truncated_by: &'static str,
281 last_line_partial: bool,
282}
283
284fn truncate_tail(content: &str, max_lines: usize, max_bytes: usize) -> TailTruncation {
286 let total_bytes = content.len();
287 let lines: Vec<&str> = content.lines().collect();
288 let total_lines = lines.len();
289
290 if total_lines <= max_lines && total_bytes <= max_bytes {
291 return TailTruncation {
292 content: content.to_string(),
293 truncated: false,
294 total_lines,
295 output_lines: total_lines,
296 output_bytes: total_bytes,
297 truncated_by: "",
298 last_line_partial: false,
299 };
300 }
301
302 let mut output: Vec<&str> = Vec::new();
303 let mut byte_count: usize = 0;
304 let mut truncated_by = "lines";
305 let mut last_line_partial = false;
306
307 for line in lines.iter().rev().take(max_lines) {
308 let line_bytes = line.len();
309 let with_newline = if output.is_empty() {
310 line_bytes
311 } else {
312 line_bytes + 1
313 };
314
315 if byte_count + with_newline > max_bytes {
316 truncated_by = "bytes";
317 if output.is_empty() {
318 let end_start = line.len().saturating_sub(max_bytes);
319 let truncated_line = &line[end_start..];
320 output.push(truncated_line);
321 byte_count = truncated_line.len();
322 last_line_partial = true;
323 }
324 break;
325 }
326
327 output.push(line);
328 byte_count += with_newline;
329 }
330
331 if output.len() >= max_lines && byte_count <= max_bytes {
332 truncated_by = "lines";
333 }
334
335 output.reverse();
336 TailTruncation {
337 content: output.join("\n"),
338 truncated: true,
339 total_lines,
340 output_lines: output.len(),
341 output_bytes: byte_count,
342 truncated_by,
343 last_line_partial,
344 }
345}
346
347fn finish_bash_execution(
350 combined: &str,
351 exit_code: i32,
352 cancelled: bool,
353 timed_out: Option<u64>,
354 ctx: &yoagent::types::ToolContext,
355) -> std::result::Result<yoagent::types::ToolResult, yoagent::types::ToolError> {
356 let trunc = truncate_tail(combined, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES);
357
358 let mut result_text = if trunc.content.is_empty() {
359 "(no output)".to_string()
360 } else {
361 trunc.content.clone()
362 };
363
364 let full_output_path = if trunc.truncated {
366 let tmp_dir = temp_output_dir();
367 let _ = std::fs::create_dir_all(&tmp_dir);
368 let tmp_path = tmp_dir.join(format!("{}.log", uuid::Uuid::new_v4()));
369 let saved = std::fs::write(&tmp_path, combined).ok().map(|_| {
370 cleanup_stale_temp_files();
371 tmp_path
372 });
373
374 let start_line = trunc.total_lines - trunc.output_lines + 1;
375 let end_line = trunc.total_lines;
376
377 let notice = if trunc.truncated_by == "lines" {
378 format!(
379 "\n\n[Showing lines {}-{} of {}. Full output: {}]",
380 start_line,
381 end_line,
382 trunc.total_lines,
383 saved
384 .as_ref()
385 .map(|p| p.display().to_string())
386 .unwrap_or_default()
387 )
388 } else {
389 format!(
390 "\n\n[Showing lines {}-{} of {} ({} limit). Full output: {}]",
391 start_line,
392 end_line,
393 trunc.total_lines,
394 format_size(DEFAULT_MAX_BYTES),
395 saved
396 .as_ref()
397 .map(|p| p.display().to_string())
398 .unwrap_or_default()
399 )
400 };
401 result_text.push_str(¬ice);
402 saved
403 } else {
404 None
405 };
406
407 let details = if trunc.truncated || full_output_path.is_some() {
409 Some(serde_json::json!({
410 "truncation": {
411 "truncated": trunc.truncated,
412 "truncatedBy": trunc.truncated_by,
413 "totalLines": trunc.total_lines,
414 "outputLines": trunc.output_lines,
415 "outputBytes": trunc.output_bytes,
416 "lastLinePartial": trunc.last_line_partial,
417 "maxLines": DEFAULT_MAX_LINES,
418 "maxBytes": DEFAULT_MAX_BYTES,
419 },
420 "fullOutputPath": full_output_path.as_ref().map(|p| p.display().to_string()),
421 }))
422 } else {
423 None
424 };
425
426 let final_output = if cancelled {
427 format_status_output(&result_text, "Command aborted")
428 } else if let Some(secs) = timed_out {
429 format_status_output(
430 &result_text,
431 &format!("Command timed out after {} seconds", secs),
432 )
433 } else if exit_code != 0 {
434 format_status_output(
435 &result_text,
436 &format!("Command exited with code {}", exit_code),
437 )
438 } else {
439 emit_update(ctx, result_text.clone(), details.clone());
440 return Ok(into_tool_result(result_text, details));
441 };
442
443 emit_update(ctx, final_output.clone(), details.clone());
444 Err(yoagent::types::ToolError::Failed(final_output))
445}
446
447struct BashRenderer;
452
453impl ToolRenderer for BashRenderer {
454 fn render_call(
455 &self,
456 args: &serde_json::Value,
457 _width: usize,
458 theme: &dyn Theme,
459 _ctx: &ToolRenderContext,
460 ) -> Vec<String> {
461 let cmd = args
462 .get("command")
463 .and_then(|v| v.as_str())
464 .unwrap_or("...");
465 let timeout = args.get("timeout").and_then(|v| v.as_i64());
466 let timeout_suffix = timeout
467 .map(|t| theme.fg_key(ThemeKey::Muted, &format!(" (timeout {}s)", t)))
468 .unwrap_or_default();
469
470 vec![format!(
471 "{}{}",
472 theme.fg_key(ThemeKey::ToolTitle, &theme.bold(&format!("$ {}", cmd))),
473 timeout_suffix
474 )]
475 }
476
477 fn render_result(
478 &self,
479 content: &str,
480 width: usize,
481 theme: &dyn Theme,
482 ctx: &ToolRenderContext,
483 ) -> Vec<String> {
484 let mut lines: Vec<String> = Vec::new();
485
486 let clean = strip_context_truncation_footer(content)
487 .trim_end()
488 .to_string();
489 let all_lines: Vec<&str> = clean.lines().collect();
490
491 if all_lines.is_empty() || (all_lines.len() == 1 && all_lines[0].is_empty()) {
492 return lines;
493 }
494
495 let preview_count = 5;
496 let (preview_lines, hidden_line_count) = if ctx.expanded {
497 (all_lines.clone(), 0)
498 } else {
499 truncate_to_visual_lines(&all_lines, width, preview_count)
500 };
501
502 if !ctx.expanded && hidden_line_count > 0 {
504 if ctx.expand_key.is_empty() {
505 lines.push(theme.fg_key(
506 ThemeKey::Muted,
507 &format!("... {} earlier lines", hidden_line_count),
508 ));
509 } else {
510 let prefix = theme.fg_key(
513 ThemeKey::Muted,
514 &format!("... ({} earlier lines, ", hidden_line_count),
515 );
516 let key_styled = theme.fg("dim", &ctx.expand_key);
517 let suffix = theme.fg_key(ThemeKey::Muted, " to expand)");
518 lines.push(format!("{}{}{}", prefix, key_styled, suffix));
519 }
520 }
521
522 let fg_key = if ctx.is_error { "error" } else { "toolOutput" };
523 for line in &preview_lines {
524 if line.is_empty() {
525 lines.push(String::new());
526 } else {
527 lines.push(theme.fg(fg_key, line));
528 }
529 }
530
531 if let Some(secs) = ctx.duration_secs {
532 if !lines.is_empty() {
533 lines.push(String::new());
534 }
535 let is_complete = ctx.exit_code.is_some() || ctx.cancelled;
536 let label = if is_complete { "Took" } else { "Elapsed" };
537 lines.push(theme.fg_key(ThemeKey::Muted, &format!("{} {:.1}s", label, secs)));
538 }
539
540 if ctx.was_truncated {
541 if !lines.is_empty() {
542 lines.push(String::new());
543 }
544 if let Some(ref path) = ctx.full_output_path {
545 lines.push(theme.fg(
546 "warning",
547 &format!("Output truncated. Full output: {}", path),
548 ));
549 } else {
550 lines.push(theme.fg_key(ThemeKey::Warning, "Output truncated."));
551 }
552 }
553
554 lines
555 }
556}
557
558fn strip_context_truncation_footer(output: &str) -> String {
559 let lines: Vec<&str> = output.lines().collect();
560 if lines.len() < 3 {
561 return output.to_string();
562 }
563 let last = lines.last().map_or("", |v| v).trim();
564 if last.starts_with('[')
565 && (last.contains("Showing lines") || last.contains("Showing last"))
566 && last.contains("Full output:")
567 {
568 let before: Vec<&str> = lines[..lines.len() - 1].to_vec();
569 if !before.is_empty() && before[before.len() - 1].is_empty() {
570 before[..before.len() - 1].join("\n")
571 } else {
572 before.join("\n")
573 }
574 } else {
575 output.to_string()
576 }
577}
578
579#[async_trait::async_trait]
580impl yoagent::types::AgentTool for BashTool {
581 fn name(&self) -> &str {
582 "bash"
583 }
584 fn label(&self) -> &str {
585 "bash"
586 }
587 fn description(&self) -> &str {
588 "Execute a bash command in the current working directory. Returns stdout and stderr. \
589 Output is truncated to last 2000 lines or 50KB (whichever is hit first). If \
590 truncated, full output is saved to a temp file. Optionally provide a timeout in seconds."
591 }
592 fn parameters_schema(&self) -> serde_json::Value {
593 serde_json::json!({
594 "type": "object",
595 "required": ["command"],
596 "properties": {
597 "command": {
598 "type": "string",
599 "description": "Bash command to execute"
600 },
601 "timeout": {
602 "type": "number",
603 "description": "Timeout in seconds (optional, no default timeout)"
604 }
605 }
606 })
607 }
608 async fn execute(
609 &self,
610 params: serde_json::Value,
611 ctx: yoagent::types::ToolContext,
612 ) -> std::result::Result<yoagent::types::ToolResult, yoagent::types::ToolError> {
613 let command = params["command"].as_str().ok_or_else(|| {
614 yoagent::types::ToolError::InvalidArgs("Missing 'command' argument".into())
615 })?;
616 let timeout = params["timeout"].as_u64();
617 let started_at = Instant::now();
618
619 if ctx.cancel.is_cancelled() {
620 return Err(yoagent::types::ToolError::Cancelled);
621 }
622
623 let effective_command = if let Some(ref prefix) = self.command_prefix {
625 format!("{}\n{}", prefix, command)
626 } else {
627 command.to_string()
628 };
629
630 if !self.cwd.exists() {
632 return Err(yoagent::types::ToolError::Failed(format!(
633 "Working directory does not exist: {}\nCannot execute bash commands.",
634 self.cwd.display()
635 )));
636 }
637
638 if let Some(ref ops) = self.operations {
640 let (output_tx, mut output_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
641 let ops_cancel = Cancel::new();
642
643 let yo_cancel = ctx.cancel.clone();
645 let watch_cancel = ops_cancel.clone();
646 tokio::spawn(async move {
647 yo_cancel.cancelled().await;
648 watch_cancel.cancel();
649 });
650
651 let ops_command = effective_command.clone();
652 let ops_cwd = self.cwd.clone();
653 let ops = ops.clone();
654 let ops_handle = tokio::spawn(async move {
655 ops.exec(
656 &ops_command,
657 &ops_cwd,
658 output_tx,
659 Some(&ops_cancel),
660 timeout,
661 None,
662 )
663 .await
664 });
665
666 let mut combined = String::new();
668 while let Some(chunk) = output_rx.recv().await {
669 combined.push_str(&chunk);
670 emit_update(&ctx, combined.clone(), None);
671 }
672
673 let exit_code = ops_handle.await.unwrap_or(Ok(None)).unwrap_or(None);
674 let code = exit_code.unwrap_or(-1);
675
676 return finish_bash_execution(&combined, code, ctx.cancel.is_cancelled(), None, &ctx);
677 }
678
679 let mut child =
680 spawn_bash_command(&effective_command, &self.cwd, self.shell_path.as_deref()).map_err(
681 |e| yoagent::types::ToolError::Failed(format!("Failed to spawn command: {}", e)),
682 )?;
683
684 let pid = child.id().unwrap_or(0);
685
686 let combined = Arc::new(TokioMutex::new(String::new()));
688 let combined_clone = combined.clone();
689
690 let stdout_pipe = child
691 .stdout
692 .take()
693 .ok_or_else(|| yoagent::types::ToolError::Failed("Failed to capture stdout".into()))?;
694 let stderr_pipe = child
695 .stderr
696 .take()
697 .ok_or_else(|| yoagent::types::ToolError::Failed("Failed to capture stderr".into()))?;
698
699 use tokio::io::AsyncReadExt;
700 let read_task = tokio::spawn(async move {
701 let mut stdout_buf = vec![0u8; 65536];
702 let mut stderr_buf = vec![0u8; 65536];
703 let mut stdout_reader = stdout_pipe;
704 let mut stderr_reader = stderr_pipe;
705 let mut stdout_done = false;
706 let mut stderr_done = false;
707 loop {
708 tokio::select! {
709 result = stdout_reader.read(&mut stdout_buf), if !stdout_done => {
710 match result {
711 Ok(0) => stdout_done = true,
712 Ok(n) => {
713 let text = String::from_utf8_lossy(&stdout_buf[..n]);
714 let sanitized = sanitize_output(&text);
715 let mut out = combined_clone.lock().await;
716 out.push_str(&sanitized);
717 }
718 Err(_) => stdout_done = true,
719 }
720 }
721 result = stderr_reader.read(&mut stderr_buf), if !stderr_done => {
722 match result {
723 Ok(0) => stderr_done = true,
724 Ok(n) => {
725 let text = String::from_utf8_lossy(&stderr_buf[..n]);
726 let sanitized = sanitize_output(&text);
727 let mut out = combined_clone.lock().await;
728 out.push_str(&sanitized);
729 }
730 Err(_) => stderr_done = true,
731 }
732 }
733 }
734 if stdout_done && stderr_done {
735 break;
736 }
737 }
738 });
739
740 let _pid_guard = ProcessGuard::new(pid);
742
743 let cancelled = Arc::new(AtomicBool::new(false));
745 let cancel_flag = cancelled.clone();
746 let yo_cancel = ctx.cancel.clone();
747 let _cancel_monitor: tokio::task::JoinHandle<()> = tokio::spawn(async move {
748 yo_cancel.cancelled().await;
749 cancel_flag.store(true, Ordering::SeqCst);
750 kill_process_group(pid);
751 });
752
753 if let Some(ref on_update) = ctx.on_update {
755 on_update(yoagent::types::ToolResult {
756 content: vec![],
757 details: serde_json::Value::Null,
758 });
759 }
760
761 let timeout_dur = timeout.map(std::time::Duration::from_secs);
763 let throttle_ms = 100u64;
764 let mut last_update_at = Instant::now();
765
766 let exit_code: i32;
767
768 loop {
769 if cancelled.load(Ordering::SeqCst) {
770 kill_process_group(pid);
771 read_task.abort();
772 let combined_str = combined.lock().await.clone();
773 return finish_bash_execution(&combined_str, -1, true, None, &ctx);
774 }
775
776 if let Some(dur) = timeout_dur
777 && started_at.elapsed() > dur
778 {
779 kill_process_group(pid);
780 read_task.abort();
781 let combined_str = combined.lock().await.clone();
782 return finish_bash_execution(&combined_str, -1, false, timeout, &ctx);
783 }
784
785 if last_update_at.elapsed().as_millis() as u64 >= throttle_ms {
786 let out = combined.lock().await.clone();
787 if !out.is_empty() {
788 last_update_at = Instant::now();
789 emit_update(&ctx, out, None);
790 }
791 }
792
793 match child.try_wait() {
794 Ok(Some(status)) => {
795 exit_code = status.code().unwrap_or(-1);
796 let mut last_len = combined.lock().await.len();
798 loop {
799 tokio::time::sleep(std::time::Duration::from_millis(EXIT_STDIO_GRACE_MS))
800 .await;
801 let new_len = combined.lock().await.len();
802 if new_len == last_len {
803 break;
804 }
805 last_len = new_len;
806 }
807 read_task.abort();
808 break;
809 }
810 Ok(None) => {
811 tokio::time::sleep(std::time::Duration::from_millis(throttle_ms)).await;
812 }
813 Err(_) => {
814 read_task.await.ok();
815 exit_code = -1;
816 break;
817 }
818 }
819 }
820
821 let combined_str = combined.lock().await.clone();
822 if !combined_str.is_empty() {
823 emit_update(&ctx, combined_str.clone(), None);
824 }
825
826 finish_bash_execution(&combined_str, exit_code, false, None, &ctx)
827 }
828}
829
830fn cleanup_stale_temp_files() {
833 let dir = temp_output_dir();
834 let Ok(entries) = std::fs::read_dir(&dir) else {
835 return;
836 };
837 let Ok(cutoff) = std::time::SystemTime::now()
838 .checked_sub(std::time::Duration::from_secs(TEMP_FILE_MAX_AGE_SECS))
839 .ok_or(())
840 else {
841 return;
842 };
843 for entry in entries.flatten() {
844 let path = entry.path();
845 if path.extension().is_none_or(|e| e != "log") {
846 continue;
847 }
848 if let Ok(metadata) = path.metadata()
849 && let Ok(modified) = metadata.modified()
850 && modified < cutoff
851 {
852 let _ = std::fs::remove_file(&path);
853 }
854 }
855}
856
857fn temp_output_dir() -> PathBuf {
859 std::env::temp_dir().join(BASH_TEMP_FILE_PREFIX)
860}
861
862fn format_status_output(result_text: &str, status_msg: &str) -> String {
864 if result_text.is_empty() || result_text == "(no output)" {
865 status_msg.to_string()
866 } else {
867 format!("{}\n\n{}", result_text, status_msg)
868 }
869}
870
871fn into_tool_result(
873 text: String,
874 details: Option<serde_json::Value>,
875) -> yoagent::types::ToolResult {
876 yoagent::types::ToolResult {
877 content: vec![yoagent::types::Content::Text { text }],
878 details: details.unwrap_or(serde_json::Value::Null),
879 }
880}
881
882fn emit_update(
884 ctx: &yoagent::types::ToolContext,
885 text: String,
886 details: Option<serde_json::Value>,
887) {
888 if let Some(ref on_update) = ctx.on_update {
889 on_update(into_tool_result(text, details));
890 }
891}
892
893use std::sync::Mutex;
898
899static TRACKED_PIDS: Mutex<Vec<u32>> = std::sync::Mutex::new(Vec::new());
900
901fn track_pid(pid: u32) {
902 if let Ok(mut pids) = TRACKED_PIDS.lock() {
903 pids.push(pid);
904 }
905}
906
907fn untrack_pid(pid: u32) {
908 if let Ok(mut pids) = TRACKED_PIDS.lock() {
909 pids.retain(|&p| p != pid);
910 }
911}
912
913pub fn kill_tracked_children() {
915 let pids: Vec<u32> = TRACKED_PIDS.lock().map(|p| p.clone()).unwrap_or_default();
916 for pid in pids {
917 kill_process_group(pid);
918 }
919}
920
921struct ProcessGuard {
922 pid: u32,
923}
924
925impl ProcessGuard {
926 fn new(pid: u32) -> Self {
927 if pid > 0 {
928 track_pid(pid);
929 }
930 Self { pid }
931 }
932}
933
934impl Drop for ProcessGuard {
935 fn drop(&mut self) {
936 if self.pid > 0 {
937 untrack_pid(self.pid);
938 }
939 }
940}
941
942#[cfg(test)]
943mod tests {
944 use super::*;
945 use yoagent::AgentTool;
946
947 fn tool_ctx() -> yoagent::types::ToolContext {
948 yoagent::types::ToolContext {
949 tool_call_id: "id".into(),
950 tool_name: "bash".into(),
951 cancel: tokio_util::sync::CancellationToken::new(),
952 on_update: None,
953 on_progress: None,
954 }
955 }
956
957 fn yo_msg_text(content: &[yoagent::types::Content]) -> String {
958 content
959 .iter()
960 .filter_map(|c| {
961 if let yoagent::types::Content::Text { text } = c {
962 Some(text.as_str())
963 } else {
964 None
965 }
966 })
967 .collect::<Vec<_>>()
968 .join("")
969 }
970
971 fn make_tool() -> BashTool {
972 BashTool {
973 cwd: std::env::temp_dir(),
974 shell_path: None,
975 command_prefix: None,
976 operations: None,
977 }
978 }
979
980 #[tokio::test]
981 async fn runs_simple_command() {
982 let tool = make_tool();
983 let output = tool
984 .execute(serde_json::json!({"command": "echo hello"}), tool_ctx())
985 .await
986 .unwrap();
987 assert!(yo_msg_text(&output.content).contains("hello"));
988 }
989
990 #[tokio::test]
991 async fn captures_stderr() {
992 let tool = make_tool();
993 let output = tool
994 .execute(serde_json::json!({"command": "echo err >&2"}), tool_ctx())
995 .await
996 .unwrap();
997 assert!(yo_msg_text(&output.content).contains("err"));
998 }
999
1000 #[tokio::test]
1001 async fn cancel_aborts() {
1002 let tool = make_tool();
1003 let cancel = tokio_util::sync::CancellationToken::new();
1004 cancel.cancel();
1005 let result = tool
1006 .execute(
1007 serde_json::json!({"command": "sleep 10"}),
1008 yoagent::types::ToolContext {
1009 tool_call_id: "id".into(),
1010 tool_name: "bash".into(),
1011 cancel,
1012 on_update: None,
1013 on_progress: None,
1014 },
1015 )
1016 .await;
1017 assert!(result.is_err());
1018 let err = result.unwrap_err().to_string();
1019 assert!(
1020 err.contains("Cancelled") || err.contains("aborted"),
1021 "expected cancellation error, got: {}",
1022 err
1023 );
1024 }
1025
1026 #[tokio::test]
1027 async fn timeout_works() {
1028 let tool = make_tool();
1029 let result = tool
1030 .execute(
1031 serde_json::json!({"command": "sleep 10", "timeout": 1}),
1032 tool_ctx(),
1033 )
1034 .await;
1035 assert!(result.is_err());
1036 let err = result.unwrap_err().to_string();
1037 assert!(err.contains("timed out"));
1038 }
1039
1040 #[test]
1041 fn test_truncate_tail_no_truncation() {
1042 let result = truncate_tail("hello\nworld\n", 2000, 50000);
1043 assert!(!result.truncated);
1044 assert_eq!(result.content, "hello\nworld\n");
1045 }
1046
1047 #[test]
1048 fn test_truncate_tail_by_lines() {
1049 let content: String = (1..=5000).map(|i| format!("line {}\n", i)).collect();
1050 let result = truncate_tail(&content, 2000, 50000);
1051 assert!(result.truncated);
1052 assert!(result.content.starts_with("line 3001"));
1053 assert_eq!(result.content.lines().count(), 2000);
1054 }
1055
1056 #[test]
1057 fn test_truncate_tail_by_bytes() {
1058 let content: String = (1..=100)
1059 .map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
1060 .collect();
1061 let result = truncate_tail(&content, 2000, 50000);
1062 assert!(result.truncated);
1063 assert!(result.content.len() <= 50000);
1064 assert!(result.content.lines().count() < 100);
1065 }
1066
1067 #[test]
1068 fn test_truncate_tail_partial_last_line() {
1069 let content = format!("short\n{}\n", "x".repeat(60000));
1070 let result = truncate_tail(&content, 2000, 50000);
1071 assert!(result.truncated);
1072 assert!(!result.content.starts_with("short"));
1073 assert!(result.content.len() <= 50000);
1074 }
1075
1076 #[test]
1077 fn test_truncate_tail_empty() {
1078 let result = truncate_tail("", 2000, 50000);
1079 assert!(!result.truncated);
1080 assert_eq!(result.content, "");
1081 }
1082
1083 #[tokio::test]
1084 async fn exit_code_nonzero() {
1085 let tool = make_tool();
1086 let result = tool
1087 .execute(serde_json::json!({"command": "exit 42"}), tool_ctx())
1088 .await;
1089 assert!(result.is_err(), "non-zero exit should return error");
1090 let err = result.unwrap_err().to_string();
1091 assert!(err.contains("exited with code 42"), "got: {}", err);
1092 }
1093
1094 #[tokio::test]
1095 async fn exit_code_with_output() {
1096 let tool = make_tool();
1097 let result = tool
1098 .execute(
1099 serde_json::json!({"command": "echo before && exit 1"}),
1100 tool_ctx(),
1101 )
1102 .await;
1103 assert!(result.is_err(), "non-zero exit should return error");
1104 let err = result.unwrap_err().to_string();
1105 assert!(err.contains("before"), "got: {}", err);
1106 assert!(err.contains("exited with code 1"), "got: {}", err);
1107 }
1108
1109 #[tokio::test]
1110 async fn no_output() {
1111 let tool = make_tool();
1112 let output = tool
1113 .execute(serde_json::json!({"command": "true"}), tool_ctx())
1114 .await
1115 .unwrap();
1116 assert!(
1117 yo_msg_text(&output.content).contains("(no output)"),
1118 "got: {}",
1119 yo_msg_text(&output.content)
1120 );
1121 }
1122
1123 #[tokio::test]
1124 async fn combined_stdout_stderr() {
1125 let tool = make_tool();
1126 let output = tool
1127 .execute(
1128 serde_json::json!({"command": "echo out; echo err >&2"}),
1129 tool_ctx(),
1130 )
1131 .await
1132 .unwrap();
1133 assert!(
1134 yo_msg_text(&output.content).contains("out"),
1135 "got: {}",
1136 yo_msg_text(&output.content)
1137 );
1138 assert!(
1139 yo_msg_text(&output.content).contains("err"),
1140 "got: {}",
1141 yo_msg_text(&output.content)
1142 );
1143 }
1144
1145 #[tokio::test]
1146 async fn runs_in_cwd() {
1147 let tmp = std::env::temp_dir().join(format!("rab-bash-cwd-{}", uuid::Uuid::new_v4()));
1148 std::fs::create_dir_all(&tmp).unwrap();
1149 std::fs::write(tmp.join("marker.txt"), "hello").unwrap();
1150
1151 let tool = BashTool {
1152 cwd: tmp.clone(),
1153 shell_path: None,
1154 command_prefix: None,
1155 operations: None,
1156 };
1157 let output = tool
1158 .execute(serde_json::json!({"command": "cat marker.txt"}), tool_ctx())
1159 .await
1160 .unwrap();
1161 assert!(
1162 yo_msg_text(&output.content).contains("hello"),
1163 "got: {}",
1164 yo_msg_text(&output.content)
1165 );
1166 }
1167
1168 #[tokio::test]
1169 async fn missing_command_errors() {
1170 let tool = make_tool();
1171 let result = tool.execute(serde_json::json!({}), tool_ctx()).await;
1172 assert!(result.is_err());
1173 let err = result.unwrap_err().to_string();
1174 assert!(err.contains("command"), "got: {}", err);
1175 }
1176
1177 #[tokio::test]
1178 async fn timeout_with_partial_output() {
1179 let tool = make_tool();
1180 let result = tool
1181 .execute(
1182 serde_json::json!({"command": "echo start && sleep 10 && echo end", "timeout": 1}),
1183 tool_ctx(),
1184 )
1185 .await;
1186 assert!(result.is_err());
1187 let err = result.unwrap_err().to_string();
1188 assert!(err.contains("timed out"), "got: {}", err);
1189 }
1190
1191 #[tokio::test]
1192 async fn cancel_during_long_command() {
1193 let tool = make_tool();
1194 let cancel = tokio_util::sync::CancellationToken::new();
1195 let cancel_ctx = cancel.clone();
1196
1197 let handle = tokio::spawn(async move {
1198 tool.execute(
1199 serde_json::json!({"command": "sleep 30"}),
1200 yoagent::types::ToolContext {
1201 tool_call_id: "id".into(),
1202 tool_name: "bash".into(),
1203 cancel: cancel_ctx,
1204 on_update: None,
1205 on_progress: None,
1206 },
1207 )
1208 .await
1209 });
1210
1211 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1212 cancel.cancel();
1213
1214 let result = handle.await.unwrap();
1215 assert!(result.is_err());
1216 let err = result.unwrap_err().to_string();
1217 assert!(
1218 err.contains("aborted") || err.contains("Cancelled"),
1219 "expected cancellation error, got: {}",
1220 err
1221 );
1222 }
1223
1224 #[test]
1225 fn test_truncate_tail_exact_line_fit() {
1226 let lines: String = (1..=2000).map(|i| format!("line {}\n", i)).collect();
1227 let result = truncate_tail(&lines, 2000, 50000);
1228 assert!(!result.truncated);
1229 assert!(result.content.lines().count() == 2000);
1230 }
1231
1232 #[test]
1233 fn test_truncate_tail_one_over_line_limit() {
1234 let lines: String = (1..=2001).map(|i| format!("line {}\n", i)).collect();
1235 let result = truncate_tail(&lines, 2000, 50000);
1236 assert!(result.truncated);
1237 assert_eq!(result.content.lines().count(), 2000);
1238 assert!(result.content.starts_with("line 2"));
1239 }
1240
1241 #[test]
1242 fn test_truncate_tail_exact_byte_fit() {
1243 let line = "a".repeat(50000);
1244 let result = truncate_tail(&line, 2000, 50000);
1245 assert!(!result.truncated);
1246 }
1247
1248 #[test]
1249 fn test_truncate_tail_one_byte_over() {
1250 let line = "a".repeat(50001);
1251 let result = truncate_tail(&line, 2000, 50000);
1252 assert!(result.truncated);
1253 assert!(result.content.len() <= 50000);
1254 }
1255
1256 #[test]
1257 fn test_truncate_tail_single_line_under_limit() {
1258 let result = truncate_tail("hello world", 2000, 50000);
1259 assert!(!result.truncated);
1260 assert_eq!(result.content, "hello world");
1261 }
1262
1263 #[test]
1264 fn test_truncate_tail_trailing_newline() {
1265 let result = truncate_tail("a\nb\n", 2000, 50000);
1266 assert!(!result.truncated);
1267 assert_eq!(result.content, "a\nb\n");
1268 }
1269
1270 #[test]
1271 fn test_truncate_tail_no_trailing_newline() {
1272 let result = truncate_tail("a\nb", 2000, 50000);
1273 assert!(!result.truncated);
1274 assert_eq!(result.content, "a\nb");
1275 }
1276
1277 #[test]
1278 fn test_truncate_tail_single_line_exceeds_limit() {
1279 let content = "x".repeat(60000);
1280 let result = truncate_tail(&content, 2000, 50000);
1281 assert!(result.truncated);
1282 assert!(result.last_line_partial);
1283 assert_eq!(result.content.len(), 50000);
1284 assert!(result.content.ends_with("x".repeat(50000).as_str()));
1285 }
1286
1287 #[test]
1288 fn test_truncate_tail_byte_count_respects_newlines() {
1289 let content: String = (1..=100)
1290 .map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
1291 .collect();
1292 let result = truncate_tail(&content, 2000, 50000);
1293 assert!(result.truncated);
1294 assert!(result.output_bytes <= 50000);
1295 }
1296
1297 #[tokio::test]
1298 async fn truncated_by_lines_shows_footer() {
1299 let tool = make_tool();
1300 let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
1301 let output = tool
1302 .execute(serde_json::json!({"command": cmd}), tool_ctx())
1303 .await
1304 .unwrap();
1305 assert!(
1306 yo_msg_text(&output.content).contains("Showing lines"),
1307 "got: {}",
1308 yo_msg_text(&output.content)
1309 );
1310 assert!(
1311 yo_msg_text(&output.content).contains("Full output:"),
1312 "got: {}",
1313 yo_msg_text(&output.content)
1314 );
1315 }
1316
1317 #[tokio::test]
1318 async fn small_output_no_footer() {
1319 let tool = make_tool();
1320 let output = tool
1321 .execute(serde_json::json!({"command": "echo hello"}), tool_ctx())
1322 .await
1323 .unwrap();
1324 assert!(!yo_msg_text(&output.content).contains("Output truncated"));
1325 assert!(!yo_msg_text(&output.content).contains("Full output:"));
1326 }
1327
1328 #[tokio::test]
1329 async fn truncated_saves_temp_file() {
1330 let tool = make_tool();
1331 let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
1332 let output = tool
1333 .execute(serde_json::json!({"command": cmd}), tool_ctx())
1334 .await
1335 .unwrap();
1336 assert!(
1337 yo_msg_text(&output.content).contains("/rab-bash/"),
1338 "expected temp file path with /rab-bash/, got: {}",
1339 yo_msg_text(&output.content)
1340 );
1341 }
1342
1343 #[test]
1344 fn test_cleanup_stale_temp_files_nonexistent_dir() {
1345 cleanup_stale_temp_files();
1347 }
1348
1349 #[test]
1350 fn test_truncate_tail_many_short_lines() {
1351 let content: String = (1..=10000).map(|i| format!("{}\n", i)).collect();
1352 let result = truncate_tail(&content, 2000, 50000);
1353 assert!(result.truncated);
1354 assert_eq!(result.truncated_by, "lines");
1355 assert_eq!(result.output_lines, 2000);
1356 assert!(
1357 result.content.starts_with("8001"),
1358 "starts with: {:?}",
1359 &result.content[..10]
1360 );
1361 }
1362
1363 #[test]
1364 fn test_truncate_tail_lines_and_bytes_both_exceeded() {
1365 let content: String = (1..=5000)
1366 .map(|i| format!("line {} {}\n", i, "x".repeat(100)))
1367 .collect();
1368 let result = truncate_tail(&content, 2000, 30000);
1369 assert!(result.truncated);
1370 assert_eq!(result.truncated_by, "bytes");
1371 assert!(result.output_lines < 2000);
1372 }
1373
1374 #[test]
1377 fn test_process_guard_tracks_pid() {
1378 let pid = 12345u32;
1379 {
1380 let _guard = ProcessGuard::new(pid);
1381 let pids = TRACKED_PIDS.lock().unwrap();
1382 assert!(pids.contains(&pid));
1383 }
1384 let pids = TRACKED_PIDS.lock().unwrap();
1385 assert!(!pids.contains(&pid));
1386 }
1387
1388 #[test]
1389 fn test_process_guard_zero_pid() {
1390 {
1391 let _guard = ProcessGuard::new(0);
1392 let pids = TRACKED_PIDS.lock().unwrap();
1393 assert!(!pids.contains(&0));
1394 }
1395 }
1396}