Skip to main content

rab/builtin/
bash.rs

1use crate::agent::extension::{Cancel, Extension, ToolDefinition};
2use crate::agent::extension::{ToolRenderContext, ToolRenderer};
3use crate::tui::Theme;
4use crate::tui::ThemeKey;
5use crate::tui::visual_truncate::truncate_to_visual_lines;
6use async_trait::async_trait;
7
8use std::borrow::Cow;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::time::Instant;
14use tokio::sync::{Mutex as TokioMutex, mpsc::UnboundedSender};
15
16// ── BashOperations (pluggable) ────────────────────────────────────
17
18/// Pluggable operations for the bash tool (matching pi's BashOperations).
19/// Override these to delegate command execution to remote systems (for example SSH).
20#[async_trait]
21pub trait BashOperations: Send + Sync {
22    /// Execute a command and stream output via the sender.
23    /// Returns the exit code (0 = success, non-zero = error, None = killed).
24    async fn exec(
25        &self,
26        command: &str,
27        cwd: &Path,
28        on_data: UnboundedSender<String>,
29        signal: Option<&Cancel>,
30        timeout: Option<u64>,
31        env: Option<HashMap<String, String>>,
32    ) -> Result<Option<i32>, anyhow::Error>;
33}
34
35#[derive(Clone, Default)]
36pub struct BashToolOptions {
37    /// Custom operations for command execution. Default: local shell.
38    pub operations: Option<Arc<dyn BashOperations>>,
39    /// Command prefix prepended to every command (for example shell setup commands).
40    pub command_prefix: Option<String>,
41    /// Optional explicit shell path from settings.
42    pub shell_path: Option<String>,
43}
44
45pub struct BashExtension {
46    cwd: PathBuf,
47    options: BashToolOptions,
48}
49
50impl BashExtension {
51    pub fn new(cwd: PathBuf) -> Self {
52        Self {
53            cwd,
54            options: BashToolOptions::default(),
55        }
56    }
57
58    pub fn with_options(cwd: PathBuf, options: BashToolOptions) -> Self {
59        Self { cwd, options }
60    }
61
62    pub fn with_shell_path(cwd: PathBuf, shell_path: String) -> Self {
63        Self {
64            cwd,
65            options: BashToolOptions {
66                shell_path: Some(shell_path),
67                ..BashToolOptions::default()
68            },
69        }
70    }
71}
72
73impl Extension for BashExtension {
74    fn name(&self) -> Cow<'static, str> {
75        "bash".into()
76    }
77
78    fn as_any(&self) -> &dyn std::any::Any {
79        self
80    }
81
82    fn tools(&self) -> Vec<ToolDefinition> {
83        vec![ToolDefinition {
84            tool: Box::new(BashTool {
85                cwd: self.cwd.clone(),
86                shell_path: self.options.shell_path.clone(),
87                command_prefix: self.options.command_prefix.clone(),
88                operations: self.options.operations.clone(),
89            }),
90            snippet: "Execute bash commands (ls, grep, find, etc.)",
91            guidelines: &[],
92            prepare_arguments: None,
93            before_tool_call: None,
94            after_tool_call: None,
95            renderer: Some(std::sync::Arc::new(BashRenderer)),
96        }]
97    }
98}
99
100struct BashTool {
101    cwd: PathBuf,
102    shell_path: Option<String>,
103    command_prefix: Option<String>,
104    operations: Option<Arc<dyn BashOperations>>,
105}
106
107// ── Constants ────────────────────────────────────────────────────
108
109const DEFAULT_MAX_LINES: usize = 2000;
110const DEFAULT_MAX_BYTES: usize = 50 * 1024; // 50KB
111const BASH_TEMP_FILE_PREFIX: &str = "rab-bash";
112
113/// Maximum age before stale temp files are cleaned up (24 hours).
114const TEMP_FILE_MAX_AGE_SECS: u64 = 24 * 60 * 60;
115
116/// Grace period after child exit (ms) — matching pi's EXIT_STDIO_GRACE_MS.
117/// Detached descendants may keep stdout/stderr pipes open; we poll until idle.
118const EXIT_STDIO_GRACE_MS: u64 = 100;
119
120// ── Shell resolution (matching pi's getShellConfig) ──────────────
121
122/// Shell configuration: which shell binary to use and how to pass commands.
123struct ShellConfig {
124    shell: String,
125    args: Vec<String>,
126}
127
128/// Resolve the shell to use for command execution.
129/// Resolution order (matching pi):
130/// 1. User-specified shell_path (from BashTool.shell_path)
131/// 2. On Unix: /bin/bash, then bash on PATH, then fallback to sh
132/// 3. On Windows: Git Bash, bash on PATH, fallback to sh
133fn resolve_shell(shell_path: Option<&str>) -> ShellConfig {
134    if let Some(path) = shell_path {
135        return ShellConfig {
136            shell: path.to_string(),
137            args: vec!["-c".to_string()],
138        };
139    }
140
141    // Try /bin/bash first (most common on Unix)
142    if std::path::Path::new("/bin/bash").exists() {
143        return ShellConfig {
144            shell: "/bin/bash".to_string(),
145            args: vec!["-c".to_string()],
146        };
147    }
148
149    // Try `which bash`
150    #[cfg(unix)]
151    {
152        if let Ok(output) = std::process::Command::new("which")
153            .arg("bash")
154            .stdout(std::process::Stdio::piped())
155            .stderr(std::process::Stdio::null())
156            .output()
157            && output.status.success()
158        {
159            let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
160            if !path.is_empty() && std::path::Path::new(&path).exists() {
161                return ShellConfig {
162                    shell: path,
163                    args: vec!["-c".to_string()],
164                };
165            }
166        }
167    }
168
169    // Fallback to sh
170    ShellConfig {
171        shell: "sh".to_string(),
172        args: vec!["-c".to_string()],
173    }
174}
175
176// ── Helpers ──────────────────────────────────────────────────────
177
178/// Kill a process group by its leader PID.
179#[cfg(unix)]
180fn kill_process_group(pid: u32) {
181    if pid > 0 {
182        let _ = std::process::Command::new("kill")
183            .arg("--")
184            .arg(format!("-{}", pid))
185            .status();
186    }
187}
188
189#[cfg(not(unix))]
190fn kill_process_group(pid: u32) {
191    let _ = pid;
192}
193
194/// Spawn a bash command with process group setup for clean cancellation.
195fn spawn_bash_command(
196    command: &str,
197    cwd: &std::path::Path,
198    shell_path: Option<&str>,
199) -> std::io::Result<tokio::process::Child> {
200    let shell_cfg = resolve_shell(shell_path);
201
202    #[cfg(unix)]
203    {
204        use std::os::unix::process::CommandExt;
205        let mut std_cmd = std::process::Command::new(&shell_cfg.shell);
206        std_cmd.args(&shell_cfg.args).arg(command).current_dir(cwd);
207        unsafe {
208            std_cmd.pre_exec(|| {
209                libc::setpgid(0, 0);
210                Ok(())
211            });
212        }
213        let mut tokio_cmd = tokio::process::Command::from(std_cmd);
214        tokio_cmd
215            .stdin(std::process::Stdio::null())
216            .stdout(std::process::Stdio::piped())
217            .stderr(std::process::Stdio::piped())
218            .spawn()
219    }
220    #[cfg(not(unix))]
221    {
222        tokio::process::Command::new(&shell_cfg.shell)
223            .args(&shell_cfg.args)
224            .arg(command)
225            .current_dir(cwd)
226            .stdin(std::process::Stdio::null())
227            .stdout(std::process::Stdio::piped())
228            .stderr(std::process::Stdio::piped())
229            .spawn()
230    }
231}
232
233/// Sanitize binary output for display/storage (matching pi's sanitizeBinaryOutput + stripAnsi).
234fn sanitize_output(text: &str) -> String {
235    let mut result = String::with_capacity(text.len());
236    let mut in_escape = false;
237    for c in text.chars() {
238        if in_escape {
239            if c == '\x1b' || c == '\u{9b}' {
240                continue;
241            }
242            if c.is_ascii_alphabetic() || c == '~' {
243                in_escape = false;
244            }
245            continue;
246        }
247        if c == '\x1b' || c == '\u{9b}' {
248            in_escape = true;
249            continue;
250        }
251        let code = c as u32;
252        if code <= 0x1f && code != 0x09 && code != 0x0a && code != 0x0d {
253            continue;
254        }
255        if (0xfff9..=0xfffb).contains(&code) {
256            continue;
257        }
258        result.push(c);
259    }
260    result
261}
262
263fn format_size(bytes: usize) -> String {
264    if bytes < 1024 {
265        format!("{}B", bytes)
266    } else if bytes < 1024 * 1024 {
267        format!("{:.1}KB", bytes as f64 / 1024.0)
268    } else {
269        format!("{:.1}MB", bytes as f64 / (1024.0 * 1024.0))
270    }
271}
272
273/// Truncation result for tail-based truncation (keep last N lines/bytes).
274struct TailTruncation {
275    content: String,
276    truncated: bool,
277    total_lines: usize,
278    output_lines: usize,
279    output_bytes: usize,
280    truncated_by: &'static str,
281    last_line_partial: bool,
282}
283
284/// Truncate content from the tail, keeping complete lines that fit within limits.
285fn truncate_tail(content: &str, max_lines: usize, max_bytes: usize) -> TailTruncation {
286    let total_bytes = content.len();
287    let lines: Vec<&str> = content.lines().collect();
288    let total_lines = lines.len();
289
290    if total_lines <= max_lines && total_bytes <= max_bytes {
291        return TailTruncation {
292            content: content.to_string(),
293            truncated: false,
294            total_lines,
295            output_lines: total_lines,
296            output_bytes: total_bytes,
297            truncated_by: "",
298            last_line_partial: false,
299        };
300    }
301
302    let mut output: Vec<&str> = Vec::new();
303    let mut byte_count: usize = 0;
304    let mut truncated_by = "lines";
305    let mut last_line_partial = false;
306
307    for line in lines.iter().rev().take(max_lines) {
308        let line_bytes = line.len();
309        let with_newline = if output.is_empty() {
310            line_bytes
311        } else {
312            line_bytes + 1
313        };
314
315        if byte_count + with_newline > max_bytes {
316            truncated_by = "bytes";
317            if output.is_empty() {
318                let end_start = line.len().saturating_sub(max_bytes);
319                let truncated_line = &line[end_start..];
320                output.push(truncated_line);
321                byte_count = truncated_line.len();
322                last_line_partial = true;
323            }
324            break;
325        }
326
327        output.push(line);
328        byte_count += with_newline;
329    }
330
331    if output.len() >= max_lines && byte_count <= max_bytes {
332        truncated_by = "lines";
333    }
334
335    output.reverse();
336    TailTruncation {
337        content: output.join("\n"),
338        truncated: true,
339        total_lines,
340        output_lines: output.len(),
341        output_bytes: byte_count,
342        truncated_by,
343        last_line_partial,
344    }
345}
346
347// ── Result formatting ────────────────────────────────────────────
348
349fn finish_bash_execution(
350    combined: &str,
351    exit_code: i32,
352    cancelled: bool,
353    timed_out: Option<u64>,
354    ctx: &yoagent::types::ToolContext,
355) -> std::result::Result<yoagent::types::ToolResult, yoagent::types::ToolError> {
356    let trunc = truncate_tail(combined, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES);
357
358    let mut result_text = if trunc.content.is_empty() {
359        "(no output)".to_string()
360    } else {
361        trunc.content.clone()
362    };
363
364    // Save full output to temp file if truncated
365    let full_output_path = if trunc.truncated {
366        let tmp_dir = temp_output_dir();
367        let _ = std::fs::create_dir_all(&tmp_dir);
368        let tmp_path = tmp_dir.join(format!("{}.log", uuid::Uuid::new_v4()));
369        let saved = std::fs::write(&tmp_path, combined).ok().map(|_| {
370            cleanup_stale_temp_files();
371            tmp_path
372        });
373
374        let start_line = trunc.total_lines - trunc.output_lines + 1;
375        let end_line = trunc.total_lines;
376
377        let notice = if trunc.truncated_by == "lines" {
378            format!(
379                "\n\n[Showing lines {}-{} of {}. Full output: {}]",
380                start_line,
381                end_line,
382                trunc.total_lines,
383                saved
384                    .as_ref()
385                    .map(|p| p.display().to_string())
386                    .unwrap_or_default()
387            )
388        } else {
389            format!(
390                "\n\n[Showing lines {}-{} of {} ({} limit). Full output: {}]",
391                start_line,
392                end_line,
393                trunc.total_lines,
394                format_size(DEFAULT_MAX_BYTES),
395                saved
396                    .as_ref()
397                    .map(|p| p.display().to_string())
398                    .unwrap_or_default()
399            )
400        };
401        result_text.push_str(&notice);
402        saved
403    } else {
404        None
405    };
406
407    // Build structured details
408    let details = if trunc.truncated || full_output_path.is_some() {
409        Some(serde_json::json!({
410            "truncation": {
411                "truncated": trunc.truncated,
412                "truncatedBy": trunc.truncated_by,
413                "totalLines": trunc.total_lines,
414                "outputLines": trunc.output_lines,
415                "outputBytes": trunc.output_bytes,
416                "lastLinePartial": trunc.last_line_partial,
417                "maxLines": DEFAULT_MAX_LINES,
418                "maxBytes": DEFAULT_MAX_BYTES,
419            },
420            "fullOutputPath": full_output_path.as_ref().map(|p| p.display().to_string()),
421        }))
422    } else {
423        None
424    };
425
426    let final_output = if cancelled {
427        format_status_output(&result_text, "Command aborted")
428    } else if let Some(secs) = timed_out {
429        format_status_output(
430            &result_text,
431            &format!("Command timed out after {} seconds", secs),
432        )
433    } else if exit_code != 0 {
434        format_status_output(
435            &result_text,
436            &format!("Command exited with code {}", exit_code),
437        )
438    } else {
439        emit_update(ctx, result_text.clone(), details.clone());
440        return Ok(into_tool_result(result_text, details));
441    };
442
443    emit_update(ctx, final_output.clone(), details.clone());
444    Err(yoagent::types::ToolError::Failed(final_output))
445}
446
447// ── Bash execution helpers (was AgentTool impl, now in yoagent) ──
448
449// ── Bash tool renderer ───────────────────────────────────────────
450
451struct BashRenderer;
452
453impl ToolRenderer for BashRenderer {
454    fn render_call(
455        &self,
456        args: &serde_json::Value,
457        _width: usize,
458        theme: &dyn Theme,
459        _ctx: &ToolRenderContext,
460    ) -> Vec<String> {
461        let cmd = args
462            .get("command")
463            .and_then(|v| v.as_str())
464            .unwrap_or("...");
465        let timeout = args.get("timeout").and_then(|v| v.as_i64());
466        let timeout_suffix = timeout
467            .map(|t| theme.fg_key(ThemeKey::Muted, &format!(" (timeout {}s)", t)))
468            .unwrap_or_default();
469
470        vec![format!(
471            "{}{}",
472            theme.fg_key(ThemeKey::ToolTitle, &theme.bold(&format!("$ {}", cmd))),
473            timeout_suffix
474        )]
475    }
476
477    fn render_result(
478        &self,
479        content: &str,
480        width: usize,
481        theme: &dyn Theme,
482        ctx: &ToolRenderContext,
483    ) -> Vec<String> {
484        let mut lines: Vec<String> = Vec::new();
485
486        let clean = strip_context_truncation_footer(content)
487            .trim_end()
488            .to_string();
489        let all_lines: Vec<&str> = clean.lines().collect();
490
491        if all_lines.is_empty() || (all_lines.len() == 1 && all_lines[0].is_empty()) {
492            return lines;
493        }
494
495        let preview_count = 5;
496        let (preview_lines, hidden_line_count) = if ctx.expanded {
497            (all_lines.clone(), 0)
498        } else {
499            truncate_to_visual_lines(&all_lines, width, preview_count)
500        };
501
502        // ── Preview hint with dim/muted styling (matching pi's keyHint) ──
503        if !ctx.expanded && hidden_line_count > 0 {
504            if ctx.expand_key.is_empty() {
505                lines.push(theme.fg_key(
506                    ThemeKey::Muted,
507                    &format!("... {} earlier lines", hidden_line_count),
508                ));
509            } else {
510                // Pi pattern: muted prefix + dim key + muted suffix
511                // e.g. "... (12 earlier lines, \x1b[2mctrl+o\x1b[22m to expand)"
512                let prefix = theme.fg_key(
513                    ThemeKey::Muted,
514                    &format!("... ({} earlier lines, ", hidden_line_count),
515                );
516                let key_styled = theme.fg("dim", &ctx.expand_key);
517                let suffix = theme.fg_key(ThemeKey::Muted, " to expand)");
518                lines.push(format!("{}{}{}", prefix, key_styled, suffix));
519            }
520        }
521
522        let fg_key = if ctx.is_error { "error" } else { "toolOutput" };
523        for line in &preview_lines {
524            if line.is_empty() {
525                lines.push(String::new());
526            } else {
527                lines.push(theme.fg(fg_key, line));
528            }
529        }
530
531        if let Some(secs) = ctx.duration_secs {
532            if !lines.is_empty() {
533                lines.push(String::new());
534            }
535            let is_complete = ctx.exit_code.is_some() || ctx.cancelled;
536            let label = if is_complete { "Took" } else { "Elapsed" };
537            lines.push(theme.fg_key(ThemeKey::Muted, &format!("{} {:.1}s", label, secs)));
538        }
539
540        if ctx.was_truncated {
541            if !lines.is_empty() {
542                lines.push(String::new());
543            }
544            if let Some(ref path) = ctx.full_output_path {
545                lines.push(theme.fg(
546                    "warning",
547                    &format!("Output truncated. Full output: {}", path),
548                ));
549            } else {
550                lines.push(theme.fg_key(ThemeKey::Warning, "Output truncated."));
551            }
552        }
553
554        lines
555    }
556}
557
558fn strip_context_truncation_footer(output: &str) -> String {
559    let lines: Vec<&str> = output.lines().collect();
560    if lines.len() < 3 {
561        return output.to_string();
562    }
563    let last = lines.last().map_or("", |v| v).trim();
564    if last.starts_with('[')
565        && (last.contains("Showing lines") || last.contains("Showing last"))
566        && last.contains("Full output:")
567    {
568        let before: Vec<&str> = lines[..lines.len() - 1].to_vec();
569        if !before.is_empty() && before[before.len() - 1].is_empty() {
570            before[..before.len() - 1].join("\n")
571        } else {
572            before.join("\n")
573        }
574    } else {
575        output.to_string()
576    }
577}
578
579#[async_trait::async_trait]
580impl yoagent::types::AgentTool for BashTool {
581    fn name(&self) -> &str {
582        "bash"
583    }
584    fn label(&self) -> &str {
585        "bash"
586    }
587    fn description(&self) -> &str {
588        "Execute a bash command in the current working directory. Returns stdout and stderr. \
589         Output is truncated to last 2000 lines or 50KB (whichever is hit first). If \
590         truncated, full output is saved to a temp file. Optionally provide a timeout in seconds."
591    }
592    fn parameters_schema(&self) -> serde_json::Value {
593        serde_json::json!({
594            "type": "object",
595            "required": ["command"],
596            "properties": {
597                "command": {
598                    "type": "string",
599                    "description": "Bash command to execute"
600                },
601                "timeout": {
602                    "type": "number",
603                    "description": "Timeout in seconds (optional, no default timeout)"
604                }
605            }
606        })
607    }
608    async fn execute(
609        &self,
610        params: serde_json::Value,
611        ctx: yoagent::types::ToolContext,
612    ) -> std::result::Result<yoagent::types::ToolResult, yoagent::types::ToolError> {
613        let command = params["command"].as_str().ok_or_else(|| {
614            yoagent::types::ToolError::InvalidArgs("Missing 'command' argument".into())
615        })?;
616        let timeout = params["timeout"].as_u64();
617        let started_at = Instant::now();
618
619        if ctx.cancel.is_cancelled() {
620            return Err(yoagent::types::ToolError::Cancelled);
621        }
622
623        // Apply command prefix if set
624        let effective_command = if let Some(ref prefix) = self.command_prefix {
625            format!("{}\n{}", prefix, command)
626        } else {
627            command.to_string()
628        };
629
630        // Check that the working directory exists
631        if !self.cwd.exists() {
632            return Err(yoagent::types::ToolError::Failed(format!(
633                "Working directory does not exist: {}\nCannot execute bash commands.",
634                self.cwd.display()
635            )));
636        }
637
638        // If custom operations are provided, delegate entirely
639        if let Some(ref ops) = self.operations {
640            let (output_tx, mut output_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
641            let ops_cancel = Cancel::new();
642
643            // Link yoagent cancellation to rab Cancel
644            let yo_cancel = ctx.cancel.clone();
645            let watch_cancel = ops_cancel.clone();
646            tokio::spawn(async move {
647                yo_cancel.cancelled().await;
648                watch_cancel.cancel();
649            });
650
651            let ops_command = effective_command.clone();
652            let ops_cwd = self.cwd.clone();
653            let ops = ops.clone();
654            let ops_handle = tokio::spawn(async move {
655                ops.exec(
656                    &ops_command,
657                    &ops_cwd,
658                    output_tx,
659                    Some(&ops_cancel),
660                    timeout,
661                    None,
662                )
663                .await
664            });
665
666            // Collect output from the channel
667            let mut combined = String::new();
668            while let Some(chunk) = output_rx.recv().await {
669                combined.push_str(&chunk);
670                emit_update(&ctx, combined.clone(), None);
671            }
672
673            let exit_code = ops_handle.await.unwrap_or(Ok(None)).unwrap_or(None);
674            let code = exit_code.unwrap_or(-1);
675
676            return finish_bash_execution(&combined, code, ctx.cancel.is_cancelled(), None, &ctx);
677        }
678
679        let mut child =
680            spawn_bash_command(&effective_command, &self.cwd, self.shell_path.as_deref()).map_err(
681                |e| yoagent::types::ToolError::Failed(format!("Failed to spawn command: {}", e)),
682            )?;
683
684        let pid = child.id().unwrap_or(0);
685
686        // Shared output buffer for streaming reads
687        let combined = Arc::new(TokioMutex::new(String::new()));
688        let combined_clone = combined.clone();
689
690        let stdout_pipe = child
691            .stdout
692            .take()
693            .ok_or_else(|| yoagent::types::ToolError::Failed("Failed to capture stdout".into()))?;
694        let stderr_pipe = child
695            .stderr
696            .take()
697            .ok_or_else(|| yoagent::types::ToolError::Failed("Failed to capture stderr".into()))?;
698
699        use tokio::io::AsyncReadExt;
700        let read_task = tokio::spawn(async move {
701            let mut stdout_buf = vec![0u8; 65536];
702            let mut stderr_buf = vec![0u8; 65536];
703            let mut stdout_reader = stdout_pipe;
704            let mut stderr_reader = stderr_pipe;
705            let mut stdout_done = false;
706            let mut stderr_done = false;
707            loop {
708                tokio::select! {
709                    result = stdout_reader.read(&mut stdout_buf), if !stdout_done => {
710                        match result {
711                            Ok(0) => stdout_done = true,
712                            Ok(n) => {
713                                let text = String::from_utf8_lossy(&stdout_buf[..n]);
714                                let sanitized = sanitize_output(&text);
715                                let mut out = combined_clone.lock().await;
716                                out.push_str(&sanitized);
717                            }
718                            Err(_) => stdout_done = true,
719                        }
720                    }
721                    result = stderr_reader.read(&mut stderr_buf), if !stderr_done => {
722                        match result {
723                            Ok(0) => stderr_done = true,
724                            Ok(n) => {
725                                let text = String::from_utf8_lossy(&stderr_buf[..n]);
726                                let sanitized = sanitize_output(&text);
727                                let mut out = combined_clone.lock().await;
728                                out.push_str(&sanitized);
729                            }
730                            Err(_) => stderr_done = true,
731                        }
732                    }
733                }
734                if stdout_done && stderr_done {
735                    break;
736                }
737            }
738        });
739
740        // ── PID tracking for cleanup on shutdown signals ──
741        let _pid_guard = ProcessGuard::new(pid);
742
743        // Set up cancellation monitor: kill the process group if cancelled
744        let cancelled = Arc::new(AtomicBool::new(false));
745        let cancel_flag = cancelled.clone();
746        let yo_cancel = ctx.cancel.clone();
747        let _cancel_monitor: tokio::task::JoinHandle<()> = tokio::spawn(async move {
748            yo_cancel.cancelled().await;
749            cancel_flag.store(true, Ordering::SeqCst);
750            kill_process_group(pid);
751        });
752
753        // Send initial empty update
754        if let Some(ref on_update) = ctx.on_update {
755            on_update(yoagent::types::ToolResult {
756                content: vec![],
757                details: serde_json::Value::Null,
758            });
759        }
760
761        // Wait for the process to exit, with optional timeout and streaming updates
762        let timeout_dur = timeout.map(std::time::Duration::from_secs);
763        let throttle_ms = 100u64;
764        let mut last_update_at = Instant::now();
765
766        let exit_code: i32;
767
768        loop {
769            if cancelled.load(Ordering::SeqCst) {
770                kill_process_group(pid);
771                read_task.abort();
772                let combined_str = combined.lock().await.clone();
773                return finish_bash_execution(&combined_str, -1, true, None, &ctx);
774            }
775
776            if let Some(dur) = timeout_dur
777                && started_at.elapsed() > dur
778            {
779                kill_process_group(pid);
780                read_task.abort();
781                let combined_str = combined.lock().await.clone();
782                return finish_bash_execution(&combined_str, -1, false, timeout, &ctx);
783            }
784
785            if last_update_at.elapsed().as_millis() as u64 >= throttle_ms {
786                let out = combined.lock().await.clone();
787                if !out.is_empty() {
788                    last_update_at = Instant::now();
789                    emit_update(&ctx, out, None);
790                }
791            }
792
793            match child.try_wait() {
794                Ok(Some(status)) => {
795                    exit_code = status.code().unwrap_or(-1);
796                    // Idle grace period — after child exit, wait for pipes to go idle
797                    let mut last_len = combined.lock().await.len();
798                    loop {
799                        tokio::time::sleep(std::time::Duration::from_millis(EXIT_STDIO_GRACE_MS))
800                            .await;
801                        let new_len = combined.lock().await.len();
802                        if new_len == last_len {
803                            break;
804                        }
805                        last_len = new_len;
806                    }
807                    read_task.abort();
808                    break;
809                }
810                Ok(None) => {
811                    tokio::time::sleep(std::time::Duration::from_millis(throttle_ms)).await;
812                }
813                Err(_) => {
814                    read_task.await.ok();
815                    exit_code = -1;
816                    break;
817                }
818            }
819        }
820
821        let combined_str = combined.lock().await.clone();
822        if !combined_str.is_empty() {
823            emit_update(&ctx, combined_str.clone(), None);
824        }
825
826        finish_bash_execution(&combined_str, exit_code, false, None, &ctx)
827    }
828}
829
830/// Remove temp files in the bash output directory that are older than the max age.
831/// This is best-effort — failures are silently ignored.
832fn cleanup_stale_temp_files() {
833    let dir = temp_output_dir();
834    let Ok(entries) = std::fs::read_dir(&dir) else {
835        return;
836    };
837    let Ok(cutoff) = std::time::SystemTime::now()
838        .checked_sub(std::time::Duration::from_secs(TEMP_FILE_MAX_AGE_SECS))
839        .ok_or(())
840    else {
841        return;
842    };
843    for entry in entries.flatten() {
844        let path = entry.path();
845        if path.extension().is_none_or(|e| e != "log") {
846            continue;
847        }
848        if let Ok(metadata) = path.metadata()
849            && let Ok(modified) = metadata.modified()
850            && modified < cutoff
851        {
852            let _ = std::fs::remove_file(&path);
853        }
854    }
855}
856
857/// Return the directory where truncated bash output is saved.
858fn temp_output_dir() -> PathBuf {
859    std::env::temp_dir().join(BASH_TEMP_FILE_PREFIX)
860}
861
862/// Format a status message (cancelled/timeout/exit code) with optional preceding output.
863fn format_status_output(result_text: &str, status_msg: &str) -> String {
864    if result_text.is_empty() || result_text == "(no output)" {
865        status_msg.to_string()
866    } else {
867        format!("{}\n\n{}", result_text, status_msg)
868    }
869}
870
871/// Build a ToolResult with text content and optional details.
872fn into_tool_result(
873    text: String,
874    details: Option<serde_json::Value>,
875) -> yoagent::types::ToolResult {
876    yoagent::types::ToolResult {
877        content: vec![yoagent::types::Content::Text { text }],
878        details: details.unwrap_or(serde_json::Value::Null),
879    }
880}
881
882/// Send an update to the context callback.
883fn emit_update(
884    ctx: &yoagent::types::ToolContext,
885    text: String,
886    details: Option<serde_json::Value>,
887) {
888    if let Some(ref on_update) = ctx.on_update {
889        on_update(into_tool_result(text, details));
890    }
891}
892
893// ── PID tracking for cleanup on shutdown signals ────────────────
894// Matching pi's trackDetachedChildPid / untrackDetachedChildPid.
895// On SIGTERM/SIGHUP, all tracked PIDs are killed before exit.
896
897use std::sync::Mutex;
898
899static TRACKED_PIDS: Mutex<Vec<u32>> = std::sync::Mutex::new(Vec::new());
900
901fn track_pid(pid: u32) {
902    if let Ok(mut pids) = TRACKED_PIDS.lock() {
903        pids.push(pid);
904    }
905}
906
907fn untrack_pid(pid: u32) {
908    if let Ok(mut pids) = TRACKED_PIDS.lock() {
909        pids.retain(|&p| p != pid);
910    }
911}
912
913/// Kill all tracked child process groups. Called on SIGTERM/SIGHUP.
914pub fn kill_tracked_children() {
915    let pids: Vec<u32> = TRACKED_PIDS.lock().map(|p| p.clone()).unwrap_or_default();
916    for pid in pids {
917        kill_process_group(pid);
918    }
919}
920
921struct ProcessGuard {
922    pid: u32,
923}
924
925impl ProcessGuard {
926    fn new(pid: u32) -> Self {
927        if pid > 0 {
928            track_pid(pid);
929        }
930        Self { pid }
931    }
932}
933
934impl Drop for ProcessGuard {
935    fn drop(&mut self) {
936        if self.pid > 0 {
937            untrack_pid(self.pid);
938        }
939    }
940}
941
942#[cfg(test)]
943mod tests {
944    use super::*;
945    use yoagent::AgentTool;
946
947    fn tool_ctx() -> yoagent::types::ToolContext {
948        yoagent::types::ToolContext {
949            tool_call_id: "id".into(),
950            tool_name: "bash".into(),
951            cancel: tokio_util::sync::CancellationToken::new(),
952            on_update: None,
953            on_progress: None,
954        }
955    }
956
957    fn yo_msg_text(content: &[yoagent::types::Content]) -> String {
958        content
959            .iter()
960            .filter_map(|c| {
961                if let yoagent::types::Content::Text { text } = c {
962                    Some(text.as_str())
963                } else {
964                    None
965                }
966            })
967            .collect::<Vec<_>>()
968            .join("")
969    }
970
971    fn make_tool() -> BashTool {
972        BashTool {
973            cwd: std::env::temp_dir(),
974            shell_path: None,
975            command_prefix: None,
976            operations: None,
977        }
978    }
979
980    #[tokio::test]
981    async fn runs_simple_command() {
982        let tool = make_tool();
983        let output = tool
984            .execute(serde_json::json!({"command": "echo hello"}), tool_ctx())
985            .await
986            .unwrap();
987        assert!(yo_msg_text(&output.content).contains("hello"));
988    }
989
990    #[tokio::test]
991    async fn captures_stderr() {
992        let tool = make_tool();
993        let output = tool
994            .execute(serde_json::json!({"command": "echo err >&2"}), tool_ctx())
995            .await
996            .unwrap();
997        assert!(yo_msg_text(&output.content).contains("err"));
998    }
999
1000    #[tokio::test]
1001    async fn cancel_aborts() {
1002        let tool = make_tool();
1003        let cancel = tokio_util::sync::CancellationToken::new();
1004        cancel.cancel();
1005        let result = tool
1006            .execute(
1007                serde_json::json!({"command": "sleep 10"}),
1008                yoagent::types::ToolContext {
1009                    tool_call_id: "id".into(),
1010                    tool_name: "bash".into(),
1011                    cancel,
1012                    on_update: None,
1013                    on_progress: None,
1014                },
1015            )
1016            .await;
1017        assert!(result.is_err());
1018        let err = result.unwrap_err().to_string();
1019        assert!(
1020            err.contains("Cancelled") || err.contains("aborted"),
1021            "expected cancellation error, got: {}",
1022            err
1023        );
1024    }
1025
1026    #[tokio::test]
1027    async fn timeout_works() {
1028        let tool = make_tool();
1029        let result = tool
1030            .execute(
1031                serde_json::json!({"command": "sleep 10", "timeout": 1}),
1032                tool_ctx(),
1033            )
1034            .await;
1035        assert!(result.is_err());
1036        let err = result.unwrap_err().to_string();
1037        assert!(err.contains("timed out"));
1038    }
1039
1040    #[test]
1041    fn test_truncate_tail_no_truncation() {
1042        let result = truncate_tail("hello\nworld\n", 2000, 50000);
1043        assert!(!result.truncated);
1044        assert_eq!(result.content, "hello\nworld\n");
1045    }
1046
1047    #[test]
1048    fn test_truncate_tail_by_lines() {
1049        let content: String = (1..=5000).map(|i| format!("line {}\n", i)).collect();
1050        let result = truncate_tail(&content, 2000, 50000);
1051        assert!(result.truncated);
1052        assert!(result.content.starts_with("line 3001"));
1053        assert_eq!(result.content.lines().count(), 2000);
1054    }
1055
1056    #[test]
1057    fn test_truncate_tail_by_bytes() {
1058        let content: String = (1..=100)
1059            .map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
1060            .collect();
1061        let result = truncate_tail(&content, 2000, 50000);
1062        assert!(result.truncated);
1063        assert!(result.content.len() <= 50000);
1064        assert!(result.content.lines().count() < 100);
1065    }
1066
1067    #[test]
1068    fn test_truncate_tail_partial_last_line() {
1069        let content = format!("short\n{}\n", "x".repeat(60000));
1070        let result = truncate_tail(&content, 2000, 50000);
1071        assert!(result.truncated);
1072        assert!(!result.content.starts_with("short"));
1073        assert!(result.content.len() <= 50000);
1074    }
1075
1076    #[test]
1077    fn test_truncate_tail_empty() {
1078        let result = truncate_tail("", 2000, 50000);
1079        assert!(!result.truncated);
1080        assert_eq!(result.content, "");
1081    }
1082
1083    #[tokio::test]
1084    async fn exit_code_nonzero() {
1085        let tool = make_tool();
1086        let result = tool
1087            .execute(serde_json::json!({"command": "exit 42"}), tool_ctx())
1088            .await;
1089        assert!(result.is_err(), "non-zero exit should return error");
1090        let err = result.unwrap_err().to_string();
1091        assert!(err.contains("exited with code 42"), "got: {}", err);
1092    }
1093
1094    #[tokio::test]
1095    async fn exit_code_with_output() {
1096        let tool = make_tool();
1097        let result = tool
1098            .execute(
1099                serde_json::json!({"command": "echo before && exit 1"}),
1100                tool_ctx(),
1101            )
1102            .await;
1103        assert!(result.is_err(), "non-zero exit should return error");
1104        let err = result.unwrap_err().to_string();
1105        assert!(err.contains("before"), "got: {}", err);
1106        assert!(err.contains("exited with code 1"), "got: {}", err);
1107    }
1108
1109    #[tokio::test]
1110    async fn no_output() {
1111        let tool = make_tool();
1112        let output = tool
1113            .execute(serde_json::json!({"command": "true"}), tool_ctx())
1114            .await
1115            .unwrap();
1116        assert!(
1117            yo_msg_text(&output.content).contains("(no output)"),
1118            "got: {}",
1119            yo_msg_text(&output.content)
1120        );
1121    }
1122
1123    #[tokio::test]
1124    async fn combined_stdout_stderr() {
1125        let tool = make_tool();
1126        let output = tool
1127            .execute(
1128                serde_json::json!({"command": "echo out; echo err >&2"}),
1129                tool_ctx(),
1130            )
1131            .await
1132            .unwrap();
1133        assert!(
1134            yo_msg_text(&output.content).contains("out"),
1135            "got: {}",
1136            yo_msg_text(&output.content)
1137        );
1138        assert!(
1139            yo_msg_text(&output.content).contains("err"),
1140            "got: {}",
1141            yo_msg_text(&output.content)
1142        );
1143    }
1144
1145    #[tokio::test]
1146    async fn runs_in_cwd() {
1147        let tmp = std::env::temp_dir().join(format!("rab-bash-cwd-{}", uuid::Uuid::new_v4()));
1148        std::fs::create_dir_all(&tmp).unwrap();
1149        std::fs::write(tmp.join("marker.txt"), "hello").unwrap();
1150
1151        let tool = BashTool {
1152            cwd: tmp.clone(),
1153            shell_path: None,
1154            command_prefix: None,
1155            operations: None,
1156        };
1157        let output = tool
1158            .execute(serde_json::json!({"command": "cat marker.txt"}), tool_ctx())
1159            .await
1160            .unwrap();
1161        assert!(
1162            yo_msg_text(&output.content).contains("hello"),
1163            "got: {}",
1164            yo_msg_text(&output.content)
1165        );
1166    }
1167
1168    #[tokio::test]
1169    async fn missing_command_errors() {
1170        let tool = make_tool();
1171        let result = tool.execute(serde_json::json!({}), tool_ctx()).await;
1172        assert!(result.is_err());
1173        let err = result.unwrap_err().to_string();
1174        assert!(err.contains("command"), "got: {}", err);
1175    }
1176
1177    #[tokio::test]
1178    async fn timeout_with_partial_output() {
1179        let tool = make_tool();
1180        let result = tool
1181            .execute(
1182                serde_json::json!({"command": "echo start && sleep 10 && echo end", "timeout": 1}),
1183                tool_ctx(),
1184            )
1185            .await;
1186        assert!(result.is_err());
1187        let err = result.unwrap_err().to_string();
1188        assert!(err.contains("timed out"), "got: {}", err);
1189    }
1190
1191    #[tokio::test]
1192    async fn cancel_during_long_command() {
1193        let tool = make_tool();
1194        let cancel = tokio_util::sync::CancellationToken::new();
1195        let cancel_ctx = cancel.clone();
1196
1197        let handle = tokio::spawn(async move {
1198            tool.execute(
1199                serde_json::json!({"command": "sleep 30"}),
1200                yoagent::types::ToolContext {
1201                    tool_call_id: "id".into(),
1202                    tool_name: "bash".into(),
1203                    cancel: cancel_ctx,
1204                    on_update: None,
1205                    on_progress: None,
1206                },
1207            )
1208            .await
1209        });
1210
1211        tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1212        cancel.cancel();
1213
1214        let result = handle.await.unwrap();
1215        assert!(result.is_err());
1216        let err = result.unwrap_err().to_string();
1217        assert!(
1218            err.contains("aborted") || err.contains("Cancelled"),
1219            "expected cancellation error, got: {}",
1220            err
1221        );
1222    }
1223
1224    #[test]
1225    fn test_truncate_tail_exact_line_fit() {
1226        let lines: String = (1..=2000).map(|i| format!("line {}\n", i)).collect();
1227        let result = truncate_tail(&lines, 2000, 50000);
1228        assert!(!result.truncated);
1229        assert!(result.content.lines().count() == 2000);
1230    }
1231
1232    #[test]
1233    fn test_truncate_tail_one_over_line_limit() {
1234        let lines: String = (1..=2001).map(|i| format!("line {}\n", i)).collect();
1235        let result = truncate_tail(&lines, 2000, 50000);
1236        assert!(result.truncated);
1237        assert_eq!(result.content.lines().count(), 2000);
1238        assert!(result.content.starts_with("line 2"));
1239    }
1240
1241    #[test]
1242    fn test_truncate_tail_exact_byte_fit() {
1243        let line = "a".repeat(50000);
1244        let result = truncate_tail(&line, 2000, 50000);
1245        assert!(!result.truncated);
1246    }
1247
1248    #[test]
1249    fn test_truncate_tail_one_byte_over() {
1250        let line = "a".repeat(50001);
1251        let result = truncate_tail(&line, 2000, 50000);
1252        assert!(result.truncated);
1253        assert!(result.content.len() <= 50000);
1254    }
1255
1256    #[test]
1257    fn test_truncate_tail_single_line_under_limit() {
1258        let result = truncate_tail("hello world", 2000, 50000);
1259        assert!(!result.truncated);
1260        assert_eq!(result.content, "hello world");
1261    }
1262
1263    #[test]
1264    fn test_truncate_tail_trailing_newline() {
1265        let result = truncate_tail("a\nb\n", 2000, 50000);
1266        assert!(!result.truncated);
1267        assert_eq!(result.content, "a\nb\n");
1268    }
1269
1270    #[test]
1271    fn test_truncate_tail_no_trailing_newline() {
1272        let result = truncate_tail("a\nb", 2000, 50000);
1273        assert!(!result.truncated);
1274        assert_eq!(result.content, "a\nb");
1275    }
1276
1277    #[test]
1278    fn test_truncate_tail_single_line_exceeds_limit() {
1279        let content = "x".repeat(60000);
1280        let result = truncate_tail(&content, 2000, 50000);
1281        assert!(result.truncated);
1282        assert!(result.last_line_partial);
1283        assert_eq!(result.content.len(), 50000);
1284        assert!(result.content.ends_with("x".repeat(50000).as_str()));
1285    }
1286
1287    #[test]
1288    fn test_truncate_tail_byte_count_respects_newlines() {
1289        let content: String = (1..=100)
1290            .map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
1291            .collect();
1292        let result = truncate_tail(&content, 2000, 50000);
1293        assert!(result.truncated);
1294        assert!(result.output_bytes <= 50000);
1295    }
1296
1297    #[tokio::test]
1298    async fn truncated_by_lines_shows_footer() {
1299        let tool = make_tool();
1300        let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
1301        let output = tool
1302            .execute(serde_json::json!({"command": cmd}), tool_ctx())
1303            .await
1304            .unwrap();
1305        assert!(
1306            yo_msg_text(&output.content).contains("Showing lines"),
1307            "got: {}",
1308            yo_msg_text(&output.content)
1309        );
1310        assert!(
1311            yo_msg_text(&output.content).contains("Full output:"),
1312            "got: {}",
1313            yo_msg_text(&output.content)
1314        );
1315    }
1316
1317    #[tokio::test]
1318    async fn small_output_no_footer() {
1319        let tool = make_tool();
1320        let output = tool
1321            .execute(serde_json::json!({"command": "echo hello"}), tool_ctx())
1322            .await
1323            .unwrap();
1324        assert!(!yo_msg_text(&output.content).contains("Output truncated"));
1325        assert!(!yo_msg_text(&output.content).contains("Full output:"));
1326    }
1327
1328    #[tokio::test]
1329    async fn truncated_saves_temp_file() {
1330        let tool = make_tool();
1331        let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
1332        let output = tool
1333            .execute(serde_json::json!({"command": cmd}), tool_ctx())
1334            .await
1335            .unwrap();
1336        assert!(
1337            yo_msg_text(&output.content).contains("/rab-bash/"),
1338            "expected temp file path with /rab-bash/, got: {}",
1339            yo_msg_text(&output.content)
1340        );
1341    }
1342
1343    #[test]
1344    fn test_cleanup_stale_temp_files_nonexistent_dir() {
1345        // Should not panic on missing directory
1346        cleanup_stale_temp_files();
1347    }
1348
1349    #[test]
1350    fn test_truncate_tail_many_short_lines() {
1351        let content: String = (1..=10000).map(|i| format!("{}\n", i)).collect();
1352        let result = truncate_tail(&content, 2000, 50000);
1353        assert!(result.truncated);
1354        assert_eq!(result.truncated_by, "lines");
1355        assert_eq!(result.output_lines, 2000);
1356        assert!(
1357            result.content.starts_with("8001"),
1358            "starts with: {:?}",
1359            &result.content[..10]
1360        );
1361    }
1362
1363    #[test]
1364    fn test_truncate_tail_lines_and_bytes_both_exceeded() {
1365        let content: String = (1..=5000)
1366            .map(|i| format!("line {} {}\n", i, "x".repeat(100)))
1367            .collect();
1368        let result = truncate_tail(&content, 2000, 30000);
1369        assert!(result.truncated);
1370        assert_eq!(result.truncated_by, "bytes");
1371        assert!(result.output_lines < 2000);
1372    }
1373
1374    // ── ProcessGuard tests ──────────────────────────────────────
1375
1376    #[test]
1377    fn test_process_guard_tracks_pid() {
1378        let pid = 12345u32;
1379        {
1380            let _guard = ProcessGuard::new(pid);
1381            let pids = TRACKED_PIDS.lock().unwrap();
1382            assert!(pids.contains(&pid));
1383        }
1384        let pids = TRACKED_PIDS.lock().unwrap();
1385        assert!(!pids.contains(&pid));
1386    }
1387
1388    #[test]
1389    fn test_process_guard_zero_pid() {
1390        {
1391            let _guard = ProcessGuard::new(0);
1392            let pids = TRACKED_PIDS.lock().unwrap();
1393            assert!(!pids.contains(&0));
1394        }
1395    }
1396}