Skip to main content

aether_cli/show_prompt/
run.rs

1use std::collections::BTreeMap;
2
3use super::PromptArgs;
4use crate::error::CliError;
5use crate::resolve::resolve_agent_spec;
6use crate::runtime::RuntimeBuilder;
7use aether_core::agent_spec::McpJsonFileRef;
8use aether_core::core::Prompt;
9use aether_project::load_agent_catalog;
10use llm::ToolDefinition;
11use serde_json::Value;
12
13pub async fn run_prompt(args: PromptArgs) -> Result<(), CliError> {
14    let cwd = args.cwd.canonicalize().map_err(CliError::IoError)?;
15    let catalog = load_agent_catalog(&cwd).map_err(|e| CliError::AgentError(e.to_string()))?;
16    let spec = resolve_agent_spec(&catalog, args.agent.as_deref(), &cwd)?;
17
18    let mcp_refs = args.mcp_configs.into_iter().map(McpJsonFileRef::direct).collect();
19    let info = RuntimeBuilder::from_spec(cwd, spec).mcp_configs(mcp_refs).build_prompt_info().await?;
20
21    let system_prompt = build_prompt(&info.spec.prompts, args.system_prompt.as_deref()).await?;
22    let tools_output = build_tools(&info.tool_definitions);
23
24    println!("{system_prompt}");
25
26    if !tools_output.is_empty() {
27        println!();
28        println!("--- Tools ({} tools) ---", info.tool_definitions.len());
29        println!();
30        println!("{tools_output}");
31    }
32
33    println!();
34    println!("{}", format_stats(system_prompt.len(), tools_output.len(), info.tool_definitions.len()));
35
36    Ok(())
37}
38
39pub async fn build_prompt(prompts: &[Prompt], custom: Option<&str>) -> Result<String, CliError> {
40    let mut prompts = prompts.to_vec();
41    if let Some(custom) = custom {
42        prompts.push(Prompt::text(custom));
43    }
44    Prompt::build_all(&prompts).await.map_err(|e| CliError::AgentError(e.to_string()))
45}
46
47pub fn build_tools(tools: &[ToolDefinition]) -> String {
48    if tools.is_empty() {
49        return String::new();
50    }
51
52    let mut grouped: BTreeMap<&str, Vec<Value>> = BTreeMap::new();
53    for tool in tools {
54        let server = tool.server.as_deref().unwrap_or("(built-in)");
55        let input_schema = serde_json::from_str::<Value>(&tool.parameters).unwrap_or(Value::Null);
56        let entry = serde_json::json!({
57            "name": tool.name,
58            "description": tool.description,
59            "input_schema": input_schema,
60        });
61        grouped.entry(server).or_default().push(entry);
62    }
63
64    let mut sections = Vec::new();
65    for (server, entries) in &grouped {
66        let json = serde_json::to_string_pretty(entries).unwrap_or_default();
67        sections.push(format!("Server: {server}\n{json}"));
68    }
69
70    sections.join("\n\n")
71}
72
73pub fn format_stats(prompt_chars: usize, tool_schema_chars: usize, tool_count: usize) -> String {
74    let est_tokens = (prompt_chars + tool_schema_chars) / 4;
75    format!(
76        "---\n\
77         Prompt chars:     {prompt_chars:>8}\n\
78         Tool schema chars:{tool_schema_chars:>8}\n\
79         Est. tokens:     ~{est_tokens:>8}\n\
80         MCP tools:        {tool_count:>8}"
81    )
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87
88    fn tool(name: &str, desc: &str, params: &str, server: Option<&str>) -> ToolDefinition {
89        ToolDefinition {
90            name: name.to_string(),
91            description: desc.to_string(),
92            parameters: params.to_string(),
93            server: server.map(String::from),
94        }
95    }
96
97    #[test]
98    fn format_stats_computes_token_estimate() {
99        let output = format_stats(12000, 8500, 14);
100        assert_eq!(
101            output,
102            "---\n\
103             Prompt chars:        12000\n\
104             Tool schema chars:    8500\n\
105             Est. tokens:     ~    5125\n\
106             MCP tools:              14"
107        );
108    }
109
110    #[test]
111    fn format_stats_handles_zero() {
112        let output = format_stats(0, 0, 0);
113        assert_eq!(
114            output,
115            "---\n\
116             Prompt chars:            0\n\
117             Tool schema chars:       0\n\
118             Est. tokens:     ~       0\n\
119             MCP tools:               0"
120        );
121    }
122
123    #[test]
124    fn format_stats_handles_small_values() {
125        let output = format_stats(3, 0, 1);
126        assert_eq!(
127            output,
128            "---\n\
129             Prompt chars:            3\n\
130             Tool schema chars:       0\n\
131             Est. tokens:     ~       0\n\
132             MCP tools:               1"
133        );
134    }
135
136    #[test]
137    fn build_tools_groups_by_server() {
138        let tools = vec![
139            tool("fs_read", "Read a file", r#"{"type":"object"}"#, Some("filesystem")),
140            tool("git_log", "Show log", r#"{"type":"object"}"#, Some("git")),
141            tool("fs_write", "Write a file", r#"{"type":"object"}"#, Some("filesystem")),
142        ];
143        let output = build_tools(&tools);
144        // BTreeMap sorts: filesystem < git
145        let fs_pos = output.find("Server: filesystem").unwrap();
146        let git_pos = output.find("Server: git").unwrap();
147        assert!(fs_pos < git_pos);
148        // filesystem group has both tools
149        assert!(output.contains("fs_read"));
150        assert!(output.contains("fs_write"));
151    }
152
153    #[test]
154    fn build_tools_handles_no_server() {
155        let tools = vec![tool("builtin_tool", "A built-in", r#"{"type":"object"}"#, None)];
156        let output = build_tools(&tools);
157        assert!(output.contains("Server: (built-in)"));
158        assert!(output.contains("builtin_tool"));
159    }
160
161    #[test]
162    fn build_tools_produces_api_format() {
163        let tools = vec![tool("my_tool", "Does stuff", r#"{"type":"object","properties":{}}"#, Some("test"))];
164        let output = build_tools(&tools);
165        // Strip "Server: test\n" prefix to get the JSON
166        let json_start = output.find('[').unwrap();
167        let parsed: Vec<Value> = serde_json::from_str(&output[json_start..]).unwrap();
168        assert_eq!(parsed.len(), 1);
169        let entry = &parsed[0];
170        assert_eq!(entry["name"], "my_tool");
171        assert_eq!(entry["description"], "Does stuff");
172        assert!(entry["input_schema"].is_object());
173    }
174
175    #[test]
176    fn build_tools_empty() {
177        assert_eq!(build_tools(&[]), "");
178    }
179
180    #[test]
181    fn build_tools_malformed_params() {
182        let tools = vec![tool("bad_tool", "Broken params", "not valid json", Some("srv"))];
183        let output = build_tools(&tools);
184        assert!(output.contains("bad_tool"));
185        assert!(output.contains("null"));
186    }
187}