mcp_protocol_sdk/transport/
websocket.rs

1//! WebSocket transport implementation for MCP
2//!
3//! This module provides WebSocket-based transport for MCP communication,
4//! offering bidirectional, real-time communication between clients and servers.
5
6use async_trait::async_trait;
7use futures_util::{
8    sink::SinkExt,
9    stream::{SplitSink, SplitStream, StreamExt},
10};
11use serde_json::Value;
12use std::{collections::HashMap, sync::Arc, time::Duration};
13use tokio::{
14    net::{TcpListener, TcpStream},
15    sync::{broadcast, mpsc, Mutex, RwLock},
16    time::timeout,
17};
18use tokio_tungstenite::{
19    accept_async, connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream,
20};
21use url::Url;
22
23use crate::core::error::{McpError, McpResult};
24use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
25use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
26
27// ============================================================================
28// WebSocket Client Transport
29// ============================================================================
30
31/// WebSocket transport for MCP clients
32///
33/// This transport communicates with an MCP server via WebSocket connections,
34/// providing bidirectional real-time communication for both requests and notifications.
35pub struct WebSocketClientTransport {
36    ws_sender: Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>,
37    pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
38    notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
39    config: TransportConfig,
40    state: Arc<RwLock<ConnectionState>>,
41    url: String,
42    message_handler: Option<tokio::task::JoinHandle<()>>,
43}
44
45impl WebSocketClientTransport {
46    /// Create a new WebSocket client transport
47    ///
48    /// # Arguments
49    /// * `url` - WebSocket URL to connect to (e.g., "ws://localhost:8080/mcp")
50    ///
51    /// # Returns
52    /// Result containing the transport or an error
53    pub async fn new<S: AsRef<str>>(url: S) -> McpResult<Self> {
54        Self::with_config(url, TransportConfig::default()).await
55    }
56
57    /// Create a new WebSocket client transport with custom configuration
58    ///
59    /// # Arguments
60    /// * `url` - WebSocket URL to connect to
61    /// * `config` - Transport configuration
62    ///
63    /// # Returns
64    /// Result containing the transport or an error
65    pub async fn with_config<S: AsRef<str>>(url: S, config: TransportConfig) -> McpResult<Self> {
66        let url_str = url.as_ref();
67        let url_parsed = Url::parse(url_str)
68            .map_err(|e| McpError::WebSocket(format!("Invalid WebSocket URL: {}", e)))?;
69
70        tracing::debug!("Connecting to WebSocket: {}", url_str);
71
72        // Connect to WebSocket with timeout
73        let connect_timeout = Duration::from_millis(config.connect_timeout_ms.unwrap_or(30_000));
74
75        let (ws_stream, _) = timeout(connect_timeout, connect_async(&url_parsed))
76            .await
77            .map_err(|_| McpError::WebSocket("Connection timeout".to_string()))?
78            .map_err(|e| McpError::WebSocket(format!("Failed to connect: {}", e)))?;
79
80        let (ws_sender, ws_receiver) = ws_stream.split();
81
82        let pending_requests = Arc::new(Mutex::new(HashMap::new()));
83        let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
84        let state = Arc::new(RwLock::new(ConnectionState::Connected));
85
86        // Start message handling task
87        let message_handler = tokio::spawn(Self::handle_messages(
88            ws_receiver,
89            pending_requests.clone(),
90            notification_sender,
91            state.clone(),
92        ));
93
94        Ok(Self {
95            ws_sender: Some(ws_sender),
96            pending_requests,
97            notification_receiver: Some(notification_receiver),
98            config,
99            state,
100            url: url_str.to_string(),
101            message_handler: Some(message_handler),
102        })
103    }
104
105    async fn handle_messages(
106        mut ws_receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
107        pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
108        notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
109        state: Arc<RwLock<ConnectionState>>,
110    ) {
111        while let Some(message) = ws_receiver.next().await {
112            match message {
113                Ok(Message::Text(text)) => {
114                    tracing::trace!("Received WebSocket message: {}", text);
115
116                    // Try to parse as response first
117                    if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&text) {
118                        let mut pending = pending_requests.lock().await;
119                        if let Some(sender) = pending.remove(&response.id) {
120                            if let Err(_) = sender.send(response) {
121                                tracing::warn!("Failed to send response to waiting request");
122                            }
123                        } else {
124                            tracing::warn!(
125                                "Received response for unknown request ID: {:?}",
126                                response.id
127                            );
128                        }
129                    }
130                    // Try to parse as notification
131                    else if let Ok(notification) =
132                        serde_json::from_str::<JsonRpcNotification>(&text)
133                    {
134                        if let Err(_) = notification_sender.send(notification) {
135                            tracing::debug!("Notification receiver dropped");
136                            break;
137                        }
138                    } else {
139                        tracing::warn!("Failed to parse WebSocket message: {}", text);
140                    }
141                }
142                Ok(Message::Close(_)) => {
143                    tracing::info!("WebSocket connection closed");
144                    *state.write().await = ConnectionState::Disconnected;
145                    break;
146                }
147                Ok(Message::Ping(_data)) => {
148                    tracing::trace!("Received WebSocket ping");
149                    // Pong responses are handled automatically by tungstenite
150                }
151                Ok(Message::Pong(_)) => {
152                    tracing::trace!("Received WebSocket pong");
153                }
154                Ok(Message::Binary(_)) => {
155                    tracing::warn!("Received unexpected binary WebSocket message");
156                }
157                Ok(Message::Frame(_)) => {
158                    tracing::trace!("Received WebSocket frame (internal)");
159                    // Frame messages are internal to tungstenite
160                }
161                Err(e) => {
162                    tracing::error!("WebSocket error: {}", e);
163                    *state.write().await = ConnectionState::Error(e.to_string());
164                    break;
165                }
166            }
167        }
168
169        tracing::debug!("WebSocket message handler exiting");
170    }
171
172    async fn send_message(&mut self, message: Message) -> McpResult<()> {
173        if let Some(ref mut sender) = self.ws_sender {
174            sender
175                .send(message)
176                .await
177                .map_err(|e| McpError::WebSocket(format!("Failed to send message: {}", e)))?;
178        } else {
179            return Err(McpError::WebSocket("WebSocket not connected".to_string()));
180        }
181        Ok(())
182    }
183}
184
185#[async_trait]
186impl Transport for WebSocketClientTransport {
187    async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
188        let (sender, receiver) = tokio::sync::oneshot::channel();
189
190        // Store the pending request
191        {
192            let mut pending = self.pending_requests.lock().await;
193            pending.insert(request.id.clone(), sender);
194        }
195
196        // Send the request
197        let request_text =
198            serde_json::to_string(&request).map_err(|e| McpError::Serialization(e))?;
199
200        tracing::trace!("Sending WebSocket request: {}", request_text);
201
202        self.send_message(Message::Text(request_text)).await?;
203
204        // Wait for response with timeout
205        let timeout_duration = Duration::from_millis(self.config.read_timeout_ms.unwrap_or(60_000));
206
207        let response = timeout(timeout_duration, receiver)
208            .await
209            .map_err(|_| McpError::WebSocket("Request timeout".to_string()))?
210            .map_err(|_| McpError::WebSocket("Response channel closed".to_string()))?;
211
212        Ok(response)
213    }
214
215    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
216        let notification_text =
217            serde_json::to_string(&notification).map_err(|e| McpError::Serialization(e))?;
218
219        tracing::trace!("Sending WebSocket notification: {}", notification_text);
220
221        self.send_message(Message::Text(notification_text)).await
222    }
223
224    async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
225        if let Some(ref mut receiver) = self.notification_receiver {
226            match receiver.try_recv() {
227                Ok(notification) => Ok(Some(notification)),
228                Err(mpsc::error::TryRecvError::Empty) => Ok(None),
229                Err(mpsc::error::TryRecvError::Disconnected) => Err(McpError::WebSocket(
230                    "Notification channel disconnected".to_string(),
231                )),
232            }
233        } else {
234            Ok(None)
235        }
236    }
237
238    async fn close(&mut self) -> McpResult<()> {
239        tracing::debug!("Closing WebSocket connection");
240
241        *self.state.write().await = ConnectionState::Closing;
242
243        // Send close message
244        if let Some(ref mut sender) = self.ws_sender {
245            let _ = sender.send(Message::Close(None)).await;
246        }
247
248        // Abort message handler
249        if let Some(handle) = self.message_handler.take() {
250            handle.abort();
251        }
252
253        self.ws_sender = None;
254        self.notification_receiver = None;
255
256        *self.state.write().await = ConnectionState::Disconnected;
257
258        Ok(())
259    }
260
261    fn is_connected(&self) -> bool {
262        // We'd need to check the actual state here
263        self.ws_sender.is_some()
264    }
265
266    fn connection_info(&self) -> String {
267        format!("WebSocket transport (url: {})", self.url)
268    }
269}
270
271// ============================================================================
272// WebSocket Server Transport
273// ============================================================================
274
275/// Connection state for a WebSocket client
276struct WebSocketConnection {
277    sender: SplitSink<WebSocketStream<TcpStream>, Message>,
278    id: String,
279}
280
281/// WebSocket transport for MCP servers
282///
283/// This transport serves MCP requests over WebSocket connections,
284/// allowing multiple concurrent clients with bidirectional communication.
285pub struct WebSocketServerTransport {
286    bind_addr: String,
287    config: TransportConfig,
288    clients: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
289    request_handler: Arc<
290        RwLock<
291            Option<
292                Arc<
293                    dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
294                        + Send
295                        + Sync,
296                >,
297            >,
298        >,
299    >,
300    server_handle: Option<tokio::task::JoinHandle<()>>,
301    running: Arc<RwLock<bool>>,
302    shutdown_sender: Option<broadcast::Sender<()>>,
303}
304
305impl WebSocketServerTransport {
306    /// Create a new WebSocket server transport
307    ///
308    /// # Arguments
309    /// * `bind_addr` - Address to bind the WebSocket server to (e.g., "0.0.0.0:8080")
310    ///
311    /// # Returns
312    /// New WebSocket server transport instance
313    pub fn new<S: Into<String>>(bind_addr: S) -> Self {
314        Self::with_config(bind_addr, TransportConfig::default())
315    }
316
317    /// Create a new WebSocket server transport with custom configuration
318    ///
319    /// # Arguments
320    /// * `bind_addr` - Address to bind the WebSocket server to
321    /// * `config` - Transport configuration
322    ///
323    /// # Returns
324    /// New WebSocket server transport instance
325    pub fn with_config<S: Into<String>>(bind_addr: S, config: TransportConfig) -> Self {
326        let (shutdown_sender, _) = broadcast::channel(1);
327
328        Self {
329            bind_addr: bind_addr.into(),
330            config,
331            clients: Arc::new(RwLock::new(HashMap::new())),
332            request_handler: Arc::new(RwLock::new(None)),
333            server_handle: None,
334            running: Arc::new(RwLock::new(false)),
335            shutdown_sender: Some(shutdown_sender),
336        }
337    }
338
339    /// Set the request handler function
340    ///
341    /// # Arguments
342    /// * `handler` - Function that processes incoming requests
343    pub async fn set_request_handler<F>(&mut self, handler: F)
344    where
345        F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
346            + Send
347            + Sync
348            + 'static,
349    {
350        let mut request_handler = self.request_handler.write().await;
351        *request_handler = Some(Arc::new(handler));
352    }
353
354    async fn handle_client_connection(
355        stream: TcpStream,
356        clients: Arc<RwLock<HashMap<String, WebSocketConnection>>>,
357        request_handler: Arc<
358            RwLock<
359                Option<
360                    Arc<
361                        dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
362                            + Send
363                            + Sync,
364                    >,
365                >,
366            >,
367        >,
368        mut shutdown_receiver: broadcast::Receiver<()>,
369    ) {
370        let client_id = uuid::Uuid::new_v4().to_string();
371
372        let ws_stream = match accept_async(stream).await {
373            Ok(ws) => ws,
374            Err(e) => {
375                tracing::error!("Failed to accept WebSocket connection: {}", e);
376                return;
377            }
378        };
379
380        tracing::info!("New WebSocket client connected: {}", client_id);
381
382        let (mut ws_sender, mut ws_receiver) = ws_stream.split();
383
384        // Add client to the connections map
385        {
386            let mut clients_guard = clients.write().await;
387            clients_guard.insert(
388                client_id.clone(),
389                WebSocketConnection {
390                    sender: ws_sender,
391                    id: client_id.clone(),
392                },
393            );
394        }
395
396        // Handle messages from this client
397        loop {
398            tokio::select! {
399                message = ws_receiver.next() => {
400                    match message {
401                        Some(Ok(Message::Text(text))) => {
402                            tracing::trace!("Received message from {}: {}", client_id, text);
403
404                            // Try to parse as request
405                            if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(&text) {
406                                let handler_guard = request_handler.read().await;
407                                if let Some(ref handler) = *handler_guard {
408                                    let response_rx = handler(request.clone());
409                                    drop(handler_guard);
410
411                                    match response_rx.await {
412                                        Ok(response) => {
413                                            let response_text = match serde_json::to_string(&response) {
414                                                Ok(text) => text,
415                                                Err(e) => {
416                                                    tracing::error!("Failed to serialize response: {}", e);
417                                                    continue;
418                                                }
419                                            };
420
421                                            // Send response back to client
422                                            let mut clients_guard = clients.write().await;
423                                            if let Some(client) = clients_guard.get_mut(&client_id) {
424                                                if let Err(e) = client.sender.send(Message::Text(response_text)).await {
425                                                    tracing::error!("Failed to send response to client {}: {}", client_id, e);
426                                                    break;
427                                                }
428                                            }
429                                        }
430                                        Err(_) => {
431                                            tracing::error!("Request handler channel closed for client {}", client_id);
432                                        }
433                                    }
434                                } else {
435                                    tracing::warn!("No request handler configured for client {}", client_id);
436                                }
437                            }
438                            // Handle notifications (no response needed)
439                            else if let Ok(_notification) = serde_json::from_str::<JsonRpcNotification>(&text) {
440                                tracing::trace!("Received notification from client {}", client_id);
441                                // Notifications don't require responses
442                            } else {
443                                tracing::warn!("Failed to parse message from client {}: {}", client_id, text);
444                            }
445                        }
446                        Some(Ok(Message::Close(_))) => {
447                            tracing::info!("Client {} disconnected", client_id);
448                            break;
449                        }
450                        Some(Ok(Message::Ping(data))) => {
451                            tracing::trace!("Received ping from client {}", client_id);
452                            let mut clients_guard = clients.write().await;
453                            if let Some(client) = clients_guard.get_mut(&client_id) {
454                                if let Err(e) = client.sender.send(Message::Pong(data)).await {
455                                    tracing::error!("Failed to send pong to client {}: {}", client_id, e);
456                                    break;
457                                }
458                            }
459                        }
460                        Some(Ok(Message::Pong(_))) => {
461                            tracing::trace!("Received pong from client {}", client_id);
462                        }
463                        Some(Ok(Message::Binary(_))) => {
464                            tracing::warn!("Received unexpected binary message from client {}", client_id);
465                        }
466                        Some(Ok(Message::Frame(_))) => {
467                            tracing::trace!("Received WebSocket frame from client {} (internal)", client_id);
468                            // Frame messages are internal to tungstenite
469                        }
470                        Some(Err(e)) => {
471                            tracing::error!("WebSocket error for client {}: {}", client_id, e);
472                            break;
473                        }
474                        None => {
475                            tracing::info!("WebSocket stream ended for client {}", client_id);
476                            break;
477                        }
478                    }
479                }
480                _ = shutdown_receiver.recv() => {
481                    tracing::info!("Shutting down connection for client {}", client_id);
482                    break;
483                }
484            }
485        }
486
487        // Remove client from connections
488        {
489            let mut clients_guard = clients.write().await;
490            clients_guard.remove(&client_id);
491        }
492
493        tracing::info!("Client {} connection handler exiting", client_id);
494    }
495}
496
497#[async_trait]
498impl ServerTransport for WebSocketServerTransport {
499    async fn start(&mut self) -> McpResult<()> {
500        tracing::info!("Starting WebSocket server on {}", self.bind_addr);
501
502        let listener = TcpListener::bind(&self.bind_addr).await.map_err(|e| {
503            McpError::WebSocket(format!("Failed to bind to {}: {}", self.bind_addr, e))
504        })?;
505
506        let clients = self.clients.clone();
507        let request_handler = self.request_handler.clone();
508        let running = self.running.clone();
509        let shutdown_sender = self.shutdown_sender.as_ref().unwrap().clone();
510
511        *running.write().await = true;
512
513        let server_handle = tokio::spawn(async move {
514            let mut shutdown_receiver = shutdown_sender.subscribe();
515
516            loop {
517                tokio::select! {
518                    result = listener.accept() => {
519                        match result {
520                            Ok((stream, addr)) => {
521                                tracing::debug!("New connection from: {}", addr);
522
523                                tokio::spawn(Self::handle_client_connection(
524                                    stream,
525                                    clients.clone(),
526                                    request_handler.clone(),
527                                    shutdown_sender.subscribe(),
528                                ));
529                            }
530                            Err(e) => {
531                                tracing::error!("Failed to accept connection: {}", e);
532                            }
533                        }
534                    }
535                    _ = shutdown_receiver.recv() => {
536                        tracing::info!("WebSocket server shutting down");
537                        break;
538                    }
539                }
540            }
541        });
542
543        self.server_handle = Some(server_handle);
544
545        tracing::info!(
546            "WebSocket server started successfully on {}",
547            self.bind_addr
548        );
549        Ok(())
550    }
551
552    async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
553        let handler_guard = self.request_handler.read().await;
554
555        if let Some(ref handler) = *handler_guard {
556            let response_rx = handler(request);
557            drop(handler_guard);
558
559            match response_rx.await {
560                Ok(response) => Ok(response),
561                Err(_) => Err(McpError::WebSocket(
562                    "Request handler channel closed".to_string(),
563                )),
564            }
565        } else {
566            Ok(JsonRpcResponse {
567                jsonrpc: "2.0".to_string(),
568                id: request.id,
569                result: None,
570                error: Some(crate::protocol::types::JsonRpcError {
571                    code: crate::protocol::types::METHOD_NOT_FOUND,
572                    message: "No request handler configured".to_string(),
573                    data: None,
574                }),
575            })
576        }
577    }
578
579    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
580        let notification_text =
581            serde_json::to_string(&notification).map_err(|e| McpError::Serialization(e))?;
582
583        let mut clients_guard = self.clients.write().await;
584        let mut disconnected_clients = Vec::new();
585
586        for (client_id, client) in clients_guard.iter_mut() {
587            if let Err(e) = client
588                .sender
589                .send(Message::Text(notification_text.clone()))
590                .await
591            {
592                tracing::error!("Failed to send notification to client {}: {}", client_id, e);
593                disconnected_clients.push(client_id.clone());
594            }
595        }
596
597        // Remove disconnected clients
598        for client_id in disconnected_clients {
599            clients_guard.remove(&client_id);
600        }
601
602        Ok(())
603    }
604
605    async fn stop(&mut self) -> McpResult<()> {
606        tracing::info!("Stopping WebSocket server");
607
608        *self.running.write().await = false;
609
610        // Send shutdown signal
611        if let Some(ref sender) = self.shutdown_sender {
612            let _ = sender.send(());
613        }
614
615        // Wait for server to stop
616        if let Some(handle) = self.server_handle.take() {
617            handle.abort();
618        }
619
620        // Close all client connections
621        let mut clients_guard = self.clients.write().await;
622        for (client_id, client) in clients_guard.iter_mut() {
623            tracing::debug!("Closing connection for client {}", client_id);
624            let _ = client.sender.send(Message::Close(None)).await;
625        }
626        clients_guard.clear();
627
628        Ok(())
629    }
630
631    fn is_running(&self) -> bool {
632        // Check if we have an active server handle
633        self.server_handle.is_some()
634    }
635
636    fn server_info(&self) -> String {
637        format!("WebSocket server transport (bind: {})", self.bind_addr)
638    }
639}
640
641#[cfg(test)]
642mod tests {
643    use super::*;
644    use serde_json::json;
645
646    #[test]
647    fn test_websocket_server_creation() {
648        let transport = WebSocketServerTransport::new("127.0.0.1:0");
649        assert_eq!(transport.bind_addr, "127.0.0.1:0");
650        assert!(!transport.is_running());
651    }
652
653    #[test]
654    fn test_websocket_server_with_config() {
655        let mut config = TransportConfig::default();
656        config.max_message_size = Some(64 * 1024);
657
658        let transport = WebSocketServerTransport::with_config("0.0.0.0:9090", config);
659        assert_eq!(transport.bind_addr, "0.0.0.0:9090");
660        assert_eq!(transport.config.max_message_size, Some(64 * 1024));
661    }
662
663    #[tokio::test]
664    async fn test_websocket_client_invalid_url() {
665        let result = WebSocketClientTransport::new("invalid-url").await;
666        assert!(result.is_err());
667
668        if let Err(McpError::WebSocket(msg)) = result {
669            assert!(msg.contains("Invalid WebSocket URL"));
670        } else {
671            panic!("Expected WebSocket error");
672        }
673    }
674
675    #[tokio::test]
676    async fn test_websocket_client_connection_info() {
677        // This will fail to connect but we can test the URL parsing
678        let url = "ws://localhost:9999/test";
679        if let Ok(transport) = WebSocketClientTransport::new(url).await {
680            let info = transport.connection_info();
681            assert!(info.contains("localhost:9999"));
682        }
683        // If connection fails (which is expected), that's fine for this test
684    }
685}