Skip to main content

mcp_utils/client/
tool_proxy.rs

1use super::McpError;
2use super::mcp_client::McpClient;
3use super::naming::split_on_server_name;
4use llm::ToolDefinition;
5use rmcp::{RoleClient, service::RunningService};
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use serde_json::{Map, Value};
9use std::collections::HashSet;
10use std::path::{Path, PathBuf};
11use std::sync::Arc;
12use tokio::fs::{create_dir_all, remove_dir_all, write};
13
14/// Resolved proxy call returned by [`ToolProxy::resolve_call`].
15#[derive(Debug)]
16pub struct ResolvedCall {
17    pub server: String,
18    pub tool: String,
19    pub arguments: Option<Map<String, Value>>,
20}
21
22/// A tool-proxy that wraps multiple servers behind a single `call_tool`.
23pub struct ToolProxy {
24    name: String,
25    /// Set of nested server names belonging to this proxy.
26    members: HashSet<String>,
27    /// Directory where tool description files are written for agent browsing.
28    tool_dir: PathBuf,
29    /// Synthesized instructions text for the proxy.
30    instructions: String,
31}
32
33/// Parsed arguments from a proxy `call_tool` invocation.
34#[derive(Deserialize, JsonSchema)]
35struct ProxyCallArgs {
36    /// The server name (directory name in the tool-proxy folder)
37    server: String,
38    /// The tool name (file name without .json)
39    tool: String,
40    /// Arguments to pass to the tool
41    arguments: Option<Map<String, Value>>,
42}
43
44impl ToolProxy {
45    pub fn new(
46        name: String,
47        members: HashSet<String>,
48        tool_dir: PathBuf,
49        server_descriptions: &[(String, String)],
50    ) -> Self {
51        let instructions = Self::build_instructions(&tool_dir, server_descriptions);
52        Self { name, members, tool_dir, instructions }
53    }
54
55    pub fn name(&self) -> &str {
56        &self.name
57    }
58
59    /// Whether `server_name` is a nested member of this proxy.
60    pub fn contains_server(&self, server_name: &str) -> bool {
61        self.members.contains(server_name)
62    }
63
64    /// Whether a namespaced tool name refers to this proxy's `call_tool`.
65    pub fn is_call_tool(&self, namespaced_tool_name: &str) -> bool {
66        split_on_server_name(namespaced_tool_name)
67            .is_some_and(|(server, tool)| tool == "call_tool" && server == self.name)
68    }
69
70    /// Parse and validate a proxy `call_tool` invocation.
71    pub fn resolve_call(&self, arguments_json: &str) -> super::Result<ResolvedCall> {
72        let args: ProxyCallArgs = serde_json::from_str(arguments_json)?;
73        if !self.contains_server(&args.server) {
74            return Err(McpError::ServerNotFound(format!(
75                "Server '{}' is not part of proxy '{}'",
76                args.server, self.name
77            )));
78        }
79        Ok(ResolvedCall { server: args.server, tool: args.tool, arguments: args.arguments })
80    }
81
82    pub fn instructions(&self) -> &str {
83        &self.instructions
84    }
85
86    pub fn tool_dir(&self) -> &Path {
87        &self.tool_dir
88    }
89
90    /// Register a new member server (e.g. after late OAuth registration).
91    pub fn add_member(&mut self, server_name: String) {
92        self.members.insert(server_name);
93    }
94
95    /// Returns the directory for a tool-proxy's tool definitions.
96    ///
97    /// Uses `$AETHER_HOME/tool-proxy/<name>` or `~/.aether/tool-proxy/<name>`.
98    pub fn dir(name: &str) -> Result<PathBuf, McpError> {
99        let base = super::aether_home().ok_or_else(|| McpError::Other("Home directory not set".into()))?;
100        Ok(base.join("tool-proxy").join(name))
101    }
102
103    /// Clean up the tool directory for a proxy, removing all tool files.
104    pub async fn clean_dir(tool_dir: &Path) -> Result<(), McpError> {
105        if tool_dir.exists() {
106            remove_dir_all(tool_dir)
107                .await
108                .map_err(|e| McpError::Other(format!("Failed to clean tool-proxy dir: {e}")))?;
109        }
110        Ok(())
111    }
112
113    /// Build the `call_tool` JSON schema used by the proxy's virtual tool.
114    pub fn call_tool_schema() -> Arc<Map<String, Value>> {
115        let schema = schemars::schema_for!(ProxyCallArgs);
116        let value = serde_json::to_value(schema).expect("schema serialization cannot fail");
117        Arc::new(value.as_object().expect("schema is always an object").clone())
118    }
119
120    /// Build a `ToolDefinition` for the proxy's `call_tool` virtual tool.
121    pub fn call_tool_definition(proxy_name: &str) -> ToolDefinition {
122        let schema = Self::call_tool_schema();
123        let namespaced_name = format!("{proxy_name}__call_tool");
124        ToolDefinition {
125            name: namespaced_name,
126            description: "Execute a tool on a nested MCP server. Browse the tool-proxy directory to discover available tools first.".to_string(),
127            parameters: Value::Object((*schema).clone()).to_string(),
128            server: Some(proxy_name.to_string()),
129        }
130    }
131
132    /// Discover tools from a connected MCP server and write them as JSON files
133    /// to `tool_dir/<server_name>/`.
134    pub async fn write_tools_to_dir(
135        server_name: &str,
136        client: &RunningService<RoleClient, McpClient>,
137        tool_dir: &Path,
138    ) -> Result<(), McpError> {
139        let tools_response = client.list_tools(None).await.map_err(|e| {
140            McpError::ToolDiscoveryFailed(format!("Failed to list tools for nested server '{server_name}': {e}"))
141        })?;
142
143        let server_dir = tool_dir.join(server_name);
144        create_dir_all(&server_dir).await?;
145
146        for tool in &tools_response.tools {
147            let entry = ToolFileEntry {
148                name: tool.name.to_string(),
149                description: tool.description.clone().unwrap_or_default().to_string(),
150                server: server_name.to_string(),
151                parameters: Value::Object((*tool.input_schema).clone()),
152            };
153
154            let file_path = server_dir.join(format!("{}.json", tool.name));
155            let json = serde_json::to_string_pretty(&entry)?;
156            write(&file_path, json).await?;
157        }
158
159        Ok(())
160    }
161
162    /// Extract a one-line description for a nested server from its peer info.
163    ///
164    /// Uses `server_info.description`, falling back to the server name.
165    pub fn extract_server_description(client: &RunningService<RoleClient, McpClient>, server_name: &str) -> String {
166        client
167            .peer_info()
168            .and_then(|info| info.server_info.description.as_deref().filter(|s| !s.is_empty()).map(ToString::to_string))
169            .unwrap_or_else(|| server_name.to_string())
170    }
171
172    /// Build proxy instructions describing the tool directory and connected servers.
173    fn build_instructions(tool_dir: &Path, server_descriptions: &[(String, String)]) -> String {
174        use std::fmt::Write;
175
176        let mut instructions = format!(
177            "You are connected to a set of MCP servers, whose tools are available at `{tool_dir}`.\n\
178             Each subdirectory in `{tool_dir}` represents a MCP server you're connected. And each subdir contains tool definitions in the form of JSON files.\n\
179             Browse or grep the directory to discover tools, then use `call_tool` to execute them.",
180            tool_dir = tool_dir.display()
181        );
182
183        if !server_descriptions.is_empty() {
184            instructions.push_str("\n\n## Connected Servers\n");
185            for (name, desc) in server_descriptions {
186                let _ = writeln!(instructions, "- **{name}**: {desc}");
187            }
188        }
189
190        instructions
191    }
192}
193
194/// A tool definition written to disk for agent browsing.
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct ToolFileEntry {
197    pub name: String,
198    pub description: String,
199    pub server: String,
200    pub parameters: Value,
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use serde_json::json;
207
208    #[test]
209    fn tool_file_entry_serialization() {
210        let entry = ToolFileEntry {
211            name: "create_issue".to_string(),
212            description: "Create a GitHub issue".to_string(),
213            server: "github".to_string(),
214            parameters: json!({
215                "type": "object",
216                "properties": {
217                    "repo": { "type": "string" },
218                    "title": { "type": "string" }
219                },
220                "required": ["repo", "title"]
221            }),
222        };
223
224        let json_str = serde_json::to_string_pretty(&entry).unwrap();
225        let deserialized: ToolFileEntry = serde_json::from_str(&json_str).unwrap();
226
227        assert_eq!(deserialized.name, "create_issue");
228        assert_eq!(deserialized.server, "github");
229        assert_eq!(deserialized.description, "Create a GitHub issue");
230    }
231
232    #[test]
233    fn call_tool_schema_is_valid() {
234        let schema = ToolProxy::call_tool_schema();
235        assert_eq!(schema.get("type").unwrap(), "object");
236
237        let properties = schema.get("properties").unwrap().as_object().unwrap();
238        assert!(properties.contains_key("server"));
239        assert!(properties.contains_key("tool"));
240        assert!(properties.contains_key("arguments"));
241
242        // `server` and `tool` are required; `arguments` is Option so not required
243        let required = schema.get("required").unwrap().as_array().unwrap();
244        assert_eq!(required.len(), 2);
245        let required_names: Vec<&str> = required.iter().map(|v| v.as_str().unwrap()).collect();
246        assert!(required_names.contains(&"server"));
247        assert!(required_names.contains(&"tool"));
248    }
249
250    #[test]
251    fn tool_proxy_dir_appends_correct_suffix() {
252        let dir = ToolProxy::dir("proxy").unwrap();
253        assert!(
254            dir.ends_with("tool-proxy/proxy"),
255            "Expected path to end with tool-proxy/proxy, got: {}",
256            dir.display()
257        );
258    }
259
260    #[test]
261    fn write_and_read_tool_files() {
262        let tmp = tempfile::tempdir().unwrap();
263        let tool_dir = tmp.path().to_path_buf();
264        let server_dir = tool_dir.join("test-server");
265        std::fs::create_dir_all(&server_dir).unwrap();
266
267        let entry = ToolFileEntry {
268            name: "my_tool".to_string(),
269            description: "Does stuff".to_string(),
270            server: "test-server".to_string(),
271            parameters: json!({"type": "object", "properties": {}}),
272        };
273
274        let file_path = server_dir.join("my_tool.json");
275        let json = serde_json::to_string_pretty(&entry).unwrap();
276        std::fs::write(&file_path, &json).unwrap();
277
278        let contents = std::fs::read_to_string(&file_path).unwrap();
279        let parsed: ToolFileEntry = serde_json::from_str(&contents).unwrap();
280        assert_eq!(parsed.name, "my_tool");
281        assert_eq!(parsed.server, "test-server");
282    }
283
284    #[test]
285    fn call_tool_definition_has_correct_name_and_server() {
286        let def = ToolProxy::call_tool_definition("myproxy");
287        assert_eq!(def.name, "myproxy__call_tool");
288        assert_eq!(def.server, Some("myproxy".to_string()));
289        assert!(def.description.contains("Execute a tool"));
290    }
291
292    #[test]
293    fn build_proxy_instructions_includes_tool_dir_and_servers() {
294        let tool_dir = std::path::Path::new("/tmp/tool-proxy/test");
295        let descriptions =
296            vec![("math".to_string(), "Math tools".to_string()), ("git".to_string(), "Git tools".to_string())];
297        let instr = ToolProxy::build_instructions(tool_dir, &descriptions);
298        assert!(instr.contains("/tmp/tool-proxy/test"));
299        assert!(instr.contains("call_tool"));
300        assert!(instr.contains("## Connected Servers"));
301        assert!(instr.contains("**math**"));
302        assert!(instr.contains("**git**"));
303    }
304
305    fn make_proxy(members: &[&str]) -> ToolProxy {
306        let members: HashSet<String> = members.iter().map(std::string::ToString::to_string).collect();
307        ToolProxy::new(
308            "myproxy".to_string(),
309            members,
310            PathBuf::from("/tmp/tool-proxy/myproxy"),
311            &[("math".to_string(), "Math tools".to_string())],
312        )
313    }
314
315    #[test]
316    fn tool_proxy_contains_server() {
317        let proxy = make_proxy(&["math", "git"]);
318        assert!(proxy.contains_server("math"));
319        assert!(proxy.contains_server("git"));
320        assert!(!proxy.contains_server("unknown"));
321    }
322
323    #[test]
324    fn tool_proxy_is_call_tool() {
325        let proxy = make_proxy(&["math"]);
326        assert!(proxy.is_call_tool("myproxy__call_tool"));
327        assert!(!proxy.is_call_tool("myproxy__other_tool"));
328        assert!(!proxy.is_call_tool("other__call_tool"));
329        assert!(!proxy.is_call_tool("invalid"));
330    }
331
332    #[test]
333    fn tool_proxy_resolve_call_success() {
334        let proxy = make_proxy(&["math"]);
335        let json = r#"{"server":"math","tool":"add","arguments":{"a":1,"b":2}}"#;
336        let call = proxy.resolve_call(json).unwrap();
337        assert_eq!(call.server, "math");
338        assert_eq!(call.tool, "add");
339        assert!(call.arguments.is_some());
340        assert_eq!(call.arguments.unwrap().get("a").unwrap(), 1);
341    }
342
343    #[test]
344    fn tool_proxy_resolve_call_unknown_server() {
345        let proxy = make_proxy(&["math"]);
346        let json = r#"{"server":"unknown","tool":"add","arguments":{}}"#;
347        let err = proxy.resolve_call(json).unwrap_err();
348        assert!(err.to_string().contains("not part of proxy"));
349    }
350
351    #[test]
352    fn tool_proxy_accessors() {
353        let proxy = make_proxy(&["math"]);
354        assert_eq!(proxy.name(), "myproxy");
355        assert_eq!(proxy.tool_dir(), Path::new("/tmp/tool-proxy/myproxy"));
356        assert!(proxy.instructions().contains("call_tool"));
357    }
358
359    #[test]
360    fn tool_proxy_add_member() {
361        let mut proxy = make_proxy(&["math"]);
362        assert!(!proxy.contains_server("git"));
363        proxy.add_member("git".to_string());
364        assert!(proxy.contains_server("git"));
365    }
366}