adk_tool/mcp/manager/
toolset_impl.rs1use std::collections::HashMap;
9use std::sync::Arc;
10
11use adk_core::{ReadonlyContext, Result, Tool, Toolset};
12use async_trait::async_trait;
13use serde_json::Value;
14
15use super::manager::McpServerManager;
16use super::status::ServerStatus;
17
18type ServerToolMap = HashMap<String, Vec<(String, Arc<dyn Tool>)>>;
20
21struct PrefixedTool {
27 inner: Arc<dyn Tool>,
29 prefixed_name: String,
31}
32
33#[async_trait]
34impl Tool for PrefixedTool {
35 fn name(&self) -> &str {
36 &self.prefixed_name
37 }
38
39 fn description(&self) -> &str {
40 self.inner.description()
41 }
42
43 fn is_long_running(&self) -> bool {
44 self.inner.is_long_running()
45 }
46
47 fn parameters_schema(&self) -> Option<Value> {
48 self.inner.parameters_schema()
49 }
50
51 fn response_schema(&self) -> Option<Value> {
52 self.inner.response_schema()
53 }
54
55 fn required_scopes(&self) -> &[&str] {
56 self.inner.required_scopes()
57 }
58
59 fn is_read_only(&self) -> bool {
60 self.inner.is_read_only()
61 }
62
63 fn is_concurrency_safe(&self) -> bool {
64 self.inner.is_concurrency_safe()
65 }
66
67 fn is_builtin(&self) -> bool {
68 self.inner.is_builtin()
69 }
70
71 fn declaration(&self) -> Value {
72 let mut decl = self.inner.declaration();
73 if let Some(obj) = decl.as_object_mut() {
74 obj.insert("name".to_string(), Value::String(self.prefixed_name.clone()));
75 }
76 decl
77 }
78
79 fn enhanced_description(&self) -> String {
80 self.inner.enhanced_description()
81 }
82
83 async fn execute(&self, ctx: Arc<dyn adk_core::ToolContext>, args: Value) -> Result<Value> {
84 self.inner.execute(ctx, args).await
85 }
86}
87
88fn resolve_tool_names(server_tools: &ServerToolMap) -> Vec<Arc<dyn Tool>> {
94 let mut name_counts: HashMap<&str, Vec<&str>> = HashMap::new();
96 for (server_id, tools) in server_tools {
97 for (name, _) in tools {
98 name_counts.entry(name).or_default().push(server_id);
99 }
100 }
101
102 let mut result = Vec::new();
104 for (server_id, tools) in server_tools {
105 for (name, tool) in tools {
106 if name_counts[name.as_str()].len() > 1 {
107 result.push(Arc::new(PrefixedTool {
108 inner: tool.clone(),
109 prefixed_name: format!("{server_id}__{name}"),
110 }) as Arc<dyn Tool>);
111 } else {
112 result.push(tool.clone());
113 }
114 }
115 }
116 result
117}
118
119#[async_trait]
120impl Toolset for McpServerManager {
121 fn name(&self) -> &str {
122 &self.name
123 }
124
125 async fn tools(&self, ctx: Arc<dyn ReadonlyContext>) -> Result<Vec<Arc<dyn Tool>>> {
126 let servers = self.servers.read().await;
128
129 let mut server_tools: ServerToolMap = HashMap::new();
131
132 for (server_id, entry) in servers.iter() {
133 if entry.status != ServerStatus::Running {
134 continue;
135 }
136
137 let toolset = match &entry.toolset {
138 Some(ts) => ts,
139 None => continue,
140 };
141
142 match toolset.tools(ctx.clone()).await {
143 Ok(tools) => {
144 let named_tools: Vec<(String, Arc<dyn Tool>)> =
145 tools.into_iter().map(|t| (t.name().to_string(), t)).collect();
146 server_tools.insert(server_id.clone(), named_tools);
147 }
148 Err(e) => {
149 tracing::warn!(
150 server.id = server_id,
151 error = %e,
152 "failed to list tools from server, skipping"
153 );
154 }
155 }
156 }
157
158 Ok(resolve_tool_names(&server_tools))
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use adk_core::ToolContext;
166
167 struct FakeTool {
169 name: String,
170 description: String,
171 }
172
173 #[async_trait]
174 impl Tool for FakeTool {
175 fn name(&self) -> &str {
176 &self.name
177 }
178
179 fn description(&self) -> &str {
180 &self.description
181 }
182
183 async fn execute(&self, _ctx: Arc<dyn ToolContext>, _args: Value) -> Result<Value> {
184 Ok(Value::String("ok".to_string()))
185 }
186 }
187
188 fn make_tool(name: &str) -> Arc<dyn Tool> {
189 Arc::new(FakeTool { name: name.to_string(), description: format!("Tool {name}") })
190 }
191
192 #[test]
193 fn test_resolve_no_collisions() {
194 let mut server_tools: ServerToolMap = HashMap::new();
195 server_tools
196 .insert("server_a".to_string(), vec![("tool_x".to_string(), make_tool("tool_x"))]);
197 server_tools
198 .insert("server_b".to_string(), vec![("tool_y".to_string(), make_tool("tool_y"))]);
199
200 let result = resolve_tool_names(&server_tools);
201 assert_eq!(result.len(), 2);
202
203 let names: Vec<&str> = result.iter().map(|t| t.name()).collect();
204 assert!(names.contains(&"tool_x"));
205 assert!(names.contains(&"tool_y"));
206 }
207
208 #[test]
209 fn test_resolve_with_collisions() {
210 let mut server_tools: ServerToolMap = HashMap::new();
211 server_tools.insert(
212 "server_a".to_string(),
213 vec![("read_file".to_string(), make_tool("read_file"))],
214 );
215 server_tools.insert(
216 "server_b".to_string(),
217 vec![("read_file".to_string(), make_tool("read_file"))],
218 );
219
220 let result = resolve_tool_names(&server_tools);
221 assert_eq!(result.len(), 2);
222
223 let mut names: Vec<String> = result.iter().map(|t| t.name().to_string()).collect();
224 names.sort();
225 assert_eq!(names, vec!["server_a__read_file", "server_b__read_file"]);
226 }
227
228 #[test]
229 fn test_resolve_mixed_collision_and_unique() {
230 let mut server_tools: ServerToolMap = HashMap::new();
231 server_tools.insert(
232 "server_a".to_string(),
233 vec![
234 ("read_file".to_string(), make_tool("read_file")),
235 ("unique_a".to_string(), make_tool("unique_a")),
236 ],
237 );
238 server_tools.insert(
239 "server_b".to_string(),
240 vec![
241 ("read_file".to_string(), make_tool("read_file")),
242 ("unique_b".to_string(), make_tool("unique_b")),
243 ],
244 );
245
246 let result = resolve_tool_names(&server_tools);
247 assert_eq!(result.len(), 4);
248
249 let mut names: Vec<String> = result.iter().map(|t| t.name().to_string()).collect();
250 names.sort();
251 assert_eq!(
252 names,
253 vec!["server_a__read_file", "server_b__read_file", "unique_a", "unique_b",]
254 );
255 }
256
257 #[test]
258 fn test_resolve_empty_servers() {
259 let server_tools: ServerToolMap = HashMap::new();
260 let result = resolve_tool_names(&server_tools);
261 assert!(result.is_empty());
262 }
263
264 #[test]
265 fn test_prefixed_tool_delegates_description() {
266 let inner = make_tool("original");
267 let prefixed =
268 PrefixedTool { inner: inner.clone(), prefixed_name: "server__original".to_string() };
269
270 assert_eq!(prefixed.name(), "server__original");
271 assert_eq!(prefixed.description(), inner.description());
272 assert_eq!(prefixed.is_long_running(), inner.is_long_running());
273 assert_eq!(prefixed.is_read_only(), inner.is_read_only());
274 assert_eq!(prefixed.is_concurrency_safe(), inner.is_concurrency_safe());
275 assert_eq!(prefixed.is_builtin(), inner.is_builtin());
276 }
277
278 #[test]
279 fn test_prefixed_tool_declaration_overrides_name() {
280 let inner = make_tool("original");
281 let prefixed = PrefixedTool { inner, prefixed_name: "server__original".to_string() };
282
283 let decl = prefixed.declaration();
284 assert_eq!(decl["name"], "server__original");
285 }
286
287 #[test]
288 fn test_resolve_three_way_collision() {
289 let mut server_tools: ServerToolMap = HashMap::new();
290 server_tools.insert("a".to_string(), vec![("shared".to_string(), make_tool("shared"))]);
291 server_tools.insert("b".to_string(), vec![("shared".to_string(), make_tool("shared"))]);
292 server_tools.insert("c".to_string(), vec![("shared".to_string(), make_tool("shared"))]);
293
294 let result = resolve_tool_names(&server_tools);
295 assert_eq!(result.len(), 3);
296
297 let mut names: Vec<String> = result.iter().map(|t| t.name().to_string()).collect();
298 names.sort();
299 assert_eq!(names, vec!["a__shared", "b__shared", "c__shared"]);
300 }
301}