1use crate::core::{AgentError, Result};
2use glob::glob;
3use schemars::{JsonSchema, Schema, SchemaGenerator};
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5use std::collections::{BTreeMap, HashMap};
6use std::path::{Path, PathBuf};
7use thiserror::Error;
8use tokio::fs;
9use tracing::warn;
10use utils::shell_expander::ShellExpander;
11use utils::substitution::substitute_parameters;
12use utils::variables::VarError;
13use utils::{PathOrObject, ResourcePath, is_false, string_or_object_schema};
14
15#[derive(Debug, Clone, PartialEq)]
16pub enum Prompt {
17 Text(String),
18 File {
19 path: PathBuf,
20 args: Option<HashMap<String, String>>,
21 cwd: PathBuf,
22 },
23 McpInstructions(BTreeMap<String, String>),
26}
27
28#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum PromptSource {
34 Text { text: String },
35 File { path: ResourcePath, optional: bool },
36 Glob { pattern: ResourcePath, optional: bool },
37}
38
39impl PromptSource {
40 pub fn file(path: impl Into<ResourcePath>) -> Self {
41 Self::File { path: path.into(), optional: false }
42 }
43
44 pub fn glob(pattern: impl Into<ResourcePath>) -> Self {
45 Self::Glob { pattern: pattern.into(), optional: false }
46 }
47
48 #[must_use]
49 pub fn optional(self) -> Self {
50 match self {
51 Self::File { path, .. } => Self::File { path, optional: true },
52 Self::Glob { pattern, .. } => Self::Glob { pattern, optional: true },
53 Self::Text { .. } => self,
54 }
55 }
56
57 pub fn path(&self) -> Option<&str> {
59 match self {
60 Self::File { path, .. } => Some(path.as_authored()),
61 Self::Glob { pattern, .. } => Some(pattern.as_authored()),
62 Self::Text { .. } => None,
63 }
64 }
65
66 pub fn is_optional(&self) -> bool {
67 match self {
68 Self::File { optional, .. } | Self::Glob { optional, .. } => *optional,
69 Self::Text { .. } => false,
70 }
71 }
72}
73
74impl From<&str> for PromptSource {
75 fn from(value: &str) -> Self {
76 Self::file(value)
77 }
78}
79
80impl From<String> for PromptSource {
81 fn from(value: String) -> Self {
82 Self::file(value)
83 }
84}
85
86impl From<PromptSourceObject> for PromptSource {
87 fn from(object: PromptSourceObject) -> Self {
88 match object {
89 PromptSourceObject::Text { text } => Self::Text { text },
90 PromptSourceObject::File { path, optional } => Self::File { path, optional },
91 PromptSourceObject::Glob { pattern, optional } => Self::Glob { pattern, optional },
92 }
93 }
94}
95
96impl<'de> Deserialize<'de> for PromptSource {
97 fn deserialize<T: Deserializer<'de>>(deserializer: T) -> std::result::Result<Self, T::Error> {
98 Ok(match PathOrObject::<PromptSourceObject>::deserialize(deserializer)? {
99 PathOrObject::Path(path) => Self::File { path, optional: false },
100 PathOrObject::Object(object) => object.into(),
101 })
102 }
103}
104
105impl Serialize for PromptSource {
106 fn serialize<T: Serializer>(&self, serializer: T) -> std::result::Result<T::Ok, T::Error> {
107 match self {
108 Self::File { path, optional: false } => path.serialize(serializer),
109 Self::File { path, optional } => {
110 Serialize::serialize(&PromptSourceObject::File { path: path.clone(), optional: *optional }, serializer)
111 }
112 Self::Text { text } => Serialize::serialize(&PromptSourceObject::Text { text: text.clone() }, serializer),
113 Self::Glob { pattern, optional } => Serialize::serialize(
114 &PromptSourceObject::Glob { pattern: pattern.clone(), optional: *optional },
115 serializer,
116 ),
117 }
118 }
119}
120
121impl JsonSchema for PromptSource {
122 fn schema_name() -> std::borrow::Cow<'static, str> {
123 "PromptSource".into()
124 }
125
126 fn json_schema(generator: &mut SchemaGenerator) -> Schema {
127 string_or_object_schema(
128 "Authored description of a prompt source — either a file path string or a typed text, file, or glob object.",
129 &generator.subschema_for::<PromptSourceObject>().to_value(),
130 )
131 }
132}
133
134#[derive(schemars::JsonSchema, serde::Deserialize, serde::Serialize)]
135#[serde(tag = "type", rename_all = "camelCase", deny_unknown_fields)]
136enum PromptSourceObject {
137 Text {
138 text: String,
139 },
140 File {
141 path: ResourcePath,
142 #[serde(default, skip_serializing_if = "is_false")]
143 optional: bool,
144 },
145 Glob {
146 pattern: ResourcePath,
147 #[serde(default, skip_serializing_if = "is_false")]
148 optional: bool,
149 },
150}
151
152#[derive(Debug, Clone, PartialEq, Eq, Error)]
154pub enum PromptSourceError {
155 #[error("Invalid glob pattern '{pattern}': {error}")]
157 InvalidGlobPattern { pattern: String, error: String },
158
159 #[error("Prompt file '{path}' does not exist")]
161 Missing { path: String },
162
163 #[error("Prompt glob '{pattern}' matched no files")]
165 ZeroMatch { pattern: String },
166
167 #[error("Prompt entry '{pattern}' references undefined variable '{variable}'")]
169 UnresolvedVariable { pattern: String, variable: String },
170}
171
172impl Prompt {
173 pub fn text(str: &str) -> Self {
174 Self::Text(str.to_string())
175 }
176
177 pub fn file(path: impl Into<PathBuf>, cwd: impl Into<PathBuf>) -> Self {
178 Self::File { path: path.into(), args: None, cwd: cwd.into() }
179 }
180
181 pub fn from_sources(
183 workspace_root: &Path,
184 sources: &[PromptSource],
185 ) -> std::result::Result<Vec<Prompt>, PromptSourceError> {
186 let mut prompts = Vec::new();
187 for source in sources {
188 if let PromptSource::Text { text } = source {
189 prompts.push(Prompt::text(text));
190 continue;
191 }
192 match resolve_source_files(workspace_root, source) {
193 Ok(paths) => {
194 for path in paths {
195 prompts.push(Prompt::file(path, workspace_root.to_path_buf()));
196 }
197 }
198 Err(PromptSourceError::Missing { .. }) if source.is_optional() => {}
199 Err(PromptSourceError::UnresolvedVariable { variable, .. }) if source.is_optional() => {
200 warn!(
201 "Skipping optional prompt entry '{}': variable '{variable}' is not defined",
202 source.path().unwrap_or_default()
203 );
204 }
205 Err(error) => return Err(error),
206 }
207 }
208 Ok(prompts)
209 }
210
211 pub async fn build(&self) -> Result<String> {
213 match self {
214 Prompt::Text(text) => Ok(text.clone()),
215 Prompt::File { path, args, cwd } => {
216 let content = Self::resolve_file(path).await?;
217 let substituted = substitute_parameters(&content, args);
218 let expander = ShellExpander::new();
219 Ok(expander.expand(&substituted, cwd).await)
220 }
221 Prompt::McpInstructions(instructions) => Ok(format_mcp_instructions(instructions)),
222 }
223 }
224
225 pub async fn build_all(prompts: &[Prompt]) -> Result<String> {
227 let mut parts = Vec::with_capacity(prompts.len());
228 for p in prompts {
229 let part = p.build().await?;
230 if !part.is_empty() {
231 parts.push(part);
232 }
233 }
234 Ok(parts.join("\n\n"))
235 }
236
237 async fn resolve_file(path: &Path) -> Result<String> {
238 fs::read_to_string(path)
239 .await
240 .map_err(|e| AgentError::IoError(format!("Failed to read file '{}': {e}", path.display())))
241 }
242}
243
244fn resolve_source_files(
245 workspace_root: &Path,
246 source: &PromptSource,
247) -> std::result::Result<Vec<PathBuf>, PromptSourceError> {
248 match source {
249 PromptSource::Text { .. } => Ok(Vec::new()),
250 PromptSource::File { path, .. } => {
251 let full_path = resolve_path(path, workspace_root)?;
252 if full_path.is_file() {
253 Ok(vec![full_path])
254 } else {
255 Err(PromptSourceError::Missing { path: path.as_authored().to_string() })
256 }
257 }
258 PromptSource::Glob { pattern, optional } => {
259 let full_pattern = resolve_path(pattern, workspace_root)?;
260 let mut paths: Vec<PathBuf> = glob(&full_pattern.to_string_lossy())
261 .map_err(|e| PromptSourceError::InvalidGlobPattern {
262 pattern: pattern.as_authored().to_string(),
263 error: e.to_string(),
264 })?
265 .filter_map(std::result::Result::ok)
266 .filter(|path| path.is_file())
267 .collect();
268 paths.sort();
269 if paths.is_empty() && !*optional {
270 Err(PromptSourceError::ZeroMatch { pattern: pattern.as_authored().to_string() })
271 } else {
272 Ok(paths)
273 }
274 }
275 }
276}
277
278fn resolve_path(path: &ResourcePath, workspace_root: &Path) -> std::result::Result<PathBuf, PromptSourceError> {
279 path.resolve(workspace_root).map_err(|VarError::NotFound(variable)| PromptSourceError::UnresolvedVariable {
280 pattern: path.as_authored().to_string(),
281 variable,
282 })
283}
284
285pub struct PromptCache {
286 prompts: Vec<Prompt>,
287 entries: Vec<(Prompt, String)>,
288}
289
290impl PromptCache {
291 pub fn new(mut prompts: Vec<Prompt>) -> Self {
292 if !prompts.iter().any(|p| matches!(p, Prompt::McpInstructions(_))) {
293 prompts.push(Prompt::McpInstructions(BTreeMap::new()));
294 }
295 Self { prompts, entries: Vec::new() }
296 }
297
298 pub fn update_mcp_instruction(&mut self, server: String, body: Option<String>) {
299 for prompt in &mut self.prompts {
300 if let Prompt::McpInstructions(map) = prompt {
301 match body {
302 Some(text) => {
303 map.insert(server, text);
304 }
305 None => {
306 map.remove(&server);
307 }
308 }
309 return;
310 }
311 }
312 }
313
314 pub async fn render(&mut self) -> Result<String> {
315 self.entries.truncate(self.prompts.len());
316 let mut rendered_prompt = String::new();
317 for i in 0..self.prompts.len() {
318 let prompt = &self.prompts[i];
319 match self.entries.get_mut(i) {
320 Some((cached, _)) if *cached == *prompt => {}
321 Some(entry) => *entry = (prompt.clone(), prompt.build().await?),
322 None => self.entries.push((prompt.clone(), prompt.build().await?)),
323 }
324
325 let (_, body) = &self.entries[i];
326 if !body.is_empty() {
327 if !rendered_prompt.is_empty() {
328 rendered_prompt.push_str("\n\n");
329 }
330 rendered_prompt.push_str(body);
331 }
332 }
333 Ok(rendered_prompt)
334 }
335}
336
337fn format_mcp_instructions(instructions: &BTreeMap<String, String>) -> String {
339 if instructions.is_empty() {
340 return String::new();
341 }
342
343 let mut parts = vec!["# MCP Server Instructions\n".to_string()];
344 parts.push("You are connected to the following MCP servers:\n".to_string());
345
346 for (server_name, body) in instructions {
347 parts.push(format!("<mcp-server name=\"{server_name}\">\n{body}\n</mcp-server>\n"));
348 }
349
350 parts.join("\n")
351}
352
353#[cfg(test)]
354mod tests {
355 use std::fs::{create_dir_all, write};
356
357 use super::*;
358 use crate::testing::mcp_instructions as instructions;
359
360 #[tokio::test]
361 async fn build_text_prompt() {
362 let prompt = Prompt::text("Hello, world!");
363 let result = prompt.build().await.unwrap();
364 assert_eq!(result, "Hello, world!");
365 }
366
367 #[tokio::test]
368 async fn build_all_concatenates_prompts() {
369 let prompts = vec![Prompt::text("Part one"), Prompt::text("Part two")];
370 let result = Prompt::build_all(&prompts).await.unwrap();
371 assert_eq!(result, "Part one\n\nPart two");
372 }
373
374 #[tokio::test]
375 async fn build_all_concatenates_multiple_files() {
376 let dir = tempfile::tempdir().unwrap();
377 std::fs::write(dir.path().join("AGENTS.md"), "Agent instructions").unwrap();
378 std::fs::write(dir.path().join("SYSTEM.md"), "System prompt").unwrap();
379
380 let prompts = vec![
381 Prompt::file(dir.path().join("AGENTS.md"), dir.path()),
382 Prompt::file(dir.path().join("SYSTEM.md"), dir.path()),
383 ];
384 let result = Prompt::build_all(&prompts).await.unwrap();
385 assert!(result.contains("Agent instructions"));
386 assert!(result.contains("System prompt"));
387 assert!(result.contains("\n\n"));
388 }
389
390 #[tokio::test]
391 async fn build_all_skips_empty_parts() {
392 let prompts = vec![Prompt::text("Part one"), Prompt::text(""), Prompt::text("Part two")];
393 let result = Prompt::build_all(&prompts).await.unwrap();
394 assert_eq!(result, "Part one\n\nPart two");
395 }
396
397 #[tokio::test]
398 async fn prompt_cache_render_matches_build_all_on_first_render() {
399 let prompts = vec![
400 Prompt::text("first"),
401 Prompt::McpInstructions(instructions(&[("srv", "body")])),
402 Prompt::text("last"),
403 ];
404 let expected = Prompt::build_all(&prompts).await.unwrap();
405 let mut cache = PromptCache::new(prompts);
406 assert_eq!(cache.render().await.unwrap(), expected);
407 }
408
409 #[tokio::test]
410 async fn prompt_cache_reuses_unchanged_slots() {
411 use std::fs::{remove_file, write};
412
413 let dir = tempfile::tempdir().unwrap();
414 write(dir.path().join("AGENTS.md"), "cached body").unwrap();
415 let mut cache = PromptCache::new(vec![
416 Prompt::file(dir.path().join("AGENTS.md"), dir.path()),
417 Prompt::McpInstructions(BTreeMap::new()),
418 ]);
419
420 cache.render().await.unwrap();
421
422 remove_file(dir.path().join("AGENTS.md")).unwrap();
424 cache.update_mcp_instruction("srv".into(), Some("instr".into()));
425
426 let rendered = cache.render().await.unwrap();
427 assert!(rendered.contains("cached body"));
428 assert!(rendered.contains("instr"));
429 }
430
431 #[tokio::test]
432 async fn prompt_cache_empty_renders_empty() {
433 assert_eq!(PromptCache::new(vec![]).render().await.unwrap(), "");
434 }
435
436 #[tokio::test]
437 async fn prompt_cache_drops_empty_slots() {
438 let mut cache = PromptCache::new(vec![Prompt::text("a"), Prompt::text("b")]);
439 assert_eq!(cache.render().await.unwrap(), "a\n\nb");
440 }
441
442 #[tokio::test]
443 async fn build_file_expands_shell_commands() {
444 let dir = tempfile::tempdir().unwrap();
445 write(dir.path().join("AGENTS.md"), "Instructions\n\nbranch: !`echo main`\n\nRules").unwrap();
446
447 let prompt = Prompt::file(dir.path().join("AGENTS.md"), dir.path().to_path_buf());
448 let result = prompt.build().await.unwrap();
449 assert!(result.contains("Instructions"));
450 assert!(result.contains("branch: main"));
451 assert!(result.contains("Rules"));
452 assert!(!result.contains("!`"));
453 }
454
455 #[tokio::test]
456 async fn build_file_runs_shell_in_cwd() {
457 let dir = tempfile::tempdir().unwrap();
458 write(dir.path().join("sentinel.txt"), "").unwrap();
459 let prompt_path = dir.path().join("AGENTS.md");
460 write(&prompt_path, "files: !`ls`").unwrap();
461
462 let prompt = Prompt::file(prompt_path, dir.path().to_path_buf());
463 let result = prompt.build().await.unwrap();
464 assert!(result.contains("sentinel.txt"), "expected sentinel.txt in output: {result}");
465 }
466
467 #[tokio::test]
468 async fn build_file_handles_multiple_commands() {
469 let dir = tempfile::tempdir().unwrap();
470 let prompt_path = dir.path().join("AGENTS.md");
471 write(&prompt_path, "a=!`echo one`, b=!`echo two`").unwrap();
472
473 let prompt = Prompt::file(prompt_path, dir.path().to_path_buf());
474 let result = prompt.build().await.unwrap();
475 assert_eq!(result, "a=one, b=two");
476 }
477
478 #[tokio::test]
479 async fn build_file_substitutes_empty_on_failure() {
480 let dir = tempfile::tempdir().unwrap();
481 let prompt_path = dir.path().join("AGENTS.md");
482 write(&prompt_path, "before !`exit 1` after").unwrap();
483
484 let prompt = Prompt::file(prompt_path, dir.path().to_path_buf());
485 let result = prompt.build().await.unwrap();
486 assert_eq!(result, "before after");
487 }
488
489 #[tokio::test]
490 async fn build_file_trims_trailing_whitespace() {
491 let dir = tempfile::tempdir().unwrap();
492 let prompt_path = dir.path().join("AGENTS.md");
493 write(&prompt_path, "!`printf 'hi\\n\\n'`").unwrap();
494
495 let prompt = Prompt::file(prompt_path, dir.path().to_path_buf());
496 let result = prompt.build().await.unwrap();
497 assert_eq!(result, "hi");
498 }
499
500 #[test]
501 fn optional_file_source_skips_missing_file() {
502 let dir = tempfile::tempdir().unwrap();
503 write(dir.path().join("EXISTS.md"), "exists").unwrap();
504
505 let sources = vec![PromptSource::file("EXISTS.md"), PromptSource::file("MISSING.md").optional()];
506 let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
507 assert_eq!(prompts.len(), 1);
508 }
509
510 #[test]
511 fn optional_glob_source_skips_zero_matches() {
512 let dir = tempfile::tempdir().unwrap();
513 write(dir.path().join("EXISTS.md"), "exists").unwrap();
514
515 let sources = vec![PromptSource::file("EXISTS.md"), PromptSource::glob("nonexistent*.md").optional()];
516 let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
517 assert_eq!(prompts.len(), 1);
518 }
519
520 #[test]
521 fn required_glob_source_expands_to_one_prompt_per_match() {
522 let dir = tempfile::tempdir().unwrap();
523 let rules_dir = dir.path().join(".aether/rules");
524 create_dir_all(&rules_dir).unwrap();
525 write(rules_dir.join("a.md"), "a").unwrap();
526 write(rules_dir.join("b.md"), "b").unwrap();
527
528 let sources = vec![PromptSource::glob(".aether/rules/*.md")];
529 let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
530 assert_eq!(prompts.len(), 2);
531 }
532
533 #[test]
534 fn required_glob_source_with_no_matches_errors() {
535 let dir = tempfile::tempdir().unwrap();
536 let sources = vec![PromptSource::glob("nonexistent*.md")];
537 let err = Prompt::from_sources(dir.path(), &sources).unwrap_err();
538 assert!(matches!(err, PromptSourceError::ZeroMatch { .. }));
539 }
540
541 #[test]
542 fn optional_glob_source_still_errors_on_invalid_pattern() {
543 let dir = tempfile::tempdir().unwrap();
544 let sources = vec![PromptSource::glob("[invalid").optional()];
545 let err = Prompt::from_sources(dir.path(), &sources).unwrap_err();
546 assert!(matches!(err, PromptSourceError::InvalidGlobPattern { .. }));
547 }
548
549 #[test]
550 fn optional_file_source_skips_unresolved_variable() {
551 let dir = tempfile::tempdir().unwrap();
552 write(dir.path().join("EXISTS.md"), "exists").unwrap();
553
554 let sources = vec![
555 PromptSource::file("EXISTS.md"),
556 PromptSource::file("${DEFINITELY_NOT_SET_VAR_PROMPT_FILE}/foo.md").optional(),
557 ];
558 let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
559 assert_eq!(prompts.len(), 1);
560 }
561
562 #[test]
563 fn optional_glob_source_skips_unresolved_variable() {
564 let dir = tempfile::tempdir().unwrap();
565 write(dir.path().join("EXISTS.md"), "exists").unwrap();
566
567 let sources = vec![
568 PromptSource::file("EXISTS.md"),
569 PromptSource::glob("${DEFINITELY_NOT_SET_VAR_PROMPT_GLOB}/*.md").optional(),
570 ];
571 let prompts = Prompt::from_sources(dir.path(), &sources).unwrap();
572 assert_eq!(prompts.len(), 1);
573 }
574
575 #[test]
576 fn required_file_source_errors_on_unresolved_variable() {
577 let dir = tempfile::tempdir().unwrap();
578 let sources = vec![PromptSource::file("${DEFINITELY_NOT_SET_VAR_PROMPT_REQ}/foo.md")];
579 let err = Prompt::from_sources(dir.path(), &sources).unwrap_err();
580 assert!(matches!(err, PromptSourceError::UnresolvedVariable { .. }));
581 }
582
583 #[test]
584 fn prompt_source_string_shorthand_is_required_file() {
585 let source: PromptSource = serde_json::from_str(r#""SYSTEM.md""#).unwrap();
586 assert_eq!(source, PromptSource::file("SYSTEM.md"));
587 }
588
589 #[test]
590 fn optional_prompt_source_serializes_as_typed_object() {
591 let source = PromptSource::file("${WORKSPACE}/AGENTS.md").optional();
592 let value = serde_json::to_value(&source).unwrap();
593 assert_eq!(value, serde_json::json!({"type":"file","path":"${WORKSPACE}/AGENTS.md","optional":true}));
594
595 let source = PromptSource::file("SYSTEM.md");
597 let value = serde_json::to_value(&source).unwrap();
598 assert_eq!(value, serde_json::json!("SYSTEM.md"));
599 }
600
601 #[test]
602 fn optional_prompt_source_deserializes_from_typed_object() {
603 let source: PromptSource =
604 serde_json::from_str(r#"{"type":"file","path":"${WORKSPACE}/AGENTS.md","optional":true}"#).unwrap();
605 assert_eq!(source, PromptSource::file("${WORKSPACE}/AGENTS.md").optional());
606 }
607
608 #[test]
609 fn optional_glob_source_deserializes_from_typed_object() {
610 let source: PromptSource =
611 serde_json::from_str(r#"{"type":"glob","pattern":"${WORKSPACE}/.aether/rules/*.md","optional":true}"#)
612 .unwrap();
613 assert_eq!(source, PromptSource::glob("${WORKSPACE}/.aether/rules/*.md").optional());
614 }
615}