mcp_protocol_sdk/transport/
websocket.rs

1//! WebSocket transport implementation for MCP
2//!
3//! This module provides WebSocket-based transport for MCP communication,
4//! offering bidirectional, real-time communication between clients and servers.
5
6use async_trait::async_trait;
7use futures_util::{
8    sink::SinkExt,
9    stream::{SplitSink, SplitStream, StreamExt},
10};
11use serde_json::Value;
12use std::{collections::HashMap, sync::Arc, time::Duration};
13use tokio::{
14    net::{TcpListener, TcpStream},
15    sync::{broadcast, mpsc, Mutex, RwLock},
16    time::timeout,
17};
18use tokio_tungstenite::{
19    accept_async, connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream,
20};
21use url::Url;
22
23use crate::core::error::{McpError, McpResult};
24use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
25use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
26
27// ============================================================================
28// WebSocket Client Transport
29// ============================================================================
30
31/// WebSocket transport for MCP clients
32///
33/// This transport communicates with an MCP server via WebSocket connections,
34/// providing bidirectional real-time communication for both requests and notifications.
35pub struct WebSocketClientTransport {
36    ws_sender: Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
37    pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
38    notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
39    config: TransportConfig,
40    state: Arc<RwLock<ConnectionState>>,
41    url: String,
42    message_handler: Option<tokio::task::JoinHandle<()>>,
43}
44
45impl WebSocketClientTransport {
46    /// Create a new WebSocket client transport
47    ///
48    /// # Arguments
49    /// * `url` - WebSocket URL to connect to (e.g., "ws://localhost:8080/mcp")
50    ///
51    /// # Returns
52    /// Result containing the transport or an error
53    pub async fn new<S: AsRef<str>>(url: S) -> McpResult<Self> {
54        Self::with_config(url, TransportConfig::default()).await
55    }
56
57    /// Create a new WebSocket client transport with custom configuration
58    ///
59    /// # Arguments
60    /// * `url` - WebSocket URL to connect to
61    /// * `config` - Transport configuration
62    ///
63    /// # Returns
64    /// Result containing the transport or an error
65    pub async fn with_config<S: AsRef<str>>(url: S, config: TransportConfig) -> McpResult<Self> {
66        let url_str = url.as_ref();
67
68        // Validate URL format
69        let _url_parsed = Url::parse(url_str)
70            .map_err(|e| McpError::WebSocket(format!("Invalid WebSocket URL: {}", e)))?;
71
72        tracing::debug!("Connecting to WebSocket: {}", url_str);
73
74        // Connect to WebSocket with timeout
75        let connect_timeout = Duration::from_millis(config.connect_timeout_ms.unwrap_or(30_000));
76
77        let (ws_stream, _) = timeout(connect_timeout, connect_async(url_str))
78            .await
79            .map_err(|_| McpError::WebSocket("Connection timeout".to_string()))?
80            .map_err(|e| McpError::WebSocket(format!("Failed to connect: {}", e)))?;
81
82        let (ws_sender, ws_receiver) = ws_stream.split();
83
84        let pending_requests = Arc::new(Mutex::new(HashMap::new()));
85        let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
86        let state = Arc::new(RwLock::new(ConnectionState::Connected));
87
88        // Start message handling task
89        let message_handler = tokio::spawn(Self::handle_messages(
90            ws_receiver,
91            pending_requests.clone(),
92            notification_sender,
93            state.clone(),
94        ));
95
96        Ok(Self {
97            ws_sender: Some(ws_sender),
98            pending_requests,
99            notification_receiver: Some(notification_receiver),
100            config,
101            state,
102            url: url_str.to_string(),
103            message_handler: Some(message_handler),
104        })
105    }
106
107    async fn handle_messages(
108        mut ws_receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
109        pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
110        notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
111        state: Arc<RwLock<ConnectionState>>,
112    ) {
113        while let Some(message) = ws_receiver.next().await {
114            match message {
115                Ok(Message::Text(text)) => {
116                    tracing::trace!("Received WebSocket message: {}", text);
117
118                    // Try to parse as response first
119                    if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&text) {
120                        let mut pending = pending_requests.lock().await;
121                        if let Some(sender) = pending.remove(&response.id) {
122                            if let Err(_) = sender.send(response) {
123                                tracing::warn!("Failed to send response to waiting request");
124                            }
125                        } else {
126                            tracing::warn!(
127                                "Received response for unknown request ID: {:?}",
128                                response.id
129                            );
130                        }
131                    }
132                    // Try to parse as notification
133                    else if let Ok(notification) =
134                        serde_json::from_str::<JsonRpcNotification>(&text)
135                    {
136                        if let Err(_) = notification_sender.send(notification) {
137                            tracing::debug!("Notification receiver dropped");
138                            break;
139                        }
140                    } else {
141                        tracing::warn!("Failed to parse WebSocket message: {}", text);
142                    }
143                }
144                Ok(Message::Close(_)) => {
145                    tracing::info!("WebSocket connection closed");
146                    *state.write().await = ConnectionState::Disconnected;
147                    break;
148                }
149                Ok(Message::Ping(_data)) => {
150                    tracing::trace!("Received WebSocket ping");
151                    // Pong responses are handled automatically by tungstenite
152                }
153                Ok(Message::Pong(_)) => {
154                    tracing::trace!("Received WebSocket pong");
155                }
156                Ok(Message::Binary(_)) => {
157                    tracing::warn!("Received unexpected binary WebSocket message");
158                }
159                Ok(Message::Frame(_)) => {
160                    tracing::trace!("Received WebSocket frame (internal)");
161                    // Frame messages are internal to tungstenite
162                }
163                Err(e) => {
164                    tracing::error!("WebSocket error: {}", e);
165                    *state.write().await = ConnectionState::Error(e.to_string());
166                    break;
167                }
168            }
169        }
170
171        tracing::debug!("WebSocket message handler exiting");
172    }
173
174    async fn send_message(&mut self, message: Message) -> McpResult<()> {
175        if let Some(ref mut sender) = self.ws_sender {
176            sender
177                .send(message)
178                .await
179                .map_err(|e| McpError::WebSocket(format!("Failed to send message: {}", e)))?;
180        } else {
181            return Err(McpError::WebSocket("WebSocket not connected".to_string()));
182        }
183        Ok(())
184    }
185}
186
187#[async_trait]
188impl Transport for WebSocketClientTransport {
189    async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
190        let (sender, receiver) = tokio::sync::oneshot::channel();
191
192        // Store the pending request
193        {
194            let mut pending = self.pending_requests.lock().await;
195            pending.insert(request.id.clone(), sender);
196        }
197
198        // Send the request
199        let request_text =
200            serde_json::to_string(&request).map_err(|e| McpError::Serialization(e.to_string()))?;
201
202        tracing::trace!("Sending WebSocket request: {}", request_text);
203
204        self.send_message(Message::Text(request_text.into()))
205            .await?;
206
207        // Wait for response with timeout
208        let timeout_duration = Duration::from_millis(self.config.read_timeout_ms.unwrap_or(60_000));
209
210        let response = timeout(timeout_duration, receiver)
211            .await
212            .map_err(|_| McpError::WebSocket("Request timeout".to_string()))?
213            .map_err(|_| McpError::WebSocket("Response channel closed".to_string()))?;
214
215        Ok(response)
216    }
217
218    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
219        let notification_text = serde_json::to_string(&notification)
220            .map_err(|e| McpError::Serialization(e.to_string()))?;
221
222        tracing::trace!("Sending WebSocket notification: {}", notification_text);
223
224        self.send_message(Message::Text(notification_text.into()))
225            .await
226    }
227
228    async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
229        if let Some(ref mut receiver) = self.notification_receiver {
230            match receiver.try_recv() {
231                Ok(notification) => Ok(Some(notification)),
232                Err(mpsc::error::TryRecvError::Empty) => Ok(None),
233                Err(mpsc::error::TryRecvError::Disconnected) => Err(McpError::WebSocket(
234                    "Notification channel disconnected".to_string(),
235                )),
236            }
237        } else {
238            Ok(None)
239        }
240    }
241
242    async fn close(&mut self) -> McpResult<()> {
243        tracing::debug!("Closing WebSocket connection");
244
245        *self.state.write().await = ConnectionState::Closing;
246
247        // Send close message
248        if let Some(ref mut sender) = self.ws_sender {
249            let _ = sender.send(Message::Close(None)).await;
250        }
251
252        // Abort message handler
253        if let Some(handle) = self.message_handler.take() {
254            handle.abort();
255        }
256
257        self.ws_sender = None;
258        self.notification_receiver = None;
259
260        *self.state.write().await = ConnectionState::Disconnected;
261
262        Ok(())
263    }
264
265    fn is_connected(&self) -> bool {
266        // We'd need to check the actual state here
267        self.ws_sender.is_some()
268    }
269
270    fn connection_info(&self) -> String {
271        format!("WebSocket transport (url: {})", self.url)
272    }
273}
274
275// ============================================================================
276// WebSocket Server Transport
277// ============================================================================
278
279/// Connection state for a WebSocket client
280struct WebSocketConnection {
281    sender: SplitSink<WebSocketStream<TcpStream>, Message>,
282    _id: String, // Keep for future connection tracking/debugging
283}
284
285/// WebSocket transport for MCP servers
286///
287/// This transport serves MCP requests over WebSocket connections,
288/// allowing multiple concurrent clients with bidirectional communication.
289pub struct WebSocketServerTransport {
290    bind_addr: String,
291    config: TransportConfig, // Used for connection timeouts and limits
292    clients: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
293    request_handler: Arc<
294        RwLock<
295            Option<
296                Arc<
297                    dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
298                        + Send
299                        + Sync,
300                >,
301            >,
302        >,
303    >,
304    server_handle: Option<tokio::task::JoinHandle<()>>,
305    running: Arc<RwLock<bool>>,
306    shutdown_sender: Option<broadcast::Sender<()>>,
307}
308
309impl WebSocketServerTransport {
310    /// Create a new WebSocket server transport
311    ///
312    /// # Arguments
313    /// * `bind_addr` - Address to bind the WebSocket server to (e.g., "0.0.0.0:8080")
314    ///
315    /// # Returns
316    /// New WebSocket server transport instance
317    pub fn new<S: Into<String>>(bind_addr: S) -> Self {
318        Self::with_config(bind_addr, TransportConfig::default())
319    }
320
321    /// Create a new WebSocket server transport with custom configuration
322    ///
323    /// # Arguments
324    /// * `bind_addr` - Address to bind the WebSocket server to
325    /// * `config` - Transport configuration
326    ///
327    /// # Returns
328    /// New WebSocket server transport instance
329    pub fn with_config<S: Into<String>>(bind_addr: S, config: TransportConfig) -> Self {
330        let (shutdown_sender, _) = broadcast::channel(1);
331
332        Self {
333            bind_addr: bind_addr.into(),
334            config,
335            clients: Arc::new(RwLock::new(HashMap::new())),
336            request_handler: Arc::new(RwLock::new(None)),
337            server_handle: None,
338            running: Arc::new(RwLock::new(false)),
339            shutdown_sender: Some(shutdown_sender),
340        }
341    }
342
343    /// Set the request handler function
344    ///
345    /// # Arguments
346    /// * `handler` - Function that processes incoming requests
347    pub async fn set_request_handler<F>(&mut self, handler: F)
348    where
349        F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
350            + Send
351            + Sync
352            + 'static,
353    {
354        let mut request_handler = self.request_handler.write().await;
355        *request_handler = Some(Arc::new(handler));
356    }
357
358    /// Get the current configuration
359    pub fn config(&self) -> &TransportConfig {
360        &self.config
361    }
362
363    /// Get the maximum message size from config
364    pub fn max_message_size(&self) -> Option<usize> {
365        self.config.max_message_size
366    }
367
368    async fn handle_client_connection(
369        stream: TcpStream,
370        clients: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
371        request_handler: Arc<
372            RwLock<
373                Option<
374                    Arc<
375                        dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
376                            + Send
377                            + Sync,
378                    >,
379                >,
380            >,
381        >,
382        mut shutdown_receiver: broadcast::Receiver<()>,
383    ) {
384        let client_id = uuid::Uuid::new_v4().to_string();
385
386        let ws_stream = match accept_async(stream).await {
387            Ok(ws) => ws,
388            Err(e) => {
389                tracing::error!("Failed to accept WebSocket connection: {}", e);
390                return;
391            }
392        };
393
394        tracing::info!("New WebSocket client connected: {}", client_id);
395
396        let (ws_sender, mut ws_receiver) = ws_stream.split();
397
398        // Add client to the connections map
399        {
400            let mut clients_guard = clients.write().await;
401            clients_guard.insert(
402                client_id.clone(),
403                WebSocketConnection {
404                    sender: ws_sender,
405                    _id: client_id.clone(),
406                },
407            );
408        }
409
410        // Handle messages from this client
411        loop {
412            tokio::select! {
413                message = ws_receiver.next() => {
414                    match message {
415                        Some(Ok(Message::Text(text))) => {
416                            tracing::trace!("Received message from {}: {}", client_id, text);
417
418                            // Try to parse as request
419                            if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(&text) {
420                                let handler_guard = request_handler.read().await;
421                                if let Some(ref handler) = *handler_guard {
422                                    let response_rx = handler(request.clone());
423                                    drop(handler_guard);
424
425                                    match response_rx.await {
426                                        Ok(response) => {
427                                            let response_text = match serde_json::to_string(&response) {
428                                                Ok(text) => text,
429                                                Err(e) => {
430                                                    tracing::error!("Failed to serialize response: {}", e);
431                                                    continue;
432                                                }
433                                            };
434
435                                            // Send response back to client
436                                            let mut clients_guard = clients.write().await;
437                                            if let Some(client) = clients_guard.get_mut(&client_id) {
438                                                if let Err(e) = client.sender.send(Message::Text(response_text.into())).await {
439                                                    tracing::error!("Failed to send response to client {}: {}", client_id, e);
440                                                    break;
441                                                }
442                                            }
443                                        }
444                                        Err(_) => {
445                                            tracing::error!("Request handler channel closed for client {}", client_id);
446                                        }
447                                    }
448                                } else {
449                                    tracing::warn!("No request handler configured for client {}", client_id);
450                                }
451                            }
452                            // Handle notifications (no response needed)
453                            else if let Ok(_notification) = serde_json::from_str::<JsonRpcNotification>(&text) {
454                                tracing::trace!("Received notification from client {}", client_id);
455                                // Notifications don't require responses
456                            } else {
457                                tracing::warn!("Failed to parse message from client {}: {}", client_id, text);
458                            }
459                        }
460                        Some(Ok(Message::Close(_))) => {
461                            tracing::info!("Client {} disconnected", client_id);
462                            break;
463                        }
464                        Some(Ok(Message::Ping(data))) => {
465                            tracing::trace!("Received ping from client {}", client_id);
466                            let mut clients_guard = clients.write().await;
467                            if let Some(client) = clients_guard.get_mut(&client_id) {
468                                if let Err(e) = client.sender.send(Message::Pong(data)).await {
469                                    tracing::error!("Failed to send pong to client {}: {}", client_id, e);
470                                    break;
471                                }
472                            }
473                        }
474                        Some(Ok(Message::Pong(_))) => {
475                            tracing::trace!("Received pong from client {}", client_id);
476                        }
477                        Some(Ok(Message::Binary(_))) => {
478                            tracing::warn!("Received unexpected binary message from client {}", client_id);
479                        }
480                        Some(Ok(Message::Frame(_))) => {
481                            tracing::trace!("Received WebSocket frame from client {} (internal)", client_id);
482                            // Frame messages are internal to tungstenite
483                        }
484                        Some(Err(e)) => {
485                            tracing::error!("WebSocket error for client {}: {}", client_id, e);
486                            break;
487                        }
488                        None => {
489                            tracing::info!("WebSocket stream ended for client {}", client_id);
490                            break;
491                        }
492                    }
493                }
494                _ = shutdown_receiver.recv() => {
495                    tracing::info!("Shutting down connection for client {}", client_id);
496                    break;
497                }
498            }
499        }
500
501        // Remove client from connections
502        {
503            let mut clients_guard = clients.write().await;
504            clients_guard.remove(&client_id);
505        }
506
507        tracing::info!("Client {} connection handler exiting", client_id);
508    }
509}
510
511#[async_trait]
512impl ServerTransport for WebSocketServerTransport {
513    async fn start(&mut self) -> McpResult<()> {
514        tracing::info!("Starting WebSocket server on {}", self.bind_addr);
515
516        let listener = TcpListener::bind(&self.bind_addr).await.map_err(|e| {
517            McpError::WebSocket(format!("Failed to bind to {}: {}", self.bind_addr, e))
518        })?;
519
520        let clients = self.clients.clone();
521        let request_handler = self.request_handler.clone();
522        let running = self.running.clone();
523        let shutdown_sender = self.shutdown_sender.as_ref().unwrap().clone();
524
525        *running.write().await = true;
526
527        let server_handle = tokio::spawn(async move {
528            let mut shutdown_receiver = shutdown_sender.subscribe();
529
530            loop {
531                tokio::select! {
532                    result = listener.accept() => {
533                        match result {
534                            Ok((stream, addr)) => {
535                                tracing::debug!("New connection from: {}", addr);
536
537                                tokio::spawn(Self::handle_client_connection(
538                                    stream,
539                                    clients.clone(),
540                                    request_handler.clone(),
541                                    shutdown_sender.subscribe(),
542                                ));
543                            }
544                            Err(e) => {
545                                tracing::error!("Failed to accept connection: {}", e);
546                            }
547                        }
548                    }
549                    _ = shutdown_receiver.recv() => {
550                        tracing::info!("WebSocket server shutting down");
551                        break;
552                    }
553                }
554            }
555        });
556
557        self.server_handle = Some(server_handle);
558
559        tracing::info!(
560            "WebSocket server started successfully on {}",
561            self.bind_addr
562        );
563        Ok(())
564    }
565
566    async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
567        let handler_guard = self.request_handler.read().await;
568
569        if let Some(ref handler) = *handler_guard {
570            let response_rx = handler(request);
571            drop(handler_guard);
572
573            match response_rx.await {
574                Ok(response) => Ok(response),
575                Err(_) => Err(McpError::WebSocket(
576                    "Request handler channel closed".to_string(),
577                )),
578            }
579        } else {
580            Ok(JsonRpcResponse {
581                jsonrpc: "2.0".to_string(),
582                id: request.id,
583                result: None,
584            })
585        }
586    }
587
588    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
589        let notification_text = serde_json::to_string(&notification)
590            .map_err(|e| McpError::Serialization(e.to_string()))?;
591
592        let mut clients_guard = self.clients.write().await;
593        let mut disconnected_clients = Vec::new();
594
595        for (client_id, client) in clients_guard.iter_mut() {
596            if let Err(e) = client
597                .sender
598                .send(Message::Text(notification_text.clone().into()))
599                .await
600            {
601                tracing::error!("Failed to send notification to client {}: {}", client_id, e);
602                disconnected_clients.push(client_id.clone());
603            }
604        }
605
606        // Remove disconnected clients
607        for client_id in disconnected_clients {
608            clients_guard.remove(&client_id);
609        }
610
611        Ok(())
612    }
613
614    async fn stop(&mut self) -> McpResult<()> {
615        tracing::info!("Stopping WebSocket server");
616
617        *self.running.write().await = false;
618
619        // Send shutdown signal
620        if let Some(ref sender) = self.shutdown_sender {
621            let _ = sender.send(());
622        }
623
624        // Wait for server to stop
625        if let Some(handle) = self.server_handle.take() {
626            handle.abort();
627        }
628
629        // Close all client connections
630        let mut clients_guard = self.clients.write().await;
631        for (client_id, client) in clients_guard.iter_mut() {
632            tracing::debug!("Closing connection for client {}", client_id);
633            let _ = client.sender.send(Message::Close(None)).await;
634        }
635        clients_guard.clear();
636
637        Ok(())
638    }
639
640    fn is_running(&self) -> bool {
641        // Check if we have an active server handle
642        self.server_handle.is_some()
643    }
644
645    fn server_info(&self) -> String {
646        format!("WebSocket server transport (bind: {})", self.bind_addr)
647    }
648}
649
650#[cfg(test)]
651mod tests {
652    use super::*;
653
654    #[test]
655    fn test_websocket_server_creation() {
656        let transport = WebSocketServerTransport::new("127.0.0.1:0");
657        assert_eq!(transport.bind_addr, "127.0.0.1:0");
658        assert!(!transport.is_running());
659    }
660
661    #[test]
662    fn test_websocket_server_with_config() {
663        let mut config = TransportConfig::default();
664        config.max_message_size = Some(64 * 1024);
665
666        let transport = WebSocketServerTransport::with_config("0.0.0.0:9090", config);
667        assert_eq!(transport.bind_addr, "0.0.0.0:9090");
668        assert_eq!(transport.config.max_message_size, Some(64 * 1024));
669    }
670
671    #[tokio::test]
672    async fn test_websocket_client_invalid_url() {
673        let result = WebSocketClientTransport::new("invalid-url").await;
674        assert!(result.is_err());
675
676        if let Err(McpError::WebSocket(msg)) = result {
677            assert!(msg.contains("Invalid WebSocket URL"));
678        } else {
679            panic!("Expected WebSocket error");
680        }
681    }
682
683    #[tokio::test]
684    async fn test_websocket_client_connection_info() {
685        // This will fail to connect but we can test the URL parsing
686        let url = "ws://localhost:9999/test";
687        if let Ok(transport) = WebSocketClientTransport::new(url).await {
688            let info = transport.connection_info();
689            assert!(info.contains("localhost:9999"));
690        }
691        // If connection fails (which is expected), that's fine for this test
692    }
693}