Skip to main content

mcp_proxy/
ws_transport.rs

1//! WebSocket client transport for connecting to WebSocket-based MCP backends.
2//!
3//! Implements tower-mcp's [`ClientTransport`](tower_mcp::client::ClientTransport) trait over a WebSocket connection
4//! using `tokio-tungstenite`. Messages are sent and received as text frames
5//! containing JSON-RPC payloads.
6//!
7//! # Example
8//!
9//! ```rust,no_run
10//! use mcp_proxy::ws_transport::WebSocketClientTransport;
11//!
12//! # async fn example() -> anyhow::Result<()> {
13//! let transport = WebSocketClientTransport::connect("ws://localhost:8080/ws").await?;
14//! // Pass to McpProxy::builder().backend("name", transport).await
15//! # Ok(())
16//! # }
17//! ```
18
19use std::sync::Arc;
20use std::sync::atomic::{AtomicBool, Ordering};
21
22use async_trait::async_trait;
23use futures_util::{SinkExt, StreamExt};
24use tokio::sync::Mutex;
25use tokio_tungstenite::tungstenite::Message;
26
27type WsStream =
28    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
29
30/// A WebSocket client transport for MCP backend connections.
31///
32/// Connects to a WebSocket endpoint and exchanges JSON-RPC messages
33/// as text frames. Supports `ws://` and `wss://` (TLS) URLs.
34pub struct WebSocketClientTransport {
35    sink: Arc<Mutex<futures_util::stream::SplitSink<WsStream, Message>>>,
36    stream: Arc<Mutex<futures_util::stream::SplitStream<WsStream>>>,
37    connected: Arc<AtomicBool>,
38}
39
40impl WebSocketClientTransport {
41    /// Connect to a WebSocket endpoint.
42    ///
43    /// # Errors
44    ///
45    /// Returns an error if the WebSocket handshake fails or the URL is invalid.
46    pub async fn connect(url: &str) -> anyhow::Result<Self> {
47        let (ws_stream, _response) = tokio_tungstenite::connect_async(url)
48            .await
49            .map_err(|e| anyhow::anyhow!("WebSocket connection failed: {e}"))?;
50
51        let (sink, stream) = ws_stream.split();
52
53        Ok(Self {
54            sink: Arc::new(Mutex::new(sink)),
55            stream: Arc::new(Mutex::new(stream)),
56            connected: Arc::new(AtomicBool::new(true)),
57        })
58    }
59
60    /// Connect to a WebSocket endpoint with a bearer token for authentication.
61    ///
62    /// The token is sent in the `Authorization` header during the handshake.
63    pub async fn connect_with_bearer_token(url: &str, token: &str) -> anyhow::Result<Self> {
64        use tokio_tungstenite::tungstenite::http::Request;
65
66        let request = Request::builder()
67            .uri(url)
68            .header("Authorization", format!("Bearer {token}"))
69            .header("Connection", "Upgrade")
70            .header("Upgrade", "websocket")
71            .header("Sec-WebSocket-Version", "13")
72            .header(
73                "Sec-WebSocket-Key",
74                tokio_tungstenite::tungstenite::handshake::client::generate_key(),
75            )
76            .body(())
77            .map_err(|e| anyhow::anyhow!("invalid WebSocket request: {e}"))?;
78
79        let (ws_stream, _response) = tokio_tungstenite::connect_async(request)
80            .await
81            .map_err(|e| anyhow::anyhow!("WebSocket connection failed: {e}"))?;
82
83        let (sink, stream) = ws_stream.split();
84
85        Ok(Self {
86            sink: Arc::new(Mutex::new(sink)),
87            stream: Arc::new(Mutex::new(stream)),
88            connected: Arc::new(AtomicBool::new(true)),
89        })
90    }
91}
92
93#[async_trait]
94impl tower_mcp::client::ClientTransport for WebSocketClientTransport {
95    async fn send(&mut self, message: &str) -> tower_mcp::error::Result<()> {
96        let mut sink = self.sink.lock().await;
97        sink.send(Message::Text(message.into()))
98            .await
99            .map_err(|e| tower_mcp::error::Error::Transport(e.to_string()))?;
100        Ok(())
101    }
102
103    async fn recv(&mut self) -> tower_mcp::error::Result<Option<String>> {
104        let mut stream = self.stream.lock().await;
105        loop {
106            match stream.next().await {
107                Some(Ok(Message::Text(text))) => return Ok(Some(text.as_str().to_owned())),
108                Some(Ok(Message::Close(_))) | None => {
109                    self.connected.store(false, Ordering::SeqCst);
110                    return Ok(None);
111                }
112                Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {
113                    // Pong is handled automatically by tungstenite; skip control frames
114                    continue;
115                }
116                Some(Ok(Message::Binary(data))) => {
117                    // Try to interpret binary as UTF-8 text
118                    let text = std::str::from_utf8(&data)
119                        .map_err(|e| tower_mcp::error::Error::Transport(e.to_string()))?;
120                    return Ok(Some(text.to_string()));
121                }
122                Some(Err(e)) => {
123                    self.connected.store(false, Ordering::SeqCst);
124                    return Err(tower_mcp::error::Error::Transport(e.to_string()));
125                }
126            }
127        }
128    }
129
130    fn is_connected(&self) -> bool {
131        self.connected.load(Ordering::SeqCst)
132    }
133
134    async fn close(&mut self) -> tower_mcp::error::Result<()> {
135        self.connected.store(false, Ordering::SeqCst);
136        let mut sink = self.sink.lock().await;
137        let _ = sink.send(Message::Close(None)).await;
138        Ok(())
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[tokio::test]
147    async fn connect_fails_with_invalid_url() {
148        let result = WebSocketClientTransport::connect("ws://127.0.0.1:1").await;
149        let err = result.err().expect("should fail").to_string();
150        assert!(
151            err.contains("WebSocket connection failed"),
152            "unexpected error: {err}"
153        );
154    }
155
156    #[tokio::test]
157    async fn connect_with_bearer_token_fails_with_invalid_url() {
158        let result =
159            WebSocketClientTransport::connect_with_bearer_token("ws://127.0.0.1:1", "tok").await;
160        let err = result.err().expect("should fail").to_string();
161        assert!(
162            err.contains("WebSocket connection failed"),
163            "unexpected error: {err}"
164        );
165    }
166}