cc_sdk/
sdk_mcp.rs

1#![allow(missing_docs)]
2//! SDK MCP Server - In-process MCP server implementation
3//!
4//! This module provides an in-process MCP server that runs directly within your
5//! Rust application, eliminating the need for separate processes.
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use serde_json::{json, Value};
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use crate::errors::{Result, SdkError};
14
15/// Tool input schema definition
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ToolInputSchema {
18    #[serde(rename = "type")]
19    pub schema_type: String,
20    pub properties: HashMap<String, Value>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub required: Option<Vec<String>>,
23}
24
25/// Tool definition
26#[derive(Clone)]
27pub struct ToolDefinition {
28    pub name: String,
29    pub description: String,
30    pub input_schema: ToolInputSchema,
31    pub handler: Arc<dyn ToolHandler>,
32}
33
34impl std::fmt::Debug for ToolDefinition {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        f.debug_struct("ToolDefinition")
37            .field("name", &self.name)
38            .field("description", &self.description)
39            .field("input_schema", &self.input_schema)
40            .field("handler", &"<Arc<dyn ToolHandler>>")
41            .finish()
42    }
43}
44
45/// Tool handler trait
46#[async_trait]
47pub trait ToolHandler: Send + Sync {
48    async fn execute(&self, args: Value) -> Result<ToolResult>;
49}
50
51/// Tool execution result
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ToolResult {
54    pub content: Vec<ToolResultContent>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub is_error: Option<bool>,
57}
58
59/// Tool result content types
60#[derive(Debug, Clone, Serialize, Deserialize)]
61#[serde(tag = "type")]
62pub enum ToolResultContent {
63    #[serde(rename = "text")]
64    Text { text: String },
65    #[serde(rename = "image")]
66    Image {
67        data: String,
68        #[serde(rename = "mimeType")]
69        mime_type: String,
70    },
71}
72
73/// SDK MCP Server
74pub struct SdkMcpServer {
75    pub name: String,
76    pub version: String,
77    pub tools: Vec<ToolDefinition>,
78}
79
80impl SdkMcpServer {
81    /// Create a new SDK MCP server
82    pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
83        Self {
84            name: name.into(),
85            version: version.into(),
86            tools: Vec::new(),
87        }
88    }
89
90    /// Add a tool to the server
91    pub fn add_tool(&mut self, tool: ToolDefinition) {
92        self.tools.push(tool);
93    }
94
95    /// Handle MCP protocol messages
96    pub async fn handle_message(&self, message: Value) -> Result<Value> {
97        let method = message
98            .get("method")
99            .and_then(|m| m.as_str())
100            .ok_or_else(|| SdkError::InvalidState {
101                message: "Missing method in MCP message".to_string(),
102            })?;
103
104        let id = message.get("id");
105
106        match method {
107            "initialize" => Ok(json!({
108                "jsonrpc": "2.0",
109                "id": id,
110                "result": {
111                    "protocolVersion": "2024-11-05",
112                    "capabilities": {
113                        "tools": {}
114                    },
115                    "serverInfo": {
116                        "name": self.name,
117                        "version": self.version
118                    }
119                }
120            })),
121
122            "tools/list" => {
123                let tools: Vec<Value> = self
124                    .tools
125                    .iter()
126                    .map(|tool| {
127                        json!({
128                            "name": tool.name,
129                            "description": tool.description,
130                            "inputSchema": tool.input_schema
131                        })
132                    })
133                    .collect();
134
135                Ok(json!({
136                    "jsonrpc": "2.0",
137                    "id": id,
138                    "result": {
139                        "tools": tools
140                    }
141                }))
142            }
143
144            "tools/call" => {
145                let params = message.get("params").ok_or_else(|| SdkError::InvalidState {
146                    message: "Missing params in tools/call".to_string(),
147                })?;
148
149                let tool_name = params
150                    .get("name")
151                    .and_then(|n| n.as_str())
152                    .ok_or_else(|| SdkError::InvalidState {
153                        message: "Missing tool name in tools/call".to_string(),
154                    })?;
155
156                let empty_args = json!({});
157                let arguments = params.get("arguments").unwrap_or(&empty_args);
158
159                // Find and execute the tool
160                let tool = self
161                    .tools
162                    .iter()
163                    .find(|t| t.name == tool_name)
164                    .ok_or_else(|| SdkError::InvalidState {
165                        message: format!("Tool not found: {tool_name}"),
166                    })?;
167
168                let result = tool.handler.execute(arguments.clone()).await?;
169
170                Ok(json!({
171                    "jsonrpc": "2.0",
172                    "id": id,
173                    "result": {
174                        "content": result.content,
175                        "isError": result.is_error
176                    }
177                }))
178            }
179
180            "notifications/initialized" => {
181                // Acknowledge initialization notification
182                Ok(json!({
183                    "jsonrpc": "2.0",
184                    "result": {}
185                }))
186            }
187
188            _ => Ok(json!({
189                "jsonrpc": "2.0",
190                "id": id,
191                "error": {
192                    "code": -32601,
193                    "message": format!("Method '{}' not found", method)
194                }
195            })),
196        }
197    }
198}
199
200impl SdkMcpServer {
201    /// Convert to McpServerConfig
202    pub fn to_config(self) -> crate::types::McpServerConfig {
203        use std::sync::Arc;
204        crate::types::McpServerConfig::Sdk {
205            name: self.name.clone(),
206            instance: Arc::new(self),
207        }
208    }
209}
210
211/// Builder for creating SDK MCP servers
212pub struct SdkMcpServerBuilder {
213    name: String,
214    version: String,
215    tools: Vec<ToolDefinition>,
216}
217
218impl SdkMcpServerBuilder {
219    /// Create a new builder
220    pub fn new(name: impl Into<String>) -> Self {
221        Self {
222            name: name.into(),
223            version: "1.0.0".to_string(),
224            tools: Vec::new(),
225        }
226    }
227
228    /// Set server version
229    pub fn version(mut self, version: impl Into<String>) -> Self {
230        self.version = version.into();
231        self
232    }
233
234    /// Add a tool
235    pub fn tool(mut self, tool: ToolDefinition) -> Self {
236        self.tools.push(tool);
237        self
238    }
239
240    /// Build the server
241    pub fn build(self) -> SdkMcpServer {
242        SdkMcpServer {
243            name: self.name,
244            version: self.version,
245            tools: self.tools,
246        }
247    }
248}
249
250/// Helper function to create a simple text-based tool
251pub fn create_simple_tool<F, Fut>(
252    name: impl Into<String>,
253    description: impl Into<String>,
254    schema: ToolInputSchema,
255    handler: F,
256) -> ToolDefinition
257where
258    F: Fn(Value) -> Fut + Send + Sync + 'static,
259    Fut: std::future::Future<Output = Result<String>> + Send + 'static,
260{
261    struct SimpleHandler<F, Fut>
262    where
263        F: Fn(Value) -> Fut + Send + Sync,
264        Fut: std::future::Future<Output = Result<String>> + Send,
265    {
266        func: F,
267    }
268
269    #[async_trait]
270    impl<F, Fut> ToolHandler for SimpleHandler<F, Fut>
271    where
272        F: Fn(Value) -> Fut + Send + Sync,
273        Fut: std::future::Future<Output = Result<String>> + Send,
274    {
275        async fn execute(&self, args: Value) -> Result<ToolResult> {
276            let text = (self.func)(args).await?;
277            Ok(ToolResult {
278                content: vec![ToolResultContent::Text { text }],
279                is_error: None,
280            })
281        }
282    }
283
284    ToolDefinition {
285        name: name.into(),
286        description: description.into(),
287        input_schema: schema,
288        handler: Arc::new(SimpleHandler { func: handler }),
289    }
290}
291
292/// Macro to define a tool with a simple syntax
293#[macro_export]
294macro_rules! tool {
295    ($name:expr, $desc:expr, $schema:expr, $handler:expr) => {
296        $crate::sdk_mcp::create_simple_tool($name, $desc, $schema, $handler)
297    };
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[tokio::test]
305    async fn test_sdk_mcp_server() {
306        let mut server = SdkMcpServer::new("test-server", "1.0.0");
307
308        // Add a simple tool
309        let tool = create_simple_tool(
310            "greet",
311            "Greet a user",
312            ToolInputSchema {
313                schema_type: "object".to_string(),
314                properties: {
315                    let mut props = HashMap::new();
316                    props.insert(
317                        "name".to_string(),
318                        json!({"type": "string", "description": "Name to greet"}),
319                    );
320                    props
321                },
322                required: Some(vec!["name".to_string()]),
323            },
324            |args| async move {
325                let name = args["name"].as_str().unwrap_or("stranger");
326                Ok(format!("Hello, {name}!"))
327            },
328        );
329
330        server.add_tool(tool);
331
332        // Test initialize
333        let init_msg = json!({
334            "jsonrpc": "2.0",
335            "id": 1,
336            "method": "initialize"
337        });
338
339        let response = server.handle_message(init_msg).await.unwrap();
340        assert_eq!(response["result"]["serverInfo"]["name"], "test-server");
341
342        // Test tools/list
343        let list_msg = json!({
344            "jsonrpc": "2.0",
345            "id": 2,
346            "method": "tools/list"
347        });
348
349        let response = server.handle_message(list_msg).await.unwrap();
350        assert_eq!(response["result"]["tools"][0]["name"], "greet");
351
352        // Test tools/call
353        let call_msg = json!({
354            "jsonrpc": "2.0",
355            "id": 3,
356            "method": "tools/call",
357            "params": {
358                "name": "greet",
359                "arguments": {
360                    "name": "Alice"
361                }
362            }
363        });
364
365        let response = server.handle_message(call_msg).await.unwrap();
366        assert_eq!(response["result"]["content"][0]["text"], "Hello, Alice!");
367    }
368}