Skip to main content

aether_core/core/
prompt.rs

1use crate::core::{AgentError, Result};
2use glob::glob;
3use schemars::{JsonSchema, Schema, SchemaGenerator};
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5use std::collections::{BTreeMap, HashMap};
6use std::path::{Path, PathBuf};
7use thiserror::Error;
8use tokio::fs;
9use tracing::warn;
10use utils::shell_expander::ShellExpander;
11use utils::substitution::substitute_parameters;
12use utils::variables::VarError;
13use utils::{PathOrObject, ResourcePath, is_false, string_or_object_schema};
14
15#[derive(Debug, Clone, PartialEq)]
16pub enum Prompt {
17    Text(String),
18    File {
19        path: PathBuf,
20        args: Option<HashMap<String, String>>,
21        cwd: PathBuf,
22    },
23    /// MCP server instructions keyed by server name. `BTreeMap` gives a
24    /// stable render order, which is useful for prompt-caching
25    McpInstructions(BTreeMap<String, String>),
26}
27
28/// Authored description of a prompt source — text, a file path, or a glob pattern.
29///
30/// Used by configuration layers to declare prompts before resolution. Convert into
31/// runtime [`Prompt`] values with [`Prompt::from_sources`].
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum PromptSource {
34    Text { text: String },
35    File { path: ResourcePath, optional: bool },
36    Glob { pattern: ResourcePath, optional: bool },
37}
38
39impl PromptSource {
40    pub fn file(path: impl Into<ResourcePath>) -> Self {
41        Self::File { path: path.into(), optional: false }
42    }
43
44    pub fn glob(pattern: impl Into<ResourcePath>) -> Self {
45        Self::Glob { pattern: pattern.into(), optional: false }
46    }
47
48    #[must_use]
49    pub fn optional(self) -> Self {
50        match self {
51            Self::File { path, .. } => Self::File { path, optional: true },
52            Self::Glob { pattern, .. } => Self::Glob { pattern, optional: true },
53            Self::Text { .. } => self,
54        }
55    }
56
57    /// The authored path/pattern string, if this source has one.
58    pub fn path(&self) -> Option<&str> {
59        match self {
60            Self::File { path, .. } => Some(path.as_authored()),
61            Self::Glob { pattern, .. } => Some(pattern.as_authored()),
62            Self::Text { .. } => None,
63        }
64    }
65
66    pub fn is_optional(&self) -> bool {
67        match self {
68            Self::File { optional, .. } | Self::Glob { optional, .. } => *optional,
69            Self::Text { .. } => false,
70        }
71    }
72}
73
74impl From<&str> for PromptSource {
75    fn from(value: &str) -> Self {
76        Self::file(value)
77    }
78}
79
80impl From<String> for PromptSource {
81    fn from(value: String) -> Self {
82        Self::file(value)
83    }
84}
85
86impl From<PromptSourceObject> for PromptSource {
87    fn from(object: PromptSourceObject) -> Self {
88        match object {
89            PromptSourceObject::Text { text } => Self::Text { text },
90            PromptSourceObject::File { path, optional } => Self::File { path, optional },
91            PromptSourceObject::Glob { pattern, optional } => Self::Glob { pattern, optional },
92        }
93    }
94}
95
96impl<'de> Deserialize<'de> for PromptSource {
97    fn deserialize<T: Deserializer<'de>>(deserializer: T) -> std::result::Result<Self, T::Error> {
98        Ok(match PathOrObject::<PromptSourceObject>::deserialize(deserializer)? {
99            PathOrObject::Path(path) => Self::File { path, optional: false },
100            PathOrObject::Object(object) => object.into(),
101        })
102    }
103}
104
105impl Serialize for PromptSource {
106    fn serialize<T: Serializer>(&self, serializer: T) -> std::result::Result<T::Ok, T::Error> {
107        match self {
108            Self::File { path, optional: false } => path.serialize(serializer),
109            Self::File { path, optional } => {
110                Serialize::serialize(&PromptSourceObject::File { path: path.clone(), optional: *optional }, serializer)
111            }
112            Self::Text { text } => Serialize::serialize(&PromptSourceObject::Text { text: text.clone() }, serializer),
113            Self::Glob { pattern, optional } => Serialize::serialize(
114                &PromptSourceObject::Glob { pattern: pattern.clone(), optional: *optional },
115                serializer,
116            ),
117        }
118    }
119}
120
121impl JsonSchema for PromptSource {
122    fn schema_name() -> std::borrow::Cow<'static, str> {
123        "PromptSource".into()
124    }
125
126    fn json_schema(generator: &mut SchemaGenerator) -> Schema {
127        string_or_object_schema(
128            "Authored description of a prompt source — either a file path string or a typed text, file, or glob object.",
129            &generator.subschema_for::<PromptSourceObject>().to_value(),
130        )
131    }
132}
133
134#[derive(schemars::JsonSchema, serde::Deserialize, serde::Serialize)]
135#[serde(tag = "type", rename_all = "camelCase", deny_unknown_fields)]
136enum PromptSourceObject {
137    Text {
138        text: String,
139    },
140    File {
141        path: ResourcePath,
142        #[serde(default, skip_serializing_if = "is_false")]
143        optional: bool,
144    },
145    Glob {
146        pattern: ResourcePath,
147        #[serde(default, skip_serializing_if = "is_false")]
148        optional: bool,
149    },
150}
151
152/// Errors raised while resolving [`PromptSource`] values into [`Prompt`]s.
153#[derive(Debug, Clone, PartialEq, Eq, Error)]
154pub enum PromptSourceError {
155    /// A glob pattern is syntactically invalid.
156    #[error("Invalid glob pattern '{pattern}': {error}")]
157    InvalidGlobPattern { pattern: String, error: String },
158
159    /// A prompt file does not exist on disk.
160    #[error("Prompt file '{path}' does not exist")]
161    Missing { path: String },
162
163    /// A prompt glob matched no files.
164    #[error("Prompt glob '{pattern}' matched no files")]
165    ZeroMatch { pattern: String },
166
167    /// A `${VAR}` reference in a prompt path could not be resolved.
168    #[error("Prompt entry '{pattern}' references undefined variable '{variable}'")]
169    UnresolvedVariable { pattern: String, variable: String },
170}
171
172impl Prompt {
173    pub fn text(str: &str) -> Self {
174        Self::Text(str.to_string())
175    }
176
177    pub fn file(path: impl Into<PathBuf>, cwd: impl Into<PathBuf>) -> Self {
178        Self::File { path: path.into(), args: None, cwd: cwd.into() }
179    }
180
181    /// Resolve a slice of [`PromptSource`] declarations into runtime [`Prompt`] values.
182    pub fn from_sources(
183        workspace_root: &Path,
184        sources: &[PromptSource],
185    ) -> std::result::Result<Vec<Prompt>, PromptSourceError> {
186        let mut prompts = Vec::new();
187        for source in sources {
188            if let PromptSource::Text { text } = source {
189                prompts.push(Prompt::text(text));
190                continue;
191            }
192            match resolve_source_files(workspace_root, source) {
193                Ok(paths) => {
194                    for path in paths {
195                        prompts.push(Prompt::file(path, workspace_root.to_path_buf()));
196                    }
197                }
198                Err(PromptSourceError::Missing { .. }) if source.is_optional() => {}
199                Err(PromptSourceError::UnresolvedVariable { variable, .. }) if source.is_optional() => {
200                    warn!(
201                        "Skipping optional prompt entry '{}': variable '{variable}' is not defined",
202                        source.path().unwrap_or_default()
203                    );
204                }
205                Err(error) => return Err(error),
206            }
207        }
208        Ok(prompts)
209    }
210
211    /// Resolve this `SystemPrompt` to a String
212    pub async fn build(&self) -> Result<String> {
213        match self {
214            Prompt::Text(text) => Ok(text.clone()),
215            Prompt::File { path, args, cwd } => {
216                let content = Self::resolve_file(path).await?;
217                let substituted = substitute_parameters(&content, args);
218                let expander = ShellExpander::new();
219                Ok(expander.expand(&substituted, cwd).await)
220            }
221            Prompt::McpInstructions(instructions) => Ok(format_mcp_instructions(instructions)),
222        }
223    }
224
225    /// Resolve multiple `SystemPrompts` and join them with double newlines
226    pub async fn build_all(prompts: &[Prompt]) -> Result<String> {
227        let mut parts = Vec::with_capacity(prompts.len());
228        for p in prompts {
229            let part = p.build().await?;
230            if !part.is_empty() {
231                parts.push(part);
232            }
233        }
234        Ok(parts.join("\n\n"))
235    }
236
237    async fn resolve_file(path: &Path) -> Result<String> {
238        fs::read_to_string(path)
239            .await
240            .map_err(|e| AgentError::IoError(format!("Failed to read file '{}': {e}", path.display())))
241    }
242}
243
244fn resolve_source_files(
245    workspace_root: &Path,
246    source: &PromptSource,
247) -> std::result::Result<Vec<PathBuf>, PromptSourceError> {
248    match source {
249        PromptSource::Text { .. } => Ok(Vec::new()),
250        PromptSource::File { path, .. } => {
251            let full_path = resolve_path(path, workspace_root)?;
252            if full_path.is_file() {
253                Ok(vec![full_path])
254            } else {
255                Err(PromptSourceError::Missing { path: path.as_authored().to_string() })
256            }
257        }
258        PromptSource::Glob { pattern, optional } => {
259            let full_pattern = resolve_path(pattern, workspace_root)?;
260            let mut paths: Vec<PathBuf> = glob(&full_pattern.to_string_lossy())
261                .map_err(|e| PromptSourceError::InvalidGlobPattern {
262                    pattern: pattern.as_authored().to_string(),
263                    error: e.to_string(),
264                })?
265                .filter_map(std::result::Result::ok)
266                .filter(|path| path.is_file())
267                .collect();
268            paths.sort();
269            if paths.is_empty() && !*optional {
270                Err(PromptSourceError::ZeroMatch { pattern: pattern.as_authored().to_string() })
271            } else {
272                Ok(paths)
273            }
274        }
275    }
276}
277
278fn resolve_path(path: &ResourcePath, workspace_root: &Path) -> std::result::Result<PathBuf, PromptSourceError> {
279    path.resolve(workspace_root).map_err(|VarError::NotFound(variable)| PromptSourceError::UnresolvedVariable {
280        pattern: path.as_authored().to_string(),
281        variable,
282    })
283}
284
285pub struct PromptCache {
286    prompts: Vec<Prompt>,
287    entries: Vec<(Prompt, String)>,
288}
289
290impl PromptCache {
291    pub fn new(mut prompts: Vec<Prompt>) -> Self {
292        if !prompts.iter().any(|p| matches!(p, Prompt::McpInstructions(_))) {
293            prompts.push(Prompt::McpInstructions(BTreeMap::new()));
294        }
295        Self { prompts, entries: Vec::new() }
296    }
297
298    pub fn update_mcp_instruction(&mut self, server: String, body: Option<String>) {
299        for prompt in &mut self.prompts {
300            if let Prompt::McpInstructions(map) = prompt {
301                match body {
302                    Some(text) => {
303                        map.insert(server, text);
304                    }
305                    None => {
306                        map.remove(&server);
307                    }
308                }
309                return;
310            }
311        }
312    }
313
314    pub async fn render(&mut self) -> Result<String> {
315        self.entries.truncate(self.prompts.len());
316        let mut rendered_prompt = String::new();
317        for i in 0..self.prompts.len() {
318            let prompt = &self.prompts[i];
319            match self.entries.get_mut(i) {
320                Some((cached, _)) if *cached == *prompt => {}
321                Some(entry) => *entry = (prompt.clone(), prompt.build().await?),
322                None => self.entries.push((prompt.clone(), prompt.build().await?)),
323            }
324
325            let (_, body) = &self.entries[i];
326            if !body.is_empty() {
327                if !rendered_prompt.is_empty() {
328                    rendered_prompt.push_str("\n\n");
329                }
330                rendered_prompt.push_str(body);
331            }
332        }
333        Ok(rendered_prompt)
334    }
335}
336
337/// Format MCP instructions with XML tags for the system prompt.
338fn format_mcp_instructions(instructions: &BTreeMap<String, String>) -> String {
339    if instructions.is_empty() {
340        return String::new();
341    }
342
343    let mut parts = vec!["# MCP Server Instructions\n".to_string()];
344    parts.push("You are connected to the following MCP servers:\n".to_string());
345
346    for (server_name, body) in instructions {
347        parts.push(format!("<mcp-server name=\"{server_name}\">\n{body}\n</mcp-server>\n"));
348    }
349
350    parts.join("\n")
351}
352
353#[cfg(test)]
354mod tests {
355    use std::fs::{create_dir_all, write};
356
357    use super::*;
358    use crate::testing::mcp_instructions as instructions;
359
360    #[tokio::test]
361    async fn build_text_prompt() {
362        let prompt = Prompt::text("Hello, world!");
363        let result = prompt.build().await.unwrap();
364        assert_eq!(result, "Hello, world!");
365    }
366
367    #[tokio::test]
368    async fn build_all_concatenates_prompts() {
369        let prompts = vec![Prompt::text("Part one"), Prompt::text("Part two")];
370        let result = Prompt::build_all(&prompts).await.unwrap();
371        assert_eq!(result, "Part one\n\nPart two");
372    }
373
374    #[tokio::test]
375    async fn build_all_concatenates_multiple_files() {
376        let dir = tempfile::tempdir().unwrap();
377        std::fs::write(dir.path().join("AGENTS.md"), "Agent instructions").unwrap();
378        std::fs::write(dir.path().join("SYSTEM.md"), "System prompt").unwrap();
379
380        let prompts = vec![
381            Prompt::file(dir.path().join("AGENTS.md"), dir.path()),
382            Prompt::file(dir.path().join("SYSTEM.md"), dir.path()),
383        ];
384        let result = Prompt::build_all(&prompts).await.unwrap();
385        assert!(result.contains("Agent instructions"));
386        assert!(result.contains("System prompt"));
387        assert!(result.contains("\n\n"));
388    }
389
390    #[tokio::test]
391    async fn build_all_skips_empty_parts() {
392        let prompts = vec![Prompt::text("Part one"), Prompt::text(""), Prompt::text("Part two")];
393        let result = Prompt::build_all(&prompts).await.unwrap();
394        assert_eq!(result, "Part one\n\nPart two");
395    }
396
397    #[tokio::test]
398    async fn prompt_cache_render_matches_build_all_on_first_render() {
399        let prompts = vec![
400            Prompt::text("first"),
401            Prompt::McpInstructions(instructions(&[("srv", "body")])),
402            Prompt::text("last"),
403        ];
404        let expected = Prompt::build_all(&prompts).await.unwrap();
405        let mut cache = PromptCache::new(prompts);
406        assert_eq!(cache.render().await.unwrap(), expected);
407    }
408
409    #[tokio::test]
410    async fn prompt_cache_reuses_unchanged_slots() {
411        use std::fs::{remove_file, write};
412
413        let dir = tempfile::tempdir().unwrap();
414        write(dir.path().join("AGENTS.md"), "cached body").unwrap();
415        let mut cache = PromptCache::new(vec![
416            Prompt::file(dir.path().join("AGENTS.md"), dir.path()),
417            Prompt::McpInstructions(BTreeMap::new()),
418        ]);
419
420        cache.render().await.unwrap();
421
422        // Remove the source file to prove we cached things
423        remove_file(dir.path().join("AGENTS.md")).unwrap();
424        cache.update_mcp_instruction("srv".into(), Some("instr".into()));
425
426        let rendered = cache.render().await.unwrap();
427        assert!(rendered.contains("cached body"));
428        assert!(rendered.contains("instr"));
429    }
430
431    #[tokio::test]
432    async fn prompt_cache_empty_renders_empty() {
433        assert_eq!(PromptCache::new(vec![]).render().await.unwrap(), "");
434    }
435
436    #[tokio::test]
437    async fn prompt_cache_drops_empty_slots() {
438        let mut cache = PromptCache::new(vec![Prompt::text("a"), Prompt::text("b")]);
439        assert_eq!(cache.render().await.unwrap(), "a\n\nb");
440    }
441
442    #[tokio::test]
443    async fn build_file_expands_shell_commands() {
444        let dir = tempfile::tempdir().unwrap();
445        write(dir.path().join("AGENTS.md"), "Instructions\n\nbranch: !`echo main`\n\nRules").unwrap();
446
447        let prompt = Prompt::file(dir.path().join("AGENTS.md"), dir.path().to_path_buf());
448        let result = prompt.build().await.unwrap();
449        assert!(result.contains("Instructions"));
450        assert!(result.contains("branch: main"));
451        assert!(result.contains("Rules"));
452        assert!(!result.contains("!`"));
453    }
454
455    #[tokio::test]
456    async fn build_file_runs_shell_in_cwd() {
457        let dir = tempfile::tempdir().unwrap();
458        write(dir.path().join("sentinel.txt"), "").unwrap();
459        let prompt_path = dir.path().join("AGENTS.md");
460        write(&prompt_path, "files: !`ls`").unwrap();
461
462        let prompt = Prompt::file(prompt_path, dir.path().to_path_buf());
463        let result = prompt.build().await.unwrap();
464        assert!(result.contains("sentinel.txt"), "expected sentinel.txt in output: {result}");
465    }
466
467    #[tokio::test]
468    async fn build_file_handles_multiple_commands() {
469        let dir = tempfile::tempdir().unwrap();
470        let prompt_path = dir.path().join("AGENTS.md");
471        write(&prompt_path, "a=!`echo one`, b=!`echo two`").unwrap();
472
473        let prompt = Prompt::file(prompt_path, dir.path().to_path_buf());
474        let result = prompt.build().await.unwrap();
475        assert_eq!(result, "a=one, b=two");
476    }
477
478    #[tokio::test]
479    async fn build_file_substitutes_empty_on_failure() {
480        let dir = tempfile::tempdir().unwrap();
481        let prompt_path = dir.path().join("AGENTS.md");
482        write(&prompt_path, "before !`exit 1` after").unwrap();
483
484        let prompt = Prompt::file(prompt_path, dir.path().to_path_buf());
485        let result = prompt.build().await.unwrap();
486        assert_eq!(result, "before  after");
487    }
488
489    #[tokio::test]
490    async fn build_file_trims_trailing_whitespace() {
491        let dir = tempfile::tempdir().unwrap();
492        let prompt_path = dir.path().join("AGENTS.md");
493        write(&prompt_path, "!`printf 'hi\\n\\n'`").unwrap();
494
495        let prompt = Prompt::file(prompt_path, dir.path().to_path_buf());
496        let result = prompt.build().await.unwrap();
497        assert_eq!(result, "hi");
498    }
499
500    #[test]
501    fn optional_file_source_skips_missing_file() {
502        let dir = tempfile::tempdir().unwrap();
503        write(dir.path().join("EXISTS.md"), "exists").unwrap();
504
505        let sources = vec![PromptSource::file("EXISTS.md"), PromptSource::file("MISSING.md").optional()];
506        let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
507        assert_eq!(prompts.len(), 1);
508    }
509
510    #[test]
511    fn optional_glob_source_skips_zero_matches() {
512        let dir = tempfile::tempdir().unwrap();
513        write(dir.path().join("EXISTS.md"), "exists").unwrap();
514
515        let sources = vec![PromptSource::file("EXISTS.md"), PromptSource::glob("nonexistent*.md").optional()];
516        let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
517        assert_eq!(prompts.len(), 1);
518    }
519
520    #[test]
521    fn required_glob_source_expands_to_one_prompt_per_match() {
522        let dir = tempfile::tempdir().unwrap();
523        let rules_dir = dir.path().join(".aether/rules");
524        create_dir_all(&rules_dir).unwrap();
525        write(rules_dir.join("a.md"), "a").unwrap();
526        write(rules_dir.join("b.md"), "b").unwrap();
527
528        let sources = vec![PromptSource::glob(".aether/rules/*.md")];
529        let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
530        assert_eq!(prompts.len(), 2);
531    }
532
533    #[test]
534    fn required_glob_source_with_no_matches_errors() {
535        let dir = tempfile::tempdir().unwrap();
536        let sources = vec![PromptSource::glob("nonexistent*.md")];
537        let err = Prompt::from_sources(dir.path(), &sources).unwrap_err();
538        assert!(matches!(err, PromptSourceError::ZeroMatch { .. }));
539    }
540
541    #[test]
542    fn optional_glob_source_still_errors_on_invalid_pattern() {
543        let dir = tempfile::tempdir().unwrap();
544        let sources = vec![PromptSource::glob("[invalid").optional()];
545        let err = Prompt::from_sources(dir.path(), &sources).unwrap_err();
546        assert!(matches!(err, PromptSourceError::InvalidGlobPattern { .. }));
547    }
548
549    #[test]
550    fn optional_file_source_skips_unresolved_variable() {
551        let dir = tempfile::tempdir().unwrap();
552        write(dir.path().join("EXISTS.md"), "exists").unwrap();
553
554        let sources = vec![
555            PromptSource::file("EXISTS.md"),
556            PromptSource::file("${DEFINITELY_NOT_SET_VAR_PROMPT_FILE}/foo.md").optional(),
557        ];
558        let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
559        assert_eq!(prompts.len(), 1);
560    }
561
562    #[test]
563    fn optional_glob_source_skips_unresolved_variable() {
564        let dir = tempfile::tempdir().unwrap();
565        write(dir.path().join("EXISTS.md"), "exists").unwrap();
566
567        let sources = vec![
568            PromptSource::file("EXISTS.md"),
569            PromptSource::glob("${DEFINITELY_NOT_SET_VAR_PROMPT_GLOB}/*.md").optional(),
570        ];
571        let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
572        assert_eq!(prompts.len(), 1);
573    }
574
575    #[test]
576    fn required_file_source_errors_on_unresolved_variable() {
577        let dir = tempfile::tempdir().unwrap();
578        let sources = vec![PromptSource::file("${DEFINITELY_NOT_SET_VAR_PROMPT_REQ}/foo.md")];
579        let err = Prompt::from_sources(dir.path(), &sources).unwrap_err();
580        assert!(matches!(err, PromptSourceError::UnresolvedVariable { .. }));
581    }
582
583    #[test]
584    fn prompt_source_string_shorthand_is_required_file() {
585        let source: PromptSource = serde_json::from_str(r#""SYSTEM.md""#).unwrap();
586        assert_eq!(source, PromptSource::file("SYSTEM.md"));
587    }
588
589    #[test]
590    fn optional_prompt_source_serializes_as_typed_object() {
591        let source = PromptSource::file("${WORKSPACE}/AGENTS.md").optional();
592        let value = serde_json::to_value(&source).unwrap();
593        assert_eq!(value, serde_json::json!({"type":"file","path":"${WORKSPACE}/AGENTS.md","optional":true}));
594
595        // Non-optional file stays as string shorthand
596        let source = PromptSource::file("SYSTEM.md");
597        let value = serde_json::to_value(&source).unwrap();
598        assert_eq!(value, serde_json::json!("SYSTEM.md"));
599    }
600
601    #[test]
602    fn optional_prompt_source_deserializes_from_typed_object() {
603        let source: PromptSource =
604            serde_json::from_str(r#"{"type":"file","path":"${WORKSPACE}/AGENTS.md","optional":true}"#).unwrap();
605        assert_eq!(source, PromptSource::file("${WORKSPACE}/AGENTS.md").optional());
606    }
607
608    #[test]
609    fn optional_glob_source_deserializes_from_typed_object() {
610        let source: PromptSource =
611            serde_json::from_str(r#"{"type":"glob","pattern":"${WORKSPACE}/.aether/rules/*.md","optional":true}"#)
612                .unwrap();
613        assert_eq!(source, PromptSource::glob("${WORKSPACE}/.aether/rules/*.md").optional());
614    }
615}