Skip to main content

rab/builtin/
bash.rs

1use crate::agent::extension::{AgentTool, Cancel, Extension, ToolOutput};
2use crate::agent::extension::{ToolRenderContext, ToolRenderer};
3use crate::tui::Theme;
4use crate::tui::visual_truncate::truncate_to_visual_lines;
5use anyhow::Context;
6use async_trait::async_trait;
7use std::borrow::Cow;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::time::Instant;
11use tokio::sync::{Mutex as TokioMutex, mpsc::UnboundedSender};
12
13pub struct BashExtension {
14    cwd: std::path::PathBuf,
15}
16
17impl BashExtension {
18    pub fn new(cwd: std::path::PathBuf) -> Self {
19        Self { cwd }
20    }
21}
22
23impl Extension for BashExtension {
24    fn name(&self) -> Cow<'static, str> {
25        "bash".into()
26    }
27
28    fn tools(&self) -> Vec<Box<dyn AgentTool>> {
29        vec![Box::new(BashTool {
30            cwd: self.cwd.clone(),
31        })]
32    }
33}
34
35struct BashTool {
36    cwd: std::path::PathBuf,
37}
38
39// ── Constants ────────────────────────────────────────────────────
40
41const DEFAULT_MAX_LINES: usize = 2000;
42const DEFAULT_MAX_BYTES: usize = 50 * 1024; // 50KB
43const DEFAULT_TIMEOUT_SECS: u64 = 300; // 5 minutes default timeout for all commands
44
45// ── Helpers ──────────────────────────────────────────────────────
46
47/// Kill a process group by its leader PID.
48#[cfg(unix)]
49fn kill_process_group(pid: u32) {
50    if pid > 0 {
51        let _ = std::process::Command::new("kill")
52            .arg("--")
53            .arg(format!("-{}", pid))
54            .spawn();
55    }
56}
57
58#[cfg(not(unix))]
59fn kill_process_group(pid: u32) {
60    let _ = pid;
61}
62
63/// Spawn a bash command with process group setup for clean cancellation.
64fn spawn_bash_command(
65    command: &str,
66    cwd: &std::path::Path,
67) -> std::io::Result<tokio::process::Child> {
68    #[cfg(unix)]
69    {
70        use std::os::unix::process::CommandExt;
71        let mut std_cmd = std::process::Command::new("sh");
72        std_cmd.arg("-c").arg(command).current_dir(cwd);
73        unsafe {
74            std_cmd.pre_exec(|| {
75                libc::setpgid(0, 0);
76                Ok(())
77            });
78        }
79        let mut tokio_cmd = tokio::process::Command::from(std_cmd);
80        tokio_cmd
81            .stdout(std::process::Stdio::piped())
82            .stderr(std::process::Stdio::piped())
83            .spawn()
84    }
85    #[cfg(not(unix))]
86    {
87        tokio::process::Command::new("sh")
88            .arg("-c")
89            .arg(command)
90            .current_dir(cwd)
91            .stdout(std::process::Stdio::piped())
92            .stderr(std::process::Stdio::piped())
93            .spawn()
94    }
95}
96
97/// Format the final bash execution result, matching pi's bash tool output format.
98///
99/// Pi's bash tool (LLM-called) returns raw output, not the `bashExecutionToText` format.
100/// - Non-empty output → raw output (no Ran prefix, no backtick fences)
101/// - Empty output → "(no output)"
102/// - Truncated → raw output + `\n\n[Showing lines X-Y of Z... Full output: path]`
103/// - Non-zero exit → returned as Err with output + `\n\nCommand exited with code N`
104/// - Cancelled → returned as Err with output + `\n\nCommand aborted`
105fn finish_bash_execution(
106    _command: &str,
107    combined: &str,
108    exit_code: i32,
109    cancelled: bool,
110    _started_at: Instant,
111    on_update: Option<UnboundedSender<ToolOutput>>,
112) -> Result<ToolOutput, anyhow::Error> {
113    // Apply tail truncation (pi-style: keep last N lines/bytes)
114    let trunc = truncate_tail(combined, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES);
115
116    // Build output text: raw output or (no output)
117    let mut result_text = if trunc.content.is_empty() {
118        "(no output)".to_string()
119    } else {
120        trunc.content.clone()
121    };
122
123    // Truncation notice (matching pi: appended to text, not in details)
124    if trunc.truncated {
125        let tmp_dir = std::env::temp_dir().join("rab-bash");
126        let _ = std::fs::create_dir_all(&tmp_dir);
127        let tmp_path = tmp_dir.join(format!("{}.txt", uuid::Uuid::new_v4()));
128        let saved = if std::fs::write(&tmp_path, combined).is_ok() {
129            Some(tmp_path)
130        } else {
131            None
132        };
133
134        let start_line = trunc.total_lines - trunc.output_lines + 1;
135        let end_line = trunc.total_lines;
136
137        let notice = if trunc.truncated_by == "lines" {
138            format!(
139                "\n\n[Showing lines {}-{} of {}. Full output: {}]",
140                start_line,
141                end_line,
142                trunc.total_lines,
143                saved
144                    .as_ref()
145                    .map(|p| p.display().to_string())
146                    .unwrap_or_default()
147            )
148        } else {
149            format!(
150                "\n\n[Showing lines {}-{} of {} ({} limit). Full output: {}]",
151                start_line,
152                end_line,
153                trunc.total_lines,
154                format_size(DEFAULT_MAX_BYTES),
155                saved
156                    .as_ref()
157                    .map(|p| p.display().to_string())
158                    .unwrap_or_default()
159            )
160        };
161        result_text.push_str(&notice);
162    }
163
164    // Send final update (before error conversion, so UI shows the output)
165    if let Some(ref tx) = on_update {
166        let _ = tx.send(ToolOutput::ok(result_text.clone()));
167    }
168
169    // Error cases: return as Err with output + status (matching pi)
170    if cancelled {
171        let err_msg = if result_text.is_empty() || result_text == "(no output)" {
172            "Command aborted".to_string()
173        } else {
174            format!("{}\n\nCommand aborted", result_text)
175        };
176        return Err(anyhow::anyhow!("{}", err_msg));
177    }
178
179    if exit_code != 0 {
180        let err_msg = if result_text.is_empty() || result_text == "(no output)" {
181            format!("Command exited with code {}", exit_code)
182        } else {
183            format!("{}\n\nCommand exited with code {}", result_text, exit_code)
184        };
185        return Err(anyhow::anyhow!("{}", err_msg));
186    }
187
188    Ok(ToolOutput::ok(result_text))
189}
190
191/// Format bytes as a human-readable size string, matching pi's format.
192fn format_size(bytes: usize) -> String {
193    if bytes < 1024 {
194        format!("{}B", bytes)
195    } else if bytes < 1024 * 1024 {
196        format!("{:.1}KB", bytes as f64 / 1024.0)
197    } else {
198        format!("{:.1}MB", bytes as f64 / (1024.0 * 1024.0))
199    }
200}
201
202/// Truncation result for tail-based truncation (keep last N lines/bytes).
203struct TailTruncation {
204    /// Truncated output content.
205    content: String,
206    /// Whether truncation occurred.
207    truncated: bool,
208    // Fields below are only used in tests; kept for test assertions.
209    #[allow(dead_code)]
210    total_lines: usize,
211    #[allow(dead_code)]
212    output_lines: usize,
213    #[allow(dead_code)]
214    output_bytes: usize,
215    #[allow(dead_code)]
216    truncated_by: &'static str, // "lines" | "bytes"
217    #[allow(dead_code)]
218    last_line_partial: bool,
219}
220
221/// Truncate content from the tail, keeping complete lines that fit within limits.
222/// Keeps the LAST N lines/bytes. Never returns partial lines unless the last line
223/// of the original content exceeds the byte limit.
224fn truncate_tail(content: &str, max_lines: usize, max_bytes: usize) -> TailTruncation {
225    let total_bytes = content.len();
226    let lines: Vec<&str> = content.lines().collect();
227    let total_lines = lines.len();
228
229    // Check if no truncation needed
230    if total_lines <= max_lines && total_bytes <= max_bytes {
231        return TailTruncation {
232            content: content.to_string(),
233            truncated: false,
234            total_lines,
235            output_lines: total_lines,
236            output_bytes: total_bytes,
237            truncated_by: "",
238            last_line_partial: false,
239        };
240    }
241
242    // Work backwards from the end
243    let mut output: Vec<&str> = Vec::new();
244    let mut byte_count: usize = 0;
245    let mut truncated_by = "lines";
246    let mut last_line_partial = false;
247
248    for line in lines.iter().rev().take(max_lines) {
249        let line_bytes = line.len();
250        let with_newline = if output.is_empty() {
251            line_bytes
252        } else {
253            line_bytes + 1 // +1 for preceding newline
254        };
255
256        if byte_count + with_newline > max_bytes {
257            truncated_by = "bytes";
258            // If we haven't added ANY lines yet and this line exceeds maxBytes,
259            // take the end of the line (partial)
260            if output.is_empty() {
261                let end_start = line.len().saturating_sub(max_bytes);
262                let truncated_line = &line[end_start..];
263                output.push(truncated_line);
264                byte_count = truncated_line.len();
265                last_line_partial = true;
266            }
267            break;
268        }
269
270        output.push(line);
271        byte_count += with_newline;
272    }
273
274    if output.len() >= max_lines && byte_count <= max_bytes {
275        truncated_by = "lines";
276    }
277
278    output.reverse();
279    TailTruncation {
280        content: output.join("\n"),
281        truncated: true,
282        total_lines,
283        output_lines: output.len(),
284        output_bytes: byte_count,
285        truncated_by,
286        last_line_partial,
287    }
288}
289
290// ── AgentTool implementation ─────────────────────────────────────
291
292#[async_trait]
293impl AgentTool for BashTool {
294    fn name(&self) -> &str {
295        "bash"
296    }
297
298    fn description(&self) -> &str {
299        "Execute a bash command in the current working directory. Returns stdout and stderr. \
300         Output is truncated to last 2000 lines or 50KB (whichever is hit first). If truncated, \
301         full output is saved to a temp file. Optionally provide a timeout in seconds."
302    }
303
304    fn parameters(&self) -> serde_json::Value {
305        serde_json::json!({
306            "type": "object",
307            "required": ["command"],
308            "properties": {
309                "command": {
310                    "type": "string",
311                    "description": "Bash command to execute"
312                },
313                "timeout": {
314                    "type": "number",
315                    "description": "Timeout in seconds (optional, no default timeout)"
316                }
317            }
318        })
319    }
320
321    fn label(&self) -> &str {
322        "Execute bash commands (ls, grep, find, etc.)"
323    }
324
325    fn renderer(&self) -> Option<Box<dyn ToolRenderer>> {
326        Some(Box::new(BashRenderer))
327    }
328
329    async fn execute(
330        &self,
331        tool_call_id: String,
332        args: serde_json::Value,
333        cancel: Cancel,
334        on_update: Option<UnboundedSender<ToolOutput>>,
335    ) -> anyhow::Result<ToolOutput> {
336        let _ = tool_call_id;
337        let command = args["command"]
338            .as_str()
339            .ok_or_else(|| anyhow::anyhow!("Missing 'command' argument"))?;
340        let timeout = args["timeout"].as_u64().or(Some(DEFAULT_TIMEOUT_SECS));
341        let started_at = Instant::now();
342
343        cancel.check()?;
344
345        // Build the command with process group setup for process-tree killing
346        let mut child = spawn_bash_command(command, &self.cwd)
347            .with_context(|| format!("Failed to spawn command: {}", command))?;
348
349        let pid = child.id().unwrap_or(0);
350
351        // Shared output buffer for streaming reads
352        let combined = Arc::new(TokioMutex::new(String::new()));
353        let combined_clone = combined.clone();
354
355        // Read stdout in a background task
356        let stdout_pipe = child
357            .stdout
358            .take()
359            .ok_or_else(|| anyhow::anyhow!("Failed to capture stdout"))?;
360        let stderr_pipe = child
361            .stderr
362            .take()
363            .ok_or_else(|| anyhow::anyhow!("Failed to capture stderr"))?;
364
365        use tokio::io::AsyncReadExt;
366        let read_task = tokio::spawn(async move {
367            let mut stdout_buf = vec![0u8; 4096];
368            let mut stderr_buf = vec![0u8; 4096];
369            let mut stdout_reader = stdout_pipe;
370            let mut stderr_reader = stderr_pipe;
371            let mut stdout_done = false;
372            let mut stderr_done = false;
373            loop {
374                tokio::select! {
375                    result = stdout_reader.read(&mut stdout_buf), if !stdout_done => {
376                        match result {
377                            Ok(0) => stdout_done = true,
378                            Ok(n) => {
379                                let mut out = combined_clone.lock().await;
380                                out.push_str(&String::from_utf8_lossy(&stdout_buf[..n]));
381                            }
382                            Err(_) => stdout_done = true,
383                        }
384                    }
385                    result = stderr_reader.read(&mut stderr_buf), if !stderr_done => {
386                        match result {
387                            Ok(0) => stderr_done = true,
388                            Ok(n) => {
389                                let mut out = combined_clone.lock().await;
390                                out.push_str(&String::from_utf8_lossy(&stderr_buf[..n]));
391                            }
392                            Err(_) => stderr_done = true,
393                        }
394                    }
395                }
396                if stdout_done && stderr_done {
397                    break;
398                }
399            }
400        });
401
402        // Set up cancellation monitor: kill the process group if cancelled
403        let cancelled = Arc::new(AtomicBool::new(false));
404        let cancel_clone = cancelled.clone();
405        let _cancel_monitor: tokio::task::JoinHandle<()> = tokio::spawn(async move {
406            while !cancel.is_cancelled() {
407                tokio::time::sleep(std::time::Duration::from_millis(100)).await;
408            }
409            cancel_clone.store(true, Ordering::SeqCst);
410            kill_process_group(pid);
411        });
412
413        // Wait for the process to exit, with optional timeout and streaming updates
414        let timeout_dur = timeout.map(std::time::Duration::from_secs);
415        loop {
416            // Check cancellation
417            if cancelled.load(Ordering::SeqCst) {
418                kill_process_group(pid);
419                read_task.abort();
420                return Err(anyhow::anyhow!("Command aborted"));
421            }
422
423            // Check timeout
424            if let Some(dur) = timeout_dur
425                && started_at.elapsed() > dur
426            {
427                kill_process_group(pid);
428                read_task.abort();
429                return Err(anyhow::anyhow!(
430                    "Command timed out after {} seconds",
431                    timeout.unwrap_or(0)
432                ));
433            }
434
435            // Send streaming update (1s tick interval, matching pi)
436            if let Some(ref tx) = on_update {
437                let out = combined.lock().await;
438                if !out.is_empty() {
439                    let elapsed = started_at.elapsed();
440                    let display = format!(
441                        "{}\n\n[Elapsed {:.1}s]",
442                        out.trim_end(),
443                        elapsed.as_secs_f64()
444                    );
445                    let _ = tx.send(ToolOutput::ok(display));
446                }
447            }
448
449            // Check if process has exited
450            match child.try_wait() {
451                Ok(Some(status)) => {
452                    read_task.await.ok();
453                    let combined_str = combined.lock().await.clone();
454                    let exit_code = status.code().unwrap_or(-1);
455
456                    return finish_bash_execution(
457                        command,
458                        &combined_str,
459                        exit_code,
460                        false,
461                        started_at,
462                        on_update,
463                    );
464                }
465                Ok(None) => {
466                    // Still running, poll again soon (1s tick, matching pi)
467                    tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
468                }
469                Err(_) => {
470                    read_task.await.ok();
471                    let combined_str = combined.lock().await.clone();
472                    let exit_code = -1;
473                    return finish_bash_execution(
474                        command,
475                        &combined_str,
476                        exit_code,
477                        false,
478                        started_at,
479                        on_update,
480                    );
481                }
482            }
483        }
484    }
485}
486
487/// Tool renderer for the `bash` tool.
488/// Formats call headers with `$ command` and result with tail-based preview.
489struct BashRenderer;
490
491// ── Command detection for better headers ────────────────────────
492
493/// Try to extract a meaningful description from a command string.
494/// Returns (command_name, description) if recognized.
495fn parse_command(cmd: &str) -> Option<(&'static str, Option<String>)> {
496    let trimmed = cmd.trim();
497
498    // Skip leading env vars (VAR=value) and cd commands
499    let effective = {
500        let mut rest = trimmed;
501        loop {
502            // Check for VAR=value pattern
503            if let Some(eq_pos) = rest.find('=') {
504                let var_name = &rest[..eq_pos];
505                // Valid env var name: only alphanumeric and underscore
506                if !var_name.is_empty() && var_name.chars().all(|c| c.is_alphanumeric() || c == '_')
507                {
508                    // Skip past the value (space-separated or end)
509                    let after_eq = &rest[eq_pos + 1..];
510                    if let Some(space_pos) = after_eq.find(' ') {
511                        rest = after_eq[space_pos + 1..].trim_start();
512                        continue;
513                    } else {
514                        // No space after value - this is just a VAR=value command
515                        rest = "";
516                        break;
517                    }
518                }
519            }
520            break;
521        }
522        rest
523    };
524
525    // ls
526    if effective.starts_with("ls ") || effective == "ls" {
527        let path = extract_ls_path(effective);
528        return Some(("ls", path));
529    }
530
531    // grep
532    if effective.starts_with("grep ") || effective.starts_with("rg ") {
533        let info = extract_grep_info(effective);
534        return Some(("grep", info));
535    }
536
537    // find
538    if effective.starts_with("find ") {
539        let info = extract_find_info(effective);
540        return Some(("find", info));
541    }
542
543    // cat
544    if effective.starts_with("cat ") || effective == "cat" {
545        let path = effective.strip_prefix("cat ").map(|s| s.trim().to_string());
546        return Some(("cat", path));
547    }
548
549    // head/tail
550    if effective.starts_with("head ") || effective.starts_with("tail ") {
551        let (cmd_name, rest) = if effective.starts_with("head") {
552            ("head", effective.strip_prefix("head").unwrap_or(""))
553        } else {
554            ("tail", effective.strip_prefix("tail").unwrap_or(""))
555        };
556        let path = rest.trim();
557        let path_opt = if path.is_empty() {
558            None
559        } else {
560            Some(path.to_string())
561        };
562        return Some((cmd_name, path_opt));
563    }
564
565    // wc
566    if effective.starts_with("wc ") || effective == "wc" {
567        let path = effective.strip_prefix("wc ").map(|s| s.trim().to_string());
568        return Some(("wc", path));
569    }
570
571    None
572}
573
574/// Extract path argument from ls command.
575fn extract_ls_path(cmd: &str) -> Option<String> {
576    // Simple: ls [path]
577    let args = cmd.strip_prefix("ls").unwrap_or("").trim();
578    if args.is_empty() {
579        Some(".".to_string())
580    } else {
581        // Take last non-flag argument
582        args.split_whitespace()
583            .rfind(|a| !a.starts_with('-'))
584            .map(|s| s.to_string())
585    }
586}
587
588/// Extract search info from grep command.
589fn extract_grep_info(cmd: &str) -> Option<String> {
590    let args = cmd
591        .strip_prefix("grep")
592        .or_else(|| cmd.strip_prefix("rg"))
593        .unwrap_or("")
594        .trim();
595    if args.is_empty() {
596        return None;
597    }
598    // Find the pattern (first non-flag argument)
599    let mut pattern = None;
600    let mut files = Vec::new();
601    let mut skip_next = false;
602    for arg in args.split_whitespace() {
603        if skip_next {
604            skip_next = false;
605            continue;
606        }
607        if arg.starts_with('-') {
608            // Flags that take a value
609            if arg == "-n" || arg == "-C" || arg == "-A" || arg == "-B" || arg == "--max-count" {
610                skip_next = true;
611            }
612            continue;
613        }
614        if pattern.is_none() {
615            pattern = Some(arg);
616        } else {
617            files.push(arg);
618        }
619    }
620    let mut desc = String::new();
621    if let Some(p) = pattern {
622        desc.push_str(p);
623    }
624    if !files.is_empty() {
625        desc.push_str(" in ");
626        desc.push_str(&files.join(", "));
627    }
628    if desc.is_empty() { None } else { Some(desc) }
629}
630
631/// Extract search info from find command.
632fn extract_find_info(cmd: &str) -> Option<String> {
633    let args = cmd.strip_prefix("find").unwrap_or("").trim();
634    if args.is_empty() {
635        return Some(".".to_string());
636    }
637    // Find path and name pattern
638    let mut path = None;
639    let mut name = None;
640    let mut skip_next = false;
641    for arg in args.split_whitespace() {
642        if skip_next {
643            skip_next = false;
644            continue;
645        }
646        if arg == "-name" || arg == "-path" || arg == "-type" {
647            skip_next = true;
648            if arg == "-name" {
649                // Next arg is the pattern
650                continue;
651            }
652        }
653        if arg.starts_with('-') {
654            continue;
655        }
656        if path.is_none() {
657            path = Some(arg);
658        }
659    }
660    // Re-parse to get -name value
661    let mut it = args.split_whitespace();
662    while let Some(arg) = it.next() {
663        if arg == "-name" {
664            name = it.next();
665        }
666    }
667    let mut desc = path.unwrap_or(".").to_string();
668    if let Some(n) = name {
669        desc.push_str(&format!(" (name={})", n));
670    }
671    Some(desc)
672}
673
674/// Format a header for recognized commands.
675fn format_command_header(cmd: &str, theme: &dyn Theme) -> Option<String> {
676    let (name, desc) = parse_command(cmd)?;
677    let title = theme.fg("toolTitle", &theme.bold(name));
678    let detail = desc
679        .map(|d| format!(" {}", theme.fg("accent", &d)))
680        .unwrap_or_default();
681    Some(format!("{}{}", title, detail))
682}
683
684// ── Visual-line-aware truncation (delegated to shared module) ────
685
686impl ToolRenderer for BashRenderer {
687    fn render_call(
688        &self,
689        args: &serde_json::Value,
690        _width: usize,
691        theme: &dyn Theme,
692        _ctx: &ToolRenderContext,
693    ) -> Vec<String> {
694        let cmd = args
695            .get("command")
696            .and_then(|v| v.as_str())
697            .unwrap_or("...");
698        let timeout = args.get("timeout").and_then(|v| v.as_i64());
699        let timeout_suffix = timeout
700            .map(|t| theme.fg("muted", &format!(" (timeout {}s)", t)))
701            .unwrap_or_default();
702
703        // Detect common commands and show them with a nicer header
704        if let Some(header) = format_command_header(cmd, theme) {
705            vec![format!("{}{}", header, timeout_suffix)]
706        } else {
707            vec![format!(
708                "{}{}",
709                theme.fg("toolTitle", &theme.bold(&format!("$ {}", cmd))),
710                timeout_suffix
711            )]
712        }
713    }
714
715    fn render_result(
716        &self,
717        content: &str,
718        width: usize,
719        theme: &dyn Theme,
720        ctx: &ToolRenderContext,
721    ) -> Vec<String> {
722        let mut lines: Vec<String> = Vec::new();
723
724        // Strip truncation footer
725        let clean = strip_context_truncation_footer(content);
726        let all_lines: Vec<&str> = clean.split('\n').collect();
727
728        if all_lines.is_empty() || (all_lines.len() == 1 && all_lines[0].is_empty()) {
729            return lines;
730        }
731
732        // Visual-line-aware truncation (matching pi's truncateToVisualLines)
733        let preview_count = 5;
734        let (preview_lines, hidden_line_count) = if ctx.expanded {
735            (all_lines.clone(), 0)
736        } else {
737            truncate_to_visual_lines(&all_lines, width, preview_count)
738        };
739
740        if !ctx.expanded && hidden_line_count > 0 {
741            let hint = if ctx.expand_key.is_empty() {
742                theme.fg("muted", &format!("... {} earlier lines", hidden_line_count))
743            } else {
744                theme.fg(
745                    "muted",
746                    &format!(
747                        "... ({} earlier lines, {} to expand)",
748                        hidden_line_count, ctx.expand_key
749                    ),
750                )
751            };
752            lines.push(hint);
753        }
754
755        let fg_key = if ctx.is_error { "error" } else { "toolOutput" };
756        for line in &preview_lines {
757            if line.is_empty() {
758                lines.push(String::new());
759            } else {
760                lines.push(theme.fg(fg_key, line));
761            }
762        }
763
764        // Duration
765        if let Some(secs) = ctx.duration_secs {
766            let is_complete = ctx.exit_code.is_some() || ctx.cancelled;
767            let label = if is_complete { "Took" } else { "Elapsed" };
768            lines.push(theme.fg("muted", &format!("{} {:.1}s", label, secs)));
769        }
770
771        // Status
772        if ctx.cancelled {
773            lines.push(theme.fg("warning", "(cancelled)"));
774        } else if let Some(code) = ctx.exit_code
775            && code != 0
776        {
777            lines.push(theme.fg("warning", &format!("(exit {})", code)));
778        }
779
780        // Truncation warnings
781        if ctx.was_truncated {
782            if let Some(ref path) = ctx.full_output_path {
783                lines.push(theme.fg(
784                    "warning",
785                    &format!("Output truncated. Full output: {}", path),
786                ));
787            } else {
788                lines.push(theme.fg("warning", "Output truncated."));
789            }
790        }
791
792        lines
793    }
794}
795
796/// Strip the context-truncation footer from bash output.
797fn strip_context_truncation_footer(output: &str) -> String {
798    let lines: Vec<&str> = output.lines().collect();
799    if lines.len() < 3 {
800        return output.to_string();
801    }
802    let last = lines.last().map_or("", |v| v).trim();
803    if last.starts_with('[')
804        && (last.contains("Showing lines") || last.contains("Showing last"))
805        && last.contains("Full output:")
806    {
807        let before: Vec<&str> = lines[..lines.len() - 1].to_vec();
808        if !before.is_empty() && before[before.len() - 1].is_empty() {
809            before[..before.len() - 1].join("\n")
810        } else {
811            before.join("\n")
812        }
813    } else {
814        output.to_string()
815    }
816}
817
818#[cfg(test)]
819mod tests {
820    use super::*;
821
822    fn make_tool() -> BashTool {
823        BashTool {
824            cwd: std::env::temp_dir(),
825        }
826    }
827
828    #[tokio::test]
829    async fn runs_simple_command() {
830        let tool = make_tool();
831        let output = tool
832            .execute(
833                "id".into(),
834                serde_json::json!({"command": "echo hello"}),
835                Cancel::new(),
836                None,
837            )
838            .await
839            .unwrap();
840        assert!(output.content.contains("hello"));
841    }
842
843    #[tokio::test]
844    async fn captures_stderr() {
845        let tool = make_tool();
846        let output = tool
847            .execute(
848                "id".into(),
849                serde_json::json!({"command": "echo err >&2"}),
850                Cancel::new(),
851                None,
852            )
853            .await
854            .unwrap();
855        assert!(output.content.contains("err"));
856    }
857
858    #[tokio::test]
859    async fn cancel_aborts() {
860        let tool = make_tool();
861        let cancel = Cancel::new();
862        cancel.cancel();
863        let result = tool
864            .execute(
865                "id".into(),
866                serde_json::json!({"command": "sleep 10"}),
867                cancel,
868                None,
869            )
870            .await;
871        assert!(result.is_err());
872        let err = result.unwrap_err().to_string();
873        assert!(
874            err.contains("cancelled") || err.contains("aborted"),
875            "expected cancellation error, got: {}",
876            err
877        );
878    }
879
880    #[tokio::test]
881    async fn timeout_works() {
882        let tool = make_tool();
883        let result = tool
884            .execute(
885                "id".into(),
886                serde_json::json!({"command": "sleep 10", "timeout": 1}),
887                Cancel::new(),
888                None,
889            )
890            .await;
891        assert!(result.is_err());
892        let err = result.unwrap_err().to_string();
893        assert!(err.contains("timed out"));
894    }
895
896    #[test]
897    fn test_truncate_tail_no_truncation() {
898        let result = truncate_tail("hello\nworld\n", 2000, 50000);
899        assert!(!result.truncated);
900        assert_eq!(result.content, "hello\nworld\n");
901    }
902
903    #[test]
904    fn test_truncate_tail_by_lines() {
905        let content: String = (1..=5000).map(|i| format!("line {}\n", i)).collect();
906        let result = truncate_tail(&content, 2000, 50000);
907        assert!(result.truncated);
908        assert!(result.content.starts_with("line 3001"));
909        assert_eq!(result.content.lines().count(), 2000);
910    }
911
912    #[test]
913    fn test_truncate_tail_by_bytes() {
914        let content: String = (1..=100)
915            .map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
916            .collect();
917        let result = truncate_tail(&content, 2000, 50000);
918        assert!(result.truncated);
919        assert!(result.content.len() <= 50000);
920        assert!(result.content.lines().count() < 100);
921    }
922
923    #[test]
924    fn test_truncate_tail_partial_last_line() {
925        // A single line that exceeds the byte limit
926        let content = format!("short\n{}\n", "x".repeat(60000));
927        let result = truncate_tail(&content, 2000, 50000);
928        assert!(result.truncated);
929        assert!(!result.content.starts_with("short"));
930        assert!(result.content.len() <= 50000);
931    }
932
933    #[test]
934    fn test_truncate_tail_empty() {
935        let result = truncate_tail("", 2000, 50000);
936        assert!(!result.truncated);
937        assert_eq!(result.content, "");
938    }
939
940    // ── Exit code integration tests ──────────────────────────────
941
942    #[tokio::test]
943    async fn exit_code_nonzero() {
944        let tool = make_tool();
945        let result = tool
946            .execute(
947                "id".into(),
948                serde_json::json!({"command": "exit 42"}),
949                Cancel::new(),
950                None,
951            )
952            .await;
953        assert!(result.is_err(), "non-zero exit should return error");
954        let err = result.unwrap_err().to_string();
955        assert!(err.contains("exited with code 42"), "got: {}", err);
956    }
957
958    #[tokio::test]
959    async fn exit_code_with_output() {
960        let tool = make_tool();
961        let result = tool
962            .execute(
963                "id".into(),
964                serde_json::json!({"command": "echo before && exit 1"}),
965                Cancel::new(),
966                None,
967            )
968            .await;
969        assert!(result.is_err(), "non-zero exit should return error");
970        let err = result.unwrap_err().to_string();
971        assert!(err.contains("before"), "got: {}", err);
972        assert!(err.contains("exited with code 1"), "got: {}", err);
973    }
974
975    #[tokio::test]
976    async fn no_output() {
977        let tool = make_tool();
978        let output = tool
979            .execute(
980                "id".into(),
981                serde_json::json!({"command": "true"}),
982                Cancel::new(),
983                None,
984            )
985            .await
986            .unwrap();
987        assert!(
988            output.content.contains("(no output)"),
989            "got: {}",
990            output.content
991        );
992    }
993
994    #[tokio::test]
995    async fn combined_stdout_stderr() {
996        let tool = make_tool();
997        let output = tool
998            .execute(
999                "id".into(),
1000                serde_json::json!({"command": "echo out; echo err >&2"}),
1001                Cancel::new(),
1002                None,
1003            )
1004            .await
1005            .unwrap();
1006        assert!(output.content.contains("out"), "got: {}", output.content);
1007        assert!(output.content.contains("err"), "got: {}", output.content);
1008    }
1009
1010    #[tokio::test]
1011    async fn runs_in_cwd() {
1012        let tmp = std::env::temp_dir().join(format!("rab-bash-cwd-{}", uuid::Uuid::new_v4()));
1013        std::fs::create_dir_all(&tmp).unwrap();
1014        std::fs::write(tmp.join("marker.txt"), "hello").unwrap();
1015
1016        let tool = BashTool { cwd: tmp.clone() };
1017        let output = tool
1018            .execute(
1019                "id".into(),
1020                serde_json::json!({"command": "cat marker.txt"}),
1021                Cancel::new(),
1022                None,
1023            )
1024            .await
1025            .unwrap();
1026        assert!(output.content.contains("hello"), "got: {}", output.content);
1027    }
1028
1029    #[tokio::test]
1030    async fn missing_command_errors() {
1031        let tool = make_tool();
1032        let result = tool
1033            .execute("id".into(), serde_json::json!({}), Cancel::new(), None)
1034            .await;
1035        assert!(result.is_err());
1036        let err = result.unwrap_err().to_string();
1037        assert!(err.contains("command"), "got: {}", err);
1038    }
1039
1040    #[tokio::test]
1041    async fn timeout_with_partial_output() {
1042        let tool = make_tool();
1043        // Command that produces some output then hangs
1044        let result = tool
1045            .execute(
1046                "id".into(),
1047                serde_json::json!({"command": "echo start && sleep 10 && echo end", "timeout": 1}),
1048                Cancel::new(),
1049                None,
1050            )
1051            .await;
1052        // May timeout before process is killed, which is fine
1053        // The key is it doesn't hang forever
1054        assert!(result.is_err());
1055        let err = result.unwrap_err().to_string();
1056        assert!(err.contains("timed out"), "got: {}", err);
1057    }
1058
1059    #[tokio::test]
1060    async fn cancel_during_long_command() {
1061        let tool = make_tool();
1062        let cancel = Cancel::new();
1063        let cancel_clone = cancel.clone();
1064
1065        let handle = tokio::spawn(async move {
1066            tool.execute(
1067                "id".into(),
1068                serde_json::json!({"command": "sleep 30"}),
1069                cancel_clone,
1070                None,
1071            )
1072            .await
1073        });
1074
1075        // Give it a moment to start
1076        tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1077        cancel.cancel();
1078
1079        let result = handle.await.unwrap();
1080        assert!(result.is_err());
1081        let err = result.unwrap_err().to_string();
1082        assert!(
1083            err.contains("aborted") || err.contains("cancelled"),
1084            "expected cancellation error, got: {}",
1085            err
1086        );
1087    }
1088
1089    // ── Truncation boundary tests ────────────────────────────────
1090
1091    #[test]
1092    fn test_truncate_tail_exact_line_fit() {
1093        // Content exactly at the line limit - no truncation
1094        let lines: String = (1..=2000).map(|i| format!("line {}\n", i)).collect();
1095        let result = truncate_tail(&lines, 2000, 50000);
1096        assert!(
1097            !result.truncated,
1098            "should not truncate when exactly at line limit"
1099        );
1100        assert!(result.content.lines().count() == 2000);
1101    }
1102
1103    #[test]
1104    fn test_truncate_tail_one_over_line_limit() {
1105        let lines: String = (1..=2001).map(|i| format!("line {}\n", i)).collect();
1106        let result = truncate_tail(&lines, 2000, 50000);
1107        assert!(result.truncated);
1108        assert_eq!(result.content.lines().count(), 2000);
1109        // Should keep last 2000 lines
1110        assert!(result.content.starts_with("line 2"));
1111    }
1112
1113    #[test]
1114    fn test_truncate_tail_exact_byte_fit() {
1115        // Content exactly at byte limit - no truncation
1116        let line = "a".repeat(50000);
1117        let result = truncate_tail(&line, 2000, 50000);
1118        assert!(!result.truncated);
1119    }
1120
1121    #[test]
1122    fn test_truncate_tail_one_byte_over() {
1123        // Content one byte over the limit
1124        let line = "a".repeat(50001);
1125        let result = truncate_tail(&line, 2000, 50000);
1126        assert!(result.truncated);
1127        assert!(result.content.len() <= 50000);
1128    }
1129
1130    #[test]
1131    fn test_truncate_tail_single_line_under_limit() {
1132        let result = truncate_tail("hello world", 2000, 50000);
1133        assert!(!result.truncated);
1134        assert_eq!(result.content, "hello world");
1135    }
1136
1137    #[test]
1138    fn test_truncate_tail_trailing_newline() {
1139        let result = truncate_tail("a\nb\n", 2000, 50000);
1140        assert!(!result.truncated);
1141        assert_eq!(result.content, "a\nb\n");
1142    }
1143
1144    #[test]
1145    fn test_truncate_tail_no_trailing_newline() {
1146        let result = truncate_tail("a\nb", 2000, 50000);
1147        assert!(!result.truncated);
1148        assert_eq!(result.content, "a\nb");
1149    }
1150
1151    #[test]
1152    fn test_truncate_tail_single_line_exceeds_limit() {
1153        let content = "x".repeat(60000);
1154        let result = truncate_tail(&content, 2000, 50000);
1155        assert!(result.truncated);
1156        assert!(result.last_line_partial);
1157        // Should keep the last 50000 bytes of the line
1158        assert_eq!(result.content.len(), 50000);
1159        assert!(result.content.ends_with("x".repeat(50000).as_str()));
1160    }
1161
1162    #[test]
1163    fn test_truncate_tail_byte_count_respects_newlines() {
1164        // Each line is 1000 bytes, 50 lines = 50KB, plus 49 newlines = ~49 bytes extra
1165        // At 2000 line limit, byte limit should be hit first
1166        let content: String = (1..=100)
1167            .map(|i| format!("line {} {}\n", i, "x".repeat(1000)))
1168            .collect();
1169        let result = truncate_tail(&content, 2000, 50000);
1170        assert!(result.truncated);
1171        // Output bytes should be at most 50000 (byte limit)
1172        assert!(
1173            result.output_bytes <= 50000,
1174            "output_bytes {} exceeds limit 50000",
1175            result.output_bytes
1176        );
1177    }
1178
1179    // ── Truncation footer tests ─────────────────────────────────
1180
1181    #[tokio::test]
1182    async fn truncated_by_lines_shows_footer() {
1183        let tool = make_tool();
1184        // Generate 3000 lines of output (exceeds 2000 line limit)
1185        let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
1186        let output = tool
1187            .execute(
1188                "id".into(),
1189                serde_json::json!({"command": cmd}),
1190                Cancel::new(),
1191                None,
1192            )
1193            .await
1194            .unwrap();
1195        assert!(
1196            output.content.contains("Showing lines"),
1197            "got: {}",
1198            output.content
1199        );
1200        assert!(
1201            output.content.contains("Full output:"),
1202            "got: {}",
1203            output.content
1204        );
1205    }
1206
1207    #[tokio::test]
1208    async fn small_output_no_footer() {
1209        let tool = make_tool();
1210        let output = tool
1211            .execute(
1212                "id".into(),
1213                serde_json::json!({"command": "echo hello"}),
1214                Cancel::new(),
1215                None,
1216            )
1217            .await
1218            .unwrap();
1219        // Small output should not have footer markers
1220        assert!(
1221            !output.content.contains("Output truncated"),
1222            "got: {}",
1223            output.content
1224        );
1225        assert!(
1226            !output.content.contains("Full output:"),
1227            "got: {}",
1228            output.content
1229        );
1230    }
1231
1232    #[tokio::test]
1233    async fn truncated_saves_temp_file() {
1234        let tool = make_tool();
1235        // Generate enough output to exceed line limit
1236        let cmd = "for i in $(seq 1 3000); do echo \"line $i\"; done";
1237        let output = tool
1238            .execute(
1239                "id".into(),
1240                serde_json::json!({"command": cmd}),
1241                Cancel::new(),
1242                None,
1243            )
1244            .await
1245            .unwrap();
1246        // Should mention a temp file path
1247        assert!(
1248            output.content.contains("/rab-bash/"),
1249            "expected temp file path, got: {}",
1250            output.content
1251        );
1252    }
1253
1254    // ── Truncate tail: many short lines ──────────────────────────
1255
1256    #[test]
1257    fn test_truncate_tail_many_short_lines() {
1258        // 10000 very short lines, well under byte limit
1259        let content: String = (1..=10000).map(|i| format!("{}\n", i)).collect();
1260        let result = truncate_tail(&content, 2000, 50000);
1261        assert!(result.truncated);
1262        assert_eq!(result.truncated_by, "lines");
1263        assert_eq!(result.output_lines, 2000);
1264        // Should keep the last 2000 lines
1265        assert!(
1266            result.content.starts_with("8001"),
1267            "starts with: {:?}",
1268            &result.content[..10]
1269        );
1270    }
1271
1272    #[test]
1273    fn test_truncate_tail_lines_and_bytes_both_exceeded() {
1274        // Both limits exceeded - byte limit should win (more restrictive)
1275        let content: String = (1..=5000)
1276            .map(|i| format!("line {} {}\n", i, "x".repeat(100)))
1277            .collect();
1278        let result = truncate_tail(&content, 2000, 30000);
1279        assert!(result.truncated);
1280        // With 100-byte lines, 300 lines would be ~30KB + newlines
1281        // So byte limit should be hit before line limit
1282        assert_eq!(result.truncated_by, "bytes");
1283        assert!(result.output_lines < 2000);
1284    }
1285}
1286
1287#[cfg(test)]
1288mod command_tests {
1289    use super::*;
1290
1291    #[test]
1292    fn test_parse_ls() {
1293        let result = parse_command("ls -la src/");
1294        assert!(result.is_some());
1295        let (name, desc) = result.unwrap();
1296        assert_eq!(name, "ls");
1297        assert_eq!(desc, Some("src/".to_string()));
1298    }
1299
1300    #[test]
1301    fn test_parse_ls_default() {
1302        let result = parse_command("ls");
1303        assert!(result.is_some());
1304        let (name, desc) = result.unwrap();
1305        assert_eq!(name, "ls");
1306        assert_eq!(desc, Some(".".to_string()));
1307    }
1308
1309    #[test]
1310    fn test_parse_grep() {
1311        let result = parse_command("grep -r \"pattern\" src/");
1312        assert!(result.is_some());
1313        let (name, desc) = result.unwrap();
1314        assert_eq!(name, "grep");
1315        assert!(desc.is_some());
1316        let desc = desc.unwrap();
1317        assert!(desc.contains("pattern"));
1318        assert!(desc.contains("src/"));
1319    }
1320
1321    #[test]
1322    fn test_parse_rg() {
1323        let result = parse_command("rg pattern src/");
1324        assert!(result.is_some());
1325        let (name, _) = result.unwrap();
1326        assert_eq!(name, "grep");
1327    }
1328
1329    #[test]
1330    fn test_parse_find() {
1331        let result = parse_command("find . -name \"*.rs\"");
1332        assert!(result.is_some());
1333        let (name, desc) = result.unwrap();
1334        assert_eq!(name, "find");
1335        assert!(desc.is_some());
1336        let desc = desc.unwrap();
1337        assert!(desc.contains("."));
1338        assert!(desc.contains("*.rs"));
1339    }
1340
1341    #[test]
1342    fn test_parse_cat() {
1343        let result = parse_command("cat README.md");
1344        assert!(result.is_some());
1345        let (name, desc) = result.unwrap();
1346        assert_eq!(name, "cat");
1347        assert_eq!(desc, Some("README.md".to_string()));
1348    }
1349
1350    #[test]
1351    fn test_parse_head() {
1352        let result = parse_command("head -20 file.txt");
1353        assert!(result.is_some());
1354        let (name, desc) = result.unwrap();
1355        assert_eq!(name, "head");
1356        assert_eq!(desc, Some("-20 file.txt".to_string()));
1357    }
1358
1359    #[test]
1360    fn test_parse_tail() {
1361        let result = parse_command("tail -f log.txt");
1362        assert!(result.is_some());
1363        let (name, desc) = result.unwrap();
1364        assert_eq!(name, "tail");
1365        assert_eq!(desc, Some("-f log.txt".to_string()));
1366    }
1367
1368    #[test]
1369    fn test_parse_wc() {
1370        let result = parse_command("wc -l file.txt");
1371        assert!(result.is_some());
1372        let (name, desc) = result.unwrap();
1373        assert_eq!(name, "wc");
1374        assert_eq!(desc, Some("-l file.txt".to_string()));
1375    }
1376
1377    #[test]
1378    fn test_parse_unknown() {
1379        let result = parse_command("echo hello");
1380        assert!(result.is_none());
1381    }
1382
1383    #[test]
1384    fn test_parse_with_env() {
1385        let result = parse_command("FOO=bar ls src/");
1386        assert!(result.is_some());
1387        let (name, desc) = result.unwrap();
1388        assert_eq!(name, "ls");
1389        assert_eq!(desc, Some("src/".to_string()));
1390    }
1391}