Skip to main content

rig/tool/
rmcp.rs

1//! MCP (Model Context Protocol) integration via the `rmcp` crate.
2//!
3//! This module provides:
4//! - [`McpTool`]: A wrapper that adapts an `rmcp` tool for use in Rig's tool system.
5//! - [`McpClientHandler`]: A client handler that reacts to `notifications/tools/list_changed`
6//!   by re-fetching the tool list and updating the [`ToolServer`](super::server::ToolServer).
7//!
8//! # Example
9//!
10//! ```rust,ignore
11//! use rig::tool::rmcp::McpClientHandler;
12//! use rig::tool::server::ToolServer;
13//! use rmcp::ServiceExt;
14//!
15//! // 1. Create a ToolServer and get a handle
16//! let tool_server_handle = ToolServer::new().run();
17//!
18//! // 2. Create a handler that auto-updates tools on list changes
19//! let handler = McpClientHandler::new(client_info, tool_server_handle.clone());
20//!
21//! // 3. Connect to the MCP server and register initial tools
22//! let mcp_service = handler.connect(transport).await?;
23//!
24//! // 4. Build an agent using the shared tool server handle
25//! let agent = openai_client
26//!     .agent(openai::GPT_5_2)
27//!     .preamble("You are a helpful assistant.")
28//!     .tool_server_handle(tool_server_handle)
29//!     .build();
30//! ```
31
32use std::borrow::Cow;
33use std::sync::Arc;
34
35use rmcp::ServiceExt;
36use rmcp::model::RawContent;
37use tokio::sync::RwLock;
38
39use crate::completion::ToolDefinition;
40use crate::tool::ToolDyn;
41use crate::tool::ToolError;
42use crate::tool::server::{ToolServerError, ToolServerHandle};
43use crate::wasm_compat::WasmBoxedFuture;
44
45/// A Rig tool adapter wrapping an `rmcp` MCP tool.
46///
47/// Bridges between the MCP tool protocol and Rig's [`ToolDyn`] trait,
48/// allowing MCP tools to be used seamlessly in Rig agents.
49#[derive(Clone)]
50pub struct McpTool {
51    definition: rmcp::model::Tool,
52    client: rmcp::service::ServerSink,
53}
54
55impl McpTool {
56    /// Create a new `McpTool` from an MCP tool definition and server sink.
57    pub fn from_mcp_server(
58        definition: rmcp::model::Tool,
59        client: rmcp::service::ServerSink,
60    ) -> Self {
61        Self { definition, client }
62    }
63}
64
65impl From<&rmcp::model::Tool> for ToolDefinition {
66    fn from(val: &rmcp::model::Tool) -> Self {
67        Self {
68            name: val.name.to_string(),
69            description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
70            parameters: val.schema_as_json_value(),
71        }
72    }
73}
74
75impl From<rmcp::model::Tool> for ToolDefinition {
76    fn from(val: rmcp::model::Tool) -> Self {
77        Self {
78            name: val.name.to_string(),
79            description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
80            parameters: val.schema_as_json_value(),
81        }
82    }
83}
84
85#[derive(Debug, thiserror::Error)]
86#[error("MCP tool error: {0}")]
87pub struct McpToolError(String);
88
89impl From<McpToolError> for ToolError {
90    fn from(e: McpToolError) -> Self {
91        ToolError::ToolCallError(Box::new(e))
92    }
93}
94
95impl ToolDyn for McpTool {
96    fn name(&self) -> String {
97        self.definition.name.to_string()
98    }
99
100    fn definition(&self, _prompt: String) -> WasmBoxedFuture<'_, ToolDefinition> {
101        Box::pin(async move {
102            ToolDefinition {
103                name: self.definition.name.to_string(),
104                description: self
105                    .definition
106                    .description
107                    .clone()
108                    .unwrap_or(Cow::from(""))
109                    .to_string(),
110                parameters: serde_json::to_value(&self.definition.input_schema).unwrap_or_default(),
111            }
112        })
113    }
114
115    fn call(&self, args: String) -> WasmBoxedFuture<'_, Result<String, ToolError>> {
116        let name = self.definition.name.clone();
117        let arguments: Option<rmcp::model::JsonObject> =
118            serde_json::from_str(&args).unwrap_or_default();
119
120        Box::pin(async move {
121            let request = arguments
122                .map(|arguments| {
123                    rmcp::model::CallToolRequestParams::new(name.clone()).with_arguments(arguments)
124                })
125                .unwrap_or_else(|| rmcp::model::CallToolRequestParams::new(name));
126
127            let result = self
128                .client
129                .call_tool(request)
130                .await
131                .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
132
133            if let Some(true) = result.is_error {
134                let error_msg = result
135                    .content
136                    .into_iter()
137                    .map(|x| x.raw.as_text().map(|y| y.to_owned()))
138                    .map(|x| x.map(|x| x.clone().text))
139                    .collect::<Option<Vec<String>>>();
140
141                let error_message = error_msg.map(|x| x.join("\n"));
142                if let Some(error_message) = error_message {
143                    return Err(McpToolError(error_message).into());
144                } else {
145                    return Err(McpToolError("No message returned".to_string()).into());
146                }
147            };
148
149            let mut content = String::new();
150
151            for item in result.content {
152                let chunk = match item.raw {
153                    rmcp::model::RawContent::Text(raw) => raw.text,
154                    rmcp::model::RawContent::Image(raw) => {
155                        format!("data:{};base64,{}", raw.mime_type, raw.data)
156                    }
157                    rmcp::model::RawContent::Resource(raw) => match raw.resource {
158                        rmcp::model::ResourceContents::TextResourceContents {
159                            uri,
160                            mime_type,
161                            text,
162                            ..
163                        } => {
164                            format!(
165                                "{mime_type}{uri}:{text}",
166                                mime_type =
167                                    mime_type.map(|m| format!("data:{m};")).unwrap_or_default(),
168                            )
169                        }
170                        rmcp::model::ResourceContents::BlobResourceContents {
171                            uri,
172                            mime_type,
173                            blob,
174                            ..
175                        } => format!(
176                            "{mime_type}{uri}:{blob}",
177                            mime_type = mime_type.map(|m| format!("data:{m};")).unwrap_or_default(),
178                        ),
179                    },
180                    RawContent::Audio(_) => {
181                        return Err(McpToolError(
182                            "MCP tool returned audio content, which Rig does not support yet"
183                                .to_string(),
184                        )
185                        .into());
186                    }
187                    thing => {
188                        return Err(McpToolError(format!(
189                            "MCP tool returned unsupported content: {thing:?}"
190                        ))
191                        .into());
192                    }
193                };
194
195                content.push_str(&chunk);
196            }
197
198            Ok(content)
199        })
200    }
201}
202
203/// Error type for [`McpClientHandler`] operations.
204#[derive(Debug, thiserror::Error)]
205pub enum McpClientError {
206    /// Failed to establish the MCP connection or complete the handshake.
207    #[error("MCP connection error: {0}")]
208    ConnectionError(String),
209
210    /// Failed to fetch the tool list from the MCP server.
211    #[error("Failed to fetch MCP tool list: {0}")]
212    ToolFetchError(#[from] rmcp::ServiceError),
213
214    /// Failed to update the tool server with new tools.
215    #[error("Tool server error: {0}")]
216    ToolServerError(#[from] ToolServerError),
217}
218
219/// An MCP client handler that automatically re-fetches the tool list when the
220/// server sends a `notifications/tools/list_changed` notification.
221///
222/// This handler implements [`rmcp::ClientHandler`] and bridges the MCP
223/// notification lifecycle with Rig's [`ToolServer`](super::server::ToolServer).
224/// When the MCP server's available tools change, this handler:
225/// 1. Removes previously registered MCP tools from the tool server
226/// 2. Re-fetches the full tool list from the MCP server
227/// 3. Registers the updated tools with the tool server
228///
229/// # Usage
230///
231/// Use [`McpClientHandler::connect`] for a streamlined setup that handles
232/// connection, initial tool fetch, and registration in one call:
233///
234/// ```rust,ignore
235/// let tool_server_handle = ToolServer::new().run();
236/// let handler = McpClientHandler::new(client_info, tool_server_handle.clone());
237/// let mcp_service = handler.connect(transport).await?;
238/// ```
239///
240/// The returned `RunningService` keeps the MCP connection alive. When the
241/// server updates its tools, the handler automatically syncs with the tool server.
242pub struct McpClientHandler {
243    client_info: rmcp::model::ClientInfo,
244    tool_server_handle: ToolServerHandle,
245    /// Tracks which tool names were registered by this handler so they
246    /// can be removed and replaced on list-change notifications.
247    managed_tool_names: Arc<RwLock<Vec<String>>>,
248}
249
250impl McpClientHandler {
251    /// Create a new handler with the given client info and tool server handle.
252    ///
253    /// The `tool_server_handle` should be a clone of the handle used by the agent,
254    /// so that tool updates are reflected in agent requests.
255    pub fn new(client_info: rmcp::model::ClientInfo, tool_server_handle: ToolServerHandle) -> Self {
256        Self {
257            client_info,
258            tool_server_handle,
259            managed_tool_names: Arc::new(RwLock::new(Vec::new())),
260        }
261    }
262
263    /// Connect to an MCP server, fetch the initial tool list, and register
264    /// all tools with the tool server.
265    ///
266    /// Returns the running MCP service. The connection stays alive as long as the
267    /// returned `RunningService` is held. When the server sends
268    /// `notifications/tools/list_changed`, this handler automatically re-fetches
269    /// and re-registers tools.
270    ///
271    /// # Errors
272    ///
273    /// Returns [`McpClientError`] if the connection fails, the initial tool fetch
274    /// fails, or tool registration with the tool server fails.
275    pub async fn connect<T, E, A>(
276        self,
277        transport: T,
278    ) -> Result<rmcp::service::RunningService<rmcp::service::RoleClient, Self>, McpClientError>
279    where
280        T: rmcp::transport::IntoTransport<rmcp::service::RoleClient, E, A>,
281        E: std::error::Error + Send + Sync + 'static,
282    {
283        let service = ServiceExt::serve(self, transport)
284            .await
285            .map_err(|e| McpClientError::ConnectionError(e.to_string()))?;
286
287        let tools = service.peer().list_all_tools().await?;
288
289        {
290            let handler = service.service();
291            let mut managed = handler.managed_tool_names.write().await;
292
293            for tool in tools {
294                let tool_name = tool.name.to_string();
295                let mcp_tool = McpTool::from_mcp_server(tool, service.peer().clone());
296                handler.tool_server_handle.add_tool(mcp_tool).await?;
297                managed.push(tool_name);
298            }
299        }
300
301        Ok(service)
302    }
303}
304
305impl rmcp::handler::client::ClientHandler for McpClientHandler {
306    fn get_info(&self) -> rmcp::model::ClientInfo {
307        self.client_info.clone()
308    }
309
310    async fn on_tool_list_changed(
311        &self,
312        context: rmcp::service::NotificationContext<rmcp::service::RoleClient>,
313    ) {
314        let tools = match context.peer.list_all_tools().await {
315            Ok(tools) => tools,
316            Err(e) => {
317                tracing::error!("Failed to re-fetch MCP tool list: {e}");
318                return;
319            }
320        };
321
322        let mut managed = self.managed_tool_names.write().await;
323
324        for name in managed.drain(..) {
325            if let Err(e) = self.tool_server_handle.remove_tool(&name).await {
326                tracing::warn!("Failed to remove MCP tool '{name}' during refresh: {e}");
327            }
328        }
329
330        for tool in tools {
331            let tool_name = tool.name.to_string();
332            let mcp_tool = McpTool::from_mcp_server(tool, context.peer.clone());
333            match self.tool_server_handle.add_tool(mcp_tool).await {
334                Ok(()) => {
335                    managed.push(tool_name);
336                }
337                Err(e) => {
338                    tracing::error!("Failed to register MCP tool '{tool_name}': {e}");
339                }
340            }
341        }
342
343        tracing::info!(
344            tool_count = managed.len(),
345            "MCP tool list refreshed successfully"
346        );
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use std::sync::Arc;
353    use std::time::Duration;
354
355    use rmcp::handler::client::ClientHandler;
356    use rmcp::model::*;
357    use rmcp::service::RequestContext;
358    use rmcp::{RoleServer, ServerHandler, ServiceExt};
359    use tokio::sync::RwLock;
360
361    use super::McpClientHandler;
362    use crate::tool::server::ToolServer;
363
364    /// An MCP server whose tool list can be swapped at runtime.
365    #[derive(Clone)]
366    struct DynamicToolServer {
367        tools: Arc<RwLock<Vec<Tool>>>,
368    }
369
370    impl DynamicToolServer {
371        fn new(tools: Vec<Tool>) -> Self {
372            Self {
373                tools: Arc::new(RwLock::new(tools)),
374            }
375        }
376
377        async fn set_tools(&self, tools: Vec<Tool>) {
378            *self.tools.write().await = tools;
379        }
380    }
381
382    impl ServerHandler for DynamicToolServer {
383        fn get_info(&self) -> ServerInfo {
384            ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
385                .with_protocol_version(ProtocolVersion::LATEST)
386                .with_server_info(Implementation::new("test-dynamic-server", "0.1.0"))
387        }
388
389        async fn list_tools(
390            &self,
391            _request: Option<PaginatedRequestParams>,
392            _context: RequestContext<RoleServer>,
393        ) -> Result<ListToolsResult, ErrorData> {
394            let tools = self.tools.read().await.clone();
395            Ok(ListToolsResult::with_all_items(tools))
396        }
397
398        async fn call_tool(
399            &self,
400            request: CallToolRequestParams,
401            _context: RequestContext<RoleServer>,
402        ) -> Result<CallToolResult, ErrorData> {
403            Ok(CallToolResult::success(vec![Content::text(format!(
404                "called {}",
405                request.name
406            ))]))
407        }
408    }
409
410    fn make_tool(name: &str, description: &str) -> Tool {
411        Tool::new(
412            name.to_string(),
413            description.to_string(),
414            Arc::new(serde_json::Map::new()),
415        )
416    }
417
418    #[tokio::test]
419    async fn test_mcp_client_handler_initial_tool_registration() {
420        let initial_tools = vec![
421            make_tool("tool_a", "First tool"),
422            make_tool("tool_b", "Second tool"),
423        ];
424
425        let server = DynamicToolServer::new(initial_tools);
426        let tool_server_handle = ToolServer::new().run();
427
428        let (client_to_server, server_from_client) = tokio::io::duplex(8192);
429        let (server_to_client, client_from_server) = tokio::io::duplex(8192);
430
431        let server_clone = server.clone();
432        tokio::spawn(async move {
433            let _service = server_clone
434                .serve((server_from_client, server_to_client))
435                .await
436                .expect("server failed to start");
437            _service.waiting().await.expect("server error");
438        });
439
440        let client_info = ClientInfo::default();
441        let handler = McpClientHandler::new(client_info, tool_server_handle.clone());
442
443        let _mcp_service = handler
444            .connect((client_from_server, client_to_server))
445            .await
446            .expect("connect failed");
447
448        let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
449        assert_eq!(defs.len(), 2);
450
451        let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
452        assert!(names.contains(&"tool_a"));
453        assert!(names.contains(&"tool_b"));
454    }
455
456    #[tokio::test]
457    async fn test_mcp_client_handler_refreshes_on_tool_list_changed() {
458        let initial_tools = vec![make_tool("alpha", "Alpha tool")];
459
460        let server = DynamicToolServer::new(initial_tools);
461        let tool_server_handle = ToolServer::new().run();
462
463        let (client_to_server, server_from_client) = tokio::io::duplex(8192);
464        let (server_to_client, client_from_server) = tokio::io::duplex(8192);
465
466        let server_clone = server.clone();
467        let server_service_handle = tokio::spawn(async move {
468            server_clone
469                .serve((server_from_client, server_to_client))
470                .await
471                .expect("server failed to start")
472        });
473
474        let client_info = ClientInfo::default();
475        let handler = McpClientHandler::new(client_info, tool_server_handle.clone());
476
477        let _mcp_service = handler
478            .connect((client_from_server, client_to_server))
479            .await
480            .expect("connect failed");
481
482        // Verify initial state
483        let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
484        assert_eq!(defs.len(), 1);
485        assert_eq!(defs[0].name, "alpha");
486
487        // Update the server's tool list
488        server
489            .set_tools(vec![
490                make_tool("beta", "Beta tool"),
491                make_tool("gamma", "Gamma tool"),
492            ])
493            .await;
494
495        // Send the notification from the server side
496        let server_service = server_service_handle.await.unwrap();
497        server_service
498            .peer()
499            .notify_tool_list_changed()
500            .await
501            .expect("failed to send notification");
502
503        // The handler processes the notification asynchronously, so give it
504        // a moment to re-fetch and re-register tools.
505        tokio::time::sleep(Duration::from_millis(200)).await;
506
507        let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
508        assert_eq!(defs.len(), 2);
509
510        let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
511        assert!(names.contains(&"beta"), "expected 'beta' in {names:?}");
512        assert!(names.contains(&"gamma"), "expected 'gamma' in {names:?}");
513        // The old tool must be gone
514        assert!(
515            !names.contains(&"alpha"),
516            "expected 'alpha' to be removed, found {names:?}"
517        );
518    }
519
520    #[tokio::test]
521    async fn test_mcp_client_handler_get_info_delegates() {
522        let client_info = ClientInfo::new(
523            ClientCapabilities::default(),
524            Implementation::new("test-client", "1.0.0"),
525        );
526
527        let tool_server_handle = ToolServer::new().run();
528        let handler = McpClientHandler::new(client_info.clone(), tool_server_handle);
529
530        let returned = handler.get_info();
531        assert_eq!(returned.client_info.name, "test-client");
532        assert_eq!(returned.client_info.version, "1.0.0");
533    }
534}