1use crate::core::{AgentError, Result};
2use glob::glob;
3use mcp_utils::client::ServerInstructions;
4use schemars::{JsonSchema, Schema, SchemaGenerator};
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6use std::collections::HashMap;
7use std::env;
8use std::path::{Path, PathBuf};
9use thiserror::Error;
10use tokio::fs;
11use tracing::warn;
12use utils::shell_expander::ShellExpander;
13use utils::substitution::substitute_parameters;
14
15#[derive(Debug, Clone)]
16pub enum Prompt {
17 Text(String),
18 File {
19 path: String,
20 args: Option<HashMap<String, String>>,
21 cwd: Option<PathBuf>,
22 },
23 PromptGlobs {
26 patterns: Vec<String>,
27 cwd: PathBuf,
28 },
29 McpInstructions(Vec<ServerInstructions>),
30}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
37pub enum PromptSource {
38 Text { text: String },
39 File { path: String },
40 Glob { pattern: String },
41}
42
43#[derive(serde::Deserialize)]
44#[serde(untagged)]
45enum PromptSourceInput {
46 Path(String),
47 Object(PromptSourceObject),
48}
49
50#[derive(schemars::JsonSchema, serde::Deserialize, serde::Serialize)]
51#[serde(tag = "type", rename_all = "camelCase", deny_unknown_fields)]
52enum PromptSourceObject {
53 Text { text: String },
54 File { path: String },
55 Glob { pattern: String },
56}
57
58impl<'de> Deserialize<'de> for PromptSource {
59 fn deserialize<T: Deserializer<'de>>(deserializer: T) -> std::result::Result<Self, T::Error> {
60 match serde::Deserialize::deserialize(deserializer)? {
61 PromptSourceInput::Path(path) | PromptSourceInput::Object(PromptSourceObject::File { path }) => {
62 Ok(Self::File { path })
63 }
64 PromptSourceInput::Object(PromptSourceObject::Text { text }) => Ok(Self::Text { text }),
65 PromptSourceInput::Object(PromptSourceObject::Glob { pattern }) => Ok(Self::Glob { pattern }),
66 }
67 }
68}
69
70impl Serialize for PromptSource {
71 fn serialize<T: Serializer>(&self, serializer: T) -> std::result::Result<T::Ok, T::Error> {
72 match self {
73 Self::File { path } => serializer.serialize_str(path),
74 Self::Text { text } => Serialize::serialize(&PromptSourceObject::Text { text: text.clone() }, serializer),
75 Self::Glob { pattern } => {
76 Serialize::serialize(&PromptSourceObject::Glob { pattern: pattern.clone() }, serializer)
77 }
78 }
79 }
80}
81
82impl JsonSchema for PromptSource {
83 fn schema_name() -> std::borrow::Cow<'static, str> {
84 "PromptSource".into()
85 }
86
87 fn json_schema(generator: &mut SchemaGenerator) -> Schema {
88 let object_schema = generator.subschema_for::<PromptSourceObject>().to_value();
89 Schema::try_from(serde_json::json!({
90 "description": "Authored description of a prompt source — either a file path string or a typed text, file, or glob object.",
91 "oneOf": [
92 { "type": "string" },
93 object_schema
94 ]
95 }))
96 .expect("prompt source schema must be valid")
97 }
98}
99
100impl PromptSource {
101 pub fn file(path: impl Into<String>) -> Self {
102 Self::File { path: path.into() }
103 }
104
105 pub fn path(&self) -> Option<&str> {
106 match self {
107 Self::File { path } => Some(path.as_str()),
108 Self::Glob { pattern } => Some(pattern.as_str()),
109 Self::Text { .. } => None,
110 }
111 }
112}
113
114impl From<&str> for PromptSource {
115 fn from(value: &str) -> Self {
116 Self::file(value)
117 }
118}
119
120impl From<String> for PromptSource {
121 fn from(value: String) -> Self {
122 Self::file(value)
123 }
124}
125
126#[derive(Debug, Clone, PartialEq, Eq, Error)]
128pub enum PromptSourceError {
129 #[error("Invalid glob pattern '{pattern}': {error}")]
131 InvalidGlobPattern { pattern: String, error: String },
132
133 #[error("Prompt entry '{pattern}' resolves to no files")]
135 ZeroMatch { pattern: String },
136}
137
138impl Prompt {
139 pub fn text(str: &str) -> Self {
140 Self::Text(str.to_string())
141 }
142
143 pub fn file(path: &str) -> Self {
144 Self::File { path: path.to_string(), args: None, cwd: None }
145 }
146
147 pub fn file_with_args(path: &str, args: HashMap<String, String>) -> Self {
148 Self::File { path: path.to_string(), args: Some(args), cwd: None }
149 }
150
151 pub fn from_globs(patterns: Vec<String>, cwd: PathBuf) -> Self {
152 Self::PromptGlobs { patterns, cwd }
153 }
154
155 pub fn from_sources(
160 project_root: &Path,
161 sources: &[PromptSource],
162 ) -> std::result::Result<Vec<Prompt>, PromptSourceError> {
163 sources
164 .iter()
165 .map(|source| match source {
166 PromptSource::Text { text } => Ok(Prompt::text(text)),
167 PromptSource::File { path } => validate_prompt_file(project_root, path)
168 .map(|()| Prompt::file(path).with_cwd(project_root.to_path_buf())),
169 PromptSource::Glob { pattern } => validate_prompt_glob(project_root, pattern)
170 .map(|()| Prompt::from_globs(vec![pattern.clone()], project_root.to_path_buf())),
171 })
172 .collect()
173 }
174
175 pub fn with_cwd(self, cwd: PathBuf) -> Self {
176 match self {
177 Self::File { path, args, .. } => Self::File { path, args, cwd: Some(cwd) },
178 Self::PromptGlobs { patterns, .. } => Self::PromptGlobs { patterns, cwd },
179 Self::Text(_) | Self::McpInstructions(_) => self,
180 }
181 }
182
183 pub fn mcp_instructions(instructions: Vec<ServerInstructions>) -> Self {
184 Self::McpInstructions(instructions)
185 }
186
187 pub async fn build(&self) -> Result<String> {
189 match self {
190 Prompt::Text(text) => Ok(text.clone()),
191 Prompt::File { path, args, cwd } => {
192 let content = Self::resolve_file(&PathBuf::from(path)).await?;
193 let substituted = substitute_parameters(&content, args);
194 let expander = ShellExpander::new();
195 Self::expand_builtins(&substituted, cwd.as_deref(), &expander).await
196 }
197 Prompt::PromptGlobs { patterns, cwd } => Self::resolve_prompt_globs(patterns, cwd).await,
198 Prompt::McpInstructions(instructions) => Ok(format_mcp_instructions(instructions)),
199 }
200 }
201
202 pub async fn build_all(prompts: &[Prompt]) -> Result<String> {
204 let mut parts = Vec::with_capacity(prompts.len());
205 for p in prompts {
206 let part = p.build().await?;
207 if !part.is_empty() {
208 parts.push(part);
209 }
210 }
211 Ok(parts.join("\n\n"))
212 }
213
214 async fn resolve_file(path: &Path) -> Result<String> {
215 fs::read_to_string(path)
216 .await
217 .map_err(|e| AgentError::IoError(format!("Failed to read file '{}': {e}", path.display())))
218 }
219
220 async fn resolve_prompt_globs(patterns: &[String], cwd: &Path) -> Result<String> {
221 let mut contents = Vec::new();
222 let expander = ShellExpander::new();
223
224 for pattern in patterns {
225 let full_pattern = if Path::new(pattern).is_absolute() {
226 pattern.clone()
227 } else {
228 cwd.join(pattern).to_string_lossy().to_string()
229 };
230
231 let paths = glob(&full_pattern)
232 .map_err(|e| AgentError::IoError(format!("Invalid glob pattern '{pattern}': {e}")))?;
233
234 let mut matched: Vec<PathBuf> = paths.filter_map(std::result::Result::ok).collect();
235 matched.sort();
236
237 for path in matched {
238 if path.is_file() {
239 match fs::read_to_string(&path).await {
240 Ok(content) => {
241 let resolved = Self::expand_builtins(&content, Some(cwd), &expander).await?;
242 contents.push(resolved);
243 }
244 Err(e) => {
245 warn!("Failed to read prompt file '{}': {e}", path.display());
246 }
247 }
248 }
249 }
250 }
251
252 Ok(contents.join("\n\n"))
253 }
254
255 async fn expand_builtins(content: &str, cwd: Option<&Path>, expander: &ShellExpander) -> Result<String> {
260 let cwd = match cwd {
261 Some(dir) => dir.to_path_buf(),
262 None => {
263 env::current_dir().map_err(|e| AgentError::IoError(format!("Failed to get current directory: {e}")))?
264 }
265 };
266 Ok(expander.expand(content, &cwd).await)
267 }
268}
269
270fn validate_prompt_file(project_root: &Path, path: &str) -> std::result::Result<(), PromptSourceError> {
271 let full_path = project_root.join(path);
272 if full_path.is_file() { Ok(()) } else { Err(PromptSourceError::ZeroMatch { pattern: path.to_string() }) }
273}
274
275fn validate_prompt_glob(project_root: &Path, pattern: &str) -> std::result::Result<(), PromptSourceError> {
276 let full_pattern = if Path::new(pattern).is_absolute() {
277 pattern.to_string()
278 } else {
279 project_root.join(pattern).to_string_lossy().to_string()
280 };
281
282 let has_file_match = glob(&full_pattern)
283 .map_err(|e| PromptSourceError::InvalidGlobPattern { pattern: pattern.to_string(), error: e.to_string() })?
284 .filter_map(std::result::Result::ok)
285 .any(|path| path.is_file());
286
287 if has_file_match { Ok(()) } else { Err(PromptSourceError::ZeroMatch { pattern: pattern.to_string() }) }
288}
289
290fn format_mcp_instructions(instructions: &[ServerInstructions]) -> String {
292 if instructions.is_empty() {
293 return String::new();
294 }
295
296 let mut parts = vec!["# MCP Server Instructions\n".to_string()];
297 parts.push("You are connected to the following MCP servers:\n".to_string());
298
299 for instr in instructions {
300 parts.push(format!("<mcp-server name=\"{}\">\n{}\n</mcp-server>\n", instr.server_name, instr.instructions));
301 }
302
303 parts.join("\n")
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[tokio::test]
311 async fn build_text_prompt() {
312 let prompt = Prompt::text("Hello, world!");
313 let result = prompt.build().await.unwrap();
314 assert_eq!(result, "Hello, world!");
315 }
316
317 #[tokio::test]
318 async fn build_all_concatenates_prompts() {
319 let prompts = vec![Prompt::text("Part one"), Prompt::text("Part two")];
320 let result = Prompt::build_all(&prompts).await.unwrap();
321 assert_eq!(result, "Part one\n\nPart two");
322 }
323
324 #[tokio::test]
325 async fn prompt_globs_resolves_single_file() {
326 let dir = tempfile::tempdir().unwrap();
327 std::fs::write(dir.path().join("AGENTS.md"), "# Instructions\nBe helpful").unwrap();
328
329 let prompt = Prompt::from_globs(vec!["AGENTS.md".to_string()], dir.path().to_path_buf());
330 let result = prompt.build().await.unwrap();
331 assert_eq!(result, "# Instructions\nBe helpful");
332 }
333
334 #[tokio::test]
335 async fn prompt_globs_resolves_glob_pattern() {
336 let dir = tempfile::tempdir().unwrap();
337 let rules_dir = dir.path().join(".aether/rules");
338 std::fs::create_dir_all(&rules_dir).unwrap();
339 std::fs::write(rules_dir.join("a-coding.md"), "Use Rust").unwrap();
340 std::fs::write(rules_dir.join("b-testing.md"), "Write tests").unwrap();
341
342 let prompt = Prompt::from_globs(vec![".aether/rules/*.md".to_string()], dir.path().to_path_buf());
343 let result = prompt.build().await.unwrap();
344 assert!(result.contains("Use Rust"));
345 assert!(result.contains("Write tests"));
346 }
347
348 #[tokio::test]
349 async fn prompt_globs_returns_empty_for_no_matches() {
350 let dir = tempfile::tempdir().unwrap();
351
352 let prompt = Prompt::from_globs(vec!["nonexistent*.md".to_string()], dir.path().to_path_buf());
353 let result = prompt.build().await.unwrap();
354 assert!(result.is_empty());
355 }
356
357 #[tokio::test]
358 async fn prompt_globs_supports_absolute_paths() {
359 let dir = tempfile::tempdir().unwrap();
360 let file_path = dir.path().join("rules.md");
361 std::fs::write(&file_path, "Absolute rule").unwrap();
362
363 let prompt = Prompt::from_globs(vec![file_path.to_string_lossy().to_string()], PathBuf::from("/tmp"));
364 let result = prompt.build().await.unwrap();
365 assert_eq!(result, "Absolute rule");
366 }
367
368 #[tokio::test]
369 async fn prompt_globs_concatenates_multiple_patterns() {
370 let dir = tempfile::tempdir().unwrap();
371 std::fs::write(dir.path().join("AGENTS.md"), "Agent instructions").unwrap();
372 std::fs::write(dir.path().join("SYSTEM.md"), "System prompt").unwrap();
373
374 let prompt =
375 Prompt::from_globs(vec!["AGENTS.md".to_string(), "SYSTEM.md".to_string()], dir.path().to_path_buf());
376 let result = prompt.build().await.unwrap();
377 assert!(result.contains("Agent instructions"));
378 assert!(result.contains("System prompt"));
379 assert!(result.contains("\n\n"));
380 }
381
382 #[tokio::test]
383 async fn build_all_skips_empty_parts() {
384 let prompts = vec![Prompt::text("Part one"), Prompt::text(""), Prompt::text("Part two")];
385 let result = Prompt::build_all(&prompts).await.unwrap();
386 assert_eq!(result, "Part one\n\nPart two");
387 }
388
389 #[tokio::test]
390 async fn expand_builtins_no_op_without_marker() {
391 let content = "Just some plain content with no directives";
392 let expander = ShellExpander::new();
393 let result = Prompt::expand_builtins(content, None, &expander).await.unwrap();
394 assert_eq!(result, content);
395 }
396
397 #[tokio::test]
398 async fn expand_builtins_runs_shell_command() {
399 let expander = ShellExpander::new();
400 let result = Prompt::expand_builtins("branch: !`echo main`", None, &expander).await.unwrap();
401 assert_eq!(result, "branch: main");
402 }
403
404 #[tokio::test]
405 async fn expand_builtins_runs_command_in_cwd() {
406 let dir = tempfile::tempdir().unwrap();
407 std::fs::write(dir.path().join("sentinel.txt"), "").unwrap();
408
409 let expander = ShellExpander::new();
410 let result = Prompt::expand_builtins("files: !`ls`", Some(dir.path()), &expander).await.unwrap();
411 assert!(result.contains("sentinel.txt"), "expected sentinel.txt in output: {result}");
412 }
413
414 #[tokio::test]
415 async fn expand_builtins_handles_multiple_commands() {
416 let expander = ShellExpander::new();
417 let result = Prompt::expand_builtins("a=!`echo one`, b=!`echo two`", None, &expander).await.unwrap();
418 assert_eq!(result, "a=one, b=two");
419 }
420
421 #[tokio::test]
422 async fn expand_builtins_substitutes_empty_on_failure() {
423 let expander = ShellExpander::new();
424 let result = Prompt::expand_builtins("before !`exit 1` after", None, &expander).await.unwrap();
425 assert_eq!(result, "before after");
426 }
427
428 #[tokio::test]
429 async fn expand_builtins_trims_trailing_whitespace() {
430 let expander = ShellExpander::new();
431 let result = Prompt::expand_builtins("!`printf 'hi\\n\\n'`", None, &expander).await.unwrap();
432 assert_eq!(result, "hi");
433 }
434
435 #[tokio::test]
436 async fn prompt_globs_expands_shell_in_file() {
437 let dir = tempfile::tempdir().unwrap();
438 std::fs::write(dir.path().join("AGENTS.md"), "Instructions\n\nbranch: !`echo main`\n\nRules").unwrap();
439
440 let prompt = Prompt::from_globs(vec!["AGENTS.md".to_string()], dir.path().to_path_buf());
441 let result = prompt.build().await.unwrap();
442 assert!(result.contains("Instructions"));
443 assert!(result.contains("branch: main"));
444 assert!(result.contains("Rules"));
445 assert!(!result.contains("!`"));
446 }
447}