1use async_trait::async_trait;
4use futures::future::BoxFuture;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10use crate::errors::Result;
11
12#[derive(Clone, Default)]
14pub enum McpServers {
15 #[default]
17 Empty,
18 Dict(HashMap<String, McpServerConfig>),
20 Path(PathBuf),
22}
23
24#[derive(Clone)]
26pub enum McpServerConfig {
27 Stdio(McpStdioServerConfig),
29 Sse(McpSseServerConfig),
31 Http(McpHttpServerConfig),
33 Sdk(McpSdkServerConfig),
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct McpStdioServerConfig {
40 pub command: String,
42 #[serde(skip_serializing_if = "Option::is_none")]
44 pub args: Option<Vec<String>>,
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub env: Option<HashMap<String, String>>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct McpSseServerConfig {
53 pub url: String,
55 #[serde(skip_serializing_if = "Option::is_none")]
57 pub headers: Option<HashMap<String, String>>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct McpHttpServerConfig {
63 pub url: String,
65 #[serde(skip_serializing_if = "Option::is_none")]
67 pub headers: Option<HashMap<String, String>>,
68}
69
70#[derive(Clone)]
72pub struct McpSdkServerConfig {
73 pub name: String,
75 pub instance: Arc<dyn SdkMcpServer>,
77}
78
79#[async_trait]
81pub trait SdkMcpServer: Send + Sync {
82 async fn handle_message(&self, message: serde_json::Value) -> Result<serde_json::Value>;
84}
85
86pub trait ToolHandler: Send + Sync {
88 fn handle(&self, args: serde_json::Value) -> BoxFuture<'static, Result<ToolResult>>;
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct ToolResult {
95 pub content: Vec<ToolResultContent>,
97 #[serde(default)]
99 pub is_error: bool,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104#[serde(tag = "type", rename_all = "lowercase")]
105pub enum ToolResultContent {
106 Text {
108 text: String,
110 },
111 Image {
113 data: String,
115 mime_type: String,
117 },
118}
119
120pub struct SdkMcpTool {
122 pub name: String,
124 pub description: String,
126 pub input_schema: serde_json::Value,
128 pub handler: Arc<dyn ToolHandler>,
130}
131
132pub fn create_sdk_mcp_server(
134 name: impl Into<String>,
135 version: impl Into<String>,
136 tools: Vec<SdkMcpTool>,
137) -> McpSdkServerConfig {
138 let server = DefaultSdkMcpServer {
139 name: name.into(),
140 version: version.into(),
141 tools: tools.into_iter().map(|t| (t.name.clone(), t)).collect(),
142 };
143
144 McpSdkServerConfig {
145 name: server.name.clone(),
146 instance: Arc::new(server),
147 }
148}
149
150struct DefaultSdkMcpServer {
152 name: String,
153 version: String,
154 tools: HashMap<String, SdkMcpTool>,
155}
156
157#[async_trait]
158impl SdkMcpServer for DefaultSdkMcpServer {
159 async fn handle_message(&self, message: serde_json::Value) -> Result<serde_json::Value> {
160 let method = message["method"]
162 .as_str()
163 .ok_or_else(|| crate::errors::ClaudeError::Transport("Missing method".to_string()))?;
164
165 match method {
166 "initialize" => {
167 Ok(serde_json::json!({
169 "protocolVersion": "2024-11-05",
170 "capabilities": {
171 "tools": {}
172 },
173 "serverInfo": {
174 "name": self.name,
175 "version": self.version
176 }
177 }))
178 },
179 "tools/list" => {
180 let tools: Vec<_> = self
182 .tools
183 .values()
184 .map(|t| {
185 serde_json::json!({
186 "name": t.name,
187 "description": t.description,
188 "inputSchema": t.input_schema
189 })
190 })
191 .collect();
192
193 Ok(serde_json::json!({
194 "tools": tools
195 }))
196 },
197 "tools/call" => {
198 let params = &message["params"];
200 let tool_name = params["name"].as_str().ok_or_else(|| {
201 crate::errors::ClaudeError::Transport("Missing tool name".to_string())
202 })?;
203 let arguments = params["arguments"].clone();
204
205 let tool = self.tools.get(tool_name).ok_or_else(|| {
206 crate::errors::ClaudeError::Transport(format!("Tool not found: {}", tool_name))
207 })?;
208
209 let result = tool.handler.handle(arguments).await?;
210
211 Ok(serde_json::json!({
212 "content": result.content,
213 "isError": result.is_error
214 }))
215 },
216 _ => Err(crate::errors::ClaudeError::Transport(format!(
217 "Unknown method: {}",
218 method
219 ))),
220 }
221 }
222}
223
224#[macro_export]
226macro_rules! tool {
227 ($name:expr, $desc:expr, $schema:expr, $handler:expr) => {{
228 struct Handler<F>(F);
229
230 impl<F, Fut> $crate::types::mcp::ToolHandler for Handler<F>
231 where
232 F: Fn(serde_json::Value) -> Fut + Send + Sync,
233 Fut: std::future::Future<Output = anyhow::Result<$crate::types::mcp::ToolResult>>
234 + Send
235 + 'static,
236 {
237 fn handle(
238 &self,
239 args: serde_json::Value,
240 ) -> futures::future::BoxFuture<
241 'static,
242 $crate::errors::Result<$crate::types::mcp::ToolResult>,
243 > {
244 use futures::FutureExt;
245 let f = &self.0;
246 let fut = f(args);
247 async move { fut.await.map_err(|e| e.into()) }.boxed()
248 }
249 }
250
251 $crate::types::mcp::SdkMcpTool {
252 name: $name.to_string(),
253 description: $desc.to_string(),
254 input_schema: $schema,
255 handler: std::sync::Arc::new(Handler($handler)),
256 }
257 }};
258}