1use crate::protocol::{ToolError, ToolMetadata, ToolProtocol, ToolResult};
45use async_trait::async_trait;
46use std::collections::HashMap;
47use std::error::Error;
48use std::sync::Arc;
49use tokio::sync::RwLock;
50
51#[derive(Clone)]
66pub struct UnifiedMcpServer {
67 tools: Arc<RwLock<HashMap<String, Arc<dyn ToolProtocol>>>>,
69}
70
71impl UnifiedMcpServer {
72 pub fn new() -> Self {
74 Self {
75 tools: Arc::new(RwLock::new(HashMap::new())),
76 }
77 }
78
79 pub async fn register_tool(&mut self, tool_name: &str, protocol: Arc<dyn ToolProtocol>) {
104 let mut tools = self.tools.write().await;
105 tools.insert(tool_name.to_string(), protocol);
106 }
107
108 pub async fn unregister_tool(&mut self, tool_name: &str) {
110 let mut tools = self.tools.write().await;
111 tools.remove(tool_name);
112 }
113
114 pub async fn has_tool(&self, tool_name: &str) -> bool {
116 let tools = self.tools.read().await;
117 tools.contains_key(tool_name)
118 }
119
120 pub async fn tool_count(&self) -> usize {
122 let tools = self.tools.read().await;
123 tools.len()
124 }
125}
126
127impl Default for UnifiedMcpServer {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133#[async_trait]
134impl ToolProtocol for UnifiedMcpServer {
135 async fn execute(
143 &self,
144 tool_name: &str,
145 parameters: serde_json::Value,
146 ) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
147 let tools = self.tools.read().await;
148
149 let protocol = tools.get(tool_name).cloned().ok_or_else(|| {
150 Box::new(ToolError::NotFound(tool_name.to_string())) as Box<dyn Error + Send + Sync>
151 })?;
152
153 drop(tools);
155
156 protocol.execute(tool_name, parameters).await
158 }
159
160 async fn list_tools(&self) -> Result<Vec<ToolMetadata>, Box<dyn Error + Send + Sync>> {
166 let tools = self.tools.read().await;
167
168 let mut seen: std::collections::HashSet<usize> = std::collections::HashSet::new();
171 let protocols: Vec<Arc<dyn ToolProtocol>> = tools
172 .values()
173 .filter(|p| seen.insert(Arc::as_ptr(*p) as *const () as usize))
174 .cloned()
175 .collect();
176
177 drop(tools);
179
180 let mut all_tools = Vec::new();
181
182 for protocol in protocols {
183 match protocol.list_tools().await {
184 Ok(mut tool_list) => all_tools.append(&mut tool_list),
185 Err(e) => {
186 eprintln!("Error listing tools from protocol: {}", e);
188 }
189 }
190 }
191
192 Ok(all_tools)
193 }
194
195 async fn get_tool_metadata(
199 &self,
200 tool_name: &str,
201 ) -> Result<ToolMetadata, Box<dyn Error + Send + Sync>> {
202 let all_tools = self.list_tools().await?;
203 all_tools
204 .into_iter()
205 .find(|t| t.name == tool_name)
206 .ok_or_else(|| {
207 Box::new(ToolError::NotFound(tool_name.to_string())) as Box<dyn Error + Send + Sync>
208 })
209 }
210
211 fn protocol_name(&self) -> &str {
213 "unified-mcp-server"
214 }
215
216 async fn initialize(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
218 let _tools = self.tools.read().await;
219
220 Ok(())
225 }
226
227 async fn shutdown(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
229 let _tools = self.tools.read().await;
230
231 Ok(())
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use crate::protocol::ToolMetadata;
242
243 struct MockToolProtocol {
245 name: String,
246 }
247
248 #[async_trait]
249 impl ToolProtocol for MockToolProtocol {
250 async fn execute(
251 &self,
252 tool_name: &str,
253 _parameters: serde_json::Value,
254 ) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
255 Ok(ToolResult::success(serde_json::json!({
256 "tool": tool_name,
257 "source": &self.name
258 })))
259 }
260
261 async fn list_tools(&self) -> Result<Vec<ToolMetadata>, Box<dyn Error + Send + Sync>> {
262 Ok(vec![ToolMetadata::new(&self.name, "A mock tool")])
263 }
264
265 async fn get_tool_metadata(
266 &self,
267 tool_name: &str,
268 ) -> Result<ToolMetadata, Box<dyn Error + Send + Sync>> {
269 if tool_name == self.name {
270 Ok(ToolMetadata::new(&self.name, "A mock tool"))
271 } else {
272 Err(Box::new(ToolError::NotFound(tool_name.to_string())))
273 }
274 }
275
276 fn protocol_name(&self) -> &str {
277 "mock"
278 }
279 }
280
281 #[tokio::test]
282 async fn test_unified_server_creation() {
283 let server = UnifiedMcpServer::new();
284 assert_eq!(server.tool_count().await, 0);
285 assert_eq!(server.protocol_name(), "unified-mcp-server");
286 }
287
288 #[tokio::test]
289 async fn test_register_single_tool() {
290 let mut server = UnifiedMcpServer::new();
291 let mock = Arc::new(MockToolProtocol {
292 name: "test_tool".to_string(),
293 });
294
295 server.register_tool("test_tool", mock).await;
296 assert_eq!(server.tool_count().await, 1);
297 assert!(server.has_tool("test_tool").await);
298 }
299
300 #[tokio::test]
301 async fn test_register_multiple_tools() {
302 let mut server = UnifiedMcpServer::new();
303 let mock1 = Arc::new(MockToolProtocol {
304 name: "tool1".to_string(),
305 });
306 let mock2 = Arc::new(MockToolProtocol {
307 name: "tool2".to_string(),
308 });
309
310 server.register_tool("tool1", mock1).await;
311 server.register_tool("tool2", mock2).await;
312 assert_eq!(server.tool_count().await, 2);
313 assert!(server.has_tool("tool1").await);
314 assert!(server.has_tool("tool2").await);
315 }
316
317 #[tokio::test]
318 async fn test_execute_tool_routing() {
319 let mut server = UnifiedMcpServer::new();
320 let mock = Arc::new(MockToolProtocol {
321 name: "router_test".to_string(),
322 });
323
324 server.register_tool("router_test", mock).await;
325
326 let result = server.execute("router_test", serde_json::json!({})).await;
327
328 assert!(result.is_ok());
329 let tool_result = result.unwrap();
330 assert!(tool_result.success);
331 assert_eq!(tool_result.output["tool"], "router_test");
332 }
333
334 #[tokio::test]
335 async fn test_execute_nonexistent_tool() {
336 let server = UnifiedMcpServer::new();
337
338 let result = server.execute("nonexistent", serde_json::json!({})).await;
339
340 assert!(result.is_err());
341 let err = result.unwrap_err().to_string();
342 assert!(err.contains("not found") || err.contains("NotFound"));
343 }
344
345 #[tokio::test]
346 async fn test_list_tools_aggregation() {
347 let mut server = UnifiedMcpServer::new();
348 let mock1 = Arc::new(MockToolProtocol {
349 name: "tool1".to_string(),
350 });
351 let mock2 = Arc::new(MockToolProtocol {
352 name: "tool2".to_string(),
353 });
354
355 server.register_tool("tool1", mock1).await;
356 server.register_tool("tool2", mock2).await;
357
358 let tools = server.list_tools().await.unwrap();
359 assert_eq!(tools.len(), 2);
360 assert!(tools.iter().any(|t| t.name == "tool1"));
361 assert!(tools.iter().any(|t| t.name == "tool2"));
362 }
363
364 #[tokio::test]
365 async fn test_get_tool_metadata() {
366 let mut server = UnifiedMcpServer::new();
367 let mock = Arc::new(MockToolProtocol {
368 name: "metadata_test".to_string(),
369 });
370
371 server.register_tool("metadata_test", mock).await;
372
373 let metadata = server.get_tool_metadata("metadata_test").await;
374 assert!(metadata.is_ok());
375 assert_eq!(metadata.unwrap().name, "metadata_test");
376 }
377
378 #[tokio::test]
379 async fn test_unregister_tool() {
380 let mut server = UnifiedMcpServer::new();
381 let mock = Arc::new(MockToolProtocol {
382 name: "temp_tool".to_string(),
383 });
384
385 server.register_tool("temp_tool", mock).await;
386 assert_eq!(server.tool_count().await, 1);
387
388 server.unregister_tool("temp_tool").await;
389 assert_eq!(server.tool_count().await, 0);
390 assert!(!server.has_tool("temp_tool").await);
391 }
392
393 #[tokio::test]
394 async fn test_default_constructor() {
395 let server = UnifiedMcpServer::default();
396 assert_eq!(server.tool_count().await, 0);
397 }
398}