Skip to main content

rig/tool/
rmcp.rs

1//! MCP (Model Context Protocol) integration via the `rmcp` crate.
2//!
3//! This module provides:
4//! - [`McpTool`]: A wrapper that adapts an `rmcp` tool for use in Rig's tool system.
5//! - [`McpClientHandler`]: A client handler that reacts to `notifications/tools/list_changed`
6//!   by re-fetching the tool list and updating the [`ToolServer`](super::server::ToolServer).
7//!
8//! # Example
9//!
10//! ```rust,ignore
11//! use rig::tool::rmcp::McpClientHandler;
12//! use rig::tool::server::ToolServer;
13//! use rmcp::ServiceExt;
14//!
15//! // 1. Create a ToolServer and get a handle
16//! let tool_server_handle = ToolServer::new().run();
17//!
18//! // 2. Create a handler that auto-updates tools on list changes
19//! let handler = McpClientHandler::new(client_info, tool_server_handle.clone());
20//!
21//! // 3. Connect to the MCP server and register initial tools
22//! let mcp_service = handler.connect(transport).await?;
23//!
24//! // 4. Build an agent using the shared tool server handle
25//! let agent = openai_client
26//!     .agent(openai::GPT_5_2)
27//!     .preamble("You are a helpful assistant.")
28//!     .tool_server_handle(tool_server_handle)
29//!     .build();
30//! ```
31
32use std::borrow::Cow;
33use std::sync::Arc;
34
35use rmcp::ServiceExt;
36use rmcp::model::RawContent;
37use tokio::sync::RwLock;
38
39use crate::completion::ToolDefinition;
40use crate::tool::ToolDyn;
41use crate::tool::ToolError;
42use crate::tool::server::{ToolServerError, ToolServerHandle};
43use crate::wasm_compat::WasmBoxedFuture;
44
45/// A Rig tool adapter wrapping an `rmcp` MCP tool.
46///
47/// Bridges between the MCP tool protocol and Rig's [`ToolDyn`] trait,
48/// allowing MCP tools to be used seamlessly in Rig agents.
49#[derive(Clone)]
50pub struct McpTool {
51    definition: rmcp::model::Tool,
52    client: rmcp::service::ServerSink,
53}
54
55impl McpTool {
56    /// Create a new `McpTool` from an MCP tool definition and server sink.
57    pub fn from_mcp_server(
58        definition: rmcp::model::Tool,
59        client: rmcp::service::ServerSink,
60    ) -> Self {
61        Self { definition, client }
62    }
63}
64
65impl From<&rmcp::model::Tool> for ToolDefinition {
66    fn from(val: &rmcp::model::Tool) -> Self {
67        Self {
68            name: val.name.to_string(),
69            description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
70            parameters: val.schema_as_json_value(),
71        }
72    }
73}
74
75impl From<rmcp::model::Tool> for ToolDefinition {
76    fn from(val: rmcp::model::Tool) -> Self {
77        Self {
78            name: val.name.to_string(),
79            description: val.description.clone().unwrap_or(Cow::from("")).to_string(),
80            parameters: val.schema_as_json_value(),
81        }
82    }
83}
84
85#[derive(Debug, thiserror::Error)]
86#[error("MCP tool error: {0}")]
87pub struct McpToolError(String);
88
89impl From<McpToolError> for ToolError {
90    fn from(e: McpToolError) -> Self {
91        ToolError::ToolCallError(Box::new(e))
92    }
93}
94
95impl ToolDyn for McpTool {
96    fn name(&self) -> String {
97        self.definition.name.to_string()
98    }
99
100    fn definition(&self, _prompt: String) -> WasmBoxedFuture<'_, ToolDefinition> {
101        Box::pin(async move {
102            ToolDefinition {
103                name: self.definition.name.to_string(),
104                description: self
105                    .definition
106                    .description
107                    .clone()
108                    .unwrap_or(Cow::from(""))
109                    .to_string(),
110                parameters: serde_json::to_value(&self.definition.input_schema).unwrap_or_default(),
111            }
112        })
113    }
114
115    fn call(&self, args: String) -> WasmBoxedFuture<'_, Result<String, ToolError>> {
116        let name = self.definition.name.clone();
117        let arguments = serde_json::from_str(&args).unwrap_or_default();
118
119        Box::pin(async move {
120            let result = self
121                .client
122                .call_tool(rmcp::model::CallToolRequestParams {
123                    name,
124                    arguments,
125                    meta: None,
126                    task: None,
127                })
128                .await
129                .map_err(|e| McpToolError(format!("Tool returned an error: {e}")))?;
130
131            if let Some(true) = result.is_error {
132                let error_msg = result
133                    .content
134                    .into_iter()
135                    .map(|x| x.raw.as_text().map(|y| y.to_owned()))
136                    .map(|x| x.map(|x| x.clone().text))
137                    .collect::<Option<Vec<String>>>();
138
139                let error_message = error_msg.map(|x| x.join("\n"));
140                if let Some(error_message) = error_message {
141                    return Err(McpToolError(error_message).into());
142                } else {
143                    return Err(McpToolError("No message returned".to_string()).into());
144                }
145            };
146
147            Ok(result
148                .content
149                .into_iter()
150                .map(|c| match c.raw {
151                    rmcp::model::RawContent::Text(raw) => raw.text,
152                    rmcp::model::RawContent::Image(raw) => {
153                        format!("data:{};base64,{}", raw.mime_type, raw.data)
154                    }
155                    rmcp::model::RawContent::Resource(raw) => match raw.resource {
156                        rmcp::model::ResourceContents::TextResourceContents {
157                            uri,
158                            mime_type,
159                            text,
160                            ..
161                        } => {
162                            format!(
163                                "{mime_type}{uri}:{text}",
164                                mime_type = mime_type
165                                    .map(|m| format!("data:{m};"))
166                                    .unwrap_or_default(),
167                            )
168                        }
169                        rmcp::model::ResourceContents::BlobResourceContents {
170                            uri,
171                            mime_type,
172                            blob,
173                            ..
174                        } => format!(
175                            "{mime_type}{uri}:{blob}",
176                            mime_type = mime_type
177                                .map(|m| format!("data:{m};"))
178                                .unwrap_or_default(),
179                        ),
180                    },
181                    RawContent::Audio(_) => {
182                        panic!("Support for audio results from an MCP tool is currently unimplemented. Come back later!")
183                    }
184                    thing => {
185                        panic!("Unsupported type found: {thing:?}")
186                    }
187                })
188                .collect::<String>())
189        })
190    }
191}
192
193/// Error type for [`McpClientHandler`] operations.
194#[derive(Debug, thiserror::Error)]
195pub enum McpClientError {
196    /// Failed to establish the MCP connection or complete the handshake.
197    #[error("MCP connection error: {0}")]
198    ConnectionError(String),
199
200    /// Failed to fetch the tool list from the MCP server.
201    #[error("Failed to fetch MCP tool list: {0}")]
202    ToolFetchError(#[from] rmcp::ServiceError),
203
204    /// Failed to update the tool server with new tools.
205    #[error("Tool server error: {0}")]
206    ToolServerError(#[from] ToolServerError),
207}
208
209/// An MCP client handler that automatically re-fetches the tool list when the
210/// server sends a `notifications/tools/list_changed` notification.
211///
212/// This handler implements [`rmcp::ClientHandler`] and bridges the MCP
213/// notification lifecycle with Rig's [`ToolServer`](super::server::ToolServer).
214/// When the MCP server's available tools change, this handler:
215/// 1. Removes previously registered MCP tools from the tool server
216/// 2. Re-fetches the full tool list from the MCP server
217/// 3. Registers the updated tools with the tool server
218///
219/// # Usage
220///
221/// Use [`McpClientHandler::connect`] for a streamlined setup that handles
222/// connection, initial tool fetch, and registration in one call:
223///
224/// ```rust,ignore
225/// let tool_server_handle = ToolServer::new().run();
226/// let handler = McpClientHandler::new(client_info, tool_server_handle.clone());
227/// let mcp_service = handler.connect(transport).await?;
228/// ```
229///
230/// The returned `RunningService` keeps the MCP connection alive. When the
231/// server updates its tools, the handler automatically syncs with the tool server.
232pub struct McpClientHandler {
233    client_info: rmcp::model::ClientInfo,
234    tool_server_handle: ToolServerHandle,
235    /// Tracks which tool names were registered by this handler so they
236    /// can be removed and replaced on list-change notifications.
237    managed_tool_names: Arc<RwLock<Vec<String>>>,
238}
239
240impl McpClientHandler {
241    /// Create a new handler with the given client info and tool server handle.
242    ///
243    /// The `tool_server_handle` should be a clone of the handle used by the agent,
244    /// so that tool updates are reflected in agent requests.
245    pub fn new(client_info: rmcp::model::ClientInfo, tool_server_handle: ToolServerHandle) -> Self {
246        Self {
247            client_info,
248            tool_server_handle,
249            managed_tool_names: Arc::new(RwLock::new(Vec::new())),
250        }
251    }
252
253    /// Connect to an MCP server, fetch the initial tool list, and register
254    /// all tools with the tool server.
255    ///
256    /// Returns the running MCP service. The connection stays alive as long as the
257    /// returned `RunningService` is held. When the server sends
258    /// `notifications/tools/list_changed`, this handler automatically re-fetches
259    /// and re-registers tools.
260    ///
261    /// # Errors
262    ///
263    /// Returns [`McpClientError`] if the connection fails, the initial tool fetch
264    /// fails, or tool registration with the tool server fails.
265    pub async fn connect<T, E, A>(
266        self,
267        transport: T,
268    ) -> Result<rmcp::service::RunningService<rmcp::service::RoleClient, Self>, McpClientError>
269    where
270        T: rmcp::transport::IntoTransport<rmcp::service::RoleClient, E, A>,
271        E: std::error::Error + Send + Sync + 'static,
272    {
273        let service = ServiceExt::serve(self, transport)
274            .await
275            .map_err(|e| McpClientError::ConnectionError(e.to_string()))?;
276
277        let tools = service.peer().list_all_tools().await?;
278
279        {
280            let handler = service.service();
281            let mut managed = handler.managed_tool_names.write().await;
282
283            for tool in tools {
284                let tool_name = tool.name.to_string();
285                let mcp_tool = McpTool::from_mcp_server(tool, service.peer().clone());
286                handler.tool_server_handle.add_tool(mcp_tool).await?;
287                managed.push(tool_name);
288            }
289        }
290
291        Ok(service)
292    }
293}
294
295impl rmcp::handler::client::ClientHandler for McpClientHandler {
296    fn get_info(&self) -> rmcp::model::ClientInfo {
297        self.client_info.clone()
298    }
299
300    async fn on_tool_list_changed(
301        &self,
302        context: rmcp::service::NotificationContext<rmcp::service::RoleClient>,
303    ) {
304        let tools = match context.peer.list_all_tools().await {
305            Ok(tools) => tools,
306            Err(e) => {
307                tracing::error!("Failed to re-fetch MCP tool list: {e}");
308                return;
309            }
310        };
311
312        let mut managed = self.managed_tool_names.write().await;
313
314        for name in managed.drain(..) {
315            if let Err(e) = self.tool_server_handle.remove_tool(&name).await {
316                tracing::warn!("Failed to remove MCP tool '{name}' during refresh: {e}");
317            }
318        }
319
320        for tool in tools {
321            let tool_name = tool.name.to_string();
322            let mcp_tool = McpTool::from_mcp_server(tool, context.peer.clone());
323            match self.tool_server_handle.add_tool(mcp_tool).await {
324                Ok(()) => {
325                    managed.push(tool_name);
326                }
327                Err(e) => {
328                    tracing::error!("Failed to register MCP tool '{tool_name}': {e}");
329                }
330            }
331        }
332
333        tracing::info!(
334            tool_count = managed.len(),
335            "MCP tool list refreshed successfully"
336        );
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use std::sync::Arc;
343    use std::time::Duration;
344
345    use rmcp::handler::client::ClientHandler;
346    use rmcp::model::*;
347    use rmcp::service::RequestContext;
348    use rmcp::{RoleServer, ServerHandler, ServiceExt};
349    use tokio::sync::RwLock;
350
351    use super::McpClientHandler;
352    use crate::tool::server::ToolServer;
353
354    /// An MCP server whose tool list can be swapped at runtime.
355    #[derive(Clone)]
356    struct DynamicToolServer {
357        tools: Arc<RwLock<Vec<Tool>>>,
358    }
359
360    impl DynamicToolServer {
361        fn new(tools: Vec<Tool>) -> Self {
362            Self {
363                tools: Arc::new(RwLock::new(tools)),
364            }
365        }
366
367        async fn set_tools(&self, tools: Vec<Tool>) {
368            *self.tools.write().await = tools;
369        }
370    }
371
372    impl ServerHandler for DynamicToolServer {
373        fn get_info(&self) -> ServerInfo {
374            ServerInfo {
375                protocol_version: ProtocolVersion::V_2024_11_05,
376                capabilities: ServerCapabilities::builder().enable_tools().build(),
377                server_info: Implementation {
378                    name: "test-dynamic-server".to_string(),
379                    version: "0.1.0".to_string(),
380                    ..Default::default()
381                },
382                instructions: None,
383            }
384        }
385
386        async fn list_tools(
387            &self,
388            _request: Option<PaginatedRequestParams>,
389            _context: RequestContext<RoleServer>,
390        ) -> Result<ListToolsResult, ErrorData> {
391            let tools = self.tools.read().await.clone();
392            Ok(ListToolsResult {
393                tools,
394                next_cursor: None,
395                meta: None,
396            })
397        }
398
399        async fn call_tool(
400            &self,
401            request: CallToolRequestParams,
402            _context: RequestContext<RoleServer>,
403        ) -> Result<CallToolResult, ErrorData> {
404            Ok(CallToolResult::success(vec![Content::text(format!(
405                "called {}",
406                request.name
407            ))]))
408        }
409    }
410
411    fn make_tool(name: &str, description: &str) -> Tool {
412        Tool::new(
413            name.to_string(),
414            description.to_string(),
415            Arc::new(serde_json::Map::new()),
416        )
417    }
418
419    #[tokio::test]
420    async fn test_mcp_client_handler_initial_tool_registration() {
421        let initial_tools = vec![
422            make_tool("tool_a", "First tool"),
423            make_tool("tool_b", "Second tool"),
424        ];
425
426        let server = DynamicToolServer::new(initial_tools);
427        let tool_server_handle = ToolServer::new().run();
428
429        let (client_to_server, server_from_client) = tokio::io::duplex(8192);
430        let (server_to_client, client_from_server) = tokio::io::duplex(8192);
431
432        let server_clone = server.clone();
433        tokio::spawn(async move {
434            let _service = server_clone
435                .serve((server_from_client, server_to_client))
436                .await
437                .expect("server failed to start");
438            _service.waiting().await.expect("server error");
439        });
440
441        let client_info = ClientInfo::default();
442        let handler = McpClientHandler::new(client_info, tool_server_handle.clone());
443
444        let _mcp_service = handler
445            .connect((client_from_server, client_to_server))
446            .await
447            .expect("connect failed");
448
449        let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
450        assert_eq!(defs.len(), 2);
451
452        let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
453        assert!(names.contains(&"tool_a"));
454        assert!(names.contains(&"tool_b"));
455    }
456
457    #[tokio::test]
458    async fn test_mcp_client_handler_refreshes_on_tool_list_changed() {
459        let initial_tools = vec![make_tool("alpha", "Alpha tool")];
460
461        let server = DynamicToolServer::new(initial_tools);
462        let tool_server_handle = ToolServer::new().run();
463
464        let (client_to_server, server_from_client) = tokio::io::duplex(8192);
465        let (server_to_client, client_from_server) = tokio::io::duplex(8192);
466
467        let server_clone = server.clone();
468        let server_service_handle = tokio::spawn(async move {
469            server_clone
470                .serve((server_from_client, server_to_client))
471                .await
472                .expect("server failed to start")
473        });
474
475        let client_info = ClientInfo::default();
476        let handler = McpClientHandler::new(client_info, tool_server_handle.clone());
477
478        let _mcp_service = handler
479            .connect((client_from_server, client_to_server))
480            .await
481            .expect("connect failed");
482
483        // Verify initial state
484        let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
485        assert_eq!(defs.len(), 1);
486        assert_eq!(defs[0].name, "alpha");
487
488        // Update the server's tool list
489        server
490            .set_tools(vec![
491                make_tool("beta", "Beta tool"),
492                make_tool("gamma", "Gamma tool"),
493            ])
494            .await;
495
496        // Send the notification from the server side
497        let server_service = server_service_handle.await.unwrap();
498        server_service
499            .peer()
500            .notify_tool_list_changed()
501            .await
502            .expect("failed to send notification");
503
504        // The handler processes the notification asynchronously, so give it
505        // a moment to re-fetch and re-register tools.
506        tokio::time::sleep(Duration::from_millis(200)).await;
507
508        let defs = tool_server_handle.get_tool_defs(None).await.unwrap();
509        assert_eq!(defs.len(), 2);
510
511        let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
512        assert!(names.contains(&"beta"), "expected 'beta' in {names:?}");
513        assert!(names.contains(&"gamma"), "expected 'gamma' in {names:?}");
514        // The old tool must be gone
515        assert!(
516            !names.contains(&"alpha"),
517            "expected 'alpha' to be removed, found {names:?}"
518        );
519    }
520
521    #[tokio::test]
522    async fn test_mcp_client_handler_get_info_delegates() {
523        let client_info = ClientInfo {
524            protocol_version: Default::default(),
525            capabilities: ClientCapabilities::default(),
526            client_info: Implementation {
527                name: "test-client".to_string(),
528                version: "1.0.0".to_string(),
529                ..Default::default()
530            },
531            meta: None,
532        };
533
534        let tool_server_handle = ToolServer::new().run();
535        let handler = McpClientHandler::new(client_info.clone(), tool_server_handle);
536
537        let returned = handler.get_info();
538        assert_eq!(returned.client_info.name, "test-client");
539        assert_eq!(returned.client_info.version, "1.0.0");
540    }
541}