1use async_trait::async_trait;
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::sync::Arc;
22use std::time::Duration;
23use tokio::sync::mpsc;
24
25use crate::mcp::error::{McpError, McpResult};
26use crate::mcp::types::{ConnectionOptions, TransportType};
27
28pub type RequestId = serde_json::Value;
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct McpRequest {
34 pub jsonrpc: String,
36 pub id: RequestId,
38 pub method: String,
40 #[serde(skip_serializing_if = "Option::is_none")]
42 pub params: Option<serde_json::Value>,
43}
44
45impl McpRequest {
46 pub fn new(id: impl Into<RequestId>, method: impl Into<String>) -> Self {
48 Self {
49 jsonrpc: "2.0".to_string(),
50 id: id.into(),
51 method: method.into(),
52 params: None,
53 }
54 }
55
56 pub fn with_params(
58 id: impl Into<RequestId>,
59 method: impl Into<String>,
60 params: serde_json::Value,
61 ) -> Self {
62 Self {
63 jsonrpc: "2.0".to_string(),
64 id: id.into(),
65 method: method.into(),
66 params: Some(params),
67 }
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct McpResponse {
74 pub jsonrpc: String,
76 pub id: RequestId,
78 #[serde(skip_serializing_if = "Option::is_none")]
80 pub result: Option<serde_json::Value>,
81 #[serde(skip_serializing_if = "Option::is_none")]
83 pub error: Option<McpErrorData>,
84}
85
86impl McpResponse {
87 pub fn success(id: RequestId, result: serde_json::Value) -> Self {
89 Self {
90 jsonrpc: "2.0".to_string(),
91 id,
92 result: Some(result),
93 error: None,
94 }
95 }
96
97 pub fn error(id: RequestId, error: McpErrorData) -> Self {
99 Self {
100 jsonrpc: "2.0".to_string(),
101 id,
102 result: None,
103 error: Some(error),
104 }
105 }
106
107 pub fn is_error(&self) -> bool {
109 self.error.is_some()
110 }
111
112 pub fn into_result(self) -> McpResult<serde_json::Value> {
114 if let Some(error) = self.error {
115 Err(McpError::server(error.code, error.message, error.data))
116 } else {
117 self.result
118 .ok_or_else(|| McpError::protocol("Response contains neither result nor error"))
119 }
120 }
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct McpErrorData {
126 pub code: i32,
128 pub message: String,
130 #[serde(skip_serializing_if = "Option::is_none")]
132 pub data: Option<serde_json::Value>,
133}
134
135impl McpErrorData {
136 pub fn new(code: i32, message: impl Into<String>) -> Self {
138 Self {
139 code,
140 message: message.into(),
141 data: None,
142 }
143 }
144
145 pub fn with_data(code: i32, message: impl Into<String>, data: serde_json::Value) -> Self {
147 Self {
148 code,
149 message: message.into(),
150 data: Some(data),
151 }
152 }
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct McpNotification {
158 pub jsonrpc: String,
160 pub method: String,
162 #[serde(skip_serializing_if = "Option::is_none")]
164 pub params: Option<serde_json::Value>,
165}
166
167impl McpNotification {
168 pub fn new(method: impl Into<String>) -> Self {
170 Self {
171 jsonrpc: "2.0".to_string(),
172 method: method.into(),
173 params: None,
174 }
175 }
176
177 pub fn with_params(method: impl Into<String>, params: serde_json::Value) -> Self {
179 Self {
180 jsonrpc: "2.0".to_string(),
181 method: method.into(),
182 params: Some(params),
183 }
184 }
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
189#[serde(untagged)]
190pub enum McpMessage {
191 Request(McpRequest),
193 Response(McpResponse),
195 Notification(McpNotification),
197}
198
199impl McpMessage {
200 pub fn id(&self) -> Option<&RequestId> {
202 match self {
203 McpMessage::Request(req) => Some(&req.id),
204 McpMessage::Response(resp) => Some(&resp.id),
205 McpMessage::Notification(_) => None,
206 }
207 }
208
209 pub fn method(&self) -> Option<&str> {
211 match self {
212 McpMessage::Request(req) => Some(&req.method),
213 McpMessage::Response(_) => None,
214 McpMessage::Notification(notif) => Some(¬if.method),
215 }
216 }
217}
218
219#[derive(Debug, Clone)]
221pub enum TransportConfig {
222 Stdio {
224 command: String,
226 args: Vec<String>,
228 env: HashMap<String, String>,
230 cwd: Option<String>,
232 },
233 Http {
235 url: String,
237 headers: HashMap<String, String>,
239 },
240 Sse {
242 url: String,
244 headers: HashMap<String, String>,
246 },
247 WebSocket {
249 url: String,
251 headers: HashMap<String, String>,
253 },
254}
255
256impl TransportConfig {
257 pub fn transport_type(&self) -> TransportType {
259 match self {
260 TransportConfig::Stdio { .. } => TransportType::Stdio,
261 TransportConfig::Http { .. } => TransportType::Http,
262 TransportConfig::Sse { .. } => TransportType::Sse,
263 TransportConfig::WebSocket { .. } => TransportType::WebSocket,
264 }
265 }
266}
267
268#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
270pub enum TransportState {
271 #[default]
273 Disconnected,
274 Connecting,
276 Connected,
278 Closing,
280 Error,
282}
283
284#[derive(Debug, Clone)]
286pub enum TransportEvent {
287 Connecting,
289 Connected,
291 Disconnected { reason: Option<String> },
293 Error { error: String },
295 MessageReceived(Box<McpMessage>),
297}
298
299#[async_trait]
304pub trait Transport: Send + Sync {
305 fn transport_type(&self) -> TransportType;
307
308 fn state(&self) -> TransportState;
310
311 async fn connect(&mut self) -> McpResult<()>;
315
316 async fn disconnect(&mut self) -> McpResult<()>;
320
321 async fn send(&mut self, message: McpMessage) -> McpResult<()>;
325
326 async fn send_request(&mut self, request: McpRequest) -> McpResult<McpResponse>;
330
331 async fn send_request_with_timeout(
335 &mut self,
336 request: McpRequest,
337 timeout: Duration,
338 ) -> McpResult<McpResponse>;
339
340 fn subscribe(&self) -> mpsc::Receiver<TransportEvent>;
344
345 fn is_connected(&self) -> bool {
347 self.state() == TransportState::Connected
348 }
349}
350
351pub type BoxedTransport = Box<dyn Transport>;
353
354pub type SharedTransport = Arc<tokio::sync::Mutex<BoxedTransport>>;
356
357pub struct TransportFactory;
359
360impl TransportFactory {
361 pub fn create(
365 config: TransportConfig,
366 options: ConnectionOptions,
367 ) -> McpResult<BoxedTransport> {
368 match config {
369 TransportConfig::Stdio {
370 command,
371 args,
372 env,
373 cwd,
374 } => {
375 use super::stdio::{StdioConfig, StdioTransport};
376 Ok(Box::new(StdioTransport::new(
377 StdioConfig {
378 command,
379 args,
380 env,
381 cwd,
382 },
383 options,
384 )))
385 }
386 TransportConfig::Http { url, headers } => {
387 use super::http::{HttpConfig, HttpTransport};
388 Ok(Box::new(HttpTransport::new(
389 HttpConfig { url, headers },
390 options,
391 )))
392 }
393 TransportConfig::Sse { url, headers } => {
394 use super::http::{HttpConfig, HttpTransport};
396 Ok(Box::new(HttpTransport::new(
397 HttpConfig { url, headers },
398 options,
399 )))
400 }
401 TransportConfig::WebSocket { url, headers } => {
402 use super::websocket::{WebSocketConfig, WebSocketTransport};
403 Ok(Box::new(WebSocketTransport::new(
404 WebSocketConfig { url, headers },
405 options,
406 )))
407 }
408 }
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_mcp_request_new() {
418 let req = McpRequest::new(serde_json::json!(1), "test/method");
419 assert_eq!(req.jsonrpc, "2.0");
420 assert_eq!(req.id, serde_json::json!(1));
421 assert_eq!(req.method, "test/method");
422 assert!(req.params.is_none());
423 }
424
425 #[test]
426 fn test_mcp_request_with_params() {
427 let params = serde_json::json!({"key": "value"});
428 let req =
429 McpRequest::with_params(serde_json::json!("req-1"), "test/method", params.clone());
430 assert_eq!(req.params, Some(params));
431 }
432
433 #[test]
434 fn test_mcp_response_success() {
435 let result = serde_json::json!({"status": "ok"});
436 let resp = McpResponse::success(serde_json::json!(1), result.clone());
437 assert!(!resp.is_error());
438 assert_eq!(resp.result, Some(result));
439 }
440
441 #[test]
442 fn test_mcp_response_error() {
443 let error = McpErrorData::new(-32600, "Invalid Request");
444 let resp = McpResponse::error(serde_json::json!(1), error);
445 assert!(resp.is_error());
446 assert!(resp.result.is_none());
447 }
448
449 #[test]
450 fn test_mcp_response_into_result() {
451 let result = serde_json::json!({"data": 42});
452 let resp = McpResponse::success(serde_json::json!(1), result.clone());
453 let res = resp.into_result();
454 assert!(res.is_ok());
455 assert_eq!(res.unwrap(), result);
456 }
457
458 #[test]
459 fn test_mcp_response_into_result_error() {
460 let error = McpErrorData::new(-32600, "Invalid Request");
461 let resp = McpResponse::error(serde_json::json!(1), error);
462 let res = resp.into_result();
463 assert!(res.is_err());
464 }
465
466 #[test]
467 fn test_mcp_notification() {
468 let notif = McpNotification::new("notifications/test");
469 assert_eq!(notif.jsonrpc, "2.0");
470 assert_eq!(notif.method, "notifications/test");
471 assert!(notif.params.is_none());
472 }
473
474 #[test]
475 fn test_mcp_notification_with_params() {
476 let params = serde_json::json!({"event": "update"});
477 let notif = McpNotification::with_params("notifications/test", params.clone());
478 assert_eq!(notif.params, Some(params));
479 }
480
481 #[test]
482 fn test_transport_config_type() {
483 let stdio = TransportConfig::Stdio {
484 command: "node".to_string(),
485 args: vec![],
486 env: HashMap::new(),
487 cwd: None,
488 };
489 assert_eq!(stdio.transport_type(), TransportType::Stdio);
490
491 let http = TransportConfig::Http {
492 url: "http://localhost:8080".to_string(),
493 headers: HashMap::new(),
494 };
495 assert_eq!(http.transport_type(), TransportType::Http);
496
497 let ws = TransportConfig::WebSocket {
498 url: "ws://localhost:8080".to_string(),
499 headers: HashMap::new(),
500 };
501 assert_eq!(ws.transport_type(), TransportType::WebSocket);
502 }
503
504 #[test]
505 fn test_transport_state_default() {
506 let state = TransportState::default();
507 assert_eq!(state, TransportState::Disconnected);
508 }
509
510 #[test]
511 fn test_mcp_message_id() {
512 let req = McpRequest::new(serde_json::json!(1), "test");
513 let msg = McpMessage::Request(req);
514 assert_eq!(msg.id(), Some(&serde_json::json!(1)));
515
516 let notif = McpNotification::new("test");
517 let msg = McpMessage::Notification(notif);
518 assert!(msg.id().is_none());
519 }
520
521 #[test]
522 fn test_mcp_message_method() {
523 let req = McpRequest::new(serde_json::json!(1), "test/method");
524 let msg = McpMessage::Request(req);
525 assert_eq!(msg.method(), Some("test/method"));
526
527 let resp = McpResponse::success(serde_json::json!(1), serde_json::json!({}));
528 let msg = McpMessage::Response(resp);
529 assert!(msg.method().is_none());
530 }
531
532 #[test]
533 fn test_mcp_error_data() {
534 let error = McpErrorData::new(-32600, "Invalid Request");
535 assert_eq!(error.code, -32600);
536 assert_eq!(error.message, "Invalid Request");
537 assert!(error.data.is_none());
538
539 let error_with_data = McpErrorData::with_data(
540 -32602,
541 "Invalid params",
542 serde_json::json!({"field": "name"}),
543 );
544 assert!(error_with_data.data.is_some());
545 }
546}