Skip to main content

koda_core/tools/
shell.rs

1//! Shell command execution tool (Bash).
2//!
3//! Runs commands as child processes with timeout protection.
4//! Output line cap is set by `OutputCaps` (context-scaled).
5//!
6//! ## Parameters
7//!
8//! - **`command`** (required) — The shell command to execute
9//! - **`timeout`** (optional, default 60) — Timeout in seconds
10//! - **`background`** (optional, default false) — Run in background, return PID
11//!
12//! ## Background mode
13//!
14//! When `background: true` the command is spawned detached and control returns
15//! immediately with the PID. Use for dev servers, file watchers, and other
16//! long-running processes. Background processes are tracked in `BgRegistry`.
17//!
18//! ## Safety
19//!
20//! - Commands are classified by `bash_safety::classify_bash_command`
21//! - Destructive commands (`rm -rf`, `git push --force`) always need confirmation
22//! - Path escapes outside the project root are flagged by `bash_path_lint`
23//! - Output is capped to prevent context overflow (verbose output is truncated)
24//!
25//! ## Best practices (sent to the model)
26//!
27//! - Use Bash only for builds, tests, git, and commands without a dedicated tool
28//! - Never use Bash for file ops — use Read/Write/Edit/Grep/List instead
29//! - Suppress verbose output: pipe to `tail`, use `--quiet`, avoid `-v` flags
30
31use crate::engine::{EngineEvent, EngineSink};
32use crate::providers::ToolDefinition;
33use crate::tools::bg_process::BgRegistry;
34use anyhow::Result;
35use serde_json::{Value, json};
36use std::path::Path;
37use tokio::io::{AsyncBufReadExt, BufReader};
38
39const DEFAULT_TIMEOUT_SECS: u64 = 60;
40/// Hard ceiling to prevent LLM-controlled DoS via huge timeout values.
41const MAX_TIMEOUT_SECS: u64 = 300;
42/// Max stderr lines to include in the summary (stderr is high-signal).
43const SUMMARY_STDERR_LINES: usize = 50;
44/// Max stdout tail lines to include in the summary.
45const SUMMARY_STDOUT_TAIL: usize = 20;
46/// Hard memory ceiling for line collection. Pathological commands (`yes`,
47/// `cat /dev/urandom | base64`) can produce gigabytes within the 300s timeout.
48/// Once this byte threshold is reached, lines are still streamed to the TUI
49/// but no longer collected into the in-memory Vec. The DB cap
50/// (`MAX_FULL_OUTPUT_BYTES`) handles what actually gets persisted.
51const MAX_COLLECT_BYTES: usize = 10 * 1024 * 1024; // 10 MB
52
53/// Result of a shell command with both a model-facing summary and full output.
54#[derive(Debug, Clone)]
55pub struct ShellOutput {
56    /// Compact summary for the model's context window.
57    pub summary: String,
58    /// Full untruncated output for DB storage / RecallContext retrieval.
59    /// `None` for background commands (no output to capture).
60    pub full_output: Option<String>,
61}
62
63/// Return tool definitions for the LLM.
64pub fn definitions() -> Vec<ToolDefinition> {
65    vec![ToolDefinition {
66        name: "Bash".to_string(),
67        description: "Execute a shell command. Use ONLY for builds, tests, git, \
68            and commands without a dedicated tool. Never use for file ops \
69            (use Read/Write/Edit/Grep/List instead). Suppress verbose output: \
70            pipe to tail, use --quiet, avoid -v flags. \
71            Set background=true for long-running processes (dev servers, watchers) \
72            — returns immediately with the PID."
73            .to_string(),
74        parameters: json!({
75            "type": "object",
76            "properties": {
77                "command": {
78                    "type": "string",
79                    "description": "The shell command to execute"
80                },
81                "timeout": {
82                    "type": "integer",
83                    "description": "Timeout in seconds (default: 60, ignored when background=true)"
84                },
85                "background": {
86                    "type": "boolean",
87                    "description": "Run in background and return immediately with PID (default: false). \
88                        Use for dev servers, file watchers, and other long-running processes."
89                }
90            },
91            "required": ["command"]
92        }),
93    }]
94}
95
96/// Execute a shell command with timeout, output capping, and optional streaming.
97///
98/// When `sink` is provided, each line of stdout/stderr is emitted as a
99/// `ToolOutputLine` event as it arrives — giving the TUI a live terminal feel.
100/// The full output is still collected and returned as the tool result.
101///
102/// When `args["background"]` is `true`, the process is spawned detached and
103/// this function returns immediately with the PID.  The `BgRegistry` tracks
104/// the child so it is cleaned up (SIGTERM) when the session ends.
105pub async fn run_shell_command(
106    project_root: &Path,
107    args: &Value,
108    max_lines: usize,
109    bg: &BgRegistry,
110    sink: Option<(&dyn EngineSink, &str)>,
111    trust: &crate::trust::TrustMode,
112) -> Result<ShellOutput> {
113    let command = args["command"]
114        .as_str()
115        .ok_or_else(|| anyhow::anyhow!("Missing 'command' argument"))?;
116    let background = args["background"].as_bool().unwrap_or(false);
117
118    tracing::info!(
119        "Running shell command (background={background}): [{} chars]",
120        command.len()
121    );
122
123    if background {
124        let msg = spawn_background(project_root, command, bg, trust)?;
125        return Ok(ShellOutput {
126            summary: msg,
127            full_output: None,
128        });
129    }
130
131    let timeout_secs = args["timeout"]
132        .as_u64()
133        .unwrap_or(DEFAULT_TIMEOUT_SECS)
134        .min(MAX_TIMEOUT_SECS);
135
136    // Spawn via sandbox wrapper (enforced for all trust modes).
137    let mut child = crate::sandbox::build(command, project_root, trust)?
138        .stdout(std::process::Stdio::piped())
139        .stderr(std::process::Stdio::piped())
140        .spawn()
141        .map_err(|e| anyhow::anyhow!("Failed to execute command: {e}"))?;
142
143    let stdout = child.stdout.take().unwrap();
144    let stderr = child.stderr.take().unwrap();
145
146    let mut stdout_lines: Vec<String> = Vec::new();
147    let mut stderr_lines: Vec<String> = Vec::new();
148
149    // Read stdout and stderr concurrently, streaming lines as they arrive.
150    // Lines are always streamed to the TUI, but collection into Vec stops
151    // once max_lines or MAX_COLLECT_BYTES is reached (OOM protection).
152    let sink_info = sink.map(|(s, id)| (s, id.to_string()));
153    let result = tokio::time::timeout(
154        std::time::Duration::from_secs(timeout_secs),
155        read_streams(
156            stdout,
157            stderr,
158            &mut stdout_lines,
159            &mut stderr_lines,
160            max_lines,
161            &sink_info,
162        ),
163    )
164    .await;
165
166    match result {
167        Ok(Ok(())) => {
168            // Wait for exit status after streams are drained.
169            let status = child
170                .wait()
171                .await
172                .map_err(|e| anyhow::anyhow!("wait: {e}"))?;
173            let exit_code = status.code().unwrap_or(-1);
174
175            let summary = format_summary(exit_code, &stdout_lines, &stderr_lines);
176            let full = format_full_output(exit_code, &stdout_lines, &stderr_lines);
177
178            Ok(ShellOutput {
179                summary,
180                full_output: Some(full),
181            })
182        }
183        Ok(Err(e)) => Err(anyhow::anyhow!("Stream read error: {e}")),
184        Err(_) => {
185            // Timeout — kill the child.
186            let _ = child.kill().await;
187            let msg = format!("Command timed out after {timeout_secs}s: {command}");
188            Ok(ShellOutput {
189                summary: msg.clone(),
190                full_output: Some(msg),
191            })
192        }
193    }
194}
195
196/// Read stdout and stderr concurrently, collecting lines and optionally streaming them.
197///
198/// Lines are always streamed to the TUI sink (if present), but collection into
199/// the Vecs is gated by two caps:
200///   - `max_lines` — total stdout + stderr lines collected
201///   - `MAX_COLLECT_BYTES` — total bytes collected (OOM protection)
202///
203/// Once either cap is hit, new lines are still streamed to the TUI but silently
204/// dropped from the Vecs. This keeps the TUI responsive while bounding memory
205/// for pathological commands.
206async fn read_streams(
207    stdout: tokio::process::ChildStdout,
208    stderr: tokio::process::ChildStderr,
209    stdout_lines: &mut Vec<String>,
210    stderr_lines: &mut Vec<String>,
211    max_lines: usize,
212    sink_info: &Option<(&dyn EngineSink, String)>,
213) -> std::io::Result<()> {
214    let mut stdout_reader = BufReader::new(stdout).lines();
215    let mut stderr_reader = BufReader::new(stderr).lines();
216
217    let mut stdout_done = false;
218    let mut stderr_done = false;
219    let mut collected_bytes: usize = 0;
220    let mut collected_lines: usize = 0;
221
222    while !stdout_done || !stderr_done {
223        tokio::select! {
224            line = stdout_reader.next_line(), if !stdout_done => {
225                match line? {
226                    Some(l) => {
227                        if let Some((sink, id)) = sink_info {
228                            sink.emit(EngineEvent::ToolOutputLine {
229                                id: id.clone(),
230                                line: l.clone(),
231                                is_stderr: false,
232                            });
233                        }
234                        if collected_lines < max_lines
235                            && collected_bytes < MAX_COLLECT_BYTES
236                        {
237                            collected_bytes += l.len();
238                            collected_lines += 1;
239                            stdout_lines.push(l);
240                        }
241                    }
242                    None => stdout_done = true,
243                }
244            }
245            line = stderr_reader.next_line(), if !stderr_done => {
246                match line? {
247                    Some(l) => {
248                        if let Some((sink, id)) = sink_info {
249                            sink.emit(EngineEvent::ToolOutputLine {
250                                id: id.clone(),
251                                line: l.clone(),
252                                is_stderr: true,
253                            });
254                        }
255                        if collected_lines < max_lines
256                            && collected_bytes < MAX_COLLECT_BYTES
257                        {
258                            collected_bytes += l.len();
259                            collected_lines += 1;
260                            stderr_lines.push(l);
261                        }
262                    }
263                    None => stderr_done = true,
264                }
265            }
266        }
267    }
268    Ok(())
269}
270
271/// Spawn a command in the background and register it.
272///
273/// Returns immediately with PID + instructions. Sync because `spawn()` doesn't
274/// need to await — only `output()` / `wait()` block.
275fn spawn_background(
276    project_root: &Path,
277    command: &str,
278    bg: &BgRegistry,
279    trust: &crate::trust::TrustMode,
280) -> Result<String> {
281    // Spawn via sandbox wrapper (enforced for all trust modes).
282    // Detach stdio so the process doesn't block on terminal I/O.
283    let child = crate::sandbox::build(command, project_root, trust)?
284        .stdin(std::process::Stdio::null())
285        .stdout(std::process::Stdio::null())
286        .stderr(std::process::Stdio::null())
287        .spawn()
288        .map_err(|e| anyhow::anyhow!("Failed to spawn background command: {e}"))?;
289
290    let pid = child
291        .id()
292        .ok_or_else(|| anyhow::anyhow!("Spawned process has no PID (already exited)"))?;
293
294    bg.insert(pid, command.to_string(), child);
295
296    Ok(format!(
297        "Background process started.\n  PID:     {pid}\n  Command: {command}\n\
298         To stop:  Bash{{command: \"kill {pid}\"}}\n\
299         To force: Bash{{command: \"kill -9 {pid}\"}}\n\
300         Note: process will be stopped automatically when the session ends."
301    ))
302}
303
304/// Build a compact summary for the model's context window.
305///
306/// Includes all stderr (high-signal — errors/warnings) and only the tail
307/// of stdout (low-signal — build progress noise).  Line counts let the
308/// model decide whether to retrieve the full output via RecallContext.
309fn format_summary(exit_code: i32, stdout_lines: &[String], stderr_lines: &[String]) -> String {
310    let mut out = format!(
311        "Exit code: {exit_code} | stdout: {} lines | stderr: {} lines",
312        stdout_lines.len(),
313        stderr_lines.len(),
314    );
315
316    // Stderr first — always include (capped at SUMMARY_STDERR_LINES).
317    if !stderr_lines.is_empty() {
318        let (label, text) = if stderr_lines.len() > SUMMARY_STDERR_LINES {
319            let skipped = stderr_lines.len() - SUMMARY_STDERR_LINES;
320            (
321                format!(
322                    "\n\n--- stderr (last {} of {}, {skipped} skipped) ---",
323                    SUMMARY_STDERR_LINES,
324                    stderr_lines.len(),
325                ),
326                stderr_lines[stderr_lines.len() - SUMMARY_STDERR_LINES..].join("\n"),
327            )
328        } else {
329            (
330                format!("\n\n--- stderr ({} lines) ---", stderr_lines.len()),
331                stderr_lines.join("\n"),
332            )
333        };
334        out.push_str(&label);
335        out.push('\n');
336        out.push_str(&text);
337    }
338
339    // Stdout tail — only last N lines.
340    if !stdout_lines.is_empty() {
341        let (label, text) = if stdout_lines.len() > SUMMARY_STDOUT_TAIL {
342            (
343                format!(
344                    "\n\n--- stdout (last {} of {}) ---",
345                    SUMMARY_STDOUT_TAIL,
346                    stdout_lines.len(),
347                ),
348                stdout_lines[stdout_lines.len() - SUMMARY_STDOUT_TAIL..].join("\n"),
349            )
350        } else {
351            (
352                format!("\n\n--- stdout ({} lines) ---", stdout_lines.len()),
353                stdout_lines.join("\n"),
354            )
355        };
356        out.push_str(&label);
357        out.push('\n');
358        out.push_str(&text);
359    }
360
361    // Hint for the model.
362    if stdout_lines.len() > SUMMARY_STDOUT_TAIL || stderr_lines.len() > SUMMARY_STDERR_LINES {
363        out.push_str("\n\nFull output stored. Use RecallContext to search if needed.");
364    }
365
366    out
367}
368
369/// Build the full output for DB storage.
370///
371/// Stored in `messages.full_content` and searchable via RecallContext.
372/// Capped at 2 MB — generous enough for RecallContext to find errors deep in
373/// build/test output, while still preventing pathological commands from
374/// bloating the SQLite DB.
375fn format_full_output(exit_code: i32, stdout_lines: &[String], stderr_lines: &[String]) -> String {
376    const MAX_FULL_OUTPUT_BYTES: usize = 2 * 1024 * 1024; // 2 MB
377
378    let mut out = format!("Exit code: {exit_code}\n");
379    if !stdout_lines.is_empty() {
380        out.push_str("\n--- stdout ---\n");
381        out.push_str(&stdout_lines.join("\n"));
382    }
383    if !stderr_lines.is_empty() {
384        out.push_str("\n\n--- stderr ---\n");
385        out.push_str(&stderr_lines.join("\n"));
386    }
387
388    // Hard cap to prevent DB bloat from pathological commands.
389    if out.len() > MAX_FULL_OUTPUT_BYTES {
390        out.truncate(MAX_FULL_OUTPUT_BYTES);
391        // Find safe char boundary
392        while !out.is_char_boundary(out.len()) {
393            out.pop();
394        }
395        out.push_str("\n\n[... output truncated at 2MB ...]");
396    }
397
398    out
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use crate::tools::bg_process::BgRegistry;
405
406    fn bg() -> BgRegistry {
407        BgRegistry::new()
408    }
409
410    #[tokio::test]
411    async fn shell_timeout_returns_timeout_message() {
412        let tmp = tempfile::tempdir().unwrap();
413        let args = serde_json::json!({"command": "sleep 5", "timeout": 1});
414        let result = run_shell_command(
415            tmp.path(),
416            &args,
417            256,
418            &bg(),
419            None,
420            &crate::trust::TrustMode::Safe,
421        )
422        .await
423        .unwrap();
424        assert!(
425            result.summary.contains("timed out"),
426            "Expected timeout message, got: {}",
427            result.summary
428        );
429    }
430
431    #[tokio::test]
432    async fn shell_respects_custom_timeout_parameter() {
433        let tmp = tempfile::tempdir().unwrap();
434        let args = serde_json::json!({"command": "echo hello", "timeout": 5});
435        let result = run_shell_command(
436            tmp.path(),
437            &args,
438            256,
439            &bg(),
440            None,
441            &crate::trust::TrustMode::Safe,
442        )
443        .await
444        .unwrap();
445        assert!(
446            result.summary.contains("hello"),
447            "Fast command should succeed: {}",
448            result.summary
449        );
450    }
451
452    #[tokio::test]
453    async fn shell_default_timeout_is_applied_when_not_specified() {
454        let tmp = tempfile::tempdir().unwrap();
455        let args = serde_json::json!({"command": "echo world"});
456        let result = run_shell_command(
457            tmp.path(),
458            &args,
459            256,
460            &bg(),
461            None,
462            &crate::trust::TrustMode::Safe,
463        )
464        .await
465        .unwrap();
466        assert!(
467            result.summary.contains("world"),
468            "Command without explicit timeout should work: {}",
469            result.summary
470        );
471    }
472
473    #[tokio::test]
474    async fn background_spawn_returns_pid() {
475        let tmp = tempfile::tempdir().unwrap();
476        let registry = BgRegistry::new();
477        let args = serde_json::json!({"command": "sleep 60", "background": true});
478        let result = run_shell_command(
479            tmp.path(),
480            &args,
481            256,
482            &registry,
483            None,
484            &crate::trust::TrustMode::Safe,
485        )
486        .await
487        .unwrap();
488        assert!(
489            result.summary.contains("Background process started"),
490            "{}",
491            result.summary
492        );
493        assert!(result.summary.contains("PID:"), "{}", result.summary);
494        assert!(result.summary.contains("kill"), "{}", result.summary);
495        assert!(
496            result.full_output.is_none(),
497            "background has no full_output"
498        );
499        assert_eq!(registry.len(), 1);
500    }
501
502    #[tokio::test]
503    async fn background_false_runs_synchronously() {
504        let tmp = tempfile::tempdir().unwrap();
505        let args = serde_json::json!({"command": "echo sync", "background": false});
506        let result = run_shell_command(
507            tmp.path(),
508            &args,
509            256,
510            &bg(),
511            None,
512            &crate::trust::TrustMode::Safe,
513        )
514        .await
515        .unwrap();
516        assert!(result.summary.contains("sync"), "{}", result.summary);
517        assert!(
518            !result.summary.contains("PID:"),
519            "foreground should not have PID line: {}",
520            result.summary
521        );
522    }
523
524    #[test]
525    fn test_format_summary_short_output() {
526        let stdout: Vec<String> = vec!["hello", "world"]
527            .into_iter()
528            .map(String::from)
529            .collect();
530        let stderr: Vec<String> = vec![];
531        let summary = format_summary(0, &stdout, &stderr);
532        assert!(summary.contains("Exit code: 0"));
533        assert!(summary.contains("stdout: 2 lines"));
534        assert!(summary.contains("hello"));
535        assert!(summary.contains("world"));
536        // Short output should NOT have the RecallContext hint
537        assert!(!summary.contains("RecallContext"));
538    }
539
540    #[test]
541    fn test_format_summary_long_stdout_truncated() {
542        let stdout: Vec<String> = (0..100).map(|i| format!("line {i}")).collect();
543        let stderr: Vec<String> = vec!["warning: something".into()];
544        let summary = format_summary(0, &stdout, &stderr);
545        // Should contain last 20 lines
546        assert!(summary.contains("line 99"));
547        assert!(summary.contains("line 80"));
548        // Should NOT contain early lines
549        assert!(!summary.contains("line 0\n"));
550        // Should show truncation metadata
551        assert!(summary.contains("last 20 of 100"));
552        // Stderr should be fully included
553        assert!(summary.contains("warning: something"));
554        // Should have RecallContext hint
555        assert!(summary.contains("RecallContext"));
556    }
557
558    #[test]
559    fn test_format_full_output_includes_everything() {
560        let stdout: Vec<String> = (0..100).map(|i| format!("line {i}")).collect();
561        let stderr: Vec<String> = vec!["err1".into(), "err2".into()];
562        let full = format_full_output(1, &stdout, &stderr);
563        assert!(full.contains("Exit code: 1"));
564        assert!(full.contains("line 0"));
565        assert!(full.contains("line 99"));
566        assert!(full.contains("err1"));
567        assert!(full.contains("err2"));
568    }
569
570    #[test]
571    fn test_format_full_output_capped_at_2mb() {
572        // Each line is ~16 bytes; 200K lines ≈ 3.2 MB → should truncate.
573        let stdout: Vec<String> = (0..200_000).map(|i| format!("line {i}: padding")).collect();
574        let full = format_full_output(0, &stdout, &[]);
575        assert!(full.len() <= 2 * 1024 * 1024 + 50); // 2MB + truncation message
576        assert!(full.contains("truncated at 2MB"));
577    }
578
579    #[test]
580    fn test_shell_output_has_full_output() {
581        // Verify ShellOutput struct works correctly
582        let so = ShellOutput {
583            summary: "Exit code: 0".into(),
584            full_output: Some("full output here".into()),
585        };
586        assert_eq!(so.summary, "Exit code: 0");
587        assert_eq!(so.full_output.unwrap(), "full output here");
588    }
589
590    #[tokio::test]
591    async fn collection_stops_at_max_lines() {
592        let tmp = tempfile::tempdir().unwrap();
593        // Generate 50 lines of output but cap collection at 10.
594        let args = serde_json::json!({
595            "command": "seq 1 50"
596        });
597        let result = run_shell_command(
598            tmp.path(),
599            &args,
600            10,
601            &bg(),
602            None,
603            &crate::trust::TrustMode::Safe,
604        )
605        .await
606        .unwrap();
607        // Summary should reflect that we only collected 10 lines.
608        assert!(
609            result.summary.contains("stdout: 10 lines"),
610            "Expected 10 collected lines, got: {}",
611            result.summary
612        );
613        // Full output should NOT contain lines beyond the cap.
614        let full = result.full_output.unwrap();
615        assert!(full.contains("1"), "Should contain first line");
616        assert!(!full.contains("\n50\n"), "Should NOT contain line 50");
617    }
618
619    #[test]
620    fn test_timeout_capped_at_max() {
621        let args = serde_json::json!({"command": "echo hi", "timeout": 99999});
622        let t = args["timeout"]
623            .as_u64()
624            .unwrap_or(DEFAULT_TIMEOUT_SECS)
625            .min(MAX_TIMEOUT_SECS);
626        assert_eq!(t, MAX_TIMEOUT_SECS);
627    }
628
629    #[tokio::test]
630    async fn streaming_emits_lines_to_sink() {
631        use std::sync::{Arc, Mutex};
632
633        /// Collects ToolOutputLine events for testing.
634        #[derive(Debug, Default)]
635        struct CaptureSink {
636            lines: Mutex<Vec<(String, bool)>>,
637        }
638        impl crate::engine::EngineSink for CaptureSink {
639            fn emit(&self, event: EngineEvent) {
640                if let EngineEvent::ToolOutputLine {
641                    line, is_stderr, ..
642                } = event
643                {
644                    self.lines.lock().unwrap().push((line, is_stderr));
645                }
646            }
647        }
648
649        let tmp = tempfile::tempdir().unwrap();
650        let sink = Arc::new(CaptureSink::default());
651        let args = serde_json::json!({
652            "command": "echo alpha && echo bravo && echo charlie >&2"
653        });
654        let result = run_shell_command(
655            tmp.path(),
656            &args,
657            256,
658            &bg(),
659            Some((sink.as_ref(), "test_id")),
660            &crate::trust::TrustMode::Safe,
661        )
662        .await
663        .unwrap();
664
665        // Summary should contain the output
666        assert!(result.summary.contains("alpha"));
667        assert!(result.summary.contains("bravo"));
668        assert!(result.summary.contains("charlie"));
669
670        // Full output should contain everything
671        let full = result.full_output.unwrap();
672        assert!(full.contains("alpha"));
673        assert!(full.contains("bravo"));
674        assert!(full.contains("charlie"));
675
676        // Streaming lines should have been emitted
677        let lines = sink.lines.lock().unwrap();
678        assert!(
679            lines.len() >= 3,
680            "Expected at least 3 streamed lines, got {}: {lines:?}",
681            lines.len()
682        );
683        // At least one stdout and one stderr line
684        assert!(lines.iter().any(|(_, is_stderr)| !is_stderr));
685        assert!(lines.iter().any(|(_, is_stderr)| *is_stderr));
686    }
687}