Skip to main content

mcp/
server.rs

1//! Unified MCP Server
2//!
3//! This module provides a concrete MCP server implementation that aggregates
4//! multiple tools and implements the ToolProtocol trait, routing tool calls
5//! to the appropriate underlying tool implementation.
6//!
7//! The server acts as a dispatcher that can be deployed as an HTTP service,
8//! allowing multiple agents (local or remote) to access a unified set of tools
9//! through a single ToolProtocol interface.
10//!
11//! # Architecture
12//!
13//! ```text
14//! Multiple Tools (Memory, Bash, etc.)
15//!         ↓
16//! UnifiedMcpServer (implements ToolProtocol)
17//!         ↓
18//! HTTP Endpoints (GET /tools, POST /execute)
19//!         ↓
20//! Agents/Clients (via McpClientProtocol)
21//! ```
22//!
23//! # Example
24//!
25//! ```rust,no_run
26//! use cloudllm::mcp_server::UnifiedMcpServer;
27//! use cloudllm::tools::Memory;
28//! use cloudllm::tool_protocols::MemoryProtocol;
29//! use cloudllm::tool_protocol::ToolProtocol;
30//! use std::sync::Arc;
31//!
32//! # async {
33//! let memory = Arc::new(Memory::new());
34//! let memory_protocol = Arc::new(MemoryProtocol::new(memory));
35//!
36//! let mut server = UnifiedMcpServer::new();
37//! server.register_tool("memory", memory_protocol);
38//!
39//! // Now the server implements ToolProtocol and can route calls
40//! let tools = server.list_tools().await.unwrap();
41//! # };
42//! ```
43
44use crate::protocol::{ToolError, ToolMetadata, ToolProtocol, ToolResult};
45use async_trait::async_trait;
46use std::collections::HashMap;
47use std::error::Error;
48use std::sync::Arc;
49use tokio::sync::RwLock;
50
51/// A unified MCP server that aggregates multiple tools
52///
53/// The UnifiedMcpServer implements the ToolProtocol trait and routes
54/// tool execution requests to the appropriate underlying tool protocol
55/// implementation based on the tool name.
56///
57/// This allows a single server instance to expose multiple tools with
58/// different implementations, making it suitable for deployment as an
59/// MCP HTTP service that can be accessed by multiple agents.
60///
61/// # Thread Safety
62///
63/// The server is thread-safe and can be shared across multiple concurrent
64/// tool executions using `Arc<UnifiedMcpServer>`.
65#[derive(Clone)]
66pub struct UnifiedMcpServer {
67    /// Map of tool name to its ToolProtocol implementation
68    tools: Arc<RwLock<HashMap<String, Arc<dyn ToolProtocol>>>>,
69}
70
71impl UnifiedMcpServer {
72    /// Create a new empty unified MCP server
73    pub fn new() -> Self {
74        Self {
75            tools: Arc::new(RwLock::new(HashMap::new())),
76        }
77    }
78
79    /// Register a tool with the server
80    ///
81    /// # Arguments
82    ///
83    /// * `tool_name` - The identifier for the tool (e.g., "memory", "bash")
84    /// * `protocol` - The ToolProtocol implementation for this tool
85    ///
86    /// # Example
87    ///
88    /// ```rust,no_run
89    /// use cloudllm::mcp_server::UnifiedMcpServer;
90    /// use cloudllm::tools::Memory;
91    /// use cloudllm::tool_protocols::MemoryProtocol;
92    /// use std::sync::Arc;
93    ///
94    /// # #[tokio::main]
95    /// # async fn main() {
96    /// let memory = Arc::new(Memory::new());
97    /// let memory_protocol = Arc::new(MemoryProtocol::new(memory));
98    ///
99    /// let mut server = UnifiedMcpServer::new();
100    /// server.register_tool("memory", memory_protocol).await;
101    /// # }
102    /// ```
103    pub async fn register_tool(&mut self, tool_name: &str, protocol: Arc<dyn ToolProtocol>) {
104        let mut tools = self.tools.write().await;
105        tools.insert(tool_name.to_string(), protocol);
106    }
107
108    /// Unregister a tool from the server
109    pub async fn unregister_tool(&mut self, tool_name: &str) {
110        let mut tools = self.tools.write().await;
111        tools.remove(tool_name);
112    }
113
114    /// Check if a tool is registered
115    pub async fn has_tool(&self, tool_name: &str) -> bool {
116        let tools = self.tools.read().await;
117        tools.contains_key(tool_name)
118    }
119
120    /// Get the number of registered tools
121    pub async fn tool_count(&self) -> usize {
122        let tools = self.tools.read().await;
123        tools.len()
124    }
125}
126
127impl Default for UnifiedMcpServer {
128    fn default() -> Self {
129        Self::new()
130    }
131}
132
133#[async_trait]
134impl ToolProtocol for UnifiedMcpServer {
135    /// Execute a tool by routing to the appropriate protocol
136    ///
137    /// # Routing Logic
138    ///
139    /// 1. Look up the tool name in the registry
140    /// 2. If found, delegate to that tool's protocol
141    /// 3. If not found, return NotFound error
142    async fn execute(
143        &self,
144        tool_name: &str,
145        parameters: serde_json::Value,
146    ) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
147        let tools = self.tools.read().await;
148
149        let protocol = tools.get(tool_name).cloned().ok_or_else(|| {
150            Box::new(ToolError::NotFound(tool_name.to_string())) as Box<dyn Error + Send + Sync>
151        })?;
152
153        // Drop the read lock before executing to allow concurrent access
154        drop(tools);
155
156        // Route to the appropriate tool's protocol
157        protocol.execute(tool_name, parameters).await
158    }
159
160    /// List all available tools across all registered protocols
161    ///
162    /// This aggregates tool metadata from all registered tool protocols.
163    /// Each protocol is queried at most once even if multiple tool names
164    /// are registered to the same protocol instance.
165    async fn list_tools(&self) -> Result<Vec<ToolMetadata>, Box<dyn Error + Send + Sync>> {
166        let tools = self.tools.read().await;
167
168        // Deduplicate protocol instances by pointer so each protocol's list_tools()
169        // is called at most once (multiple tool names may point to the same protocol).
170        let mut seen: std::collections::HashSet<usize> = std::collections::HashSet::new();
171        let protocols: Vec<Arc<dyn ToolProtocol>> = tools
172            .values()
173            .filter(|p| seen.insert(Arc::as_ptr(*p) as *const () as usize))
174            .cloned()
175            .collect();
176
177        // Drop the read lock before making async calls
178        drop(tools);
179
180        let mut all_tools = Vec::new();
181
182        for protocol in protocols {
183            match protocol.list_tools().await {
184                Ok(mut tool_list) => all_tools.append(&mut tool_list),
185                Err(e) => {
186                    // Log but continue - we want to return what we can
187                    eprintln!("Error listing tools from protocol: {}", e);
188                }
189            }
190        }
191
192        Ok(all_tools)
193    }
194
195    /// Get metadata for a specific tool
196    ///
197    /// This searches across all registered protocols to find the tool.
198    async fn get_tool_metadata(
199        &self,
200        tool_name: &str,
201    ) -> Result<ToolMetadata, Box<dyn Error + Send + Sync>> {
202        let all_tools = self.list_tools().await?;
203        all_tools
204            .into_iter()
205            .find(|t| t.name == tool_name)
206            .ok_or_else(|| {
207                Box::new(ToolError::NotFound(tool_name.to_string())) as Box<dyn Error + Send + Sync>
208            })
209    }
210
211    /// Protocol identifier
212    fn protocol_name(&self) -> &str {
213        "unified-mcp-server"
214    }
215
216    /// Initialize the server (initializes all registered protocols)
217    async fn initialize(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
218        let _tools = self.tools.read().await;
219
220        // Note: We can't call initialize on Arc<dyn ToolProtocol> since
221        // it takes &mut self. This is a limitation of the current design.
222        // Future: Consider a separate initialization registry or use Arc<Mutex<>>.
223
224        Ok(())
225    }
226
227    /// Shutdown the server (shuts down all registered protocols)
228    async fn shutdown(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
229        let _tools = self.tools.read().await;
230
231        // Same limitation as initialize - we need Arc<Mutex<>> for protocols
232        // that need shutdown handling.
233
234        Ok(())
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use crate::protocol::ToolMetadata;
242
243    /// Mock tool protocol for testing
244    struct MockToolProtocol {
245        name: String,
246    }
247
248    #[async_trait]
249    impl ToolProtocol for MockToolProtocol {
250        async fn execute(
251            &self,
252            tool_name: &str,
253            _parameters: serde_json::Value,
254        ) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
255            Ok(ToolResult::success(serde_json::json!({
256                "tool": tool_name,
257                "source": &self.name
258            })))
259        }
260
261        async fn list_tools(&self) -> Result<Vec<ToolMetadata>, Box<dyn Error + Send + Sync>> {
262            Ok(vec![ToolMetadata::new(&self.name, "A mock tool")])
263        }
264
265        async fn get_tool_metadata(
266            &self,
267            tool_name: &str,
268        ) -> Result<ToolMetadata, Box<dyn Error + Send + Sync>> {
269            if tool_name == self.name {
270                Ok(ToolMetadata::new(&self.name, "A mock tool"))
271            } else {
272                Err(Box::new(ToolError::NotFound(tool_name.to_string())))
273            }
274        }
275
276        fn protocol_name(&self) -> &str {
277            "mock"
278        }
279    }
280
281    #[tokio::test]
282    async fn test_unified_server_creation() {
283        let server = UnifiedMcpServer::new();
284        assert_eq!(server.tool_count().await, 0);
285        assert_eq!(server.protocol_name(), "unified-mcp-server");
286    }
287
288    #[tokio::test]
289    async fn test_register_single_tool() {
290        let mut server = UnifiedMcpServer::new();
291        let mock = Arc::new(MockToolProtocol {
292            name: "test_tool".to_string(),
293        });
294
295        server.register_tool("test_tool", mock).await;
296        assert_eq!(server.tool_count().await, 1);
297        assert!(server.has_tool("test_tool").await);
298    }
299
300    #[tokio::test]
301    async fn test_register_multiple_tools() {
302        let mut server = UnifiedMcpServer::new();
303        let mock1 = Arc::new(MockToolProtocol {
304            name: "tool1".to_string(),
305        });
306        let mock2 = Arc::new(MockToolProtocol {
307            name: "tool2".to_string(),
308        });
309
310        server.register_tool("tool1", mock1).await;
311        server.register_tool("tool2", mock2).await;
312        assert_eq!(server.tool_count().await, 2);
313        assert!(server.has_tool("tool1").await);
314        assert!(server.has_tool("tool2").await);
315    }
316
317    #[tokio::test]
318    async fn test_execute_tool_routing() {
319        let mut server = UnifiedMcpServer::new();
320        let mock = Arc::new(MockToolProtocol {
321            name: "router_test".to_string(),
322        });
323
324        server.register_tool("router_test", mock).await;
325
326        let result = server.execute("router_test", serde_json::json!({})).await;
327
328        assert!(result.is_ok());
329        let tool_result = result.unwrap();
330        assert!(tool_result.success);
331        assert_eq!(tool_result.output["tool"], "router_test");
332    }
333
334    #[tokio::test]
335    async fn test_execute_nonexistent_tool() {
336        let server = UnifiedMcpServer::new();
337
338        let result = server.execute("nonexistent", serde_json::json!({})).await;
339
340        assert!(result.is_err());
341        let err = result.unwrap_err().to_string();
342        assert!(err.contains("not found") || err.contains("NotFound"));
343    }
344
345    #[tokio::test]
346    async fn test_list_tools_aggregation() {
347        let mut server = UnifiedMcpServer::new();
348        let mock1 = Arc::new(MockToolProtocol {
349            name: "tool1".to_string(),
350        });
351        let mock2 = Arc::new(MockToolProtocol {
352            name: "tool2".to_string(),
353        });
354
355        server.register_tool("tool1", mock1).await;
356        server.register_tool("tool2", mock2).await;
357
358        let tools = server.list_tools().await.unwrap();
359        assert_eq!(tools.len(), 2);
360        assert!(tools.iter().any(|t| t.name == "tool1"));
361        assert!(tools.iter().any(|t| t.name == "tool2"));
362    }
363
364    #[tokio::test]
365    async fn test_get_tool_metadata() {
366        let mut server = UnifiedMcpServer::new();
367        let mock = Arc::new(MockToolProtocol {
368            name: "metadata_test".to_string(),
369        });
370
371        server.register_tool("metadata_test", mock).await;
372
373        let metadata = server.get_tool_metadata("metadata_test").await;
374        assert!(metadata.is_ok());
375        assert_eq!(metadata.unwrap().name, "metadata_test");
376    }
377
378    #[tokio::test]
379    async fn test_unregister_tool() {
380        let mut server = UnifiedMcpServer::new();
381        let mock = Arc::new(MockToolProtocol {
382            name: "temp_tool".to_string(),
383        });
384
385        server.register_tool("temp_tool", mock).await;
386        assert_eq!(server.tool_count().await, 1);
387
388        server.unregister_tool("temp_tool").await;
389        assert_eq!(server.tool_count().await, 0);
390        assert!(!server.has_tool("temp_tool").await);
391    }
392
393    #[tokio::test]
394    async fn test_default_constructor() {
395        let server = UnifiedMcpServer::default();
396        assert_eq!(server.tool_count().await, 0);
397    }
398}