adk_tool/mcp/
toolset.rs

1// MCP (Model Context Protocol) Toolset Integration
2//
3// Based on Go implementation: adk-go/tool/mcptoolset/
4// Uses official Rust SDK: https://github.com/modelcontextprotocol/rust-sdk
5//
6// The McpToolset connects to an MCP server, discovers available tools,
7// and exposes them as ADK-compatible tools for use with LlmAgent.
8
9use adk_core::{AdkError, ReadonlyContext, Result, Tool, ToolContext, Toolset};
10use async_trait::async_trait;
11use rmcp::{
12    RoleClient,
13    model::{CallToolRequestParam, RawContent, ResourceContents},
14    service::RunningService,
15};
16use serde_json::{Value, json};
17use std::ops::Deref;
18use std::sync::Arc;
19use tokio::sync::Mutex;
20
21/// Type alias for tool filter predicate
22pub type ToolFilter = Arc<dyn Fn(&str) -> bool + Send + Sync>;
23
24/// Sanitize JSON schema for LLM compatibility.
25/// Removes fields like `$schema`, `additionalProperties`, `definitions`, `$ref`
26/// that some LLM APIs (like Gemini) don't accept.
27fn sanitize_schema(value: &mut Value) {
28    if let Value::Object(map) = value {
29        map.remove("$schema");
30        map.remove("definitions");
31        map.remove("$ref");
32        map.remove("additionalProperties");
33
34        for (_, v) in map.iter_mut() {
35            sanitize_schema(v);
36        }
37    } else if let Value::Array(arr) = value {
38        for v in arr.iter_mut() {
39            sanitize_schema(v);
40        }
41    }
42}
43
44/// MCP Toolset - connects to an MCP server and exposes its tools as ADK tools.
45///
46/// This toolset implements the ADK `Toolset` trait and bridges the gap between
47/// MCP servers and ADK agents. It:
48/// 1. Connects to an MCP server via the provided transport
49/// 2. Discovers available tools from the server
50/// 3. Converts MCP tools to ADK-compatible `Tool` implementations
51/// 4. Proxies tool execution calls to the MCP server
52///
53/// # Example
54///
55/// ```rust,ignore
56/// use adk_tool::McpToolset;
57/// use rmcp::{ServiceExt, transport::TokioChildProcess};
58/// use tokio::process::Command;
59///
60/// // Create MCP client connection to a local server
61/// let client = ().serve(TokioChildProcess::new(
62///     Command::new("npx")
63///         .arg("-y")
64///         .arg("@modelcontextprotocol/server-everything")
65/// )?).await?;
66///
67/// // Create toolset from the client
68/// let toolset = McpToolset::new(client);
69///
70/// // Add to agent
71/// let agent = LlmAgentBuilder::new("assistant")
72///     .toolset(Arc::new(toolset))
73///     .build()?;
74/// ```
75pub struct McpToolset<S = ()>
76where
77    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
78{
79    /// The running MCP client service
80    client: Arc<Mutex<RunningService<RoleClient, S>>>,
81    /// Optional filter to select which tools to expose
82    tool_filter: Option<ToolFilter>,
83    /// Name of this toolset
84    name: String,
85}
86
87impl<S> McpToolset<S>
88where
89    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
90{
91    /// Create a new MCP toolset from a running MCP client service.
92    ///
93    /// The client should already be connected and initialized.
94    /// Use `rmcp::ServiceExt::serve()` to create the client.
95    ///
96    /// # Example
97    ///
98    /// ```rust,ignore
99    /// use rmcp::{ServiceExt, transport::TokioChildProcess};
100    /// use tokio::process::Command;
101    ///
102    /// let client = ().serve(TokioChildProcess::new(
103    ///     Command::new("my-mcp-server")
104    /// )?).await?;
105    ///
106    /// let toolset = McpToolset::new(client);
107    /// ```
108    pub fn new(client: RunningService<RoleClient, S>) -> Self {
109        Self {
110            client: Arc::new(Mutex::new(client)),
111            tool_filter: None,
112            name: "mcp_toolset".to_string(),
113        }
114    }
115
116    /// Set a custom name for this toolset.
117    pub fn with_name(mut self, name: impl Into<String>) -> Self {
118        self.name = name.into();
119        self
120    }
121
122    /// Add a filter to select which tools to expose.
123    ///
124    /// The filter function receives a tool name and returns true if the tool
125    /// should be included.
126    ///
127    /// # Example
128    ///
129    /// ```rust,ignore
130    /// let toolset = McpToolset::new(client)
131    ///     .with_filter(|name| {
132    ///         matches!(name, "read_file" | "list_directory" | "search_files")
133    ///     });
134    /// ```
135    pub fn with_filter<F>(mut self, filter: F) -> Self
136    where
137        F: Fn(&str) -> bool + Send + Sync + 'static,
138    {
139        self.tool_filter = Some(Arc::new(filter));
140        self
141    }
142
143    /// Add a filter that only includes tools with the specified names.
144    ///
145    /// # Example
146    ///
147    /// ```rust,ignore
148    /// let toolset = McpToolset::new(client)
149    ///     .with_tools(&["read_file", "write_file"]);
150    /// ```
151    pub fn with_tools(self, tool_names: &[&str]) -> Self {
152        let names: Vec<String> = tool_names.iter().map(|s| s.to_string()).collect();
153        self.with_filter(move |name| names.iter().any(|n| n == name))
154    }
155
156    /// Get a cancellation token that can be used to shutdown the MCP server.
157    ///
158    /// Call `cancel()` on the returned token to cleanly shutdown the MCP server.
159    /// This should be called before exiting to avoid EPIPE errors.
160    ///
161    /// # Example
162    ///
163    /// ```rust,ignore
164    /// let toolset = McpToolset::new(client);
165    /// let cancel_token = toolset.cancellation_token().await;
166    ///
167    /// // ... use the toolset ...
168    ///
169    /// // Before exiting:
170    /// cancel_token.cancel();
171    /// ```
172    pub async fn cancellation_token(&self) -> rmcp::service::RunningServiceCancellationToken {
173        let client = self.client.lock().await;
174        client.cancellation_token()
175    }
176}
177
178#[async_trait]
179impl<S> Toolset for McpToolset<S>
180where
181    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
182{
183    fn name(&self) -> &str {
184        &self.name
185    }
186
187    async fn tools(&self, _ctx: Arc<dyn ReadonlyContext>) -> Result<Vec<Arc<dyn Tool>>> {
188        let client = self.client.lock().await;
189
190        // List all tools from the MCP server (handles pagination internally)
191        let mcp_tools = client
192            .list_all_tools()
193            .await
194            .map_err(|e| AdkError::Tool(format!("Failed to list MCP tools: {}", e)))?;
195
196        // Convert MCP tools to ADK tools
197        let mut tools: Vec<Arc<dyn Tool>> = Vec::new();
198
199        for mcp_tool in mcp_tools {
200            let tool_name = mcp_tool.name.to_string();
201
202            // Apply filter if present
203            if let Some(ref filter) = self.tool_filter {
204                if !filter(&tool_name) {
205                    continue;
206                }
207            }
208
209            let adk_tool = McpTool {
210                name: tool_name,
211                description: mcp_tool.description.map(|d| d.to_string()).unwrap_or_default(),
212                input_schema: {
213                    let mut schema = Value::Object(mcp_tool.input_schema.as_ref().clone());
214                    sanitize_schema(&mut schema);
215                    Some(schema)
216                },
217                output_schema: mcp_tool.output_schema.map(|s| {
218                    let mut schema = Value::Object(s.as_ref().clone());
219                    sanitize_schema(&mut schema);
220                    schema
221                }),
222                client: self.client.clone(),
223            };
224
225            tools.push(Arc::new(adk_tool) as Arc<dyn Tool>);
226        }
227
228        Ok(tools)
229    }
230}
231
232/// Individual MCP tool wrapper that implements the ADK `Tool` trait.
233///
234/// This struct wraps an MCP tool and proxies execution calls to the MCP server.
235struct McpTool<S>
236where
237    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
238{
239    name: String,
240    description: String,
241    input_schema: Option<Value>,
242    output_schema: Option<Value>,
243    client: Arc<Mutex<RunningService<RoleClient, S>>>,
244}
245
246#[async_trait]
247impl<S> Tool for McpTool<S>
248where
249    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static,
250{
251    fn name(&self) -> &str {
252        &self.name
253    }
254
255    fn description(&self) -> &str {
256        &self.description
257    }
258
259    fn is_long_running(&self) -> bool {
260        false
261    }
262
263    fn parameters_schema(&self) -> Option<Value> {
264        self.input_schema.clone()
265    }
266
267    fn response_schema(&self) -> Option<Value> {
268        self.output_schema.clone()
269    }
270
271    async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
272        let client = self.client.lock().await;
273
274        // Call the MCP tool
275        let result = client
276            .call_tool(CallToolRequestParam {
277                name: self.name.clone().into(),
278                arguments: if args.is_null() || args == json!({}) {
279                    None
280                } else {
281                    // Convert Value to the expected type
282                    match args {
283                        Value::Object(map) => Some(map),
284                        _ => {
285                            return Err(AdkError::Tool(
286                                "Tool arguments must be an object".to_string(),
287                            ));
288                        }
289                    }
290                },
291            })
292            .await
293            .map_err(|e| {
294                AdkError::Tool(format!("Failed to call MCP tool '{}': {}", self.name, e))
295            })?;
296
297        // Check for error response
298        if result.is_error.unwrap_or(false) {
299            let mut error_msg = format!("MCP tool '{}' execution failed", self.name);
300
301            // Extract error details from content
302            for content in &result.content {
303                // Use Deref to access the inner RawContent
304                if let Some(text_content) = content.deref().as_text() {
305                    error_msg.push_str(": ");
306                    error_msg.push_str(&text_content.text);
307                    break;
308                }
309            }
310
311            return Err(AdkError::Tool(error_msg));
312        }
313
314        // Return structured content if available
315        if let Some(structured) = result.structured_content {
316            return Ok(json!({ "output": structured }));
317        }
318
319        // Otherwise, collect text content
320        let mut text_parts: Vec<String> = Vec::new();
321
322        for content in &result.content {
323            // Access the inner RawContent via Deref
324            let raw: &RawContent = content.deref();
325            match raw {
326                RawContent::Text(text_content) => {
327                    text_parts.push(text_content.text.clone());
328                }
329                RawContent::Image(image_content) => {
330                    // Return image data as base64
331                    text_parts.push(format!(
332                        "[Image: {} bytes, mime: {}]",
333                        image_content.data.len(),
334                        image_content.mime_type
335                    ));
336                }
337                RawContent::Resource(resource_content) => {
338                    let uri = match &resource_content.resource {
339                        ResourceContents::TextResourceContents { uri, .. } => uri,
340                        ResourceContents::BlobResourceContents { uri, .. } => uri,
341                    };
342                    text_parts.push(format!("[Resource: {}]", uri));
343                }
344                RawContent::Audio(_) => {
345                    text_parts.push("[Audio content]".to_string());
346                }
347                RawContent::ResourceLink(link) => {
348                    text_parts.push(format!("[ResourceLink: {}]", link.uri));
349                }
350            }
351        }
352
353        if text_parts.is_empty() {
354            return Err(AdkError::Tool(format!("MCP tool '{}' returned no content", self.name)));
355        }
356
357        Ok(json!({ "output": text_parts.join("\n") }))
358    }
359}
360
361// Ensure McpTool is Send + Sync
362unsafe impl<S> Send for McpTool<S> where
363    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static
364{
365}
366unsafe impl<S> Sync for McpTool<S> where
367    S: rmcp::service::Service<RoleClient> + Send + Sync + 'static
368{
369}