Skip to main content

mcp/
server.rs

1//! Unified MCP Server
2//!
3//! This module provides a concrete MCP server implementation that aggregates
4//! multiple tools and implements the ToolProtocol trait, routing tool calls
5//! to the appropriate underlying tool implementation.
6//!
7//! The server acts as a dispatcher that can be deployed as an HTTP service,
8//! allowing multiple agents (local or remote) to access a unified set of tools
9//! through a single ToolProtocol interface.
10//!
11//! # Architecture
12//!
13//! ```text
14//! Multiple Tools (Memory, Bash, etc.)
15//!         ↓
16//! UnifiedMcpServer (implements ToolProtocol)
17//!         ↓
18//! HTTP Endpoints (GET /tools, POST /execute)
19//!         ↓
20//! Agents/Clients (via McpClientProtocol)
21//! ```
22//!
23//! # Example
24//!
25//! ```ignore
26//! use async_trait::async_trait;
27//! use mcp::{ToolMetadata, ToolProtocol, ToolResult};
28//! use mcp::UnifiedMcpServer;
29//! use std::sync::Arc;
30//!
31//! struct MemoryProtocol;
32//!
33//! #[async_trait]
34//! impl ToolProtocol for MemoryProtocol {
35//!     async fn execute(
36//!         &self,
37//!         _tool_name: &str,
38//!         _parameters: serde_json::Value,
39//!     ) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> {
40//!         Ok(ToolResult::success(serde_json::json!({"ok": true})))
41//!     }
42//!
43//!     async fn list_tools(
44//!         &self,
45//!     ) -> Result<Vec<ToolMetadata>, Box<dyn std::error::Error + Send + Sync>> {
46//!         Ok(vec![])
47//!     }
48//! }
49//!
50//! # async {
51//! let memory_protocol = Arc::new(MemoryProtocol);
52//!
53//! let mut server = UnifiedMcpServer::new();
54//! server.register_tool("memory", memory_protocol).await;
55//!
56//! // Now the server implements ToolProtocol and can route calls
57//! let tools = server.list_tools().await.unwrap();
58//! # };
59//! ```
60
61use crate::protocol::{ToolError, ToolMetadata, ToolProtocol, ToolResult};
62use async_trait::async_trait;
63use std::collections::HashMap;
64use std::error::Error;
65use std::sync::Arc;
66use tokio::sync::RwLock;
67
68/// A unified MCP server that aggregates multiple tools
69///
70/// The UnifiedMcpServer implements the ToolProtocol trait and routes
71/// tool execution requests to the appropriate underlying tool protocol
72/// implementation based on the tool name.
73///
74/// This allows a single server instance to expose multiple tools with
75/// different implementations, making it suitable for deployment as an
76/// MCP HTTP service that can be accessed by multiple agents.
77///
78/// # Thread Safety
79///
80/// The server is thread-safe and can be shared across multiple concurrent
81/// tool executions using `Arc<UnifiedMcpServer>`.
82#[derive(Clone)]
83pub struct UnifiedMcpServer {
84    /// Map of tool name to its ToolProtocol implementation
85    tools: Arc<RwLock<HashMap<String, Arc<dyn ToolProtocol>>>>,
86}
87
88impl UnifiedMcpServer {
89    /// Create a new empty unified MCP server
90    pub fn new() -> Self {
91        Self {
92            tools: Arc::new(RwLock::new(HashMap::new())),
93        }
94    }
95
96    /// Register a tool with the server
97    ///
98    /// # Arguments
99    ///
100    /// * `tool_name` - The identifier for the tool (e.g., "memory", "bash")
101    /// * `protocol` - The ToolProtocol implementation for this tool
102    ///
103    /// # Example
104    ///
105    /// ```ignore
106    /// use async_trait::async_trait;
107    /// use mcp::{ToolMetadata, ToolProtocol, ToolResult, UnifiedMcpServer};
108    /// use std::sync::Arc;
109    ///
110    /// struct MemoryProtocol;
111    ///
112    /// #[async_trait]
113    /// impl ToolProtocol for MemoryProtocol {
114    ///     async fn execute(
115    ///         &self,
116    ///         _tool_name: &str,
117    ///         _parameters: serde_json::Value,
118    ///     ) -> Result<ToolResult, Box<dyn std::error::Error + Send + Sync>> {
119    ///         Ok(ToolResult::success(serde_json::json!({"ok": true})))
120    ///     }
121    ///
122    ///     async fn list_tools(
123    ///         &self,
124    ///     ) -> Result<Vec<ToolMetadata>, Box<dyn std::error::Error + Send + Sync>> {
125    ///         Ok(vec![])
126    ///     }
127    /// }
128    ///
129    /// # #[tokio::main]
130    /// # async fn main() {
131    /// let memory_protocol = Arc::new(MemoryProtocol);
132    ///
133    /// let mut server = UnifiedMcpServer::new();
134    /// server.register_tool("memory", memory_protocol).await;
135    /// # }
136    /// ```
137    pub async fn register_tool(&mut self, tool_name: &str, protocol: Arc<dyn ToolProtocol>) {
138        let mut tools = self.tools.write().await;
139        tools.insert(tool_name.to_string(), protocol);
140    }
141
142    /// Unregister a tool from the server
143    pub async fn unregister_tool(&mut self, tool_name: &str) {
144        let mut tools = self.tools.write().await;
145        tools.remove(tool_name);
146    }
147
148    /// Check if a tool is registered
149    pub async fn has_tool(&self, tool_name: &str) -> bool {
150        let tools = self.tools.read().await;
151        tools.contains_key(tool_name)
152    }
153
154    /// Get the number of registered tools
155    pub async fn tool_count(&self) -> usize {
156        let tools = self.tools.read().await;
157        tools.len()
158    }
159}
160
161impl Default for UnifiedMcpServer {
162    fn default() -> Self {
163        Self::new()
164    }
165}
166
167#[async_trait]
168impl ToolProtocol for UnifiedMcpServer {
169    /// Execute a tool by routing to the appropriate protocol
170    ///
171    /// # Routing Logic
172    ///
173    /// 1. Look up the tool name in the registry
174    /// 2. If found, delegate to that tool's protocol
175    /// 3. If not found, return NotFound error
176    async fn execute(
177        &self,
178        tool_name: &str,
179        parameters: serde_json::Value,
180    ) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
181        let tools = self.tools.read().await;
182
183        let protocol = tools.get(tool_name).cloned().ok_or_else(|| {
184            Box::new(ToolError::NotFound(tool_name.to_string())) as Box<dyn Error + Send + Sync>
185        })?;
186
187        // Drop the read lock before executing to allow concurrent access
188        drop(tools);
189
190        // Route to the appropriate tool's protocol
191        protocol.execute(tool_name, parameters).await
192    }
193
194    /// List all available tools across all registered protocols
195    ///
196    /// This aggregates tool metadata from all registered tool protocols.
197    /// Each protocol is queried at most once even if multiple tool names
198    /// are registered to the same protocol instance.
199    async fn list_tools(&self) -> Result<Vec<ToolMetadata>, Box<dyn Error + Send + Sync>> {
200        let tools = self.tools.read().await;
201
202        // Deduplicate protocol instances by pointer so each protocol's list_tools()
203        // is called at most once (multiple tool names may point to the same protocol).
204        let mut seen: std::collections::HashSet<usize> = std::collections::HashSet::new();
205        let protocols: Vec<Arc<dyn ToolProtocol>> = tools
206            .values()
207            .filter(|p| seen.insert(Arc::as_ptr(*p) as *const () as usize))
208            .cloned()
209            .collect();
210
211        // Drop the read lock before making async calls
212        drop(tools);
213
214        let mut all_tools = Vec::new();
215
216        for protocol in protocols {
217            match protocol.list_tools().await {
218                Ok(mut tool_list) => all_tools.append(&mut tool_list),
219                Err(e) => {
220                    // Log but continue - we want to return what we can
221                    eprintln!("Error listing tools from protocol: {}", e);
222                }
223            }
224        }
225
226        Ok(all_tools)
227    }
228
229    /// Get metadata for a specific tool
230    ///
231    /// This searches across all registered protocols to find the tool.
232    async fn get_tool_metadata(
233        &self,
234        tool_name: &str,
235    ) -> Result<ToolMetadata, Box<dyn Error + Send + Sync>> {
236        let all_tools = self.list_tools().await?;
237        all_tools
238            .into_iter()
239            .find(|t| t.name == tool_name)
240            .ok_or_else(|| {
241                Box::new(ToolError::NotFound(tool_name.to_string())) as Box<dyn Error + Send + Sync>
242            })
243    }
244
245    /// Protocol identifier
246    fn protocol_name(&self) -> &str {
247        "unified-mcp-server"
248    }
249
250    /// Initialize the server (initializes all registered protocols)
251    async fn initialize(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
252        let _tools = self.tools.read().await;
253
254        // Note: We can't call initialize on Arc<dyn ToolProtocol> since
255        // it takes &mut self. This is a limitation of the current design.
256        // Future: Consider a separate initialization registry or use Arc<Mutex<>>.
257
258        Ok(())
259    }
260
261    /// Shutdown the server (shuts down all registered protocols)
262    async fn shutdown(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
263        let _tools = self.tools.read().await;
264
265        // Same limitation as initialize - we need Arc<Mutex<>> for protocols
266        // that need shutdown handling.
267
268        Ok(())
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use crate::protocol::ToolMetadata;
276
277    /// Mock tool protocol for testing
278    struct MockToolProtocol {
279        name: String,
280    }
281
282    #[async_trait]
283    impl ToolProtocol for MockToolProtocol {
284        async fn execute(
285            &self,
286            tool_name: &str,
287            _parameters: serde_json::Value,
288        ) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
289            Ok(ToolResult::success(serde_json::json!({
290                "tool": tool_name,
291                "source": &self.name
292            })))
293        }
294
295        async fn list_tools(&self) -> Result<Vec<ToolMetadata>, Box<dyn Error + Send + Sync>> {
296            Ok(vec![ToolMetadata::new(&self.name, "A mock tool")])
297        }
298
299        async fn get_tool_metadata(
300            &self,
301            tool_name: &str,
302        ) -> Result<ToolMetadata, Box<dyn Error + Send + Sync>> {
303            if tool_name == self.name {
304                Ok(ToolMetadata::new(&self.name, "A mock tool"))
305            } else {
306                Err(Box::new(ToolError::NotFound(tool_name.to_string())))
307            }
308        }
309
310        fn protocol_name(&self) -> &str {
311            "mock"
312        }
313    }
314
315    #[tokio::test]
316    async fn test_unified_server_creation() {
317        let server = UnifiedMcpServer::new();
318        assert_eq!(server.tool_count().await, 0);
319        assert_eq!(server.protocol_name(), "unified-mcp-server");
320    }
321
322    #[tokio::test]
323    async fn test_register_single_tool() {
324        let mut server = UnifiedMcpServer::new();
325        let mock = Arc::new(MockToolProtocol {
326            name: "test_tool".to_string(),
327        });
328
329        server.register_tool("test_tool", mock).await;
330        assert_eq!(server.tool_count().await, 1);
331        assert!(server.has_tool("test_tool").await);
332    }
333
334    #[tokio::test]
335    async fn test_register_multiple_tools() {
336        let mut server = UnifiedMcpServer::new();
337        let mock1 = Arc::new(MockToolProtocol {
338            name: "tool1".to_string(),
339        });
340        let mock2 = Arc::new(MockToolProtocol {
341            name: "tool2".to_string(),
342        });
343
344        server.register_tool("tool1", mock1).await;
345        server.register_tool("tool2", mock2).await;
346        assert_eq!(server.tool_count().await, 2);
347        assert!(server.has_tool("tool1").await);
348        assert!(server.has_tool("tool2").await);
349    }
350
351    #[tokio::test]
352    async fn test_execute_tool_routing() {
353        let mut server = UnifiedMcpServer::new();
354        let mock = Arc::new(MockToolProtocol {
355            name: "router_test".to_string(),
356        });
357
358        server.register_tool("router_test", mock).await;
359
360        let result = server.execute("router_test", serde_json::json!({})).await;
361
362        assert!(result.is_ok());
363        let tool_result = result.unwrap();
364        assert!(tool_result.success);
365        assert_eq!(tool_result.output["tool"], "router_test");
366    }
367
368    #[tokio::test]
369    async fn test_execute_nonexistent_tool() {
370        let server = UnifiedMcpServer::new();
371
372        let result = server.execute("nonexistent", serde_json::json!({})).await;
373
374        assert!(result.is_err());
375        let err = result.unwrap_err().to_string();
376        assert!(err.contains("not found") || err.contains("NotFound"));
377    }
378
379    #[tokio::test]
380    async fn test_list_tools_aggregation() {
381        let mut server = UnifiedMcpServer::new();
382        let mock1 = Arc::new(MockToolProtocol {
383            name: "tool1".to_string(),
384        });
385        let mock2 = Arc::new(MockToolProtocol {
386            name: "tool2".to_string(),
387        });
388
389        server.register_tool("tool1", mock1).await;
390        server.register_tool("tool2", mock2).await;
391
392        let tools = server.list_tools().await.unwrap();
393        assert_eq!(tools.len(), 2);
394        assert!(tools.iter().any(|t| t.name == "tool1"));
395        assert!(tools.iter().any(|t| t.name == "tool2"));
396    }
397
398    #[tokio::test]
399    async fn test_get_tool_metadata() {
400        let mut server = UnifiedMcpServer::new();
401        let mock = Arc::new(MockToolProtocol {
402            name: "metadata_test".to_string(),
403        });
404
405        server.register_tool("metadata_test", mock).await;
406
407        let metadata = server.get_tool_metadata("metadata_test").await;
408        assert!(metadata.is_ok());
409        assert_eq!(metadata.unwrap().name, "metadata_test");
410    }
411
412    #[tokio::test]
413    async fn test_unregister_tool() {
414        let mut server = UnifiedMcpServer::new();
415        let mock = Arc::new(MockToolProtocol {
416            name: "temp_tool".to_string(),
417        });
418
419        server.register_tool("temp_tool", mock).await;
420        assert_eq!(server.tool_count().await, 1);
421
422        server.unregister_tool("temp_tool").await;
423        assert_eq!(server.tool_count().await, 0);
424        assert!(!server.has_tool("temp_tool").await);
425    }
426
427    #[tokio::test]
428    async fn test_default_constructor() {
429        let server = UnifiedMcpServer::default();
430        assert_eq!(server.tool_count().await, 0);
431    }
432}