Skip to main content

llm_stack/
mcp.rs

1//! MCP (Model Context Protocol) integration for `ToolRegistry`.
2//!
3//! This module defines the [`McpService`] trait - a minimal contract for MCP-like
4//! tool sources. Users implement this trait for their chosen MCP client library
5//! (e.g., `rmcp`), keeping version coupling in their code rather than in llm-core.
6//!
7//! # Design Philosophy
8//!
9//! Rather than depending on a specific MCP client library, llm-core defines a
10//! simple trait that any MCP implementation can satisfy. This means:
11//!
12//! - No forced dependency upgrades when MCP libraries release new versions
13//! - Users choose their preferred MCP client and version
14//! - Easy to mock for testing
15//! - Simple enough to implement (~50 lines for rmcp)
16//!
17//! # Example: Implementing for rmcp
18//!
19//! ```rust,ignore
20//! use std::sync::Arc;
21//! use rmcp::service::RunningService;
22//! use rmcp::handler::client::ClientHandler;
23//! use rmcp::RoleClient;
24//! use llm_stack::{McpService, McpError, ToolDefinition, JsonSchema};
25//!
26//! /// Adapter: wraps rmcp RunningService to implement llm-core's McpService
27//! pub struct RmcpAdapter<S: ClientHandler> {
28//!     service: Arc<RunningService<RoleClient, S>>,
29//! }
30//!
31//! impl<S: ClientHandler> RmcpAdapter<S> {
32//!     pub fn new(service: Arc<RunningService<RoleClient, S>>) -> Self {
33//!         Self { service }
34//!     }
35//! }
36//!
37//! impl<S: ClientHandler> McpService for RmcpAdapter<S> {
38//!     async fn list_tools(&self) -> Result<Vec<ToolDefinition>, McpError> {
39//!         let tools = self.service
40//!             .list_all_tools()
41//!             .await
42//!             .map_err(|e| McpError::Protocol(e.to_string()))?;
43//!
44//!         Ok(tools.into_iter().map(|t| ToolDefinition {
45//!             name: t.name.to_string(),
46//!             description: t.description.map(|d| d.to_string()).unwrap_or_default(),
47//!             parameters: JsonSchema::new(
48//!                 serde_json::to_value(&*t.input_schema).unwrap_or_default()
49//!             ),
50//!         }).collect())
51//!     }
52//!
53//!     async fn call_tool(&self, name: &str, args: serde_json::Value) -> Result<String, McpError> {
54//!         use rmcp::model::{CallToolRequestParams, RawContent};
55//!
56//!         let params = CallToolRequestParams {
57//!             meta: None,
58//!             name: name.to_string().into(),
59//!             arguments: args.as_object().cloned(),
60//!             task: None,
61//!         };
62//!
63//!         let result = self.service
64//!             .call_tool(params)
65//!             .await
66//!             .map_err(|e| McpError::ToolExecution(e.to_string()))?;
67//!
68//!         if result.is_error.unwrap_or(false) {
69//!             return Err(McpError::ToolExecution(extract_text(&result.content)));
70//!         }
71//!
72//!         Ok(extract_text(&result.content))
73//!     }
74//! }
75//!
76//! fn extract_text(content: &[rmcp::model::Content]) -> String {
77//!     use rmcp::model::RawContent;
78//!     content.iter().map(|c| match &c.raw {
79//!         RawContent::Text(t) => t.text.clone(),
80//!         _ => "[non-text]".into(),
81//!     }).collect::<Vec<_>>().join("\n")
82//! }
83//! ```
84//!
85//! # Usage
86//!
87//! ```rust,ignore
88//! use std::sync::Arc;
89//! use llm_stack::{ToolRegistry, McpRegistryExt};
90//!
91//! // Create your MCP service (using rmcp or any other library)
92//! let mcp_service = Arc::new(RmcpAdapter::new(rmcp_client));
93//!
94//! // Register tools with llm-core
95//! let mut registry = ToolRegistry::new();
96//! registry.register_mcp_service(&mcp_service).await?;
97//! ```
98
99use std::future::Future;
100use std::pin::Pin;
101use std::sync::Arc;
102
103use crate::tool::{ToolError, ToolHandler};
104use crate::{ToolDefinition, ToolRegistry};
105
106/// Error type for MCP operations.
107#[derive(Debug, Clone, thiserror::Error)]
108pub enum McpError {
109    /// Error during MCP protocol communication.
110    #[error("MCP protocol error: {0}")]
111    Protocol(String),
112
113    /// Error during tool execution.
114    #[error("MCP tool execution error: {0}")]
115    ToolExecution(String),
116}
117
118/// Minimal contract for MCP-like tool sources.
119///
120/// Implement this trait to bridge any MCP client library with llm-core's
121/// tool system. The trait requires only two operations:
122///
123/// 1. List available tools (with their schemas)
124/// 2. Call a tool by name with JSON arguments
125///
126/// # Thread Safety
127///
128/// Implementations must be `Send + Sync` to allow use across async tasks.
129/// Typically achieved by wrapping the underlying client in `Arc`.
130///
131/// # Object Safety
132///
133/// This trait is object-safe (`dyn McpService`) to allow storing different
134/// MCP service implementations in the same registry.
135pub trait McpService: Send + Sync {
136    /// Lists all tools available from the MCP server.
137    fn list_tools(
138        &self,
139    ) -> Pin<Box<dyn Future<Output = Result<Vec<ToolDefinition>, McpError>> + Send + '_>>;
140
141    /// Calls a tool on the MCP server.
142    ///
143    /// # Arguments
144    ///
145    /// * `name` - The name of the tool to call
146    /// * `args` - JSON arguments matching the tool's input schema
147    ///
148    /// # Returns
149    ///
150    /// The tool's text output, or an error if execution failed.
151    fn call_tool(
152        &self,
153        name: &str,
154        args: serde_json::Value,
155    ) -> Pin<Box<dyn Future<Output = Result<String, McpError>> + Send + '_>>;
156}
157
158/// A [`ToolHandler`] that delegates execution to an [`McpService`].
159struct McpToolHandler {
160    service: Arc<dyn McpService>,
161    definition: ToolDefinition,
162}
163
164impl McpToolHandler {
165    fn new(service: Arc<dyn McpService>, definition: ToolDefinition) -> Self {
166        Self {
167            service,
168            definition,
169        }
170    }
171}
172
173impl std::fmt::Debug for McpToolHandler {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        f.debug_struct("McpToolHandler")
176            .field("tool", &self.definition.name)
177            .finish_non_exhaustive()
178    }
179}
180
181impl ToolHandler<()> for McpToolHandler {
182    fn definition(&self) -> ToolDefinition {
183        self.definition.clone()
184    }
185
186    fn execute<'a>(
187        &'a self,
188        input: serde_json::Value,
189        _ctx: &'a (),
190    ) -> Pin<Box<dyn Future<Output = Result<crate::tool::ToolOutput, ToolError>> + Send + 'a>> {
191        Box::pin(async move {
192            self.service
193                .call_tool(&self.definition.name, input)
194                .await
195                .map(crate::tool::ToolOutput::new)
196                .map_err(|e| ToolError::new(e.to_string()))
197        })
198    }
199}
200
201/// Extension trait for registering MCP services with a [`ToolRegistry`].
202pub trait McpRegistryExt {
203    /// Registers all tools from an MCP service.
204    ///
205    /// Discovers tools from the service and registers each one as a
206    /// `ToolHandler` that delegates execution to the service.
207    ///
208    /// # Returns
209    ///
210    /// The number of tools registered.
211    fn register_mcp_service<S: McpService + 'static>(
212        &mut self,
213        service: &Arc<S>,
214    ) -> impl Future<Output = Result<usize, McpError>> + Send;
215
216    /// Registers specific tools from an MCP service by name.
217    ///
218    /// Only registers tools whose names are in the provided list.
219    /// Tools not found on the server are silently skipped.
220    ///
221    /// # Returns
222    ///
223    /// The number of tools actually registered.
224    fn register_mcp_tools_by_name<S: McpService + 'static>(
225        &mut self,
226        service: &Arc<S>,
227        tool_names: &[&str],
228    ) -> impl Future<Output = Result<usize, McpError>> + Send;
229}
230
231impl McpRegistryExt for ToolRegistry<()> {
232    async fn register_mcp_service<S: McpService + 'static>(
233        &mut self,
234        service: &Arc<S>,
235    ) -> Result<usize, McpError> {
236        let tools = service.list_tools().await?;
237        let count = tools.len();
238
239        for definition in tools {
240            let handler =
241                McpToolHandler::new(Arc::clone(service) as Arc<dyn McpService>, definition);
242            self.register(handler);
243        }
244
245        Ok(count)
246    }
247
248    async fn register_mcp_tools_by_name<S: McpService + 'static>(
249        &mut self,
250        service: &Arc<S>,
251        tool_names: &[&str],
252    ) -> Result<usize, McpError> {
253        let tools = service.list_tools().await?;
254        let mut count = 0;
255
256        for definition in tools {
257            if tool_names.contains(&definition.name.as_str()) {
258                let handler =
259                    McpToolHandler::new(Arc::clone(service) as Arc<dyn McpService>, definition);
260                self.register(handler);
261                count += 1;
262            }
263        }
264
265        Ok(count)
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::JsonSchema;
273    use std::sync::atomic::{AtomicUsize, Ordering};
274
275    /// Mock MCP service for testing
276    struct MockMcpService {
277        tools: Vec<ToolDefinition>,
278        call_count: AtomicUsize,
279    }
280
281    impl MockMcpService {
282        fn new(tools: Vec<ToolDefinition>) -> Self {
283            Self {
284                tools,
285                call_count: AtomicUsize::new(0),
286            }
287        }
288    }
289
290    impl McpService for MockMcpService {
291        fn list_tools(
292            &self,
293        ) -> Pin<Box<dyn Future<Output = Result<Vec<ToolDefinition>, McpError>> + Send + '_>>
294        {
295            let tools = self.tools.clone();
296            Box::pin(async move { Ok(tools) })
297        }
298
299        fn call_tool(
300            &self,
301            name: &str,
302            _args: serde_json::Value,
303        ) -> Pin<Box<dyn Future<Output = Result<String, McpError>> + Send + '_>> {
304            self.call_count.fetch_add(1, Ordering::SeqCst);
305            let result = format!("Called {name}");
306            Box::pin(async move { Ok(result) })
307        }
308    }
309
310    fn test_tool(name: &str) -> ToolDefinition {
311        ToolDefinition {
312            name: name.to_string(),
313            description: format!("{name} description"),
314            parameters: JsonSchema::new(serde_json::json!({"type": "object"})),
315            retry: None,
316        }
317    }
318
319    #[test]
320    fn test_trait_is_object_safe() {
321        fn assert_object_safe(_: &dyn McpService) {}
322        let mock = MockMcpService::new(vec![]);
323        assert_object_safe(&mock);
324    }
325
326    #[tokio::test]
327    async fn test_register_mcp_service() {
328        let service = Arc::new(MockMcpService::new(vec![
329            test_tool("tool_a"),
330            test_tool("tool_b"),
331        ]));
332
333        let mut registry = ToolRegistry::new();
334        let count = registry.register_mcp_service(&service).await.unwrap();
335
336        assert_eq!(count, 2);
337        assert_eq!(registry.len(), 2);
338        assert!(registry.get("tool_a").is_some());
339        assert!(registry.get("tool_b").is_some());
340    }
341
342    #[tokio::test]
343    async fn test_register_mcp_tools_by_name() {
344        let service = Arc::new(MockMcpService::new(vec![
345            test_tool("tool_a"),
346            test_tool("tool_b"),
347            test_tool("tool_c"),
348        ]));
349
350        let mut registry = ToolRegistry::new();
351        let count = registry
352            .register_mcp_tools_by_name(&service, &["tool_a", "tool_c"])
353            .await
354            .unwrap();
355
356        assert_eq!(count, 2);
357        assert!(registry.get("tool_a").is_some());
358        assert!(registry.get("tool_b").is_none());
359        assert!(registry.get("tool_c").is_some());
360    }
361
362    #[tokio::test]
363    async fn test_mcp_tool_execution() {
364        let service = Arc::new(MockMcpService::new(vec![test_tool("my_tool")]));
365
366        let mut registry = ToolRegistry::new();
367        registry.register_mcp_service(&service).await.unwrap();
368
369        let handler = registry.get("my_tool").unwrap();
370        let result = handler.execute(serde_json::json!({}), &()).await.unwrap();
371
372        assert_eq!(result.content, "Called my_tool");
373        assert_eq!(service.call_count.load(Ordering::SeqCst), 1);
374    }
375}