use async_trait::async_trait;
use futures::future::BoxFuture;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use crate::errors::Result;
#[derive(Clone, Default)]
pub enum McpServers {
#[default]
Empty,
Dict(HashMap<String, McpServerConfig>),
Path(PathBuf),
}
#[derive(Clone)]
pub enum McpServerConfig {
Stdio(McpStdioServerConfig),
Sse(McpSseServerConfig),
Http(McpHttpServerConfig),
Sdk(McpSdkServerConfig),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpStdioServerConfig {
pub command: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub args: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub env: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpSseServerConfig {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpHttpServerConfig {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
}
#[derive(Clone)]
pub struct McpSdkServerConfig {
pub name: String,
pub instance: Arc<dyn SdkMcpServer>,
}
#[async_trait]
pub trait SdkMcpServer: Send + Sync {
async fn handle_message(&self, message: serde_json::Value) -> Result<serde_json::Value>;
}
pub trait ToolHandler: Send + Sync {
fn handle(&self, args: serde_json::Value) -> BoxFuture<'static, Result<ToolResult>>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub content: Vec<ToolResultContent>,
#[serde(default)]
pub is_error: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ToolResultContent {
Text {
text: String,
},
Image {
data: String,
mime_type: String,
},
}
pub struct SdkMcpTool {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
pub handler: Arc<dyn ToolHandler>,
}
pub fn create_sdk_mcp_server(
name: impl Into<String>,
version: impl Into<String>,
tools: Vec<SdkMcpTool>,
) -> McpSdkServerConfig {
let server = DefaultSdkMcpServer {
name: name.into(),
version: version.into(),
tools: tools.into_iter().map(|t| (t.name.clone(), t)).collect(),
};
McpSdkServerConfig {
name: server.name.clone(),
instance: Arc::new(server),
}
}
struct DefaultSdkMcpServer {
name: String,
version: String,
tools: HashMap<String, SdkMcpTool>,
}
#[async_trait]
impl SdkMcpServer for DefaultSdkMcpServer {
async fn handle_message(&self, message: serde_json::Value) -> Result<serde_json::Value> {
let method = message["method"]
.as_str()
.ok_or_else(|| crate::errors::ClaudeError::Transport("Missing method".to_string()))?;
match method {
"initialize" => {
Ok(serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": self.name,
"version": self.version
}
}))
},
"tools/list" => {
let tools: Vec<_> = self
.tools
.values()
.map(|t| {
serde_json::json!({
"name": t.name,
"description": t.description,
"inputSchema": t.input_schema
})
})
.collect();
Ok(serde_json::json!({
"tools": tools
}))
},
"tools/call" => {
let params = &message["params"];
let tool_name = params["name"].as_str().ok_or_else(|| {
crate::errors::ClaudeError::Transport("Missing tool name".to_string())
})?;
let arguments = params["arguments"].clone();
let tool = self.tools.get(tool_name).ok_or_else(|| {
crate::errors::ClaudeError::Transport(format!("Tool not found: {}", tool_name))
})?;
let result = tool.handler.handle(arguments).await?;
Ok(serde_json::json!({
"content": result.content,
"isError": result.is_error
}))
},
_ => Err(crate::errors::ClaudeError::Transport(format!(
"Unknown method: {}",
method
))),
}
}
}
#[macro_export]
macro_rules! tool {
($name:expr, $desc:expr, $schema:expr, $handler:expr) => {{
struct Handler<F>(F);
impl<F, Fut> $crate::types::mcp::ToolHandler for Handler<F>
where
F: Fn(serde_json::Value) -> Fut + Send + Sync,
Fut: std::future::Future<Output = anyhow::Result<$crate::types::mcp::ToolResult>>
+ Send
+ 'static,
{
fn handle(
&self,
args: serde_json::Value,
) -> futures::future::BoxFuture<
'static,
$crate::errors::Result<$crate::types::mcp::ToolResult>,
> {
use futures::FutureExt;
let f = &self.0;
let fut = f(args);
async move { fut.await.map_err(|e| e.into()) }.boxed()
}
}
$crate::types::mcp::SdkMcpTool {
name: $name.to_string(),
description: $desc.to_string(),
input_schema: $schema,
handler: std::sync::Arc::new(Handler($handler)),
}
}};
}