1use crate::error::AgentError;
12use futures_util::future::join_all;
13use log::warn;
14use regex::Regex;
15use std::process::Command;
16use tokio::time::timeout;
17
18fn block_pattern() -> &'static Regex {
20 lazy_static::lazy_static! {
21 static ref BLOCK: Regex = Regex::new(r"```\!\s*\n?([\s\S]*?)\n?```").unwrap();
22 }
23 &BLOCK
24}
25
26fn inline_pattern() -> &'static Regex {
30 lazy_static::lazy_static! {
31 static ref INLINE: Regex = Regex::new(r"(^|\s)!`([^`]+)`").unwrap();
32 }
33 &INLINE
34}
35
36#[derive(Debug, Clone, PartialEq, Eq, Default)]
38pub enum FrontmatterShell {
39 #[default]
40 Bash,
41 PowerShell,
42}
43
44impl FrontmatterShell {
45 pub fn from_str(s: &str) -> Self {
46 match s.to_lowercase().as_str() {
47 "powershell" => FrontmatterShell::PowerShell,
48 _ => FrontmatterShell::Bash,
49 }
50 }
51}
52
53struct ShellOutput {
55 stdout: String,
56 stderr: String,
57}
58
59fn format_shell_output(stdout: &str, stderr: &str, inline: bool) -> String {
61 let mut parts = Vec::new();
62
63 if !stdout.trim().is_empty() {
64 parts.push(stdout.trim().to_string());
65 }
66
67 if !stderr.trim().is_empty() {
68 if inline {
69 parts.push(format!("[stderr: {}]", stderr.trim()));
70 } else {
71 parts.push(format!("[stderr]\n{}", stderr.trim()));
72 }
73 }
74
75 if inline {
76 parts.join(" ")
77 } else {
78 parts.join("\n")
79 }
80}
81
82async fn execute_single_command(
84 command: String,
85 shell_bin: String,
86 shell_arg: String,
87 _tool_name: String,
88) -> Result<ShellOutput, String> {
89 let result = timeout(
90 std::time::Duration::from_secs(30),
91 tokio::task::spawn_blocking(move || {
92 let output = Command::new(&shell_bin)
93 .args([&shell_arg, &command])
94 .output()
95 .map_err(|e| format!("Failed to spawn shell: {}", e))?;
96
97 if !output.status.success() {
98 let stderr = String::from_utf8_lossy(&output.stderr).to_string();
99 let stdout = String::from_utf8_lossy(&output.stdout).to_string();
100 return Err(format!(
101 "Shell command failed (exit {}): {}",
102 output.status,
103 if !stderr.is_empty() { stderr } else { stdout }
104 ));
105 }
106
107 Ok(ShellOutput {
108 stdout: String::from_utf8_lossy(&output.stdout).to_string(),
109 stderr: String::from_utf8_lossy(&output.stderr).to_string(),
110 })
111 }),
112 )
113 .await;
114
115 match result {
116 Ok(Ok(Ok(output))) => Ok(output),
117 Ok(Ok(Err(e))) => Err(e),
118 Ok(Err(join_err)) => Err(format!("Shell task failed: {}", join_err)),
119 Err(_) => Err("Shell command timed out (30s)".to_string()),
120 }
121}
122
123fn resolve_shell_tool(shell: &FrontmatterShell) -> (String, String, String) {
128 match shell {
129 FrontmatterShell::Bash => ("bash".to_string(), "-c".to_string(), "Bash".to_string()),
130 FrontmatterShell::PowerShell => {
131 if which_command("pwsh").is_some() {
132 ("pwsh".to_string(), "-c".to_string(), "PowerShell".to_string())
133 } else {
134 warn!(
135 "PowerShell shell requested but 'pwsh' is not available, falling back to bash"
136 );
137 ("bash".to_string(), "-c".to_string(), "Bash".to_string())
138 }
139 }
140 }
141}
142
143#[allow(dead_code)]
145fn resolve_shell_tool_with_path(
146 shell: &FrontmatterShell,
147 path_override: &std::ffi::OsStr,
148) -> (String, String, String) {
149 match shell {
150 FrontmatterShell::Bash => ("bash".to_string(), "-c".to_string(), "Bash".to_string()),
151 FrontmatterShell::PowerShell => {
152 if which_command_in_path("pwsh", path_override).is_some() {
153 ("pwsh".to_string(), "-c".to_string(), "PowerShell".to_string())
154 } else {
155 ("bash".to_string(), "-c".to_string(), "Bash".to_string())
156 }
157 }
158 }
159}
160
161fn which_command(cmd: &str) -> Option<std::path::PathBuf> {
163 let path_var = std::env::var_os("PATH")?;
164 for dir in std::env::split_paths(&path_var) {
165 let full = dir.join(cmd);
166 if full.is_file() {
167 return Some(full);
168 }
169 }
170 None
171}
172
173#[allow(dead_code)]
175fn which_command_in_path(cmd: &str, path_env: &std::ffi::OsStr) -> Option<std::path::PathBuf> {
176 for dir in std::env::split_paths(path_env) {
177 let full = dir.join(cmd);
178 if full.is_file() {
179 return Some(full);
180 }
181 }
182 None
183}
184
185pub async fn execute_shell_commands_in_prompt<F>(
196 text: &str,
197 shell: &FrontmatterShell,
198 skill_name: &str,
199 can_execute: Option<&F>,
200) -> String
201where
202 F: Fn(&str, &str) -> bool + Send + Sync + ?Sized,
203{
204 let mut matches: Vec<(usize, usize, String, bool)> = Vec::new();
206
207 for cap in block_pattern().captures_iter(text) {
208 if let Some(full) = cap.get(0) {
209 matches.push((full.start(), full.end(), full.as_str().to_string(), false));
210 }
211 }
212
213 if text.contains("!`") {
214 for cap in inline_pattern().captures_iter(text) {
215 if let (Some(full), Some(prefix)) = (cap.get(0), cap.get(1)) {
216 let pattern_start = prefix.end();
218 let pattern = text[pattern_start..full.end()].to_string();
219 matches.push((pattern_start, full.end(), pattern, true));
220 }
221 }
222 }
223
224 if matches.is_empty() {
225 return text.to_string();
226 }
227
228 let commands: Vec<(String, String, bool)> = matches
230 .iter()
231 .map(|(_, _, pattern, inline)| {
232 let command = if *inline {
233 if let Some(stripped) = pattern.strip_prefix("!`") {
235 stripped.strip_suffix('`')
236 .map(|s| s.trim().to_string())
237 .unwrap_or_default()
238 } else {
239 String::new()
240 }
241 } else {
242 block_pattern()
243 .captures(pattern)
244 .and_then(|c| c.get(1))
245 .map(|m| m.as_str().trim().to_string())
246 .unwrap_or_default()
247 };
248 (pattern.clone(), command, *inline)
249 })
250 .collect();
251
252 let (shell_bin, shell_arg, tool_name) = resolve_shell_tool(shell);
254
255 let futures: Vec<_> = commands
257 .into_iter()
258 .map(|(pattern, command, inline)| {
259 let shell_bin = shell_bin.to_string();
260 let shell_arg = shell_arg.to_string();
261 let tool_name = tool_name.to_string();
262 let skill_name = skill_name.to_string();
263 async move {
264 if command.is_empty() {
265 return (pattern.clone(), pattern);
266 }
267
268 if let Some(ref cb) = can_execute {
270 if !cb(&command, &tool_name) {
271 warn!(
272 "Shell command permission denied in skill '{}': {}",
273 skill_name, command
274 );
275 return (pattern.clone(), "[Permission denied]".to_string());
276 }
277 }
278
279 match execute_single_command(command, shell_bin, shell_arg, tool_name).await {
280 Ok(output) => {
281 let formatted =
282 format_shell_output(&output.stdout, &output.stderr, inline);
283 (pattern.clone(), formatted)
284 }
285 Err(e) => {
286 let error_msg = if inline {
287 format!("[Error: {}]", e)
288 } else {
289 format!("[Error]\n{}", e)
290 };
291 (pattern.clone(), error_msg)
292 }
293 }
294 }
295 })
296 .collect();
297
298 let mut results: Vec<(String, String)> = join_all(futures).await;
299
300 let mut result = text.to_string();
302 for (start, end, pattern, _) in matches.iter().rev() {
303 if let Some(pos) = results.iter().position(|(p, _)| p == pattern) {
304 let (_, replacement) = results.remove(pos);
305 result.replace_range(*start..*end, &replacement);
306 }
307 }
308
309 result
310}
311
312pub async fn execute_prompt_shell(command: &str) -> Result<String, String> {
318 let output = Command::new("sh")
319 .args(["-c", command])
320 .output()
321 .map_err(|e| e.to_string())?;
322
323 if output.status.success() {
324 Ok(String::from_utf8_lossy(&output.stdout).to_string())
325 } else {
326 Err(String::from_utf8_lossy(&output.stderr).to_string())
327 }
328}
329
330pub fn build_shell_command(program: &str, args: &[&str]) -> String {
332 let mut cmd = program.to_string();
333 for arg in args {
334 cmd.push(' ');
335 cmd.push_str(&shell_escape(arg));
336 }
337 cmd
338}
339
340fn shell_escape(s: &str) -> String {
341 if s.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.') {
342 s.to_string()
343 } else {
344 format!("'{}'", s.replace('\'', "'\\''"))
345 }
346}
347
348pub fn can_execute_skill_shell(_command: &str, tool_name: &str) -> Result<(), AgentError> {
354 match tool_name {
355 "Bash" | "bash" | "PowerShell" | "powershell" => Ok(()),
356 _ => Err(AgentError::Tool(format!("Unsupported shell tool: {}", tool_name))),
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 #[test]
365 fn test_block_pattern_matches() {
366 let text = "```!\necho hello\n```";
367 assert!(block_pattern().is_match(text));
368 let cap = block_pattern().captures(text).unwrap();
369 assert!(cap.get(1).is_some());
370 }
371
372 #[test]
373 fn test_block_pattern_multiline() {
374 let text = "```!\necho hello\necho world\n```";
375 let cap = block_pattern().captures(text).unwrap();
376 let cmd = cap.get(1).unwrap().as_str().trim();
377 assert_eq!(cmd, "echo hello\necho world");
378 }
379
380 #[test]
381 fn test_inline_pattern_matches() {
382 assert!(inline_pattern().is_match("Run !`ls` to see files"));
383 }
384
385 #[test]
386 fn test_inline_pattern_no_match_without_whitespace() {
387 assert!(!inline_pattern().is_match("x!`this`"));
388 }
389
390 #[test]
391 fn test_inline_pattern_extract_command() {
392 let cap = inline_pattern().captures("Run !`echo hi` now").unwrap();
393 assert_eq!(cap.get(2).unwrap().as_str(), "echo hi");
394 }
395
396 #[test]
397 fn test_format_shell_output_stdout_only() {
398 assert_eq!(format_shell_output("hello world", "", false), "hello world");
399 }
400
401 #[test]
402 fn test_format_shell_output_with_stderr_block() {
403 assert_eq!(
404 format_shell_output("stdout", "stderr msg", false),
405 "stdout\n[stderr]\nstderr msg"
406 );
407 }
408
409 #[test]
410 fn test_format_shell_output_with_stderr_inline() {
411 assert_eq!(
412 format_shell_output("stdout", "stderr msg", true),
413 "stdout [stderr: stderr msg]"
414 );
415 }
416
417 #[test]
418 fn test_format_shell_output_empty() {
419 assert_eq!(format_shell_output("", "", false), "");
420 }
421
422 #[tokio::test]
423 async fn test_execute_block_command() {
424 let result = execute_shell_commands_in_prompt(
425 "Before ```!\necho hello\n``` After",
426 &FrontmatterShell::Bash,
427 "test-skill",
428 None::<&(dyn Fn(&str, &str) -> bool + Send + Sync)>,
429 )
430 .await;
431 assert!(result.contains("hello"));
432 assert!(result.contains("Before"));
433 assert!(result.contains("After"));
434 assert!(!result.contains("```!"));
435 }
436
437 #[tokio::test]
438 async fn test_execute_inline_command() {
439 let result = execute_shell_commands_in_prompt(
440 "Count: !`echo 42` items",
441 &FrontmatterShell::Bash,
442 "test-skill",
443 None::<&(dyn Fn(&str, &str) -> bool + Send + Sync)>,
444 )
445 .await;
446 assert!(result.contains("42"));
447 assert!(!result.contains("!`echo 42`"));
448 }
449
450 #[tokio::test]
451 async fn test_no_shell_commands() {
452 let text = "This is plain text with no commands";
453 let result = execute_shell_commands_in_prompt(
454 text,
455 &FrontmatterShell::Bash,
456 "test",
457 None::<&(dyn Fn(&str, &str) -> bool + Send + Sync)>,
458 )
459 .await;
460 assert_eq!(result, text);
461 }
462
463 #[tokio::test]
464 async fn test_failed_command_substitutes_error() {
465 let result =
466 execute_shell_commands_in_prompt("```!\nexit 1\n```", &FrontmatterShell::Bash, "test", None::<&(dyn Fn(&str, &str) -> bool + Send + Sync)>)
467 .await;
468 assert!(result.contains("[Error]"));
469 assert!(!result.contains("```!"));
470 }
471
472 #[tokio::test]
473 async fn test_multiple_commands() {
474 let result = execute_shell_commands_in_prompt(
475 "A ```!\necho one\n``` B !`echo two` C",
476 &FrontmatterShell::Bash,
477 "test-skill",
478 None::<&(dyn Fn(&str, &str) -> bool + Send + Sync)>,
479 )
480 .await;
481 assert!(result.contains("one"));
482 assert!(result.contains("two"));
483 assert!(result.contains("A"));
484 assert!(result.contains("B"));
485 assert!(result.contains("C"));
486 }
487
488 #[tokio::test]
489 async fn test_command_with_stderr() {
490 let result = execute_shell_commands_in_prompt(
491 "```!\necho out && echo err >&2\n```",
492 &FrontmatterShell::Bash,
493 "test-skill",
494 None::<&(dyn Fn(&str, &str) -> bool + Send + Sync)>,
495 )
496 .await;
497 assert!(result.contains("out"));
498 assert!(result.contains("err") || result.contains("[stderr]"));
499 }
500
501 #[test]
502 fn test_frontmatter_shell_from_str() {
503 assert_eq!(FrontmatterShell::from_str("bash"), FrontmatterShell::Bash);
504 assert_eq!(
505 FrontmatterShell::from_str("powershell"),
506 FrontmatterShell::PowerShell
507 );
508 assert_eq!(FrontmatterShell::from_str("unknown"), FrontmatterShell::Bash);
509 assert_eq!(FrontmatterShell::from_str(""), FrontmatterShell::Bash);
510 }
511
512 #[test]
513 fn test_shell_escape_safe() {
514 assert_eq!(shell_escape("hello"), "hello");
515 }
516
517 #[test]
518 fn test_shell_escape_needs_quotes() {
519 assert_eq!(shell_escape("he'llo"), "'he'\\''llo'");
521 }
522
523 #[test]
524 fn test_build_shell_command() {
525 assert_eq!(build_shell_command("echo", &["hello", "world"]), "echo hello world");
526 }
527
528 #[tokio::test]
529 async fn test_execute_prompt_shell() {
530 let result = execute_prompt_shell("echo -n test").await;
531 assert_eq!(result.unwrap(), "test");
532 }
533
534 #[test]
535 fn test_can_execute_skill_shell() {
536 assert!(can_execute_skill_shell("echo hello", "Bash").is_ok());
537 }
538
539 #[test]
540 fn test_can_execute_skill_shell_unsupported_tool() {
541 assert!(can_execute_skill_shell("echo hello", "Fish").is_err());
542 }
543
544 #[test]
545 fn test_can_execute_skill_shell_powershell() {
546 assert!(can_execute_skill_shell("Write-Host hello", "PowerShell").is_ok());
547 }
548
549 #[tokio::test]
551 async fn test_permission_denied_substitutes_message() {
552 let deny_all = |_cmd: &str, _tool: &str| false;
554 let result = execute_shell_commands_in_prompt(
555 "Before ```!\necho hello\n``` After",
556 &FrontmatterShell::Bash,
557 "test-skill",
558 Some(&deny_all),
559 )
560 .await;
561 assert!(result.contains("[Permission denied]"));
562 assert!(result.contains("Before"));
563 assert!(result.contains("After"));
564 assert!(!result.contains("hello"));
565 }
566
567 #[tokio::test]
569 async fn test_permission_denied_inline_substitutes_message() {
570 let deny_all = |_cmd: &str, _tool: &str| false;
571 let result = execute_shell_commands_in_prompt(
572 "Count: !`echo 42` items",
573 &FrontmatterShell::Bash,
574 "test-skill",
575 Some(&deny_all),
576 )
577 .await;
578 assert!(result.contains("[Permission denied]"));
579 assert!(!result.contains("42"));
580 assert!(!result.contains("!`echo 42`"));
581 }
582
583 #[tokio::test]
585 async fn test_permission_allowed_executes() {
586 let allow_all = |_cmd: &str, _tool: &str| true;
587 let result = execute_shell_commands_in_prompt(
588 "Before ```!\necho hello\n``` After",
589 &FrontmatterShell::Bash,
590 "test-skill",
591 Some(&allow_all),
592 )
593 .await;
594 assert!(result.contains("hello"));
595 assert!(result.contains("Before"));
596 assert!(result.contains("After"));
597 assert!(!result.contains("[Permission denied]"));
598 }
599
600 #[tokio::test]
602 async fn test_permission_allowed_inline_executes() {
603 let allow_all = |_cmd: &str, _tool: &str| true;
604 let result = execute_shell_commands_in_prompt(
605 "Count: !`echo 42` items",
606 &FrontmatterShell::Bash,
607 "test-skill",
608 Some(&allow_all),
609 )
610 .await;
611 assert!(result.contains("42"));
612 assert!(!result.contains("[Permission denied]"));
613 }
614
615 #[tokio::test]
617 async fn test_permission_selective() {
618 let selective = |cmd: &str, _tool: &str| cmd.starts_with("echo");
619 let result = execute_shell_commands_in_prompt(
620 "A ```!\necho one\n``` B ```!\nexit 1\n```",
621 &FrontmatterShell::Bash,
622 "test-skill",
623 Some(&selective),
624 )
625 .await;
626 assert!(result.contains("one"));
627 assert!(result.contains("[Permission denied]"));
628 }
629
630 #[test]
633 fn test_powershell_fallback_to_bash() {
634 let fake_path = std::ffi::OsStr::new("/nonexistent/path");
636 let (bin, arg, tool) =
637 resolve_shell_tool_with_path(&FrontmatterShell::PowerShell, fake_path);
638 assert_eq!(bin, "bash");
639 assert_eq!(arg, "-c");
640 assert_eq!(tool, "Bash");
641 }
642
643 #[test]
645 fn test_powershell_resolves_when_pwsh_available() {
646 let current_path = std::env::var_os("PATH");
647 if let Some(ref p) = current_path {
648 let (bin, _arg, tool) =
649 resolve_shell_tool_with_path(&FrontmatterShell::PowerShell, p.as_ref());
650 if which_command_in_path("pwsh", p.as_ref()).is_some() {
651 assert_eq!(bin, "pwsh");
653 assert_eq!(tool, "PowerShell");
654 } else {
655 assert_eq!(bin, "bash");
657 assert_eq!(tool, "Bash");
658 }
659 }
660 }
662
663 #[test]
665 fn test_resolve_shell_bash() {
666 let (bin, arg, tool) = resolve_shell_tool(&FrontmatterShell::Bash);
667 assert_eq!(bin, "bash");
668 assert_eq!(arg, "-c");
669 assert_eq!(tool, "Bash");
670 }
671}