Skip to main content

aether_core/core/
prompt.rs

1use crate::core::{AgentError, Result};
2use glob::glob;
3use mcp_utils::client::ServerInstructions;
4use std::collections::HashMap;
5use std::env;
6use std::path::{Path, PathBuf};
7use tokio::fs;
8use tracing::warn;
9use utils::shell_expander::ShellExpander;
10use utils::substitution::substitute_parameters;
11
12#[derive(Debug, Clone)]
13pub enum Prompt {
14    Text(String),
15    File {
16        path: String,
17        args: Option<HashMap<String, String>>,
18        cwd: Option<PathBuf>,
19    },
20    /// Resolve prompt files from glob patterns relative to cwd.
21    /// Absolute paths are also supported.
22    PromptGlobs {
23        patterns: Vec<String>,
24        cwd: PathBuf,
25    },
26    McpInstructions(Vec<ServerInstructions>),
27}
28
29impl Prompt {
30    pub fn text(str: &str) -> Self {
31        Self::Text(str.to_string())
32    }
33
34    pub fn file(path: &str) -> Self {
35        Self::File { path: path.to_string(), args: None, cwd: None }
36    }
37
38    pub fn file_with_args(path: &str, args: HashMap<String, String>) -> Self {
39        Self::File { path: path.to_string(), args: Some(args), cwd: None }
40    }
41
42    pub fn from_globs(patterns: Vec<String>, cwd: PathBuf) -> Self {
43        Self::PromptGlobs { patterns, cwd }
44    }
45
46    pub fn with_cwd(self, cwd: PathBuf) -> Self {
47        match self {
48            Self::File { path, args, .. } => Self::File { path, args, cwd: Some(cwd) },
49            Self::PromptGlobs { patterns, .. } => Self::PromptGlobs { patterns, cwd },
50            Self::Text(_) | Self::McpInstructions(_) => self,
51        }
52    }
53
54    pub fn mcp_instructions(instructions: Vec<ServerInstructions>) -> Self {
55        Self::McpInstructions(instructions)
56    }
57
58    /// Resolve this `SystemPrompt` to a String
59    pub async fn build(&self) -> Result<String> {
60        match self {
61            Prompt::Text(text) => Ok(text.clone()),
62            Prompt::File { path, args, cwd } => {
63                let content = Self::resolve_file(&PathBuf::from(path)).await?;
64                let substituted = substitute_parameters(&content, args);
65                let expander = ShellExpander::new();
66                Self::expand_builtins(&substituted, cwd.as_deref(), &expander).await
67            }
68            Prompt::PromptGlobs { patterns, cwd } => Self::resolve_prompt_globs(patterns, cwd).await,
69            Prompt::McpInstructions(instructions) => Ok(format_mcp_instructions(instructions)),
70        }
71    }
72
73    /// Resolve multiple `SystemPrompts` and join them with double newlines
74    pub async fn build_all(prompts: &[Prompt]) -> Result<String> {
75        let mut parts = Vec::with_capacity(prompts.len());
76        for p in prompts {
77            let part = p.build().await?;
78            if !part.is_empty() {
79                parts.push(part);
80            }
81        }
82        Ok(parts.join("\n\n"))
83    }
84
85    async fn resolve_file(path: &Path) -> Result<String> {
86        fs::read_to_string(path)
87            .await
88            .map_err(|e| AgentError::IoError(format!("Failed to read file '{}': {e}", path.display())))
89    }
90
91    async fn resolve_prompt_globs(patterns: &[String], cwd: &Path) -> Result<String> {
92        let mut contents = Vec::new();
93        let expander = ShellExpander::new();
94
95        for pattern in patterns {
96            let full_pattern = if Path::new(pattern).is_absolute() {
97                pattern.clone()
98            } else {
99                cwd.join(pattern).to_string_lossy().to_string()
100            };
101
102            let paths = glob(&full_pattern)
103                .map_err(|e| AgentError::IoError(format!("Invalid glob pattern '{pattern}': {e}")))?;
104
105            let mut matched: Vec<PathBuf> = paths.filter_map(std::result::Result::ok).collect();
106            matched.sort();
107
108            for path in matched {
109                if path.is_file() {
110                    match fs::read_to_string(&path).await {
111                        Ok(content) => {
112                            let resolved = Self::expand_builtins(&content, Some(cwd), &expander).await?;
113                            contents.push(resolved);
114                        }
115                        Err(e) => {
116                            warn!("Failed to read prompt file '{}': {e}", path.display());
117                        }
118                    }
119                }
120            }
121        }
122
123        Ok(contents.join("\n\n"))
124    }
125
126    /// Expand `` !`command` `` shell-interpolation markers in prompt content.
127    ///
128    /// Thin wrapper around [`ShellExpander::expand`] that resolves `cwd` from
129    /// the process working directory when `None`.
130    async fn expand_builtins(content: &str, cwd: Option<&Path>, expander: &ShellExpander) -> Result<String> {
131        let cwd = match cwd {
132            Some(dir) => dir.to_path_buf(),
133            None => {
134                env::current_dir().map_err(|e| AgentError::IoError(format!("Failed to get current directory: {e}")))?
135            }
136        };
137        Ok(expander.expand(content, &cwd).await)
138    }
139}
140
141/// Format MCP instructions with XML tags for the system prompt.
142fn format_mcp_instructions(instructions: &[ServerInstructions]) -> String {
143    if instructions.is_empty() {
144        return String::new();
145    }
146
147    let mut parts = vec!["# MCP Server Instructions\n".to_string()];
148    parts.push("You are connected to the following MCP servers:\n".to_string());
149
150    for instr in instructions {
151        parts.push(format!("<mcp-server name=\"{}\">\n{}\n</mcp-server>\n", instr.server_name, instr.instructions));
152    }
153
154    parts.join("\n")
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[tokio::test]
162    async fn build_text_prompt() {
163        let prompt = Prompt::text("Hello, world!");
164        let result = prompt.build().await.unwrap();
165        assert_eq!(result, "Hello, world!");
166    }
167
168    #[tokio::test]
169    async fn build_all_concatenates_prompts() {
170        let prompts = vec![Prompt::text("Part one"), Prompt::text("Part two")];
171        let result = Prompt::build_all(&prompts).await.unwrap();
172        assert_eq!(result, "Part one\n\nPart two");
173    }
174
175    #[tokio::test]
176    async fn prompt_globs_resolves_single_file() {
177        let dir = tempfile::tempdir().unwrap();
178        std::fs::write(dir.path().join("AGENTS.md"), "# Instructions\nBe helpful").unwrap();
179
180        let prompt = Prompt::from_globs(vec!["AGENTS.md".to_string()], dir.path().to_path_buf());
181        let result = prompt.build().await.unwrap();
182        assert_eq!(result, "# Instructions\nBe helpful");
183    }
184
185    #[tokio::test]
186    async fn prompt_globs_resolves_glob_pattern() {
187        let dir = tempfile::tempdir().unwrap();
188        let rules_dir = dir.path().join(".aether/rules");
189        std::fs::create_dir_all(&rules_dir).unwrap();
190        std::fs::write(rules_dir.join("a-coding.md"), "Use Rust").unwrap();
191        std::fs::write(rules_dir.join("b-testing.md"), "Write tests").unwrap();
192
193        let prompt = Prompt::from_globs(vec![".aether/rules/*.md".to_string()], dir.path().to_path_buf());
194        let result = prompt.build().await.unwrap();
195        assert!(result.contains("Use Rust"));
196        assert!(result.contains("Write tests"));
197    }
198
199    #[tokio::test]
200    async fn prompt_globs_returns_empty_for_no_matches() {
201        let dir = tempfile::tempdir().unwrap();
202
203        let prompt = Prompt::from_globs(vec!["nonexistent*.md".to_string()], dir.path().to_path_buf());
204        let result = prompt.build().await.unwrap();
205        assert!(result.is_empty());
206    }
207
208    #[tokio::test]
209    async fn prompt_globs_supports_absolute_paths() {
210        let dir = tempfile::tempdir().unwrap();
211        let file_path = dir.path().join("rules.md");
212        std::fs::write(&file_path, "Absolute rule").unwrap();
213
214        let prompt = Prompt::from_globs(vec![file_path.to_string_lossy().to_string()], PathBuf::from("/tmp"));
215        let result = prompt.build().await.unwrap();
216        assert_eq!(result, "Absolute rule");
217    }
218
219    #[tokio::test]
220    async fn prompt_globs_concatenates_multiple_patterns() {
221        let dir = tempfile::tempdir().unwrap();
222        std::fs::write(dir.path().join("AGENTS.md"), "Agent instructions").unwrap();
223        std::fs::write(dir.path().join("SYSTEM.md"), "System prompt").unwrap();
224
225        let prompt =
226            Prompt::from_globs(vec!["AGENTS.md".to_string(), "SYSTEM.md".to_string()], dir.path().to_path_buf());
227        let result = prompt.build().await.unwrap();
228        assert!(result.contains("Agent instructions"));
229        assert!(result.contains("System prompt"));
230        assert!(result.contains("\n\n"));
231    }
232
233    #[tokio::test]
234    async fn build_all_skips_empty_parts() {
235        let prompts = vec![Prompt::text("Part one"), Prompt::text(""), Prompt::text("Part two")];
236        let result = Prompt::build_all(&prompts).await.unwrap();
237        assert_eq!(result, "Part one\n\nPart two");
238    }
239
240    #[tokio::test]
241    async fn expand_builtins_no_op_without_marker() {
242        let content = "Just some plain content with no directives";
243        let expander = ShellExpander::new();
244        let result = Prompt::expand_builtins(content, None, &expander).await.unwrap();
245        assert_eq!(result, content);
246    }
247
248    #[tokio::test]
249    async fn expand_builtins_runs_shell_command() {
250        let expander = ShellExpander::new();
251        let result = Prompt::expand_builtins("branch: !`echo main`", None, &expander).await.unwrap();
252        assert_eq!(result, "branch: main");
253    }
254
255    #[tokio::test]
256    async fn expand_builtins_runs_command_in_cwd() {
257        let dir = tempfile::tempdir().unwrap();
258        std::fs::write(dir.path().join("sentinel.txt"), "").unwrap();
259
260        let expander = ShellExpander::new();
261        let result = Prompt::expand_builtins("files: !`ls`", Some(dir.path()), &expander).await.unwrap();
262        assert!(result.contains("sentinel.txt"), "expected sentinel.txt in output: {result}");
263    }
264
265    #[tokio::test]
266    async fn expand_builtins_handles_multiple_commands() {
267        let expander = ShellExpander::new();
268        let result = Prompt::expand_builtins("a=!`echo one`, b=!`echo two`", None, &expander).await.unwrap();
269        assert_eq!(result, "a=one, b=two");
270    }
271
272    #[tokio::test]
273    async fn expand_builtins_substitutes_empty_on_failure() {
274        let expander = ShellExpander::new();
275        let result = Prompt::expand_builtins("before !`exit 1` after", None, &expander).await.unwrap();
276        assert_eq!(result, "before  after");
277    }
278
279    #[tokio::test]
280    async fn expand_builtins_trims_trailing_whitespace() {
281        let expander = ShellExpander::new();
282        let result = Prompt::expand_builtins("!`printf 'hi\\n\\n'`", None, &expander).await.unwrap();
283        assert_eq!(result, "hi");
284    }
285
286    #[tokio::test]
287    async fn prompt_globs_expands_shell_in_file() {
288        let dir = tempfile::tempdir().unwrap();
289        std::fs::write(dir.path().join("AGENTS.md"), "Instructions\n\nbranch: !`echo main`\n\nRules").unwrap();
290
291        let prompt = Prompt::from_globs(vec!["AGENTS.md".to_string()], dir.path().to_path_buf());
292        let result = prompt.build().await.unwrap();
293        assert!(result.contains("Instructions"));
294        assert!(result.contains("branch: main"));
295        assert!(result.contains("Rules"));
296        assert!(!result.contains("!`"));
297    }
298}