Skip to main content

aether_core/core/
prompt.rs

1use crate::core::{AgentError, Result};
2use glob::glob;
3use schemars::{JsonSchema, Schema, SchemaGenerator};
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5use std::collections::{BTreeMap, HashMap};
6use std::path::{Path, PathBuf};
7use thiserror::Error;
8use tokio::fs;
9use tracing::warn;
10use utils::shell_expander::ShellExpander;
11use utils::substitution::substitute_parameters;
12use utils::variables::VarError;
13use utils::{PathOrObject, ResourcePath, is_false, string_or_object_schema};
14
15#[derive(Debug, Clone, PartialEq)]
16pub enum Prompt {
17    Text(String),
18    File {
19        path: PathBuf,
20        args: Option<HashMap<String, String>>,
21        cwd: PathBuf,
22    },
23    /// MCP server instructions keyed by server name. `BTreeMap` gives a
24    /// stable render order, which is useful for prompt-caching
25    McpInstructions(BTreeMap<String, String>),
26}
27
28/// Authored description of a prompt source — text, a file path, or a glob pattern.
29///
30/// Used by configuration layers to declare prompts before resolution. Convert into
31/// runtime [`Prompt`] values with [`Prompt::from_sources`].
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum PromptSource {
34    Text { text: String },
35    File { path: ResourcePath, optional: bool },
36    Glob { pattern: ResourcePath, optional: bool },
37}
38
39impl PromptSource {
40    pub fn file(path: impl Into<ResourcePath>) -> Self {
41        Self::File { path: path.into(), optional: false }
42    }
43
44    pub fn glob(pattern: impl Into<ResourcePath>) -> Self {
45        Self::Glob { pattern: pattern.into(), optional: false }
46    }
47
48    #[must_use]
49    pub fn optional(self) -> Self {
50        match self {
51            Self::File { path, .. } => Self::File { path, optional: true },
52            Self::Glob { pattern, .. } => Self::Glob { pattern, optional: true },
53            Self::Text { .. } => self,
54        }
55    }
56
57    /// The authored path/pattern string, if this source has one.
58    pub fn path(&self) -> Option<&str> {
59        match self {
60            Self::File { path, .. } => Some(path.as_authored()),
61            Self::Glob { pattern, .. } => Some(pattern.as_authored()),
62            Self::Text { .. } => None,
63        }
64    }
65
66    pub fn is_optional(&self) -> bool {
67        match self {
68            Self::File { optional, .. } | Self::Glob { optional, .. } => *optional,
69            Self::Text { .. } => false,
70        }
71    }
72}
73
74impl From<&str> for PromptSource {
75    fn from(value: &str) -> Self {
76        Self::file(value)
77    }
78}
79
80impl From<String> for PromptSource {
81    fn from(value: String) -> Self {
82        Self::file(value)
83    }
84}
85
86impl From<PromptSourceObject> for PromptSource {
87    fn from(object: PromptSourceObject) -> Self {
88        match object {
89            PromptSourceObject::Text { text } => Self::Text { text },
90            PromptSourceObject::File { path, optional } => Self::File { path, optional },
91            PromptSourceObject::Glob { pattern, optional } => Self::Glob { pattern, optional },
92        }
93    }
94}
95
96impl<'de> Deserialize<'de> for PromptSource {
97    fn deserialize<T: Deserializer<'de>>(deserializer: T) -> std::result::Result<Self, T::Error> {
98        Ok(match PathOrObject::<PromptSourceObject>::deserialize(deserializer)? {
99            PathOrObject::Path(path) => Self::File { path, optional: false },
100            PathOrObject::Object(object) => object.into(),
101        })
102    }
103}
104
105impl Serialize for PromptSource {
106    fn serialize<T: Serializer>(&self, serializer: T) -> std::result::Result<T::Ok, T::Error> {
107        match self {
108            Self::File { path, optional: false } => path.serialize(serializer),
109            Self::File { path, optional } => {
110                Serialize::serialize(&PromptSourceObject::File { path: path.clone(), optional: *optional }, serializer)
111            }
112            Self::Text { text } => Serialize::serialize(&PromptSourceObject::Text { text: text.clone() }, serializer),
113            Self::Glob { pattern, optional } => Serialize::serialize(
114                &PromptSourceObject::Glob { pattern: pattern.clone(), optional: *optional },
115                serializer,
116            ),
117        }
118    }
119}
120
121impl JsonSchema for PromptSource {
122    fn schema_name() -> std::borrow::Cow<'static, str> {
123        "PromptSource".into()
124    }
125
126    fn json_schema(generator: &mut SchemaGenerator) -> Schema {
127        string_or_object_schema(
128            include_str!("../docs/prompt_source.md"),
129            &generator.subschema_for::<PromptSourceObject>().to_value(),
130        )
131    }
132}
133
134#[derive(schemars::JsonSchema, serde::Deserialize, serde::Serialize)]
135#[serde(tag = "type", rename_all = "camelCase", deny_unknown_fields)]
136enum PromptSourceObject {
137    Text {
138        /// Literal prompt text included verbatim.
139        text: String,
140    },
141    File {
142        /// Path to a prompt file, resolved as a resource path.
143        path: ResourcePath,
144        /// When true, a missing file is skipped instead of raising an error.
145        #[serde(default, skip_serializing_if = "is_false")]
146        optional: bool,
147    },
148    Glob {
149        /// Glob pattern matching prompt files, resolved as a resource path.
150        pattern: ResourcePath,
151        /// When true, a zero-match glob is skipped instead of raising an error.
152        #[serde(default, skip_serializing_if = "is_false")]
153        optional: bool,
154    },
155}
156
157/// Errors raised while resolving [`PromptSource`] values into [`Prompt`]s.
158#[derive(Debug, Clone, PartialEq, Eq, Error)]
159pub enum PromptSourceError {
160    /// A glob pattern is syntactically invalid.
161    #[error("Invalid glob pattern '{pattern}': {error}")]
162    InvalidGlobPattern { pattern: String, error: String },
163
164    /// A prompt file does not exist on disk.
165    #[error("Prompt file '{path}' does not exist")]
166    Missing { path: String },
167
168    /// A prompt glob matched no files.
169    #[error("Prompt glob '{pattern}' matched no files")]
170    ZeroMatch { pattern: String },
171
172    /// A `${VAR}` reference in a prompt path could not be resolved.
173    #[error("Prompt entry '{pattern}' references undefined variable '{variable}'")]
174    UnresolvedVariable { pattern: String, variable: String },
175}
176
177impl Prompt {
178    pub fn text(str: &str) -> Self {
179        Self::Text(str.to_string())
180    }
181
182    pub fn file(path: impl Into<PathBuf>, cwd: impl Into<PathBuf>) -> Self {
183        Self::File { path: path.into(), args: None, cwd: cwd.into() }
184    }
185
186    /// Resolve a slice of [`PromptSource`] declarations into runtime [`Prompt`] values.
187    pub fn from_sources(
188        workspace_root: &Path,
189        sources: &[PromptSource],
190    ) -> std::result::Result<Vec<Prompt>, PromptSourceError> {
191        let mut prompts = Vec::new();
192        for source in sources {
193            if let PromptSource::Text { text } = source {
194                prompts.push(Prompt::text(text));
195                continue;
196            }
197            match resolve_source_files(workspace_root, source) {
198                Ok(paths) => {
199                    for path in paths {
200                        prompts.push(Prompt::file(path, workspace_root.to_path_buf()));
201                    }
202                }
203                Err(PromptSourceError::Missing { .. }) if source.is_optional() => {}
204                Err(PromptSourceError::UnresolvedVariable { variable, .. }) if source.is_optional() => {
205                    warn!(
206                        "Skipping optional prompt entry '{}': variable '{variable}' is not defined",
207                        source.path().unwrap_or_default()
208                    );
209                }
210                Err(error) => return Err(error),
211            }
212        }
213        Ok(prompts)
214    }
215
216    /// Resolve this `SystemPrompt` to a String
217    pub async fn build(&self) -> Result<String> {
218        match self {
219            Prompt::Text(text) => Ok(text.clone()),
220            Prompt::File { path, args, cwd } => {
221                let content = Self::resolve_file(path).await?;
222                let substituted = substitute_parameters(&content, args);
223                let expander = ShellExpander::new();
224                Ok(expander.expand(&substituted, cwd).await)
225            }
226            Prompt::McpInstructions(instructions) => Ok(format_mcp_instructions(instructions)),
227        }
228    }
229
230    /// Resolve multiple `SystemPrompts` and join them with double newlines
231    pub async fn build_all(prompts: &[Prompt]) -> Result<String> {
232        let mut parts = Vec::with_capacity(prompts.len());
233        for p in prompts {
234            let part = p.build().await?;
235            if !part.is_empty() {
236                parts.push(part);
237            }
238        }
239        Ok(parts.join("\n\n"))
240    }
241
242    async fn resolve_file(path: &Path) -> Result<String> {
243        fs::read_to_string(path)
244            .await
245            .map_err(|e| AgentError::IoError(format!("Failed to read file '{}': {e}", path.display())))
246    }
247}
248
249fn resolve_source_files(
250    workspace_root: &Path,
251    source: &PromptSource,
252) -> std::result::Result<Vec<PathBuf>, PromptSourceError> {
253    match source {
254        PromptSource::Text { .. } => Ok(Vec::new()),
255        PromptSource::File { path, .. } => {
256            let full_path = resolve_path(path, workspace_root)?;
257            if full_path.is_file() {
258                Ok(vec![full_path])
259            } else {
260                Err(PromptSourceError::Missing { path: path.as_authored().to_string() })
261            }
262        }
263        PromptSource::Glob { pattern, optional } => {
264            let full_pattern = resolve_path(pattern, workspace_root)?;
265            let mut paths: Vec<PathBuf> = glob(&full_pattern.to_string_lossy())
266                .map_err(|e| PromptSourceError::InvalidGlobPattern {
267                    pattern: pattern.as_authored().to_string(),
268                    error: e.to_string(),
269                })?
270                .filter_map(std::result::Result::ok)
271                .filter(|path| path.is_file())
272                .collect();
273            paths.sort();
274            if paths.is_empty() && !*optional {
275                Err(PromptSourceError::ZeroMatch { pattern: pattern.as_authored().to_string() })
276            } else {
277                Ok(paths)
278            }
279        }
280    }
281}
282
283fn resolve_path(path: &ResourcePath, workspace_root: &Path) -> std::result::Result<PathBuf, PromptSourceError> {
284    path.resolve(workspace_root).map_err(|VarError::NotFound(variable)| PromptSourceError::UnresolvedVariable {
285        pattern: path.as_authored().to_string(),
286        variable,
287    })
288}
289
290pub struct PromptCache {
291    prompts: Vec<Prompt>,
292    entries: Vec<(Prompt, String)>,
293}
294
295impl PromptCache {
296    pub fn new(mut prompts: Vec<Prompt>) -> Self {
297        if !prompts.iter().any(|p| matches!(p, Prompt::McpInstructions(_))) {
298            prompts.push(Prompt::McpInstructions(BTreeMap::new()));
299        }
300        Self { prompts, entries: Vec::new() }
301    }
302
303    pub fn update_mcp_instruction(&mut self, server: String, body: Option<String>) {
304        for prompt in &mut self.prompts {
305            if let Prompt::McpInstructions(map) = prompt {
306                match body {
307                    Some(text) => {
308                        map.insert(server, text);
309                    }
310                    None => {
311                        map.remove(&server);
312                    }
313                }
314                return;
315            }
316        }
317    }
318
319    pub async fn render(&mut self) -> Result<String> {
320        self.entries.truncate(self.prompts.len());
321        let mut rendered_prompt = String::new();
322        for i in 0..self.prompts.len() {
323            let prompt = &self.prompts[i];
324            match self.entries.get_mut(i) {
325                Some((cached, _)) if *cached == *prompt => {}
326                Some(entry) => *entry = (prompt.clone(), prompt.build().await?),
327                None => self.entries.push((prompt.clone(), prompt.build().await?)),
328            }
329
330            let (_, body) = &self.entries[i];
331            if !body.is_empty() {
332                if !rendered_prompt.is_empty() {
333                    rendered_prompt.push_str("\n\n");
334                }
335                rendered_prompt.push_str(body);
336            }
337        }
338        Ok(rendered_prompt)
339    }
340}
341
342/// Format MCP instructions with XML tags for the system prompt.
343fn format_mcp_instructions(instructions: &BTreeMap<String, String>) -> String {
344    if instructions.is_empty() {
345        return String::new();
346    }
347
348    let mut parts = vec!["# MCP Server Instructions\n".to_string()];
349    parts.push("You are connected to the following MCP servers:\n".to_string());
350
351    for (server_name, body) in instructions {
352        parts.push(format!("<mcp-server name=\"{server_name}\">\n{body}\n</mcp-server>\n"));
353    }
354
355    parts.join("\n")
356}
357
358#[cfg(test)]
359mod tests {
360    use std::fs::{create_dir_all, write};
361
362    use super::*;
363    use crate::testing::mcp_instructions as instructions;
364
365    #[tokio::test]
366    async fn build_text_prompt() {
367        let prompt = Prompt::text("Hello, world!");
368        let result = prompt.build().await.unwrap();
369        assert_eq!(result, "Hello, world!");
370    }
371
372    #[tokio::test]
373    async fn build_all_concatenates_prompts() {
374        let prompts = vec![Prompt::text("Part one"), Prompt::text("Part two")];
375        let result = Prompt::build_all(&prompts).await.unwrap();
376        assert_eq!(result, "Part one\n\nPart two");
377    }
378
379    #[tokio::test]
380    async fn build_all_concatenates_multiple_files() {
381        let dir = tempfile::tempdir().unwrap();
382        std::fs::write(dir.path().join("AGENTS.md"), "Agent instructions").unwrap();
383        std::fs::write(dir.path().join("SYSTEM.md"), "System prompt").unwrap();
384
385        let prompts = vec![
386            Prompt::file(dir.path().join("AGENTS.md"), dir.path()),
387            Prompt::file(dir.path().join("SYSTEM.md"), dir.path()),
388        ];
389        let result = Prompt::build_all(&prompts).await.unwrap();
390        assert!(result.contains("Agent instructions"));
391        assert!(result.contains("System prompt"));
392        assert!(result.contains("\n\n"));
393    }
394
395    #[tokio::test]
396    async fn build_all_skips_empty_parts() {
397        let prompts = vec![Prompt::text("Part one"), Prompt::text(""), Prompt::text("Part two")];
398        let result = Prompt::build_all(&prompts).await.unwrap();
399        assert_eq!(result, "Part one\n\nPart two");
400    }
401
402    #[tokio::test]
403    async fn prompt_cache_render_matches_build_all_on_first_render() {
404        let prompts = vec![
405            Prompt::text("first"),
406            Prompt::McpInstructions(instructions(&[("srv", "body")])),
407            Prompt::text("last"),
408        ];
409        let expected = Prompt::build_all(&prompts).await.unwrap();
410        let mut cache = PromptCache::new(prompts);
411        assert_eq!(cache.render().await.unwrap(), expected);
412    }
413
414    #[tokio::test]
415    async fn prompt_cache_reuses_unchanged_slots() {
416        use std::fs::{remove_file, write};
417
418        let dir = tempfile::tempdir().unwrap();
419        write(dir.path().join("AGENTS.md"), "cached body").unwrap();
420        let mut cache = PromptCache::new(vec![
421            Prompt::file(dir.path().join("AGENTS.md"), dir.path()),
422            Prompt::McpInstructions(BTreeMap::new()),
423        ]);
424
425        cache.render().await.unwrap();
426
427        // Remove the source file to prove we cached things
428        remove_file(dir.path().join("AGENTS.md")).unwrap();
429        cache.update_mcp_instruction("srv".into(), Some("instr".into()));
430
431        let rendered = cache.render().await.unwrap();
432        assert!(rendered.contains("cached body"));
433        assert!(rendered.contains("instr"));
434    }
435
436    #[tokio::test]
437    async fn prompt_cache_empty_renders_empty() {
438        assert_eq!(PromptCache::new(vec![]).render().await.unwrap(), "");
439    }
440
441    #[tokio::test]
442    async fn prompt_cache_drops_empty_slots() {
443        let mut cache = PromptCache::new(vec![Prompt::text("a"), Prompt::text("b")]);
444        assert_eq!(cache.render().await.unwrap(), "a\n\nb");
445    }
446
447    #[tokio::test]
448    async fn build_file_expands_shell_commands() {
449        let dir = tempfile::tempdir().unwrap();
450        write(dir.path().join("AGENTS.md"), "Instructions\n\nbranch: !`echo main`\n\nRules").unwrap();
451
452        let prompt = Prompt::file(dir.path().join("AGENTS.md"), dir.path().to_path_buf());
453        let result = prompt.build().await.unwrap();
454        assert!(result.contains("Instructions"));
455        assert!(result.contains("branch: main"));
456        assert!(result.contains("Rules"));
457        assert!(!result.contains("!`"));
458    }
459
460    #[tokio::test]
461    async fn build_file_runs_shell_in_cwd() {
462        let dir = tempfile::tempdir().unwrap();
463        write(dir.path().join("sentinel.txt"), "").unwrap();
464        let prompt_path = dir.path().join("AGENTS.md");
465        write(&prompt_path, "files: !`ls`").unwrap();
466
467        let prompt = Prompt::file(prompt_path, dir.path().to_path_buf());
468        let result = prompt.build().await.unwrap();
469        assert!(result.contains("sentinel.txt"), "expected sentinel.txt in output: {result}");
470    }
471
472    #[tokio::test]
473    async fn build_file_handles_multiple_commands() {
474        let dir = tempfile::tempdir().unwrap();
475        let prompt_path = dir.path().join("AGENTS.md");
476        write(&prompt_path, "a=!`echo one`, b=!`echo two`").unwrap();
477
478        let prompt = Prompt::file(prompt_path, dir.path().to_path_buf());
479        let result = prompt.build().await.unwrap();
480        assert_eq!(result, "a=one, b=two");
481    }
482
483    #[tokio::test]
484    async fn build_file_substitutes_empty_on_failure() {
485        let dir = tempfile::tempdir().unwrap();
486        let prompt_path = dir.path().join("AGENTS.md");
487        write(&prompt_path, "before !`exit 1` after").unwrap();
488
489        let prompt = Prompt::file(prompt_path, dir.path().to_path_buf());
490        let result = prompt.build().await.unwrap();
491        assert_eq!(result, "before  after");
492    }
493
494    #[tokio::test]
495    async fn build_file_trims_trailing_whitespace() {
496        let dir = tempfile::tempdir().unwrap();
497        let prompt_path = dir.path().join("AGENTS.md");
498        write(&prompt_path, "!`printf 'hi\\n\\n'`").unwrap();
499
500        let prompt = Prompt::file(prompt_path, dir.path().to_path_buf());
501        let result = prompt.build().await.unwrap();
502        assert_eq!(result, "hi");
503    }
504
505    #[test]
506    fn optional_file_source_skips_missing_file() {
507        let dir = tempfile::tempdir().unwrap();
508        write(dir.path().join("EXISTS.md"), "exists").unwrap();
509
510        let sources = vec![PromptSource::file("EXISTS.md"), PromptSource::file("MISSING.md").optional()];
511        let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
512        assert_eq!(prompts.len(), 1);
513    }
514
515    #[test]
516    fn optional_glob_source_skips_zero_matches() {
517        let dir = tempfile::tempdir().unwrap();
518        write(dir.path().join("EXISTS.md"), "exists").unwrap();
519
520        let sources = vec![PromptSource::file("EXISTS.md"), PromptSource::glob("nonexistent*.md").optional()];
521        let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
522        assert_eq!(prompts.len(), 1);
523    }
524
525    #[test]
526    fn required_glob_source_expands_to_one_prompt_per_match() {
527        let dir = tempfile::tempdir().unwrap();
528        let rules_dir = dir.path().join(".aether/rules");
529        create_dir_all(&rules_dir).unwrap();
530        write(rules_dir.join("a.md"), "a").unwrap();
531        write(rules_dir.join("b.md"), "b").unwrap();
532
533        let sources = vec![PromptSource::glob(".aether/rules/*.md")];
534        let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
535        assert_eq!(prompts.len(), 2);
536    }
537
538    #[test]
539    fn required_glob_source_with_no_matches_errors() {
540        let dir = tempfile::tempdir().unwrap();
541        let sources = vec![PromptSource::glob("nonexistent*.md")];
542        let err = Prompt::from_sources(dir.path(), &sources).unwrap_err();
543        assert!(matches!(err, PromptSourceError::ZeroMatch { .. }));
544    }
545
546    #[test]
547    fn optional_glob_source_still_errors_on_invalid_pattern() {
548        let dir = tempfile::tempdir().unwrap();
549        let sources = vec![PromptSource::glob("[invalid").optional()];
550        let err = Prompt::from_sources(dir.path(), &sources).unwrap_err();
551        assert!(matches!(err, PromptSourceError::InvalidGlobPattern { .. }));
552    }
553
554    #[test]
555    fn optional_file_source_skips_unresolved_variable() {
556        let dir = tempfile::tempdir().unwrap();
557        write(dir.path().join("EXISTS.md"), "exists").unwrap();
558
559        let sources = vec![
560            PromptSource::file("EXISTS.md"),
561            PromptSource::file("${DEFINITELY_NOT_SET_VAR_PROMPT_FILE}/foo.md").optional(),
562        ];
563        let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
564        assert_eq!(prompts.len(), 1);
565    }
566
567    #[test]
568    fn optional_glob_source_skips_unresolved_variable() {
569        let dir = tempfile::tempdir().unwrap();
570        write(dir.path().join("EXISTS.md"), "exists").unwrap();
571
572        let sources = vec![
573            PromptSource::file("EXISTS.md"),
574            PromptSource::glob("${DEFINITELY_NOT_SET_VAR_PROMPT_GLOB}/*.md").optional(),
575        ];
576        let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
577        assert_eq!(prompts.len(), 1);
578    }
579
580    #[test]
581    fn required_file_source_errors_on_unresolved_variable() {
582        let dir = tempfile::tempdir().unwrap();
583        let sources = vec![PromptSource::file("${DEFINITELY_NOT_SET_VAR_PROMPT_REQ}/foo.md")];
584        let err = Prompt::from_sources(dir.path(), &sources).unwrap_err();
585        assert!(matches!(err, PromptSourceError::UnresolvedVariable { .. }));
586    }
587
588    #[test]
589    fn prompt_source_string_shorthand_is_required_file() {
590        let source: PromptSource = serde_json::from_str(r#""SYSTEM.md""#).unwrap();
591        assert_eq!(source, PromptSource::file("SYSTEM.md"));
592    }
593
594    #[test]
595    fn optional_prompt_source_serializes_as_typed_object() {
596        let source = PromptSource::file("${WORKSPACE}/AGENTS.md").optional();
597        let value = serde_json::to_value(&source).unwrap();
598        assert_eq!(value, serde_json::json!({"type":"file","path":"${WORKSPACE}/AGENTS.md","optional":true}));
599
600        // Non-optional file stays as string shorthand
601        let source = PromptSource::file("SYSTEM.md");
602        let value = serde_json::to_value(&source).unwrap();
603        assert_eq!(value, serde_json::json!("SYSTEM.md"));
604    }
605
606    #[test]
607    fn optional_prompt_source_deserializes_from_typed_object() {
608        let source: PromptSource =
609            serde_json::from_str(r#"{"type":"file","path":"${WORKSPACE}/AGENTS.md","optional":true}"#).unwrap();
610        assert_eq!(source, PromptSource::file("${WORKSPACE}/AGENTS.md").optional());
611    }
612
613    #[test]
614    fn optional_glob_source_deserializes_from_typed_object() {
615        let source: PromptSource =
616            serde_json::from_str(r#"{"type":"glob","pattern":"${WORKSPACE}/.aether/rules/*.md","optional":true}"#)
617                .unwrap();
618        assert_eq!(source, PromptSource::glob("${WORKSPACE}/.aether/rules/*.md").optional());
619    }
620}