1use crate::agent::extension::{Cancel, Extension, ToolDefinition};
2use crate::agent::extension::{ToolRenderContext, ToolRenderer};
3use crate::tui::Theme;
4use crate::tui::ThemeKey;
5use crate::tui::visual_truncate::truncate_to_visual_lines;
6use async_trait::async_trait;
7
8use std::borrow::Cow;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::time::Instant;
14use tokio::sync::{Mutex as TokioMutex, mpsc::UnboundedSender};
15
16#[async_trait]
21pub trait BashOperations: Send + Sync {
22 async fn exec(
25 &self,
26 command: &str,
27 cwd: &Path,
28 on_data: UnboundedSender<String>,
29 signal: Option<&Cancel>,
30 timeout: Option<u64>,
31 env: Option<HashMap<String, String>>,
32 ) -> Result<Option<i32>, anyhow::Error>;
33}
34
35#[derive(Clone, Default)]
36pub struct BashToolOptions {
37 pub operations: Option<Arc<dyn BashOperations>>,
39 pub command_prefix: Option<String>,
41 pub shell_path: Option<String>,
43}
44
45pub struct BashExtension {
46 cwd: PathBuf,
47 options: BashToolOptions,
48}
49
50impl BashExtension {
51 pub fn new(cwd: PathBuf) -> Self {
52 Self {
53 cwd,
54 options: BashToolOptions::default(),
55 }
56 }
57
58 pub fn with_options(cwd: PathBuf, options: BashToolOptions) -> Self {
59 Self { cwd, options }
60 }
61
62 pub fn with_shell_path(cwd: PathBuf, shell_path: String) -> Self {
63 Self {
64 cwd,
65 options: BashToolOptions {
66 shell_path: Some(shell_path),
67 ..BashToolOptions::default()
68 },
69 }
70 }
71}
72
73impl Extension for BashExtension {
74 fn name(&self) -> Cow<'static, str> {
75 "bash".into()
76 }
77
78 fn tools(&self) -> Vec<ToolDefinition> {
79 vec![ToolDefinition {
80 tool: Box::new(BashTool {
81 cwd: self.cwd.clone(),
82 shell_path: self.options.shell_path.clone(),
83 command_prefix: self.options.command_prefix.clone(),
84 operations: self.options.operations.clone(),
85 }),
86 snippet: "Execute bash commands (ls, grep, find, etc.)",
87 guidelines: &[],
88 prepare_arguments: None,
89 before_tool_call: None,
90 after_tool_call: None,
91 renderer: Some(std::sync::Arc::new(BashRenderer)),
92 }]
93 }
94}
95
96struct BashTool {
97 cwd: PathBuf,
98 shell_path: Option<String>,
99 command_prefix: Option<String>,
100 operations: Option<Arc<dyn BashOperations>>,
101}
102
103const DEFAULT_MAX_LINES: usize = 2000;
106const DEFAULT_MAX_BYTES: usize = 50 * 1024; const BASH_TEMP_FILE_PREFIX: &str = "pi-bash";
108
109const EXIT_STDIO_GRACE_MS: u64 = 100;
112
113struct ShellConfig {
117 shell: String,
118 args: Vec<String>,
119}
120
121fn resolve_shell(shell_path: Option<&str>) -> ShellConfig {
127 if let Some(path) = shell_path {
128 return ShellConfig {
129 shell: path.to_string(),
130 args: vec!["-c".to_string()],
131 };
132 }
133
134 if std::path::Path::new("/bin/bash").exists() {
136 return ShellConfig {
137 shell: "/bin/bash".to_string(),
138 args: vec!["-c".to_string()],
139 };
140 }
141
142 #[cfg(unix)]
144 {
145 if let Ok(output) = std::process::Command::new("which")
146 .arg("bash")
147 .stdout(std::process::Stdio::piped())
148 .stderr(std::process::Stdio::null())
149 .output()
150 && output.status.success()
151 {
152 let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
153 if !path.is_empty() && std::path::Path::new(&path).exists() {
154 return ShellConfig {
155 shell: path,
156 args: vec!["-c".to_string()],
157 };
158 }
159 }
160 }
161
162 ShellConfig {
164 shell: "sh".to_string(),
165 args: vec!["-c".to_string()],
166 }
167}
168
169#[cfg(unix)]
173fn kill_process_group(pid: u32) {
174 if pid > 0 {
175 let _ = std::process::Command::new("kill")
176 .arg("--")
177 .arg(format!("-{}", pid))
178 .status();
179 }
180}
181
182#[cfg(not(unix))]
183fn kill_process_group(pid: u32) {
184 let _ = pid;
185}
186
187fn spawn_bash_command(
189 command: &str,
190 cwd: &std::path::Path,
191 shell_path: Option<&str>,
192) -> std::io::Result<tokio::process::Child> {
193 let shell_cfg = resolve_shell(shell_path);
194
195 #[cfg(unix)]
196 {
197 use std::os::unix::process::CommandExt;
198 let mut std_cmd = std::process::Command::new(&shell_cfg.shell);
199 std_cmd.args(&shell_cfg.args).arg(command).current_dir(cwd);
200 unsafe {
201 std_cmd.pre_exec(|| {
202 libc::setpgid(0, 0);
203 Ok(())
204 });
205 }
206 let mut tokio_cmd = tokio::process::Command::from(std_cmd);
207 tokio_cmd
208 .stdin(std::process::Stdio::null())
209 .stdout(std::process::Stdio::piped())
210 .stderr(std::process::Stdio::piped())
211 .spawn()
212 }
213 #[cfg(not(unix))]
214 {
215 tokio::process::Command::new(&shell_cfg.shell)
216 .args(&shell_cfg.args)
217 .arg(command)
218 .current_dir(cwd)
219 .stdin(std::process::Stdio::null())
220 .stdout(std::process::Stdio::piped())
221 .stderr(std::process::Stdio::piped())
222 .spawn()
223 }
224}
225
226fn sanitize_output(text: &str) -> String {
228 let mut result = String::with_capacity(text.len());
229 let mut in_escape = false;
230 for c in text.chars() {
231 if in_escape {
232 if c == '\x1b' || c == '\u{9b}' {
233 continue;
234 }
235 if c.is_ascii_alphabetic() || c == '~' {
236 in_escape = false;
237 }
238 continue;
239 }
240 if c == '\x1b' || c == '\u{9b}' {
241 in_escape = true;
242 continue;
243 }
244 let code = c as u32;
245 if code <= 0x1f && code != 0x09 && code != 0x0a && code != 0x0d {
246 continue;
247 }
248 if (0xfff9..=0xfffb).contains(&code) {
249 continue;
250 }
251 result.push(c);
252 }
253 result
254}
255
256fn format_size(bytes: usize) -> String {
257 if bytes < 1024 {
258 format!("{}B", bytes)
259 } else if bytes < 1024 * 1024 {
260 format!("{:.1}KB", bytes as f64 / 1024.0)
261 } else {
262 format!("{:.1}MB", bytes as f64 / (1024.0 * 1024.0))
263 }
264}
265
266struct TailTruncation {
268 content: String,
269 truncated: bool,
270 total_lines: usize,
271 output_lines: usize,
272 output_bytes: usize,
273 truncated_by: &'static str,
274 last_line_partial: bool,
275}
276
277fn truncate_tail(content: &str, max_lines: usize, max_bytes: usize) -> TailTruncation {
279 let total_bytes = content.len();
280 let lines: Vec<&str> = content.lines().collect();
281 let total_lines = lines.len();
282
283 if total_lines <= max_lines && total_bytes <= max_bytes {
284 return TailTruncation {
285 content: content.to_string(),
286 truncated: false,
287 total_lines,
288 output_lines: total_lines,
289 output_bytes: total_bytes,
290 truncated_by: "",
291 last_line_partial: false,
292 };
293 }
294
295 let mut output: Vec<&str> = Vec::new();
296 let mut byte_count: usize = 0;
297 let mut truncated_by = "lines";
298 let mut last_line_partial = false;
299
300 for line in lines.iter().rev().take(max_lines) {
301 let line_bytes = line.len();
302 let with_newline = if output.is_empty() {
303 line_bytes
304 } else {
305 line_bytes + 1
306 };
307
308 if byte_count + with_newline > max_bytes {
309 truncated_by = "bytes";
310 if output.is_empty() {
311 let end_start = line.len().saturating_sub(max_bytes);
312 let truncated_line = &line[end_start..];
313 output.push(truncated_line);
314 byte_count = truncated_line.len();
315 last_line_partial = true;
316 }
317 break;
318 }
319
320 output.push(line);
321 byte_count += with_newline;
322 }
323
324 if output.len() >= max_lines && byte_count <= max_bytes {
325 truncated_by = "lines";
326 }
327
328 output.reverse();
329 TailTruncation {
330 content: output.join("\n"),
331 truncated: true,
332 total_lines,
333 output_lines: output.len(),
334 output_bytes: byte_count,
335 truncated_by,
336 last_line_partial,
337 }
338}
339
340fn finish_bash_execution(
343 combined: &str,
344 exit_code: i32,
345 cancelled: bool,
346 timed_out: Option<u64>,
347 ctx: &yoagent::types::ToolContext,
348) -> std::result::Result<yoagent::types::ToolResult, yoagent::types::ToolError> {
349 let trunc = truncate_tail(combined, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES);
350
351 let mut result_text = if trunc.content.is_empty() {
352 "(no output)".to_string()
353 } else {
354 trunc.content.clone()
355 };
356
357 let full_output_path = if trunc.truncated {
359 let tmp_dir = std::env::temp_dir().join(BASH_TEMP_FILE_PREFIX);
360 let _ = std::fs::create_dir_all(&tmp_dir);
361 let tmp_path = tmp_dir.join(format!("{}.log", uuid::Uuid::new_v4()));
362 let saved = std::fs::write(&tmp_path, combined).ok().map(|_| tmp_path);
363
364 let start_line = trunc.total_lines - trunc.output_lines + 1;
365 let end_line = trunc.total_lines;
366
367 let notice = if trunc.truncated_by == "lines" {
368 format!(
369 "\n\n[Showing lines {}-{} of {}. Full output: {}]",
370 start_line,
371 end_line,
372 trunc.total_lines,
373 saved
374 .as_ref()
375 .map(|p| p.display().to_string())
376 .unwrap_or_default()
377 )
378 } else {
379 format!(
380 "\n\n[Showing lines {}-{} of {} ({} limit). Full output: {}]",
381 start_line,
382 end_line,
383 trunc.total_lines,
384 format_size(DEFAULT_MAX_BYTES),
385 saved
386 .as_ref()
387 .map(|p| p.display().to_string())
388 .unwrap_or_default()
389 )
390 };
391 result_text.push_str(¬ice);
392 saved
393 } else {
394 None
395 };
396
397 let details = if trunc.truncated || full_output_path.is_some() {
399 Some(serde_json::json!({
400 "truncation": {
401 "truncated": trunc.truncated,
402 "truncatedBy": trunc.truncated_by,
403 "totalLines": trunc.total_lines,
404 "outputLines": trunc.output_lines,
405 "outputBytes": trunc.output_bytes,
406 "lastLinePartial": trunc.last_line_partial,
407 "maxLines": DEFAULT_MAX_LINES,
408 "maxBytes": DEFAULT_MAX_BYTES,
409 },
410 "fullOutputPath": full_output_path.as_ref().map(|p| p.display().to_string()),
411 }))
412 } else {
413 None
414 };
415
416 let final_output = if cancelled {
417 if result_text.is_empty() || result_text == "(no output)" {
418 "Command aborted".to_string()
419 } else {
420 format!("{}\n\nCommand aborted", result_text)
421 }
422 } else if let Some(secs) = timed_out {
423 if result_text.is_empty() || result_text == "(no output)" {
424 format!("Command timed out after {} seconds", secs)
425 } else {
426 format!(
427 "{}\n\nCommand timed out after {} seconds",
428 result_text, secs
429 )
430 }
431 } else if exit_code != 0 {
432 if result_text.is_empty() || result_text == "(no output)" {
433 format!("Command exited with code {}", exit_code)
434 } else {
435 format!("{}\n\nCommand exited with code {}", result_text, exit_code)
436 }
437 } else {
438 if let Some(ref on_update) = ctx.on_update {
439 on_update(yoagent::types::ToolResult {
440 content: vec![yoagent::types::Content::Text {
441 text: result_text.clone(),
442 }],
443 details: details.clone().unwrap_or(serde_json::Value::Null),
444 });
445 }
446 return Ok(yoagent::types::ToolResult {
447 content: vec![yoagent::types::Content::Text { text: result_text }],
448 details: details.unwrap_or(serde_json::Value::Null),
449 });
450 };
451
452 if let Some(ref on_update) = ctx.on_update {
453 on_update(yoagent::types::ToolResult {
454 content: vec![yoagent::types::Content::Text {
455 text: final_output.clone(),
456 }],
457 details: details.clone().unwrap_or(serde_json::Value::Null),
458 });
459 }
460
461 Err(yoagent::types::ToolError::Failed(final_output))
462}
463
464struct BashRenderer;
469
470impl ToolRenderer for BashRenderer {
471 fn render_call(
472 &self,
473 args: &serde_json::Value,
474 _width: usize,
475 theme: &dyn Theme,
476 _ctx: &ToolRenderContext,
477 ) -> Vec<String> {
478 let cmd = args
479 .get("command")
480 .and_then(|v| v.as_str())
481 .unwrap_or("...");
482 let timeout = args.get("timeout").and_then(|v| v.as_i64());
483 let timeout_suffix = timeout
484 .map(|t| theme.fg_key(ThemeKey::Muted, &format!(" (timeout {}s)", t)))
485 .unwrap_or_default();
486
487 vec![format!(
488 "{}{}",
489 theme.fg_key(ThemeKey::ToolTitle, &theme.bold(&format!("$ {}", cmd))),
490 timeout_suffix
491 )]
492 }
493
494 fn render_result(
495 &self,
496 content: &str,
497 width: usize,
498 theme: &dyn Theme,
499 ctx: &ToolRenderContext,
500 ) -> Vec<String> {
501 let mut lines: Vec<String> = Vec::new();
502
503 let clean = strip_context_truncation_footer(content)
504 .trim_end()
505 .to_string();
506 let all_lines: Vec<&str> = clean.lines().collect();
507
508 if all_lines.is_empty() || (all_lines.len() == 1 && all_lines[0].is_empty()) {
509 return lines;
510 }
511
512 let preview_count = 5;
513 let (preview_lines, hidden_line_count) = if ctx.expanded {
514 (all_lines.clone(), 0)
515 } else {
516 truncate_to_visual_lines(&all_lines, width, preview_count)
517 };
518
519 if !ctx.expanded && hidden_line_count > 0 {
521 if ctx.expand_key.is_empty() {
522 lines.push(theme.fg_key(
523 ThemeKey::Muted,
524 &format!("... {} earlier lines", hidden_line_count),
525 ));
526 } else {
527 let prefix = theme.fg_key(
530 ThemeKey::Muted,
531 &format!("... ({} earlier lines, ", hidden_line_count),
532 );
533 let key_styled = theme.fg("dim", &ctx.expand_key);
534 let suffix = theme.fg_key(ThemeKey::Muted, " to expand)");
535 lines.push(format!("{}{}{}", prefix, key_styled, suffix));
536 }
537 }
538
539 let fg_key = if ctx.is_error { "error" } else { "toolOutput" };
540 for line in &preview_lines {
541 if line.is_empty() {
542 lines.push(String::new());
543 } else {
544 lines.push(theme.fg(fg_key, line));
545 }
546 }
547
548 if let Some(secs) = ctx.duration_secs {
549 if !lines.is_empty() {
550 lines.push(String::new());
551 }
552 let is_complete = ctx.exit_code.is_some() || ctx.cancelled;
553 let label = if is_complete { "Took" } else { "Elapsed" };
554 lines.push(theme.fg_key(ThemeKey::Muted, &format!("{} {:.1}s", label, secs)));
555 }
556
557 if ctx.was_truncated {
558 if !lines.is_empty() {
559 lines.push(String::new());
560 }
561 if let Some(ref path) = ctx.full_output_path {
562 lines.push(theme.fg(
563 "warning",
564 &format!("Output truncated. Full output: {}", path),
565 ));
566 } else {
567 lines.push(theme.fg_key(ThemeKey::Warning, "Output truncated."));
568 }
569 }
570
571 lines
572 }
573}
574
575fn strip_context_truncation_footer(output: &str) -> String {
576 let lines: Vec<&str> = output.lines().collect();
577 if lines.len() < 3 {
578 return output.to_string();
579 }
580 let last = lines.last().map_or("", |v| v).trim();
581 if last.starts_with('[')
582 && (last.contains("Showing lines") || last.contains("Showing last"))
583 && last.contains("Full output:")
584 {
585 let before: Vec<&str> = lines[..lines.len() - 1].to_vec();
586 if !before.is_empty() && before[before.len() - 1].is_empty() {
587 before[..before.len() - 1].join("\n")
588 } else {
589 before.join("\n")
590 }
591 } else {
592 output.to_string()
593 }
594}
595
596#[async_trait::async_trait]
597impl yoagent::types::AgentTool for BashTool {
598 fn name(&self) -> &str {
599 "bash"
600 }
601 fn label(&self) -> &str {
602 "bash"
603 }
604 fn description(&self) -> &str {
605 "Execute a bash command in the current working directory. Returns stdout and stderr. \
606 Output is truncated to last 2000 lines or 50KB (whichever is hit first). If \
607 truncated, full output is saved to a temp file. Optionally provide a timeout in seconds."
608 }
609 fn parameters_schema(&self) -> serde_json::Value {
610 serde_json::json!({
611 "type": "object",
612 "required": ["command"],
613 "properties": {
614 "command": {
615 "type": "string",
616 "description": "Bash command to execute"
617 },
618 "timeout": {
619 "type": "number",
620 "description": "Timeout in seconds (optional, no default timeout)"
621 }
622 }
623 })
624 }
625 async fn execute(
626 &self,
627 params: serde_json::Value,
628 ctx: yoagent::types::ToolContext,
629 ) -> std::result::Result<yoagent::types::ToolResult, yoagent::types::ToolError> {
630 let command = params["command"].as_str().ok_or_else(|| {
631 yoagent::types::ToolError::InvalidArgs("Missing 'command' argument".into())
632 })?;
633 let timeout = params["timeout"].as_u64();
634 let started_at = Instant::now();
635
636 if ctx.cancel.is_cancelled() {
637 return Err(yoagent::types::ToolError::Cancelled);
638 }
639
640 let effective_command = if let Some(ref prefix) = self.command_prefix {
642 format!("{}\n{}", prefix, command)
643 } else {
644 command.to_string()
645 };
646
647 if !self.cwd.exists() {
649 return Err(yoagent::types::ToolError::Failed(format!(
650 "Working directory does not exist: {}\nCannot execute bash commands.",
651 self.cwd.display()
652 )));
653 }
654
655 if let Some(ref ops) = self.operations {
657 let (output_tx, mut output_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
658 let ops_cancel = Cancel::new();
659
660 let yo_cancel = ctx.cancel.clone();
662 let watch_cancel = ops_cancel.clone();
663 tokio::spawn(async move {
664 yo_cancel.cancelled().await;
665 watch_cancel.cancel();
666 });
667
668 let ops_command = effective_command.clone();
669 let ops_cwd = self.cwd.clone();
670 let ops = ops.clone();
671 let ops_handle = tokio::spawn(async move {
672 ops.exec(
673 &ops_command,
674 &ops_cwd,
675 output_tx,
676 Some(&ops_cancel),
677 timeout,
678 None,
679 )
680 .await
681 });
682
683 let mut combined = String::new();
685 while let Some(chunk) = output_rx.recv().await {
686 combined.push_str(&chunk);
687 if let Some(ref on_update) = ctx.on_update {
688 on_update(yoagent::types::ToolResult {
689 content: vec![yoagent::types::Content::Text {
690 text: combined.clone(),
691 }],
692 details: serde_json::Value::Null,
693 });
694 }
695 }
696
697 let exit_code = ops_handle.await.unwrap_or(Ok(None)).unwrap_or(None);
698 let code = exit_code.unwrap_or(-1);
699
700 return finish_bash_execution(&combined, code, ctx.cancel.is_cancelled(), None, &ctx);
701 }
702
703 let mut child =
704 spawn_bash_command(&effective_command, &self.cwd, self.shell_path.as_deref()).map_err(
705 |e| yoagent::types::ToolError::Failed(format!("Failed to spawn command: {}", e)),
706 )?;
707
708 let pid = child.id().unwrap_or(0);
709
710 let combined = Arc::new(TokioMutex::new(String::new()));
712 let combined_clone = combined.clone();
713
714 let stdout_pipe = child
715 .stdout
716 .take()
717 .ok_or_else(|| yoagent::types::ToolError::Failed("Failed to capture stdout".into()))?;
718 let stderr_pipe = child
719 .stderr
720 .take()
721 .ok_or_else(|| yoagent::types::ToolError::Failed("Failed to capture stderr".into()))?;
722
723 use tokio::io::AsyncReadExt;
724 let read_task = tokio::spawn(async move {
725 let mut stdout_buf = vec![0u8; 65536];
726 let mut stderr_buf = vec![0u8; 65536];
727 let mut stdout_reader = stdout_pipe;
728 let mut stderr_reader = stderr_pipe;
729 let mut stdout_done = false;
730 let mut stderr_done = false;
731 loop {
732 tokio::select! {
733 result = stdout_reader.read(&mut stdout_buf), if !stdout_done => {
734 match result {
735 Ok(0) => stdout_done = true,
736 Ok(n) => {
737 let text = String::from_utf8_lossy(&stdout_buf[..n]);
738 let sanitized = sanitize_output(&text);
739 let mut out = combined_clone.lock().await;
740 out.push_str(&sanitized);
741 }
742 Err(_) => stdout_done = true,
743 }
744 }
745 result = stderr_reader.read(&mut stderr_buf), if !stderr_done => {
746 match result {
747 Ok(0) => stderr_done = true,
748 Ok(n) => {
749 let text = String::from_utf8_lossy(&stderr_buf[..n]);
750 let sanitized = sanitize_output(&text);
751 let mut out = combined_clone.lock().await;
752 out.push_str(&sanitized);
753 }
754 Err(_) => stderr_done = true,
755 }
756 }
757 }
758 if stdout_done && stderr_done {
759 break;
760 }
761 }
762 });
763
764 let _pid_guard = ProcessGuard::new(pid);
766
767 let cancelled = Arc::new(AtomicBool::new(false));
769 let cancel_flag = cancelled.clone();
770 let yo_cancel = ctx.cancel.clone();
771 let _cancel_monitor: tokio::task::JoinHandle<()> = tokio::spawn(async move {
772 yo_cancel.cancelled().await;
773 cancel_flag.store(true, Ordering::SeqCst);
774 kill_process_group(pid);
775 });
776
777 if let Some(ref on_update) = ctx.on_update {
779 on_update(yoagent::types::ToolResult {
780 content: vec![],
781 details: serde_json::Value::Null,
782 });
783 }
784
785 let timeout_dur = timeout.map(std::time::Duration::from_secs);
787 let throttle_ms = 100u64;
788 let mut last_update_at = Instant::now();
789
790 let exit_code: i32;
791
792 loop {
793 if cancelled.load(Ordering::SeqCst) {
794 kill_process_group(pid);
795 read_task.abort();
796 let combined_str = combined.lock().await.clone();
797 return finish_bash_execution(&combined_str, -1, true, None, &ctx);
798 }
799
800 if let Some(dur) = timeout_dur
801 && started_at.elapsed() > dur
802 {
803 kill_process_group(pid);
804 read_task.abort();
805 let combined_str = combined.lock().await.clone();
806 return finish_bash_execution(&combined_str, -1, false, timeout, &ctx);
807 }
808
809 if let Some(ref on_update) = ctx.on_update
810 && last_update_at.elapsed().as_millis() as u64 >= throttle_ms
811 {
812 let out = combined.lock().await.clone();
813 if !out.is_empty() {
814 last_update_at = Instant::now();
815 on_update(yoagent::types::ToolResult {
816 content: vec![yoagent::types::Content::Text { text: out }],
817 details: serde_json::Value::Null,
818 });
819 }
820 }
821
822 match child.try_wait() {
823 Ok(Some(status)) => {
824 exit_code = status.code().unwrap_or(-1);
825 let mut last_len = combined.lock().await.len();
827 loop {
828 tokio::time::sleep(std::time::Duration::from_millis(EXIT_STDIO_GRACE_MS))
829 .await;
830 let new_len = combined.lock().await.len();
831 if new_len == last_len {
832 break;
833 }
834 last_len = new_len;
835 }
836 read_task.abort();
837 break;
838 }
839 Ok(None) => {
840 tokio::time::sleep(std::time::Duration::from_millis(throttle_ms)).await;
841 }
842 Err(_) => {
843 read_task.await.ok();
844 exit_code = -1;
845 break;
846 }
847 }
848 }
849
850 let combined_str = combined.lock().await.clone();
851 if let Some(ref on_update) = ctx.on_update
852 && !combined_str.is_empty()
853 {
854 on_update(yoagent::types::ToolResult {
855 content: vec![yoagent::types::Content::Text {
856 text: combined_str.clone(),
857 }],
858 details: serde_json::Value::Null,
859 });
860 }
861
862 finish_bash_execution(&combined_str, exit_code, false, None, &ctx)
863 }
864}
865
866use std::sync::Mutex;
871
872static TRACKED_PIDS: Mutex<Vec<u32>> = std::sync::Mutex::new(Vec::new());
873
874fn track_pid(pid: u32) {
875 if let Ok(mut pids) = TRACKED_PIDS.lock() {
876 pids.push(pid);
877 }
878}
879
880fn untrack_pid(pid: u32) {
881 if let Ok(mut pids) = TRACKED_PIDS.lock() {
882 pids.retain(|&p| p != pid);
883 }
884}
885
886pub fn kill_tracked_children() {
888 let pids: Vec<u32> = TRACKED_PIDS.lock().map(|p| p.clone()).unwrap_or_default();
889 for pid in pids {
890 kill_process_group(pid);
891 }
892}
893
894struct ProcessGuard {
895 pid: u32,
896}
897
898impl ProcessGuard {
899 fn new(pid: u32) -> Self {
900 if pid > 0 {
901 track_pid(pid);
902 }
903 Self { pid }
904 }
905}
906
907impl Drop for ProcessGuard {
908 fn drop(&mut self) {
909 if self.pid > 0 {
910 untrack_pid(self.pid);
911 }
912 }
913}
914
915#[cfg(test)]
916mod tests {
917 use super::*;
918 use yoagent::AgentTool;
919
920 fn tool_ctx() -> yoagent::types::ToolContext {
921 yoagent::types::ToolContext {
922 tool_call_id: "id".into(),
923 tool_name: "bash".into(),
924 cancel: tokio_util::sync::CancellationToken::new(),
925 on_update: None,
926 on_progress: None,
927 }
928 }
929
930 fn yo_msg_text(content: &[yoagent::types::Content]) -> String {
931 content
932 .iter()
933 .filter_map(|c| {
934 if let yoagent::types::Content::Text { text } = c {
935 Some(text.as_str())
936 } else {
937 None
938 }
939 })
940 .collect::<Vec<_>>()
941 .join("")
942 }
943
944 fn make_tool() -> BashTool {
945 BashTool {
946 cwd: std::env::temp_dir(),
947 shell_path: None,
948 command_prefix: None,
949 operations: None,
950 }
951 }
952
953 #[tokio::test]
954 async fn runs_simple_command() {
955 let tool = make_tool();
956 let output = tool
957 .execute(serde_json::json!({"command": "echo hello"}), tool_ctx())
958 .await
959 .unwrap();
960 assert!(yo_msg_text(&output.content).contains("hello"));
961 }
962
963 #[tokio::test]
964 async fn captures_stderr() {
965 let tool = make_tool();
966 let output = tool
967 .execute(serde_json::json!({"command": "echo err >&2"}), tool_ctx())
968 .await
969 .unwrap();
970 assert!(yo_msg_text(&output.content).contains("err"));
971 }
972
973 #[tokio::test]
974 async fn cancel_aborts() {
975 let tool = make_tool();
976 let cancel = tokio_util::sync::CancellationToken::new();
977 cancel.cancel();
978 let result = tool
979 .execute(
980 serde_json::json!({"command": "sleep 10"}),
981 yoagent::types::ToolContext {
982 tool_call_id: "id".into(),
983 tool_name: "bash".into(),
984 cancel,
985 on_update: None,
986 on_progress: None,
987 },
988 )
989 .await;
990 assert!(result.is_err());
991 let err = result.unwrap_err().to_string();
992 assert!(
993 err.contains("Cancelled") || err.contains("aborted"),
994 "expected cancellation error, got: {}",
995 err
996 );
997 }
998
999 #[tokio::test]
1000 async fn timeout_works() {
1001 let tool = make_tool();
1002 let result = tool
1003 .execute(
1004 serde_json::json!({"command": "sleep 10", "timeout": 1}),
1005 tool_ctx(),
1006 )
1007 .await;
1008 assert!(result.is_err());
1009 let err = result.unwrap_err().to_string();
1010 assert!(err.contains("timed out"));
1011 }
1012
1013 #[test]
1014 fn test_truncate_tail_no_truncation() {
1015 let result = truncate_tail("hello\nworld\n", 2000, 50000);
1016 assert!(!result.truncated);
1017 assert_eq!(result.content, "hello\nworld\n");
1018 }
1019
1020 #[test]
1021 fn test_truncate_tail_by_lines() {
1022 let content: String = (1..=5000).map(|i| format!("line {}\n", i)).collect();
1023 let result = truncate_tail(&content, 2000, 50000);
1024 assert!(result.truncated);
1025 assert!(result.content.starts_with("line 3001"));
1026 assert_eq!(result.content.lines().count(), 2000);
1027 }
1028
1029 #[test]
1030 fn test_truncate_tail_by_bytes() {
1031 let content: String = (1..=100)
1032 .map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
1033 .collect();
1034 let result = truncate_tail(&content, 2000, 50000);
1035 assert!(result.truncated);
1036 assert!(result.content.len() <= 50000);
1037 assert!(result.content.lines().count() < 100);
1038 }
1039
1040 #[test]
1041 fn test_truncate_tail_partial_last_line() {
1042 let content = format!("short\n{}\n", "x".repeat(60000));
1043 let result = truncate_tail(&content, 2000, 50000);
1044 assert!(result.truncated);
1045 assert!(!result.content.starts_with("short"));
1046 assert!(result.content.len() <= 50000);
1047 }
1048
1049 #[test]
1050 fn test_truncate_tail_empty() {
1051 let result = truncate_tail("", 2000, 50000);
1052 assert!(!result.truncated);
1053 assert_eq!(result.content, "");
1054 }
1055
1056 #[tokio::test]
1057 async fn exit_code_nonzero() {
1058 let tool = make_tool();
1059 let result = tool
1060 .execute(serde_json::json!({"command": "exit 42"}), tool_ctx())
1061 .await;
1062 assert!(result.is_err(), "non-zero exit should return error");
1063 let err = result.unwrap_err().to_string();
1064 assert!(err.contains("exited with code 42"), "got: {}", err);
1065 }
1066
1067 #[tokio::test]
1068 async fn exit_code_with_output() {
1069 let tool = make_tool();
1070 let result = tool
1071 .execute(
1072 serde_json::json!({"command": "echo before && exit 1"}),
1073 tool_ctx(),
1074 )
1075 .await;
1076 assert!(result.is_err(), "non-zero exit should return error");
1077 let err = result.unwrap_err().to_string();
1078 assert!(err.contains("before"), "got: {}", err);
1079 assert!(err.contains("exited with code 1"), "got: {}", err);
1080 }
1081
1082 #[tokio::test]
1083 async fn no_output() {
1084 let tool = make_tool();
1085 let output = tool
1086 .execute(serde_json::json!({"command": "true"}), tool_ctx())
1087 .await
1088 .unwrap();
1089 assert!(
1090 yo_msg_text(&output.content).contains("(no output)"),
1091 "got: {}",
1092 yo_msg_text(&output.content)
1093 );
1094 }
1095
1096 #[tokio::test]
1097 async fn combined_stdout_stderr() {
1098 let tool = make_tool();
1099 let output = tool
1100 .execute(
1101 serde_json::json!({"command": "echo out; echo err >&2"}),
1102 tool_ctx(),
1103 )
1104 .await
1105 .unwrap();
1106 assert!(
1107 yo_msg_text(&output.content).contains("out"),
1108 "got: {}",
1109 yo_msg_text(&output.content)
1110 );
1111 assert!(
1112 yo_msg_text(&output.content).contains("err"),
1113 "got: {}",
1114 yo_msg_text(&output.content)
1115 );
1116 }
1117
1118 #[tokio::test]
1119 async fn runs_in_cwd() {
1120 let tmp = std::env::temp_dir().join(format!("rab-bash-cwd-{}", uuid::Uuid::new_v4()));
1121 std::fs::create_dir_all(&tmp).unwrap();
1122 std::fs::write(tmp.join("marker.txt"), "hello").unwrap();
1123
1124 let tool = BashTool {
1125 cwd: tmp.clone(),
1126 shell_path: None,
1127 command_prefix: None,
1128 operations: None,
1129 };
1130 let output = tool
1131 .execute(serde_json::json!({"command": "cat marker.txt"}), tool_ctx())
1132 .await
1133 .unwrap();
1134 assert!(
1135 yo_msg_text(&output.content).contains("hello"),
1136 "got: {}",
1137 yo_msg_text(&output.content)
1138 );
1139 }
1140
1141 #[tokio::test]
1142 async fn missing_command_errors() {
1143 let tool = make_tool();
1144 let result = tool.execute(serde_json::json!({}), tool_ctx()).await;
1145 assert!(result.is_err());
1146 let err = result.unwrap_err().to_string();
1147 assert!(err.contains("command"), "got: {}", err);
1148 }
1149
1150 #[tokio::test]
1151 async fn timeout_with_partial_output() {
1152 let tool = make_tool();
1153 let result = tool
1154 .execute(
1155 serde_json::json!({"command": "echo start && sleep 10 && echo end", "timeout": 1}),
1156 tool_ctx(),
1157 )
1158 .await;
1159 assert!(result.is_err());
1160 let err = result.unwrap_err().to_string();
1161 assert!(err.contains("timed out"), "got: {}", err);
1162 }
1163
1164 #[tokio::test]
1165 async fn cancel_during_long_command() {
1166 let tool = make_tool();
1167 let cancel = tokio_util::sync::CancellationToken::new();
1168 let cancel_ctx = cancel.clone();
1169
1170 let handle = tokio::spawn(async move {
1171 tool.execute(
1172 serde_json::json!({"command": "sleep 30"}),
1173 yoagent::types::ToolContext {
1174 tool_call_id: "id".into(),
1175 tool_name: "bash".into(),
1176 cancel: cancel_ctx,
1177 on_update: None,
1178 on_progress: None,
1179 },
1180 )
1181 .await
1182 });
1183
1184 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1185 cancel.cancel();
1186
1187 let result = handle.await.unwrap();
1188 assert!(result.is_err());
1189 let err = result.unwrap_err().to_string();
1190 assert!(
1191 err.contains("aborted") || err.contains("Cancelled"),
1192 "expected cancellation error, got: {}",
1193 err
1194 );
1195 }
1196
1197 #[test]
1198 fn test_truncate_tail_exact_line_fit() {
1199 let lines: String = (1..=2000).map(|i| format!("line {}\n", i)).collect();
1200 let result = truncate_tail(&lines, 2000, 50000);
1201 assert!(!result.truncated);
1202 assert!(result.content.lines().count() == 2000);
1203 }
1204
1205 #[test]
1206 fn test_truncate_tail_one_over_line_limit() {
1207 let lines: String = (1..=2001).map(|i| format!("line {}\n", i)).collect();
1208 let result = truncate_tail(&lines, 2000, 50000);
1209 assert!(result.truncated);
1210 assert_eq!(result.content.lines().count(), 2000);
1211 assert!(result.content.starts_with("line 2"));
1212 }
1213
1214 #[test]
1215 fn test_truncate_tail_exact_byte_fit() {
1216 let line = "a".repeat(50000);
1217 let result = truncate_tail(&line, 2000, 50000);
1218 assert!(!result.truncated);
1219 }
1220
1221 #[test]
1222 fn test_truncate_tail_one_byte_over() {
1223 let line = "a".repeat(50001);
1224 let result = truncate_tail(&line, 2000, 50000);
1225 assert!(result.truncated);
1226 assert!(result.content.len() <= 50000);
1227 }
1228
1229 #[test]
1230 fn test_truncate_tail_single_line_under_limit() {
1231 let result = truncate_tail("hello world", 2000, 50000);
1232 assert!(!result.truncated);
1233 assert_eq!(result.content, "hello world");
1234 }
1235
1236 #[test]
1237 fn test_truncate_tail_trailing_newline() {
1238 let result = truncate_tail("a\nb\n", 2000, 50000);
1239 assert!(!result.truncated);
1240 assert_eq!(result.content, "a\nb\n");
1241 }
1242
1243 #[test]
1244 fn test_truncate_tail_no_trailing_newline() {
1245 let result = truncate_tail("a\nb", 2000, 50000);
1246 assert!(!result.truncated);
1247 assert_eq!(result.content, "a\nb");
1248 }
1249
1250 #[test]
1251 fn test_truncate_tail_single_line_exceeds_limit() {
1252 let content = "x".repeat(60000);
1253 let result = truncate_tail(&content, 2000, 50000);
1254 assert!(result.truncated);
1255 assert!(result.last_line_partial);
1256 assert_eq!(result.content.len(), 50000);
1257 assert!(result.content.ends_with("x".repeat(50000).as_str()));
1258 }
1259
1260 #[test]
1261 fn test_truncate_tail_byte_count_respects_newlines() {
1262 let content: String = (1..=100)
1263 .map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
1264 .collect();
1265 let result = truncate_tail(&content, 2000, 50000);
1266 assert!(result.truncated);
1267 assert!(result.output_bytes <= 50000);
1268 }
1269
1270 #[tokio::test]
1271 async fn truncated_by_lines_shows_footer() {
1272 let tool = make_tool();
1273 let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
1274 let output = tool
1275 .execute(serde_json::json!({"command": cmd}), tool_ctx())
1276 .await
1277 .unwrap();
1278 assert!(
1279 yo_msg_text(&output.content).contains("Showing lines"),
1280 "got: {}",
1281 yo_msg_text(&output.content)
1282 );
1283 assert!(
1284 yo_msg_text(&output.content).contains("Full output:"),
1285 "got: {}",
1286 yo_msg_text(&output.content)
1287 );
1288 }
1289
1290 #[tokio::test]
1291 async fn small_output_no_footer() {
1292 let tool = make_tool();
1293 let output = tool
1294 .execute(serde_json::json!({"command": "echo hello"}), tool_ctx())
1295 .await
1296 .unwrap();
1297 assert!(!yo_msg_text(&output.content).contains("Output truncated"));
1298 assert!(!yo_msg_text(&output.content).contains("Full output:"));
1299 }
1300
1301 #[tokio::test]
1302 async fn truncated_saves_temp_file() {
1303 let tool = make_tool();
1304 let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
1305 let output = tool
1306 .execute(serde_json::json!({"command": cmd}), tool_ctx())
1307 .await
1308 .unwrap();
1309 assert!(
1310 yo_msg_text(&output.content).contains("/pi-bash/"),
1311 "expected temp file path with /pi-bash/, got: {}",
1312 yo_msg_text(&output.content)
1313 );
1314 }
1315
1316 #[test]
1317 fn test_truncate_tail_many_short_lines() {
1318 let content: String = (1..=10000).map(|i| format!("{}\n", i)).collect();
1319 let result = truncate_tail(&content, 2000, 50000);
1320 assert!(result.truncated);
1321 assert_eq!(result.truncated_by, "lines");
1322 assert_eq!(result.output_lines, 2000);
1323 assert!(
1324 result.content.starts_with("8001"),
1325 "starts with: {:?}",
1326 &result.content[..10]
1327 );
1328 }
1329
1330 #[test]
1331 fn test_truncate_tail_lines_and_bytes_both_exceeded() {
1332 let content: String = (1..=5000)
1333 .map(|i| format!("line {} {}\n", i, "x".repeat(100)))
1334 .collect();
1335 let result = truncate_tail(&content, 2000, 30000);
1336 assert!(result.truncated);
1337 assert_eq!(result.truncated_by, "bytes");
1338 assert!(result.output_lines < 2000);
1339 }
1340
1341 #[test]
1344 fn test_process_guard_tracks_pid() {
1345 let pid = 12345u32;
1346 {
1347 let _guard = ProcessGuard::new(pid);
1348 let pids = TRACKED_PIDS.lock().unwrap();
1349 assert!(pids.contains(&pid));
1350 }
1351 let pids = TRACKED_PIDS.lock().unwrap();
1352 assert!(!pids.contains(&pid));
1353 }
1354
1355 #[test]
1356 fn test_process_guard_zero_pid() {
1357 {
1358 let _guard = ProcessGuard::new(0);
1359 let pids = TRACKED_PIDS.lock().unwrap();
1360 assert!(!pids.contains(&0));
1361 }
1362 }
1363}