1use crate::protocol::{ToolError, ToolMetadata, ToolProtocol, ToolResult};
62use async_trait::async_trait;
63use std::collections::HashMap;
64use std::error::Error;
65use std::sync::Arc;
66use tokio::sync::RwLock;
67
68#[derive(Clone)]
83pub struct UnifiedMcpServer {
84 tools: Arc<RwLock<HashMap<String, Arc<dyn ToolProtocol>>>>,
86}
87
88impl UnifiedMcpServer {
89 pub fn new() -> Self {
91 Self {
92 tools: Arc::new(RwLock::new(HashMap::new())),
93 }
94 }
95
96 pub async fn register_tool(&mut self, tool_name: &str, protocol: Arc<dyn ToolProtocol>) {
138 let mut tools = self.tools.write().await;
139 tools.insert(tool_name.to_string(), protocol);
140 }
141
142 pub async fn unregister_tool(&mut self, tool_name: &str) {
144 let mut tools = self.tools.write().await;
145 tools.remove(tool_name);
146 }
147
148 pub async fn has_tool(&self, tool_name: &str) -> bool {
150 let tools = self.tools.read().await;
151 tools.contains_key(tool_name)
152 }
153
154 pub async fn tool_count(&self) -> usize {
156 let tools = self.tools.read().await;
157 tools.len()
158 }
159}
160
161impl Default for UnifiedMcpServer {
162 fn default() -> Self {
163 Self::new()
164 }
165}
166
167#[async_trait]
168impl ToolProtocol for UnifiedMcpServer {
169 async fn execute(
177 &self,
178 tool_name: &str,
179 parameters: serde_json::Value,
180 ) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
181 let tools = self.tools.read().await;
182
183 let protocol = tools.get(tool_name).cloned().ok_or_else(|| {
184 Box::new(ToolError::NotFound(tool_name.to_string())) as Box<dyn Error + Send + Sync>
185 })?;
186
187 drop(tools);
189
190 protocol.execute(tool_name, parameters).await
192 }
193
194 async fn list_tools(&self) -> Result<Vec<ToolMetadata>, Box<dyn Error + Send + Sync>> {
200 let tools = self.tools.read().await;
201
202 let mut seen: std::collections::HashSet<usize> = std::collections::HashSet::new();
205 let protocols: Vec<Arc<dyn ToolProtocol>> = tools
206 .values()
207 .filter(|p| seen.insert(Arc::as_ptr(*p) as *const () as usize))
208 .cloned()
209 .collect();
210
211 drop(tools);
213
214 let mut all_tools = Vec::new();
215
216 for protocol in protocols {
217 match protocol.list_tools().await {
218 Ok(mut tool_list) => all_tools.append(&mut tool_list),
219 Err(e) => {
220 eprintln!("Error listing tools from protocol: {}", e);
222 }
223 }
224 }
225
226 Ok(all_tools)
227 }
228
229 async fn get_tool_metadata(
233 &self,
234 tool_name: &str,
235 ) -> Result<ToolMetadata, Box<dyn Error + Send + Sync>> {
236 let all_tools = self.list_tools().await?;
237 all_tools
238 .into_iter()
239 .find(|t| t.name == tool_name)
240 .ok_or_else(|| {
241 Box::new(ToolError::NotFound(tool_name.to_string())) as Box<dyn Error + Send + Sync>
242 })
243 }
244
245 fn protocol_name(&self) -> &str {
247 "unified-mcp-server"
248 }
249
250 async fn initialize(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
252 let _tools = self.tools.read().await;
253
254 Ok(())
259 }
260
261 async fn shutdown(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
263 let _tools = self.tools.read().await;
264
265 Ok(())
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use crate::protocol::ToolMetadata;
276
277 struct MockToolProtocol {
279 name: String,
280 }
281
282 #[async_trait]
283 impl ToolProtocol for MockToolProtocol {
284 async fn execute(
285 &self,
286 tool_name: &str,
287 _parameters: serde_json::Value,
288 ) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
289 Ok(ToolResult::success(serde_json::json!({
290 "tool": tool_name,
291 "source": &self.name
292 })))
293 }
294
295 async fn list_tools(&self) -> Result<Vec<ToolMetadata>, Box<dyn Error + Send + Sync>> {
296 Ok(vec![ToolMetadata::new(&self.name, "A mock tool")])
297 }
298
299 async fn get_tool_metadata(
300 &self,
301 tool_name: &str,
302 ) -> Result<ToolMetadata, Box<dyn Error + Send + Sync>> {
303 if tool_name == self.name {
304 Ok(ToolMetadata::new(&self.name, "A mock tool"))
305 } else {
306 Err(Box::new(ToolError::NotFound(tool_name.to_string())))
307 }
308 }
309
310 fn protocol_name(&self) -> &str {
311 "mock"
312 }
313 }
314
315 #[tokio::test]
316 async fn test_unified_server_creation() {
317 let server = UnifiedMcpServer::new();
318 assert_eq!(server.tool_count().await, 0);
319 assert_eq!(server.protocol_name(), "unified-mcp-server");
320 }
321
322 #[tokio::test]
323 async fn test_register_single_tool() {
324 let mut server = UnifiedMcpServer::new();
325 let mock = Arc::new(MockToolProtocol {
326 name: "test_tool".to_string(),
327 });
328
329 server.register_tool("test_tool", mock).await;
330 assert_eq!(server.tool_count().await, 1);
331 assert!(server.has_tool("test_tool").await);
332 }
333
334 #[tokio::test]
335 async fn test_register_multiple_tools() {
336 let mut server = UnifiedMcpServer::new();
337 let mock1 = Arc::new(MockToolProtocol {
338 name: "tool1".to_string(),
339 });
340 let mock2 = Arc::new(MockToolProtocol {
341 name: "tool2".to_string(),
342 });
343
344 server.register_tool("tool1", mock1).await;
345 server.register_tool("tool2", mock2).await;
346 assert_eq!(server.tool_count().await, 2);
347 assert!(server.has_tool("tool1").await);
348 assert!(server.has_tool("tool2").await);
349 }
350
351 #[tokio::test]
352 async fn test_execute_tool_routing() {
353 let mut server = UnifiedMcpServer::new();
354 let mock = Arc::new(MockToolProtocol {
355 name: "router_test".to_string(),
356 });
357
358 server.register_tool("router_test", mock).await;
359
360 let result = server.execute("router_test", serde_json::json!({})).await;
361
362 assert!(result.is_ok());
363 let tool_result = result.unwrap();
364 assert!(tool_result.success);
365 assert_eq!(tool_result.output["tool"], "router_test");
366 }
367
368 #[tokio::test]
369 async fn test_execute_nonexistent_tool() {
370 let server = UnifiedMcpServer::new();
371
372 let result = server.execute("nonexistent", serde_json::json!({})).await;
373
374 assert!(result.is_err());
375 let err = result.unwrap_err().to_string();
376 assert!(err.contains("not found") || err.contains("NotFound"));
377 }
378
379 #[tokio::test]
380 async fn test_list_tools_aggregation() {
381 let mut server = UnifiedMcpServer::new();
382 let mock1 = Arc::new(MockToolProtocol {
383 name: "tool1".to_string(),
384 });
385 let mock2 = Arc::new(MockToolProtocol {
386 name: "tool2".to_string(),
387 });
388
389 server.register_tool("tool1", mock1).await;
390 server.register_tool("tool2", mock2).await;
391
392 let tools = server.list_tools().await.unwrap();
393 assert_eq!(tools.len(), 2);
394 assert!(tools.iter().any(|t| t.name == "tool1"));
395 assert!(tools.iter().any(|t| t.name == "tool2"));
396 }
397
398 #[tokio::test]
399 async fn test_get_tool_metadata() {
400 let mut server = UnifiedMcpServer::new();
401 let mock = Arc::new(MockToolProtocol {
402 name: "metadata_test".to_string(),
403 });
404
405 server.register_tool("metadata_test", mock).await;
406
407 let metadata = server.get_tool_metadata("metadata_test").await;
408 assert!(metadata.is_ok());
409 assert_eq!(metadata.unwrap().name, "metadata_test");
410 }
411
412 #[tokio::test]
413 async fn test_unregister_tool() {
414 let mut server = UnifiedMcpServer::new();
415 let mock = Arc::new(MockToolProtocol {
416 name: "temp_tool".to_string(),
417 });
418
419 server.register_tool("temp_tool", mock).await;
420 assert_eq!(server.tool_count().await, 1);
421
422 server.unregister_tool("temp_tool").await;
423 assert_eq!(server.tool_count().await, 0);
424 assert!(!server.has_tool("temp_tool").await);
425 }
426
427 #[tokio::test]
428 async fn test_default_constructor() {
429 let server = UnifiedMcpServer::default();
430 assert_eq!(server.tool_count().await, 0);
431 }
432}