mcp_sdk_rs/transport/
websocket.rs

1//! WebSocket Transport Implementation
2//!
3//! This module provides a transport implementation that uses WebSocket protocol
4//! for communication. This transport is ideal for:
5//! - Network-based client-server communication
6//! - Real-time bidirectional messaging
7//! - Web-based applications
8//! - Scenarios requiring secure communication (WSS)
9//!
10//! The implementation uses tokio-tungstenite for WebSocket functionality and
11//! provides thread-safe connection management through Arc and Mutex.
12
13use async_trait::async_trait;
14use futures::{
15    stream::{SplitSink, SplitStream},
16    Sink, SinkExt, Stream, StreamExt,
17};
18use std::{fmt::Display, pin::Pin, sync::Arc};
19use tokio::io::{AsyncRead, AsyncWrite};
20use tokio::sync::Mutex;
21use tokio_tungstenite::{
22    connect_async,
23    tungstenite::{error::Error as WsError, protocol::CloseFrame, protocol::Message as WsMessage},
24    WebSocketStream,
25};
26use url::Url;
27
28use crate::{
29    error::Error,
30    transport::{Message, Transport},
31};
32
33/// Type alias for the WebSocket connection stream
34type WebSocketConnection<S> = WebSocketStream<S>;
35
36/// A transport implementation that uses WebSocket protocol for communication.
37///
38/// This transport provides bidirectional communication over WebSocket protocol,
39/// supporting both secure (WSS) and standard (WS) connections.
40pub struct WebSocketTransport<S> {
41    read_connection: Arc<Mutex<SplitStream<WebSocketConnection<S>>>>,
42    write_connection: Arc<Mutex<SplitSink<WebSocketConnection<S>, WsMessage>>>,
43}
44
45impl<S> WebSocketTransport<S>
46where
47    S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
48{
49    /// Creates a new WebSocket transport from an existing WebSocket stream.
50    /// This is typically used on the server side when accepting a new connection.
51    ///
52    /// # Arguments
53    ///
54    /// * `stream` - The WebSocket stream from an accepted connection
55
56    pub fn from_stream(stream: WebSocketConnection<S>) -> Self {
57        let (w, r) = stream.split();
58        Self {
59            read_connection: Arc::new(Mutex::new(r)),
60            write_connection: Arc::new(Mutex::new(w)),
61        }
62    }
63
64    /// Converts an MCP message to a WebSocket message.
65    fn convert_to_ws_message(message: &Message) -> Result<WsMessage, Error> {
66        let json =
67            serde_json::to_string(message).map_err(|e| Error::Serialization(e.to_string()))?;
68        Ok(WsMessage::Text(json))
69    }
70
71    /// Parses a WebSocket message into an MCP message.
72    fn parse_ws_message(ws_message: WsMessage) -> Result<Message, Error> {
73        match ws_message {
74            WsMessage::Text(text) => {
75                serde_json::from_str(&text).map_err(|e| Error::Serialization(e.to_string()))
76            }
77            WsMessage::Binary(_) => Err(Error::Transport(
78                "Binary messages not supported".to_string(),
79            )),
80            WsMessage::Ping(_) => Ok(Message::Notification(crate::protocol::Notification {
81                jsonrpc: crate::protocol::JSONRPC_VERSION.to_string(),
82                method: "ping".to_string(),
83                params: None,
84            })),
85            WsMessage::Pong(_) => Ok(Message::Notification(crate::protocol::Notification {
86                jsonrpc: crate::protocol::JSONRPC_VERSION.to_string(),
87                method: "pong".to_string(),
88                params: None,
89            })),
90            WsMessage::Close(_) => Err(Error::Transport("Connection closed".to_string())),
91            WsMessage::Frame(_) => Err(Error::Transport("Raw frames not supported".to_string())),
92        }
93    }
94
95    /// Handle a WebSocket message, including control messages
96    async fn handle_ws_message<T, E>(
97        connection: &mut T,
98        message: WsMessage,
99    ) -> Result<Option<Message>, Error>
100    where
101        T: Sink<WsMessage, Error = E> + Unpin,
102        E: Display,
103    {
104        match message {
105            WsMessage::Ping(data) => {
106                // Automatically respond to ping with pong
107                connection
108                    .send(WsMessage::Pong(data))
109                    .await
110                    .map_err(|e| Error::Transport(e.to_string()))?;
111                Ok(None)
112            }
113            WsMessage::Pong(_) => {
114                // Ignore pong messages
115                Ok(None)
116            }
117            _ => Self::parse_ws_message(message).map(Some),
118        }
119    }
120}
121
122#[async_trait]
123impl<S> Transport for WebSocketTransport<S>
124where
125    S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
126{
127    async fn send(&self, message: Message) -> Result<(), Error> {
128        let ws_message = Self::convert_to_ws_message(&message)?;
129        let mut connection = self.write_connection.lock().await;
130        connection
131            .send(ws_message)
132            .await
133            .map_err(|e| Error::Transport(e.to_string()))
134    }
135
136    fn receive(&self) -> Pin<Box<dyn Stream<Item = Result<Message, Error>> + Send>> {
137        let read_connection = self.read_connection.clone();
138        let write_connection = self.write_connection.clone();
139
140        Box::pin(futures::stream::unfold(
141            read_connection,
142            move |read_connection| {
143                let read_connection = read_connection.clone();
144                let write_connection = write_connection.clone();
145                async move {
146                    loop {
147                        let mut guard = read_connection.lock().await;
148                        match guard.next().await {
149                            Some(Ok(ws_message)) => {
150                                drop(guard);
151                                let mut guard = write_connection.lock().await;
152                                match Self::handle_ws_message(&mut *guard, ws_message).await {
153                                    Ok(Some(message)) => {
154                                        return Some((Ok(message), read_connection.clone()))
155                                    }
156                                    Ok(None) => continue, // Control message handled, continue to next message
157                                    Err(e) => return Some((Err(e), read_connection.clone())),
158                                }
159                            }
160                            Some(Err(e)) => {
161                                return Some((
162                                    Err(Error::Transport(e.to_string())),
163                                    read_connection.clone(),
164                                ))
165                            }
166                            None => return None,
167                        }
168                    }
169                }
170            },
171        ))
172    }
173
174    async fn close(&self) -> Result<(), Error> {
175        let mut connection = self.write_connection.lock().await;
176        // Send close frame with normal closure status
177        connection
178            .send(WsMessage::Close(Some(CloseFrame {
179                code: 1000u16.into(), // Normal closure
180                reason: "Client initiated close".into(),
181            })))
182            .await
183            .map_err(|e| Error::Transport(e.to_string()))?;
184        drop(connection);
185
186        let mut connection = self.read_connection.lock().await;
187        // Wait for the close frame from the server
188        while let Some(msg) = connection.next().await {
189            match msg {
190                Ok(WsMessage::Close(_)) => break,
191                Ok(_) => continue,
192                Err(e) => {
193                    if matches!(e, WsError::ConnectionClosed) {
194                        break;
195                    }
196                    return Err(Error::Transport(e.to_string()));
197                }
198            }
199        }
200
201        Ok(())
202    }
203}
204
205// Client-specific implementation
206impl WebSocketTransport<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>> {
207    /// Creates a new WebSocket transport as a client by connecting to the specified URL.
208    ///
209    /// # Arguments
210    ///
211    /// * `url` - The WebSocket URL to connect to (ws:// or wss://)
212    ///
213    /// # Returns
214    ///
215    /// Returns a Result containing:
216    /// - Ok: The new WebSocketTransport instance
217    /// - Err: An error if connection fails
218    pub async fn new(url: &str) -> Result<Self, Error> {
219        let url = Url::parse(url).map_err(|e| Error::Transport(e.to_string()))?;
220
221        let (ws_stream, _) = connect_async(url)
222            .await
223            .map_err(|e| Error::Transport(e.to_string()))?;
224        let (w, r) = ws_stream.split();
225        Ok(Self {
226            read_connection: Arc::new(Mutex::new(r)),
227            write_connection: Arc::new(Mutex::new(w)),
228        })
229    }
230}