mcp_proxy/
ws_transport.rs1use std::sync::Arc;
20use std::sync::atomic::{AtomicBool, Ordering};
21
22use async_trait::async_trait;
23use futures_util::{SinkExt, StreamExt};
24use tokio::sync::Mutex;
25use tokio_tungstenite::tungstenite::Message;
26
27type WsStream =
28 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
29
30pub struct WebSocketClientTransport {
35 sink: Arc<Mutex<futures_util::stream::SplitSink<WsStream, Message>>>,
36 stream: Arc<Mutex<futures_util::stream::SplitStream<WsStream>>>,
37 connected: Arc<AtomicBool>,
38}
39
40impl WebSocketClientTransport {
41 pub async fn connect(url: &str) -> anyhow::Result<Self> {
47 let (ws_stream, _response) = tokio_tungstenite::connect_async(url)
48 .await
49 .map_err(|e| anyhow::anyhow!("WebSocket connection failed: {e}"))?;
50
51 let (sink, stream) = ws_stream.split();
52
53 Ok(Self {
54 sink: Arc::new(Mutex::new(sink)),
55 stream: Arc::new(Mutex::new(stream)),
56 connected: Arc::new(AtomicBool::new(true)),
57 })
58 }
59
60 pub async fn connect_with_bearer_token(url: &str, token: &str) -> anyhow::Result<Self> {
64 use tokio_tungstenite::tungstenite::http::Request;
65
66 let request = Request::builder()
67 .uri(url)
68 .header("Authorization", format!("Bearer {token}"))
69 .header("Connection", "Upgrade")
70 .header("Upgrade", "websocket")
71 .header("Sec-WebSocket-Version", "13")
72 .header(
73 "Sec-WebSocket-Key",
74 tokio_tungstenite::tungstenite::handshake::client::generate_key(),
75 )
76 .body(())
77 .map_err(|e| anyhow::anyhow!("invalid WebSocket request: {e}"))?;
78
79 let (ws_stream, _response) = tokio_tungstenite::connect_async(request)
80 .await
81 .map_err(|e| anyhow::anyhow!("WebSocket connection failed: {e}"))?;
82
83 let (sink, stream) = ws_stream.split();
84
85 Ok(Self {
86 sink: Arc::new(Mutex::new(sink)),
87 stream: Arc::new(Mutex::new(stream)),
88 connected: Arc::new(AtomicBool::new(true)),
89 })
90 }
91}
92
93#[async_trait]
94impl tower_mcp::client::ClientTransport for WebSocketClientTransport {
95 async fn send(&mut self, message: &str) -> tower_mcp::error::Result<()> {
96 let mut sink = self.sink.lock().await;
97 sink.send(Message::Text(message.into()))
98 .await
99 .map_err(|e| tower_mcp::error::Error::Transport(e.to_string()))?;
100 Ok(())
101 }
102
103 async fn recv(&mut self) -> tower_mcp::error::Result<Option<String>> {
104 let mut stream = self.stream.lock().await;
105 loop {
106 match stream.next().await {
107 Some(Ok(Message::Text(text))) => return Ok(Some(text.as_str().to_owned())),
108 Some(Ok(Message::Close(_))) | None => {
109 self.connected.store(false, Ordering::SeqCst);
110 return Ok(None);
111 }
112 Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => {
113 continue;
115 }
116 Some(Ok(Message::Binary(data))) => {
117 let text = std::str::from_utf8(&data)
119 .map_err(|e| tower_mcp::error::Error::Transport(e.to_string()))?;
120 return Ok(Some(text.to_string()));
121 }
122 Some(Err(e)) => {
123 self.connected.store(false, Ordering::SeqCst);
124 return Err(tower_mcp::error::Error::Transport(e.to_string()));
125 }
126 }
127 }
128 }
129
130 fn is_connected(&self) -> bool {
131 self.connected.load(Ordering::SeqCst)
132 }
133
134 async fn close(&mut self) -> tower_mcp::error::Result<()> {
135 self.connected.store(false, Ordering::SeqCst);
136 let mut sink = self.sink.lock().await;
137 let _ = sink.send(Message::Close(None)).await;
138 Ok(())
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[tokio::test]
147 async fn connect_fails_with_invalid_url() {
148 let result = WebSocketClientTransport::connect("ws://127.0.0.1:1").await;
149 let err = result.err().expect("should fail").to_string();
150 assert!(
151 err.contains("WebSocket connection failed"),
152 "unexpected error: {err}"
153 );
154 }
155
156 #[tokio::test]
157 async fn connect_with_bearer_token_fails_with_invalid_url() {
158 let result =
159 WebSocketClientTransport::connect_with_bearer_token("ws://127.0.0.1:1", "tok").await;
160 let err = result.err().expect("should fail").to_string();
161 assert!(
162 err.contains("WebSocket connection failed"),
163 "unexpected error: {err}"
164 );
165 }
166}