1#![allow(missing_docs)]
2use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use serde_json::{json, Value};
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use crate::errors::{Result, SdkError};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ToolInputSchema {
18 #[serde(rename = "type")]
19 pub schema_type: String,
20 pub properties: HashMap<String, Value>,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub required: Option<Vec<String>>,
23}
24
25#[derive(Clone)]
27pub struct ToolDefinition {
28 pub name: String,
29 pub description: String,
30 pub input_schema: ToolInputSchema,
31 pub handler: Arc<dyn ToolHandler>,
32}
33
34impl std::fmt::Debug for ToolDefinition {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("ToolDefinition")
37 .field("name", &self.name)
38 .field("description", &self.description)
39 .field("input_schema", &self.input_schema)
40 .field("handler", &"<Arc<dyn ToolHandler>>")
41 .finish()
42 }
43}
44
45#[async_trait]
47pub trait ToolHandler: Send + Sync {
48 async fn execute(&self, args: Value) -> Result<ToolResult>;
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ToolResult {
54 pub content: Vec<ToolResultContent>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub is_error: Option<bool>,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61#[serde(tag = "type")]
62pub enum ToolResultContent {
63 #[serde(rename = "text")]
64 Text { text: String },
65 #[serde(rename = "image")]
66 Image {
67 data: String,
68 #[serde(rename = "mimeType")]
69 mime_type: String,
70 },
71}
72
73pub struct SdkMcpServer {
75 pub name: String,
76 pub version: String,
77 pub tools: Vec<ToolDefinition>,
78}
79
80impl SdkMcpServer {
81 pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
83 Self {
84 name: name.into(),
85 version: version.into(),
86 tools: Vec::new(),
87 }
88 }
89
90 pub fn add_tool(&mut self, tool: ToolDefinition) {
92 self.tools.push(tool);
93 }
94
95 pub async fn handle_message(&self, message: Value) -> Result<Value> {
97 let method = message
98 .get("method")
99 .and_then(|m| m.as_str())
100 .ok_or_else(|| SdkError::InvalidState {
101 message: "Missing method in MCP message".to_string(),
102 })?;
103
104 let id = message.get("id");
105
106 match method {
107 "initialize" => Ok(json!({
108 "jsonrpc": "2.0",
109 "id": id,
110 "result": {
111 "protocolVersion": "2024-11-05",
112 "capabilities": {
113 "tools": {}
114 },
115 "serverInfo": {
116 "name": self.name,
117 "version": self.version
118 }
119 }
120 })),
121
122 "tools/list" => {
123 let tools: Vec<Value> = self
124 .tools
125 .iter()
126 .map(|tool| {
127 json!({
128 "name": tool.name,
129 "description": tool.description,
130 "inputSchema": tool.input_schema
131 })
132 })
133 .collect();
134
135 Ok(json!({
136 "jsonrpc": "2.0",
137 "id": id,
138 "result": {
139 "tools": tools
140 }
141 }))
142 }
143
144 "tools/call" => {
145 let params = message.get("params").ok_or_else(|| SdkError::InvalidState {
146 message: "Missing params in tools/call".to_string(),
147 })?;
148
149 let tool_name = params
150 .get("name")
151 .and_then(|n| n.as_str())
152 .ok_or_else(|| SdkError::InvalidState {
153 message: "Missing tool name in tools/call".to_string(),
154 })?;
155
156 let empty_args = json!({});
157 let arguments = params.get("arguments").unwrap_or(&empty_args);
158
159 let tool = self
161 .tools
162 .iter()
163 .find(|t| t.name == tool_name)
164 .ok_or_else(|| SdkError::InvalidState {
165 message: format!("Tool not found: {tool_name}"),
166 })?;
167
168 let result = tool.handler.execute(arguments.clone()).await?;
169
170 Ok(json!({
171 "jsonrpc": "2.0",
172 "id": id,
173 "result": {
174 "content": result.content,
175 "isError": result.is_error
176 }
177 }))
178 }
179
180 "notifications/initialized" => {
181 Ok(json!({
183 "jsonrpc": "2.0",
184 "result": {}
185 }))
186 }
187
188 _ => Ok(json!({
189 "jsonrpc": "2.0",
190 "id": id,
191 "error": {
192 "code": -32601,
193 "message": format!("Method '{}' not found", method)
194 }
195 })),
196 }
197 }
198}
199
200impl SdkMcpServer {
201 pub fn to_config(self) -> crate::types::McpServerConfig {
203 use std::sync::Arc;
204 crate::types::McpServerConfig::Sdk {
205 name: self.name.clone(),
206 instance: Arc::new(self),
207 }
208 }
209}
210
211pub struct SdkMcpServerBuilder {
213 name: String,
214 version: String,
215 tools: Vec<ToolDefinition>,
216}
217
218impl SdkMcpServerBuilder {
219 pub fn new(name: impl Into<String>) -> Self {
221 Self {
222 name: name.into(),
223 version: "1.0.0".to_string(),
224 tools: Vec::new(),
225 }
226 }
227
228 pub fn version(mut self, version: impl Into<String>) -> Self {
230 self.version = version.into();
231 self
232 }
233
234 pub fn tool(mut self, tool: ToolDefinition) -> Self {
236 self.tools.push(tool);
237 self
238 }
239
240 pub fn build(self) -> SdkMcpServer {
242 SdkMcpServer {
243 name: self.name,
244 version: self.version,
245 tools: self.tools,
246 }
247 }
248}
249
250pub fn create_simple_tool<F, Fut>(
252 name: impl Into<String>,
253 description: impl Into<String>,
254 schema: ToolInputSchema,
255 handler: F,
256) -> ToolDefinition
257where
258 F: Fn(Value) -> Fut + Send + Sync + 'static,
259 Fut: std::future::Future<Output = Result<String>> + Send + 'static,
260{
261 struct SimpleHandler<F, Fut>
262 where
263 F: Fn(Value) -> Fut + Send + Sync,
264 Fut: std::future::Future<Output = Result<String>> + Send,
265 {
266 func: F,
267 }
268
269 #[async_trait]
270 impl<F, Fut> ToolHandler for SimpleHandler<F, Fut>
271 where
272 F: Fn(Value) -> Fut + Send + Sync,
273 Fut: std::future::Future<Output = Result<String>> + Send,
274 {
275 async fn execute(&self, args: Value) -> Result<ToolResult> {
276 let text = (self.func)(args).await?;
277 Ok(ToolResult {
278 content: vec![ToolResultContent::Text { text }],
279 is_error: None,
280 })
281 }
282 }
283
284 ToolDefinition {
285 name: name.into(),
286 description: description.into(),
287 input_schema: schema,
288 handler: Arc::new(SimpleHandler { func: handler }),
289 }
290}
291
292#[macro_export]
294macro_rules! tool {
295 ($name:expr, $desc:expr, $schema:expr, $handler:expr) => {
296 $crate::sdk_mcp::create_simple_tool($name, $desc, $schema, $handler)
297 };
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[tokio::test]
305 async fn test_sdk_mcp_server() {
306 let mut server = SdkMcpServer::new("test-server", "1.0.0");
307
308 let tool = create_simple_tool(
310 "greet",
311 "Greet a user",
312 ToolInputSchema {
313 schema_type: "object".to_string(),
314 properties: {
315 let mut props = HashMap::new();
316 props.insert(
317 "name".to_string(),
318 json!({"type": "string", "description": "Name to greet"}),
319 );
320 props
321 },
322 required: Some(vec!["name".to_string()]),
323 },
324 |args| async move {
325 let name = args["name"].as_str().unwrap_or("stranger");
326 Ok(format!("Hello, {name}!"))
327 },
328 );
329
330 server.add_tool(tool);
331
332 let init_msg = json!({
334 "jsonrpc": "2.0",
335 "id": 1,
336 "method": "initialize"
337 });
338
339 let response = server.handle_message(init_msg).await.unwrap();
340 assert_eq!(response["result"]["serverInfo"]["name"], "test-server");
341
342 let list_msg = json!({
344 "jsonrpc": "2.0",
345 "id": 2,
346 "method": "tools/list"
347 });
348
349 let response = server.handle_message(list_msg).await.unwrap();
350 assert_eq!(response["result"]["tools"][0]["name"], "greet");
351
352 let call_msg = json!({
354 "jsonrpc": "2.0",
355 "id": 3,
356 "method": "tools/call",
357 "params": {
358 "name": "greet",
359 "arguments": {
360 "name": "Alice"
361 }
362 }
363 });
364
365 let response = server.handle_message(call_msg).await.unwrap();
366 assert_eq!(response["result"]["content"][0]["text"], "Hello, Alice!");
367 }
368}