Skip to main content

rab/builtin/
bash.rs

1use crate::agent::extension::{Cancel, Extension, ToolDefinition};
2use crate::agent::extension::{ToolRenderContext, ToolRenderer};
3use crate::tui::Theme;
4use crate::tui::ThemeKey;
5use crate::tui::visual_truncate::truncate_to_visual_lines;
6use async_trait::async_trait;
7
8use std::borrow::Cow;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::time::Instant;
14use tokio::sync::{Mutex as TokioMutex, mpsc::UnboundedSender};
15
16// ── BashOperations (pluggable) ────────────────────────────────────
17
18/// Pluggable operations for the bash tool (matching pi's BashOperations).
19/// Override these to delegate command execution to remote systems (for example SSH).
20#[async_trait]
21pub trait BashOperations: Send + Sync {
22    /// Execute a command and stream output via the sender.
23    /// Returns the exit code (0 = success, non-zero = error, None = killed).
24    async fn exec(
25        &self,
26        command: &str,
27        cwd: &Path,
28        on_data: UnboundedSender<String>,
29        signal: Option<&Cancel>,
30        timeout: Option<u64>,
31        env: Option<HashMap<String, String>>,
32    ) -> Result<Option<i32>, anyhow::Error>;
33}
34
35#[derive(Clone, Default)]
36pub struct BashToolOptions {
37    /// Custom operations for command execution. Default: local shell.
38    pub operations: Option<Arc<dyn BashOperations>>,
39    /// Command prefix prepended to every command (for example shell setup commands).
40    pub command_prefix: Option<String>,
41    /// Optional explicit shell path from settings.
42    pub shell_path: Option<String>,
43}
44
45pub struct BashExtension {
46    cwd: PathBuf,
47    options: BashToolOptions,
48}
49
50impl BashExtension {
51    pub fn new(cwd: PathBuf) -> Self {
52        Self {
53            cwd,
54            options: BashToolOptions::default(),
55        }
56    }
57
58    pub fn with_options(cwd: PathBuf, options: BashToolOptions) -> Self {
59        Self { cwd, options }
60    }
61
62    pub fn with_shell_path(cwd: PathBuf, shell_path: String) -> Self {
63        Self {
64            cwd,
65            options: BashToolOptions {
66                shell_path: Some(shell_path),
67                ..BashToolOptions::default()
68            },
69        }
70    }
71}
72
73impl Extension for BashExtension {
74    fn name(&self) -> Cow<'static, str> {
75        "bash".into()
76    }
77
78    fn tools(&self) -> Vec<ToolDefinition> {
79        vec![ToolDefinition {
80            tool: Box::new(BashTool {
81                cwd: self.cwd.clone(),
82                shell_path: self.options.shell_path.clone(),
83                command_prefix: self.options.command_prefix.clone(),
84                operations: self.options.operations.clone(),
85            }),
86            snippet: "Execute bash commands (ls, grep, find, etc.)",
87            guidelines: &[],
88            prepare_arguments: None,
89            before_tool_call: None,
90            after_tool_call: None,
91            renderer: Some(std::sync::Arc::new(BashRenderer)),
92        }]
93    }
94}
95
96struct BashTool {
97    cwd: PathBuf,
98    shell_path: Option<String>,
99    command_prefix: Option<String>,
100    operations: Option<Arc<dyn BashOperations>>,
101}
102
103// ── Constants ────────────────────────────────────────────────────
104
105const DEFAULT_MAX_LINES: usize = 2000;
106const DEFAULT_MAX_BYTES: usize = 50 * 1024; // 50KB
107const BASH_TEMP_FILE_PREFIX: &str = "rab-bash";
108
109/// Maximum age before stale temp files are cleaned up (24 hours).
110const TEMP_FILE_MAX_AGE_SECS: u64 = 24 * 60 * 60;
111
112/// Grace period after child exit (ms) — matching pi's EXIT_STDIO_GRACE_MS.
113/// Detached descendants may keep stdout/stderr pipes open; we poll until idle.
114const EXIT_STDIO_GRACE_MS: u64 = 100;
115
116// ── Shell resolution (matching pi's getShellConfig) ──────────────
117
118/// Shell configuration: which shell binary to use and how to pass commands.
119struct ShellConfig {
120    shell: String,
121    args: Vec<String>,
122}
123
124/// Resolve the shell to use for command execution.
125/// Resolution order (matching pi):
126/// 1. User-specified shell_path (from BashTool.shell_path)
127/// 2. On Unix: /bin/bash, then bash on PATH, then fallback to sh
128/// 3. On Windows: Git Bash, bash on PATH, fallback to sh
129fn resolve_shell(shell_path: Option<&str>) -> ShellConfig {
130    if let Some(path) = shell_path {
131        return ShellConfig {
132            shell: path.to_string(),
133            args: vec!["-c".to_string()],
134        };
135    }
136
137    // Try /bin/bash first (most common on Unix)
138    if std::path::Path::new("/bin/bash").exists() {
139        return ShellConfig {
140            shell: "/bin/bash".to_string(),
141            args: vec!["-c".to_string()],
142        };
143    }
144
145    // Try `which bash`
146    #[cfg(unix)]
147    {
148        if let Ok(output) = std::process::Command::new("which")
149            .arg("bash")
150            .stdout(std::process::Stdio::piped())
151            .stderr(std::process::Stdio::null())
152            .output()
153            && output.status.success()
154        {
155            let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
156            if !path.is_empty() && std::path::Path::new(&path).exists() {
157                return ShellConfig {
158                    shell: path,
159                    args: vec!["-c".to_string()],
160                };
161            }
162        }
163    }
164
165    // Fallback to sh
166    ShellConfig {
167        shell: "sh".to_string(),
168        args: vec!["-c".to_string()],
169    }
170}
171
172// ── Helpers ──────────────────────────────────────────────────────
173
174/// Kill a process group by its leader PID.
175#[cfg(unix)]
176fn kill_process_group(pid: u32) {
177    if pid > 0 {
178        let _ = std::process::Command::new("kill")
179            .arg("--")
180            .arg(format!("-{}", pid))
181            .status();
182    }
183}
184
185#[cfg(not(unix))]
186fn kill_process_group(pid: u32) {
187    let _ = pid;
188}
189
190/// Spawn a bash command with process group setup for clean cancellation.
191fn spawn_bash_command(
192    command: &str,
193    cwd: &std::path::Path,
194    shell_path: Option<&str>,
195) -> std::io::Result<tokio::process::Child> {
196    let shell_cfg = resolve_shell(shell_path);
197
198    #[cfg(unix)]
199    {
200        use std::os::unix::process::CommandExt;
201        let mut std_cmd = std::process::Command::new(&shell_cfg.shell);
202        std_cmd.args(&shell_cfg.args).arg(command).current_dir(cwd);
203        unsafe {
204            std_cmd.pre_exec(|| {
205                libc::setpgid(0, 0);
206                Ok(())
207            });
208        }
209        let mut tokio_cmd = tokio::process::Command::from(std_cmd);
210        tokio_cmd
211            .stdin(std::process::Stdio::null())
212            .stdout(std::process::Stdio::piped())
213            .stderr(std::process::Stdio::piped())
214            .spawn()
215    }
216    #[cfg(not(unix))]
217    {
218        tokio::process::Command::new(&shell_cfg.shell)
219            .args(&shell_cfg.args)
220            .arg(command)
221            .current_dir(cwd)
222            .stdin(std::process::Stdio::null())
223            .stdout(std::process::Stdio::piped())
224            .stderr(std::process::Stdio::piped())
225            .spawn()
226    }
227}
228
229/// Sanitize binary output for display/storage (matching pi's sanitizeBinaryOutput + stripAnsi).
230fn sanitize_output(text: &str) -> String {
231    let mut result = String::with_capacity(text.len());
232    let mut in_escape = false;
233    for c in text.chars() {
234        if in_escape {
235            if c == '\x1b' || c == '\u{9b}' {
236                continue;
237            }
238            if c.is_ascii_alphabetic() || c == '~' {
239                in_escape = false;
240            }
241            continue;
242        }
243        if c == '\x1b' || c == '\u{9b}' {
244            in_escape = true;
245            continue;
246        }
247        let code = c as u32;
248        if code <= 0x1f && code != 0x09 && code != 0x0a && code != 0x0d {
249            continue;
250        }
251        if (0xfff9..=0xfffb).contains(&code) {
252            continue;
253        }
254        result.push(c);
255    }
256    result
257}
258
259fn format_size(bytes: usize) -> String {
260    if bytes < 1024 {
261        format!("{}B", bytes)
262    } else if bytes < 1024 * 1024 {
263        format!("{:.1}KB", bytes as f64 / 1024.0)
264    } else {
265        format!("{:.1}MB", bytes as f64 / (1024.0 * 1024.0))
266    }
267}
268
269/// Truncation result for tail-based truncation (keep last N lines/bytes).
270struct TailTruncation {
271    content: String,
272    truncated: bool,
273    total_lines: usize,
274    output_lines: usize,
275    output_bytes: usize,
276    truncated_by: &'static str,
277    last_line_partial: bool,
278}
279
280/// Truncate content from the tail, keeping complete lines that fit within limits.
281fn truncate_tail(content: &str, max_lines: usize, max_bytes: usize) -> TailTruncation {
282    let total_bytes = content.len();
283    let lines: Vec<&str> = content.lines().collect();
284    let total_lines = lines.len();
285
286    if total_lines <= max_lines && total_bytes <= max_bytes {
287        return TailTruncation {
288            content: content.to_string(),
289            truncated: false,
290            total_lines,
291            output_lines: total_lines,
292            output_bytes: total_bytes,
293            truncated_by: "",
294            last_line_partial: false,
295        };
296    }
297
298    let mut output: Vec<&str> = Vec::new();
299    let mut byte_count: usize = 0;
300    let mut truncated_by = "lines";
301    let mut last_line_partial = false;
302
303    for line in lines.iter().rev().take(max_lines) {
304        let line_bytes = line.len();
305        let with_newline = if output.is_empty() {
306            line_bytes
307        } else {
308            line_bytes + 1
309        };
310
311        if byte_count + with_newline > max_bytes {
312            truncated_by = "bytes";
313            if output.is_empty() {
314                let end_start = line.len().saturating_sub(max_bytes);
315                let truncated_line = &line[end_start..];
316                output.push(truncated_line);
317                byte_count = truncated_line.len();
318                last_line_partial = true;
319            }
320            break;
321        }
322
323        output.push(line);
324        byte_count += with_newline;
325    }
326
327    if output.len() >= max_lines && byte_count <= max_bytes {
328        truncated_by = "lines";
329    }
330
331    output.reverse();
332    TailTruncation {
333        content: output.join("\n"),
334        truncated: true,
335        total_lines,
336        output_lines: output.len(),
337        output_bytes: byte_count,
338        truncated_by,
339        last_line_partial,
340    }
341}
342
343// ── Result formatting ────────────────────────────────────────────
344
345fn finish_bash_execution(
346    combined: &str,
347    exit_code: i32,
348    cancelled: bool,
349    timed_out: Option<u64>,
350    ctx: &yoagent::types::ToolContext,
351) -> std::result::Result<yoagent::types::ToolResult, yoagent::types::ToolError> {
352    let trunc = truncate_tail(combined, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES);
353
354    let mut result_text = if trunc.content.is_empty() {
355        "(no output)".to_string()
356    } else {
357        trunc.content.clone()
358    };
359
360    // Save full output to temp file if truncated
361    let full_output_path = if trunc.truncated {
362        let tmp_dir = temp_output_dir();
363        let _ = std::fs::create_dir_all(&tmp_dir);
364        let tmp_path = tmp_dir.join(format!("{}.log", uuid::Uuid::new_v4()));
365        let saved = std::fs::write(&tmp_path, combined).ok().map(|_| {
366            cleanup_stale_temp_files();
367            tmp_path
368        });
369
370        let start_line = trunc.total_lines - trunc.output_lines + 1;
371        let end_line = trunc.total_lines;
372
373        let notice = if trunc.truncated_by == "lines" {
374            format!(
375                "\n\n[Showing lines {}-{} of {}. Full output: {}]",
376                start_line,
377                end_line,
378                trunc.total_lines,
379                saved
380                    .as_ref()
381                    .map(|p| p.display().to_string())
382                    .unwrap_or_default()
383            )
384        } else {
385            format!(
386                "\n\n[Showing lines {}-{} of {} ({} limit). Full output: {}]",
387                start_line,
388                end_line,
389                trunc.total_lines,
390                format_size(DEFAULT_MAX_BYTES),
391                saved
392                    .as_ref()
393                    .map(|p| p.display().to_string())
394                    .unwrap_or_default()
395            )
396        };
397        result_text.push_str(&notice);
398        saved
399    } else {
400        None
401    };
402
403    // Build structured details
404    let details = if trunc.truncated || full_output_path.is_some() {
405        Some(serde_json::json!({
406            "truncation": {
407                "truncated": trunc.truncated,
408                "truncatedBy": trunc.truncated_by,
409                "totalLines": trunc.total_lines,
410                "outputLines": trunc.output_lines,
411                "outputBytes": trunc.output_bytes,
412                "lastLinePartial": trunc.last_line_partial,
413                "maxLines": DEFAULT_MAX_LINES,
414                "maxBytes": DEFAULT_MAX_BYTES,
415            },
416            "fullOutputPath": full_output_path.as_ref().map(|p| p.display().to_string()),
417        }))
418    } else {
419        None
420    };
421
422    let final_output = if cancelled {
423        format_status_output(&result_text, "Command aborted")
424    } else if let Some(secs) = timed_out {
425        format_status_output(
426            &result_text,
427            &format!("Command timed out after {} seconds", secs),
428        )
429    } else if exit_code != 0 {
430        format_status_output(
431            &result_text,
432            &format!("Command exited with code {}", exit_code),
433        )
434    } else {
435        emit_update(ctx, result_text.clone(), details.clone());
436        return Ok(into_tool_result(result_text, details));
437    };
438
439    emit_update(ctx, final_output.clone(), details.clone());
440    Err(yoagent::types::ToolError::Failed(final_output))
441}
442
443// ── Bash execution helpers (was AgentTool impl, now in yoagent) ──
444
445// ── Bash tool renderer ───────────────────────────────────────────
446
447struct BashRenderer;
448
449impl ToolRenderer for BashRenderer {
450    fn render_call(
451        &self,
452        args: &serde_json::Value,
453        _width: usize,
454        theme: &dyn Theme,
455        _ctx: &ToolRenderContext,
456    ) -> Vec<String> {
457        let cmd = args
458            .get("command")
459            .and_then(|v| v.as_str())
460            .unwrap_or("...");
461        let timeout = args.get("timeout").and_then(|v| v.as_i64());
462        let timeout_suffix = timeout
463            .map(|t| theme.fg_key(ThemeKey::Muted, &format!(" (timeout {}s)", t)))
464            .unwrap_or_default();
465
466        vec![format!(
467            "{}{}",
468            theme.fg_key(ThemeKey::ToolTitle, &theme.bold(&format!("$ {}", cmd))),
469            timeout_suffix
470        )]
471    }
472
473    fn render_result(
474        &self,
475        content: &str,
476        width: usize,
477        theme: &dyn Theme,
478        ctx: &ToolRenderContext,
479    ) -> Vec<String> {
480        let mut lines: Vec<String> = Vec::new();
481
482        let clean = strip_context_truncation_footer(content)
483            .trim_end()
484            .to_string();
485        let all_lines: Vec<&str> = clean.lines().collect();
486
487        if all_lines.is_empty() || (all_lines.len() == 1 && all_lines[0].is_empty()) {
488            return lines;
489        }
490
491        let preview_count = 5;
492        let (preview_lines, hidden_line_count) = if ctx.expanded {
493            (all_lines.clone(), 0)
494        } else {
495            truncate_to_visual_lines(&all_lines, width, preview_count)
496        };
497
498        // ── Preview hint with dim/muted styling (matching pi's keyHint) ──
499        if !ctx.expanded && hidden_line_count > 0 {
500            if ctx.expand_key.is_empty() {
501                lines.push(theme.fg_key(
502                    ThemeKey::Muted,
503                    &format!("... {} earlier lines", hidden_line_count),
504                ));
505            } else {
506                // Pi pattern: muted prefix + dim key + muted suffix
507                // e.g. "... (12 earlier lines, \x1b[2mctrl+o\x1b[22m to expand)"
508                let prefix = theme.fg_key(
509                    ThemeKey::Muted,
510                    &format!("... ({} earlier lines, ", hidden_line_count),
511                );
512                let key_styled = theme.fg("dim", &ctx.expand_key);
513                let suffix = theme.fg_key(ThemeKey::Muted, " to expand)");
514                lines.push(format!("{}{}{}", prefix, key_styled, suffix));
515            }
516        }
517
518        let fg_key = if ctx.is_error { "error" } else { "toolOutput" };
519        for line in &preview_lines {
520            if line.is_empty() {
521                lines.push(String::new());
522            } else {
523                lines.push(theme.fg(fg_key, line));
524            }
525        }
526
527        if let Some(secs) = ctx.duration_secs {
528            if !lines.is_empty() {
529                lines.push(String::new());
530            }
531            let is_complete = ctx.exit_code.is_some() || ctx.cancelled;
532            let label = if is_complete { "Took" } else { "Elapsed" };
533            lines.push(theme.fg_key(ThemeKey::Muted, &format!("{} {:.1}s", label, secs)));
534        }
535
536        if ctx.was_truncated {
537            if !lines.is_empty() {
538                lines.push(String::new());
539            }
540            if let Some(ref path) = ctx.full_output_path {
541                lines.push(theme.fg(
542                    "warning",
543                    &format!("Output truncated. Full output: {}", path),
544                ));
545            } else {
546                lines.push(theme.fg_key(ThemeKey::Warning, "Output truncated."));
547            }
548        }
549
550        lines
551    }
552}
553
554fn strip_context_truncation_footer(output: &str) -> String {
555    let lines: Vec<&str> = output.lines().collect();
556    if lines.len() < 3 {
557        return output.to_string();
558    }
559    let last = lines.last().map_or("", |v| v).trim();
560    if last.starts_with('[')
561        && (last.contains("Showing lines") || last.contains("Showing last"))
562        && last.contains("Full output:")
563    {
564        let before: Vec<&str> = lines[..lines.len() - 1].to_vec();
565        if !before.is_empty() && before[before.len() - 1].is_empty() {
566            before[..before.len() - 1].join("\n")
567        } else {
568            before.join("\n")
569        }
570    } else {
571        output.to_string()
572    }
573}
574
575#[async_trait::async_trait]
576impl yoagent::types::AgentTool for BashTool {
577    fn name(&self) -> &str {
578        "bash"
579    }
580    fn label(&self) -> &str {
581        "bash"
582    }
583    fn description(&self) -> &str {
584        "Execute a bash command in the current working directory. Returns stdout and stderr. \
585         Output is truncated to last 2000 lines or 50KB (whichever is hit first). If \
586         truncated, full output is saved to a temp file. Optionally provide a timeout in seconds."
587    }
588    fn parameters_schema(&self) -> serde_json::Value {
589        serde_json::json!({
590            "type": "object",
591            "required": ["command"],
592            "properties": {
593                "command": {
594                    "type": "string",
595                    "description": "Bash command to execute"
596                },
597                "timeout": {
598                    "type": "number",
599                    "description": "Timeout in seconds (optional, no default timeout)"
600                }
601            }
602        })
603    }
604    async fn execute(
605        &self,
606        params: serde_json::Value,
607        ctx: yoagent::types::ToolContext,
608    ) -> std::result::Result<yoagent::types::ToolResult, yoagent::types::ToolError> {
609        let command = params["command"].as_str().ok_or_else(|| {
610            yoagent::types::ToolError::InvalidArgs("Missing 'command' argument".into())
611        })?;
612        let timeout = params["timeout"].as_u64();
613        let started_at = Instant::now();
614
615        if ctx.cancel.is_cancelled() {
616            return Err(yoagent::types::ToolError::Cancelled);
617        }
618
619        // Apply command prefix if set
620        let effective_command = if let Some(ref prefix) = self.command_prefix {
621            format!("{}\n{}", prefix, command)
622        } else {
623            command.to_string()
624        };
625
626        // Check that the working directory exists
627        if !self.cwd.exists() {
628            return Err(yoagent::types::ToolError::Failed(format!(
629                "Working directory does not exist: {}\nCannot execute bash commands.",
630                self.cwd.display()
631            )));
632        }
633
634        // If custom operations are provided, delegate entirely
635        if let Some(ref ops) = self.operations {
636            let (output_tx, mut output_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
637            let ops_cancel = Cancel::new();
638
639            // Link yoagent cancellation to rab Cancel
640            let yo_cancel = ctx.cancel.clone();
641            let watch_cancel = ops_cancel.clone();
642            tokio::spawn(async move {
643                yo_cancel.cancelled().await;
644                watch_cancel.cancel();
645            });
646
647            let ops_command = effective_command.clone();
648            let ops_cwd = self.cwd.clone();
649            let ops = ops.clone();
650            let ops_handle = tokio::spawn(async move {
651                ops.exec(
652                    &ops_command,
653                    &ops_cwd,
654                    output_tx,
655                    Some(&ops_cancel),
656                    timeout,
657                    None,
658                )
659                .await
660            });
661
662            // Collect output from the channel
663            let mut combined = String::new();
664            while let Some(chunk) = output_rx.recv().await {
665                combined.push_str(&chunk);
666                emit_update(&ctx, combined.clone(), None);
667            }
668
669            let exit_code = ops_handle.await.unwrap_or(Ok(None)).unwrap_or(None);
670            let code = exit_code.unwrap_or(-1);
671
672            return finish_bash_execution(&combined, code, ctx.cancel.is_cancelled(), None, &ctx);
673        }
674
675        let mut child =
676            spawn_bash_command(&effective_command, &self.cwd, self.shell_path.as_deref()).map_err(
677                |e| yoagent::types::ToolError::Failed(format!("Failed to spawn command: {}", e)),
678            )?;
679
680        let pid = child.id().unwrap_or(0);
681
682        // Shared output buffer for streaming reads
683        let combined = Arc::new(TokioMutex::new(String::new()));
684        let combined_clone = combined.clone();
685
686        let stdout_pipe = child
687            .stdout
688            .take()
689            .ok_or_else(|| yoagent::types::ToolError::Failed("Failed to capture stdout".into()))?;
690        let stderr_pipe = child
691            .stderr
692            .take()
693            .ok_or_else(|| yoagent::types::ToolError::Failed("Failed to capture stderr".into()))?;
694
695        use tokio::io::AsyncReadExt;
696        let read_task = tokio::spawn(async move {
697            let mut stdout_buf = vec![0u8; 65536];
698            let mut stderr_buf = vec![0u8; 65536];
699            let mut stdout_reader = stdout_pipe;
700            let mut stderr_reader = stderr_pipe;
701            let mut stdout_done = false;
702            let mut stderr_done = false;
703            loop {
704                tokio::select! {
705                    result = stdout_reader.read(&mut stdout_buf), if !stdout_done => {
706                        match result {
707                            Ok(0) => stdout_done = true,
708                            Ok(n) => {
709                                let text = String::from_utf8_lossy(&stdout_buf[..n]);
710                                let sanitized = sanitize_output(&text);
711                                let mut out = combined_clone.lock().await;
712                                out.push_str(&sanitized);
713                            }
714                            Err(_) => stdout_done = true,
715                        }
716                    }
717                    result = stderr_reader.read(&mut stderr_buf), if !stderr_done => {
718                        match result {
719                            Ok(0) => stderr_done = true,
720                            Ok(n) => {
721                                let text = String::from_utf8_lossy(&stderr_buf[..n]);
722                                let sanitized = sanitize_output(&text);
723                                let mut out = combined_clone.lock().await;
724                                out.push_str(&sanitized);
725                            }
726                            Err(_) => stderr_done = true,
727                        }
728                    }
729                }
730                if stdout_done && stderr_done {
731                    break;
732                }
733            }
734        });
735
736        // ── PID tracking for cleanup on shutdown signals ──
737        let _pid_guard = ProcessGuard::new(pid);
738
739        // Set up cancellation monitor: kill the process group if cancelled
740        let cancelled = Arc::new(AtomicBool::new(false));
741        let cancel_flag = cancelled.clone();
742        let yo_cancel = ctx.cancel.clone();
743        let _cancel_monitor: tokio::task::JoinHandle<()> = tokio::spawn(async move {
744            yo_cancel.cancelled().await;
745            cancel_flag.store(true, Ordering::SeqCst);
746            kill_process_group(pid);
747        });
748
749        // Send initial empty update
750        if let Some(ref on_update) = ctx.on_update {
751            on_update(yoagent::types::ToolResult {
752                content: vec![],
753                details: serde_json::Value::Null,
754            });
755        }
756
757        // Wait for the process to exit, with optional timeout and streaming updates
758        let timeout_dur = timeout.map(std::time::Duration::from_secs);
759        let throttle_ms = 100u64;
760        let mut last_update_at = Instant::now();
761
762        let exit_code: i32;
763
764        loop {
765            if cancelled.load(Ordering::SeqCst) {
766                kill_process_group(pid);
767                read_task.abort();
768                let combined_str = combined.lock().await.clone();
769                return finish_bash_execution(&combined_str, -1, true, None, &ctx);
770            }
771
772            if let Some(dur) = timeout_dur
773                && started_at.elapsed() > dur
774            {
775                kill_process_group(pid);
776                read_task.abort();
777                let combined_str = combined.lock().await.clone();
778                return finish_bash_execution(&combined_str, -1, false, timeout, &ctx);
779            }
780
781            if last_update_at.elapsed().as_millis() as u64 >= throttle_ms {
782                let out = combined.lock().await.clone();
783                if !out.is_empty() {
784                    last_update_at = Instant::now();
785                    emit_update(&ctx, out, None);
786                }
787            }
788
789            match child.try_wait() {
790                Ok(Some(status)) => {
791                    exit_code = status.code().unwrap_or(-1);
792                    // Idle grace period — after child exit, wait for pipes to go idle
793                    let mut last_len = combined.lock().await.len();
794                    loop {
795                        tokio::time::sleep(std::time::Duration::from_millis(EXIT_STDIO_GRACE_MS))
796                            .await;
797                        let new_len = combined.lock().await.len();
798                        if new_len == last_len {
799                            break;
800                        }
801                        last_len = new_len;
802                    }
803                    read_task.abort();
804                    break;
805                }
806                Ok(None) => {
807                    tokio::time::sleep(std::time::Duration::from_millis(throttle_ms)).await;
808                }
809                Err(_) => {
810                    read_task.await.ok();
811                    exit_code = -1;
812                    break;
813                }
814            }
815        }
816
817        let combined_str = combined.lock().await.clone();
818        if !combined_str.is_empty() {
819            emit_update(&ctx, combined_str.clone(), None);
820        }
821
822        finish_bash_execution(&combined_str, exit_code, false, None, &ctx)
823    }
824}
825
826/// Remove temp files in the bash output directory that are older than the max age.
827/// This is best-effort — failures are silently ignored.
828fn cleanup_stale_temp_files() {
829    let dir = temp_output_dir();
830    let Ok(entries) = std::fs::read_dir(&dir) else {
831        return;
832    };
833    let Ok(cutoff) = std::time::SystemTime::now()
834        .checked_sub(std::time::Duration::from_secs(TEMP_FILE_MAX_AGE_SECS))
835        .ok_or(())
836    else {
837        return;
838    };
839    for entry in entries.flatten() {
840        let path = entry.path();
841        if path.extension().is_none_or(|e| e != "log") {
842            continue;
843        }
844        if let Ok(metadata) = path.metadata()
845            && let Ok(modified) = metadata.modified()
846            && modified < cutoff
847        {
848            let _ = std::fs::remove_file(&path);
849        }
850    }
851}
852
853/// Return the directory where truncated bash output is saved.
854fn temp_output_dir() -> PathBuf {
855    std::env::temp_dir().join(BASH_TEMP_FILE_PREFIX)
856}
857
858/// Format a status message (cancelled/timeout/exit code) with optional preceding output.
859fn format_status_output(result_text: &str, status_msg: &str) -> String {
860    if result_text.is_empty() || result_text == "(no output)" {
861        status_msg.to_string()
862    } else {
863        format!("{}\n\n{}", result_text, status_msg)
864    }
865}
866
867/// Build a ToolResult with text content and optional details.
868fn into_tool_result(
869    text: String,
870    details: Option<serde_json::Value>,
871) -> yoagent::types::ToolResult {
872    yoagent::types::ToolResult {
873        content: vec![yoagent::types::Content::Text { text }],
874        details: details.unwrap_or(serde_json::Value::Null),
875    }
876}
877
878/// Send an update to the context callback.
879fn emit_update(
880    ctx: &yoagent::types::ToolContext,
881    text: String,
882    details: Option<serde_json::Value>,
883) {
884    if let Some(ref on_update) = ctx.on_update {
885        on_update(into_tool_result(text, details));
886    }
887}
888
889// ── PID tracking for cleanup on shutdown signals ────────────────
890// Matching pi's trackDetachedChildPid / untrackDetachedChildPid.
891// On SIGTERM/SIGHUP, all tracked PIDs are killed before exit.
892
893use std::sync::Mutex;
894
895static TRACKED_PIDS: Mutex<Vec<u32>> = std::sync::Mutex::new(Vec::new());
896
897fn track_pid(pid: u32) {
898    if let Ok(mut pids) = TRACKED_PIDS.lock() {
899        pids.push(pid);
900    }
901}
902
903fn untrack_pid(pid: u32) {
904    if let Ok(mut pids) = TRACKED_PIDS.lock() {
905        pids.retain(|&p| p != pid);
906    }
907}
908
909/// Kill all tracked child process groups. Called on SIGTERM/SIGHUP.
910pub fn kill_tracked_children() {
911    let pids: Vec<u32> = TRACKED_PIDS.lock().map(|p| p.clone()).unwrap_or_default();
912    for pid in pids {
913        kill_process_group(pid);
914    }
915}
916
917struct ProcessGuard {
918    pid: u32,
919}
920
921impl ProcessGuard {
922    fn new(pid: u32) -> Self {
923        if pid > 0 {
924            track_pid(pid);
925        }
926        Self { pid }
927    }
928}
929
930impl Drop for ProcessGuard {
931    fn drop(&mut self) {
932        if self.pid > 0 {
933            untrack_pid(self.pid);
934        }
935    }
936}
937
938#[cfg(test)]
939mod tests {
940    use super::*;
941    use yoagent::AgentTool;
942
943    fn tool_ctx() -> yoagent::types::ToolContext {
944        yoagent::types::ToolContext {
945            tool_call_id: "id".into(),
946            tool_name: "bash".into(),
947            cancel: tokio_util::sync::CancellationToken::new(),
948            on_update: None,
949            on_progress: None,
950        }
951    }
952
953    fn yo_msg_text(content: &[yoagent::types::Content]) -> String {
954        content
955            .iter()
956            .filter_map(|c| {
957                if let yoagent::types::Content::Text { text } = c {
958                    Some(text.as_str())
959                } else {
960                    None
961                }
962            })
963            .collect::<Vec<_>>()
964            .join("")
965    }
966
967    fn make_tool() -> BashTool {
968        BashTool {
969            cwd: std::env::temp_dir(),
970            shell_path: None,
971            command_prefix: None,
972            operations: None,
973        }
974    }
975
976    #[tokio::test]
977    async fn runs_simple_command() {
978        let tool = make_tool();
979        let output = tool
980            .execute(serde_json::json!({"command": "echo hello"}), tool_ctx())
981            .await
982            .unwrap();
983        assert!(yo_msg_text(&output.content).contains("hello"));
984    }
985
986    #[tokio::test]
987    async fn captures_stderr() {
988        let tool = make_tool();
989        let output = tool
990            .execute(serde_json::json!({"command": "echo err >&2"}), tool_ctx())
991            .await
992            .unwrap();
993        assert!(yo_msg_text(&output.content).contains("err"));
994    }
995
996    #[tokio::test]
997    async fn cancel_aborts() {
998        let tool = make_tool();
999        let cancel = tokio_util::sync::CancellationToken::new();
1000        cancel.cancel();
1001        let result = tool
1002            .execute(
1003                serde_json::json!({"command": "sleep 10"}),
1004                yoagent::types::ToolContext {
1005                    tool_call_id: "id".into(),
1006                    tool_name: "bash".into(),
1007                    cancel,
1008                    on_update: None,
1009                    on_progress: None,
1010                },
1011            )
1012            .await;
1013        assert!(result.is_err());
1014        let err = result.unwrap_err().to_string();
1015        assert!(
1016            err.contains("Cancelled") || err.contains("aborted"),
1017            "expected cancellation error, got: {}",
1018            err
1019        );
1020    }
1021
1022    #[tokio::test]
1023    async fn timeout_works() {
1024        let tool = make_tool();
1025        let result = tool
1026            .execute(
1027                serde_json::json!({"command": "sleep 10", "timeout": 1}),
1028                tool_ctx(),
1029            )
1030            .await;
1031        assert!(result.is_err());
1032        let err = result.unwrap_err().to_string();
1033        assert!(err.contains("timed out"));
1034    }
1035
1036    #[test]
1037    fn test_truncate_tail_no_truncation() {
1038        let result = truncate_tail("hello\nworld\n", 2000, 50000);
1039        assert!(!result.truncated);
1040        assert_eq!(result.content, "hello\nworld\n");
1041    }
1042
1043    #[test]
1044    fn test_truncate_tail_by_lines() {
1045        let content: String = (1..=5000).map(|i| format!("line {}\n", i)).collect();
1046        let result = truncate_tail(&content, 2000, 50000);
1047        assert!(result.truncated);
1048        assert!(result.content.starts_with("line 3001"));
1049        assert_eq!(result.content.lines().count(), 2000);
1050    }
1051
1052    #[test]
1053    fn test_truncate_tail_by_bytes() {
1054        let content: String = (1..=100)
1055            .map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
1056            .collect();
1057        let result = truncate_tail(&content, 2000, 50000);
1058        assert!(result.truncated);
1059        assert!(result.content.len() <= 50000);
1060        assert!(result.content.lines().count() < 100);
1061    }
1062
1063    #[test]
1064    fn test_truncate_tail_partial_last_line() {
1065        let content = format!("short\n{}\n", "x".repeat(60000));
1066        let result = truncate_tail(&content, 2000, 50000);
1067        assert!(result.truncated);
1068        assert!(!result.content.starts_with("short"));
1069        assert!(result.content.len() <= 50000);
1070    }
1071
1072    #[test]
1073    fn test_truncate_tail_empty() {
1074        let result = truncate_tail("", 2000, 50000);
1075        assert!(!result.truncated);
1076        assert_eq!(result.content, "");
1077    }
1078
1079    #[tokio::test]
1080    async fn exit_code_nonzero() {
1081        let tool = make_tool();
1082        let result = tool
1083            .execute(serde_json::json!({"command": "exit 42"}), tool_ctx())
1084            .await;
1085        assert!(result.is_err(), "non-zero exit should return error");
1086        let err = result.unwrap_err().to_string();
1087        assert!(err.contains("exited with code 42"), "got: {}", err);
1088    }
1089
1090    #[tokio::test]
1091    async fn exit_code_with_output() {
1092        let tool = make_tool();
1093        let result = tool
1094            .execute(
1095                serde_json::json!({"command": "echo before && exit 1"}),
1096                tool_ctx(),
1097            )
1098            .await;
1099        assert!(result.is_err(), "non-zero exit should return error");
1100        let err = result.unwrap_err().to_string();
1101        assert!(err.contains("before"), "got: {}", err);
1102        assert!(err.contains("exited with code 1"), "got: {}", err);
1103    }
1104
1105    #[tokio::test]
1106    async fn no_output() {
1107        let tool = make_tool();
1108        let output = tool
1109            .execute(serde_json::json!({"command": "true"}), tool_ctx())
1110            .await
1111            .unwrap();
1112        assert!(
1113            yo_msg_text(&output.content).contains("(no output)"),
1114            "got: {}",
1115            yo_msg_text(&output.content)
1116        );
1117    }
1118
1119    #[tokio::test]
1120    async fn combined_stdout_stderr() {
1121        let tool = make_tool();
1122        let output = tool
1123            .execute(
1124                serde_json::json!({"command": "echo out; echo err >&2"}),
1125                tool_ctx(),
1126            )
1127            .await
1128            .unwrap();
1129        assert!(
1130            yo_msg_text(&output.content).contains("out"),
1131            "got: {}",
1132            yo_msg_text(&output.content)
1133        );
1134        assert!(
1135            yo_msg_text(&output.content).contains("err"),
1136            "got: {}",
1137            yo_msg_text(&output.content)
1138        );
1139    }
1140
1141    #[tokio::test]
1142    async fn runs_in_cwd() {
1143        let tmp = std::env::temp_dir().join(format!("rab-bash-cwd-{}", uuid::Uuid::new_v4()));
1144        std::fs::create_dir_all(&tmp).unwrap();
1145        std::fs::write(tmp.join("marker.txt"), "hello").unwrap();
1146
1147        let tool = BashTool {
1148            cwd: tmp.clone(),
1149            shell_path: None,
1150            command_prefix: None,
1151            operations: None,
1152        };
1153        let output = tool
1154            .execute(serde_json::json!({"command": "cat marker.txt"}), tool_ctx())
1155            .await
1156            .unwrap();
1157        assert!(
1158            yo_msg_text(&output.content).contains("hello"),
1159            "got: {}",
1160            yo_msg_text(&output.content)
1161        );
1162    }
1163
1164    #[tokio::test]
1165    async fn missing_command_errors() {
1166        let tool = make_tool();
1167        let result = tool.execute(serde_json::json!({}), tool_ctx()).await;
1168        assert!(result.is_err());
1169        let err = result.unwrap_err().to_string();
1170        assert!(err.contains("command"), "got: {}", err);
1171    }
1172
1173    #[tokio::test]
1174    async fn timeout_with_partial_output() {
1175        let tool = make_tool();
1176        let result = tool
1177            .execute(
1178                serde_json::json!({"command": "echo start && sleep 10 && echo end", "timeout": 1}),
1179                tool_ctx(),
1180            )
1181            .await;
1182        assert!(result.is_err());
1183        let err = result.unwrap_err().to_string();
1184        assert!(err.contains("timed out"), "got: {}", err);
1185    }
1186
1187    #[tokio::test]
1188    async fn cancel_during_long_command() {
1189        let tool = make_tool();
1190        let cancel = tokio_util::sync::CancellationToken::new();
1191        let cancel_ctx = cancel.clone();
1192
1193        let handle = tokio::spawn(async move {
1194            tool.execute(
1195                serde_json::json!({"command": "sleep 30"}),
1196                yoagent::types::ToolContext {
1197                    tool_call_id: "id".into(),
1198                    tool_name: "bash".into(),
1199                    cancel: cancel_ctx,
1200                    on_update: None,
1201                    on_progress: None,
1202                },
1203            )
1204            .await
1205        });
1206
1207        tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1208        cancel.cancel();
1209
1210        let result = handle.await.unwrap();
1211        assert!(result.is_err());
1212        let err = result.unwrap_err().to_string();
1213        assert!(
1214            err.contains("aborted") || err.contains("Cancelled"),
1215            "expected cancellation error, got: {}",
1216            err
1217        );
1218    }
1219
1220    #[test]
1221    fn test_truncate_tail_exact_line_fit() {
1222        let lines: String = (1..=2000).map(|i| format!("line {}\n", i)).collect();
1223        let result = truncate_tail(&lines, 2000, 50000);
1224        assert!(!result.truncated);
1225        assert!(result.content.lines().count() == 2000);
1226    }
1227
1228    #[test]
1229    fn test_truncate_tail_one_over_line_limit() {
1230        let lines: String = (1..=2001).map(|i| format!("line {}\n", i)).collect();
1231        let result = truncate_tail(&lines, 2000, 50000);
1232        assert!(result.truncated);
1233        assert_eq!(result.content.lines().count(), 2000);
1234        assert!(result.content.starts_with("line 2"));
1235    }
1236
1237    #[test]
1238    fn test_truncate_tail_exact_byte_fit() {
1239        let line = "a".repeat(50000);
1240        let result = truncate_tail(&line, 2000, 50000);
1241        assert!(!result.truncated);
1242    }
1243
1244    #[test]
1245    fn test_truncate_tail_one_byte_over() {
1246        let line = "a".repeat(50001);
1247        let result = truncate_tail(&line, 2000, 50000);
1248        assert!(result.truncated);
1249        assert!(result.content.len() <= 50000);
1250    }
1251
1252    #[test]
1253    fn test_truncate_tail_single_line_under_limit() {
1254        let result = truncate_tail("hello world", 2000, 50000);
1255        assert!(!result.truncated);
1256        assert_eq!(result.content, "hello world");
1257    }
1258
1259    #[test]
1260    fn test_truncate_tail_trailing_newline() {
1261        let result = truncate_tail("a\nb\n", 2000, 50000);
1262        assert!(!result.truncated);
1263        assert_eq!(result.content, "a\nb\n");
1264    }
1265
1266    #[test]
1267    fn test_truncate_tail_no_trailing_newline() {
1268        let result = truncate_tail("a\nb", 2000, 50000);
1269        assert!(!result.truncated);
1270        assert_eq!(result.content, "a\nb");
1271    }
1272
1273    #[test]
1274    fn test_truncate_tail_single_line_exceeds_limit() {
1275        let content = "x".repeat(60000);
1276        let result = truncate_tail(&content, 2000, 50000);
1277        assert!(result.truncated);
1278        assert!(result.last_line_partial);
1279        assert_eq!(result.content.len(), 50000);
1280        assert!(result.content.ends_with("x".repeat(50000).as_str()));
1281    }
1282
1283    #[test]
1284    fn test_truncate_tail_byte_count_respects_newlines() {
1285        let content: String = (1..=100)
1286            .map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
1287            .collect();
1288        let result = truncate_tail(&content, 2000, 50000);
1289        assert!(result.truncated);
1290        assert!(result.output_bytes <= 50000);
1291    }
1292
1293    #[tokio::test]
1294    async fn truncated_by_lines_shows_footer() {
1295        let tool = make_tool();
1296        let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
1297        let output = tool
1298            .execute(serde_json::json!({"command": cmd}), tool_ctx())
1299            .await
1300            .unwrap();
1301        assert!(
1302            yo_msg_text(&output.content).contains("Showing lines"),
1303            "got: {}",
1304            yo_msg_text(&output.content)
1305        );
1306        assert!(
1307            yo_msg_text(&output.content).contains("Full output:"),
1308            "got: {}",
1309            yo_msg_text(&output.content)
1310        );
1311    }
1312
1313    #[tokio::test]
1314    async fn small_output_no_footer() {
1315        let tool = make_tool();
1316        let output = tool
1317            .execute(serde_json::json!({"command": "echo hello"}), tool_ctx())
1318            .await
1319            .unwrap();
1320        assert!(!yo_msg_text(&output.content).contains("Output truncated"));
1321        assert!(!yo_msg_text(&output.content).contains("Full output:"));
1322    }
1323
1324    #[tokio::test]
1325    async fn truncated_saves_temp_file() {
1326        let tool = make_tool();
1327        let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
1328        let output = tool
1329            .execute(serde_json::json!({"command": cmd}), tool_ctx())
1330            .await
1331            .unwrap();
1332        assert!(
1333            yo_msg_text(&output.content).contains("/rab-bash/"),
1334            "expected temp file path with /rab-bash/, got: {}",
1335            yo_msg_text(&output.content)
1336        );
1337    }
1338
1339    #[test]
1340    fn test_cleanup_stale_temp_files_nonexistent_dir() {
1341        // Should not panic on missing directory
1342        cleanup_stale_temp_files();
1343    }
1344
1345    #[test]
1346    fn test_truncate_tail_many_short_lines() {
1347        let content: String = (1..=10000).map(|i| format!("{}\n", i)).collect();
1348        let result = truncate_tail(&content, 2000, 50000);
1349        assert!(result.truncated);
1350        assert_eq!(result.truncated_by, "lines");
1351        assert_eq!(result.output_lines, 2000);
1352        assert!(
1353            result.content.starts_with("8001"),
1354            "starts with: {:?}",
1355            &result.content[..10]
1356        );
1357    }
1358
1359    #[test]
1360    fn test_truncate_tail_lines_and_bytes_both_exceeded() {
1361        let content: String = (1..=5000)
1362            .map(|i| format!("line {} {}\n", i, "x".repeat(100)))
1363            .collect();
1364        let result = truncate_tail(&content, 2000, 30000);
1365        assert!(result.truncated);
1366        assert_eq!(result.truncated_by, "bytes");
1367        assert!(result.output_lines < 2000);
1368    }
1369
1370    // ── ProcessGuard tests ──────────────────────────────────────
1371
1372    #[test]
1373    fn test_process_guard_tracks_pid() {
1374        let pid = 12345u32;
1375        {
1376            let _guard = ProcessGuard::new(pid);
1377            let pids = TRACKED_PIDS.lock().unwrap();
1378            assert!(pids.contains(&pid));
1379        }
1380        let pids = TRACKED_PIDS.lock().unwrap();
1381        assert!(!pids.contains(&pid));
1382    }
1383
1384    #[test]
1385    fn test_process_guard_zero_pid() {
1386        {
1387            let _guard = ProcessGuard::new(0);
1388            let pids = TRACKED_PIDS.lock().unwrap();
1389            assert!(!pids.contains(&0));
1390        }
1391    }
1392}