1use crate::core::{AgentError, Result};
2use glob::glob;
3use schemars::{JsonSchema, Schema, SchemaGenerator};
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5use std::collections::{BTreeMap, HashMap};
6use std::path::{Path, PathBuf};
7use thiserror::Error;
8use tokio::fs;
9use tracing::warn;
10use utils::shell_expander::ShellExpander;
11use utils::substitution::substitute_parameters;
12use utils::variables::VarError;
13use utils::{PathOrObject, ResourcePath, is_false, string_or_object_schema};
14
15#[derive(Debug, Clone, PartialEq)]
16pub enum Prompt {
17 Text(String),
18 File {
19 path: PathBuf,
20 args: Option<HashMap<String, String>>,
21 cwd: PathBuf,
22 },
23 McpInstructions(BTreeMap<String, String>),
26}
27
28#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum PromptSource {
34 Text { text: String },
35 File { path: ResourcePath, optional: bool },
36 Glob { pattern: ResourcePath, optional: bool },
37}
38
39impl PromptSource {
40 pub fn file(path: impl Into<ResourcePath>) -> Self {
41 Self::File { path: path.into(), optional: false }
42 }
43
44 pub fn glob(pattern: impl Into<ResourcePath>) -> Self {
45 Self::Glob { pattern: pattern.into(), optional: false }
46 }
47
48 #[must_use]
49 pub fn optional(self) -> Self {
50 match self {
51 Self::File { path, .. } => Self::File { path, optional: true },
52 Self::Glob { pattern, .. } => Self::Glob { pattern, optional: true },
53 Self::Text { .. } => self,
54 }
55 }
56
57 pub fn path(&self) -> Option<&str> {
59 match self {
60 Self::File { path, .. } => Some(path.as_authored()),
61 Self::Glob { pattern, .. } => Some(pattern.as_authored()),
62 Self::Text { .. } => None,
63 }
64 }
65
66 pub fn is_optional(&self) -> bool {
67 match self {
68 Self::File { optional, .. } | Self::Glob { optional, .. } => *optional,
69 Self::Text { .. } => false,
70 }
71 }
72}
73
74impl From<&str> for PromptSource {
75 fn from(value: &str) -> Self {
76 Self::file(value)
77 }
78}
79
80impl From<String> for PromptSource {
81 fn from(value: String) -> Self {
82 Self::file(value)
83 }
84}
85
86impl From<PromptSourceObject> for PromptSource {
87 fn from(object: PromptSourceObject) -> Self {
88 match object {
89 PromptSourceObject::Text { text } => Self::Text { text },
90 PromptSourceObject::File { path, optional } => Self::File { path, optional },
91 PromptSourceObject::Glob { pattern, optional } => Self::Glob { pattern, optional },
92 }
93 }
94}
95
96impl<'de> Deserialize<'de> for PromptSource {
97 fn deserialize<T: Deserializer<'de>>(deserializer: T) -> std::result::Result<Self, T::Error> {
98 Ok(match PathOrObject::<PromptSourceObject>::deserialize(deserializer)? {
99 PathOrObject::Path(path) => Self::File { path, optional: false },
100 PathOrObject::Object(object) => object.into(),
101 })
102 }
103}
104
105impl Serialize for PromptSource {
106 fn serialize<T: Serializer>(&self, serializer: T) -> std::result::Result<T::Ok, T::Error> {
107 match self {
108 Self::File { path, optional: false } => path.serialize(serializer),
109 Self::File { path, optional } => {
110 Serialize::serialize(&PromptSourceObject::File { path: path.clone(), optional: *optional }, serializer)
111 }
112 Self::Text { text } => Serialize::serialize(&PromptSourceObject::Text { text: text.clone() }, serializer),
113 Self::Glob { pattern, optional } => Serialize::serialize(
114 &PromptSourceObject::Glob { pattern: pattern.clone(), optional: *optional },
115 serializer,
116 ),
117 }
118 }
119}
120
121impl JsonSchema for PromptSource {
122 fn schema_name() -> std::borrow::Cow<'static, str> {
123 "PromptSource".into()
124 }
125
126 fn json_schema(generator: &mut SchemaGenerator) -> Schema {
127 string_or_object_schema(
128 include_str!("../docs/prompt_source.md"),
129 &generator.subschema_for::<PromptSourceObject>().to_value(),
130 )
131 }
132}
133
134#[derive(schemars::JsonSchema, serde::Deserialize, serde::Serialize)]
135#[serde(tag = "type", rename_all = "camelCase", deny_unknown_fields)]
136enum PromptSourceObject {
137 Text {
138 text: String,
140 },
141 File {
142 path: ResourcePath,
144 #[serde(default, skip_serializing_if = "is_false")]
146 optional: bool,
147 },
148 Glob {
149 pattern: ResourcePath,
151 #[serde(default, skip_serializing_if = "is_false")]
153 optional: bool,
154 },
155}
156
157#[derive(Debug, Clone, PartialEq, Eq, Error)]
159pub enum PromptSourceError {
160 #[error("Invalid glob pattern '{pattern}': {error}")]
162 InvalidGlobPattern { pattern: String, error: String },
163
164 #[error("Prompt file '{path}' does not exist")]
166 Missing { path: String },
167
168 #[error("Prompt glob '{pattern}' matched no files")]
170 ZeroMatch { pattern: String },
171
172 #[error("Prompt entry '{pattern}' references undefined variable '{variable}'")]
174 UnresolvedVariable { pattern: String, variable: String },
175}
176
177impl Prompt {
178 pub fn text(str: &str) -> Self {
179 Self::Text(str.to_string())
180 }
181
182 pub fn file(path: impl Into<PathBuf>, cwd: impl Into<PathBuf>) -> Self {
183 Self::File { path: path.into(), args: None, cwd: cwd.into() }
184 }
185
186 pub fn from_sources(
188 workspace_root: &Path,
189 sources: &[PromptSource],
190 ) -> std::result::Result<Vec<Prompt>, PromptSourceError> {
191 let mut prompts = Vec::new();
192 for source in sources {
193 if let PromptSource::Text { text } = source {
194 prompts.push(Prompt::text(text));
195 continue;
196 }
197 match resolve_source_files(workspace_root, source) {
198 Ok(paths) => {
199 for path in paths {
200 prompts.push(Prompt::file(path, workspace_root.to_path_buf()));
201 }
202 }
203 Err(PromptSourceError::Missing { .. }) if source.is_optional() => {}
204 Err(PromptSourceError::UnresolvedVariable { variable, .. }) if source.is_optional() => {
205 warn!(
206 "Skipping optional prompt entry '{}': variable '{variable}' is not defined",
207 source.path().unwrap_or_default()
208 );
209 }
210 Err(error) => return Err(error),
211 }
212 }
213 Ok(prompts)
214 }
215
216 pub async fn build(&self) -> Result<String> {
218 match self {
219 Prompt::Text(text) => Ok(text.clone()),
220 Prompt::File { path, args, cwd } => {
221 let content = Self::resolve_file(path).await?;
222 let substituted = substitute_parameters(&content, args);
223 let expander = ShellExpander::new();
224 Ok(expander.expand(&substituted, cwd).await)
225 }
226 Prompt::McpInstructions(instructions) => Ok(format_mcp_instructions(instructions)),
227 }
228 }
229
230 pub async fn build_all(prompts: &[Prompt]) -> Result<String> {
232 let mut parts = Vec::with_capacity(prompts.len());
233 for p in prompts {
234 let part = p.build().await?;
235 if !part.is_empty() {
236 parts.push(part);
237 }
238 }
239 Ok(parts.join("\n\n"))
240 }
241
242 async fn resolve_file(path: &Path) -> Result<String> {
243 fs::read_to_string(path)
244 .await
245 .map_err(|e| AgentError::IoError(format!("Failed to read file '{}': {e}", path.display())))
246 }
247}
248
249fn resolve_source_files(
250 workspace_root: &Path,
251 source: &PromptSource,
252) -> std::result::Result<Vec<PathBuf>, PromptSourceError> {
253 match source {
254 PromptSource::Text { .. } => Ok(Vec::new()),
255 PromptSource::File { path, .. } => {
256 let full_path = resolve_path(path, workspace_root)?;
257 if full_path.is_file() {
258 Ok(vec![full_path])
259 } else {
260 Err(PromptSourceError::Missing { path: path.as_authored().to_string() })
261 }
262 }
263 PromptSource::Glob { pattern, optional } => {
264 let full_pattern = resolve_path(pattern, workspace_root)?;
265 let mut paths: Vec<PathBuf> = glob(&full_pattern.to_string_lossy())
266 .map_err(|e| PromptSourceError::InvalidGlobPattern {
267 pattern: pattern.as_authored().to_string(),
268 error: e.to_string(),
269 })?
270 .filter_map(std::result::Result::ok)
271 .filter(|path| path.is_file())
272 .collect();
273 paths.sort();
274 if paths.is_empty() && !*optional {
275 Err(PromptSourceError::ZeroMatch { pattern: pattern.as_authored().to_string() })
276 } else {
277 Ok(paths)
278 }
279 }
280 }
281}
282
283fn resolve_path(path: &ResourcePath, workspace_root: &Path) -> std::result::Result<PathBuf, PromptSourceError> {
284 path.resolve(workspace_root).map_err(|VarError::NotFound(variable)| PromptSourceError::UnresolvedVariable {
285 pattern: path.as_authored().to_string(),
286 variable,
287 })
288}
289
290pub struct PromptCache {
291 prompts: Vec<Prompt>,
292 entries: Vec<(Prompt, String)>,
293}
294
295impl PromptCache {
296 pub fn new(mut prompts: Vec<Prompt>) -> Self {
297 if !prompts.iter().any(|p| matches!(p, Prompt::McpInstructions(_))) {
298 prompts.push(Prompt::McpInstructions(BTreeMap::new()));
299 }
300 Self { prompts, entries: Vec::new() }
301 }
302
303 pub fn update_mcp_instruction(&mut self, server: String, body: Option<String>) {
304 for prompt in &mut self.prompts {
305 if let Prompt::McpInstructions(map) = prompt {
306 match body {
307 Some(text) => {
308 map.insert(server, text);
309 }
310 None => {
311 map.remove(&server);
312 }
313 }
314 return;
315 }
316 }
317 }
318
319 pub async fn render(&mut self) -> Result<String> {
320 self.entries.truncate(self.prompts.len());
321 let mut rendered_prompt = String::new();
322 for i in 0..self.prompts.len() {
323 let prompt = &self.prompts[i];
324 match self.entries.get_mut(i) {
325 Some((cached, _)) if *cached == *prompt => {}
326 Some(entry) => *entry = (prompt.clone(), prompt.build().await?),
327 None => self.entries.push((prompt.clone(), prompt.build().await?)),
328 }
329
330 let (_, body) = &self.entries[i];
331 if !body.is_empty() {
332 if !rendered_prompt.is_empty() {
333 rendered_prompt.push_str("\n\n");
334 }
335 rendered_prompt.push_str(body);
336 }
337 }
338 Ok(rendered_prompt)
339 }
340}
341
342fn format_mcp_instructions(instructions: &BTreeMap<String, String>) -> String {
344 if instructions.is_empty() {
345 return String::new();
346 }
347
348 let mut parts = vec!["# MCP Server Instructions\n".to_string()];
349 parts.push("You are connected to the following MCP servers:\n".to_string());
350
351 for (server_name, body) in instructions {
352 parts.push(format!("<mcp-server name=\"{server_name}\">\n{body}\n</mcp-server>\n"));
353 }
354
355 parts.join("\n")
356}
357
358#[cfg(test)]
359mod tests {
360 use std::fs::{create_dir_all, write};
361
362 use super::*;
363 use crate::testing::mcp_instructions as instructions;
364
365 #[tokio::test]
366 async fn build_text_prompt() {
367 let prompt = Prompt::text("Hello, world!");
368 let result = prompt.build().await.unwrap();
369 assert_eq!(result, "Hello, world!");
370 }
371
372 #[tokio::test]
373 async fn build_all_concatenates_prompts() {
374 let prompts = vec![Prompt::text("Part one"), Prompt::text("Part two")];
375 let result = Prompt::build_all(&prompts).await.unwrap();
376 assert_eq!(result, "Part one\n\nPart two");
377 }
378
379 #[tokio::test]
380 async fn build_all_concatenates_multiple_files() {
381 let dir = tempfile::tempdir().unwrap();
382 std::fs::write(dir.path().join("AGENTS.md"), "Agent instructions").unwrap();
383 std::fs::write(dir.path().join("SYSTEM.md"), "System prompt").unwrap();
384
385 let prompts = vec![
386 Prompt::file(dir.path().join("AGENTS.md"), dir.path()),
387 Prompt::file(dir.path().join("SYSTEM.md"), dir.path()),
388 ];
389 let result = Prompt::build_all(&prompts).await.unwrap();
390 assert!(result.contains("Agent instructions"));
391 assert!(result.contains("System prompt"));
392 assert!(result.contains("\n\n"));
393 }
394
395 #[tokio::test]
396 async fn build_all_skips_empty_parts() {
397 let prompts = vec![Prompt::text("Part one"), Prompt::text(""), Prompt::text("Part two")];
398 let result = Prompt::build_all(&prompts).await.unwrap();
399 assert_eq!(result, "Part one\n\nPart two");
400 }
401
402 #[tokio::test]
403 async fn prompt_cache_render_matches_build_all_on_first_render() {
404 let prompts = vec![
405 Prompt::text("first"),
406 Prompt::McpInstructions(instructions(&[("srv", "body")])),
407 Prompt::text("last"),
408 ];
409 let expected = Prompt::build_all(&prompts).await.unwrap();
410 let mut cache = PromptCache::new(prompts);
411 assert_eq!(cache.render().await.unwrap(), expected);
412 }
413
414 #[tokio::test]
415 async fn prompt_cache_reuses_unchanged_slots() {
416 use std::fs::{remove_file, write};
417
418 let dir = tempfile::tempdir().unwrap();
419 write(dir.path().join("AGENTS.md"), "cached body").unwrap();
420 let mut cache = PromptCache::new(vec![
421 Prompt::file(dir.path().join("AGENTS.md"), dir.path()),
422 Prompt::McpInstructions(BTreeMap::new()),
423 ]);
424
425 cache.render().await.unwrap();
426
427 remove_file(dir.path().join("AGENTS.md")).unwrap();
429 cache.update_mcp_instruction("srv".into(), Some("instr".into()));
430
431 let rendered = cache.render().await.unwrap();
432 assert!(rendered.contains("cached body"));
433 assert!(rendered.contains("instr"));
434 }
435
436 #[tokio::test]
437 async fn prompt_cache_empty_renders_empty() {
438 assert_eq!(PromptCache::new(vec![]).render().await.unwrap(), "");
439 }
440
441 #[tokio::test]
442 async fn prompt_cache_drops_empty_slots() {
443 let mut cache = PromptCache::new(vec![Prompt::text("a"), Prompt::text("b")]);
444 assert_eq!(cache.render().await.unwrap(), "a\n\nb");
445 }
446
447 #[tokio::test]
448 async fn build_file_expands_shell_commands() {
449 let dir = tempfile::tempdir().unwrap();
450 write(dir.path().join("AGENTS.md"), "Instructions\n\nbranch: !`echo main`\n\nRules").unwrap();
451
452 let prompt = Prompt::file(dir.path().join("AGENTS.md"), dir.path().to_path_buf());
453 let result = prompt.build().await.unwrap();
454 assert!(result.contains("Instructions"));
455 assert!(result.contains("branch: main"));
456 assert!(result.contains("Rules"));
457 assert!(!result.contains("!`"));
458 }
459
460 #[tokio::test]
461 async fn build_file_runs_shell_in_cwd() {
462 let dir = tempfile::tempdir().unwrap();
463 write(dir.path().join("sentinel.txt"), "").unwrap();
464 let prompt_path = dir.path().join("AGENTS.md");
465 write(&prompt_path, "files: !`ls`").unwrap();
466
467 let prompt = Prompt::file(prompt_path, dir.path().to_path_buf());
468 let result = prompt.build().await.unwrap();
469 assert!(result.contains("sentinel.txt"), "expected sentinel.txt in output: {result}");
470 }
471
472 #[tokio::test]
473 async fn build_file_handles_multiple_commands() {
474 let dir = tempfile::tempdir().unwrap();
475 let prompt_path = dir.path().join("AGENTS.md");
476 write(&prompt_path, "a=!`echo one`, b=!`echo two`").unwrap();
477
478 let prompt = Prompt::file(prompt_path, dir.path().to_path_buf());
479 let result = prompt.build().await.unwrap();
480 assert_eq!(result, "a=one, b=two");
481 }
482
483 #[tokio::test]
484 async fn build_file_substitutes_empty_on_failure() {
485 let dir = tempfile::tempdir().unwrap();
486 let prompt_path = dir.path().join("AGENTS.md");
487 write(&prompt_path, "before !`exit 1` after").unwrap();
488
489 let prompt = Prompt::file(prompt_path, dir.path().to_path_buf());
490 let result = prompt.build().await.unwrap();
491 assert_eq!(result, "before after");
492 }
493
494 #[tokio::test]
495 async fn build_file_trims_trailing_whitespace() {
496 let dir = tempfile::tempdir().unwrap();
497 let prompt_path = dir.path().join("AGENTS.md");
498 write(&prompt_path, "!`printf 'hi\\n\\n'`").unwrap();
499
500 let prompt = Prompt::file(prompt_path, dir.path().to_path_buf());
501 let result = prompt.build().await.unwrap();
502 assert_eq!(result, "hi");
503 }
504
505 #[test]
506 fn optional_file_source_skips_missing_file() {
507 let dir = tempfile::tempdir().unwrap();
508 write(dir.path().join("EXISTS.md"), "exists").unwrap();
509
510 let sources = vec![PromptSource::file("EXISTS.md"), PromptSource::file("MISSING.md").optional()];
511 let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
512 assert_eq!(prompts.len(), 1);
513 }
514
515 #[test]
516 fn optional_glob_source_skips_zero_matches() {
517 let dir = tempfile::tempdir().unwrap();
518 write(dir.path().join("EXISTS.md"), "exists").unwrap();
519
520 let sources = vec![PromptSource::file("EXISTS.md"), PromptSource::glob("nonexistent*.md").optional()];
521 let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
522 assert_eq!(prompts.len(), 1);
523 }
524
525 #[test]
526 fn required_glob_source_expands_to_one_prompt_per_match() {
527 let dir = tempfile::tempdir().unwrap();
528 let rules_dir = dir.path().join(".aether/rules");
529 create_dir_all(&rules_dir).unwrap();
530 write(rules_dir.join("a.md"), "a").unwrap();
531 write(rules_dir.join("b.md"), "b").unwrap();
532
533 let sources = vec![PromptSource::glob(".aether/rules/*.md")];
534 let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
535 assert_eq!(prompts.len(), 2);
536 }
537
538 #[test]
539 fn required_glob_source_with_no_matches_errors() {
540 let dir = tempfile::tempdir().unwrap();
541 let sources = vec![PromptSource::glob("nonexistent*.md")];
542 let err = Prompt::from_sources(dir.path(), &sources).unwrap_err();
543 assert!(matches!(err, PromptSourceError::ZeroMatch { .. }));
544 }
545
546 #[test]
547 fn optional_glob_source_still_errors_on_invalid_pattern() {
548 let dir = tempfile::tempdir().unwrap();
549 let sources = vec![PromptSource::glob("[invalid").optional()];
550 let err = Prompt::from_sources(dir.path(), &sources).unwrap_err();
551 assert!(matches!(err, PromptSourceError::InvalidGlobPattern { .. }));
552 }
553
554 #[test]
555 fn optional_file_source_skips_unresolved_variable() {
556 let dir = tempfile::tempdir().unwrap();
557 write(dir.path().join("EXISTS.md"), "exists").unwrap();
558
559 let sources = vec![
560 PromptSource::file("EXISTS.md"),
561 PromptSource::file("${DEFINITELY_NOT_SET_VAR_PROMPT_FILE}/foo.md").optional(),
562 ];
563 let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
564 assert_eq!(prompts.len(), 1);
565 }
566
567 #[test]
568 fn optional_glob_source_skips_unresolved_variable() {
569 let dir = tempfile::tempdir().unwrap();
570 write(dir.path().join("EXISTS.md"), "exists").unwrap();
571
572 let sources = vec![
573 PromptSource::file("EXISTS.md"),
574 PromptSource::glob("${DEFINITELY_NOT_SET_VAR_PROMPT_GLOB}/*.md").optional(),
575 ];
576 let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
577 assert_eq!(prompts.len(), 1);
578 }
579
580 #[test]
581 fn required_file_source_errors_on_unresolved_variable() {
582 let dir = tempfile::tempdir().unwrap();
583 let sources = vec![PromptSource::file("${DEFINITELY_NOT_SET_VAR_PROMPT_REQ}/foo.md")];
584 let err = Prompt::from_sources(dir.path(), &sources).unwrap_err();
585 assert!(matches!(err, PromptSourceError::UnresolvedVariable { .. }));
586 }
587
588 #[test]
589 fn prompt_source_string_shorthand_is_required_file() {
590 let source: PromptSource = serde_json::from_str(r#""SYSTEM.md""#).unwrap();
591 assert_eq!(source, PromptSource::file("SYSTEM.md"));
592 }
593
594 #[test]
595 fn optional_prompt_source_serializes_as_typed_object() {
596 let source = PromptSource::file("${WORKSPACE}/AGENTS.md").optional();
597 let value = serde_json::to_value(&source).unwrap();
598 assert_eq!(value, serde_json::json!({"type":"file","path":"${WORKSPACE}/AGENTS.md","optional":true}));
599
600 let source = PromptSource::file("SYSTEM.md");
602 let value = serde_json::to_value(&source).unwrap();
603 assert_eq!(value, serde_json::json!("SYSTEM.md"));
604 }
605
606 #[test]
607 fn optional_prompt_source_deserializes_from_typed_object() {
608 let source: PromptSource =
609 serde_json::from_str(r#"{"type":"file","path":"${WORKSPACE}/AGENTS.md","optional":true}"#).unwrap();
610 assert_eq!(source, PromptSource::file("${WORKSPACE}/AGENTS.md").optional());
611 }
612
613 #[test]
614 fn optional_glob_source_deserializes_from_typed_object() {
615 let source: PromptSource =
616 serde_json::from_str(r#"{"type":"glob","pattern":"${WORKSPACE}/.aether/rules/*.md","optional":true}"#)
617 .unwrap();
618 assert_eq!(source, PromptSource::glob("${WORKSPACE}/.aether/rules/*.md").optional());
619 }
620}