Skip to main content

aether_core/core/
prompt.rs

1use crate::core::{AgentError, Result};
2use glob::glob;
3use mcp_utils::client::ServerInstructions;
4use std::collections::HashMap;
5use std::env;
6use std::path::{Path, PathBuf};
7use tokio::fs;
8use tokio::process::Command;
9use tracing::warn;
10use utils::substitution::substitute_parameters;
11
12#[derive(Debug, Clone)]
13pub enum Prompt {
14    Text(String),
15    File {
16        path: String,
17        args: Option<HashMap<String, String>>,
18        cwd: Option<PathBuf>,
19    },
20    /// Resolve prompt files from glob patterns relative to cwd.
21    /// Absolute paths are also supported.
22    PromptGlobs {
23        patterns: Vec<String>,
24        cwd: PathBuf,
25    },
26    SystemEnv(Option<PathBuf>),
27    McpInstructions(Vec<ServerInstructions>),
28}
29
30impl Prompt {
31    pub fn text(str: &str) -> Self {
32        Self::Text(str.to_string())
33    }
34
35    pub fn file(path: &str) -> Self {
36        Self::File {
37            path: path.to_string(),
38            args: None,
39            cwd: None,
40        }
41    }
42
43    pub fn file_with_args(path: &str, args: HashMap<String, String>) -> Self {
44        Self::File {
45            path: path.to_string(),
46            args: Some(args),
47            cwd: None,
48        }
49    }
50
51    pub fn from_globs(patterns: Vec<String>, cwd: PathBuf) -> Self {
52        Self::PromptGlobs { patterns, cwd }
53    }
54
55    pub fn system_env() -> Self {
56        Self::SystemEnv(None)
57    }
58
59    pub fn with_cwd(self, cwd: PathBuf) -> Self {
60        match self {
61            Self::File { path, args, .. } => Self::File {
62                path,
63                args,
64                cwd: Some(cwd),
65            },
66            Self::SystemEnv(_) => Self::SystemEnv(Some(cwd)),
67            Self::PromptGlobs { patterns, .. } => Self::PromptGlobs { patterns, cwd },
68            Self::Text(_) | Self::McpInstructions(_) => self,
69        }
70    }
71
72    pub fn mcp_instructions(instructions: Vec<ServerInstructions>) -> Self {
73        Self::McpInstructions(instructions)
74    }
75
76    /// Resolve this `SystemPrompt` to a String
77    pub async fn build(&self) -> Result<String> {
78        match self {
79            Prompt::Text(text) => Ok(text.clone()),
80            Prompt::File { path, args, cwd } => {
81                let content = Self::resolve_file(&PathBuf::from(path)).await?;
82                let substituted = substitute_parameters(&content, args);
83                Self::expand_builtins(&substituted, cwd.as_deref()).await
84            }
85            Prompt::PromptGlobs { patterns, cwd } => {
86                Self::resolve_prompt_globs(patterns, cwd).await
87            }
88            Prompt::SystemEnv(cwd) => Self::resolve_system_env(cwd.as_deref()).await,
89            Prompt::McpInstructions(instructions) => Ok(format_mcp_instructions(instructions)),
90        }
91    }
92
93    /// Resolve multiple `SystemPrompts` and join them with double newlines
94    pub async fn build_all(prompts: &[Prompt]) -> Result<String> {
95        let mut parts = Vec::with_capacity(prompts.len());
96        for p in prompts {
97            let part = p.build().await?;
98            if !part.is_empty() {
99                parts.push(part);
100            }
101        }
102        Ok(parts.join("\n\n"))
103    }
104
105    async fn resolve_file(path: &Path) -> Result<String> {
106        fs::read_to_string(path).await.map_err(|e| {
107            AgentError::IoError(format!("Failed to read file '{}': {e}", path.display()))
108        })
109    }
110
111    async fn resolve_prompt_globs(patterns: &[String], cwd: &Path) -> Result<String> {
112        let mut contents = Vec::new();
113
114        for pattern in patterns {
115            let full_pattern = if Path::new(pattern).is_absolute() {
116                pattern.clone()
117            } else {
118                cwd.join(pattern).to_string_lossy().to_string()
119            };
120
121            let paths = glob(&full_pattern).map_err(|e| {
122                AgentError::IoError(format!("Invalid glob pattern '{pattern}': {e}"))
123            })?;
124
125            let mut matched: Vec<PathBuf> = paths.filter_map(std::result::Result::ok).collect();
126            matched.sort();
127
128            for path in matched {
129                if path.is_file() {
130                    match fs::read_to_string(&path).await {
131                        Ok(content) => {
132                            let expanded = Self::expand_builtins(&content, Some(cwd)).await?;
133                            contents.push(expanded);
134                        }
135                        Err(e) => {
136                            warn!("Failed to read prompt file '{}': {e}", path.display());
137                        }
138                    }
139                }
140            }
141        }
142
143        Ok(contents.join("\n\n"))
144    }
145
146    /// Expand builtin directives like `$SYSTEM_ENV` in prompt file content.
147    async fn expand_builtins(content: &str, cwd: Option<&Path>) -> Result<String> {
148        if !content.contains("$SYSTEM_ENV") {
149            return Ok(content.to_string());
150        }
151        let env_block = Self::resolve_system_env(cwd).await?;
152        Ok(content.replace("$SYSTEM_ENV", &env_block))
153    }
154
155    async fn resolve_system_env(cwd: Option<&Path>) -> Result<String> {
156        let cwd = match cwd {
157            Some(dir) => dir.to_path_buf(),
158            None => env::current_dir().map_err(|e| {
159                AgentError::IoError(format!("Failed to get current directory: {e}"))
160            })?,
161        };
162
163        let os_version = Command::new("uname")
164            .arg("-a")
165            .output()
166            .await
167            .ok()
168            .and_then(|output| String::from_utf8(output.stdout).ok())
169            .and_then(|version| {
170                let version = version.trim();
171                if version.is_empty() {
172                    None
173                } else {
174                    Some(format!("OS Version: {version}"))
175                }
176            });
177
178        let is_git_repo = fs::metadata(cwd.join(".git"))
179            .await
180            .map(|m| m.is_dir())
181            .unwrap_or(false);
182
183        let working_dir = if is_git_repo {
184            format!("Working directory: {} (git repo)", cwd.display())
185        } else {
186            format!("Working directory: {}", cwd.display())
187        };
188
189        let mut lines = vec![
190            working_dir,
191            format!("Platform: {}", env::consts::OS),
192            format!("Today's date: {}", chrono::Local::now().format("%Y-%m-%d")),
193        ];
194
195        if let Some(os) = os_version {
196            lines.push(os);
197        }
198
199        Ok(format!("<env>\n{}\n</env>", lines.join("\n")))
200    }
201}
202
203/// Format MCP instructions with XML tags for the system prompt.
204fn format_mcp_instructions(instructions: &[ServerInstructions]) -> String {
205    if instructions.is_empty() {
206        return String::new();
207    }
208
209    let mut parts = vec!["# MCP Server Instructions\n".to_string()];
210    parts.push("You are connected to the following MCP servers:\n".to_string());
211
212    for instr in instructions {
213        parts.push(format!(
214            "<mcp-server name=\"{}\">\n{}\n</mcp-server>\n",
215            instr.server_name, instr.instructions
216        ));
217    }
218
219    parts.join("\n")
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[tokio::test]
227    async fn build_text_prompt() {
228        let prompt = Prompt::text("Hello, world!");
229        let result = prompt.build().await.unwrap();
230        assert_eq!(result, "Hello, world!");
231    }
232
233    #[tokio::test]
234    async fn build_all_concatenates_prompts() {
235        let prompts = vec![Prompt::text("Part one"), Prompt::text("Part two")];
236        let result = Prompt::build_all(&prompts).await.unwrap();
237        assert_eq!(result, "Part one\n\nPart two");
238    }
239
240    #[tokio::test]
241    async fn resolve_system_env_contains_expected_fields() {
242        let result = Prompt::resolve_system_env(None).await.unwrap();
243        assert!(result.contains("<env>"));
244        assert!(result.contains("</env>"));
245        assert!(result.contains("Working directory:"));
246        assert!(result.contains("Platform:"));
247        assert!(result.contains("Today's date:"));
248    }
249
250    #[tokio::test]
251    async fn resolve_system_env_uses_provided_cwd() {
252        let cwd = std::env::temp_dir();
253        let result = Prompt::resolve_system_env(Some(cwd.as_path()))
254            .await
255            .unwrap();
256        assert!(result.contains(&cwd.display().to_string()));
257    }
258
259    #[tokio::test]
260    async fn prompt_globs_resolves_single_file() {
261        let dir = tempfile::tempdir().unwrap();
262        std::fs::write(dir.path().join("AGENTS.md"), "# Instructions\nBe helpful").unwrap();
263
264        let prompt = Prompt::from_globs(vec!["AGENTS.md".to_string()], dir.path().to_path_buf());
265        let result = prompt.build().await.unwrap();
266        assert_eq!(result, "# Instructions\nBe helpful");
267    }
268
269    #[tokio::test]
270    async fn prompt_globs_resolves_glob_pattern() {
271        let dir = tempfile::tempdir().unwrap();
272        let rules_dir = dir.path().join(".aether/rules");
273        std::fs::create_dir_all(&rules_dir).unwrap();
274        std::fs::write(rules_dir.join("a-coding.md"), "Use Rust").unwrap();
275        std::fs::write(rules_dir.join("b-testing.md"), "Write tests").unwrap();
276
277        let prompt = Prompt::from_globs(
278            vec![".aether/rules/*.md".to_string()],
279            dir.path().to_path_buf(),
280        );
281        let result = prompt.build().await.unwrap();
282        assert!(result.contains("Use Rust"));
283        assert!(result.contains("Write tests"));
284    }
285
286    #[tokio::test]
287    async fn prompt_globs_returns_empty_for_no_matches() {
288        let dir = tempfile::tempdir().unwrap();
289
290        let prompt = Prompt::from_globs(
291            vec!["nonexistent*.md".to_string()],
292            dir.path().to_path_buf(),
293        );
294        let result = prompt.build().await.unwrap();
295        assert!(result.is_empty());
296    }
297
298    #[tokio::test]
299    async fn prompt_globs_supports_absolute_paths() {
300        let dir = tempfile::tempdir().unwrap();
301        let file_path = dir.path().join("rules.md");
302        std::fs::write(&file_path, "Absolute rule").unwrap();
303
304        let prompt = Prompt::from_globs(
305            vec![file_path.to_string_lossy().to_string()],
306            PathBuf::from("/tmp"),
307        );
308        let result = prompt.build().await.unwrap();
309        assert_eq!(result, "Absolute rule");
310    }
311
312    #[tokio::test]
313    async fn prompt_globs_concatenates_multiple_patterns() {
314        let dir = tempfile::tempdir().unwrap();
315        std::fs::write(dir.path().join("AGENTS.md"), "Agent instructions").unwrap();
316        std::fs::write(dir.path().join("SYSTEM.md"), "System prompt").unwrap();
317
318        let prompt = Prompt::from_globs(
319            vec!["AGENTS.md".to_string(), "SYSTEM.md".to_string()],
320            dir.path().to_path_buf(),
321        );
322        let result = prompt.build().await.unwrap();
323        assert!(result.contains("Agent instructions"));
324        assert!(result.contains("System prompt"));
325        assert!(result.contains("\n\n"));
326    }
327
328    #[tokio::test]
329    async fn build_all_skips_empty_parts() {
330        let prompts = vec![
331            Prompt::text("Part one"),
332            Prompt::text(""),
333            Prompt::text("Part two"),
334        ];
335        let result = Prompt::build_all(&prompts).await.unwrap();
336        assert_eq!(result, "Part one\n\nPart two");
337    }
338
339    #[tokio::test]
340    async fn expand_builtins_replaces_system_env() {
341        let result = Prompt::expand_builtins("Before\n$SYSTEM_ENV\nAfter", None)
342            .await
343            .unwrap();
344        assert!(result.starts_with("Before\n<env>"));
345        assert!(result.contains("</env>"));
346        assert!(result.ends_with("</env>\nAfter"));
347    }
348
349    #[tokio::test]
350    async fn expand_builtins_no_op_without_marker() {
351        let content = "Just some plain content with no directives";
352        let result = Prompt::expand_builtins(content, None).await.unwrap();
353        assert_eq!(result, content);
354    }
355
356    #[tokio::test]
357    async fn expand_builtins_with_cwd() {
358        let cwd = std::env::temp_dir();
359        let result = Prompt::expand_builtins("$SYSTEM_ENV", Some(cwd.as_path()))
360            .await
361            .unwrap();
362        assert!(result.contains(&cwd.display().to_string()));
363    }
364
365    #[tokio::test]
366    async fn prompt_globs_expands_system_env_in_file() {
367        let dir = tempfile::tempdir().unwrap();
368        std::fs::write(
369            dir.path().join("AGENTS.md"),
370            "Instructions\n\n$SYSTEM_ENV\n\nRules",
371        )
372        .unwrap();
373
374        let prompt = Prompt::from_globs(vec!["AGENTS.md".to_string()], dir.path().to_path_buf());
375        let result = prompt.build().await.unwrap();
376        assert!(result.contains("Instructions"));
377        assert!(result.contains("<env>"));
378        assert!(result.contains("Rules"));
379        assert!(!result.contains("$SYSTEM_ENV"));
380    }
381}