use async_trait::async_trait;
use futures::future::BoxFuture;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use typed_builder::TypedBuilder;
use crate::errors::Result;
pub const ACP_TOOL_PREFIX: &str = "mcp__acp__";
pub fn acp_tool_name(tool_name: &str) -> String {
format!("{}{}", ACP_TOOL_PREFIX, tool_name)
}
pub fn is_acp_tool(tool_name: &str) -> bool {
tool_name.starts_with(ACP_TOOL_PREFIX)
}
pub fn strip_acp_prefix(tool_name: &str) -> &str {
tool_name.strip_prefix(ACP_TOOL_PREFIX).unwrap_or(tool_name)
}
#[derive(Clone, Default)]
pub enum McpServers {
#[default]
Empty,
Dict(HashMap<String, McpServerConfig>),
Path(PathBuf),
}
#[derive(Clone)]
pub enum McpServerConfig {
Stdio(McpStdioServerConfig),
Sse(McpSseServerConfig),
Http(McpHttpServerConfig),
Sdk(McpSdkServerConfig),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpStdioServerConfig {
pub command: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub args: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub env: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpSseServerConfig {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpHttpServerConfig {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
}
#[derive(Clone)]
pub struct McpSdkServerConfig {
pub name: String,
pub instance: Arc<dyn SdkMcpServer>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, TypedBuilder)]
#[builder(doc)]
pub struct ToolAnnotations {
#[serde(skip_serializing_if = "Option::is_none", rename = "readOnlyHint")]
#[builder(default, setter(strip_option))]
pub read_only_hint: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none", rename = "destructiveHint")]
#[builder(default, setter(strip_option))]
pub destructive_hint: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none", rename = "idempotentHint")]
#[builder(default, setter(strip_option))]
pub idempotent_hint: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none", rename = "openWorldHint")]
#[builder(default, setter(strip_option))]
pub open_world_hint: Option<bool>,
}
#[derive(Debug, Clone)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
pub annotations: Option<ToolAnnotations>,
}
#[async_trait]
pub trait SdkMcpServer: Send + Sync {
async fn handle_message(&self, message: serde_json::Value) -> Result<serde_json::Value>;
fn list_tools(&self) -> Vec<ToolDefinition>;
}
pub trait ToolHandler: Send + Sync {
fn handle(&self, args: serde_json::Value) -> BoxFuture<'static, Result<ToolResult>>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub content: Vec<ToolResultContent>,
#[serde(default)]
pub is_error: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ToolResultContent {
Text {
text: String,
},
Image {
data: String,
#[serde(rename = "mimeType")]
mime_type: String,
},
}
pub struct SdkMcpTool {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
pub handler: Arc<dyn ToolHandler>,
pub annotations: Option<ToolAnnotations>,
}
pub fn create_sdk_mcp_server(
name: impl Into<String>,
version: impl Into<String>,
tools: Vec<SdkMcpTool>,
) -> McpSdkServerConfig {
let server = DefaultSdkMcpServer {
name: name.into(),
version: version.into(),
tools: tools.into_iter().map(|t| (t.name.clone(), t)).collect(),
};
McpSdkServerConfig {
name: server.name.clone(),
instance: Arc::new(server),
}
}
struct DefaultSdkMcpServer {
name: String,
version: String,
tools: HashMap<String, SdkMcpTool>,
}
#[async_trait]
impl SdkMcpServer for DefaultSdkMcpServer {
async fn handle_message(&self, message: serde_json::Value) -> Result<serde_json::Value> {
let method = message["method"]
.as_str()
.ok_or_else(|| crate::errors::ClaudeError::Transport("Missing method".to_string()))?;
match method {
"initialize" => {
Ok(serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": self.name,
"version": self.version
}
}))
}
"notifications/initialized" => {
Ok(serde_json::json!({}))
}
"tools/list" => {
let tools: Vec<_> = self
.tools
.values()
.map(|t| {
let mut tool_json = serde_json::json!({
"name": t.name,
"description": t.description,
"inputSchema": t.input_schema
});
if let Some(ref annotations) = t.annotations {
tool_json["annotations"] =
serde_json::to_value(annotations).unwrap_or_default();
}
tool_json
})
.collect();
Ok(serde_json::json!({
"tools": tools
}))
}
"tools/call" => {
let params = &message["params"];
let tool_name = params["name"].as_str().ok_or_else(|| {
crate::errors::ClaudeError::Transport("Missing tool name".to_string())
})?;
let arguments = params["arguments"].clone();
let tool = self.tools.get(tool_name).ok_or_else(|| {
crate::errors::ClaudeError::Transport(format!("Tool not found: {}", tool_name))
})?;
let result = tool.handler.handle(arguments).await?;
Ok(serde_json::json!({
"content": result.content,
"is_error": result.is_error
}))
}
_ => {
Err(crate::errors::McpError::method_not_found(method).into())
}
}
}
fn list_tools(&self) -> Vec<ToolDefinition> {
self.tools
.values()
.map(|t| ToolDefinition {
name: t.name.clone(),
description: t.description.clone(),
input_schema: t.input_schema.clone(),
annotations: t.annotations.clone(),
})
.collect()
}
}
#[macro_export]
macro_rules! tool {
($name:expr, $desc:expr, $schema:expr, $handler:expr) => {{
struct Handler<F>(F);
impl<F, Fut> $crate::types::mcp::ToolHandler for Handler<F>
where
F: Fn(serde_json::Value) -> Fut + Send + Sync,
Fut: std::future::Future<Output = anyhow::Result<$crate::types::mcp::ToolResult>>
+ Send
+ 'static,
{
fn handle(
&self,
args: serde_json::Value,
) -> futures::future::BoxFuture<
'static,
$crate::errors::Result<$crate::types::mcp::ToolResult>,
> {
use futures::FutureExt;
let f = &self.0;
let fut = f(args);
async move { fut.await.map_err(|e| e.into()) }.boxed()
}
}
$crate::types::mcp::SdkMcpTool {
name: $name.to_string(),
description: $desc.to_string(),
input_schema: $schema,
handler: std::sync::Arc::new(Handler($handler)),
annotations: None,
}
}};
($name:expr, $desc:expr, $schema:expr, $handler:expr, $annotations:expr) => {{
struct Handler<F>(F);
impl<F, Fut> $crate::types::mcp::ToolHandler for Handler<F>
where
F: Fn(serde_json::Value) -> Fut + Send + Sync,
Fut: std::future::Future<Output = anyhow::Result<$crate::types::mcp::ToolResult>>
+ Send
+ 'static,
{
fn handle(
&self,
args: serde_json::Value,
) -> futures::future::BoxFuture<
'static,
$crate::errors::Result<$crate::types::mcp::ToolResult>,
> {
use futures::FutureExt;
let f = &self.0;
let fut = f(args);
async move { fut.await.map_err(|e| e.into()) }.boxed()
}
}
$crate::types::mcp::SdkMcpTool {
name: $name.to_string(),
description: $desc.to_string(),
input_schema: $schema,
handler: std::sync::Arc::new(Handler($handler)),
annotations: Some($annotations),
}
}};
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_tool_annotations_serialization() {
let annotations = ToolAnnotations::builder()
.read_only_hint(true)
.destructive_hint(false)
.idempotent_hint(true)
.open_world_hint(false)
.build();
let json = serde_json::to_value(&annotations).unwrap();
assert_eq!(json["readOnlyHint"], true);
assert_eq!(json["destructiveHint"], false);
assert_eq!(json["idempotentHint"], true);
assert_eq!(json["openWorldHint"], false);
}
#[test]
fn test_tool_annotations_optional_fields() {
let annotations = ToolAnnotations::builder().read_only_hint(true).build();
let json = serde_json::to_value(&annotations).unwrap();
assert_eq!(json["readOnlyHint"], true);
assert!(json.get("destructiveHint").is_none());
assert!(json.get("idempotentHint").is_none());
assert!(json.get("openWorldHint").is_none());
}
#[test]
fn test_tool_annotations_deserialization() {
let json_str = r#"{
"readOnlyHint": true,
"destructiveHint": false
}"#;
let annotations: ToolAnnotations = serde_json::from_str(json_str).unwrap();
assert_eq!(annotations.read_only_hint, Some(true));
assert_eq!(annotations.destructive_hint, Some(false));
assert_eq!(annotations.idempotent_hint, None);
assert_eq!(annotations.open_world_hint, None);
}
#[test]
fn test_tool_annotations_default() {
let annotations = ToolAnnotations::default();
assert_eq!(annotations.read_only_hint, None);
assert_eq!(annotations.destructive_hint, None);
assert_eq!(annotations.idempotent_hint, None);
assert_eq!(annotations.open_world_hint, None);
}
#[test]
fn test_tool_definition_with_annotations() {
let def = ToolDefinition {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
input_schema: json!({"type": "object"}),
annotations: Some(ToolAnnotations::builder().read_only_hint(true).build()),
};
assert_eq!(def.name, "test_tool");
assert!(def.annotations.is_some());
assert_eq!(def.annotations.unwrap().read_only_hint, Some(true));
}
#[test]
fn test_tool_definition_without_annotations() {
let def = ToolDefinition {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
input_schema: json!({"type": "object"}),
annotations: None,
};
assert!(def.annotations.is_none());
}
}