1use super::McpError;
2use super::mcp_client::McpClient;
3use super::naming::split_on_server_name;
4use llm::ToolDefinition;
5use rmcp::{RoleClient, service::RunningService};
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use serde_json::{Map, Value};
9use std::collections::HashSet;
10use std::path::{Path, PathBuf};
11use std::sync::Arc;
12use tokio::fs::{create_dir_all, remove_dir_all, write};
13
14#[derive(Debug)]
16pub struct ResolvedCall {
17 pub server: String,
18 pub tool: String,
19 pub arguments: Option<Map<String, Value>>,
20}
21
22pub struct ToolProxy {
24 name: String,
25 members: HashSet<String>,
27 tool_dir: PathBuf,
29 instructions: String,
31}
32
33#[derive(Deserialize, JsonSchema)]
35struct ProxyCallArgs {
36 server: String,
38 tool: String,
40 arguments: Option<Map<String, Value>>,
42}
43
44impl ToolProxy {
45 pub fn new(
46 name: String,
47 members: HashSet<String>,
48 tool_dir: PathBuf,
49 server_descriptions: &[(String, String)],
50 ) -> Self {
51 let instructions = Self::build_instructions(&tool_dir, server_descriptions);
52 Self { name, members, tool_dir, instructions }
53 }
54
55 pub fn name(&self) -> &str {
56 &self.name
57 }
58
59 pub fn members(&self) -> &HashSet<String> {
60 &self.members
61 }
62
63 pub fn contains_server(&self, server_name: &str) -> bool {
65 self.members.contains(server_name)
66 }
67
68 pub fn is_call_tool(&self, namespaced_tool_name: &str) -> bool {
70 split_on_server_name(namespaced_tool_name)
71 .is_some_and(|(server, tool)| tool == "call_tool" && server == self.name)
72 }
73
74 pub fn resolve_call(&self, arguments_json: &str) -> super::Result<ResolvedCall> {
76 let args: ProxyCallArgs = serde_json::from_str(arguments_json)?;
77 if !self.contains_server(&args.server) {
78 return Err(McpError::ServerNotFound(format!(
79 "Server '{}' is not part of proxy '{}'",
80 args.server, self.name
81 )));
82 }
83 Ok(ResolvedCall { server: args.server, tool: args.tool, arguments: args.arguments })
84 }
85
86 pub fn instructions(&self) -> &str {
87 &self.instructions
88 }
89
90 pub fn tool_dir(&self) -> &Path {
91 &self.tool_dir
92 }
93
94 pub fn add_member(&mut self, server_name: String) {
96 self.members.insert(server_name);
97 }
98
99 pub fn dir(name: &str) -> Result<PathBuf, McpError> {
103 let base = super::aether_home().ok_or_else(|| McpError::Other("Home directory not set".into()))?;
104 Ok(base.join("tool-proxy").join(name))
105 }
106
107 pub async fn clean_dir(tool_dir: &Path) -> Result<(), McpError> {
109 if tool_dir.exists() {
110 remove_dir_all(tool_dir)
111 .await
112 .map_err(|e| McpError::Other(format!("Failed to clean tool-proxy dir: {e}")))?;
113 }
114 Ok(())
115 }
116
117 pub fn call_tool_schema() -> Arc<Map<String, Value>> {
119 let schema = schemars::schema_for!(ProxyCallArgs);
120 let value = serde_json::to_value(schema).expect("schema serialization cannot fail");
121 Arc::new(value.as_object().expect("schema is always an object").clone())
122 }
123
124 pub fn call_tool_definition(proxy_name: &str) -> ToolDefinition {
126 let schema = Self::call_tool_schema();
127 let namespaced_name = format!("{proxy_name}__call_tool");
128 ToolDefinition {
129 name: namespaced_name,
130 description: "Execute a tool on a nested MCP server. Browse the tool-proxy directory to discover available tools first.".to_string(),
131 parameters: Value::Object((*schema).clone()).to_string(),
132 server: Some(proxy_name.to_string()),
133 }
134 }
135
136 pub async fn write_tools_to_dir(
139 server_name: &str,
140 client: &RunningService<RoleClient, McpClient>,
141 tool_dir: &Path,
142 ) -> Result<(), McpError> {
143 let tools_response = client.list_tools(None).await.map_err(|e| {
144 McpError::ToolDiscoveryFailed(format!("Failed to list tools for nested server '{server_name}': {e}"))
145 })?;
146
147 Self::write_tool_entries_to_dir(server_name, &tools_response.tools, tool_dir).await
148 }
149
150 pub(super) async fn write_tool_entries_to_dir(
152 server_name: &str,
153 tools: &[rmcp::model::Tool],
154 tool_dir: &Path,
155 ) -> Result<(), McpError> {
156 let server_dir = tool_dir.join(server_name);
157 if server_dir.exists() {
158 remove_dir_all(&server_dir).await?;
159 }
160 create_dir_all(&server_dir).await?;
161
162 for tool in tools {
163 let entry = ToolFileEntry {
164 name: tool.name.to_string(),
165 description: tool.description.clone().unwrap_or_default().to_string(),
166 server: server_name.to_string(),
167 parameters: Value::Object((*tool.input_schema).clone()),
168 };
169
170 let file_path = server_dir.join(format!("{}.json", tool.name));
171 let json = serde_json::to_string_pretty(&entry)?;
172 write(&file_path, json).await?;
173 }
174
175 Ok(())
176 }
177
178 pub fn extract_server_description(client: &RunningService<RoleClient, McpClient>, server_name: &str) -> String {
182 client
183 .peer_info()
184 .and_then(|info| info.server_info.description.as_deref().filter(|s| !s.is_empty()).map(ToString::to_string))
185 .unwrap_or_else(|| server_name.to_string())
186 }
187
188 pub(super) fn build_instructions(tool_dir: &Path, server_descriptions: &[(String, String)]) -> String {
190 use std::fmt::Write;
191
192 let mut instructions = format!(
193 "You are connected to a set of MCP servers, whose tools are available at `{tool_dir}`.\n\
194 Each subdirectory in `{tool_dir}` represents a MCP server you're connected. And each subdir contains tool definitions in the form of JSON files.\n\
195 Browse or grep the directory to discover tools, then use `call_tool` to execute them.",
196 tool_dir = tool_dir.display()
197 );
198
199 if !server_descriptions.is_empty() {
200 instructions.push_str("\n\n## Connected Servers\n");
201 for (name, desc) in server_descriptions {
202 let _ = writeln!(instructions, "- **{name}**: {desc}");
203 }
204 }
205
206 instructions
207 }
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct ToolFileEntry {
213 pub name: String,
214 pub description: String,
215 pub server: String,
216 pub parameters: Value,
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use serde_json::json;
223
224 #[test]
225 fn tool_file_entry_serialization() {
226 let entry = ToolFileEntry {
227 name: "create_issue".to_string(),
228 description: "Create a GitHub issue".to_string(),
229 server: "github".to_string(),
230 parameters: json!({
231 "type": "object",
232 "properties": {
233 "repo": { "type": "string" },
234 "title": { "type": "string" }
235 },
236 "required": ["repo", "title"]
237 }),
238 };
239
240 let json_str = serde_json::to_string_pretty(&entry).unwrap();
241 let deserialized: ToolFileEntry = serde_json::from_str(&json_str).unwrap();
242
243 assert_eq!(deserialized.name, "create_issue");
244 assert_eq!(deserialized.server, "github");
245 assert_eq!(deserialized.description, "Create a GitHub issue");
246 }
247
248 #[test]
249 fn call_tool_schema_is_valid() {
250 let schema = ToolProxy::call_tool_schema();
251 assert_eq!(schema.get("type").unwrap(), "object");
252
253 let properties = schema.get("properties").unwrap().as_object().unwrap();
254 assert!(properties.contains_key("server"));
255 assert!(properties.contains_key("tool"));
256 assert!(properties.contains_key("arguments"));
257
258 let required = schema.get("required").unwrap().as_array().unwrap();
259 assert_eq!(required.len(), 2);
260 let required_names: Vec<&str> = required.iter().map(|v| v.as_str().unwrap()).collect();
261 assert!(required_names.contains(&"server"));
262 assert!(required_names.contains(&"tool"));
263 }
264
265 #[test]
266 fn tool_proxy_dir_appends_correct_suffix() {
267 let dir = ToolProxy::dir("proxy").unwrap();
268 assert!(
269 dir.ends_with("tool-proxy/proxy"),
270 "Expected path to end with tool-proxy/proxy, got: {}",
271 dir.display()
272 );
273 }
274
275 #[test]
276 fn write_and_read_tool_files() {
277 let tmp = tempfile::tempdir().unwrap();
278 let tool_dir = tmp.path().to_path_buf();
279 let server_dir = tool_dir.join("test-server");
280 std::fs::create_dir_all(&server_dir).unwrap();
281
282 let entry = ToolFileEntry {
283 name: "my_tool".to_string(),
284 description: "Does stuff".to_string(),
285 server: "test-server".to_string(),
286 parameters: json!({"type": "object", "properties": {}}),
287 };
288
289 let file_path = server_dir.join("my_tool.json");
290 let json = serde_json::to_string_pretty(&entry).unwrap();
291 std::fs::write(&file_path, &json).unwrap();
292
293 let contents = std::fs::read_to_string(&file_path).unwrap();
294 let parsed: ToolFileEntry = serde_json::from_str(&contents).unwrap();
295 assert_eq!(parsed.name, "my_tool");
296 assert_eq!(parsed.server, "test-server");
297 }
298
299 #[test]
300 fn call_tool_definition_has_correct_name_and_server() {
301 let def = ToolProxy::call_tool_definition("myproxy");
302 assert_eq!(def.name, "myproxy__call_tool");
303 assert_eq!(def.server, Some("myproxy".to_string()));
304 assert!(def.description.contains("Execute a tool"));
305 }
306
307 #[test]
308 fn build_proxy_instructions_includes_tool_dir_and_servers() {
309 let tool_dir = std::path::Path::new("/tmp/tool-proxy/test");
310 let descriptions =
311 vec![("math".to_string(), "Math tools".to_string()), ("git".to_string(), "Git tools".to_string())];
312 let instr = ToolProxy::build_instructions(tool_dir, &descriptions);
313 assert!(instr.contains("/tmp/tool-proxy/test"));
314 assert!(instr.contains("call_tool"));
315 assert!(instr.contains("## Connected Servers"));
316 assert!(instr.contains("**math**"));
317 assert!(instr.contains("**git**"));
318 }
319
320 #[tokio::test]
321 async fn write_tool_entries_to_dir_removes_stale_files() {
322 let tmp = tempfile::tempdir().unwrap();
323 let tool_dir = tmp.path().to_path_buf();
324 let server_dir = tool_dir.join("my-server");
325 std::fs::create_dir_all(&server_dir).unwrap();
326
327 let old_entry = ToolFileEntry {
328 name: "old_tool".to_string(),
329 description: "Old tool".to_string(),
330 server: "my-server".to_string(),
331 parameters: json!({"type": "object", "properties": {}}),
332 };
333 std::fs::write(server_dir.join("old_tool.json"), serde_json::to_string_pretty(&old_entry).unwrap()).unwrap();
334 assert!(server_dir.join("old_tool.json").exists());
335
336 let tools: Vec<rmcp::model::Tool> =
337 vec![rmcp::model::Tool::new("new_tool", "New tool", Arc::new(serde_json::Map::new()))];
338 ToolProxy::write_tool_entries_to_dir("my-server", &tools, &tool_dir).await.unwrap();
339
340 assert!(!server_dir.join("old_tool.json").exists(), "stale file should be removed");
341 assert!(server_dir.join("new_tool.json").exists(), "new file should be written");
342 }
343
344 fn make_proxy(members: &[&str]) -> ToolProxy {
345 let members: HashSet<String> = members.iter().map(std::string::ToString::to_string).collect();
346 ToolProxy::new(
347 "myproxy".to_string(),
348 members,
349 PathBuf::from("/tmp/tool-proxy/myproxy"),
350 &[("math".to_string(), "Math tools".to_string())],
351 )
352 }
353
354 #[test]
355 fn tool_proxy_contains_server() {
356 let proxy = make_proxy(&["math", "git"]);
357 assert!(proxy.contains_server("math"));
358 assert!(proxy.contains_server("git"));
359 assert!(!proxy.contains_server("unknown"));
360 }
361
362 #[test]
363 fn tool_proxy_is_call_tool() {
364 let proxy = make_proxy(&["math"]);
365 assert!(proxy.is_call_tool("myproxy__call_tool"));
366 assert!(!proxy.is_call_tool("myproxy__other_tool"));
367 assert!(!proxy.is_call_tool("other__call_tool"));
368 assert!(!proxy.is_call_tool("invalid"));
369 }
370
371 #[test]
372 fn tool_proxy_resolve_call_success() {
373 let proxy = make_proxy(&["math"]);
374 let json = r#"{"server":"math","tool":"add","arguments":{"a":1,"b":2}}"#;
375 let call = proxy.resolve_call(json).unwrap();
376 assert_eq!(call.server, "math");
377 assert_eq!(call.tool, "add");
378 assert!(call.arguments.is_some());
379 assert_eq!(call.arguments.unwrap().get("a").unwrap(), 1);
380 }
381
382 #[test]
383 fn tool_proxy_resolve_call_unknown_server() {
384 let proxy = make_proxy(&["math"]);
385 let json = r#"{"server":"unknown","tool":"add","arguments":{}}"#;
386 let err = proxy.resolve_call(json).unwrap_err();
387 assert!(err.to_string().contains("not part of proxy"));
388 }
389
390 #[test]
391 fn tool_proxy_accessors() {
392 let proxy = make_proxy(&["math"]);
393 assert_eq!(proxy.name(), "myproxy");
394 assert_eq!(proxy.tool_dir(), Path::new("/tmp/tool-proxy/myproxy"));
395 assert!(proxy.instructions().contains("call_tool"));
396 }
397
398 #[test]
399 fn tool_proxy_add_member() {
400 let mut proxy = make_proxy(&["math"]);
401 assert!(!proxy.contains_server("git"));
402 proxy.add_member("git".to_string());
403 assert!(proxy.contains_server("git"));
404 }
405}