Skip to main content

aether_core/core/
prompt.rs

1use crate::core::{AgentError, Result};
2use glob::glob;
3use mcp_utils::client::ServerInstructions;
4use schemars::{JsonSchema, Schema, SchemaGenerator};
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6use std::collections::HashMap;
7use std::env;
8use std::path::{Path, PathBuf};
9use thiserror::Error;
10use tokio::fs;
11use tracing::warn;
12use utils::shell_expander::ShellExpander;
13use utils::substitution::substitute_parameters;
14
15#[derive(Debug, Clone)]
16pub enum Prompt {
17    Text(String),
18    File {
19        path: String,
20        args: Option<HashMap<String, String>>,
21        cwd: Option<PathBuf>,
22    },
23    /// Resolve prompt files from glob patterns relative to cwd.
24    /// Absolute paths are also supported.
25    PromptGlobs {
26        patterns: Vec<String>,
27        cwd: PathBuf,
28    },
29    McpInstructions(Vec<ServerInstructions>),
30}
31
32/// Authored description of a prompt source — text, a file path, or a glob pattern.
33///
34/// Used by configuration layers to declare prompts before resolution. Convert into
35/// runtime [`Prompt`] values with [`Prompt::from_sources`].
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum PromptSource {
38    Text { text: String },
39    File { path: String },
40    Glob { pattern: String },
41}
42
43#[derive(serde::Deserialize)]
44#[serde(untagged)]
45enum PromptSourceInput {
46    Path(String),
47    Object(PromptSourceObject),
48}
49
50#[derive(schemars::JsonSchema, serde::Deserialize, serde::Serialize)]
51#[serde(tag = "type", rename_all = "camelCase", deny_unknown_fields)]
52enum PromptSourceObject {
53    Text { text: String },
54    File { path: String },
55    Glob { pattern: String },
56}
57
58impl<'de> Deserialize<'de> for PromptSource {
59    fn deserialize<T: Deserializer<'de>>(deserializer: T) -> std::result::Result<Self, T::Error> {
60        match serde::Deserialize::deserialize(deserializer)? {
61            PromptSourceInput::Path(path) | PromptSourceInput::Object(PromptSourceObject::File { path }) => {
62                Ok(Self::File { path })
63            }
64            PromptSourceInput::Object(PromptSourceObject::Text { text }) => Ok(Self::Text { text }),
65            PromptSourceInput::Object(PromptSourceObject::Glob { pattern }) => Ok(Self::Glob { pattern }),
66        }
67    }
68}
69
70impl Serialize for PromptSource {
71    fn serialize<T: Serializer>(&self, serializer: T) -> std::result::Result<T::Ok, T::Error> {
72        match self {
73            Self::File { path } => serializer.serialize_str(path),
74            Self::Text { text } => Serialize::serialize(&PromptSourceObject::Text { text: text.clone() }, serializer),
75            Self::Glob { pattern } => {
76                Serialize::serialize(&PromptSourceObject::Glob { pattern: pattern.clone() }, serializer)
77            }
78        }
79    }
80}
81
82impl JsonSchema for PromptSource {
83    fn schema_name() -> std::borrow::Cow<'static, str> {
84        "PromptSource".into()
85    }
86
87    fn json_schema(generator: &mut SchemaGenerator) -> Schema {
88        let object_schema = generator.subschema_for::<PromptSourceObject>().to_value();
89        Schema::try_from(serde_json::json!({
90            "description": "Authored description of a prompt source — either a file path string or a typed text, file, or glob object.",
91            "oneOf": [
92                { "type": "string" },
93                object_schema
94            ]
95        }))
96        .expect("prompt source schema must be valid")
97    }
98}
99
100impl PromptSource {
101    pub fn file(path: impl Into<String>) -> Self {
102        Self::File { path: path.into() }
103    }
104
105    pub fn path(&self) -> Option<&str> {
106        match self {
107            Self::File { path } => Some(path.as_str()),
108            Self::Glob { pattern } => Some(pattern.as_str()),
109            Self::Text { .. } => None,
110        }
111    }
112}
113
114impl From<&str> for PromptSource {
115    fn from(value: &str) -> Self {
116        Self::file(value)
117    }
118}
119
120impl From<String> for PromptSource {
121    fn from(value: String) -> Self {
122        Self::file(value)
123    }
124}
125
126/// Validation failures raised while resolving [`PromptSource`] values into [`Prompt`]s.
127#[derive(Debug, Clone, PartialEq, Eq, Error)]
128pub enum PromptSourceError {
129    /// A glob pattern is syntactically invalid.
130    #[error("Invalid glob pattern '{pattern}': {error}")]
131    InvalidGlobPattern { pattern: String, error: String },
132
133    /// A prompt file or glob did not match any files on disk.
134    #[error("Prompt entry '{pattern}' resolves to no files")]
135    ZeroMatch { pattern: String },
136}
137
138impl Prompt {
139    pub fn text(str: &str) -> Self {
140        Self::Text(str.to_string())
141    }
142
143    pub fn file(path: &str) -> Self {
144        Self::File { path: path.to_string(), args: None, cwd: None }
145    }
146
147    pub fn file_with_args(path: &str, args: HashMap<String, String>) -> Self {
148        Self::File { path: path.to_string(), args: Some(args), cwd: None }
149    }
150
151    pub fn from_globs(patterns: Vec<String>, cwd: PathBuf) -> Self {
152        Self::PromptGlobs { patterns, cwd }
153    }
154
155    /// Resolve a slice of [`PromptSource`] declarations into runtime [`Prompt`] values.
156    ///
157    /// Validates that file paths and glob patterns produce at least one matching file
158    /// under `project_root`. Text sources pass through unchanged.
159    pub fn from_sources(
160        project_root: &Path,
161        sources: &[PromptSource],
162    ) -> std::result::Result<Vec<Prompt>, PromptSourceError> {
163        sources
164            .iter()
165            .map(|source| match source {
166                PromptSource::Text { text } => Ok(Prompt::text(text)),
167                PromptSource::File { path } => validate_prompt_file(project_root, path)
168                    .map(|()| Prompt::file(path).with_cwd(project_root.to_path_buf())),
169                PromptSource::Glob { pattern } => validate_prompt_glob(project_root, pattern)
170                    .map(|()| Prompt::from_globs(vec![pattern.clone()], project_root.to_path_buf())),
171            })
172            .collect()
173    }
174
175    pub fn with_cwd(self, cwd: PathBuf) -> Self {
176        match self {
177            Self::File { path, args, .. } => Self::File { path, args, cwd: Some(cwd) },
178            Self::PromptGlobs { patterns, .. } => Self::PromptGlobs { patterns, cwd },
179            Self::Text(_) | Self::McpInstructions(_) => self,
180        }
181    }
182
183    pub fn mcp_instructions(instructions: Vec<ServerInstructions>) -> Self {
184        Self::McpInstructions(instructions)
185    }
186
187    /// Resolve this `SystemPrompt` to a String
188    pub async fn build(&self) -> Result<String> {
189        match self {
190            Prompt::Text(text) => Ok(text.clone()),
191            Prompt::File { path, args, cwd } => {
192                let content = Self::resolve_file(&PathBuf::from(path)).await?;
193                let substituted = substitute_parameters(&content, args);
194                let expander = ShellExpander::new();
195                Self::expand_builtins(&substituted, cwd.as_deref(), &expander).await
196            }
197            Prompt::PromptGlobs { patterns, cwd } => Self::resolve_prompt_globs(patterns, cwd).await,
198            Prompt::McpInstructions(instructions) => Ok(format_mcp_instructions(instructions)),
199        }
200    }
201
202    /// Resolve multiple `SystemPrompts` and join them with double newlines
203    pub async fn build_all(prompts: &[Prompt]) -> Result<String> {
204        let mut parts = Vec::with_capacity(prompts.len());
205        for p in prompts {
206            let part = p.build().await?;
207            if !part.is_empty() {
208                parts.push(part);
209            }
210        }
211        Ok(parts.join("\n\n"))
212    }
213
214    async fn resolve_file(path: &Path) -> Result<String> {
215        fs::read_to_string(path)
216            .await
217            .map_err(|e| AgentError::IoError(format!("Failed to read file '{}': {e}", path.display())))
218    }
219
220    async fn resolve_prompt_globs(patterns: &[String], cwd: &Path) -> Result<String> {
221        let mut contents = Vec::new();
222        let expander = ShellExpander::new();
223
224        for pattern in patterns {
225            let full_pattern = if Path::new(pattern).is_absolute() {
226                pattern.clone()
227            } else {
228                cwd.join(pattern).to_string_lossy().to_string()
229            };
230
231            let paths = glob(&full_pattern)
232                .map_err(|e| AgentError::IoError(format!("Invalid glob pattern '{pattern}': {e}")))?;
233
234            let mut matched: Vec<PathBuf> = paths.filter_map(std::result::Result::ok).collect();
235            matched.sort();
236
237            for path in matched {
238                if path.is_file() {
239                    match fs::read_to_string(&path).await {
240                        Ok(content) => {
241                            let resolved = Self::expand_builtins(&content, Some(cwd), &expander).await?;
242                            contents.push(resolved);
243                        }
244                        Err(e) => {
245                            warn!("Failed to read prompt file '{}': {e}", path.display());
246                        }
247                    }
248                }
249            }
250        }
251
252        Ok(contents.join("\n\n"))
253    }
254
255    /// Expand `` !`command` `` shell-interpolation markers in prompt content.
256    ///
257    /// Thin wrapper around [`ShellExpander::expand`] that resolves `cwd` from
258    /// the process working directory when `None`.
259    async fn expand_builtins(content: &str, cwd: Option<&Path>, expander: &ShellExpander) -> Result<String> {
260        let cwd = match cwd {
261            Some(dir) => dir.to_path_buf(),
262            None => {
263                env::current_dir().map_err(|e| AgentError::IoError(format!("Failed to get current directory: {e}")))?
264            }
265        };
266        Ok(expander.expand(content, &cwd).await)
267    }
268}
269
270fn validate_prompt_file(project_root: &Path, path: &str) -> std::result::Result<(), PromptSourceError> {
271    let full_path = project_root.join(path);
272    if full_path.is_file() { Ok(()) } else { Err(PromptSourceError::ZeroMatch { pattern: path.to_string() }) }
273}
274
275fn validate_prompt_glob(project_root: &Path, pattern: &str) -> std::result::Result<(), PromptSourceError> {
276    let full_pattern = if Path::new(pattern).is_absolute() {
277        pattern.to_string()
278    } else {
279        project_root.join(pattern).to_string_lossy().to_string()
280    };
281
282    let has_file_match = glob(&full_pattern)
283        .map_err(|e| PromptSourceError::InvalidGlobPattern { pattern: pattern.to_string(), error: e.to_string() })?
284        .filter_map(std::result::Result::ok)
285        .any(|path| path.is_file());
286
287    if has_file_match { Ok(()) } else { Err(PromptSourceError::ZeroMatch { pattern: pattern.to_string() }) }
288}
289
290/// Format MCP instructions with XML tags for the system prompt.
291fn format_mcp_instructions(instructions: &[ServerInstructions]) -> String {
292    if instructions.is_empty() {
293        return String::new();
294    }
295
296    let mut parts = vec!["# MCP Server Instructions\n".to_string()];
297    parts.push("You are connected to the following MCP servers:\n".to_string());
298
299    for instr in instructions {
300        parts.push(format!("<mcp-server name=\"{}\">\n{}\n</mcp-server>\n", instr.server_name, instr.instructions));
301    }
302
303    parts.join("\n")
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[tokio::test]
311    async fn build_text_prompt() {
312        let prompt = Prompt::text("Hello, world!");
313        let result = prompt.build().await.unwrap();
314        assert_eq!(result, "Hello, world!");
315    }
316
317    #[tokio::test]
318    async fn build_all_concatenates_prompts() {
319        let prompts = vec![Prompt::text("Part one"), Prompt::text("Part two")];
320        let result = Prompt::build_all(&prompts).await.unwrap();
321        assert_eq!(result, "Part one\n\nPart two");
322    }
323
324    #[tokio::test]
325    async fn prompt_globs_resolves_single_file() {
326        let dir = tempfile::tempdir().unwrap();
327        std::fs::write(dir.path().join("AGENTS.md"), "# Instructions\nBe helpful").unwrap();
328
329        let prompt = Prompt::from_globs(vec!["AGENTS.md".to_string()], dir.path().to_path_buf());
330        let result = prompt.build().await.unwrap();
331        assert_eq!(result, "# Instructions\nBe helpful");
332    }
333
334    #[tokio::test]
335    async fn prompt_globs_resolves_glob_pattern() {
336        let dir = tempfile::tempdir().unwrap();
337        let rules_dir = dir.path().join(".aether/rules");
338        std::fs::create_dir_all(&rules_dir).unwrap();
339        std::fs::write(rules_dir.join("a-coding.md"), "Use Rust").unwrap();
340        std::fs::write(rules_dir.join("b-testing.md"), "Write tests").unwrap();
341
342        let prompt = Prompt::from_globs(vec![".aether/rules/*.md".to_string()], dir.path().to_path_buf());
343        let result = prompt.build().await.unwrap();
344        assert!(result.contains("Use Rust"));
345        assert!(result.contains("Write tests"));
346    }
347
348    #[tokio::test]
349    async fn prompt_globs_returns_empty_for_no_matches() {
350        let dir = tempfile::tempdir().unwrap();
351
352        let prompt = Prompt::from_globs(vec!["nonexistent*.md".to_string()], dir.path().to_path_buf());
353        let result = prompt.build().await.unwrap();
354        assert!(result.is_empty());
355    }
356
357    #[tokio::test]
358    async fn prompt_globs_supports_absolute_paths() {
359        let dir = tempfile::tempdir().unwrap();
360        let file_path = dir.path().join("rules.md");
361        std::fs::write(&file_path, "Absolute rule").unwrap();
362
363        let prompt = Prompt::from_globs(vec![file_path.to_string_lossy().to_string()], PathBuf::from("/tmp"));
364        let result = prompt.build().await.unwrap();
365        assert_eq!(result, "Absolute rule");
366    }
367
368    #[tokio::test]
369    async fn prompt_globs_concatenates_multiple_patterns() {
370        let dir = tempfile::tempdir().unwrap();
371        std::fs::write(dir.path().join("AGENTS.md"), "Agent instructions").unwrap();
372        std::fs::write(dir.path().join("SYSTEM.md"), "System prompt").unwrap();
373
374        let prompt =
375            Prompt::from_globs(vec!["AGENTS.md".to_string(), "SYSTEM.md".to_string()], dir.path().to_path_buf());
376        let result = prompt.build().await.unwrap();
377        assert!(result.contains("Agent instructions"));
378        assert!(result.contains("System prompt"));
379        assert!(result.contains("\n\n"));
380    }
381
382    #[tokio::test]
383    async fn build_all_skips_empty_parts() {
384        let prompts = vec![Prompt::text("Part one"), Prompt::text(""), Prompt::text("Part two")];
385        let result = Prompt::build_all(&prompts).await.unwrap();
386        assert_eq!(result, "Part one\n\nPart two");
387    }
388
389    #[tokio::test]
390    async fn expand_builtins_no_op_without_marker() {
391        let content = "Just some plain content with no directives";
392        let expander = ShellExpander::new();
393        let result = Prompt::expand_builtins(content, None, &expander).await.unwrap();
394        assert_eq!(result, content);
395    }
396
397    #[tokio::test]
398    async fn expand_builtins_runs_shell_command() {
399        let expander = ShellExpander::new();
400        let result = Prompt::expand_builtins("branch: !`echo main`", None, &expander).await.unwrap();
401        assert_eq!(result, "branch: main");
402    }
403
404    #[tokio::test]
405    async fn expand_builtins_runs_command_in_cwd() {
406        let dir = tempfile::tempdir().unwrap();
407        std::fs::write(dir.path().join("sentinel.txt"), "").unwrap();
408
409        let expander = ShellExpander::new();
410        let result = Prompt::expand_builtins("files: !`ls`", Some(dir.path()), &expander).await.unwrap();
411        assert!(result.contains("sentinel.txt"), "expected sentinel.txt in output: {result}");
412    }
413
414    #[tokio::test]
415    async fn expand_builtins_handles_multiple_commands() {
416        let expander = ShellExpander::new();
417        let result = Prompt::expand_builtins("a=!`echo one`, b=!`echo two`", None, &expander).await.unwrap();
418        assert_eq!(result, "a=one, b=two");
419    }
420
421    #[tokio::test]
422    async fn expand_builtins_substitutes_empty_on_failure() {
423        let expander = ShellExpander::new();
424        let result = Prompt::expand_builtins("before !`exit 1` after", None, &expander).await.unwrap();
425        assert_eq!(result, "before  after");
426    }
427
428    #[tokio::test]
429    async fn expand_builtins_trims_trailing_whitespace() {
430        let expander = ShellExpander::new();
431        let result = Prompt::expand_builtins("!`printf 'hi\\n\\n'`", None, &expander).await.unwrap();
432        assert_eq!(result, "hi");
433    }
434
435    #[tokio::test]
436    async fn prompt_globs_expands_shell_in_file() {
437        let dir = tempfile::tempdir().unwrap();
438        std::fs::write(dir.path().join("AGENTS.md"), "Instructions\n\nbranch: !`echo main`\n\nRules").unwrap();
439
440        let prompt = Prompt::from_globs(vec!["AGENTS.md".to_string()], dir.path().to_path_buf());
441        let result = prompt.build().await.unwrap();
442        assert!(result.contains("Instructions"));
443        assert!(result.contains("branch: main"));
444        assert!(result.contains("Rules"));
445        assert!(!result.contains("!`"));
446    }
447}