mcp_protocol_sdk/transport/
websocket.rs

1//! WebSocket transport implementation for MCP
2//!
3//! This module provides WebSocket-based transport for MCP communication,
4//! offering bidirectional, real-time communication between clients and servers.
5
6use async_trait::async_trait;
7use futures_util::{
8    sink::SinkExt,
9    stream::{SplitSink, SplitStream, StreamExt},
10};
11use serde_json::Value;
12use std::{collections::HashMap, sync::Arc, time::Duration};
13use tokio::{
14    net::{TcpListener, TcpStream},
15    sync::{Mutex, RwLock, broadcast, mpsc},
16    time::timeout,
17};
18use tokio_tungstenite::{
19    MaybeTlsStream, WebSocketStream, accept_async, connect_async, tungstenite::Message,
20};
21use url::Url;
22
23use crate::core::error::{McpError, McpResult};
24use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
25use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
26
27// Type aliases to reduce complexity warnings
28type RequestHandler = Arc<
29    RwLock<
30        Option<
31            Arc<
32                dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
33                    + Send
34                    + Sync,
35            >,
36        >,
37    >,
38>;
39
40// ============================================================================
41// WebSocket Client Transport
42// ============================================================================
43
44/// WebSocket transport for MCP clients
45///
46/// This transport communicates with an MCP server via WebSocket connections,
47/// providing bidirectional real-time communication for both requests and notifications.
48pub struct WebSocketClientTransport {
49    ws_sender: Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
50    pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
51    notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
52    config: TransportConfig,
53    state: Arc<RwLock<ConnectionState>>,
54    url: String,
55    message_handler: Option<tokio::task::JoinHandle<()>>,
56}
57
58impl WebSocketClientTransport {
59    /// Create a new WebSocket client transport
60    ///
61    /// # Arguments
62    /// * `url` - WebSocket URL to connect to (e.g., "ws://localhost:8080/mcp")
63    ///
64    /// # Returns
65    /// Result containing the transport or an error
66    pub async fn new<S: AsRef<str>>(url: S) -> McpResult<Self> {
67        Self::with_config(url, TransportConfig::default()).await
68    }
69
70    /// Create a new WebSocket client transport with custom configuration
71    ///
72    /// # Arguments
73    /// * `url` - WebSocket URL to connect to
74    /// * `config` - Transport configuration
75    ///
76    /// # Returns
77    /// Result containing the transport or an error
78    pub async fn with_config<S: AsRef<str>>(url: S, config: TransportConfig) -> McpResult<Self> {
79        let url_str = url.as_ref();
80
81        // Validate URL format
82        let _url_parsed = Url::parse(url_str)
83            .map_err(|e| McpError::WebSocket(format!("Invalid WebSocket URL: {e}")))?;
84
85        tracing::debug!("Connecting to WebSocket: {}", url_str);
86
87        // Connect to WebSocket with timeout
88        let connect_timeout = Duration::from_millis(config.connect_timeout_ms.unwrap_or(30_000));
89
90        let (ws_stream, _) = timeout(connect_timeout, connect_async(url_str))
91            .await
92            .map_err(|_| McpError::WebSocket("Connection timeout".to_string()))?
93            .map_err(|e| McpError::WebSocket(format!("Failed to connect: {e}")))?;
94
95        let (ws_sender, ws_receiver) = ws_stream.split();
96
97        let pending_requests = Arc::new(Mutex::new(HashMap::new()));
98        let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
99        let state = Arc::new(RwLock::new(ConnectionState::Connected));
100
101        // Start message handling task
102        let message_handler = tokio::spawn(Self::handle_messages(
103            ws_receiver,
104            pending_requests.clone(),
105            notification_sender,
106            state.clone(),
107        ));
108
109        Ok(Self {
110            ws_sender: Some(ws_sender),
111            pending_requests,
112            notification_receiver: Some(notification_receiver),
113            config,
114            state,
115            url: url_str.to_string(),
116            message_handler: Some(message_handler),
117        })
118    }
119
120    async fn handle_messages(
121        mut ws_receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
122        pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
123        notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
124        state: Arc<RwLock<ConnectionState>>,
125    ) {
126        while let Some(message) = ws_receiver.next().await {
127            match message {
128                Ok(Message::Text(text)) => {
129                    tracing::trace!("Received WebSocket message: {}", text);
130
131                    // Try to parse as response first
132                    if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&text) {
133                        let mut pending = pending_requests.lock().await;
134                        if let Some(sender) = pending.remove(&response.id) {
135                            if sender.send(response).is_err() {
136                                tracing::warn!("Failed to send response to waiting request");
137                            }
138                        } else {
139                            tracing::warn!(
140                                "Received response for unknown request ID: {:?}",
141                                response.id
142                            );
143                        }
144                    }
145                    // Try to parse as notification
146                    else if let Ok(notification) =
147                        serde_json::from_str::<JsonRpcNotification>(&text)
148                    {
149                        if notification_sender.send(notification).is_err() {
150                            tracing::debug!("Notification receiver dropped");
151                            break;
152                        }
153                    } else {
154                        tracing::warn!("Failed to parse WebSocket message: {}", text);
155                    }
156                }
157                Ok(Message::Close(_)) => {
158                    tracing::info!("WebSocket connection closed");
159                    *state.write().await = ConnectionState::Disconnected;
160                    break;
161                }
162                Ok(Message::Ping(_data)) => {
163                    tracing::trace!("Received WebSocket ping");
164                    // Pong responses are handled automatically by tungstenite
165                }
166                Ok(Message::Pong(_)) => {
167                    tracing::trace!("Received WebSocket pong");
168                }
169                Ok(Message::Binary(_)) => {
170                    tracing::warn!("Received unexpected binary WebSocket message");
171                }
172                Ok(Message::Frame(_)) => {
173                    tracing::trace!("Received WebSocket frame (internal)");
174                    // Frame messages are internal to tungstenite
175                }
176                Err(e) => {
177                    tracing::error!("WebSocket error: {}", e);
178                    *state.write().await = ConnectionState::Error(e.to_string());
179                    break;
180                }
181            }
182        }
183
184        tracing::debug!("WebSocket message handler exiting");
185    }
186
187    async fn send_message(&mut self, message: Message) -> McpResult<()> {
188        if let Some(ref mut sender) = self.ws_sender {
189            sender
190                .send(message)
191                .await
192                .map_err(|e| McpError::WebSocket(format!("Failed to send message: {e}")))?;
193        } else {
194            return Err(McpError::WebSocket("WebSocket not connected".to_string()));
195        }
196        Ok(())
197    }
198}
199
200#[async_trait]
201impl Transport for WebSocketClientTransport {
202    async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
203        let (sender, receiver) = tokio::sync::oneshot::channel();
204
205        // Store the pending request
206        {
207            let mut pending = self.pending_requests.lock().await;
208            pending.insert(request.id.clone(), sender);
209        }
210
211        // Send the request
212        let request_text =
213            serde_json::to_string(&request).map_err(|e| McpError::Serialization(e.to_string()))?;
214
215        tracing::trace!("Sending WebSocket request: {}", request_text);
216
217        self.send_message(Message::Text(request_text.into()))
218            .await?;
219
220        // Wait for response with timeout
221        let timeout_duration = Duration::from_millis(self.config.read_timeout_ms.unwrap_or(60_000));
222
223        let response = timeout(timeout_duration, receiver)
224            .await
225            .map_err(|_| McpError::WebSocket("Request timeout".to_string()))?
226            .map_err(|_| McpError::WebSocket("Response channel closed".to_string()))?;
227
228        Ok(response)
229    }
230
231    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
232        let notification_text = serde_json::to_string(&notification)
233            .map_err(|e| McpError::Serialization(e.to_string()))?;
234
235        tracing::trace!("Sending WebSocket notification: {}", notification_text);
236
237        self.send_message(Message::Text(notification_text.into()))
238            .await
239    }
240
241    async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
242        if let Some(ref mut receiver) = self.notification_receiver {
243            match receiver.try_recv() {
244                Ok(notification) => Ok(Some(notification)),
245                Err(mpsc::error::TryRecvError::Empty) => Ok(None),
246                Err(mpsc::error::TryRecvError::Disconnected) => Err(McpError::WebSocket(
247                    "Notification channel disconnected".to_string(),
248                )),
249            }
250        } else {
251            Ok(None)
252        }
253    }
254
255    async fn close(&mut self) -> McpResult<()> {
256        tracing::debug!("Closing WebSocket connection");
257
258        *self.state.write().await = ConnectionState::Closing;
259
260        // Send close message
261        if let Some(ref mut sender) = self.ws_sender {
262            let _ = sender.send(Message::Close(None)).await;
263        }
264
265        // Abort message handler
266        if let Some(handle) = self.message_handler.take() {
267            handle.abort();
268        }
269
270        self.ws_sender = None;
271        self.notification_receiver = None;
272
273        *self.state.write().await = ConnectionState::Disconnected;
274
275        Ok(())
276    }
277
278    fn is_connected(&self) -> bool {
279        // We'd need to check the actual state here
280        self.ws_sender.is_some()
281    }
282
283    fn connection_info(&self) -> String {
284        format!("WebSocket transport (url: {})", self.url)
285    }
286}
287
288// ============================================================================
289// WebSocket Server Transport
290// ============================================================================
291
292/// Connection state for a WebSocket client
293struct WebSocketConnection {
294    sender: SplitSink<WebSocketStream<TcpStream>, Message>,
295    _id: String, // Keep for future connection tracking/debugging
296}
297
298/// WebSocket transport for MCP servers
299///
300/// This transport serves MCP requests over WebSocket connections,
301/// allowing multiple concurrent clients with bidirectional communication.
302pub struct WebSocketServerTransport {
303    bind_addr: String,
304    config: TransportConfig, // Used for connection timeouts and limits
305    clients: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
306    request_handler: RequestHandler,
307    server_handle: Option<tokio::task::JoinHandle<()>>,
308    running: Arc<RwLock<bool>>,
309    shutdown_sender: Option<broadcast::Sender<()>>,
310}
311
312impl WebSocketServerTransport {
313    /// Create a new WebSocket server transport
314    ///
315    /// # Arguments
316    /// * `bind_addr` - Address to bind the WebSocket server to (e.g., "0.0.0.0:8080")
317    ///
318    /// # Returns
319    /// New WebSocket server transport instance
320    pub fn new<S: Into<String>>(bind_addr: S) -> Self {
321        Self::with_config(bind_addr, TransportConfig::default())
322    }
323
324    /// Create a new WebSocket server transport with custom configuration
325    ///
326    /// # Arguments
327    /// * `bind_addr` - Address to bind the WebSocket server to
328    /// * `config` - Transport configuration
329    ///
330    /// # Returns
331    /// New WebSocket server transport instance
332    pub fn with_config<S: Into<String>>(bind_addr: S, config: TransportConfig) -> Self {
333        let (shutdown_sender, _) = broadcast::channel(1);
334
335        Self {
336            bind_addr: bind_addr.into(),
337            config,
338            clients: Arc::new(RwLock::new(HashMap::new())),
339            request_handler: Arc::new(RwLock::new(None)),
340            server_handle: None,
341            running: Arc::new(RwLock::new(false)),
342            shutdown_sender: Some(shutdown_sender),
343        }
344    }
345
346    /// Set the request handler function
347    ///
348    /// # Arguments
349    /// * `handler` - Function that processes incoming requests
350    pub async fn set_request_handler<F>(&mut self, handler: F)
351    where
352        F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
353            + Send
354            + Sync
355            + 'static,
356    {
357        let mut request_handler = self.request_handler.write().await;
358        *request_handler = Some(Arc::new(handler));
359    }
360
361    /// Get the current configuration
362    pub fn config(&self) -> &TransportConfig {
363        &self.config
364    }
365
366    /// Get the maximum message size from config
367    pub fn max_message_size(&self) -> Option<usize> {
368        self.config.max_message_size
369    }
370
371    async fn handle_client_connection(
372        stream: TcpStream,
373        clients: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
374        request_handler: RequestHandler,
375        mut shutdown_receiver: broadcast::Receiver<()>,
376    ) {
377        let client_id = uuid::Uuid::new_v4().to_string();
378
379        let ws_stream = match accept_async(stream).await {
380            Ok(ws) => ws,
381            Err(e) => {
382                tracing::error!("Failed to accept WebSocket connection: {}", e);
383                return;
384            }
385        };
386
387        tracing::info!("New WebSocket client connected: {}", client_id);
388
389        let (ws_sender, mut ws_receiver) = ws_stream.split();
390
391        // Add client to the connections map
392        {
393            let mut clients_guard = clients.write().await;
394            clients_guard.insert(
395                client_id.clone(),
396                WebSocketConnection {
397                    sender: ws_sender,
398                    _id: client_id.clone(),
399                },
400            );
401        }
402
403        // Handle messages from this client
404        loop {
405            tokio::select! {
406                message = ws_receiver.next() => {
407                    match message {
408                        Some(Ok(Message::Text(text))) => {
409                            tracing::trace!("Received message from {}: {}", client_id, text);
410
411                            // Try to parse as request
412                            if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(&text) {
413                                let handler_guard = request_handler.read().await;
414                                if let Some(ref handler) = *handler_guard {
415                                    let response_rx = handler(request.clone());
416                                    drop(handler_guard);
417
418                                    match response_rx.await {
419                                        Ok(response) => {
420                                            let response_text = match serde_json::to_string(&response) {
421                                                Ok(text) => text,
422                                                Err(e) => {
423                                                    tracing::error!("Failed to serialize response: {}", e);
424                                                    continue;
425                                                }
426                                            };
427
428                                            // Send response back to client
429                                            let mut clients_guard = clients.write().await;
430                                            if let Some(client) = clients_guard.get_mut(&client_id) {
431                                                if let Err(e) = client.sender.send(Message::Text(response_text.into())).await {
432                                                    tracing::error!("Failed to send response to client {}: {}", client_id, e);
433                                                    break;
434                                                }
435                                            }
436                                        }
437                                        Err(_) => {
438                                            tracing::error!("Request handler channel closed for client {}", client_id);
439                                        }
440                                    }
441                                } else {
442                                    tracing::warn!("No request handler configured for client {}", client_id);
443                                }
444                            }
445                            // Handle notifications (no response needed)
446                            else if let Ok(_notification) = serde_json::from_str::<JsonRpcNotification>(&text) {
447                                tracing::trace!("Received notification from client {}", client_id);
448                                // Notifications don't require responses
449                            } else {
450                                tracing::warn!("Failed to parse message from client {}: {}", client_id, text);
451                            }
452                        }
453                        Some(Ok(Message::Close(_))) => {
454                            tracing::info!("Client {} disconnected", client_id);
455                            break;
456                        }
457                        Some(Ok(Message::Ping(data))) => {
458                            tracing::trace!("Received ping from client {}", client_id);
459                            let mut clients_guard = clients.write().await;
460                            if let Some(client) = clients_guard.get_mut(&client_id) {
461                                if let Err(e) = client.sender.send(Message::Pong(data)).await {
462                                    tracing::error!("Failed to send pong to client {}: {}", client_id, e);
463                                    break;
464                                }
465                            }
466                        }
467                        Some(Ok(Message::Pong(_))) => {
468                            tracing::trace!("Received pong from client {}", client_id);
469                        }
470                        Some(Ok(Message::Binary(_))) => {
471                            tracing::warn!("Received unexpected binary message from client {}", client_id);
472                        }
473                        Some(Ok(Message::Frame(_))) => {
474                            tracing::trace!("Received WebSocket frame from client {} (internal)", client_id);
475                            // Frame messages are internal to tungstenite
476                        }
477                        Some(Err(e)) => {
478                            tracing::error!("WebSocket error for client {}: {}", client_id, e);
479                            break;
480                        }
481                        None => {
482                            tracing::info!("WebSocket stream ended for client {}", client_id);
483                            break;
484                        }
485                    }
486                }
487                _ = shutdown_receiver.recv() => {
488                    tracing::info!("Shutting down connection for client {}", client_id);
489                    break;
490                }
491            }
492        }
493
494        // Remove client from connections
495        {
496            let mut clients_guard = clients.write().await;
497            clients_guard.remove(&client_id);
498        }
499
500        tracing::info!("Client {} connection handler exiting", client_id);
501    }
502}
503
504#[async_trait]
505impl ServerTransport for WebSocketServerTransport {
506    async fn start(&mut self) -> McpResult<()> {
507        tracing::info!("Starting WebSocket server on {}", self.bind_addr);
508
509        let listener = TcpListener::bind(&self.bind_addr).await.map_err(|e| {
510            McpError::WebSocket(format!("Failed to bind to {}: {}", self.bind_addr, e))
511        })?;
512
513        let clients = self.clients.clone();
514        let request_handler = self.request_handler.clone();
515        let running = self.running.clone();
516        let shutdown_sender = self.shutdown_sender.as_ref().unwrap().clone();
517
518        *running.write().await = true;
519
520        let server_handle = tokio::spawn(async move {
521            let mut shutdown_receiver = shutdown_sender.subscribe();
522
523            loop {
524                tokio::select! {
525                    result = listener.accept() => {
526                        match result {
527                            Ok((stream, addr)) => {
528                                tracing::debug!("New connection from: {}", addr);
529
530                                tokio::spawn(Self::handle_client_connection(
531                                    stream,
532                                    clients.clone(),
533                                    request_handler.clone(),
534                                    shutdown_sender.subscribe(),
535                                ));
536                            }
537                            Err(e) => {
538                                tracing::error!("Failed to accept connection: {}", e);
539                            }
540                        }
541                    }
542                    _ = shutdown_receiver.recv() => {
543                        tracing::info!("WebSocket server shutting down");
544                        break;
545                    }
546                }
547            }
548        });
549
550        self.server_handle = Some(server_handle);
551
552        tracing::info!(
553            "WebSocket server started successfully on {}",
554            self.bind_addr
555        );
556        Ok(())
557    }
558
559    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
560        let notification_text = serde_json::to_string(&notification)
561            .map_err(|e| McpError::Serialization(e.to_string()))?;
562
563        let mut clients_guard = self.clients.write().await;
564        let mut disconnected_clients = Vec::new();
565
566        for (client_id, client) in clients_guard.iter_mut() {
567            if let Err(e) = client
568                .sender
569                .send(Message::Text(notification_text.clone().into()))
570                .await
571            {
572                tracing::error!("Failed to send notification to client {}: {}", client_id, e);
573                disconnected_clients.push(client_id.clone());
574            }
575        }
576
577        // Remove disconnected clients
578        for client_id in disconnected_clients {
579            clients_guard.remove(&client_id);
580        }
581
582        Ok(())
583    }
584
585    async fn stop(&mut self) -> McpResult<()> {
586        tracing::info!("Stopping WebSocket server");
587
588        *self.running.write().await = false;
589
590        // Send shutdown signal
591        if let Some(ref sender) = self.shutdown_sender {
592            let _ = sender.send(());
593        }
594
595        // Wait for server to stop
596        if let Some(handle) = self.server_handle.take() {
597            handle.abort();
598        }
599
600        // Close all client connections
601        let mut clients_guard = self.clients.write().await;
602        for (client_id, client) in clients_guard.iter_mut() {
603            tracing::debug!("Closing connection for client {}", client_id);
604            let _ = client.sender.send(Message::Close(None)).await;
605        }
606        clients_guard.clear();
607
608        Ok(())
609    }
610
611    fn is_running(&self) -> bool {
612        // Check if we have an active server handle
613        self.server_handle.is_some()
614    }
615
616    fn server_info(&self) -> String {
617        format!("WebSocket server transport (bind: {})", self.bind_addr)
618    }
619
620    fn set_request_handler(&mut self, handler: crate::transport::traits::ServerRequestHandler) {
621        // Convert the ServerRequestHandler to the WebSocket transport's expected format
622        let _ws_handler = Arc::new(move |request: JsonRpcRequest| {
623            let (tx, rx) = tokio::sync::oneshot::channel();
624            let handler_future = handler(request);
625            tokio::spawn(async move {
626                let result = handler_future.await;
627                let _ = tx.send(result.unwrap_or_else(|e| JsonRpcResponse {
628                    jsonrpc: "2.0".to_string(),
629                    id: serde_json::Value::Null,
630                    result: Some(serde_json::json!({
631                        "error": {
632                            "code": -32603,
633                            "message": e.to_string()
634                        }
635                    })),
636                }));
637            });
638            rx
639        });
640
641        // Set the handler in the WebSocket transport's request_handler field
642        tokio::spawn(async move {
643            // Note: This is a limitation - we can't easily update the async field from sync method
644            // The WebSocket transport should be updated in the future to support the new trait design
645        });
646    }
647}
648
649#[cfg(test)]
650mod tests {
651    use super::*;
652
653    #[test]
654    fn test_websocket_server_creation() {
655        let transport = WebSocketServerTransport::new("127.0.0.1:0");
656        assert_eq!(transport.bind_addr, "127.0.0.1:0");
657        assert!(!transport.is_running());
658    }
659
660    #[test]
661    fn test_websocket_server_with_config() {
662        let mut config = TransportConfig::default();
663        config.max_message_size = Some(64 * 1024);
664
665        let transport = WebSocketServerTransport::with_config("0.0.0.0:9090", config);
666        assert_eq!(transport.bind_addr, "0.0.0.0:9090");
667        assert_eq!(transport.config.max_message_size, Some(64 * 1024));
668    }
669
670    #[tokio::test]
671    async fn test_websocket_client_invalid_url() {
672        let result = WebSocketClientTransport::new("invalid-url").await;
673        assert!(result.is_err());
674
675        if let Err(McpError::WebSocket(msg)) = result {
676            assert!(msg.contains("Invalid WebSocket URL"));
677        } else {
678            panic!("Expected WebSocket error");
679        }
680    }
681
682    #[tokio::test]
683    async fn test_websocket_client_connection_info() {
684        // This will fail to connect but we can test the URL parsing
685        let url = "ws://localhost:9999/test";
686        if let Ok(transport) = WebSocketClientTransport::new(url).await {
687            let info = transport.connection_info();
688            assert!(info.contains("localhost:9999"));
689        }
690        // If connection fails (which is expected), that's fine for this test
691    }
692}