1use adk_core::{AdkError, ReadonlyContext, Result, Tool, ToolContext, Toolset};
10use async_trait::async_trait;
11use rmcp::{
12 RoleClient,
13 model::{CallToolRequestParam, RawContent, ResourceContents},
14 service::RunningService,
15};
16use serde_json::{Value, json};
17use std::ops::Deref;
18use std::sync::Arc;
19use tokio::sync::Mutex;
20
21pub type ToolFilter = Arc<dyn Fn(&str) -> bool + Send + Sync>;
23
24fn sanitize_schema(value: &mut Value) {
28 if let Value::Object(map) = value {
29 map.remove("$schema");
30 map.remove("definitions");
31 map.remove("$ref");
32 map.remove("additionalProperties");
33
34 for (_, v) in map.iter_mut() {
35 sanitize_schema(v);
36 }
37 } else if let Value::Array(arr) = value {
38 for v in arr.iter_mut() {
39 sanitize_schema(v);
40 }
41 }
42}
43
44pub struct McpToolset<S = ()>
76where
77 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
78{
79 client: Arc<Mutex<RunningService<RoleClient, S>>>,
81 tool_filter: Option<ToolFilter>,
83 name: String,
85}
86
87impl<S> McpToolset<S>
88where
89 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
90{
91 pub fn new(client: RunningService<RoleClient, S>) -> Self {
109 Self {
110 client: Arc::new(Mutex::new(client)),
111 tool_filter: None,
112 name: "mcp_toolset".to_string(),
113 }
114 }
115
116 pub fn with_name(mut self, name: impl Into<String>) -> Self {
118 self.name = name.into();
119 self
120 }
121
122 pub fn with_filter<F>(mut self, filter: F) -> Self
136 where
137 F: Fn(&str) -> bool + Send + Sync + 'static,
138 {
139 self.tool_filter = Some(Arc::new(filter));
140 self
141 }
142
143 pub fn with_tools(self, tool_names: &[&str]) -> Self {
152 let names: Vec<String> = tool_names.iter().map(|s| s.to_string()).collect();
153 self.with_filter(move |name| names.iter().any(|n| n == name))
154 }
155
156 pub async fn cancellation_token(&self) -> rmcp::service::RunningServiceCancellationToken {
173 let client = self.client.lock().await;
174 client.cancellation_token()
175 }
176}
177
178#[async_trait]
179impl<S> Toolset for McpToolset<S>
180where
181 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
182{
183 fn name(&self) -> &str {
184 &self.name
185 }
186
187 async fn tools(&self, _ctx: Arc<dyn ReadonlyContext>) -> Result<Vec<Arc<dyn Tool>>> {
188 let client = self.client.lock().await;
189
190 let mcp_tools = client
192 .list_all_tools()
193 .await
194 .map_err(|e| AdkError::Tool(format!("Failed to list MCP tools: {}", e)))?;
195
196 let mut tools: Vec<Arc<dyn Tool>> = Vec::new();
198
199 for mcp_tool in mcp_tools {
200 let tool_name = mcp_tool.name.to_string();
201
202 if let Some(ref filter) = self.tool_filter {
204 if !filter(&tool_name) {
205 continue;
206 }
207 }
208
209 let adk_tool = McpTool {
210 name: tool_name,
211 description: mcp_tool.description.map(|d| d.to_string()).unwrap_or_default(),
212 input_schema: {
213 let mut schema = Value::Object(mcp_tool.input_schema.as_ref().clone());
214 sanitize_schema(&mut schema);
215 Some(schema)
216 },
217 output_schema: mcp_tool.output_schema.map(|s| {
218 let mut schema = Value::Object(s.as_ref().clone());
219 sanitize_schema(&mut schema);
220 schema
221 }),
222 client: self.client.clone(),
223 };
224
225 tools.push(Arc::new(adk_tool) as Arc<dyn Tool>);
226 }
227
228 Ok(tools)
229 }
230}
231
232struct McpTool<S>
236where
237 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
238{
239 name: String,
240 description: String,
241 input_schema: Option<Value>,
242 output_schema: Option<Value>,
243 client: Arc<Mutex<RunningService<RoleClient, S>>>,
244}
245
246#[async_trait]
247impl<S> Tool for McpTool<S>
248where
249 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
250{
251 fn name(&self) -> &str {
252 &self.name
253 }
254
255 fn description(&self) -> &str {
256 &self.description
257 }
258
259 fn is_long_running(&self) -> bool {
260 false
261 }
262
263 fn parameters_schema(&self) -> Option<Value> {
264 self.input_schema.clone()
265 }
266
267 fn response_schema(&self) -> Option<Value> {
268 self.output_schema.clone()
269 }
270
271 async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
272 let client = self.client.lock().await;
273
274 let result = client
276 .call_tool(CallToolRequestParam {
277 name: self.name.clone().into(),
278 arguments: if args.is_null() || args == json!({}) {
279 None
280 } else {
281 match args {
283 Value::Object(map) => Some(map),
284 _ => {
285 return Err(AdkError::Tool(
286 "Tool arguments must be an object".to_string(),
287 ));
288 }
289 }
290 },
291 })
292 .await
293 .map_err(|e| {
294 AdkError::Tool(format!("Failed to call MCP tool '{}': {}", self.name, e))
295 })?;
296
297 if result.is_error.unwrap_or(false) {
299 let mut error_msg = format!("MCP tool '{}' execution failed", self.name);
300
301 for content in &result.content {
303 if let Some(text_content) = content.deref().as_text() {
305 error_msg.push_str(": ");
306 error_msg.push_str(&text_content.text);
307 break;
308 }
309 }
310
311 return Err(AdkError::Tool(error_msg));
312 }
313
314 if let Some(structured) = result.structured_content {
316 return Ok(json!({ "output": structured }));
317 }
318
319 let mut text_parts: Vec<String> = Vec::new();
321
322 for content in &result.content {
323 let raw: &RawContent = content.deref();
325 match raw {
326 RawContent::Text(text_content) => {
327 text_parts.push(text_content.text.clone());
328 }
329 RawContent::Image(image_content) => {
330 text_parts.push(format!(
332 "[Image: {} bytes, mime: {}]",
333 image_content.data.len(),
334 image_content.mime_type
335 ));
336 }
337 RawContent::Resource(resource_content) => {
338 let uri = match &resource_content.resource {
339 ResourceContents::TextResourceContents { uri, .. } => uri,
340 ResourceContents::BlobResourceContents { uri, .. } => uri,
341 };
342 text_parts.push(format!("[Resource: {}]", uri));
343 }
344 RawContent::Audio(_) => {
345 text_parts.push("[Audio content]".to_string());
346 }
347 RawContent::ResourceLink(link) => {
348 text_parts.push(format!("[ResourceLink: {}]", link.uri));
349 }
350 }
351 }
352
353 if text_parts.is_empty() {
354 return Err(AdkError::Tool(format!("MCP tool '{}' returned no content", self.name)));
355 }
356
357 Ok(json!({ "output": text_parts.join("\n") }))
358 }
359}
360
361unsafe impl<S> Send for McpTool<S> where
363 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static
364{
365}
366unsafe impl<S> Sync for McpTool<S> where
367 S: rmcp::service::Service<RoleClient> + Send + Sync + 'static
368{
369}