llm_stack/mcp.rs
1//! MCP (Model Context Protocol) integration for `ToolRegistry`.
2//!
3//! This module defines the [`McpService`] trait - a minimal contract for MCP-like
4//! tool sources. Users implement this trait for their chosen MCP client library
5//! (e.g., `rmcp`), keeping version coupling in their code rather than in llm-core.
6//!
7//! # Design Philosophy
8//!
9//! Rather than depending on a specific MCP client library, llm-core defines a
10//! simple trait that any MCP implementation can satisfy. This means:
11//!
12//! - No forced dependency upgrades when MCP libraries release new versions
13//! - Users choose their preferred MCP client and version
14//! - Easy to mock for testing
15//! - Simple enough to implement (~50 lines for rmcp)
16//!
17//! # Example: Implementing for rmcp
18//!
19//! ```rust,ignore
20//! use std::sync::Arc;
21//! use rmcp::service::RunningService;
22//! use rmcp::handler::client::ClientHandler;
23//! use rmcp::RoleClient;
24//! use llm_stack::{ToolDefinition, JsonSchema};
25//! use llm_stack::mcp::{McpService, McpError};
26//!
27//! /// Adapter: wraps rmcp RunningService to implement llm-core's McpService
28//! pub struct RmcpAdapter<S: ClientHandler> {
29//! service: Arc<RunningService<RoleClient, S>>,
30//! }
31//!
32//! impl<S: ClientHandler> RmcpAdapter<S> {
33//! pub fn new(service: Arc<RunningService<RoleClient, S>>) -> Self {
34//! Self { service }
35//! }
36//! }
37//!
38//! impl<S: ClientHandler> McpService for RmcpAdapter<S> {
39//! async fn list_tools(&self) -> Result<Vec<ToolDefinition>, McpError> {
40//! let tools = self.service
41//! .list_all_tools()
42//! .await
43//! .map_err(|e| McpError::Protocol(e.to_string()))?;
44//!
45//! Ok(tools.into_iter().map(|t| ToolDefinition {
46//! name: t.name.to_string(),
47//! description: t.description.map(|d| d.to_string()).unwrap_or_default(),
48//! parameters: JsonSchema::new(
49//! serde_json::to_value(&*t.input_schema).unwrap_or_default()
50//! ),
51//! }).collect())
52//! }
53//!
54//! async fn call_tool(&self, name: &str, args: serde_json::Value) -> Result<String, McpError> {
55//! use rmcp::model::{CallToolRequestParams, RawContent};
56//!
57//! let params = CallToolRequestParams {
58//! meta: None,
59//! name: name.to_string().into(),
60//! arguments: args.as_object().cloned(),
61//! task: None,
62//! };
63//!
64//! let result = self.service
65//! .call_tool(params)
66//! .await
67//! .map_err(|e| McpError::ToolExecution(e.to_string()))?;
68//!
69//! if result.is_error.unwrap_or(false) {
70//! return Err(McpError::ToolExecution(extract_text(&result.content)));
71//! }
72//!
73//! Ok(extract_text(&result.content))
74//! }
75//! }
76//!
77//! fn extract_text(content: &[rmcp::model::Content]) -> String {
78//! use rmcp::model::RawContent;
79//! content.iter().map(|c| match &c.raw {
80//! RawContent::Text(t) => t.text.clone(),
81//! _ => "[non-text]".into(),
82//! }).collect::<Vec<_>>().join("\n")
83//! }
84//! ```
85//!
86//! # Usage
87//!
88//! ```rust,ignore
89//! use std::sync::Arc;
90//! use llm_stack::ToolRegistry;
91//! use llm_stack::mcp::McpRegistryExt;
92//!
93//! // Create your MCP service (using rmcp or any other library)
94//! let mcp_service = Arc::new(RmcpAdapter::new(rmcp_client));
95//!
96//! // Register tools with llm-core
97//! let mut registry = ToolRegistry::new();
98//! registry.register_mcp_service(&mcp_service).await?;
99//! ```
100
101use std::future::Future;
102use std::pin::Pin;
103use std::sync::Arc;
104
105use crate::tool::{ToolError, ToolHandler};
106use crate::{ToolDefinition, ToolRegistry};
107
108/// Error type for MCP operations.
109#[derive(Debug, Clone, thiserror::Error)]
110pub enum McpError {
111 /// Error during MCP protocol communication.
112 #[error("MCP protocol error: {0}")]
113 Protocol(String),
114
115 /// Error during tool execution.
116 #[error("MCP tool execution error: {0}")]
117 ToolExecution(String),
118}
119
120/// Minimal contract for MCP-like tool sources.
121///
122/// Implement this trait to bridge any MCP client library with llm-core's
123/// tool system. The trait requires only two operations:
124///
125/// 1. List available tools (with their schemas)
126/// 2. Call a tool by name with JSON arguments
127///
128/// # Thread Safety
129///
130/// Implementations must be `Send + Sync` to allow use across async tasks.
131/// Typically achieved by wrapping the underlying client in `Arc`.
132///
133/// # Object Safety
134///
135/// This trait is object-safe (`dyn McpService`) to allow storing different
136/// MCP service implementations in the same registry.
137pub trait McpService: Send + Sync {
138 /// Lists all tools available from the MCP server.
139 fn list_tools(
140 &self,
141 ) -> Pin<Box<dyn Future<Output = Result<Vec<ToolDefinition>, McpError>> + Send + '_>>;
142
143 /// Calls a tool on the MCP server.
144 ///
145 /// # Arguments
146 ///
147 /// * `name` - The name of the tool to call
148 /// * `args` - JSON arguments matching the tool's input schema
149 ///
150 /// # Returns
151 ///
152 /// The tool's text output, or an error if execution failed.
153 fn call_tool(
154 &self,
155 name: &str,
156 args: serde_json::Value,
157 ) -> Pin<Box<dyn Future<Output = Result<String, McpError>> + Send + '_>>;
158}
159
160/// A [`ToolHandler`] that delegates execution to an [`McpService`].
161struct McpToolHandler {
162 service: Arc<dyn McpService>,
163 definition: ToolDefinition,
164}
165
166impl McpToolHandler {
167 fn new(service: Arc<dyn McpService>, definition: ToolDefinition) -> Self {
168 Self {
169 service,
170 definition,
171 }
172 }
173}
174
175impl std::fmt::Debug for McpToolHandler {
176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177 f.debug_struct("McpToolHandler")
178 .field("tool", &self.definition.name)
179 .finish_non_exhaustive()
180 }
181}
182
183impl ToolHandler<()> for McpToolHandler {
184 fn definition(&self) -> ToolDefinition {
185 self.definition.clone()
186 }
187
188 fn execute<'a>(
189 &'a self,
190 input: serde_json::Value,
191 _ctx: &'a (),
192 ) -> Pin<Box<dyn Future<Output = Result<crate::tool::ToolOutput, ToolError>> + Send + 'a>> {
193 Box::pin(async move {
194 self.service
195 .call_tool(&self.definition.name, input)
196 .await
197 .map(crate::tool::ToolOutput::new)
198 .map_err(|e| ToolError::new(e.to_string()))
199 })
200 }
201}
202
203/// Extension trait for registering MCP services with a [`ToolRegistry`].
204pub trait McpRegistryExt {
205 /// Registers all tools from an MCP service.
206 ///
207 /// Discovers tools from the service and registers each one as a
208 /// `ToolHandler` that delegates execution to the service.
209 ///
210 /// # Returns
211 ///
212 /// The number of tools registered.
213 fn register_mcp_service<S: McpService + 'static>(
214 &mut self,
215 service: &Arc<S>,
216 ) -> impl Future<Output = Result<usize, McpError>> + Send;
217
218 /// Registers specific tools from an MCP service by name.
219 ///
220 /// Only registers tools whose names are in the provided list.
221 /// Tools not found on the server are silently skipped.
222 ///
223 /// # Returns
224 ///
225 /// The number of tools actually registered.
226 fn register_mcp_tools_by_name<S: McpService + 'static>(
227 &mut self,
228 service: &Arc<S>,
229 tool_names: &[&str],
230 ) -> impl Future<Output = Result<usize, McpError>> + Send;
231}
232
233impl McpRegistryExt for ToolRegistry<()> {
234 async fn register_mcp_service<S: McpService + 'static>(
235 &mut self,
236 service: &Arc<S>,
237 ) -> Result<usize, McpError> {
238 let tools = service.list_tools().await?;
239 let count = tools.len();
240
241 for definition in tools {
242 let handler =
243 McpToolHandler::new(Arc::clone(service) as Arc<dyn McpService>, definition);
244 self.register(handler);
245 }
246
247 Ok(count)
248 }
249
250 async fn register_mcp_tools_by_name<S: McpService + 'static>(
251 &mut self,
252 service: &Arc<S>,
253 tool_names: &[&str],
254 ) -> Result<usize, McpError> {
255 let tools = service.list_tools().await?;
256 let mut count = 0;
257
258 for definition in tools {
259 if tool_names.contains(&definition.name.as_str()) {
260 let handler =
261 McpToolHandler::new(Arc::clone(service) as Arc<dyn McpService>, definition);
262 self.register(handler);
263 count += 1;
264 }
265 }
266
267 Ok(count)
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use crate::JsonSchema;
275 use std::sync::atomic::{AtomicUsize, Ordering};
276
277 /// Mock MCP service for testing
278 struct MockMcpService {
279 tools: Vec<ToolDefinition>,
280 call_count: AtomicUsize,
281 }
282
283 impl MockMcpService {
284 fn new(tools: Vec<ToolDefinition>) -> Self {
285 Self {
286 tools,
287 call_count: AtomicUsize::new(0),
288 }
289 }
290 }
291
292 impl McpService for MockMcpService {
293 fn list_tools(
294 &self,
295 ) -> Pin<Box<dyn Future<Output = Result<Vec<ToolDefinition>, McpError>> + Send + '_>>
296 {
297 let tools = self.tools.clone();
298 Box::pin(async move { Ok(tools) })
299 }
300
301 fn call_tool(
302 &self,
303 name: &str,
304 _args: serde_json::Value,
305 ) -> Pin<Box<dyn Future<Output = Result<String, McpError>> + Send + '_>> {
306 self.call_count.fetch_add(1, Ordering::SeqCst);
307 let result = format!("Called {name}");
308 Box::pin(async move { Ok(result) })
309 }
310 }
311
312 fn test_tool(name: &str) -> ToolDefinition {
313 ToolDefinition {
314 name: name.to_string(),
315 description: format!("{name} description"),
316 parameters: JsonSchema::new(serde_json::json!({"type": "object"})),
317 retry: None,
318 }
319 }
320
321 #[test]
322 fn test_trait_is_object_safe() {
323 fn assert_object_safe(_: &dyn McpService) {}
324 let mock = MockMcpService::new(vec![]);
325 assert_object_safe(&mock);
326 }
327
328 #[tokio::test]
329 async fn test_register_mcp_service() {
330 let service = Arc::new(MockMcpService::new(vec![
331 test_tool("tool_a"),
332 test_tool("tool_b"),
333 ]));
334
335 let mut registry = ToolRegistry::new();
336 let count = registry.register_mcp_service(&service).await.unwrap();
337
338 assert_eq!(count, 2);
339 assert_eq!(registry.len(), 2);
340 assert!(registry.get("tool_a").is_some());
341 assert!(registry.get("tool_b").is_some());
342 }
343
344 #[tokio::test]
345 async fn test_register_mcp_tools_by_name() {
346 let service = Arc::new(MockMcpService::new(vec![
347 test_tool("tool_a"),
348 test_tool("tool_b"),
349 test_tool("tool_c"),
350 ]));
351
352 let mut registry = ToolRegistry::new();
353 let count = registry
354 .register_mcp_tools_by_name(&service, &["tool_a", "tool_c"])
355 .await
356 .unwrap();
357
358 assert_eq!(count, 2);
359 assert!(registry.get("tool_a").is_some());
360 assert!(registry.get("tool_b").is_none());
361 assert!(registry.get("tool_c").is_some());
362 }
363
364 #[tokio::test]
365 async fn test_mcp_tool_execution() {
366 let service = Arc::new(MockMcpService::new(vec![test_tool("my_tool")]));
367
368 let mut registry = ToolRegistry::new();
369 registry.register_mcp_service(&service).await.unwrap();
370
371 let handler = registry.get("my_tool").unwrap();
372 let result = handler.execute(serde_json::json!({}), &()).await.unwrap();
373
374 assert_eq!(result.content, "Called my_tool");
375 assert_eq!(service.call_count.load(Ordering::SeqCst), 1);
376 }
377}