use async_trait::async_trait;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use serde_json::{Value, json};
use std::collections::HashMap;
use std::sync::Arc;
use crate::control::handlers::McpMessageHandler;
use crate::error::ClawError;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ToolContent {
Text {
text: String,
},
Image {
data: String,
#[serde(rename = "mimeType")]
mime_type: String,
},
}
impl ToolContent {
pub fn text(text: impl Into<String>) -> Self {
Self::Text { text: text.into() }
}
pub fn image(data: impl Into<String>, mime_type: impl Into<String>) -> Self {
Self::Image {
data: data.into(),
mime_type: mime_type.into(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub content: Vec<ToolContent>,
#[serde(rename = "isError", skip_serializing_if = "Option::is_none")]
pub is_error: Option<bool>,
}
impl ToolResult {
pub fn text(text: impl Into<String>) -> Self {
Self {
content: vec![ToolContent::text(text)],
is_error: None,
}
}
pub fn error(text: impl Into<String>) -> Self {
Self {
content: vec![ToolContent::text(text)],
is_error: Some(true),
}
}
pub fn new(content: Vec<ToolContent>) -> Self {
Self {
content,
is_error: None,
}
}
}
#[async_trait]
pub trait ToolHandler: Send + Sync {
async fn call(&self, args: Value) -> Result<ToolResult, ClawError>;
}
pub struct TypedToolHandler<I, F, Fut>
where
I: DeserializeOwned + Send + Sync + 'static,
F: Fn(I) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<ToolResult, ClawError>> + Send + 'static,
{
handler: F,
_phantom: std::marker::PhantomData<I>,
}
impl<I, F, Fut> TypedToolHandler<I, F, Fut>
where
I: DeserializeOwned + Send + Sync + 'static,
F: Fn(I) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<ToolResult, ClawError>> + Send + 'static,
{
pub fn new(handler: F) -> Self {
Self {
handler,
_phantom: std::marker::PhantomData,
}
}
}
#[async_trait]
impl<I, F, Fut> ToolHandler for TypedToolHandler<I, F, Fut>
where
I: DeserializeOwned + Send + Sync + 'static,
F: Fn(I) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<ToolResult, ClawError>> + Send + 'static,
{
async fn call(&self, args: Value) -> Result<ToolResult, ClawError> {
let typed_input: I = serde_json::from_value(args.clone()).map_err(|e| {
ClawError::ToolExecution(format!(
"Failed to deserialize tool args into {}: {}. Raw args: {}",
std::any::type_name::<I>(),
e,
args
))
})?;
(self.handler)(typed_input).await
}
}
#[derive(Clone)]
pub struct SdkMcpTool {
pub name: String,
pub description: String,
pub input_schema: Value,
handler: Arc<dyn ToolHandler>,
}
impl SdkMcpTool {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
input_schema: Value,
handler: Arc<dyn ToolHandler>,
) -> Self {
Self {
name: name.into(),
description: description.into(),
input_schema,
handler,
}
}
pub fn to_tool_definition(&self) -> Value {
json!({
"name": self.name,
"description": self.description,
"inputSchema": self.input_schema,
})
}
pub async fn execute(&self, args: Value) -> Result<ToolResult, ClawError> {
self.handler.call(args).await
}
}
pub struct SdkMcpServerImpl {
pub name: String,
pub version: String,
tools: HashMap<String, SdkMcpTool>,
}
impl SdkMcpServerImpl {
pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
Self {
name: name.into(),
version: version.into(),
tools: HashMap::new(),
}
}
pub fn from_tools(
name: impl Into<String>,
version: impl Into<String>,
tools: Vec<SdkMcpTool>,
) -> Self {
let mut server = Self::new(name, version);
for tool in tools {
server.register_tool(tool);
}
server
}
pub fn register_tool(&mut self, tool: SdkMcpTool) {
self.tools.insert(tool.name.clone(), tool);
}
pub fn get_tool(&self, name: &str) -> Option<&SdkMcpTool> {
self.tools.get(name)
}
pub fn list_tools(&self) -> Vec<Value> {
self.tools
.values()
.map(|t| t.to_tool_definition())
.collect()
}
pub async fn handle_jsonrpc(&self, request: Value) -> Result<Value, ClawError> {
let method = request["method"]
.as_str()
.ok_or_else(|| ClawError::ControlError("Missing method field".to_string()))?;
match method {
"initialize" => self.handle_initialize(&request),
"notifications/initialized" => Ok(json_rpc_success(request["id"].clone(), json!({}))),
"tools/list" => self.handle_tools_list(&request),
"tools/call" => self.handle_tools_call(&request).await,
_ => Ok(json_rpc_error(
request["id"].clone(),
-32601,
format!("Method not found: {}", method),
)),
}
}
fn handle_initialize(&self, request: &Value) -> Result<Value, ClawError> {
Ok(json_rpc_success(
request["id"].clone(),
json!({
"protocolVersion": "2025-11-25",
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": self.name,
"version": self.version
}
}),
))
}
fn handle_tools_list(&self, request: &Value) -> Result<Value, ClawError> {
Ok(json_rpc_success(
request["id"].clone(),
json!({
"tools": self.list_tools()
}),
))
}
async fn handle_tools_call(&self, request: &Value) -> Result<Value, ClawError> {
let params = request["params"]
.as_object()
.ok_or_else(|| ClawError::ControlError("Missing params".to_string()))?;
let name = params["name"]
.as_str()
.ok_or_else(|| ClawError::ControlError("Missing tool name".to_string()))?;
let arguments = params
.get("arguments")
.cloned()
.unwrap_or_else(|| json!({}));
let tool = match self.get_tool(name) {
Some(t) => t,
None => {
return Ok(json_rpc_error(
request["id"].clone(),
-32602,
format!("Tool not found: {}", name),
));
}
};
match tool.execute(arguments).await {
Ok(result) => Ok(json_rpc_success(request["id"].clone(), result)),
Err(e) => Ok(json_rpc_success(
request["id"].clone(),
ToolResult::error(format!("Tool execution failed: {}", e)),
)),
}
}
}
pub fn create_sdk_mcp_server(
name: impl Into<String>,
version: impl Into<String>,
tools: Vec<SdkMcpTool>,
) -> Arc<SdkMcpServerImpl> {
Arc::new(SdkMcpServerImpl::from_tools(name, version, tools))
}
pub struct SdkMcpServerRegistry {
servers: HashMap<String, SdkMcpServerImpl>,
}
impl SdkMcpServerRegistry {
pub fn new() -> Self {
Self {
servers: HashMap::new(),
}
}
pub fn register(&mut self, server: SdkMcpServerImpl) {
self.servers.insert(server.name.clone(), server);
}
pub fn get(&self, name: &str) -> Option<&SdkMcpServerImpl> {
self.servers.get(name)
}
}
impl Default for SdkMcpServerRegistry {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl McpMessageHandler for SdkMcpServerRegistry {
async fn handle(&self, server_name: &str, message: Value) -> Result<Value, ClawError> {
let server = self
.get(server_name)
.ok_or_else(|| ClawError::ControlError(format!("Server not found: {}", server_name)))?;
server.handle_jsonrpc(message).await
}
}
fn json_rpc_success(id: Value, result: impl Serialize) -> Value {
json!({
"jsonrpc": "2.0",
"id": id,
"result": result
})
}
fn json_rpc_error(id: Value, code: i32, message: String) -> Value {
json!({
"jsonrpc": "2.0",
"id": id,
"error": {
"code": code,
"message": message
}
})
}
#[cfg(test)]
mod tests {
use super::*;
struct MockHandler {
response: String,
}
#[async_trait]
impl ToolHandler for MockHandler {
async fn call(&self, _args: Value) -> Result<ToolResult, ClawError> {
Ok(ToolResult::text(&self.response))
}
}
struct ErrorHandler;
#[async_trait]
impl ToolHandler for ErrorHandler {
async fn call(&self, _args: Value) -> Result<ToolResult, ClawError> {
Err(ClawError::ControlError("Handler error".to_string()))
}
}
#[test]
fn test_tool_content_text() {
let content = ToolContent::text("Hello");
match content {
ToolContent::Text { text } => assert_eq!(text, "Hello"),
_ => panic!("Expected Text variant"),
}
}
#[test]
fn test_tool_content_image() {
let content = ToolContent::image("data123", "image/png");
match content {
ToolContent::Image { data, mime_type } => {
assert_eq!(data, "data123");
assert_eq!(mime_type, "image/png");
}
_ => panic!("Expected Image variant"),
}
}
#[test]
fn test_tool_result_text() {
let result = ToolResult::text("Success");
assert_eq!(result.content.len(), 1);
assert!(result.is_error.is_none());
}
#[test]
fn test_tool_result_error() {
let result = ToolResult::error("Failed");
assert_eq!(result.content.len(), 1);
assert_eq!(result.is_error, Some(true));
}
#[test]
fn test_tool_result_new() {
let result = ToolResult::new(vec![
ToolContent::text("Text"),
ToolContent::image("data", "image/png"),
]);
assert_eq!(result.content.len(), 2);
assert!(result.is_error.is_none());
}
#[tokio::test]
async fn test_tool_handler() {
let handler = MockHandler {
response: "Test".to_string(),
};
let result = handler.call(json!({})).await.unwrap();
match &result.content[0] {
ToolContent::Text { text } => assert_eq!(text, "Test"),
_ => panic!("Expected text content"),
}
}
#[test]
fn test_sdk_mcp_tool_new() {
let handler = Arc::new(MockHandler {
response: "Test".to_string(),
});
let tool = SdkMcpTool::new(
"test_tool",
"Test description",
json!({"type": "object"}),
handler,
);
assert_eq!(tool.name, "test_tool");
assert_eq!(tool.description, "Test description");
}
#[test]
fn test_sdk_mcp_tool_to_definition() {
let handler = Arc::new(MockHandler {
response: "Test".to_string(),
});
let tool = SdkMcpTool::new(
"test_tool",
"Test description",
json!({"type": "object"}),
handler,
);
let def = tool.to_tool_definition();
assert_eq!(def["name"], "test_tool");
assert_eq!(def["description"], "Test description");
assert_eq!(def["inputSchema"]["type"], "object");
}
#[tokio::test]
async fn test_sdk_mcp_tool_execute() {
let handler = Arc::new(MockHandler {
response: "Executed".to_string(),
});
let tool = SdkMcpTool::new("test_tool", "Test", json!({"type": "object"}), handler);
let result = tool.execute(json!({})).await.unwrap();
match &result.content[0] {
ToolContent::Text { text } => assert_eq!(text, "Executed"),
_ => panic!("Expected text content"),
}
}
#[test]
fn test_sdk_mcp_server_new() {
let server = SdkMcpServerImpl::new("test_server", "1.0.0");
assert_eq!(server.name, "test_server");
assert_eq!(server.version, "1.0.0");
assert_eq!(server.tools.len(), 0);
}
#[test]
fn test_sdk_mcp_server_register_tool() {
let mut server = SdkMcpServerImpl::new("test_server", "1.0.0");
let handler = Arc::new(MockHandler {
response: "Test".to_string(),
});
let tool = SdkMcpTool::new("tool1", "Test", json!({"type": "object"}), handler);
server.register_tool(tool);
assert_eq!(server.tools.len(), 1);
assert!(server.get_tool("tool1").is_some());
}
#[test]
fn test_sdk_mcp_server_list_tools() {
let mut server = SdkMcpServerImpl::new("test_server", "1.0.0");
let handler = Arc::new(MockHandler {
response: "Test".to_string(),
});
server.register_tool(SdkMcpTool::new(
"tool1",
"Test 1",
json!({"type": "object"}),
handler.clone(),
));
server.register_tool(SdkMcpTool::new(
"tool2",
"Test 2",
json!({"type": "object"}),
handler,
));
let tools = server.list_tools();
assert_eq!(tools.len(), 2);
}
#[tokio::test]
async fn test_handle_initialize() {
let server = SdkMcpServerImpl::new("test_server", "1.0.0");
let request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize"
});
let response = server.handle_jsonrpc(request).await.unwrap();
assert_eq!(response["jsonrpc"], "2.0");
assert_eq!(response["id"], 1);
assert_eq!(response["result"]["serverInfo"]["name"], "test_server");
}
#[tokio::test]
async fn test_handle_tools_list() {
let mut server = SdkMcpServerImpl::new("test_server", "1.0.0");
let handler = Arc::new(MockHandler {
response: "Test".to_string(),
});
server.register_tool(SdkMcpTool::new(
"tool1",
"Test",
json!({"type": "object"}),
handler,
));
let request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list"
});
let response = server.handle_jsonrpc(request).await.unwrap();
assert_eq!(response["result"]["tools"].as_array().unwrap().len(), 1);
}
#[tokio::test]
async fn test_handle_tools_call() {
let mut server = SdkMcpServerImpl::new("test_server", "1.0.0");
let handler = Arc::new(MockHandler {
response: "Result".to_string(),
});
server.register_tool(SdkMcpTool::new(
"tool1",
"Test",
json!({"type": "object"}),
handler,
));
let request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {
"name": "tool1",
"arguments": {}
}
});
let response = server.handle_jsonrpc(request).await.unwrap();
assert_eq!(response["result"]["content"][0]["text"], "Result");
}
#[tokio::test]
async fn test_handle_tools_call_not_found() {
let server = SdkMcpServerImpl::new("test_server", "1.0.0");
let request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {
"name": "nonexistent",
"arguments": {}
}
});
let response = server.handle_jsonrpc(request).await.unwrap();
assert!(response["error"].is_object());
assert_eq!(response["error"]["code"], -32602);
}
#[tokio::test]
async fn test_handle_tools_call_handler_error() {
let mut server = SdkMcpServerImpl::new("test_server", "1.0.0");
server.register_tool(SdkMcpTool::new(
"error_tool",
"Test",
json!({"type": "object"}),
Arc::new(ErrorHandler),
));
let request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {
"name": "error_tool",
"arguments": {}
}
});
let response = server.handle_jsonrpc(request).await.unwrap();
assert!(response["error"].is_null());
assert!(response["result"].is_object());
assert_eq!(response["result"]["isError"], true);
}
#[tokio::test]
async fn test_handle_unknown_method() {
let server = SdkMcpServerImpl::new("test_server", "1.0.0");
let request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "unknown/method"
});
let response = server.handle_jsonrpc(request).await.unwrap();
assert!(response["error"].is_object());
assert_eq!(response["error"]["code"], -32601);
}
#[test]
fn test_registry_new() {
let registry = SdkMcpServerRegistry::new();
assert_eq!(registry.servers.len(), 0);
}
#[test]
fn test_registry_register() {
let mut registry = SdkMcpServerRegistry::new();
let server = SdkMcpServerImpl::new("test_server", "1.0.0");
registry.register(server);
assert_eq!(registry.servers.len(), 1);
assert!(registry.get("test_server").is_some());
}
#[tokio::test]
async fn test_registry_handle() {
let mut registry = SdkMcpServerRegistry::new();
let server = SdkMcpServerImpl::new("test_server", "1.0.0");
registry.register(server);
let message = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize"
});
let response = registry.handle("test_server", message).await.unwrap();
assert_eq!(response["result"]["serverInfo"]["name"], "test_server");
}
#[tokio::test]
async fn test_registry_handle_server_not_found() {
let registry = SdkMcpServerRegistry::new();
let message = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize"
});
let result = registry.handle("nonexistent", message).await;
assert!(result.is_err());
}
#[test]
fn test_json_rpc_success() {
let response = json_rpc_success(json!(1), json!({"status": "ok"}));
assert_eq!(response["jsonrpc"], "2.0");
assert_eq!(response["id"], 1);
assert_eq!(response["result"]["status"], "ok");
}
#[test]
fn test_json_rpc_error() {
let response = json_rpc_error(json!(1), -32601, "Method not found".to_string());
assert_eq!(response["jsonrpc"], "2.0");
assert_eq!(response["id"], 1);
assert_eq!(response["error"]["code"], -32601);
assert_eq!(response["error"]["message"], "Method not found");
}
#[tokio::test]
async fn test_typed_tool_handler_success() {
use serde::Deserialize;
#[derive(Deserialize)]
struct AddInput {
a: f64,
b: f64,
}
let handler = TypedToolHandler::new(|input: AddInput| async move {
Ok(ToolResult::text(format!("{}", input.a + input.b)))
});
let result = handler.call(json!({"a": 3.0, "b": 4.0})).await.unwrap();
match &result.content[0] {
ToolContent::Text { text } => assert_eq!(text, "7"),
_ => panic!("Expected text"),
}
}
#[tokio::test]
async fn test_typed_tool_handler_deserialization_error() {
use serde::Deserialize;
#[derive(Deserialize)]
#[allow(dead_code)]
struct StrictInput {
count: u32,
}
let handler =
TypedToolHandler::new(|_input: StrictInput| async move { Ok(ToolResult::text("ok")) });
let result = handler.call(json!({"wrong_field": 1})).await;
assert!(result.is_err(), "Expected error for bad input");
let err = result.unwrap_err().to_string();
assert!(
err.contains("Failed to deserialize tool args"),
"Error should mention deserialization: {}",
err
);
}
#[tokio::test]
async fn test_typed_tool_handler_optional_field() {
use serde::Deserialize;
#[derive(Deserialize)]
struct Input {
required: String,
optional: Option<String>,
}
let handler = TypedToolHandler::new(|input: Input| async move {
let suffix = input.optional.unwrap_or_else(|| "none".to_string());
Ok(ToolResult::text(format!("{}:{}", input.required, suffix)))
});
let r1 = handler
.call(json!({"required": "hello", "optional": "world"}))
.await
.unwrap();
match &r1.content[0] {
ToolContent::Text { text } => assert_eq!(text, "hello:world"),
_ => panic!(),
}
let r2 = handler.call(json!({"required": "hello"})).await.unwrap();
match &r2.content[0] {
ToolContent::Text { text } => assert_eq!(text, "hello:none"),
_ => panic!(),
}
}
#[test]
fn test_sdk_mcp_server_from_tools() {
let handler = Arc::new(MockHandler {
response: "ok".to_string(),
});
let tools = vec![
SdkMcpTool::new("t1", "Tool 1", json!({"type": "object"}), handler.clone()),
SdkMcpTool::new("t2", "Tool 2", json!({"type": "object"}), handler),
];
let server = SdkMcpServerImpl::from_tools("my_server", "2.0.0", tools);
assert_eq!(server.name, "my_server");
assert_eq!(server.version, "2.0.0");
assert_eq!(server.list_tools().len(), 2);
assert!(server.get_tool("t1").is_some());
assert!(server.get_tool("t2").is_some());
}
#[test]
fn test_create_sdk_mcp_server() {
let handler = Arc::new(MockHandler {
response: "result".to_string(),
});
let server = create_sdk_mcp_server(
"test_server",
"1.0.0",
vec![SdkMcpTool::new(
"my_tool",
"Test",
json!({"type": "object"}),
handler,
)],
);
assert_eq!(server.name, "test_server");
assert_eq!(server.list_tools().len(), 1);
}
#[test]
fn test_create_sdk_mcp_server_empty() {
let server = create_sdk_mcp_server("empty_server", "0.1.0", vec![]);
assert_eq!(server.name, "empty_server");
assert_eq!(server.list_tools().len(), 0);
}
#[tokio::test]
async fn test_create_sdk_mcp_server_can_handle_requests() {
let handler = Arc::new(MockHandler {
response: "hello".to_string(),
});
let server = create_sdk_mcp_server(
"hello_server",
"1.0.0",
vec![SdkMcpTool::new(
"greet",
"Greet",
json!({"type": "object"}),
handler,
)],
);
let request = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {"name": "greet", "arguments": {}}
});
let response = server.handle_jsonrpc(request).await.unwrap();
assert_eq!(response["result"]["content"][0]["text"], "hello");
}
}