use async_trait::async_trait;
use serde::Serialize;
use serde::de::DeserializeOwned;
use serde_json::Value;
use std::collections::HashMap;
use std::future::Future;
use std::marker::PhantomData;
use std::sync::Arc;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SdkToolError {
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Execution failed: {0}")]
ExecutionFailed(String),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
}
#[async_trait]
pub trait SdkTool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn input_schema(&self) -> Value;
async fn execute(&self, input: Value) -> Result<Value, SdkToolError>;
}
pub struct FunctionTool<F, Fut, I, O> {
name: String,
description: String,
handler: F,
_phantom: PhantomData<(Fut, I, O)>,
}
impl<F, Fut, I, O> FunctionTool<F, Fut, I, O> {
pub fn new(name: String, description: String, handler: F) -> Self {
Self {
name,
description,
handler,
_phantom: PhantomData,
}
}
}
#[async_trait]
impl<F, Fut, I, O> SdkTool for FunctionTool<F, Fut, I, O>
where
F: Fn(I) -> Fut + Send + Sync,
Fut: Future<Output = Result<O, SdkToolError>> + Send + Sync,
I: DeserializeOwned + Send + Sync,
O: Serialize + Send + Sync,
{
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn input_schema(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {},
"additionalProperties": true
})
}
async fn execute(&self, input: Value) -> Result<Value, SdkToolError> {
let typed_input: I = serde_json::from_value(input).map_err(|e| {
SdkToolError::InvalidInput(format!("Failed to deserialize input: {}", e))
})?;
let output = (self.handler)(typed_input).await?;
let json_output = serde_json::to_value(output)?;
Ok(json_output)
}
}
pub struct SdkMcpServerBuilder {
name: String,
tools: HashMap<String, Arc<dyn SdkTool>>,
}
impl SdkMcpServerBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
tools: HashMap::new(),
}
}
pub fn tool<F, Fut, I, O>(mut self, name: &str, description: &str, handler: F) -> Self
where
F: Fn(I) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<O, SdkToolError>> + Send + Sync + 'static,
I: DeserializeOwned + Send + Sync + 'static,
O: Serialize + Send + Sync + 'static,
{
let tool = FunctionTool::new(name.to_string(), description.to_string(), handler);
self.tools.insert(name.to_string(), Arc::new(tool));
self
}
pub fn add_tool(mut self, tool: Arc<dyn SdkTool>) -> Self {
let name = tool.name().to_string();
self.tools.insert(name, tool);
self
}
pub fn build(self) -> SdkMcpServer {
SdkMcpServer {
name: self.name,
tools: self.tools,
}
}
}
#[derive(Clone)]
pub struct SdkMcpServer {
name: String,
tools: HashMap<String, Arc<dyn SdkTool>>,
}
impl std::fmt::Debug for SdkMcpServer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SdkMcpServer")
.field("name", &self.name)
.field("tool_count", &self.tools.len())
.finish()
}
}
impl SdkMcpServer {
pub fn name(&self) -> &str {
&self.name
}
pub fn get_tool(&self, name: &str) -> Option<&Arc<dyn SdkTool>> {
self.tools.get(name)
}
pub fn list_tools(&self) -> Vec<&Arc<dyn SdkTool>> {
self.tools.values().collect()
}
pub async fn execute_tool(&self, name: &str, input: Value) -> Result<Value, SdkToolError> {
match self.get_tool(name) {
Some(tool) => tool.execute(input).await,
None => Err(SdkToolError::InvalidInput(format!(
"Tool '{}' not found in server '{}'",
name, self.name
))),
}
}
pub fn has_tool(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
pub fn tool_count(&self) -> usize {
self.tools.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Deserialize)]
struct TestInput {
value: i32,
}
#[derive(Serialize, PartialEq, Debug)]
struct TestOutput {
result: i32,
}
#[tokio::test]
async fn test_function_tool_execution() {
let tool = FunctionTool::new(
"double".to_string(),
"Double a number".to_string(),
|input: TestInput| async move {
Ok(TestOutput {
result: input.value * 2,
})
},
);
let input = serde_json::json!({"value": 21});
let output = tool.execute(input).await.expect("execution failed");
assert_eq!(output, serde_json::json!({"result": 42}));
}
#[tokio::test]
async fn test_sdk_server_builder() {
let server = SdkMcpServerBuilder::new("test-server")
.tool("add", "Add two numbers", |input: TestInput| async move {
Ok(TestOutput {
result: input.value + 10,
})
})
.tool(
"multiply",
"Multiply by two",
|input: TestInput| async move {
Ok(TestOutput {
result: input.value * 2,
})
},
)
.build();
assert_eq!(server.name(), "test-server");
assert_eq!(server.tool_count(), 2);
assert!(server.has_tool("add"));
assert!(server.has_tool("multiply"));
assert!(!server.has_tool("nonexistent"));
}
#[tokio::test]
async fn test_server_execute_tool() {
let server = SdkMcpServerBuilder::new("calculator")
.tool("double", "Double a number", |input: TestInput| async move {
Ok(TestOutput {
result: input.value * 2,
})
})
.build();
let result = server
.execute_tool("double", serde_json::json!({"value": 5}))
.await
.expect("execution failed");
assert_eq!(result, serde_json::json!({"result": 10}));
}
#[tokio::test]
async fn test_tool_not_found() {
let server = SdkMcpServerBuilder::new("empty").build();
let result = server.execute_tool("missing", serde_json::json!({})).await;
assert!(result.is_err());
match result {
Err(SdkToolError::InvalidInput(msg)) => {
assert!(msg.contains("not found"));
}
_ => panic!("Expected InvalidInput error"),
}
}
#[tokio::test]
async fn test_invalid_input_deserialization() {
let server = SdkMcpServerBuilder::new("test")
.tool("strict", "Strict input", |input: TestInput| async move {
Ok(TestOutput {
result: input.value,
})
})
.build();
let result = server.execute_tool("strict", serde_json::json!({})).await;
assert!(result.is_err());
match result {
Err(SdkToolError::InvalidInput(_)) => {}
_ => panic!("Expected InvalidInput error"),
}
}
#[tokio::test]
async fn test_tool_execution_error() {
let server = SdkMcpServerBuilder::new("test")
.tool("failing", "Always fails", |_input: TestInput| async move {
Err::<TestOutput, _>(SdkToolError::ExecutionFailed(
"intentional failure".to_string(),
))
})
.build();
let result = server
.execute_tool("failing", serde_json::json!({"value": 1}))
.await;
assert!(result.is_err());
match result {
Err(SdkToolError::ExecutionFailed(msg)) => {
assert_eq!(msg, "intentional failure");
}
_ => panic!("Expected ExecutionFailed error"),
}
}
#[tokio::test]
async fn test_server_clone() {
let server = SdkMcpServerBuilder::new("original")
.tool("tool1", "First tool", |input: TestInput| async move {
Ok(TestOutput {
result: input.value,
})
})
.build();
let cloned = server.clone();
assert_eq!(server.name(), cloned.name());
assert_eq!(server.tool_count(), cloned.tool_count());
assert!(cloned.has_tool("tool1"));
}
#[tokio::test]
async fn test_list_tools() {
let server = SdkMcpServerBuilder::new("test")
.tool("tool1", "First", |input: TestInput| async move {
Ok(TestOutput {
result: input.value,
})
})
.tool("tool2", "Second", |input: TestInput| async move {
Ok(TestOutput {
result: input.value,
})
})
.build();
let tools = server.list_tools();
assert_eq!(tools.len(), 2);
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
assert!(names.contains(&"tool1"));
assert!(names.contains(&"tool2"));
}
}