mcp_tools/common/
protocol.rs

1//! MCP Protocol implementation for MCP Tools
2
3use super::*;
4use crate::{McpToolsError, Result};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use uuid::Uuid;
8
9/// MCP Protocol version
10pub const MCP_PROTOCOL_VERSION: &str = "1.0";
11
12/// MCP message types
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum McpMessageType {
16    // Initialization
17    Initialize,
18    InitializeResult,
19
20    // Capabilities
21    GetCapabilities,
22    CapabilitiesResult,
23
24    // Tool execution
25    CallTool,
26    ToolResult,
27
28    // Notifications
29    Notification,
30
31    // Errors
32    Error,
33}
34
35/// MCP protocol message
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct McpMessage {
38    /// Message ID (for request-response correlation)
39    pub id: Option<String>,
40
41    /// Message type
42    #[serde(rename = "type")]
43    pub message_type: McpMessageType,
44
45    /// Message payload
46    pub payload: serde_json::Value,
47
48    /// Protocol version
49    pub version: String,
50
51    /// Message metadata
52    #[serde(default)]
53    pub metadata: HashMap<String, serde_json::Value>,
54}
55
56/// Initialize request payload
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct InitializeRequest {
59    /// Client information
60    pub client_info: ClientInfo,
61
62    /// Client capabilities
63    pub capabilities: ClientCapabilities,
64
65    /// Protocol version
66    pub protocol_version: String,
67}
68
69/// Initialize response payload
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct InitializeResponse {
72    /// Server information
73    pub server_info: ServerInfo,
74
75    /// Server capabilities
76    pub capabilities: ServerCapabilities,
77
78    /// Protocol version
79    pub protocol_version: String,
80}
81
82/// Tool call request payload
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct CallToolRequest {
85    /// Tool name
86    pub name: String,
87
88    /// Tool arguments
89    pub arguments: serde_json::Value,
90
91    /// Session ID
92    pub session_id: String,
93}
94
95/// Tool call response payload
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct ToolResultResponse {
98    /// Tool execution result
99    pub content: Vec<McpContent>,
100
101    /// Whether the operation was successful
102    pub is_error: bool,
103
104    /// Error message if any
105    pub error: Option<String>,
106}
107
108/// Error response payload
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ErrorResponse {
111    /// Error code
112    pub code: i32,
113
114    /// Error message
115    pub message: String,
116
117    /// Additional error data
118    pub data: Option<serde_json::Value>,
119}
120
121/// MCP Protocol handler
122pub struct McpProtocol {
123    /// Protocol version
124    version: String,
125
126    /// Message ID counter
127    message_counter: std::sync::atomic::AtomicU64,
128}
129
130impl McpProtocol {
131    /// Create new MCP protocol handler
132    pub fn new() -> Self {
133        Self {
134            version: MCP_PROTOCOL_VERSION.to_string(),
135            message_counter: std::sync::atomic::AtomicU64::new(0),
136        }
137    }
138
139    /// Generate unique message ID
140    pub fn generate_message_id(&self) -> String {
141        let counter = self
142            .message_counter
143            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
144        format!(
145            "msg-{}-{}",
146            std::time::SystemTime::now()
147                .duration_since(std::time::UNIX_EPOCH)
148                .unwrap()
149                .as_millis(),
150            counter
151        )
152    }
153
154    /// Create initialize request message
155    pub fn create_initialize_request(
156        &self,
157        client_info: ClientInfo,
158        capabilities: ClientCapabilities,
159    ) -> McpMessage {
160        let payload = InitializeRequest {
161            client_info,
162            capabilities,
163            protocol_version: self.version.clone(),
164        };
165
166        McpMessage {
167            id: Some(self.generate_message_id()),
168            message_type: McpMessageType::Initialize,
169            payload: serde_json::to_value(payload).unwrap(),
170            version: self.version.clone(),
171            metadata: HashMap::new(),
172        }
173    }
174
175    /// Create initialize response message
176    pub fn create_initialize_response(
177        &self,
178        request_id: &str,
179        server_info: ServerInfo,
180        capabilities: ServerCapabilities,
181    ) -> McpMessage {
182        let payload = InitializeResponse {
183            server_info,
184            capabilities,
185            protocol_version: self.version.clone(),
186        };
187
188        McpMessage {
189            id: Some(request_id.to_string()),
190            message_type: McpMessageType::InitializeResult,
191            payload: serde_json::to_value(payload).unwrap(),
192            version: self.version.clone(),
193            metadata: HashMap::new(),
194        }
195    }
196
197    /// Create tool call request message
198    pub fn create_tool_call_request(
199        &self,
200        tool_name: &str,
201        arguments: serde_json::Value,
202        session_id: &str,
203    ) -> McpMessage {
204        let payload = CallToolRequest {
205            name: tool_name.to_string(),
206            arguments,
207            session_id: session_id.to_string(),
208        };
209
210        McpMessage {
211            id: Some(self.generate_message_id()),
212            message_type: McpMessageType::CallTool,
213            payload: serde_json::to_value(payload).unwrap(),
214            version: self.version.clone(),
215            metadata: HashMap::new(),
216        }
217    }
218
219    /// Create tool result response message
220    pub fn create_tool_result_response(
221        &self,
222        request_id: &str,
223        content: Vec<McpContent>,
224        is_error: bool,
225        error: Option<String>,
226    ) -> McpMessage {
227        let payload = ToolResultResponse {
228            content,
229            is_error,
230            error,
231        };
232
233        McpMessage {
234            id: Some(request_id.to_string()),
235            message_type: McpMessageType::ToolResult,
236            payload: serde_json::to_value(payload).unwrap(),
237            version: self.version.clone(),
238            metadata: HashMap::new(),
239        }
240    }
241
242    /// Create error response message
243    pub fn create_error_response(
244        &self,
245        request_id: Option<&str>,
246        code: i32,
247        message: &str,
248        data: Option<serde_json::Value>,
249    ) -> McpMessage {
250        let payload = ErrorResponse {
251            code,
252            message: message.to_string(),
253            data,
254        };
255
256        McpMessage {
257            id: request_id.map(|s| s.to_string()),
258            message_type: McpMessageType::Error,
259            payload: serde_json::to_value(payload).unwrap(),
260            version: self.version.clone(),
261            metadata: HashMap::new(),
262        }
263    }
264
265    /// Parse MCP message from JSON
266    pub fn parse_message(&self, json: &str) -> Result<McpMessage> {
267        serde_json::from_str(json).map_err(|e| McpToolsError::Serialization(e))
268    }
269
270    /// Serialize MCP message to JSON
271    pub fn serialize_message(&self, message: &McpMessage) -> Result<String> {
272        serde_json::to_string(message).map_err(|e| McpToolsError::Serialization(e))
273    }
274
275    /// Validate message protocol version
276    pub fn validate_version(&self, message: &McpMessage) -> Result<()> {
277        if message.version != self.version {
278            return Err(McpToolsError::Server(format!(
279                "Protocol version mismatch: expected {}, got {}",
280                self.version, message.version
281            )));
282        }
283        Ok(())
284    }
285}
286
287impl Default for McpProtocol {
288    fn default() -> Self {
289        Self::new()
290    }
291}
292
293/// Error codes for MCP protocol
294pub mod error_codes {
295    pub const PARSE_ERROR: i32 = -32700;
296    pub const INVALID_REQUEST: i32 = -32600;
297    pub const METHOD_NOT_FOUND: i32 = -32601;
298    pub const INVALID_PARAMS: i32 = -32602;
299    pub const INTERNAL_ERROR: i32 = -32603;
300
301    // Custom error codes
302    pub const PERMISSION_DENIED: i32 = -32000;
303    pub const TOOL_NOT_FOUND: i32 = -32001;
304    pub const TOOL_EXECUTION_ERROR: i32 = -32002;
305    pub const SESSION_ERROR: i32 = -32003;
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_protocol_creation() {
314        let protocol = McpProtocol::new();
315        assert_eq!(protocol.version, MCP_PROTOCOL_VERSION);
316    }
317
318    #[test]
319    fn test_message_id_generation() {
320        let protocol = McpProtocol::new();
321        let id1 = protocol.generate_message_id();
322        let id2 = protocol.generate_message_id();
323
324        assert_ne!(id1, id2);
325        assert!(id1.starts_with("msg-"));
326        assert!(id2.starts_with("msg-"));
327    }
328
329    #[test]
330    fn test_initialize_request_creation() {
331        let protocol = McpProtocol::new();
332        let client_info = ClientInfo {
333            name: "Test Client".to_string(),
334            version: "1.0.0".to_string(),
335            description: "Test".to_string(),
336        };
337        let capabilities = ClientCapabilities {
338            content_types: vec!["text".to_string()],
339            features: vec!["test".to_string()],
340            info: client_info.clone(),
341        };
342
343        let message = protocol.create_initialize_request(client_info, capabilities);
344
345        assert_eq!(message.message_type, McpMessageType::Initialize);
346        assert!(message.id.is_some());
347        assert_eq!(message.version, MCP_PROTOCOL_VERSION);
348    }
349
350    #[test]
351    fn test_message_serialization() {
352        let protocol = McpProtocol::new();
353        let message = McpMessage {
354            id: Some("test-id".to_string()),
355            message_type: McpMessageType::Notification,
356            payload: serde_json::json!({"test": "data"}),
357            version: MCP_PROTOCOL_VERSION.to_string(),
358            metadata: HashMap::new(),
359        };
360
361        let json = protocol.serialize_message(&message).unwrap();
362        let parsed = protocol.parse_message(&json).unwrap();
363
364        assert_eq!(parsed.id, message.id);
365        assert_eq!(parsed.message_type, message.message_type);
366        assert_eq!(parsed.version, message.version);
367    }
368}