1use super::*;
4use crate::{McpToolsError, Result};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use uuid::Uuid;
8
9pub const MCP_PROTOCOL_VERSION: &str = "1.0";
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum McpMessageType {
16 Initialize,
18 InitializeResult,
19
20 GetCapabilities,
22 CapabilitiesResult,
23
24 CallTool,
26 ToolResult,
27
28 Notification,
30
31 Error,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct McpMessage {
38 pub id: Option<String>,
40
41 #[serde(rename = "type")]
43 pub message_type: McpMessageType,
44
45 pub payload: serde_json::Value,
47
48 pub version: String,
50
51 #[serde(default)]
53 pub metadata: HashMap<String, serde_json::Value>,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct InitializeRequest {
59 pub client_info: ClientInfo,
61
62 pub capabilities: ClientCapabilities,
64
65 pub protocol_version: String,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct InitializeResponse {
72 pub server_info: ServerInfo,
74
75 pub capabilities: ServerCapabilities,
77
78 pub protocol_version: String,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct CallToolRequest {
85 pub name: String,
87
88 pub arguments: serde_json::Value,
90
91 pub session_id: String,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct ToolResultResponse {
98 pub content: Vec<McpContent>,
100
101 pub is_error: bool,
103
104 pub error: Option<String>,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ErrorResponse {
111 pub code: i32,
113
114 pub message: String,
116
117 pub data: Option<serde_json::Value>,
119}
120
121pub struct McpProtocol {
123 version: String,
125
126 message_counter: std::sync::atomic::AtomicU64,
128}
129
130impl McpProtocol {
131 pub fn new() -> Self {
133 Self {
134 version: MCP_PROTOCOL_VERSION.to_string(),
135 message_counter: std::sync::atomic::AtomicU64::new(0),
136 }
137 }
138
139 pub fn generate_message_id(&self) -> String {
141 let counter = self
142 .message_counter
143 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
144 format!(
145 "msg-{}-{}",
146 std::time::SystemTime::now()
147 .duration_since(std::time::UNIX_EPOCH)
148 .unwrap()
149 .as_millis(),
150 counter
151 )
152 }
153
154 pub fn create_initialize_request(
156 &self,
157 client_info: ClientInfo,
158 capabilities: ClientCapabilities,
159 ) -> McpMessage {
160 let payload = InitializeRequest {
161 client_info,
162 capabilities,
163 protocol_version: self.version.clone(),
164 };
165
166 McpMessage {
167 id: Some(self.generate_message_id()),
168 message_type: McpMessageType::Initialize,
169 payload: serde_json::to_value(payload).unwrap(),
170 version: self.version.clone(),
171 metadata: HashMap::new(),
172 }
173 }
174
175 pub fn create_initialize_response(
177 &self,
178 request_id: &str,
179 server_info: ServerInfo,
180 capabilities: ServerCapabilities,
181 ) -> McpMessage {
182 let payload = InitializeResponse {
183 server_info,
184 capabilities,
185 protocol_version: self.version.clone(),
186 };
187
188 McpMessage {
189 id: Some(request_id.to_string()),
190 message_type: McpMessageType::InitializeResult,
191 payload: serde_json::to_value(payload).unwrap(),
192 version: self.version.clone(),
193 metadata: HashMap::new(),
194 }
195 }
196
197 pub fn create_tool_call_request(
199 &self,
200 tool_name: &str,
201 arguments: serde_json::Value,
202 session_id: &str,
203 ) -> McpMessage {
204 let payload = CallToolRequest {
205 name: tool_name.to_string(),
206 arguments,
207 session_id: session_id.to_string(),
208 };
209
210 McpMessage {
211 id: Some(self.generate_message_id()),
212 message_type: McpMessageType::CallTool,
213 payload: serde_json::to_value(payload).unwrap(),
214 version: self.version.clone(),
215 metadata: HashMap::new(),
216 }
217 }
218
219 pub fn create_tool_result_response(
221 &self,
222 request_id: &str,
223 content: Vec<McpContent>,
224 is_error: bool,
225 error: Option<String>,
226 ) -> McpMessage {
227 let payload = ToolResultResponse {
228 content,
229 is_error,
230 error,
231 };
232
233 McpMessage {
234 id: Some(request_id.to_string()),
235 message_type: McpMessageType::ToolResult,
236 payload: serde_json::to_value(payload).unwrap(),
237 version: self.version.clone(),
238 metadata: HashMap::new(),
239 }
240 }
241
242 pub fn create_error_response(
244 &self,
245 request_id: Option<&str>,
246 code: i32,
247 message: &str,
248 data: Option<serde_json::Value>,
249 ) -> McpMessage {
250 let payload = ErrorResponse {
251 code,
252 message: message.to_string(),
253 data,
254 };
255
256 McpMessage {
257 id: request_id.map(|s| s.to_string()),
258 message_type: McpMessageType::Error,
259 payload: serde_json::to_value(payload).unwrap(),
260 version: self.version.clone(),
261 metadata: HashMap::new(),
262 }
263 }
264
265 pub fn parse_message(&self, json: &str) -> Result<McpMessage> {
267 serde_json::from_str(json).map_err(|e| McpToolsError::Serialization(e))
268 }
269
270 pub fn serialize_message(&self, message: &McpMessage) -> Result<String> {
272 serde_json::to_string(message).map_err(|e| McpToolsError::Serialization(e))
273 }
274
275 pub fn validate_version(&self, message: &McpMessage) -> Result<()> {
277 if message.version != self.version {
278 return Err(McpToolsError::Server(format!(
279 "Protocol version mismatch: expected {}, got {}",
280 self.version, message.version
281 )));
282 }
283 Ok(())
284 }
285}
286
287impl Default for McpProtocol {
288 fn default() -> Self {
289 Self::new()
290 }
291}
292
293pub mod error_codes {
295 pub const PARSE_ERROR: i32 = -32700;
296 pub const INVALID_REQUEST: i32 = -32600;
297 pub const METHOD_NOT_FOUND: i32 = -32601;
298 pub const INVALID_PARAMS: i32 = -32602;
299 pub const INTERNAL_ERROR: i32 = -32603;
300
301 pub const PERMISSION_DENIED: i32 = -32000;
303 pub const TOOL_NOT_FOUND: i32 = -32001;
304 pub const TOOL_EXECUTION_ERROR: i32 = -32002;
305 pub const SESSION_ERROR: i32 = -32003;
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_protocol_creation() {
314 let protocol = McpProtocol::new();
315 assert_eq!(protocol.version, MCP_PROTOCOL_VERSION);
316 }
317
318 #[test]
319 fn test_message_id_generation() {
320 let protocol = McpProtocol::new();
321 let id1 = protocol.generate_message_id();
322 let id2 = protocol.generate_message_id();
323
324 assert_ne!(id1, id2);
325 assert!(id1.starts_with("msg-"));
326 assert!(id2.starts_with("msg-"));
327 }
328
329 #[test]
330 fn test_initialize_request_creation() {
331 let protocol = McpProtocol::new();
332 let client_info = ClientInfo {
333 name: "Test Client".to_string(),
334 version: "1.0.0".to_string(),
335 description: "Test".to_string(),
336 };
337 let capabilities = ClientCapabilities {
338 content_types: vec!["text".to_string()],
339 features: vec!["test".to_string()],
340 info: client_info.clone(),
341 };
342
343 let message = protocol.create_initialize_request(client_info, capabilities);
344
345 assert_eq!(message.message_type, McpMessageType::Initialize);
346 assert!(message.id.is_some());
347 assert_eq!(message.version, MCP_PROTOCOL_VERSION);
348 }
349
350 #[test]
351 fn test_message_serialization() {
352 let protocol = McpProtocol::new();
353 let message = McpMessage {
354 id: Some("test-id".to_string()),
355 message_type: McpMessageType::Notification,
356 payload: serde_json::json!({"test": "data"}),
357 version: MCP_PROTOCOL_VERSION.to_string(),
358 metadata: HashMap::new(),
359 };
360
361 let json = protocol.serialize_message(&message).unwrap();
362 let parsed = protocol.parse_message(&json).unwrap();
363
364 assert_eq!(parsed.id, message.id);
365 assert_eq!(parsed.message_type, message.message_type);
366 assert_eq!(parsed.version, message.version);
367 }
368}