Skip to main content

adk_tool/mcp/manager/
toolset_impl.rs

1//! Toolset trait implementation for McpServerManager.
2//!
3//! This module contains the [`Toolset`](adk_core::Toolset) trait implementation
4//! for [`McpServerManager`], tool name collision resolution, and the
5//! [`PrefixedTool`] wrapper that delegates all [`Tool`](adk_core::Tool) methods
6//! while overriding the tool name.
7
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use adk_core::{ReadonlyContext, Result, Tool, Toolset};
12use async_trait::async_trait;
13use serde_json::Value;
14
15use super::manager::McpServerManager;
16use super::status::ServerStatus;
17
18/// Per-server tool list: each entry is `(tool_name, tool_arc)`.
19type ServerToolMap = HashMap<String, Vec<(String, Arc<dyn Tool>)>>;
20
21/// A wrapper around an `Arc<dyn Tool>` that overrides the tool name with a
22/// prefixed version to resolve name collisions across multiple MCP servers.
23///
24/// All [`Tool`] trait methods delegate to the inner tool, except `name()` and
25/// `declaration()` which use the prefixed name.
26struct PrefixedTool {
27    /// The original tool being wrapped.
28    inner: Arc<dyn Tool>,
29    /// The prefixed name in the format `{server_id}__{tool_name}`.
30    prefixed_name: String,
31}
32
33#[async_trait]
34impl Tool for PrefixedTool {
35    fn name(&self) -> &str {
36        &self.prefixed_name
37    }
38
39    fn description(&self) -> &str {
40        self.inner.description()
41    }
42
43    fn is_long_running(&self) -> bool {
44        self.inner.is_long_running()
45    }
46
47    fn parameters_schema(&self) -> Option<Value> {
48        self.inner.parameters_schema()
49    }
50
51    fn response_schema(&self) -> Option<Value> {
52        self.inner.response_schema()
53    }
54
55    fn required_scopes(&self) -> &[&str] {
56        self.inner.required_scopes()
57    }
58
59    fn is_read_only(&self) -> bool {
60        self.inner.is_read_only()
61    }
62
63    fn is_concurrency_safe(&self) -> bool {
64        self.inner.is_concurrency_safe()
65    }
66
67    fn is_builtin(&self) -> bool {
68        self.inner.is_builtin()
69    }
70
71    fn declaration(&self) -> Value {
72        let mut decl = self.inner.declaration();
73        if let Some(obj) = decl.as_object_mut() {
74            obj.insert("name".to_string(), Value::String(self.prefixed_name.clone()));
75        }
76        decl
77    }
78
79    fn enhanced_description(&self) -> String {
80        self.inner.enhanced_description()
81    }
82
83    async fn execute(&self, ctx: Arc<dyn adk_core::ToolContext>, args: Value) -> Result<Value> {
84        self.inner.execute(ctx, args).await
85    }
86}
87
88/// Resolve tool name collisions across multiple servers.
89///
90/// For tool names that appear in two or more servers, the tool is wrapped in a
91/// [`PrefixedTool`] with the format `{server_id}__{tool_name}`. Tools with
92/// unique names across all servers retain their original names.
93fn resolve_tool_names(server_tools: &ServerToolMap) -> Vec<Arc<dyn Tool>> {
94    // Step 1: Count occurrences of each tool name across all servers
95    let mut name_counts: HashMap<&str, Vec<&str>> = HashMap::new();
96    for (server_id, tools) in server_tools {
97        for (name, _) in tools {
98            name_counts.entry(name).or_default().push(server_id);
99        }
100    }
101
102    // Step 2: For names appearing in multiple servers, prefix with server_id
103    let mut result = Vec::new();
104    for (server_id, tools) in server_tools {
105        for (name, tool) in tools {
106            if name_counts[name.as_str()].len() > 1 {
107                result.push(Arc::new(PrefixedTool {
108                    inner: tool.clone(),
109                    prefixed_name: format!("{server_id}__{name}"),
110                }) as Arc<dyn Tool>);
111            } else {
112                result.push(tool.clone());
113            }
114        }
115    }
116    result
117}
118
119#[async_trait]
120impl Toolset for McpServerManager {
121    fn name(&self) -> &str {
122        &self.name
123    }
124
125    async fn tools(&self, ctx: Arc<dyn ReadonlyContext>) -> Result<Vec<Arc<dyn Tool>>> {
126        // Acquire read lock to iterate over servers
127        let servers = self.servers.read().await;
128
129        // Collect tools from each Running server
130        let mut server_tools: ServerToolMap = HashMap::new();
131
132        for (server_id, entry) in servers.iter() {
133            if entry.status != ServerStatus::Running {
134                continue;
135            }
136
137            let toolset = match &entry.toolset {
138                Some(ts) => ts,
139                None => continue,
140            };
141
142            match toolset.tools(ctx.clone()).await {
143                Ok(tools) => {
144                    let named_tools: Vec<(String, Arc<dyn Tool>)> =
145                        tools.into_iter().map(|t| (t.name().to_string(), t)).collect();
146                    server_tools.insert(server_id.clone(), named_tools);
147                }
148                Err(e) => {
149                    tracing::warn!(
150                        server.id = server_id,
151                        error = %e,
152                        "failed to list tools from server, skipping"
153                    );
154                }
155            }
156        }
157
158        Ok(resolve_tool_names(&server_tools))
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use adk_core::ToolContext;
166
167    /// A minimal test tool for verifying collision resolution.
168    struct FakeTool {
169        name: String,
170        description: String,
171    }
172
173    #[async_trait]
174    impl Tool for FakeTool {
175        fn name(&self) -> &str {
176            &self.name
177        }
178
179        fn description(&self) -> &str {
180            &self.description
181        }
182
183        async fn execute(&self, _ctx: Arc<dyn ToolContext>, _args: Value) -> Result<Value> {
184            Ok(Value::String("ok".to_string()))
185        }
186    }
187
188    fn make_tool(name: &str) -> Arc<dyn Tool> {
189        Arc::new(FakeTool { name: name.to_string(), description: format!("Tool {name}") })
190    }
191
192    #[test]
193    fn test_resolve_no_collisions() {
194        let mut server_tools: ServerToolMap = HashMap::new();
195        server_tools
196            .insert("server_a".to_string(), vec![("tool_x".to_string(), make_tool("tool_x"))]);
197        server_tools
198            .insert("server_b".to_string(), vec![("tool_y".to_string(), make_tool("tool_y"))]);
199
200        let result = resolve_tool_names(&server_tools);
201        assert_eq!(result.len(), 2);
202
203        let names: Vec<&str> = result.iter().map(|t| t.name()).collect();
204        assert!(names.contains(&"tool_x"));
205        assert!(names.contains(&"tool_y"));
206    }
207
208    #[test]
209    fn test_resolve_with_collisions() {
210        let mut server_tools: ServerToolMap = HashMap::new();
211        server_tools.insert(
212            "server_a".to_string(),
213            vec![("read_file".to_string(), make_tool("read_file"))],
214        );
215        server_tools.insert(
216            "server_b".to_string(),
217            vec![("read_file".to_string(), make_tool("read_file"))],
218        );
219
220        let result = resolve_tool_names(&server_tools);
221        assert_eq!(result.len(), 2);
222
223        let mut names: Vec<String> = result.iter().map(|t| t.name().to_string()).collect();
224        names.sort();
225        assert_eq!(names, vec!["server_a__read_file", "server_b__read_file"]);
226    }
227
228    #[test]
229    fn test_resolve_mixed_collision_and_unique() {
230        let mut server_tools: ServerToolMap = HashMap::new();
231        server_tools.insert(
232            "server_a".to_string(),
233            vec![
234                ("read_file".to_string(), make_tool("read_file")),
235                ("unique_a".to_string(), make_tool("unique_a")),
236            ],
237        );
238        server_tools.insert(
239            "server_b".to_string(),
240            vec![
241                ("read_file".to_string(), make_tool("read_file")),
242                ("unique_b".to_string(), make_tool("unique_b")),
243            ],
244        );
245
246        let result = resolve_tool_names(&server_tools);
247        assert_eq!(result.len(), 4);
248
249        let mut names: Vec<String> = result.iter().map(|t| t.name().to_string()).collect();
250        names.sort();
251        assert_eq!(
252            names,
253            vec!["server_a__read_file", "server_b__read_file", "unique_a", "unique_b",]
254        );
255    }
256
257    #[test]
258    fn test_resolve_empty_servers() {
259        let server_tools: ServerToolMap = HashMap::new();
260        let result = resolve_tool_names(&server_tools);
261        assert!(result.is_empty());
262    }
263
264    #[test]
265    fn test_prefixed_tool_delegates_description() {
266        let inner = make_tool("original");
267        let prefixed =
268            PrefixedTool { inner: inner.clone(), prefixed_name: "server__original".to_string() };
269
270        assert_eq!(prefixed.name(), "server__original");
271        assert_eq!(prefixed.description(), inner.description());
272        assert_eq!(prefixed.is_long_running(), inner.is_long_running());
273        assert_eq!(prefixed.is_read_only(), inner.is_read_only());
274        assert_eq!(prefixed.is_concurrency_safe(), inner.is_concurrency_safe());
275        assert_eq!(prefixed.is_builtin(), inner.is_builtin());
276    }
277
278    #[test]
279    fn test_prefixed_tool_declaration_overrides_name() {
280        let inner = make_tool("original");
281        let prefixed = PrefixedTool { inner, prefixed_name: "server__original".to_string() };
282
283        let decl = prefixed.declaration();
284        assert_eq!(decl["name"], "server__original");
285    }
286
287    #[test]
288    fn test_resolve_three_way_collision() {
289        let mut server_tools: ServerToolMap = HashMap::new();
290        server_tools.insert("a".to_string(), vec![("shared".to_string(), make_tool("shared"))]);
291        server_tools.insert("b".to_string(), vec![("shared".to_string(), make_tool("shared"))]);
292        server_tools.insert("c".to_string(), vec![("shared".to_string(), make_tool("shared"))]);
293
294        let result = resolve_tool_names(&server_tools);
295        assert_eq!(result.len(), 3);
296
297        let mut names: Vec<String> = result.iter().map(|t| t.name().to_string()).collect();
298        names.sort();
299        assert_eq!(names, vec!["a__shared", "b__shared", "c__shared"]);
300    }
301}