use std::collections::HashMap;
use super::protocol::{JsonRpcRequest, JsonRpcResponse, McpError};
use super::tool::SdkMcpTool;
use crate::error::{ClaudeError, Result};
pub struct SdkMcpServer {
name: String,
version: String,
tools: HashMap<String, SdkMcpTool>,
}
impl SdkMcpServer {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
version: "1.0.0".to_string(),
tools: HashMap::new(),
}
}
pub fn version(mut self, version: impl Into<String>) -> Self {
self.version = version.into();
self
}
pub fn tool(mut self, tool: SdkMcpTool) -> Self {
let name = tool.name().to_string();
self.tools.insert(name, tool);
self
}
pub fn tools(mut self, tools: Vec<SdkMcpTool>) -> Self {
for tool in tools {
let name = tool.name().to_string();
self.tools.insert(name, tool);
}
self
}
pub fn name(&self) -> &str {
&self.name
}
pub fn server_version(&self) -> &str {
&self.version
}
pub fn get_tool(&self, name: &str) -> Option<&SdkMcpTool> {
self.tools.get(name)
}
pub fn list_tools(&self) -> Vec<&SdkMcpTool> {
self.tools.values().collect()
}
pub async fn handle_request(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
let request_id = request.id.clone().unwrap_or(serde_json::json!(null));
match request.method.as_str() {
"tools/list" => self.handle_tools_list(request_id).await,
"tools/call" => self.handle_tools_call(request_id, request.params).await,
_ => Ok(JsonRpcResponse::error(
request_id,
McpError::method_not_found(&request.method),
)),
}
}
async fn handle_tools_list(&self, request_id: serde_json::Value) -> Result<JsonRpcResponse> {
let tools: Vec<_> = self.tools.values().map(|tool| tool.to_tool_info()).collect();
Ok(JsonRpcResponse::success(
request_id,
serde_json::json!({
"tools": tools
}),
))
}
async fn handle_tools_call(
&self,
request_id: serde_json::Value,
params: Option<serde_json::Value>,
) -> Result<JsonRpcResponse> {
let params = match params {
Some(p) => p,
None => {
return Ok(JsonRpcResponse::error(
request_id,
McpError::invalid_params("tools/call requires parameters".to_string()),
));
}
};
let tool_name = match params["name"].as_str() {
Some(name) => name,
None => {
return Ok(JsonRpcResponse::error(
request_id,
McpError::invalid_params("Missing tool name in parameters".to_string()),
));
}
};
let tool = match self.tools.get(tool_name) {
Some(t) => t,
None => {
return Ok(JsonRpcResponse::error(
request_id,
McpError::tool_not_found(tool_name),
));
}
};
let arguments = params["arguments"].clone();
match tool.invoke(arguments).await {
Ok(result) => {
let result_json = serde_json::to_value(result)
.map_err(|e| ClaudeError::Mcp(format!("Failed to serialize result: {e}")))?;
Ok(JsonRpcResponse::success(request_id, result_json))
}
Err(e) => Ok(JsonRpcResponse::error(
request_id,
McpError::internal_error(format!("Tool execution failed: {e}")),
)),
}
}
}
impl std::fmt::Debug for SdkMcpServer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SdkMcpServer")
.field("name", &self.name)
.field("version", &self.version)
.field("tools", &self.tools.keys().collect::<Vec<_>>())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mcp::ToolResult;
use serde_json::json;
fn create_test_tool(name: &str) -> SdkMcpTool {
SdkMcpTool::new(
name,
format!("Test tool {name}"),
json!({"type": "object"}),
|input| {
Box::pin(async move {
let text = input["text"].as_str().unwrap_or("default");
Ok(ToolResult::text(text))
})
},
)
}
#[test]
fn test_server_creation() {
let server = SdkMcpServer::new("test-server").version("1.0.0");
assert_eq!(server.name(), "test-server");
assert_eq!(server.server_version(), "1.0.0");
}
#[test]
fn test_tool_registration() {
let server = SdkMcpServer::new("test")
.tool(create_test_tool("tool1"))
.tool(create_test_tool("tool2"));
assert_eq!(server.list_tools().len(), 2);
assert!(server.get_tool("tool1").is_some());
assert!(server.get_tool("tool2").is_some());
assert!(server.get_tool("tool3").is_none());
}
#[test]
fn test_multiple_tools_registration() {
let tools = vec![create_test_tool("a"), create_test_tool("b")];
let server = SdkMcpServer::new("test").tools(tools);
assert_eq!(server.list_tools().len(), 2);
}
#[tokio::test]
async fn test_tools_list_request() {
let server = SdkMcpServer::new("test")
.tool(create_test_tool("tool1"))
.tool(create_test_tool("tool2"));
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "tools/list".to_string(),
params: None,
};
let response = server.handle_request(request).await.unwrap();
assert!(response.result.is_some());
let result = response.result.unwrap();
let tools = result["tools"].as_array().unwrap();
assert_eq!(tools.len(), 2);
}
#[tokio::test]
async fn test_tools_call_request() {
let server = SdkMcpServer::new("test").tool(create_test_tool("echo"));
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "tools/call".to_string(),
params: Some(json!({
"name": "echo",
"arguments": {
"text": "hello"
}
})),
};
let response = server.handle_request(request).await.unwrap();
assert!(response.result.is_some());
let result = response.result.unwrap();
assert_eq!(result["content"][0]["text"], "hello");
}
#[tokio::test]
async fn test_unknown_method() {
let server = SdkMcpServer::new("test");
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "unknown/method".to_string(),
params: None,
};
let response = server.handle_request(request).await.unwrap();
assert!(response.error.is_some());
assert_eq!(response.error.unwrap().code, -32601);
}
#[tokio::test]
async fn test_tool_not_found() {
let server = SdkMcpServer::new("test");
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(json!(1)),
method: "tools/call".to_string(),
params: Some(json!({
"name": "nonexistent",
"arguments": {}
})),
};
let response = server.handle_request(request).await.unwrap();
assert!(response.error.is_some());
}
}