Skip to main content

llm_stack/
mcp.rs

1//! MCP (Model Context Protocol) integration for `ToolRegistry`.
2//!
3//! This module defines the [`McpService`] trait - a minimal contract for MCP-like
4//! tool sources. Users implement this trait for their chosen MCP client library
5//! (e.g., `rmcp`), keeping version coupling in their code rather than in llm-core.
6//!
7//! # Design Philosophy
8//!
9//! Rather than depending on a specific MCP client library, llm-core defines a
10//! simple trait that any MCP implementation can satisfy. This means:
11//!
12//! - No forced dependency upgrades when MCP libraries release new versions
13//! - Users choose their preferred MCP client and version
14//! - Easy to mock for testing
15//! - Simple enough to implement (~50 lines for rmcp)
16//!
17//! # Example: Implementing for rmcp
18//!
19//! ```rust,ignore
20//! use std::sync::Arc;
21//! use rmcp::service::RunningService;
22//! use rmcp::handler::client::ClientHandler;
23//! use rmcp::RoleClient;
24//! use llm_stack::{ToolDefinition, JsonSchema};
25//! use llm_stack::mcp::{McpService, McpError};
26//!
27//! /// Adapter: wraps rmcp RunningService to implement llm-core's McpService
28//! pub struct RmcpAdapter<S: ClientHandler> {
29//!     service: Arc<RunningService<RoleClient, S>>,
30//! }
31//!
32//! impl<S: ClientHandler> RmcpAdapter<S> {
33//!     pub fn new(service: Arc<RunningService<RoleClient, S>>) -> Self {
34//!         Self { service }
35//!     }
36//! }
37//!
38//! impl<S: ClientHandler> McpService for RmcpAdapter<S> {
39//!     async fn list_tools(&self) -> Result<Vec<ToolDefinition>, McpError> {
40//!         let tools = self.service
41//!             .list_all_tools()
42//!             .await
43//!             .map_err(|e| McpError::Protocol(e.to_string()))?;
44//!
45//!         Ok(tools.into_iter().map(|t| ToolDefinition {
46//!             name: t.name.to_string(),
47//!             description: t.description.map(|d| d.to_string()).unwrap_or_default(),
48//!             parameters: JsonSchema::new(
49//!                 serde_json::to_value(&*t.input_schema).unwrap_or_default()
50//!             ),
51//!         }).collect())
52//!     }
53//!
54//!     async fn call_tool(&self, name: &str, args: serde_json::Value) -> Result<String, McpError> {
55//!         use rmcp::model::{CallToolRequestParams, RawContent};
56//!
57//!         let params = CallToolRequestParams {
58//!             meta: None,
59//!             name: name.to_string().into(),
60//!             arguments: args.as_object().cloned(),
61//!             task: None,
62//!         };
63//!
64//!         let result = self.service
65//!             .call_tool(params)
66//!             .await
67//!             .map_err(|e| McpError::ToolExecution(e.to_string()))?;
68//!
69//!         if result.is_error.unwrap_or(false) {
70//!             return Err(McpError::ToolExecution(extract_text(&result.content)));
71//!         }
72//!
73//!         Ok(extract_text(&result.content))
74//!     }
75//! }
76//!
77//! fn extract_text(content: &[rmcp::model::Content]) -> String {
78//!     use rmcp::model::RawContent;
79//!     content.iter().map(|c| match &c.raw {
80//!         RawContent::Text(t) => t.text.clone(),
81//!         _ => "[non-text]".into(),
82//!     }).collect::<Vec<_>>().join("\n")
83//! }
84//! ```
85//!
86//! # Usage
87//!
88//! ```rust,ignore
89//! use std::sync::Arc;
90//! use llm_stack::ToolRegistry;
91//! use llm_stack::mcp::McpRegistryExt;
92//!
93//! // Create your MCP service (using rmcp or any other library)
94//! let mcp_service = Arc::new(RmcpAdapter::new(rmcp_client));
95//!
96//! // Register tools with llm-core
97//! let mut registry = ToolRegistry::new();
98//! registry.register_mcp_service(&mcp_service).await?;
99//! ```
100
101use std::future::Future;
102use std::pin::Pin;
103use std::sync::Arc;
104
105use crate::tool::{ToolError, ToolHandler};
106use crate::{ToolDefinition, ToolRegistry};
107
108/// Error type for MCP operations.
109#[derive(Debug, Clone, thiserror::Error)]
110pub enum McpError {
111    /// Error during MCP protocol communication.
112    #[error("MCP protocol error: {0}")]
113    Protocol(String),
114
115    /// Error during tool execution.
116    #[error("MCP tool execution error: {0}")]
117    ToolExecution(String),
118}
119
120/// Minimal contract for MCP-like tool sources.
121///
122/// Implement this trait to bridge any MCP client library with llm-core's
123/// tool system. The trait requires only two operations:
124///
125/// 1. List available tools (with their schemas)
126/// 2. Call a tool by name with JSON arguments
127///
128/// # Thread Safety
129///
130/// Implementations must be `Send + Sync` to allow use across async tasks.
131/// Typically achieved by wrapping the underlying client in `Arc`.
132///
133/// # Object Safety
134///
135/// This trait is object-safe (`dyn McpService`) to allow storing different
136/// MCP service implementations in the same registry.
137pub trait McpService: Send + Sync {
138    /// Lists all tools available from the MCP server.
139    fn list_tools(
140        &self,
141    ) -> Pin<Box<dyn Future<Output = Result<Vec<ToolDefinition>, McpError>> + Send + '_>>;
142
143    /// Calls a tool on the MCP server.
144    ///
145    /// # Arguments
146    ///
147    /// * `name` - The name of the tool to call
148    /// * `args` - JSON arguments matching the tool's input schema
149    ///
150    /// # Returns
151    ///
152    /// The tool's text output, or an error if execution failed.
153    fn call_tool(
154        &self,
155        name: &str,
156        args: serde_json::Value,
157    ) -> Pin<Box<dyn Future<Output = Result<String, McpError>> + Send + '_>>;
158}
159
160/// A [`ToolHandler`] that delegates execution to an [`McpService`].
161struct McpToolHandler {
162    service: Arc<dyn McpService>,
163    definition: ToolDefinition,
164}
165
166impl McpToolHandler {
167    fn new(service: Arc<dyn McpService>, definition: ToolDefinition) -> Self {
168        Self {
169            service,
170            definition,
171        }
172    }
173}
174
175impl std::fmt::Debug for McpToolHandler {
176    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177        f.debug_struct("McpToolHandler")
178            .field("tool", &self.definition.name)
179            .finish_non_exhaustive()
180    }
181}
182
183impl ToolHandler<()> for McpToolHandler {
184    fn definition(&self) -> ToolDefinition {
185        self.definition.clone()
186    }
187
188    fn execute<'a>(
189        &'a self,
190        input: serde_json::Value,
191        _ctx: &'a (),
192    ) -> Pin<Box<dyn Future<Output = Result<crate::tool::ToolOutput, ToolError>> + Send + 'a>> {
193        Box::pin(async move {
194            self.service
195                .call_tool(&self.definition.name, input)
196                .await
197                .map(crate::tool::ToolOutput::new)
198                .map_err(|e| ToolError::new(e.to_string()))
199        })
200    }
201}
202
203/// Extension trait for registering MCP services with a [`ToolRegistry`].
204pub trait McpRegistryExt {
205    /// Registers all tools from an MCP service.
206    ///
207    /// Discovers tools from the service and registers each one as a
208    /// `ToolHandler` that delegates execution to the service.
209    ///
210    /// # Returns
211    ///
212    /// The number of tools registered.
213    fn register_mcp_service<S: McpService + 'static>(
214        &mut self,
215        service: &Arc<S>,
216    ) -> impl Future<Output = Result<usize, McpError>> + Send;
217
218    /// Registers specific tools from an MCP service by name.
219    ///
220    /// Only registers tools whose names are in the provided list.
221    /// Tools not found on the server are silently skipped.
222    ///
223    /// # Returns
224    ///
225    /// The number of tools actually registered.
226    fn register_mcp_tools_by_name<S: McpService + 'static>(
227        &mut self,
228        service: &Arc<S>,
229        tool_names: &[&str],
230    ) -> impl Future<Output = Result<usize, McpError>> + Send;
231}
232
233impl McpRegistryExt for ToolRegistry<()> {
234    async fn register_mcp_service<S: McpService + 'static>(
235        &mut self,
236        service: &Arc<S>,
237    ) -> Result<usize, McpError> {
238        let tools = service.list_tools().await?;
239        let count = tools.len();
240
241        for definition in tools {
242            let handler =
243                McpToolHandler::new(Arc::clone(service) as Arc<dyn McpService>, definition);
244            self.register(handler);
245        }
246
247        Ok(count)
248    }
249
250    async fn register_mcp_tools_by_name<S: McpService + 'static>(
251        &mut self,
252        service: &Arc<S>,
253        tool_names: &[&str],
254    ) -> Result<usize, McpError> {
255        let tools = service.list_tools().await?;
256        let mut count = 0;
257
258        for definition in tools {
259            if tool_names.contains(&definition.name.as_str()) {
260                let handler =
261                    McpToolHandler::new(Arc::clone(service) as Arc<dyn McpService>, definition);
262                self.register(handler);
263                count += 1;
264            }
265        }
266
267        Ok(count)
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use crate::JsonSchema;
275    use std::sync::atomic::{AtomicUsize, Ordering};
276
277    /// Mock MCP service for testing
278    struct MockMcpService {
279        tools: Vec<ToolDefinition>,
280        call_count: AtomicUsize,
281    }
282
283    impl MockMcpService {
284        fn new(tools: Vec<ToolDefinition>) -> Self {
285            Self {
286                tools,
287                call_count: AtomicUsize::new(0),
288            }
289        }
290    }
291
292    impl McpService for MockMcpService {
293        fn list_tools(
294            &self,
295        ) -> Pin<Box<dyn Future<Output = Result<Vec<ToolDefinition>, McpError>> + Send + '_>>
296        {
297            let tools = self.tools.clone();
298            Box::pin(async move { Ok(tools) })
299        }
300
301        fn call_tool(
302            &self,
303            name: &str,
304            _args: serde_json::Value,
305        ) -> Pin<Box<dyn Future<Output = Result<String, McpError>> + Send + '_>> {
306            self.call_count.fetch_add(1, Ordering::SeqCst);
307            let result = format!("Called {name}");
308            Box::pin(async move { Ok(result) })
309        }
310    }
311
312    fn test_tool(name: &str) -> ToolDefinition {
313        ToolDefinition {
314            name: name.to_string(),
315            description: format!("{name} description"),
316            parameters: JsonSchema::new(serde_json::json!({"type": "object"})),
317            retry: None,
318        }
319    }
320
321    #[test]
322    fn test_trait_is_object_safe() {
323        fn assert_object_safe(_: &dyn McpService) {}
324        let mock = MockMcpService::new(vec![]);
325        assert_object_safe(&mock);
326    }
327
328    #[tokio::test]
329    async fn test_register_mcp_service() {
330        let service = Arc::new(MockMcpService::new(vec![
331            test_tool("tool_a"),
332            test_tool("tool_b"),
333        ]));
334
335        let mut registry = ToolRegistry::new();
336        let count = registry.register_mcp_service(&service).await.unwrap();
337
338        assert_eq!(count, 2);
339        assert_eq!(registry.len(), 2);
340        assert!(registry.get("tool_a").is_some());
341        assert!(registry.get("tool_b").is_some());
342    }
343
344    #[tokio::test]
345    async fn test_register_mcp_tools_by_name() {
346        let service = Arc::new(MockMcpService::new(vec![
347            test_tool("tool_a"),
348            test_tool("tool_b"),
349            test_tool("tool_c"),
350        ]));
351
352        let mut registry = ToolRegistry::new();
353        let count = registry
354            .register_mcp_tools_by_name(&service, &["tool_a", "tool_c"])
355            .await
356            .unwrap();
357
358        assert_eq!(count, 2);
359        assert!(registry.get("tool_a").is_some());
360        assert!(registry.get("tool_b").is_none());
361        assert!(registry.get("tool_c").is_some());
362    }
363
364    #[tokio::test]
365    async fn test_mcp_tool_execution() {
366        let service = Arc::new(MockMcpService::new(vec![test_tool("my_tool")]));
367
368        let mut registry = ToolRegistry::new();
369        registry.register_mcp_service(&service).await.unwrap();
370
371        let handler = registry.get("my_tool").unwrap();
372        let result = handler.execute(serde_json::json!({}), &()).await.unwrap();
373
374        assert_eq!(result.content, "Called my_tool");
375        assert_eq!(service.call_count.load(Ordering::SeqCst), 1);
376    }
377}