1use crate::agent::extension::{Cancel, Extension, ToolDefinition};
2use crate::agent::extension::{ToolRenderContext, ToolRenderer};
3use crate::tui::Theme;
4use crate::tui::ThemeKey;
5use crate::tui::visual_truncate::truncate_to_visual_lines;
6use async_trait::async_trait;
7
8use std::borrow::Cow;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::time::Instant;
14use tokio::sync::{Mutex as TokioMutex, mpsc::UnboundedSender};
15
16#[async_trait]
21pub trait BashOperations: Send + Sync {
22 async fn exec(
25 &self,
26 command: &str,
27 cwd: &Path,
28 on_data: UnboundedSender<String>,
29 signal: Option<&Cancel>,
30 timeout: Option<u64>,
31 env: Option<HashMap<String, String>>,
32 ) -> Result<Option<i32>, anyhow::Error>;
33}
34
35#[derive(Clone, Default)]
36pub struct BashToolOptions {
37 pub operations: Option<Arc<dyn BashOperations>>,
39 pub command_prefix: Option<String>,
41 pub shell_path: Option<String>,
43}
44
45pub struct BashExtension {
46 cwd: PathBuf,
47 options: BashToolOptions,
48}
49
50impl BashExtension {
51 pub fn new(cwd: PathBuf) -> Self {
52 Self {
53 cwd,
54 options: BashToolOptions::default(),
55 }
56 }
57
58 pub fn with_options(cwd: PathBuf, options: BashToolOptions) -> Self {
59 Self { cwd, options }
60 }
61
62 pub fn with_shell_path(cwd: PathBuf, shell_path: String) -> Self {
63 Self {
64 cwd,
65 options: BashToolOptions {
66 shell_path: Some(shell_path),
67 ..BashToolOptions::default()
68 },
69 }
70 }
71}
72
73impl Extension for BashExtension {
74 fn name(&self) -> Cow<'static, str> {
75 "bash".into()
76 }
77
78 fn tools(&self) -> Vec<ToolDefinition> {
79 vec![ToolDefinition {
80 tool: Box::new(BashTool {
81 cwd: self.cwd.clone(),
82 shell_path: self.options.shell_path.clone(),
83 command_prefix: self.options.command_prefix.clone(),
84 operations: self.options.operations.clone(),
85 }),
86 snippet: "Execute bash commands (ls, grep, find, etc.)",
87 guidelines: &[],
88 prepare_arguments: None,
89 before_tool_call: None,
90 after_tool_call: None,
91 renderer: Some(std::sync::Arc::new(BashRenderer)),
92 }]
93 }
94}
95
96struct BashTool {
97 cwd: PathBuf,
98 shell_path: Option<String>,
99 command_prefix: Option<String>,
100 operations: Option<Arc<dyn BashOperations>>,
101}
102
103const DEFAULT_MAX_LINES: usize = 2000;
106const DEFAULT_MAX_BYTES: usize = 50 * 1024; const BASH_TEMP_FILE_PREFIX: &str = "rab-bash";
108
109const TEMP_FILE_MAX_AGE_SECS: u64 = 24 * 60 * 60;
111
112const EXIT_STDIO_GRACE_MS: u64 = 100;
115
116struct ShellConfig {
120 shell: String,
121 args: Vec<String>,
122}
123
124fn resolve_shell(shell_path: Option<&str>) -> ShellConfig {
130 if let Some(path) = shell_path {
131 return ShellConfig {
132 shell: path.to_string(),
133 args: vec!["-c".to_string()],
134 };
135 }
136
137 if std::path::Path::new("/bin/bash").exists() {
139 return ShellConfig {
140 shell: "/bin/bash".to_string(),
141 args: vec!["-c".to_string()],
142 };
143 }
144
145 #[cfg(unix)]
147 {
148 if let Ok(output) = std::process::Command::new("which")
149 .arg("bash")
150 .stdout(std::process::Stdio::piped())
151 .stderr(std::process::Stdio::null())
152 .output()
153 && output.status.success()
154 {
155 let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
156 if !path.is_empty() && std::path::Path::new(&path).exists() {
157 return ShellConfig {
158 shell: path,
159 args: vec!["-c".to_string()],
160 };
161 }
162 }
163 }
164
165 ShellConfig {
167 shell: "sh".to_string(),
168 args: vec!["-c".to_string()],
169 }
170}
171
172#[cfg(unix)]
176fn kill_process_group(pid: u32) {
177 if pid > 0 {
178 let _ = std::process::Command::new("kill")
179 .arg("--")
180 .arg(format!("-{}", pid))
181 .status();
182 }
183}
184
185#[cfg(not(unix))]
186fn kill_process_group(pid: u32) {
187 let _ = pid;
188}
189
190fn spawn_bash_command(
192 command: &str,
193 cwd: &std::path::Path,
194 shell_path: Option<&str>,
195) -> std::io::Result<tokio::process::Child> {
196 let shell_cfg = resolve_shell(shell_path);
197
198 #[cfg(unix)]
199 {
200 use std::os::unix::process::CommandExt;
201 let mut std_cmd = std::process::Command::new(&shell_cfg.shell);
202 std_cmd.args(&shell_cfg.args).arg(command).current_dir(cwd);
203 unsafe {
204 std_cmd.pre_exec(|| {
205 libc::setpgid(0, 0);
206 Ok(())
207 });
208 }
209 let mut tokio_cmd = tokio::process::Command::from(std_cmd);
210 tokio_cmd
211 .stdin(std::process::Stdio::null())
212 .stdout(std::process::Stdio::piped())
213 .stderr(std::process::Stdio::piped())
214 .spawn()
215 }
216 #[cfg(not(unix))]
217 {
218 tokio::process::Command::new(&shell_cfg.shell)
219 .args(&shell_cfg.args)
220 .arg(command)
221 .current_dir(cwd)
222 .stdin(std::process::Stdio::null())
223 .stdout(std::process::Stdio::piped())
224 .stderr(std::process::Stdio::piped())
225 .spawn()
226 }
227}
228
229fn sanitize_output(text: &str) -> String {
231 let mut result = String::with_capacity(text.len());
232 let mut in_escape = false;
233 for c in text.chars() {
234 if in_escape {
235 if c == '\x1b' || c == '\u{9b}' {
236 continue;
237 }
238 if c.is_ascii_alphabetic() || c == '~' {
239 in_escape = false;
240 }
241 continue;
242 }
243 if c == '\x1b' || c == '\u{9b}' {
244 in_escape = true;
245 continue;
246 }
247 let code = c as u32;
248 if code <= 0x1f && code != 0x09 && code != 0x0a && code != 0x0d {
249 continue;
250 }
251 if (0xfff9..=0xfffb).contains(&code) {
252 continue;
253 }
254 result.push(c);
255 }
256 result
257}
258
259fn format_size(bytes: usize) -> String {
260 if bytes < 1024 {
261 format!("{}B", bytes)
262 } else if bytes < 1024 * 1024 {
263 format!("{:.1}KB", bytes as f64 / 1024.0)
264 } else {
265 format!("{:.1}MB", bytes as f64 / (1024.0 * 1024.0))
266 }
267}
268
269struct TailTruncation {
271 content: String,
272 truncated: bool,
273 total_lines: usize,
274 output_lines: usize,
275 output_bytes: usize,
276 truncated_by: &'static str,
277 last_line_partial: bool,
278}
279
280fn truncate_tail(content: &str, max_lines: usize, max_bytes: usize) -> TailTruncation {
282 let total_bytes = content.len();
283 let lines: Vec<&str> = content.lines().collect();
284 let total_lines = lines.len();
285
286 if total_lines <= max_lines && total_bytes <= max_bytes {
287 return TailTruncation {
288 content: content.to_string(),
289 truncated: false,
290 total_lines,
291 output_lines: total_lines,
292 output_bytes: total_bytes,
293 truncated_by: "",
294 last_line_partial: false,
295 };
296 }
297
298 let mut output: Vec<&str> = Vec::new();
299 let mut byte_count: usize = 0;
300 let mut truncated_by = "lines";
301 let mut last_line_partial = false;
302
303 for line in lines.iter().rev().take(max_lines) {
304 let line_bytes = line.len();
305 let with_newline = if output.is_empty() {
306 line_bytes
307 } else {
308 line_bytes + 1
309 };
310
311 if byte_count + with_newline > max_bytes {
312 truncated_by = "bytes";
313 if output.is_empty() {
314 let end_start = line.len().saturating_sub(max_bytes);
315 let truncated_line = &line[end_start..];
316 output.push(truncated_line);
317 byte_count = truncated_line.len();
318 last_line_partial = true;
319 }
320 break;
321 }
322
323 output.push(line);
324 byte_count += with_newline;
325 }
326
327 if output.len() >= max_lines && byte_count <= max_bytes {
328 truncated_by = "lines";
329 }
330
331 output.reverse();
332 TailTruncation {
333 content: output.join("\n"),
334 truncated: true,
335 total_lines,
336 output_lines: output.len(),
337 output_bytes: byte_count,
338 truncated_by,
339 last_line_partial,
340 }
341}
342
343fn finish_bash_execution(
346 combined: &str,
347 exit_code: i32,
348 cancelled: bool,
349 timed_out: Option<u64>,
350 ctx: &yoagent::types::ToolContext,
351) -> std::result::Result<yoagent::types::ToolResult, yoagent::types::ToolError> {
352 let trunc = truncate_tail(combined, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES);
353
354 let mut result_text = if trunc.content.is_empty() {
355 "(no output)".to_string()
356 } else {
357 trunc.content.clone()
358 };
359
360 let full_output_path = if trunc.truncated {
362 let tmp_dir = temp_output_dir();
363 let _ = std::fs::create_dir_all(&tmp_dir);
364 let tmp_path = tmp_dir.join(format!("{}.log", uuid::Uuid::new_v4()));
365 let saved = std::fs::write(&tmp_path, combined).ok().map(|_| {
366 cleanup_stale_temp_files();
367 tmp_path
368 });
369
370 let start_line = trunc.total_lines - trunc.output_lines + 1;
371 let end_line = trunc.total_lines;
372
373 let notice = if trunc.truncated_by == "lines" {
374 format!(
375 "\n\n[Showing lines {}-{} of {}. Full output: {}]",
376 start_line,
377 end_line,
378 trunc.total_lines,
379 saved
380 .as_ref()
381 .map(|p| p.display().to_string())
382 .unwrap_or_default()
383 )
384 } else {
385 format!(
386 "\n\n[Showing lines {}-{} of {} ({} limit). Full output: {}]",
387 start_line,
388 end_line,
389 trunc.total_lines,
390 format_size(DEFAULT_MAX_BYTES),
391 saved
392 .as_ref()
393 .map(|p| p.display().to_string())
394 .unwrap_or_default()
395 )
396 };
397 result_text.push_str(¬ice);
398 saved
399 } else {
400 None
401 };
402
403 let details = if trunc.truncated || full_output_path.is_some() {
405 Some(serde_json::json!({
406 "truncation": {
407 "truncated": trunc.truncated,
408 "truncatedBy": trunc.truncated_by,
409 "totalLines": trunc.total_lines,
410 "outputLines": trunc.output_lines,
411 "outputBytes": trunc.output_bytes,
412 "lastLinePartial": trunc.last_line_partial,
413 "maxLines": DEFAULT_MAX_LINES,
414 "maxBytes": DEFAULT_MAX_BYTES,
415 },
416 "fullOutputPath": full_output_path.as_ref().map(|p| p.display().to_string()),
417 }))
418 } else {
419 None
420 };
421
422 let final_output = if cancelled {
423 format_status_output(&result_text, "Command aborted")
424 } else if let Some(secs) = timed_out {
425 format_status_output(
426 &result_text,
427 &format!("Command timed out after {} seconds", secs),
428 )
429 } else if exit_code != 0 {
430 format_status_output(
431 &result_text,
432 &format!("Command exited with code {}", exit_code),
433 )
434 } else {
435 emit_update(ctx, result_text.clone(), details.clone());
436 return Ok(into_tool_result(result_text, details));
437 };
438
439 emit_update(ctx, final_output.clone(), details.clone());
440 Err(yoagent::types::ToolError::Failed(final_output))
441}
442
443struct BashRenderer;
448
449impl ToolRenderer for BashRenderer {
450 fn render_call(
451 &self,
452 args: &serde_json::Value,
453 _width: usize,
454 theme: &dyn Theme,
455 _ctx: &ToolRenderContext,
456 ) -> Vec<String> {
457 let cmd = args
458 .get("command")
459 .and_then(|v| v.as_str())
460 .unwrap_or("...");
461 let timeout = args.get("timeout").and_then(|v| v.as_i64());
462 let timeout_suffix = timeout
463 .map(|t| theme.fg_key(ThemeKey::Muted, &format!(" (timeout {}s)", t)))
464 .unwrap_or_default();
465
466 vec![format!(
467 "{}{}",
468 theme.fg_key(ThemeKey::ToolTitle, &theme.bold(&format!("$ {}", cmd))),
469 timeout_suffix
470 )]
471 }
472
473 fn render_result(
474 &self,
475 content: &str,
476 width: usize,
477 theme: &dyn Theme,
478 ctx: &ToolRenderContext,
479 ) -> Vec<String> {
480 let mut lines: Vec<String> = Vec::new();
481
482 let clean = strip_context_truncation_footer(content)
483 .trim_end()
484 .to_string();
485 let all_lines: Vec<&str> = clean.lines().collect();
486
487 if all_lines.is_empty() || (all_lines.len() == 1 && all_lines[0].is_empty()) {
488 return lines;
489 }
490
491 let preview_count = 5;
492 let (preview_lines, hidden_line_count) = if ctx.expanded {
493 (all_lines.clone(), 0)
494 } else {
495 truncate_to_visual_lines(&all_lines, width, preview_count)
496 };
497
498 if !ctx.expanded && hidden_line_count > 0 {
500 if ctx.expand_key.is_empty() {
501 lines.push(theme.fg_key(
502 ThemeKey::Muted,
503 &format!("... {} earlier lines", hidden_line_count),
504 ));
505 } else {
506 let prefix = theme.fg_key(
509 ThemeKey::Muted,
510 &format!("... ({} earlier lines, ", hidden_line_count),
511 );
512 let key_styled = theme.fg("dim", &ctx.expand_key);
513 let suffix = theme.fg_key(ThemeKey::Muted, " to expand)");
514 lines.push(format!("{}{}{}", prefix, key_styled, suffix));
515 }
516 }
517
518 let fg_key = if ctx.is_error { "error" } else { "toolOutput" };
519 for line in &preview_lines {
520 if line.is_empty() {
521 lines.push(String::new());
522 } else {
523 lines.push(theme.fg(fg_key, line));
524 }
525 }
526
527 if let Some(secs) = ctx.duration_secs {
528 if !lines.is_empty() {
529 lines.push(String::new());
530 }
531 let is_complete = ctx.exit_code.is_some() || ctx.cancelled;
532 let label = if is_complete { "Took" } else { "Elapsed" };
533 lines.push(theme.fg_key(ThemeKey::Muted, &format!("{} {:.1}s", label, secs)));
534 }
535
536 if ctx.was_truncated {
537 if !lines.is_empty() {
538 lines.push(String::new());
539 }
540 if let Some(ref path) = ctx.full_output_path {
541 lines.push(theme.fg(
542 "warning",
543 &format!("Output truncated. Full output: {}", path),
544 ));
545 } else {
546 lines.push(theme.fg_key(ThemeKey::Warning, "Output truncated."));
547 }
548 }
549
550 lines
551 }
552}
553
554fn strip_context_truncation_footer(output: &str) -> String {
555 let lines: Vec<&str> = output.lines().collect();
556 if lines.len() < 3 {
557 return output.to_string();
558 }
559 let last = lines.last().map_or("", |v| v).trim();
560 if last.starts_with('[')
561 && (last.contains("Showing lines") || last.contains("Showing last"))
562 && last.contains("Full output:")
563 {
564 let before: Vec<&str> = lines[..lines.len() - 1].to_vec();
565 if !before.is_empty() && before[before.len() - 1].is_empty() {
566 before[..before.len() - 1].join("\n")
567 } else {
568 before.join("\n")
569 }
570 } else {
571 output.to_string()
572 }
573}
574
575#[async_trait::async_trait]
576impl yoagent::types::AgentTool for BashTool {
577 fn name(&self) -> &str {
578 "bash"
579 }
580 fn label(&self) -> &str {
581 "bash"
582 }
583 fn description(&self) -> &str {
584 "Execute a bash command in the current working directory. Returns stdout and stderr. \
585 Output is truncated to last 2000 lines or 50KB (whichever is hit first). If \
586 truncated, full output is saved to a temp file. Optionally provide a timeout in seconds."
587 }
588 fn parameters_schema(&self) -> serde_json::Value {
589 serde_json::json!({
590 "type": "object",
591 "required": ["command"],
592 "properties": {
593 "command": {
594 "type": "string",
595 "description": "Bash command to execute"
596 },
597 "timeout": {
598 "type": "number",
599 "description": "Timeout in seconds (optional, no default timeout)"
600 }
601 }
602 })
603 }
604 async fn execute(
605 &self,
606 params: serde_json::Value,
607 ctx: yoagent::types::ToolContext,
608 ) -> std::result::Result<yoagent::types::ToolResult, yoagent::types::ToolError> {
609 let command = params["command"].as_str().ok_or_else(|| {
610 yoagent::types::ToolError::InvalidArgs("Missing 'command' argument".into())
611 })?;
612 let timeout = params["timeout"].as_u64();
613 let started_at = Instant::now();
614
615 if ctx.cancel.is_cancelled() {
616 return Err(yoagent::types::ToolError::Cancelled);
617 }
618
619 let effective_command = if let Some(ref prefix) = self.command_prefix {
621 format!("{}\n{}", prefix, command)
622 } else {
623 command.to_string()
624 };
625
626 if !self.cwd.exists() {
628 return Err(yoagent::types::ToolError::Failed(format!(
629 "Working directory does not exist: {}\nCannot execute bash commands.",
630 self.cwd.display()
631 )));
632 }
633
634 if let Some(ref ops) = self.operations {
636 let (output_tx, mut output_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
637 let ops_cancel = Cancel::new();
638
639 let yo_cancel = ctx.cancel.clone();
641 let watch_cancel = ops_cancel.clone();
642 tokio::spawn(async move {
643 yo_cancel.cancelled().await;
644 watch_cancel.cancel();
645 });
646
647 let ops_command = effective_command.clone();
648 let ops_cwd = self.cwd.clone();
649 let ops = ops.clone();
650 let ops_handle = tokio::spawn(async move {
651 ops.exec(
652 &ops_command,
653 &ops_cwd,
654 output_tx,
655 Some(&ops_cancel),
656 timeout,
657 None,
658 )
659 .await
660 });
661
662 let mut combined = String::new();
664 while let Some(chunk) = output_rx.recv().await {
665 combined.push_str(&chunk);
666 emit_update(&ctx, combined.clone(), None);
667 }
668
669 let exit_code = ops_handle.await.unwrap_or(Ok(None)).unwrap_or(None);
670 let code = exit_code.unwrap_or(-1);
671
672 return finish_bash_execution(&combined, code, ctx.cancel.is_cancelled(), None, &ctx);
673 }
674
675 let mut child =
676 spawn_bash_command(&effective_command, &self.cwd, self.shell_path.as_deref()).map_err(
677 |e| yoagent::types::ToolError::Failed(format!("Failed to spawn command: {}", e)),
678 )?;
679
680 let pid = child.id().unwrap_or(0);
681
682 let combined = Arc::new(TokioMutex::new(String::new()));
684 let combined_clone = combined.clone();
685
686 let stdout_pipe = child
687 .stdout
688 .take()
689 .ok_or_else(|| yoagent::types::ToolError::Failed("Failed to capture stdout".into()))?;
690 let stderr_pipe = child
691 .stderr
692 .take()
693 .ok_or_else(|| yoagent::types::ToolError::Failed("Failed to capture stderr".into()))?;
694
695 use tokio::io::AsyncReadExt;
696 let read_task = tokio::spawn(async move {
697 let mut stdout_buf = vec![0u8; 65536];
698 let mut stderr_buf = vec![0u8; 65536];
699 let mut stdout_reader = stdout_pipe;
700 let mut stderr_reader = stderr_pipe;
701 let mut stdout_done = false;
702 let mut stderr_done = false;
703 loop {
704 tokio::select! {
705 result = stdout_reader.read(&mut stdout_buf), if !stdout_done => {
706 match result {
707 Ok(0) => stdout_done = true,
708 Ok(n) => {
709 let text = String::from_utf8_lossy(&stdout_buf[..n]);
710 let sanitized = sanitize_output(&text);
711 let mut out = combined_clone.lock().await;
712 out.push_str(&sanitized);
713 }
714 Err(_) => stdout_done = true,
715 }
716 }
717 result = stderr_reader.read(&mut stderr_buf), if !stderr_done => {
718 match result {
719 Ok(0) => stderr_done = true,
720 Ok(n) => {
721 let text = String::from_utf8_lossy(&stderr_buf[..n]);
722 let sanitized = sanitize_output(&text);
723 let mut out = combined_clone.lock().await;
724 out.push_str(&sanitized);
725 }
726 Err(_) => stderr_done = true,
727 }
728 }
729 }
730 if stdout_done && stderr_done {
731 break;
732 }
733 }
734 });
735
736 let _pid_guard = ProcessGuard::new(pid);
738
739 let cancelled = Arc::new(AtomicBool::new(false));
741 let cancel_flag = cancelled.clone();
742 let yo_cancel = ctx.cancel.clone();
743 let _cancel_monitor: tokio::task::JoinHandle<()> = tokio::spawn(async move {
744 yo_cancel.cancelled().await;
745 cancel_flag.store(true, Ordering::SeqCst);
746 kill_process_group(pid);
747 });
748
749 if let Some(ref on_update) = ctx.on_update {
751 on_update(yoagent::types::ToolResult {
752 content: vec![],
753 details: serde_json::Value::Null,
754 });
755 }
756
757 let timeout_dur = timeout.map(std::time::Duration::from_secs);
759 let throttle_ms = 100u64;
760 let mut last_update_at = Instant::now();
761
762 let exit_code: i32;
763
764 loop {
765 if cancelled.load(Ordering::SeqCst) {
766 kill_process_group(pid);
767 read_task.abort();
768 let combined_str = combined.lock().await.clone();
769 return finish_bash_execution(&combined_str, -1, true, None, &ctx);
770 }
771
772 if let Some(dur) = timeout_dur
773 && started_at.elapsed() > dur
774 {
775 kill_process_group(pid);
776 read_task.abort();
777 let combined_str = combined.lock().await.clone();
778 return finish_bash_execution(&combined_str, -1, false, timeout, &ctx);
779 }
780
781 if last_update_at.elapsed().as_millis() as u64 >= throttle_ms {
782 let out = combined.lock().await.clone();
783 if !out.is_empty() {
784 last_update_at = Instant::now();
785 emit_update(&ctx, out, None);
786 }
787 }
788
789 match child.try_wait() {
790 Ok(Some(status)) => {
791 exit_code = status.code().unwrap_or(-1);
792 let mut last_len = combined.lock().await.len();
794 loop {
795 tokio::time::sleep(std::time::Duration::from_millis(EXIT_STDIO_GRACE_MS))
796 .await;
797 let new_len = combined.lock().await.len();
798 if new_len == last_len {
799 break;
800 }
801 last_len = new_len;
802 }
803 read_task.abort();
804 break;
805 }
806 Ok(None) => {
807 tokio::time::sleep(std::time::Duration::from_millis(throttle_ms)).await;
808 }
809 Err(_) => {
810 read_task.await.ok();
811 exit_code = -1;
812 break;
813 }
814 }
815 }
816
817 let combined_str = combined.lock().await.clone();
818 if !combined_str.is_empty() {
819 emit_update(&ctx, combined_str.clone(), None);
820 }
821
822 finish_bash_execution(&combined_str, exit_code, false, None, &ctx)
823 }
824}
825
826fn cleanup_stale_temp_files() {
829 let dir = temp_output_dir();
830 let Ok(entries) = std::fs::read_dir(&dir) else {
831 return;
832 };
833 let Ok(cutoff) = std::time::SystemTime::now()
834 .checked_sub(std::time::Duration::from_secs(TEMP_FILE_MAX_AGE_SECS))
835 .ok_or(())
836 else {
837 return;
838 };
839 for entry in entries.flatten() {
840 let path = entry.path();
841 if path.extension().is_none_or(|e| e != "log") {
842 continue;
843 }
844 if let Ok(metadata) = path.metadata()
845 && let Ok(modified) = metadata.modified()
846 && modified < cutoff
847 {
848 let _ = std::fs::remove_file(&path);
849 }
850 }
851}
852
853fn temp_output_dir() -> PathBuf {
855 std::env::temp_dir().join(BASH_TEMP_FILE_PREFIX)
856}
857
858fn format_status_output(result_text: &str, status_msg: &str) -> String {
860 if result_text.is_empty() || result_text == "(no output)" {
861 status_msg.to_string()
862 } else {
863 format!("{}\n\n{}", result_text, status_msg)
864 }
865}
866
867fn into_tool_result(
869 text: String,
870 details: Option<serde_json::Value>,
871) -> yoagent::types::ToolResult {
872 yoagent::types::ToolResult {
873 content: vec![yoagent::types::Content::Text { text }],
874 details: details.unwrap_or(serde_json::Value::Null),
875 }
876}
877
878fn emit_update(
880 ctx: &yoagent::types::ToolContext,
881 text: String,
882 details: Option<serde_json::Value>,
883) {
884 if let Some(ref on_update) = ctx.on_update {
885 on_update(into_tool_result(text, details));
886 }
887}
888
889use std::sync::Mutex;
894
895static TRACKED_PIDS: Mutex<Vec<u32>> = std::sync::Mutex::new(Vec::new());
896
897fn track_pid(pid: u32) {
898 if let Ok(mut pids) = TRACKED_PIDS.lock() {
899 pids.push(pid);
900 }
901}
902
903fn untrack_pid(pid: u32) {
904 if let Ok(mut pids) = TRACKED_PIDS.lock() {
905 pids.retain(|&p| p != pid);
906 }
907}
908
909pub fn kill_tracked_children() {
911 let pids: Vec<u32> = TRACKED_PIDS.lock().map(|p| p.clone()).unwrap_or_default();
912 for pid in pids {
913 kill_process_group(pid);
914 }
915}
916
917struct ProcessGuard {
918 pid: u32,
919}
920
921impl ProcessGuard {
922 fn new(pid: u32) -> Self {
923 if pid > 0 {
924 track_pid(pid);
925 }
926 Self { pid }
927 }
928}
929
930impl Drop for ProcessGuard {
931 fn drop(&mut self) {
932 if self.pid > 0 {
933 untrack_pid(self.pid);
934 }
935 }
936}
937
938#[cfg(test)]
939mod tests {
940 use super::*;
941 use yoagent::AgentTool;
942
943 fn tool_ctx() -> yoagent::types::ToolContext {
944 yoagent::types::ToolContext {
945 tool_call_id: "id".into(),
946 tool_name: "bash".into(),
947 cancel: tokio_util::sync::CancellationToken::new(),
948 on_update: None,
949 on_progress: None,
950 }
951 }
952
953 fn yo_msg_text(content: &[yoagent::types::Content]) -> String {
954 content
955 .iter()
956 .filter_map(|c| {
957 if let yoagent::types::Content::Text { text } = c {
958 Some(text.as_str())
959 } else {
960 None
961 }
962 })
963 .collect::<Vec<_>>()
964 .join("")
965 }
966
967 fn make_tool() -> BashTool {
968 BashTool {
969 cwd: std::env::temp_dir(),
970 shell_path: None,
971 command_prefix: None,
972 operations: None,
973 }
974 }
975
976 #[tokio::test]
977 async fn runs_simple_command() {
978 let tool = make_tool();
979 let output = tool
980 .execute(serde_json::json!({"command": "echo hello"}), tool_ctx())
981 .await
982 .unwrap();
983 assert!(yo_msg_text(&output.content).contains("hello"));
984 }
985
986 #[tokio::test]
987 async fn captures_stderr() {
988 let tool = make_tool();
989 let output = tool
990 .execute(serde_json::json!({"command": "echo err >&2"}), tool_ctx())
991 .await
992 .unwrap();
993 assert!(yo_msg_text(&output.content).contains("err"));
994 }
995
996 #[tokio::test]
997 async fn cancel_aborts() {
998 let tool = make_tool();
999 let cancel = tokio_util::sync::CancellationToken::new();
1000 cancel.cancel();
1001 let result = tool
1002 .execute(
1003 serde_json::json!({"command": "sleep 10"}),
1004 yoagent::types::ToolContext {
1005 tool_call_id: "id".into(),
1006 tool_name: "bash".into(),
1007 cancel,
1008 on_update: None,
1009 on_progress: None,
1010 },
1011 )
1012 .await;
1013 assert!(result.is_err());
1014 let err = result.unwrap_err().to_string();
1015 assert!(
1016 err.contains("Cancelled") || err.contains("aborted"),
1017 "expected cancellation error, got: {}",
1018 err
1019 );
1020 }
1021
1022 #[tokio::test]
1023 async fn timeout_works() {
1024 let tool = make_tool();
1025 let result = tool
1026 .execute(
1027 serde_json::json!({"command": "sleep 10", "timeout": 1}),
1028 tool_ctx(),
1029 )
1030 .await;
1031 assert!(result.is_err());
1032 let err = result.unwrap_err().to_string();
1033 assert!(err.contains("timed out"));
1034 }
1035
1036 #[test]
1037 fn test_truncate_tail_no_truncation() {
1038 let result = truncate_tail("hello\nworld\n", 2000, 50000);
1039 assert!(!result.truncated);
1040 assert_eq!(result.content, "hello\nworld\n");
1041 }
1042
1043 #[test]
1044 fn test_truncate_tail_by_lines() {
1045 let content: String = (1..=5000).map(|i| format!("line {}\n", i)).collect();
1046 let result = truncate_tail(&content, 2000, 50000);
1047 assert!(result.truncated);
1048 assert!(result.content.starts_with("line 3001"));
1049 assert_eq!(result.content.lines().count(), 2000);
1050 }
1051
1052 #[test]
1053 fn test_truncate_tail_by_bytes() {
1054 let content: String = (1..=100)
1055 .map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
1056 .collect();
1057 let result = truncate_tail(&content, 2000, 50000);
1058 assert!(result.truncated);
1059 assert!(result.content.len() <= 50000);
1060 assert!(result.content.lines().count() < 100);
1061 }
1062
1063 #[test]
1064 fn test_truncate_tail_partial_last_line() {
1065 let content = format!("short\n{}\n", "x".repeat(60000));
1066 let result = truncate_tail(&content, 2000, 50000);
1067 assert!(result.truncated);
1068 assert!(!result.content.starts_with("short"));
1069 assert!(result.content.len() <= 50000);
1070 }
1071
1072 #[test]
1073 fn test_truncate_tail_empty() {
1074 let result = truncate_tail("", 2000, 50000);
1075 assert!(!result.truncated);
1076 assert_eq!(result.content, "");
1077 }
1078
1079 #[tokio::test]
1080 async fn exit_code_nonzero() {
1081 let tool = make_tool();
1082 let result = tool
1083 .execute(serde_json::json!({"command": "exit 42"}), tool_ctx())
1084 .await;
1085 assert!(result.is_err(), "non-zero exit should return error");
1086 let err = result.unwrap_err().to_string();
1087 assert!(err.contains("exited with code 42"), "got: {}", err);
1088 }
1089
1090 #[tokio::test]
1091 async fn exit_code_with_output() {
1092 let tool = make_tool();
1093 let result = tool
1094 .execute(
1095 serde_json::json!({"command": "echo before && exit 1"}),
1096 tool_ctx(),
1097 )
1098 .await;
1099 assert!(result.is_err(), "non-zero exit should return error");
1100 let err = result.unwrap_err().to_string();
1101 assert!(err.contains("before"), "got: {}", err);
1102 assert!(err.contains("exited with code 1"), "got: {}", err);
1103 }
1104
1105 #[tokio::test]
1106 async fn no_output() {
1107 let tool = make_tool();
1108 let output = tool
1109 .execute(serde_json::json!({"command": "true"}), tool_ctx())
1110 .await
1111 .unwrap();
1112 assert!(
1113 yo_msg_text(&output.content).contains("(no output)"),
1114 "got: {}",
1115 yo_msg_text(&output.content)
1116 );
1117 }
1118
1119 #[tokio::test]
1120 async fn combined_stdout_stderr() {
1121 let tool = make_tool();
1122 let output = tool
1123 .execute(
1124 serde_json::json!({"command": "echo out; echo err >&2"}),
1125 tool_ctx(),
1126 )
1127 .await
1128 .unwrap();
1129 assert!(
1130 yo_msg_text(&output.content).contains("out"),
1131 "got: {}",
1132 yo_msg_text(&output.content)
1133 );
1134 assert!(
1135 yo_msg_text(&output.content).contains("err"),
1136 "got: {}",
1137 yo_msg_text(&output.content)
1138 );
1139 }
1140
1141 #[tokio::test]
1142 async fn runs_in_cwd() {
1143 let tmp = std::env::temp_dir().join(format!("rab-bash-cwd-{}", uuid::Uuid::new_v4()));
1144 std::fs::create_dir_all(&tmp).unwrap();
1145 std::fs::write(tmp.join("marker.txt"), "hello").unwrap();
1146
1147 let tool = BashTool {
1148 cwd: tmp.clone(),
1149 shell_path: None,
1150 command_prefix: None,
1151 operations: None,
1152 };
1153 let output = tool
1154 .execute(serde_json::json!({"command": "cat marker.txt"}), tool_ctx())
1155 .await
1156 .unwrap();
1157 assert!(
1158 yo_msg_text(&output.content).contains("hello"),
1159 "got: {}",
1160 yo_msg_text(&output.content)
1161 );
1162 }
1163
1164 #[tokio::test]
1165 async fn missing_command_errors() {
1166 let tool = make_tool();
1167 let result = tool.execute(serde_json::json!({}), tool_ctx()).await;
1168 assert!(result.is_err());
1169 let err = result.unwrap_err().to_string();
1170 assert!(err.contains("command"), "got: {}", err);
1171 }
1172
1173 #[tokio::test]
1174 async fn timeout_with_partial_output() {
1175 let tool = make_tool();
1176 let result = tool
1177 .execute(
1178 serde_json::json!({"command": "echo start && sleep 10 && echo end", "timeout": 1}),
1179 tool_ctx(),
1180 )
1181 .await;
1182 assert!(result.is_err());
1183 let err = result.unwrap_err().to_string();
1184 assert!(err.contains("timed out"), "got: {}", err);
1185 }
1186
1187 #[tokio::test]
1188 async fn cancel_during_long_command() {
1189 let tool = make_tool();
1190 let cancel = tokio_util::sync::CancellationToken::new();
1191 let cancel_ctx = cancel.clone();
1192
1193 let handle = tokio::spawn(async move {
1194 tool.execute(
1195 serde_json::json!({"command": "sleep 30"}),
1196 yoagent::types::ToolContext {
1197 tool_call_id: "id".into(),
1198 tool_name: "bash".into(),
1199 cancel: cancel_ctx,
1200 on_update: None,
1201 on_progress: None,
1202 },
1203 )
1204 .await
1205 });
1206
1207 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1208 cancel.cancel();
1209
1210 let result = handle.await.unwrap();
1211 assert!(result.is_err());
1212 let err = result.unwrap_err().to_string();
1213 assert!(
1214 err.contains("aborted") || err.contains("Cancelled"),
1215 "expected cancellation error, got: {}",
1216 err
1217 );
1218 }
1219
1220 #[test]
1221 fn test_truncate_tail_exact_line_fit() {
1222 let lines: String = (1..=2000).map(|i| format!("line {}\n", i)).collect();
1223 let result = truncate_tail(&lines, 2000, 50000);
1224 assert!(!result.truncated);
1225 assert!(result.content.lines().count() == 2000);
1226 }
1227
1228 #[test]
1229 fn test_truncate_tail_one_over_line_limit() {
1230 let lines: String = (1..=2001).map(|i| format!("line {}\n", i)).collect();
1231 let result = truncate_tail(&lines, 2000, 50000);
1232 assert!(result.truncated);
1233 assert_eq!(result.content.lines().count(), 2000);
1234 assert!(result.content.starts_with("line 2"));
1235 }
1236
1237 #[test]
1238 fn test_truncate_tail_exact_byte_fit() {
1239 let line = "a".repeat(50000);
1240 let result = truncate_tail(&line, 2000, 50000);
1241 assert!(!result.truncated);
1242 }
1243
1244 #[test]
1245 fn test_truncate_tail_one_byte_over() {
1246 let line = "a".repeat(50001);
1247 let result = truncate_tail(&line, 2000, 50000);
1248 assert!(result.truncated);
1249 assert!(result.content.len() <= 50000);
1250 }
1251
1252 #[test]
1253 fn test_truncate_tail_single_line_under_limit() {
1254 let result = truncate_tail("hello world", 2000, 50000);
1255 assert!(!result.truncated);
1256 assert_eq!(result.content, "hello world");
1257 }
1258
1259 #[test]
1260 fn test_truncate_tail_trailing_newline() {
1261 let result = truncate_tail("a\nb\n", 2000, 50000);
1262 assert!(!result.truncated);
1263 assert_eq!(result.content, "a\nb\n");
1264 }
1265
1266 #[test]
1267 fn test_truncate_tail_no_trailing_newline() {
1268 let result = truncate_tail("a\nb", 2000, 50000);
1269 assert!(!result.truncated);
1270 assert_eq!(result.content, "a\nb");
1271 }
1272
1273 #[test]
1274 fn test_truncate_tail_single_line_exceeds_limit() {
1275 let content = "x".repeat(60000);
1276 let result = truncate_tail(&content, 2000, 50000);
1277 assert!(result.truncated);
1278 assert!(result.last_line_partial);
1279 assert_eq!(result.content.len(), 50000);
1280 assert!(result.content.ends_with("x".repeat(50000).as_str()));
1281 }
1282
1283 #[test]
1284 fn test_truncate_tail_byte_count_respects_newlines() {
1285 let content: String = (1..=100)
1286 .map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
1287 .collect();
1288 let result = truncate_tail(&content, 2000, 50000);
1289 assert!(result.truncated);
1290 assert!(result.output_bytes <= 50000);
1291 }
1292
1293 #[tokio::test]
1294 async fn truncated_by_lines_shows_footer() {
1295 let tool = make_tool();
1296 let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
1297 let output = tool
1298 .execute(serde_json::json!({"command": cmd}), tool_ctx())
1299 .await
1300 .unwrap();
1301 assert!(
1302 yo_msg_text(&output.content).contains("Showing lines"),
1303 "got: {}",
1304 yo_msg_text(&output.content)
1305 );
1306 assert!(
1307 yo_msg_text(&output.content).contains("Full output:"),
1308 "got: {}",
1309 yo_msg_text(&output.content)
1310 );
1311 }
1312
1313 #[tokio::test]
1314 async fn small_output_no_footer() {
1315 let tool = make_tool();
1316 let output = tool
1317 .execute(serde_json::json!({"command": "echo hello"}), tool_ctx())
1318 .await
1319 .unwrap();
1320 assert!(!yo_msg_text(&output.content).contains("Output truncated"));
1321 assert!(!yo_msg_text(&output.content).contains("Full output:"));
1322 }
1323
1324 #[tokio::test]
1325 async fn truncated_saves_temp_file() {
1326 let tool = make_tool();
1327 let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
1328 let output = tool
1329 .execute(serde_json::json!({"command": cmd}), tool_ctx())
1330 .await
1331 .unwrap();
1332 assert!(
1333 yo_msg_text(&output.content).contains("/rab-bash/"),
1334 "expected temp file path with /rab-bash/, got: {}",
1335 yo_msg_text(&output.content)
1336 );
1337 }
1338
1339 #[test]
1340 fn test_cleanup_stale_temp_files_nonexistent_dir() {
1341 cleanup_stale_temp_files();
1343 }
1344
1345 #[test]
1346 fn test_truncate_tail_many_short_lines() {
1347 let content: String = (1..=10000).map(|i| format!("{}\n", i)).collect();
1348 let result = truncate_tail(&content, 2000, 50000);
1349 assert!(result.truncated);
1350 assert_eq!(result.truncated_by, "lines");
1351 assert_eq!(result.output_lines, 2000);
1352 assert!(
1353 result.content.starts_with("8001"),
1354 "starts with: {:?}",
1355 &result.content[..10]
1356 );
1357 }
1358
1359 #[test]
1360 fn test_truncate_tail_lines_and_bytes_both_exceeded() {
1361 let content: String = (1..=5000)
1362 .map(|i| format!("line {} {}\n", i, "x".repeat(100)))
1363 .collect();
1364 let result = truncate_tail(&content, 2000, 30000);
1365 assert!(result.truncated);
1366 assert_eq!(result.truncated_by, "bytes");
1367 assert!(result.output_lines < 2000);
1368 }
1369
1370 #[test]
1373 fn test_process_guard_tracks_pid() {
1374 let pid = 12345u32;
1375 {
1376 let _guard = ProcessGuard::new(pid);
1377 let pids = TRACKED_PIDS.lock().unwrap();
1378 assert!(pids.contains(&pid));
1379 }
1380 let pids = TRACKED_PIDS.lock().unwrap();
1381 assert!(!pids.contains(&pid));
1382 }
1383
1384 #[test]
1385 fn test_process_guard_zero_pid() {
1386 {
1387 let _guard = ProcessGuard::new(0);
1388 let pids = TRACKED_PIDS.lock().unwrap();
1389 assert!(!pids.contains(&0));
1390 }
1391 }
1392}