Skip to main content

rab/builtin/
bash.rs

1use crate::agent::extension::{Cancel, Extension, ToolDefinition};
2use crate::agent::extension::{ToolRenderContext, ToolRenderer};
3use crate::tui::Theme;
4use crate::tui::ThemeKey;
5use crate::tui::visual_truncate::truncate_to_visual_lines;
6use async_trait::async_trait;
7
8use std::borrow::Cow;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use std::sync::Arc;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::time::Instant;
14use tokio::sync::{Mutex as TokioMutex, mpsc::UnboundedSender};
15
16// ── BashOperations (pluggable) ────────────────────────────────────
17
18/// Pluggable operations for the bash tool (matching pi's BashOperations).
19/// Override these to delegate command execution to remote systems (for example SSH).
20#[async_trait]
21pub trait BashOperations: Send + Sync {
22    /// Execute a command and stream output via the sender.
23    /// Returns the exit code (0 = success, non-zero = error, None = killed).
24    async fn exec(
25        &self,
26        command: &str,
27        cwd: &Path,
28        on_data: UnboundedSender<String>,
29        signal: Option<&Cancel>,
30        timeout: Option<u64>,
31        env: Option<HashMap<String, String>>,
32    ) -> Result<Option<i32>, anyhow::Error>;
33}
34
35#[derive(Clone, Default)]
36pub struct BashToolOptions {
37    /// Custom operations for command execution. Default: local shell.
38    pub operations: Option<Arc<dyn BashOperations>>,
39    /// Command prefix prepended to every command (for example shell setup commands).
40    pub command_prefix: Option<String>,
41    /// Optional explicit shell path from settings.
42    pub shell_path: Option<String>,
43}
44
45pub struct BashExtension {
46    cwd: PathBuf,
47    options: BashToolOptions,
48}
49
50impl BashExtension {
51    pub fn new(cwd: PathBuf) -> Self {
52        Self {
53            cwd,
54            options: BashToolOptions::default(),
55        }
56    }
57
58    pub fn with_options(cwd: PathBuf, options: BashToolOptions) -> Self {
59        Self { cwd, options }
60    }
61
62    pub fn with_shell_path(cwd: PathBuf, shell_path: String) -> Self {
63        Self {
64            cwd,
65            options: BashToolOptions {
66                shell_path: Some(shell_path),
67                ..BashToolOptions::default()
68            },
69        }
70    }
71}
72
73impl Extension for BashExtension {
74    fn name(&self) -> Cow<'static, str> {
75        "bash".into()
76    }
77
78    fn tools(&self) -> Vec<ToolDefinition> {
79        vec![ToolDefinition {
80            tool: Box::new(BashTool {
81                cwd: self.cwd.clone(),
82                shell_path: self.options.shell_path.clone(),
83                command_prefix: self.options.command_prefix.clone(),
84                operations: self.options.operations.clone(),
85            }),
86            snippet: "Execute bash commands (ls, grep, find, etc.)",
87            guidelines: &[],
88            prepare_arguments: None,
89            before_tool_call: None,
90            after_tool_call: None,
91            renderer: Some(std::sync::Arc::new(BashRenderer)),
92        }]
93    }
94}
95
96struct BashTool {
97    cwd: PathBuf,
98    shell_path: Option<String>,
99    command_prefix: Option<String>,
100    operations: Option<Arc<dyn BashOperations>>,
101}
102
103// ── Constants ────────────────────────────────────────────────────
104
105const DEFAULT_MAX_LINES: usize = 2000;
106const DEFAULT_MAX_BYTES: usize = 50 * 1024; // 50KB
107const BASH_TEMP_FILE_PREFIX: &str = "pi-bash";
108
109/// Grace period after child exit (ms) — matching pi's EXIT_STDIO_GRACE_MS.
110/// Detached descendants may keep stdout/stderr pipes open; we poll until idle.
111const EXIT_STDIO_GRACE_MS: u64 = 100;
112
113// ── Shell resolution (matching pi's getShellConfig) ──────────────
114
115/// Shell configuration: which shell binary to use and how to pass commands.
116struct ShellConfig {
117    shell: String,
118    args: Vec<String>,
119}
120
121/// Resolve the shell to use for command execution.
122/// Resolution order (matching pi):
123/// 1. User-specified shell_path (from BashTool.shell_path)
124/// 2. On Unix: /bin/bash, then bash on PATH, then fallback to sh
125/// 3. On Windows: Git Bash, bash on PATH, fallback to sh
126fn resolve_shell(shell_path: Option<&str>) -> ShellConfig {
127    if let Some(path) = shell_path {
128        return ShellConfig {
129            shell: path.to_string(),
130            args: vec!["-c".to_string()],
131        };
132    }
133
134    // Try /bin/bash first (most common on Unix)
135    if std::path::Path::new("/bin/bash").exists() {
136        return ShellConfig {
137            shell: "/bin/bash".to_string(),
138            args: vec!["-c".to_string()],
139        };
140    }
141
142    // Try `which bash`
143    #[cfg(unix)]
144    {
145        if let Ok(output) = std::process::Command::new("which")
146            .arg("bash")
147            .stdout(std::process::Stdio::piped())
148            .stderr(std::process::Stdio::null())
149            .output()
150            && output.status.success()
151        {
152            let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
153            if !path.is_empty() && std::path::Path::new(&path).exists() {
154                return ShellConfig {
155                    shell: path,
156                    args: vec!["-c".to_string()],
157                };
158            }
159        }
160    }
161
162    // Fallback to sh
163    ShellConfig {
164        shell: "sh".to_string(),
165        args: vec!["-c".to_string()],
166    }
167}
168
169// ── Helpers ──────────────────────────────────────────────────────
170
171/// Kill a process group by its leader PID.
172#[cfg(unix)]
173fn kill_process_group(pid: u32) {
174    if pid > 0 {
175        let _ = std::process::Command::new("kill")
176            .arg("--")
177            .arg(format!("-{}", pid))
178            .status();
179    }
180}
181
182#[cfg(not(unix))]
183fn kill_process_group(pid: u32) {
184    let _ = pid;
185}
186
187/// Spawn a bash command with process group setup for clean cancellation.
188fn spawn_bash_command(
189    command: &str,
190    cwd: &std::path::Path,
191    shell_path: Option<&str>,
192) -> std::io::Result<tokio::process::Child> {
193    let shell_cfg = resolve_shell(shell_path);
194
195    #[cfg(unix)]
196    {
197        use std::os::unix::process::CommandExt;
198        let mut std_cmd = std::process::Command::new(&shell_cfg.shell);
199        std_cmd.args(&shell_cfg.args).arg(command).current_dir(cwd);
200        unsafe {
201            std_cmd.pre_exec(|| {
202                libc::setpgid(0, 0);
203                Ok(())
204            });
205        }
206        let mut tokio_cmd = tokio::process::Command::from(std_cmd);
207        tokio_cmd
208            .stdin(std::process::Stdio::null())
209            .stdout(std::process::Stdio::piped())
210            .stderr(std::process::Stdio::piped())
211            .spawn()
212    }
213    #[cfg(not(unix))]
214    {
215        tokio::process::Command::new(&shell_cfg.shell)
216            .args(&shell_cfg.args)
217            .arg(command)
218            .current_dir(cwd)
219            .stdin(std::process::Stdio::null())
220            .stdout(std::process::Stdio::piped())
221            .stderr(std::process::Stdio::piped())
222            .spawn()
223    }
224}
225
226/// Sanitize binary output for display/storage (matching pi's sanitizeBinaryOutput + stripAnsi).
227fn sanitize_output(text: &str) -> String {
228    let mut result = String::with_capacity(text.len());
229    let mut in_escape = false;
230    for c in text.chars() {
231        if in_escape {
232            if c == '\x1b' || c == '\u{9b}' {
233                continue;
234            }
235            if c.is_ascii_alphabetic() || c == '~' {
236                in_escape = false;
237            }
238            continue;
239        }
240        if c == '\x1b' || c == '\u{9b}' {
241            in_escape = true;
242            continue;
243        }
244        let code = c as u32;
245        if code <= 0x1f && code != 0x09 && code != 0x0a && code != 0x0d {
246            continue;
247        }
248        if (0xfff9..=0xfffb).contains(&code) {
249            continue;
250        }
251        result.push(c);
252    }
253    result
254}
255
256fn format_size(bytes: usize) -> String {
257    if bytes < 1024 {
258        format!("{}B", bytes)
259    } else if bytes < 1024 * 1024 {
260        format!("{:.1}KB", bytes as f64 / 1024.0)
261    } else {
262        format!("{:.1}MB", bytes as f64 / (1024.0 * 1024.0))
263    }
264}
265
266/// Truncation result for tail-based truncation (keep last N lines/bytes).
267struct TailTruncation {
268    content: String,
269    truncated: bool,
270    total_lines: usize,
271    output_lines: usize,
272    output_bytes: usize,
273    truncated_by: &'static str,
274    last_line_partial: bool,
275}
276
277/// Truncate content from the tail, keeping complete lines that fit within limits.
278fn truncate_tail(content: &str, max_lines: usize, max_bytes: usize) -> TailTruncation {
279    let total_bytes = content.len();
280    let lines: Vec<&str> = content.lines().collect();
281    let total_lines = lines.len();
282
283    if total_lines <= max_lines && total_bytes <= max_bytes {
284        return TailTruncation {
285            content: content.to_string(),
286            truncated: false,
287            total_lines,
288            output_lines: total_lines,
289            output_bytes: total_bytes,
290            truncated_by: "",
291            last_line_partial: false,
292        };
293    }
294
295    let mut output: Vec<&str> = Vec::new();
296    let mut byte_count: usize = 0;
297    let mut truncated_by = "lines";
298    let mut last_line_partial = false;
299
300    for line in lines.iter().rev().take(max_lines) {
301        let line_bytes = line.len();
302        let with_newline = if output.is_empty() {
303            line_bytes
304        } else {
305            line_bytes + 1
306        };
307
308        if byte_count + with_newline > max_bytes {
309            truncated_by = "bytes";
310            if output.is_empty() {
311                let end_start = line.len().saturating_sub(max_bytes);
312                let truncated_line = &line[end_start..];
313                output.push(truncated_line);
314                byte_count = truncated_line.len();
315                last_line_partial = true;
316            }
317            break;
318        }
319
320        output.push(line);
321        byte_count += with_newline;
322    }
323
324    if output.len() >= max_lines && byte_count <= max_bytes {
325        truncated_by = "lines";
326    }
327
328    output.reverse();
329    TailTruncation {
330        content: output.join("\n"),
331        truncated: true,
332        total_lines,
333        output_lines: output.len(),
334        output_bytes: byte_count,
335        truncated_by,
336        last_line_partial,
337    }
338}
339
340// ── Result formatting ────────────────────────────────────────────
341
342fn finish_bash_execution(
343    combined: &str,
344    exit_code: i32,
345    cancelled: bool,
346    timed_out: Option<u64>,
347    ctx: &yoagent::types::ToolContext,
348) -> std::result::Result<yoagent::types::ToolResult, yoagent::types::ToolError> {
349    let trunc = truncate_tail(combined, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES);
350
351    let mut result_text = if trunc.content.is_empty() {
352        "(no output)".to_string()
353    } else {
354        trunc.content.clone()
355    };
356
357    // Save full output to temp file if truncated
358    let full_output_path = if trunc.truncated {
359        let tmp_dir = std::env::temp_dir().join(BASH_TEMP_FILE_PREFIX);
360        let _ = std::fs::create_dir_all(&tmp_dir);
361        let tmp_path = tmp_dir.join(format!("{}.log", uuid::Uuid::new_v4()));
362        let saved = std::fs::write(&tmp_path, combined).ok().map(|_| tmp_path);
363
364        let start_line = trunc.total_lines - trunc.output_lines + 1;
365        let end_line = trunc.total_lines;
366
367        let notice = if trunc.truncated_by == "lines" {
368            format!(
369                "\n\n[Showing lines {}-{} of {}. Full output: {}]",
370                start_line,
371                end_line,
372                trunc.total_lines,
373                saved
374                    .as_ref()
375                    .map(|p| p.display().to_string())
376                    .unwrap_or_default()
377            )
378        } else {
379            format!(
380                "\n\n[Showing lines {}-{} of {} ({} limit). Full output: {}]",
381                start_line,
382                end_line,
383                trunc.total_lines,
384                format_size(DEFAULT_MAX_BYTES),
385                saved
386                    .as_ref()
387                    .map(|p| p.display().to_string())
388                    .unwrap_or_default()
389            )
390        };
391        result_text.push_str(&notice);
392        saved
393    } else {
394        None
395    };
396
397    // Build structured details
398    let details = if trunc.truncated || full_output_path.is_some() {
399        Some(serde_json::json!({
400            "truncation": {
401                "truncated": trunc.truncated,
402                "truncatedBy": trunc.truncated_by,
403                "totalLines": trunc.total_lines,
404                "outputLines": trunc.output_lines,
405                "outputBytes": trunc.output_bytes,
406                "lastLinePartial": trunc.last_line_partial,
407                "maxLines": DEFAULT_MAX_LINES,
408                "maxBytes": DEFAULT_MAX_BYTES,
409            },
410            "fullOutputPath": full_output_path.as_ref().map(|p| p.display().to_string()),
411        }))
412    } else {
413        None
414    };
415
416    let final_output = if cancelled {
417        if result_text.is_empty() || result_text == "(no output)" {
418            "Command aborted".to_string()
419        } else {
420            format!("{}\n\nCommand aborted", result_text)
421        }
422    } else if let Some(secs) = timed_out {
423        if result_text.is_empty() || result_text == "(no output)" {
424            format!("Command timed out after {} seconds", secs)
425        } else {
426            format!(
427                "{}\n\nCommand timed out after {} seconds",
428                result_text, secs
429            )
430        }
431    } else if exit_code != 0 {
432        if result_text.is_empty() || result_text == "(no output)" {
433            format!("Command exited with code {}", exit_code)
434        } else {
435            format!("{}\n\nCommand exited with code {}", result_text, exit_code)
436        }
437    } else {
438        if let Some(ref on_update) = ctx.on_update {
439            on_update(yoagent::types::ToolResult {
440                content: vec![yoagent::types::Content::Text {
441                    text: result_text.clone(),
442                }],
443                details: details.clone().unwrap_or(serde_json::Value::Null),
444            });
445        }
446        return Ok(yoagent::types::ToolResult {
447            content: vec![yoagent::types::Content::Text { text: result_text }],
448            details: details.unwrap_or(serde_json::Value::Null),
449        });
450    };
451
452    if let Some(ref on_update) = ctx.on_update {
453        on_update(yoagent::types::ToolResult {
454            content: vec![yoagent::types::Content::Text {
455                text: final_output.clone(),
456            }],
457            details: details.clone().unwrap_or(serde_json::Value::Null),
458        });
459    }
460
461    Err(yoagent::types::ToolError::Failed(final_output))
462}
463
464// ── Bash execution helpers (was AgentTool impl, now in yoagent) ──
465
466// ── Bash tool renderer ───────────────────────────────────────────
467
468struct BashRenderer;
469
470impl ToolRenderer for BashRenderer {
471    fn render_call(
472        &self,
473        args: &serde_json::Value,
474        _width: usize,
475        theme: &dyn Theme,
476        _ctx: &ToolRenderContext,
477    ) -> Vec<String> {
478        let cmd = args
479            .get("command")
480            .and_then(|v| v.as_str())
481            .unwrap_or("...");
482        let timeout = args.get("timeout").and_then(|v| v.as_i64());
483        let timeout_suffix = timeout
484            .map(|t| theme.fg_key(ThemeKey::Muted, &format!(" (timeout {}s)", t)))
485            .unwrap_or_default();
486
487        vec![format!(
488            "{}{}",
489            theme.fg_key(ThemeKey::ToolTitle, &theme.bold(&format!("$ {}", cmd))),
490            timeout_suffix
491        )]
492    }
493
494    fn render_result(
495        &self,
496        content: &str,
497        width: usize,
498        theme: &dyn Theme,
499        ctx: &ToolRenderContext,
500    ) -> Vec<String> {
501        let mut lines: Vec<String> = Vec::new();
502
503        let clean = strip_context_truncation_footer(content)
504            .trim_end()
505            .to_string();
506        let all_lines: Vec<&str> = clean.lines().collect();
507
508        if all_lines.is_empty() || (all_lines.len() == 1 && all_lines[0].is_empty()) {
509            return lines;
510        }
511
512        let preview_count = 5;
513        let (preview_lines, hidden_line_count) = if ctx.expanded {
514            (all_lines.clone(), 0)
515        } else {
516            truncate_to_visual_lines(&all_lines, width, preview_count)
517        };
518
519        // ── Preview hint with dim/muted styling (matching pi's keyHint) ──
520        if !ctx.expanded && hidden_line_count > 0 {
521            if ctx.expand_key.is_empty() {
522                lines.push(theme.fg_key(
523                    ThemeKey::Muted,
524                    &format!("... {} earlier lines", hidden_line_count),
525                ));
526            } else {
527                // Pi pattern: muted prefix + dim key + muted suffix
528                // e.g. "... (12 earlier lines, \x1b[2mctrl+o\x1b[22m to expand)"
529                let prefix = theme.fg_key(
530                    ThemeKey::Muted,
531                    &format!("... ({} earlier lines, ", hidden_line_count),
532                );
533                let key_styled = theme.fg("dim", &ctx.expand_key);
534                let suffix = theme.fg_key(ThemeKey::Muted, " to expand)");
535                lines.push(format!("{}{}{}", prefix, key_styled, suffix));
536            }
537        }
538
539        let fg_key = if ctx.is_error { "error" } else { "toolOutput" };
540        for line in &preview_lines {
541            if line.is_empty() {
542                lines.push(String::new());
543            } else {
544                lines.push(theme.fg(fg_key, line));
545            }
546        }
547
548        if let Some(secs) = ctx.duration_secs {
549            if !lines.is_empty() {
550                lines.push(String::new());
551            }
552            let is_complete = ctx.exit_code.is_some() || ctx.cancelled;
553            let label = if is_complete { "Took" } else { "Elapsed" };
554            lines.push(theme.fg_key(ThemeKey::Muted, &format!("{} {:.1}s", label, secs)));
555        }
556
557        if ctx.was_truncated {
558            if !lines.is_empty() {
559                lines.push(String::new());
560            }
561            if let Some(ref path) = ctx.full_output_path {
562                lines.push(theme.fg(
563                    "warning",
564                    &format!("Output truncated. Full output: {}", path),
565                ));
566            } else {
567                lines.push(theme.fg_key(ThemeKey::Warning, "Output truncated."));
568            }
569        }
570
571        lines
572    }
573}
574
575fn strip_context_truncation_footer(output: &str) -> String {
576    let lines: Vec<&str> = output.lines().collect();
577    if lines.len() < 3 {
578        return output.to_string();
579    }
580    let last = lines.last().map_or("", |v| v).trim();
581    if last.starts_with('[')
582        && (last.contains("Showing lines") || last.contains("Showing last"))
583        && last.contains("Full output:")
584    {
585        let before: Vec<&str> = lines[..lines.len() - 1].to_vec();
586        if !before.is_empty() && before[before.len() - 1].is_empty() {
587            before[..before.len() - 1].join("\n")
588        } else {
589            before.join("\n")
590        }
591    } else {
592        output.to_string()
593    }
594}
595
596#[async_trait::async_trait]
597impl yoagent::types::AgentTool for BashTool {
598    fn name(&self) -> &str {
599        "bash"
600    }
601    fn label(&self) -> &str {
602        "bash"
603    }
604    fn description(&self) -> &str {
605        "Execute a bash command in the current working directory. Returns stdout and stderr. \
606         Output is truncated to last 2000 lines or 50KB (whichever is hit first). If \
607         truncated, full output is saved to a temp file. Optionally provide a timeout in seconds."
608    }
609    fn parameters_schema(&self) -> serde_json::Value {
610        serde_json::json!({
611            "type": "object",
612            "required": ["command"],
613            "properties": {
614                "command": {
615                    "type": "string",
616                    "description": "Bash command to execute"
617                },
618                "timeout": {
619                    "type": "number",
620                    "description": "Timeout in seconds (optional, no default timeout)"
621                }
622            }
623        })
624    }
625    async fn execute(
626        &self,
627        params: serde_json::Value,
628        ctx: yoagent::types::ToolContext,
629    ) -> std::result::Result<yoagent::types::ToolResult, yoagent::types::ToolError> {
630        let command = params["command"].as_str().ok_or_else(|| {
631            yoagent::types::ToolError::InvalidArgs("Missing 'command' argument".into())
632        })?;
633        let timeout = params["timeout"].as_u64();
634        let started_at = Instant::now();
635
636        if ctx.cancel.is_cancelled() {
637            return Err(yoagent::types::ToolError::Cancelled);
638        }
639
640        // Apply command prefix if set
641        let effective_command = if let Some(ref prefix) = self.command_prefix {
642            format!("{}\n{}", prefix, command)
643        } else {
644            command.to_string()
645        };
646
647        // Check that the working directory exists
648        if !self.cwd.exists() {
649            return Err(yoagent::types::ToolError::Failed(format!(
650                "Working directory does not exist: {}\nCannot execute bash commands.",
651                self.cwd.display()
652            )));
653        }
654
655        // If custom operations are provided, delegate entirely
656        if let Some(ref ops) = self.operations {
657            let (output_tx, mut output_rx) = tokio::sync::mpsc::unbounded_channel::<String>();
658            let ops_cancel = Cancel::new();
659
660            // Link yoagent cancellation to rab Cancel
661            let yo_cancel = ctx.cancel.clone();
662            let watch_cancel = ops_cancel.clone();
663            tokio::spawn(async move {
664                yo_cancel.cancelled().await;
665                watch_cancel.cancel();
666            });
667
668            let ops_command = effective_command.clone();
669            let ops_cwd = self.cwd.clone();
670            let ops = ops.clone();
671            let ops_handle = tokio::spawn(async move {
672                ops.exec(
673                    &ops_command,
674                    &ops_cwd,
675                    output_tx,
676                    Some(&ops_cancel),
677                    timeout,
678                    None,
679                )
680                .await
681            });
682
683            // Collect output from the channel
684            let mut combined = String::new();
685            while let Some(chunk) = output_rx.recv().await {
686                combined.push_str(&chunk);
687                if let Some(ref on_update) = ctx.on_update {
688                    on_update(yoagent::types::ToolResult {
689                        content: vec![yoagent::types::Content::Text {
690                            text: combined.clone(),
691                        }],
692                        details: serde_json::Value::Null,
693                    });
694                }
695            }
696
697            let exit_code = ops_handle.await.unwrap_or(Ok(None)).unwrap_or(None);
698            let code = exit_code.unwrap_or(-1);
699
700            return finish_bash_execution(&combined, code, ctx.cancel.is_cancelled(), None, &ctx);
701        }
702
703        let mut child =
704            spawn_bash_command(&effective_command, &self.cwd, self.shell_path.as_deref()).map_err(
705                |e| yoagent::types::ToolError::Failed(format!("Failed to spawn command: {}", e)),
706            )?;
707
708        let pid = child.id().unwrap_or(0);
709
710        // Shared output buffer for streaming reads
711        let combined = Arc::new(TokioMutex::new(String::new()));
712        let combined_clone = combined.clone();
713
714        let stdout_pipe = child
715            .stdout
716            .take()
717            .ok_or_else(|| yoagent::types::ToolError::Failed("Failed to capture stdout".into()))?;
718        let stderr_pipe = child
719            .stderr
720            .take()
721            .ok_or_else(|| yoagent::types::ToolError::Failed("Failed to capture stderr".into()))?;
722
723        use tokio::io::AsyncReadExt;
724        let read_task = tokio::spawn(async move {
725            let mut stdout_buf = vec![0u8; 65536];
726            let mut stderr_buf = vec![0u8; 65536];
727            let mut stdout_reader = stdout_pipe;
728            let mut stderr_reader = stderr_pipe;
729            let mut stdout_done = false;
730            let mut stderr_done = false;
731            loop {
732                tokio::select! {
733                    result = stdout_reader.read(&mut stdout_buf), if !stdout_done => {
734                        match result {
735                            Ok(0) => stdout_done = true,
736                            Ok(n) => {
737                                let text = String::from_utf8_lossy(&stdout_buf[..n]);
738                                let sanitized = sanitize_output(&text);
739                                let mut out = combined_clone.lock().await;
740                                out.push_str(&sanitized);
741                            }
742                            Err(_) => stdout_done = true,
743                        }
744                    }
745                    result = stderr_reader.read(&mut stderr_buf), if !stderr_done => {
746                        match result {
747                            Ok(0) => stderr_done = true,
748                            Ok(n) => {
749                                let text = String::from_utf8_lossy(&stderr_buf[..n]);
750                                let sanitized = sanitize_output(&text);
751                                let mut out = combined_clone.lock().await;
752                                out.push_str(&sanitized);
753                            }
754                            Err(_) => stderr_done = true,
755                        }
756                    }
757                }
758                if stdout_done && stderr_done {
759                    break;
760                }
761            }
762        });
763
764        // ── PID tracking for cleanup on shutdown signals ──
765        let _pid_guard = ProcessGuard::new(pid);
766
767        // Set up cancellation monitor: kill the process group if cancelled
768        let cancelled = Arc::new(AtomicBool::new(false));
769        let cancel_flag = cancelled.clone();
770        let yo_cancel = ctx.cancel.clone();
771        let _cancel_monitor: tokio::task::JoinHandle<()> = tokio::spawn(async move {
772            yo_cancel.cancelled().await;
773            cancel_flag.store(true, Ordering::SeqCst);
774            kill_process_group(pid);
775        });
776
777        // Send initial empty update
778        if let Some(ref on_update) = ctx.on_update {
779            on_update(yoagent::types::ToolResult {
780                content: vec![],
781                details: serde_json::Value::Null,
782            });
783        }
784
785        // Wait for the process to exit, with optional timeout and streaming updates
786        let timeout_dur = timeout.map(std::time::Duration::from_secs);
787        let throttle_ms = 100u64;
788        let mut last_update_at = Instant::now();
789
790        let exit_code: i32;
791
792        loop {
793            if cancelled.load(Ordering::SeqCst) {
794                kill_process_group(pid);
795                read_task.abort();
796                let combined_str = combined.lock().await.clone();
797                return finish_bash_execution(&combined_str, -1, true, None, &ctx);
798            }
799
800            if let Some(dur) = timeout_dur
801                && started_at.elapsed() > dur
802            {
803                kill_process_group(pid);
804                read_task.abort();
805                let combined_str = combined.lock().await.clone();
806                return finish_bash_execution(&combined_str, -1, false, timeout, &ctx);
807            }
808
809            if let Some(ref on_update) = ctx.on_update
810                && last_update_at.elapsed().as_millis() as u64 >= throttle_ms
811            {
812                let out = combined.lock().await.clone();
813                if !out.is_empty() {
814                    last_update_at = Instant::now();
815                    on_update(yoagent::types::ToolResult {
816                        content: vec![yoagent::types::Content::Text { text: out }],
817                        details: serde_json::Value::Null,
818                    });
819                }
820            }
821
822            match child.try_wait() {
823                Ok(Some(status)) => {
824                    exit_code = status.code().unwrap_or(-1);
825                    // Idle grace period — after child exit, wait for pipes to go idle
826                    let mut last_len = combined.lock().await.len();
827                    loop {
828                        tokio::time::sleep(std::time::Duration::from_millis(EXIT_STDIO_GRACE_MS))
829                            .await;
830                        let new_len = combined.lock().await.len();
831                        if new_len == last_len {
832                            break;
833                        }
834                        last_len = new_len;
835                    }
836                    read_task.abort();
837                    break;
838                }
839                Ok(None) => {
840                    tokio::time::sleep(std::time::Duration::from_millis(throttle_ms)).await;
841                }
842                Err(_) => {
843                    read_task.await.ok();
844                    exit_code = -1;
845                    break;
846                }
847            }
848        }
849
850        let combined_str = combined.lock().await.clone();
851        if let Some(ref on_update) = ctx.on_update
852            && !combined_str.is_empty()
853        {
854            on_update(yoagent::types::ToolResult {
855                content: vec![yoagent::types::Content::Text {
856                    text: combined_str.clone(),
857                }],
858                details: serde_json::Value::Null,
859            });
860        }
861
862        finish_bash_execution(&combined_str, exit_code, false, None, &ctx)
863    }
864}
865
866// ── PID tracking for cleanup on shutdown signals ────────────────
867// Matching pi's trackDetachedChildPid / untrackDetachedChildPid.
868// On SIGTERM/SIGHUP, all tracked PIDs are killed before exit.
869
870use std::sync::Mutex;
871
872static TRACKED_PIDS: Mutex<Vec<u32>> = std::sync::Mutex::new(Vec::new());
873
874fn track_pid(pid: u32) {
875    if let Ok(mut pids) = TRACKED_PIDS.lock() {
876        pids.push(pid);
877    }
878}
879
880fn untrack_pid(pid: u32) {
881    if let Ok(mut pids) = TRACKED_PIDS.lock() {
882        pids.retain(|&p| p != pid);
883    }
884}
885
886/// Kill all tracked child process groups. Called on SIGTERM/SIGHUP.
887pub fn kill_tracked_children() {
888    let pids: Vec<u32> = TRACKED_PIDS.lock().map(|p| p.clone()).unwrap_or_default();
889    for pid in pids {
890        kill_process_group(pid);
891    }
892}
893
894struct ProcessGuard {
895    pid: u32,
896}
897
898impl ProcessGuard {
899    fn new(pid: u32) -> Self {
900        if pid > 0 {
901            track_pid(pid);
902        }
903        Self { pid }
904    }
905}
906
907impl Drop for ProcessGuard {
908    fn drop(&mut self) {
909        if self.pid > 0 {
910            untrack_pid(self.pid);
911        }
912    }
913}
914
915#[cfg(test)]
916mod tests {
917    use super::*;
918    use yoagent::AgentTool;
919
920    fn tool_ctx() -> yoagent::types::ToolContext {
921        yoagent::types::ToolContext {
922            tool_call_id: "id".into(),
923            tool_name: "bash".into(),
924            cancel: tokio_util::sync::CancellationToken::new(),
925            on_update: None,
926            on_progress: None,
927        }
928    }
929
930    fn yo_msg_text(content: &[yoagent::types::Content]) -> String {
931        content
932            .iter()
933            .filter_map(|c| {
934                if let yoagent::types::Content::Text { text } = c {
935                    Some(text.as_str())
936                } else {
937                    None
938                }
939            })
940            .collect::<Vec<_>>()
941            .join("")
942    }
943
944    fn make_tool() -> BashTool {
945        BashTool {
946            cwd: std::env::temp_dir(),
947            shell_path: None,
948            command_prefix: None,
949            operations: None,
950        }
951    }
952
953    #[tokio::test]
954    async fn runs_simple_command() {
955        let tool = make_tool();
956        let output = tool
957            .execute(serde_json::json!({"command": "echo hello"}), tool_ctx())
958            .await
959            .unwrap();
960        assert!(yo_msg_text(&output.content).contains("hello"));
961    }
962
963    #[tokio::test]
964    async fn captures_stderr() {
965        let tool = make_tool();
966        let output = tool
967            .execute(serde_json::json!({"command": "echo err >&2"}), tool_ctx())
968            .await
969            .unwrap();
970        assert!(yo_msg_text(&output.content).contains("err"));
971    }
972
973    #[tokio::test]
974    async fn cancel_aborts() {
975        let tool = make_tool();
976        let cancel = tokio_util::sync::CancellationToken::new();
977        cancel.cancel();
978        let result = tool
979            .execute(
980                serde_json::json!({"command": "sleep 10"}),
981                yoagent::types::ToolContext {
982                    tool_call_id: "id".into(),
983                    tool_name: "bash".into(),
984                    cancel,
985                    on_update: None,
986                    on_progress: None,
987                },
988            )
989            .await;
990        assert!(result.is_err());
991        let err = result.unwrap_err().to_string();
992        assert!(
993            err.contains("Cancelled") || err.contains("aborted"),
994            "expected cancellation error, got: {}",
995            err
996        );
997    }
998
999    #[tokio::test]
1000    async fn timeout_works() {
1001        let tool = make_tool();
1002        let result = tool
1003            .execute(
1004                serde_json::json!({"command": "sleep 10", "timeout": 1}),
1005                tool_ctx(),
1006            )
1007            .await;
1008        assert!(result.is_err());
1009        let err = result.unwrap_err().to_string();
1010        assert!(err.contains("timed out"));
1011    }
1012
1013    #[test]
1014    fn test_truncate_tail_no_truncation() {
1015        let result = truncate_tail("hello\nworld\n", 2000, 50000);
1016        assert!(!result.truncated);
1017        assert_eq!(result.content, "hello\nworld\n");
1018    }
1019
1020    #[test]
1021    fn test_truncate_tail_by_lines() {
1022        let content: String = (1..=5000).map(|i| format!("line {}\n", i)).collect();
1023        let result = truncate_tail(&content, 2000, 50000);
1024        assert!(result.truncated);
1025        assert!(result.content.starts_with("line 3001"));
1026        assert_eq!(result.content.lines().count(), 2000);
1027    }
1028
1029    #[test]
1030    fn test_truncate_tail_by_bytes() {
1031        let content: String = (1..=100)
1032            .map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
1033            .collect();
1034        let result = truncate_tail(&content, 2000, 50000);
1035        assert!(result.truncated);
1036        assert!(result.content.len() <= 50000);
1037        assert!(result.content.lines().count() < 100);
1038    }
1039
1040    #[test]
1041    fn test_truncate_tail_partial_last_line() {
1042        let content = format!("short\n{}\n", "x".repeat(60000));
1043        let result = truncate_tail(&content, 2000, 50000);
1044        assert!(result.truncated);
1045        assert!(!result.content.starts_with("short"));
1046        assert!(result.content.len() <= 50000);
1047    }
1048
1049    #[test]
1050    fn test_truncate_tail_empty() {
1051        let result = truncate_tail("", 2000, 50000);
1052        assert!(!result.truncated);
1053        assert_eq!(result.content, "");
1054    }
1055
1056    #[tokio::test]
1057    async fn exit_code_nonzero() {
1058        let tool = make_tool();
1059        let result = tool
1060            .execute(serde_json::json!({"command": "exit 42"}), tool_ctx())
1061            .await;
1062        assert!(result.is_err(), "non-zero exit should return error");
1063        let err = result.unwrap_err().to_string();
1064        assert!(err.contains("exited with code 42"), "got: {}", err);
1065    }
1066
1067    #[tokio::test]
1068    async fn exit_code_with_output() {
1069        let tool = make_tool();
1070        let result = tool
1071            .execute(
1072                serde_json::json!({"command": "echo before && exit 1"}),
1073                tool_ctx(),
1074            )
1075            .await;
1076        assert!(result.is_err(), "non-zero exit should return error");
1077        let err = result.unwrap_err().to_string();
1078        assert!(err.contains("before"), "got: {}", err);
1079        assert!(err.contains("exited with code 1"), "got: {}", err);
1080    }
1081
1082    #[tokio::test]
1083    async fn no_output() {
1084        let tool = make_tool();
1085        let output = tool
1086            .execute(serde_json::json!({"command": "true"}), tool_ctx())
1087            .await
1088            .unwrap();
1089        assert!(
1090            yo_msg_text(&output.content).contains("(no output)"),
1091            "got: {}",
1092            yo_msg_text(&output.content)
1093        );
1094    }
1095
1096    #[tokio::test]
1097    async fn combined_stdout_stderr() {
1098        let tool = make_tool();
1099        let output = tool
1100            .execute(
1101                serde_json::json!({"command": "echo out; echo err >&2"}),
1102                tool_ctx(),
1103            )
1104            .await
1105            .unwrap();
1106        assert!(
1107            yo_msg_text(&output.content).contains("out"),
1108            "got: {}",
1109            yo_msg_text(&output.content)
1110        );
1111        assert!(
1112            yo_msg_text(&output.content).contains("err"),
1113            "got: {}",
1114            yo_msg_text(&output.content)
1115        );
1116    }
1117
1118    #[tokio::test]
1119    async fn runs_in_cwd() {
1120        let tmp = std::env::temp_dir().join(format!("rab-bash-cwd-{}", uuid::Uuid::new_v4()));
1121        std::fs::create_dir_all(&tmp).unwrap();
1122        std::fs::write(tmp.join("marker.txt"), "hello").unwrap();
1123
1124        let tool = BashTool {
1125            cwd: tmp.clone(),
1126            shell_path: None,
1127            command_prefix: None,
1128            operations: None,
1129        };
1130        let output = tool
1131            .execute(serde_json::json!({"command": "cat marker.txt"}), tool_ctx())
1132            .await
1133            .unwrap();
1134        assert!(
1135            yo_msg_text(&output.content).contains("hello"),
1136            "got: {}",
1137            yo_msg_text(&output.content)
1138        );
1139    }
1140
1141    #[tokio::test]
1142    async fn missing_command_errors() {
1143        let tool = make_tool();
1144        let result = tool.execute(serde_json::json!({}), tool_ctx()).await;
1145        assert!(result.is_err());
1146        let err = result.unwrap_err().to_string();
1147        assert!(err.contains("command"), "got: {}", err);
1148    }
1149
1150    #[tokio::test]
1151    async fn timeout_with_partial_output() {
1152        let tool = make_tool();
1153        let result = tool
1154            .execute(
1155                serde_json::json!({"command": "echo start && sleep 10 && echo end", "timeout": 1}),
1156                tool_ctx(),
1157            )
1158            .await;
1159        assert!(result.is_err());
1160        let err = result.unwrap_err().to_string();
1161        assert!(err.contains("timed out"), "got: {}", err);
1162    }
1163
1164    #[tokio::test]
1165    async fn cancel_during_long_command() {
1166        let tool = make_tool();
1167        let cancel = tokio_util::sync::CancellationToken::new();
1168        let cancel_ctx = cancel.clone();
1169
1170        let handle = tokio::spawn(async move {
1171            tool.execute(
1172                serde_json::json!({"command": "sleep 30"}),
1173                yoagent::types::ToolContext {
1174                    tool_call_id: "id".into(),
1175                    tool_name: "bash".into(),
1176                    cancel: cancel_ctx,
1177                    on_update: None,
1178                    on_progress: None,
1179                },
1180            )
1181            .await
1182        });
1183
1184        tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1185        cancel.cancel();
1186
1187        let result = handle.await.unwrap();
1188        assert!(result.is_err());
1189        let err = result.unwrap_err().to_string();
1190        assert!(
1191            err.contains("aborted") || err.contains("Cancelled"),
1192            "expected cancellation error, got: {}",
1193            err
1194        );
1195    }
1196
1197    #[test]
1198    fn test_truncate_tail_exact_line_fit() {
1199        let lines: String = (1..=2000).map(|i| format!("line {}\n", i)).collect();
1200        let result = truncate_tail(&lines, 2000, 50000);
1201        assert!(!result.truncated);
1202        assert!(result.content.lines().count() == 2000);
1203    }
1204
1205    #[test]
1206    fn test_truncate_tail_one_over_line_limit() {
1207        let lines: String = (1..=2001).map(|i| format!("line {}\n", i)).collect();
1208        let result = truncate_tail(&lines, 2000, 50000);
1209        assert!(result.truncated);
1210        assert_eq!(result.content.lines().count(), 2000);
1211        assert!(result.content.starts_with("line 2"));
1212    }
1213
1214    #[test]
1215    fn test_truncate_tail_exact_byte_fit() {
1216        let line = "a".repeat(50000);
1217        let result = truncate_tail(&line, 2000, 50000);
1218        assert!(!result.truncated);
1219    }
1220
1221    #[test]
1222    fn test_truncate_tail_one_byte_over() {
1223        let line = "a".repeat(50001);
1224        let result = truncate_tail(&line, 2000, 50000);
1225        assert!(result.truncated);
1226        assert!(result.content.len() <= 50000);
1227    }
1228
1229    #[test]
1230    fn test_truncate_tail_single_line_under_limit() {
1231        let result = truncate_tail("hello world", 2000, 50000);
1232        assert!(!result.truncated);
1233        assert_eq!(result.content, "hello world");
1234    }
1235
1236    #[test]
1237    fn test_truncate_tail_trailing_newline() {
1238        let result = truncate_tail("a\nb\n", 2000, 50000);
1239        assert!(!result.truncated);
1240        assert_eq!(result.content, "a\nb\n");
1241    }
1242
1243    #[test]
1244    fn test_truncate_tail_no_trailing_newline() {
1245        let result = truncate_tail("a\nb", 2000, 50000);
1246        assert!(!result.truncated);
1247        assert_eq!(result.content, "a\nb");
1248    }
1249
1250    #[test]
1251    fn test_truncate_tail_single_line_exceeds_limit() {
1252        let content = "x".repeat(60000);
1253        let result = truncate_tail(&content, 2000, 50000);
1254        assert!(result.truncated);
1255        assert!(result.last_line_partial);
1256        assert_eq!(result.content.len(), 50000);
1257        assert!(result.content.ends_with("x".repeat(50000).as_str()));
1258    }
1259
1260    #[test]
1261    fn test_truncate_tail_byte_count_respects_newlines() {
1262        let content: String = (1..=100)
1263            .map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
1264            .collect();
1265        let result = truncate_tail(&content, 2000, 50000);
1266        assert!(result.truncated);
1267        assert!(result.output_bytes <= 50000);
1268    }
1269
1270    #[tokio::test]
1271    async fn truncated_by_lines_shows_footer() {
1272        let tool = make_tool();
1273        let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
1274        let output = tool
1275            .execute(serde_json::json!({"command": cmd}), tool_ctx())
1276            .await
1277            .unwrap();
1278        assert!(
1279            yo_msg_text(&output.content).contains("Showing lines"),
1280            "got: {}",
1281            yo_msg_text(&output.content)
1282        );
1283        assert!(
1284            yo_msg_text(&output.content).contains("Full output:"),
1285            "got: {}",
1286            yo_msg_text(&output.content)
1287        );
1288    }
1289
1290    #[tokio::test]
1291    async fn small_output_no_footer() {
1292        let tool = make_tool();
1293        let output = tool
1294            .execute(serde_json::json!({"command": "echo hello"}), tool_ctx())
1295            .await
1296            .unwrap();
1297        assert!(!yo_msg_text(&output.content).contains("Output truncated"));
1298        assert!(!yo_msg_text(&output.content).contains("Full output:"));
1299    }
1300
1301    #[tokio::test]
1302    async fn truncated_saves_temp_file() {
1303        let tool = make_tool();
1304        let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
1305        let output = tool
1306            .execute(serde_json::json!({"command": cmd}), tool_ctx())
1307            .await
1308            .unwrap();
1309        assert!(
1310            yo_msg_text(&output.content).contains("/pi-bash/"),
1311            "expected temp file path with /pi-bash/, got: {}",
1312            yo_msg_text(&output.content)
1313        );
1314    }
1315
1316    #[test]
1317    fn test_truncate_tail_many_short_lines() {
1318        let content: String = (1..=10000).map(|i| format!("{}\n", i)).collect();
1319        let result = truncate_tail(&content, 2000, 50000);
1320        assert!(result.truncated);
1321        assert_eq!(result.truncated_by, "lines");
1322        assert_eq!(result.output_lines, 2000);
1323        assert!(
1324            result.content.starts_with("8001"),
1325            "starts with: {:?}",
1326            &result.content[..10]
1327        );
1328    }
1329
1330    #[test]
1331    fn test_truncate_tail_lines_and_bytes_both_exceeded() {
1332        let content: String = (1..=5000)
1333            .map(|i| format!("line {} {}\n", i, "x".repeat(100)))
1334            .collect();
1335        let result = truncate_tail(&content, 2000, 30000);
1336        assert!(result.truncated);
1337        assert_eq!(result.truncated_by, "bytes");
1338        assert!(result.output_lines < 2000);
1339    }
1340
1341    // ── ProcessGuard tests ──────────────────────────────────────
1342
1343    #[test]
1344    fn test_process_guard_tracks_pid() {
1345        let pid = 12345u32;
1346        {
1347            let _guard = ProcessGuard::new(pid);
1348            let pids = TRACKED_PIDS.lock().unwrap();
1349            assert!(pids.contains(&pid));
1350        }
1351        let pids = TRACKED_PIDS.lock().unwrap();
1352        assert!(!pids.contains(&pid));
1353    }
1354
1355    #[test]
1356    fn test_process_guard_zero_pid() {
1357        {
1358            let _guard = ProcessGuard::new(0);
1359            let pids = TRACKED_PIDS.lock().unwrap();
1360            assert!(!pids.contains(&0));
1361        }
1362    }
1363}