use crate::protocol::{ToolError, ToolMetadata, ToolProtocol, ToolResult};
use async_trait::async_trait;
use std::collections::HashMap;
use std::error::Error;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Clone)]
pub struct UnifiedMcpServer {
tools: Arc<RwLock<HashMap<String, Arc<dyn ToolProtocol>>>>,
}
impl UnifiedMcpServer {
pub fn new() -> Self {
Self {
tools: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register_tool(&mut self, tool_name: &str, protocol: Arc<dyn ToolProtocol>) {
let mut tools = self.tools.write().await;
tools.insert(tool_name.to_string(), protocol);
}
pub async fn unregister_tool(&mut self, tool_name: &str) {
let mut tools = self.tools.write().await;
tools.remove(tool_name);
}
pub async fn has_tool(&self, tool_name: &str) -> bool {
let tools = self.tools.read().await;
tools.contains_key(tool_name)
}
pub async fn tool_count(&self) -> usize {
let tools = self.tools.read().await;
tools.len()
}
}
impl Default for UnifiedMcpServer {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ToolProtocol for UnifiedMcpServer {
async fn execute(
&self,
tool_name: &str,
parameters: serde_json::Value,
) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
let tools = self.tools.read().await;
let protocol = tools.get(tool_name).cloned().ok_or_else(|| {
Box::new(ToolError::NotFound(tool_name.to_string())) as Box<dyn Error + Send + Sync>
})?;
drop(tools);
protocol.execute(tool_name, parameters).await
}
async fn list_tools(&self) -> Result<Vec<ToolMetadata>, Box<dyn Error + Send + Sync>> {
let tools = self.tools.read().await;
let mut seen: std::collections::HashSet<usize> = std::collections::HashSet::new();
let protocols: Vec<Arc<dyn ToolProtocol>> = tools
.values()
.filter(|p| seen.insert(Arc::as_ptr(*p) as *const () as usize))
.cloned()
.collect();
drop(tools);
let mut all_tools = Vec::new();
for protocol in protocols {
match protocol.list_tools().await {
Ok(mut tool_list) => all_tools.append(&mut tool_list),
Err(e) => {
eprintln!("Error listing tools from protocol: {}", e);
}
}
}
Ok(all_tools)
}
async fn get_tool_metadata(
&self,
tool_name: &str,
) -> Result<ToolMetadata, Box<dyn Error + Send + Sync>> {
let all_tools = self.list_tools().await?;
all_tools
.into_iter()
.find(|t| t.name == tool_name)
.ok_or_else(|| {
Box::new(ToolError::NotFound(tool_name.to_string())) as Box<dyn Error + Send + Sync>
})
}
fn protocol_name(&self) -> &str {
"unified-mcp-server"
}
async fn initialize(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
let _tools = self.tools.read().await;
Ok(())
}
async fn shutdown(&mut self) -> Result<(), Box<dyn Error + Send + Sync>> {
let _tools = self.tools.read().await;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::ToolMetadata;
struct MockToolProtocol {
name: String,
}
#[async_trait]
impl ToolProtocol for MockToolProtocol {
async fn execute(
&self,
tool_name: &str,
_parameters: serde_json::Value,
) -> Result<ToolResult, Box<dyn Error + Send + Sync>> {
Ok(ToolResult::success(serde_json::json!({
"tool": tool_name,
"source": &self.name
})))
}
async fn list_tools(&self) -> Result<Vec<ToolMetadata>, Box<dyn Error + Send + Sync>> {
Ok(vec![ToolMetadata::new(&self.name, "A mock tool")])
}
async fn get_tool_metadata(
&self,
tool_name: &str,
) -> Result<ToolMetadata, Box<dyn Error + Send + Sync>> {
if tool_name == self.name {
Ok(ToolMetadata::new(&self.name, "A mock tool"))
} else {
Err(Box::new(ToolError::NotFound(tool_name.to_string())))
}
}
fn protocol_name(&self) -> &str {
"mock"
}
}
#[tokio::test]
async fn test_unified_server_creation() {
let server = UnifiedMcpServer::new();
assert_eq!(server.tool_count().await, 0);
assert_eq!(server.protocol_name(), "unified-mcp-server");
}
#[tokio::test]
async fn test_register_single_tool() {
let mut server = UnifiedMcpServer::new();
let mock = Arc::new(MockToolProtocol {
name: "test_tool".to_string(),
});
server.register_tool("test_tool", mock).await;
assert_eq!(server.tool_count().await, 1);
assert!(server.has_tool("test_tool").await);
}
#[tokio::test]
async fn test_register_multiple_tools() {
let mut server = UnifiedMcpServer::new();
let mock1 = Arc::new(MockToolProtocol {
name: "tool1".to_string(),
});
let mock2 = Arc::new(MockToolProtocol {
name: "tool2".to_string(),
});
server.register_tool("tool1", mock1).await;
server.register_tool("tool2", mock2).await;
assert_eq!(server.tool_count().await, 2);
assert!(server.has_tool("tool1").await);
assert!(server.has_tool("tool2").await);
}
#[tokio::test]
async fn test_execute_tool_routing() {
let mut server = UnifiedMcpServer::new();
let mock = Arc::new(MockToolProtocol {
name: "router_test".to_string(),
});
server.register_tool("router_test", mock).await;
let result = server.execute("router_test", serde_json::json!({})).await;
assert!(result.is_ok());
let tool_result = result.unwrap();
assert!(tool_result.success);
assert_eq!(tool_result.output["tool"], "router_test");
}
#[tokio::test]
async fn test_execute_nonexistent_tool() {
let server = UnifiedMcpServer::new();
let result = server.execute("nonexistent", serde_json::json!({})).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("not found") || err.contains("NotFound"));
}
#[tokio::test]
async fn test_list_tools_aggregation() {
let mut server = UnifiedMcpServer::new();
let mock1 = Arc::new(MockToolProtocol {
name: "tool1".to_string(),
});
let mock2 = Arc::new(MockToolProtocol {
name: "tool2".to_string(),
});
server.register_tool("tool1", mock1).await;
server.register_tool("tool2", mock2).await;
let tools = server.list_tools().await.unwrap();
assert_eq!(tools.len(), 2);
assert!(tools.iter().any(|t| t.name == "tool1"));
assert!(tools.iter().any(|t| t.name == "tool2"));
}
#[tokio::test]
async fn test_get_tool_metadata() {
let mut server = UnifiedMcpServer::new();
let mock = Arc::new(MockToolProtocol {
name: "metadata_test".to_string(),
});
server.register_tool("metadata_test", mock).await;
let metadata = server.get_tool_metadata("metadata_test").await;
assert!(metadata.is_ok());
assert_eq!(metadata.unwrap().name, "metadata_test");
}
#[tokio::test]
async fn test_unregister_tool() {
let mut server = UnifiedMcpServer::new();
let mock = Arc::new(MockToolProtocol {
name: "temp_tool".to_string(),
});
server.register_tool("temp_tool", mock).await;
assert_eq!(server.tool_count().await, 1);
server.unregister_tool("temp_tool").await;
assert_eq!(server.tool_count().await, 0);
assert!(!server.has_tool("temp_tool").await);
}
#[tokio::test]
async fn test_default_constructor() {
let server = UnifiedMcpServer::default();
assert_eq!(server.tool_count().await, 0);
}
}