llm_stack/mcp.rs
1//! MCP (Model Context Protocol) integration for `ToolRegistry`.
2//!
3//! This module defines the [`McpService`] trait - a minimal contract for MCP-like
4//! tool sources. Users implement this trait for their chosen MCP client library
5//! (e.g., `rmcp`), keeping version coupling in their code rather than in llm-core.
6//!
7//! # Design Philosophy
8//!
9//! Rather than depending on a specific MCP client library, llm-core defines a
10//! simple trait that any MCP implementation can satisfy. This means:
11//!
12//! - No forced dependency upgrades when MCP libraries release new versions
13//! - Users choose their preferred MCP client and version
14//! - Easy to mock for testing
15//! - Simple enough to implement (~50 lines for rmcp)
16//!
17//! # Example: Implementing for rmcp
18//!
19//! ```rust,ignore
20//! use std::sync::Arc;
21//! use rmcp::service::RunningService;
22//! use rmcp::handler::client::ClientHandler;
23//! use rmcp::RoleClient;
24//! use llm_stack::{McpService, McpError, ToolDefinition, JsonSchema};
25//!
26//! /// Adapter: wraps rmcp RunningService to implement llm-core's McpService
27//! pub struct RmcpAdapter<S: ClientHandler> {
28//! service: Arc<RunningService<RoleClient, S>>,
29//! }
30//!
31//! impl<S: ClientHandler> RmcpAdapter<S> {
32//! pub fn new(service: Arc<RunningService<RoleClient, S>>) -> Self {
33//! Self { service }
34//! }
35//! }
36//!
37//! impl<S: ClientHandler> McpService for RmcpAdapter<S> {
38//! async fn list_tools(&self) -> Result<Vec<ToolDefinition>, McpError> {
39//! let tools = self.service
40//! .list_all_tools()
41//! .await
42//! .map_err(|e| McpError::Protocol(e.to_string()))?;
43//!
44//! Ok(tools.into_iter().map(|t| ToolDefinition {
45//! name: t.name.to_string(),
46//! description: t.description.map(|d| d.to_string()).unwrap_or_default(),
47//! parameters: JsonSchema::new(
48//! serde_json::to_value(&*t.input_schema).unwrap_or_default()
49//! ),
50//! }).collect())
51//! }
52//!
53//! async fn call_tool(&self, name: &str, args: serde_json::Value) -> Result<String, McpError> {
54//! use rmcp::model::{CallToolRequestParams, RawContent};
55//!
56//! let params = CallToolRequestParams {
57//! meta: None,
58//! name: name.to_string().into(),
59//! arguments: args.as_object().cloned(),
60//! task: None,
61//! };
62//!
63//! let result = self.service
64//! .call_tool(params)
65//! .await
66//! .map_err(|e| McpError::ToolExecution(e.to_string()))?;
67//!
68//! if result.is_error.unwrap_or(false) {
69//! return Err(McpError::ToolExecution(extract_text(&result.content)));
70//! }
71//!
72//! Ok(extract_text(&result.content))
73//! }
74//! }
75//!
76//! fn extract_text(content: &[rmcp::model::Content]) -> String {
77//! use rmcp::model::RawContent;
78//! content.iter().map(|c| match &c.raw {
79//! RawContent::Text(t) => t.text.clone(),
80//! _ => "[non-text]".into(),
81//! }).collect::<Vec<_>>().join("\n")
82//! }
83//! ```
84//!
85//! # Usage
86//!
87//! ```rust,ignore
88//! use std::sync::Arc;
89//! use llm_stack::{ToolRegistry, McpRegistryExt};
90//!
91//! // Create your MCP service (using rmcp or any other library)
92//! let mcp_service = Arc::new(RmcpAdapter::new(rmcp_client));
93//!
94//! // Register tools with llm-core
95//! let mut registry = ToolRegistry::new();
96//! registry.register_mcp_service(&mcp_service).await?;
97//! ```
98
99use std::future::Future;
100use std::pin::Pin;
101use std::sync::Arc;
102
103use crate::tool::{ToolError, ToolHandler};
104use crate::{ToolDefinition, ToolRegistry};
105
106/// Error type for MCP operations.
107#[derive(Debug, Clone, thiserror::Error)]
108pub enum McpError {
109 /// Error during MCP protocol communication.
110 #[error("MCP protocol error: {0}")]
111 Protocol(String),
112
113 /// Error during tool execution.
114 #[error("MCP tool execution error: {0}")]
115 ToolExecution(String),
116}
117
118/// Minimal contract for MCP-like tool sources.
119///
120/// Implement this trait to bridge any MCP client library with llm-core's
121/// tool system. The trait requires only two operations:
122///
123/// 1. List available tools (with their schemas)
124/// 2. Call a tool by name with JSON arguments
125///
126/// # Thread Safety
127///
128/// Implementations must be `Send + Sync` to allow use across async tasks.
129/// Typically achieved by wrapping the underlying client in `Arc`.
130///
131/// # Object Safety
132///
133/// This trait is object-safe (`dyn McpService`) to allow storing different
134/// MCP service implementations in the same registry.
135pub trait McpService: Send + Sync {
136 /// Lists all tools available from the MCP server.
137 fn list_tools(
138 &self,
139 ) -> Pin<Box<dyn Future<Output = Result<Vec<ToolDefinition>, McpError>> + Send + '_>>;
140
141 /// Calls a tool on the MCP server.
142 ///
143 /// # Arguments
144 ///
145 /// * `name` - The name of the tool to call
146 /// * `args` - JSON arguments matching the tool's input schema
147 ///
148 /// # Returns
149 ///
150 /// The tool's text output, or an error if execution failed.
151 fn call_tool(
152 &self,
153 name: &str,
154 args: serde_json::Value,
155 ) -> Pin<Box<dyn Future<Output = Result<String, McpError>> + Send + '_>>;
156}
157
158/// A [`ToolHandler`] that delegates execution to an [`McpService`].
159struct McpToolHandler {
160 service: Arc<dyn McpService>,
161 definition: ToolDefinition,
162}
163
164impl McpToolHandler {
165 fn new(service: Arc<dyn McpService>, definition: ToolDefinition) -> Self {
166 Self {
167 service,
168 definition,
169 }
170 }
171}
172
173impl std::fmt::Debug for McpToolHandler {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 f.debug_struct("McpToolHandler")
176 .field("tool", &self.definition.name)
177 .finish_non_exhaustive()
178 }
179}
180
181impl ToolHandler<()> for McpToolHandler {
182 fn definition(&self) -> ToolDefinition {
183 self.definition.clone()
184 }
185
186 fn execute<'a>(
187 &'a self,
188 input: serde_json::Value,
189 _ctx: &'a (),
190 ) -> Pin<Box<dyn Future<Output = Result<crate::tool::ToolOutput, ToolError>> + Send + 'a>> {
191 Box::pin(async move {
192 self.service
193 .call_tool(&self.definition.name, input)
194 .await
195 .map(crate::tool::ToolOutput::new)
196 .map_err(|e| ToolError::new(e.to_string()))
197 })
198 }
199}
200
201/// Extension trait for registering MCP services with a [`ToolRegistry`].
202pub trait McpRegistryExt {
203 /// Registers all tools from an MCP service.
204 ///
205 /// Discovers tools from the service and registers each one as a
206 /// `ToolHandler` that delegates execution to the service.
207 ///
208 /// # Returns
209 ///
210 /// The number of tools registered.
211 fn register_mcp_service<S: McpService + 'static>(
212 &mut self,
213 service: &Arc<S>,
214 ) -> impl Future<Output = Result<usize, McpError>> + Send;
215
216 /// Registers specific tools from an MCP service by name.
217 ///
218 /// Only registers tools whose names are in the provided list.
219 /// Tools not found on the server are silently skipped.
220 ///
221 /// # Returns
222 ///
223 /// The number of tools actually registered.
224 fn register_mcp_tools_by_name<S: McpService + 'static>(
225 &mut self,
226 service: &Arc<S>,
227 tool_names: &[&str],
228 ) -> impl Future<Output = Result<usize, McpError>> + Send;
229}
230
231impl McpRegistryExt for ToolRegistry<()> {
232 async fn register_mcp_service<S: McpService + 'static>(
233 &mut self,
234 service: &Arc<S>,
235 ) -> Result<usize, McpError> {
236 let tools = service.list_tools().await?;
237 let count = tools.len();
238
239 for definition in tools {
240 let handler =
241 McpToolHandler::new(Arc::clone(service) as Arc<dyn McpService>, definition);
242 self.register(handler);
243 }
244
245 Ok(count)
246 }
247
248 async fn register_mcp_tools_by_name<S: McpService + 'static>(
249 &mut self,
250 service: &Arc<S>,
251 tool_names: &[&str],
252 ) -> Result<usize, McpError> {
253 let tools = service.list_tools().await?;
254 let mut count = 0;
255
256 for definition in tools {
257 if tool_names.contains(&definition.name.as_str()) {
258 let handler =
259 McpToolHandler::new(Arc::clone(service) as Arc<dyn McpService>, definition);
260 self.register(handler);
261 count += 1;
262 }
263 }
264
265 Ok(count)
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::JsonSchema;
273 use std::sync::atomic::{AtomicUsize, Ordering};
274
275 /// Mock MCP service for testing
276 struct MockMcpService {
277 tools: Vec<ToolDefinition>,
278 call_count: AtomicUsize,
279 }
280
281 impl MockMcpService {
282 fn new(tools: Vec<ToolDefinition>) -> Self {
283 Self {
284 tools,
285 call_count: AtomicUsize::new(0),
286 }
287 }
288 }
289
290 impl McpService for MockMcpService {
291 fn list_tools(
292 &self,
293 ) -> Pin<Box<dyn Future<Output = Result<Vec<ToolDefinition>, McpError>> + Send + '_>>
294 {
295 let tools = self.tools.clone();
296 Box::pin(async move { Ok(tools) })
297 }
298
299 fn call_tool(
300 &self,
301 name: &str,
302 _args: serde_json::Value,
303 ) -> Pin<Box<dyn Future<Output = Result<String, McpError>> + Send + '_>> {
304 self.call_count.fetch_add(1, Ordering::SeqCst);
305 let result = format!("Called {name}");
306 Box::pin(async move { Ok(result) })
307 }
308 }
309
310 fn test_tool(name: &str) -> ToolDefinition {
311 ToolDefinition {
312 name: name.to_string(),
313 description: format!("{name} description"),
314 parameters: JsonSchema::new(serde_json::json!({"type": "object"})),
315 retry: None,
316 }
317 }
318
319 #[test]
320 fn test_trait_is_object_safe() {
321 fn assert_object_safe(_: &dyn McpService) {}
322 let mock = MockMcpService::new(vec![]);
323 assert_object_safe(&mock);
324 }
325
326 #[tokio::test]
327 async fn test_register_mcp_service() {
328 let service = Arc::new(MockMcpService::new(vec![
329 test_tool("tool_a"),
330 test_tool("tool_b"),
331 ]));
332
333 let mut registry = ToolRegistry::new();
334 let count = registry.register_mcp_service(&service).await.unwrap();
335
336 assert_eq!(count, 2);
337 assert_eq!(registry.len(), 2);
338 assert!(registry.get("tool_a").is_some());
339 assert!(registry.get("tool_b").is_some());
340 }
341
342 #[tokio::test]
343 async fn test_register_mcp_tools_by_name() {
344 let service = Arc::new(MockMcpService::new(vec![
345 test_tool("tool_a"),
346 test_tool("tool_b"),
347 test_tool("tool_c"),
348 ]));
349
350 let mut registry = ToolRegistry::new();
351 let count = registry
352 .register_mcp_tools_by_name(&service, &["tool_a", "tool_c"])
353 .await
354 .unwrap();
355
356 assert_eq!(count, 2);
357 assert!(registry.get("tool_a").is_some());
358 assert!(registry.get("tool_b").is_none());
359 assert!(registry.get("tool_c").is_some());
360 }
361
362 #[tokio::test]
363 async fn test_mcp_tool_execution() {
364 let service = Arc::new(MockMcpService::new(vec![test_tool("my_tool")]));
365
366 let mut registry = ToolRegistry::new();
367 registry.register_mcp_service(&service).await.unwrap();
368
369 let handler = registry.get("my_tool").unwrap();
370 let result = handler.execute(serde_json::json!({}), &()).await.unwrap();
371
372 assert_eq!(result.content, "Called my_tool");
373 assert_eq!(service.call_count.load(Ordering::SeqCst), 1);
374 }
375}