Skip to main content

ai_agent/utils/
prompt_shell_execution.rs

1// Source: ~/claudecode/openclaudecode/src/utils/promptShellExecution.ts
2//! Skill prompt shell command execution.
3//!
4//! Parses skill markdown content and executes embedded shell commands.
5//! Supports two syntaxes:
6//! - Code blocks: ```! command ```
7//! - Inline: !`command`
8//!
9//! Results are substituted back into the prompt text.
10
11use crate::error::AgentError;
12use futures_util::future::join_all;
13use log::warn;
14use regex::Regex;
15use std::process::Command;
16use tokio::time::timeout;
17
18/// Regex for code block shell commands: ```! command ```
19fn block_pattern() -> &'static Regex {
20    lazy_static::lazy_static! {
21        static ref BLOCK: Regex = Regex::new(r"```\!\s*\n?([\s\S]*?)\n?```").unwrap();
22    }
23    &BLOCK
24}
25
26/// Regex for inline shell commands: !`command`
27/// Requires whitespace or start-of-line before ! to prevent false matches.
28/// Uses (^|\s) capture group instead of lookbehind (Rust regex requires fixed-width).
29fn inline_pattern() -> &'static Regex {
30    lazy_static::lazy_static! {
31        static ref INLINE: Regex = Regex::new(r"(^|\s)!`([^`]+)`").unwrap();
32    }
33    &INLINE
34}
35
36/// Shell type from skill frontmatter
37#[derive(Debug, Clone, PartialEq, Eq, Default)]
38pub enum FrontmatterShell {
39    #[default]
40    Bash,
41    PowerShell,
42}
43
44impl FrontmatterShell {
45    pub fn from_str(s: &str) -> Self {
46        match s.to_lowercase().as_str() {
47            "powershell" => FrontmatterShell::PowerShell,
48            _ => FrontmatterShell::Bash,
49        }
50    }
51}
52
53/// Result of executing a single shell command
54struct ShellOutput {
55    stdout: String,
56    stderr: String,
57}
58
59/// Format shell output for inline or block context
60fn format_shell_output(stdout: &str, stderr: &str, inline: bool) -> String {
61    let mut parts = Vec::new();
62
63    if !stdout.trim().is_empty() {
64        parts.push(stdout.trim().to_string());
65    }
66
67    if !stderr.trim().is_empty() {
68        if inline {
69            parts.push(format!("[stderr: {}]", stderr.trim()));
70        } else {
71            parts.push(format!("[stderr]\n{}", stderr.trim()));
72        }
73    }
74
75    if inline {
76        parts.join(" ")
77    } else {
78        parts.join("\n")
79    }
80}
81
82/// Execute a single shell command, returning output
83async fn execute_single_command(
84    command: String,
85    shell_bin: String,
86    shell_arg: String,
87    _tool_name: String,
88) -> Result<ShellOutput, String> {
89    let result = timeout(
90        std::time::Duration::from_secs(30),
91        tokio::task::spawn_blocking(move || {
92            let output = Command::new(&shell_bin)
93                .args([&shell_arg, &command])
94                .output()
95                .map_err(|e| format!("Failed to spawn shell: {}", e))?;
96
97            if !output.status.success() {
98                let stderr = String::from_utf8_lossy(&output.stderr).to_string();
99                let stdout = String::from_utf8_lossy(&output.stdout).to_string();
100                return Err(format!(
101                    "Shell command failed (exit {}): {}",
102                    output.status,
103                    if !stderr.is_empty() { stderr } else { stdout }
104                ));
105            }
106
107            Ok(ShellOutput {
108                stdout: String::from_utf8_lossy(&output.stdout).to_string(),
109                stderr: String::from_utf8_lossy(&output.stderr).to_string(),
110            })
111        }),
112    )
113    .await;
114
115    match result {
116        Ok(Ok(Ok(output))) => Ok(output),
117        Ok(Ok(Err(e))) => Err(e),
118        Ok(Err(join_err)) => Err(format!("Shell task failed: {}", join_err)),
119        Err(_) => Err("Shell command timed out (30s)".to_string()),
120    }
121}
122
123/// Resolve the shell binary and argument based on the requested `FrontmatterShell`.
124///
125/// When `FrontmatterShell::PowerShell` is requested, attempts to find `pwsh`.
126/// Falls back to `bash -c` with a warning log if `pwsh` is not available.
127fn resolve_shell_tool(shell: &FrontmatterShell) -> (String, String, String) {
128    match shell {
129        FrontmatterShell::Bash => ("bash".to_string(), "-c".to_string(), "Bash".to_string()),
130        FrontmatterShell::PowerShell => {
131            if which_command("pwsh").is_some() {
132                ("pwsh".to_string(), "-c".to_string(), "PowerShell".to_string())
133            } else {
134                warn!(
135                    "PowerShell shell requested but 'pwsh' is not available, falling back to bash"
136                );
137                ("bash".to_string(), "-c".to_string(), "Bash".to_string())
138            }
139        }
140    }
141}
142
143/// Internal: resolve shell with an optional custom PATH override (for testing)
144#[allow(dead_code)]
145fn resolve_shell_tool_with_path(
146    shell: &FrontmatterShell,
147    path_override: &std::ffi::OsStr,
148) -> (String, String, String) {
149    match shell {
150        FrontmatterShell::Bash => ("bash".to_string(), "-c".to_string(), "Bash".to_string()),
151        FrontmatterShell::PowerShell => {
152            if which_command_in_path("pwsh", path_override).is_some() {
153                ("pwsh".to_string(), "-c".to_string(), "PowerShell".to_string())
154            } else {
155                ("bash".to_string(), "-c".to_string(), "Bash".to_string())
156            }
157        }
158    }
159}
160
161/// Check if a command exists in PATH (cross-platform helper for resolving `pwsh`)
162fn which_command(cmd: &str) -> Option<std::path::PathBuf> {
163    let path_var = std::env::var_os("PATH")?;
164    for dir in std::env::split_paths(&path_var) {
165        let full = dir.join(cmd);
166        if full.is_file() {
167            return Some(full);
168        }
169    }
170    None
171}
172
173/// Internal: check if a command exists in a specific PATH string (for testing)
174#[allow(dead_code)]
175fn which_command_in_path(cmd: &str, path_env: &std::ffi::OsStr) -> Option<std::path::PathBuf> {
176    for dir in std::env::split_paths(path_env) {
177        let full = dir.join(cmd);
178        if full.is_file() {
179            return Some(full);
180        }
181    }
182    None
183}
184
185/// Parse shell commands from text and execute them, substituting output back.
186///
187/// Scans for both block (` ```! ```) and inline (`!```) patterns.
188/// Commands are executed in parallel. On failure, the error is substituted
189/// back into the text in place of the command.
190///
191/// The optional `can_execute` callback is invoked before each command to check
192/// permissions. If it returns `false`, the command is skipped and `[Permission
193/// denied]` is substituted instead. The callback receives `(command, tool_name)`
194/// where `tool_name` is "Bash" or "PowerShell".
195pub async fn execute_shell_commands_in_prompt<F>(
196    text: &str,
197    shell: &FrontmatterShell,
198    skill_name: &str,
199    can_execute: Option<&F>,
200) -> String
201where
202    F: Fn(&str, &str) -> bool + Send + Sync + ?Sized,
203{
204    // Collect all matches with their positions and types
205    let mut matches: Vec<(usize, usize, String, bool)> = Vec::new();
206
207    for cap in block_pattern().captures_iter(text) {
208        if let Some(full) = cap.get(0) {
209            matches.push((full.start(), full.end(), full.as_str().to_string(), false));
210        }
211    }
212
213    if text.contains("!`") {
214        for cap in inline_pattern().captures_iter(text) {
215            if let (Some(full), Some(prefix)) = (cap.get(0), cap.get(1)) {
216                // Start from the `!` character, not the whitespace prefix
217                let pattern_start = prefix.end();
218                let pattern = text[pattern_start..full.end()].to_string();
219                matches.push((pattern_start, full.end(), pattern, true));
220            }
221        }
222    }
223
224    if matches.is_empty() {
225        return text.to_string();
226    }
227
228    // Build command list
229    let commands: Vec<(String, String, bool)> = matches
230        .iter()
231        .map(|(_, _, pattern, inline)| {
232            let command = if *inline {
233                // Pattern is !`command` - extract between !` and `
234                if let Some(stripped) = pattern.strip_prefix("!`") {
235                    stripped.strip_suffix('`')
236                        .map(|s| s.trim().to_string())
237                        .unwrap_or_default()
238                } else {
239                    String::new()
240                }
241            } else {
242                block_pattern()
243                    .captures(pattern)
244                    .and_then(|c| c.get(1))
245                    .map(|m| m.as_str().trim().to_string())
246                    .unwrap_or_default()
247            };
248            (pattern.clone(), command, *inline)
249        })
250        .collect();
251
252    // Resolve shell binary
253    let (shell_bin, shell_arg, tool_name) = resolve_shell_tool(shell);
254
255    // Execute all commands in parallel
256    let futures: Vec<_> = commands
257        .into_iter()
258        .map(|(pattern, command, inline)| {
259            let shell_bin = shell_bin.to_string();
260            let shell_arg = shell_arg.to_string();
261            let tool_name = tool_name.to_string();
262            let skill_name = skill_name.to_string();
263            async move {
264                if command.is_empty() {
265                    return (pattern.clone(), pattern);
266                }
267
268                // Permission check
269                if let Some(ref cb) = can_execute {
270                    if !cb(&command, &tool_name) {
271                        warn!(
272                            "Shell command permission denied in skill '{}': {}",
273                            skill_name, command
274                        );
275                        return (pattern.clone(), "[Permission denied]".to_string());
276                    }
277                }
278
279                match execute_single_command(command, shell_bin, shell_arg, tool_name).await {
280                    Ok(output) => {
281                        let formatted =
282                            format_shell_output(&output.stdout, &output.stderr, inline);
283                        (pattern.clone(), formatted)
284                    }
285                    Err(e) => {
286                        let error_msg = if inline {
287                            format!("[Error: {}]", e)
288                        } else {
289                            format!("[Error]\n{}", e)
290                        };
291                        (pattern.clone(), error_msg)
292                    }
293                }
294            }
295        })
296        .collect();
297
298    let mut results: Vec<(String, String)> = join_all(futures).await;
299
300    // Build result by replacing matches in reverse order to preserve positions
301    let mut result = text.to_string();
302    for (start, end, pattern, _) in matches.iter().rev() {
303        if let Some(pos) = results.iter().position(|(p, _)| p == pattern) {
304            let (_, replacement) = results.remove(pos);
305            result.replace_range(*start..*end, &replacement);
306        }
307    }
308
309    result
310}
311
312// ============================================================================
313// Legacy helpers (kept for backwards compatibility)
314// ============================================================================
315
316/// Execute a shell command and return the result (legacy API)
317pub async fn execute_prompt_shell(command: &str) -> Result<String, String> {
318    let output = Command::new("sh")
319        .args(["-c", command])
320        .output()
321        .map_err(|e| e.to_string())?;
322
323    if output.status.success() {
324        Ok(String::from_utf8_lossy(&output.stdout).to_string())
325    } else {
326        Err(String::from_utf8_lossy(&output.stderr).to_string())
327    }
328}
329
330/// Build a shell command with proper escaping (legacy API)
331pub fn build_shell_command(program: &str, args: &[&str]) -> String {
332    let mut cmd = program.to_string();
333    for arg in args {
334        cmd.push(' ');
335        cmd.push_str(&shell_escape(arg));
336    }
337    cmd
338}
339
340fn shell_escape(s: &str) -> String {
341    if s.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.') {
342        s.to_string()
343    } else {
344        format!("'{}'", s.replace('\'', "'\\''"))
345    }
346}
347
348/// Check if a shell command in a skill should be allowed.
349///
350/// This is a synchronous pre-check. The actual permission gating
351/// happens in `execute_shell_commands_in_prompt` via the `can_execute` callback.
352/// This function validates that the tool name is recognized.
353pub fn can_execute_skill_shell(_command: &str, tool_name: &str) -> Result<(), AgentError> {
354    match tool_name {
355        "Bash" | "bash" | "PowerShell" | "powershell" => Ok(()),
356        _ => Err(AgentError::Tool(format!("Unsupported shell tool: {}", tool_name))),
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363
364    #[test]
365    fn test_block_pattern_matches() {
366        let text = "```!\necho hello\n```";
367        assert!(block_pattern().is_match(text));
368        let cap = block_pattern().captures(text).unwrap();
369        assert!(cap.get(1).is_some());
370    }
371
372    #[test]
373    fn test_block_pattern_multiline() {
374        let text = "```!\necho hello\necho world\n```";
375        let cap = block_pattern().captures(text).unwrap();
376        let cmd = cap.get(1).unwrap().as_str().trim();
377        assert_eq!(cmd, "echo hello\necho world");
378    }
379
380    #[test]
381    fn test_inline_pattern_matches() {
382        assert!(inline_pattern().is_match("Run !`ls` to see files"));
383    }
384
385    #[test]
386    fn test_inline_pattern_no_match_without_whitespace() {
387        assert!(!inline_pattern().is_match("x!`this`"));
388    }
389
390    #[test]
391    fn test_inline_pattern_extract_command() {
392        let cap = inline_pattern().captures("Run !`echo hi` now").unwrap();
393        assert_eq!(cap.get(2).unwrap().as_str(), "echo hi");
394    }
395
396    #[test]
397    fn test_format_shell_output_stdout_only() {
398        assert_eq!(format_shell_output("hello world", "", false), "hello world");
399    }
400
401    #[test]
402    fn test_format_shell_output_with_stderr_block() {
403        assert_eq!(
404            format_shell_output("stdout", "stderr msg", false),
405            "stdout\n[stderr]\nstderr msg"
406        );
407    }
408
409    #[test]
410    fn test_format_shell_output_with_stderr_inline() {
411        assert_eq!(
412            format_shell_output("stdout", "stderr msg", true),
413            "stdout [stderr: stderr msg]"
414        );
415    }
416
417    #[test]
418    fn test_format_shell_output_empty() {
419        assert_eq!(format_shell_output("", "", false), "");
420    }
421
422    #[tokio::test]
423    async fn test_execute_block_command() {
424        let result = execute_shell_commands_in_prompt(
425            "Before ```!\necho hello\n``` After",
426            &FrontmatterShell::Bash,
427            "test-skill",
428            None::<&(dyn Fn(&str, &str) -> bool + Send + Sync)>,
429        )
430        .await;
431        assert!(result.contains("hello"));
432        assert!(result.contains("Before"));
433        assert!(result.contains("After"));
434        assert!(!result.contains("```!"));
435    }
436
437    #[tokio::test]
438    async fn test_execute_inline_command() {
439        let result = execute_shell_commands_in_prompt(
440            "Count: !`echo 42` items",
441            &FrontmatterShell::Bash,
442            "test-skill",
443            None::<&(dyn Fn(&str, &str) -> bool + Send + Sync)>,
444        )
445        .await;
446        assert!(result.contains("42"));
447        assert!(!result.contains("!`echo 42`"));
448    }
449
450    #[tokio::test]
451    async fn test_no_shell_commands() {
452        let text = "This is plain text with no commands";
453        let result = execute_shell_commands_in_prompt(
454            text,
455            &FrontmatterShell::Bash,
456            "test",
457            None::<&(dyn Fn(&str, &str) -> bool + Send + Sync)>,
458        )
459        .await;
460        assert_eq!(result, text);
461    }
462
463    #[tokio::test]
464    async fn test_failed_command_substitutes_error() {
465        let result =
466            execute_shell_commands_in_prompt("```!\nexit 1\n```", &FrontmatterShell::Bash, "test", None::<&(dyn Fn(&str, &str) -> bool + Send + Sync)>)
467                .await;
468        assert!(result.contains("[Error]"));
469        assert!(!result.contains("```!"));
470    }
471
472    #[tokio::test]
473    async fn test_multiple_commands() {
474        let result = execute_shell_commands_in_prompt(
475            "A ```!\necho one\n``` B !`echo two` C",
476            &FrontmatterShell::Bash,
477            "test-skill",
478            None::<&(dyn Fn(&str, &str) -> bool + Send + Sync)>,
479        )
480        .await;
481        assert!(result.contains("one"));
482        assert!(result.contains("two"));
483        assert!(result.contains("A"));
484        assert!(result.contains("B"));
485        assert!(result.contains("C"));
486    }
487
488    #[tokio::test]
489    async fn test_command_with_stderr() {
490        let result = execute_shell_commands_in_prompt(
491            "```!\necho out && echo err >&2\n```",
492            &FrontmatterShell::Bash,
493            "test-skill",
494            None::<&(dyn Fn(&str, &str) -> bool + Send + Sync)>,
495        )
496        .await;
497        assert!(result.contains("out"));
498        assert!(result.contains("err") || result.contains("[stderr]"));
499    }
500
501    #[test]
502    fn test_frontmatter_shell_from_str() {
503        assert_eq!(FrontmatterShell::from_str("bash"), FrontmatterShell::Bash);
504        assert_eq!(
505            FrontmatterShell::from_str("powershell"),
506            FrontmatterShell::PowerShell
507        );
508        assert_eq!(FrontmatterShell::from_str("unknown"), FrontmatterShell::Bash);
509        assert_eq!(FrontmatterShell::from_str(""), FrontmatterShell::Bash);
510    }
511
512    #[test]
513    fn test_shell_escape_safe() {
514        assert_eq!(shell_escape("hello"), "hello");
515    }
516
517    #[test]
518    fn test_shell_escape_needs_quotes() {
519        // "he'llo" -> replace ' -> '\\'' -> he'\\''llo -> 'he'\\''llo'
520        assert_eq!(shell_escape("he'llo"), "'he'\\''llo'");
521    }
522
523    #[test]
524    fn test_build_shell_command() {
525        assert_eq!(build_shell_command("echo", &["hello", "world"]), "echo hello world");
526    }
527
528    #[tokio::test]
529    async fn test_execute_prompt_shell() {
530        let result = execute_prompt_shell("echo -n test").await;
531        assert_eq!(result.unwrap(), "test");
532    }
533
534    #[test]
535    fn test_can_execute_skill_shell() {
536        assert!(can_execute_skill_shell("echo hello", "Bash").is_ok());
537    }
538
539    #[test]
540    fn test_can_execute_skill_shell_unsupported_tool() {
541        assert!(can_execute_skill_shell("echo hello", "Fish").is_err());
542    }
543
544    #[test]
545    fn test_can_execute_skill_shell_powershell() {
546        assert!(can_execute_skill_shell("Write-Host hello", "PowerShell").is_ok());
547    }
548
549    /// Test that denied commands show [Permission denied]
550    #[tokio::test]
551    async fn test_permission_denied_substitutes_message() {
552        // Callback that denies all commands
553        let deny_all = |_cmd: &str, _tool: &str| false;
554        let result = execute_shell_commands_in_prompt(
555            "Before ```!\necho hello\n``` After",
556            &FrontmatterShell::Bash,
557            "test-skill",
558            Some(&deny_all),
559        )
560        .await;
561        assert!(result.contains("[Permission denied]"));
562        assert!(result.contains("Before"));
563        assert!(result.contains("After"));
564        assert!(!result.contains("hello"));
565    }
566
567    /// Test that denied inline commands show [Permission denied]
568    #[tokio::test]
569    async fn test_permission_denied_inline_substitutes_message() {
570        let deny_all = |_cmd: &str, _tool: &str| false;
571        let result = execute_shell_commands_in_prompt(
572            "Count: !`echo 42` items",
573            &FrontmatterShell::Bash,
574            "test-skill",
575            Some(&deny_all),
576        )
577        .await;
578        assert!(result.contains("[Permission denied]"));
579        assert!(!result.contains("42"));
580        assert!(!result.contains("!`echo 42`"));
581    }
582
583    /// Test that allowed commands run normally with can_execute callback
584    #[tokio::test]
585    async fn test_permission_allowed_executes() {
586        let allow_all = |_cmd: &str, _tool: &str| true;
587        let result = execute_shell_commands_in_prompt(
588            "Before ```!\necho hello\n``` After",
589            &FrontmatterShell::Bash,
590            "test-skill",
591            Some(&allow_all),
592        )
593        .await;
594        assert!(result.contains("hello"));
595        assert!(result.contains("Before"));
596        assert!(result.contains("After"));
597        assert!(!result.contains("[Permission denied]"));
598    }
599
600    /// Test that allowed inline commands run normally
601    #[tokio::test]
602    async fn test_permission_allowed_inline_executes() {
603        let allow_all = |_cmd: &str, _tool: &str| true;
604        let result = execute_shell_commands_in_prompt(
605            "Count: !`echo 42` items",
606            &FrontmatterShell::Bash,
607            "test-skill",
608            Some(&allow_all),
609        )
610        .await;
611        assert!(result.contains("42"));
612        assert!(!result.contains("[Permission denied]"));
613    }
614
615    /// Test selective allow/deny: allow only "echo" commands
616    #[tokio::test]
617    async fn test_permission_selective() {
618        let selective = |cmd: &str, _tool: &str| cmd.starts_with("echo");
619        let result = execute_shell_commands_in_prompt(
620            "A ```!\necho one\n``` B ```!\nexit 1\n```",
621            &FrontmatterShell::Bash,
622            "test-skill",
623            Some(&selective),
624        )
625        .await;
626        assert!(result.contains("one"));
627        assert!(result.contains("[Permission denied]"));
628    }
629
630    /// Test that PowerShell falls back to bash when pwsh is not available.
631    /// Uses a synthetic PATH that contains no pwsh binary.
632    #[test]
633    fn test_powershell_fallback_to_bash() {
634        // Use a PATH that definitely has no pwsh
635        let fake_path = std::ffi::OsStr::new("/nonexistent/path");
636        let (bin, arg, tool) =
637            resolve_shell_tool_with_path(&FrontmatterShell::PowerShell, fake_path);
638        assert_eq!(bin, "bash");
639        assert_eq!(arg, "-c");
640        assert_eq!(tool, "Bash");
641    }
642
643    /// Test that PowerShell resolves to pwsh when available
644    #[test]
645    fn test_powershell_resolves_when_pwsh_available() {
646        let current_path = std::env::var_os("PATH");
647        if let Some(ref p) = current_path {
648            let (bin, _arg, tool) =
649                resolve_shell_tool_with_path(&FrontmatterShell::PowerShell, p.as_ref());
650            if which_command_in_path("pwsh", p.as_ref()).is_some() {
651                // pwsh is available
652                assert_eq!(bin, "pwsh");
653                assert_eq!(tool, "PowerShell");
654            } else {
655                // pwsh is not available - falls back to bash
656                assert_eq!(bin, "bash");
657                assert_eq!(tool, "Bash");
658            }
659        }
660        // If PATH is not set, there's nothing to assert
661    }
662
663    /// Test resolve_shell_tool for bash
664    #[test]
665    fn test_resolve_shell_bash() {
666        let (bin, arg, tool) = resolve_shell_tool(&FrontmatterShell::Bash);
667        assert_eq!(bin, "bash");
668        assert_eq!(arg, "-c");
669        assert_eq!(tool, "Bash");
670    }
671}