1use async_trait::async_trait;
12use futures::stream::{SplitSink, SplitStream};
13use futures::{SinkExt, StreamExt};
14use std::collections::HashMap;
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::Arc;
17use std::time::Duration;
18use tokio::net::TcpStream;
19use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
20use tokio_tungstenite::tungstenite::http::Request;
21use tokio_tungstenite::tungstenite::Message;
22use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
23
24use crate::mcp::error::{McpError, McpResult};
25use crate::mcp::transport::{
26 McpMessage, McpNotification, McpRequest, McpResponse, Transport, TransportConfig,
27 TransportEvent, TransportState,
28};
29use crate::mcp::types::{ConnectionOptions, TransportType};
30
31#[derive(Debug, Clone)]
33pub struct WebSocketConfig {
34 pub url: String,
36 pub headers: HashMap<String, String>,
38}
39
40struct PendingRequest {
42 tx: oneshot::Sender<McpResult<McpResponse>>,
44}
45
46type WsWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
47type WsReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
48
49pub struct WebSocketTransport {
54 config: WebSocketConfig,
56 options: ConnectionOptions,
58 state: Arc<RwLock<TransportState>>,
60 writer: Arc<Mutex<Option<WsWriter>>>,
62 message_tx: Arc<Mutex<Option<mpsc::Sender<String>>>>,
64 pending_requests: Arc<Mutex<HashMap<String, PendingRequest>>>,
66 event_tx: Arc<Mutex<Option<mpsc::Sender<TransportEvent>>>>,
68 request_counter: AtomicU64,
70 shutdown_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
72}
73
74impl WebSocketTransport {
75 pub fn new(config: WebSocketConfig, options: ConnectionOptions) -> Self {
77 Self {
78 config,
79 options,
80 state: Arc::new(RwLock::new(TransportState::Disconnected)),
81 writer: Arc::new(Mutex::new(None)),
82 message_tx: Arc::new(Mutex::new(None)),
83 pending_requests: Arc::new(Mutex::new(HashMap::new())),
84 event_tx: Arc::new(Mutex::new(None)),
85 request_counter: AtomicU64::new(1),
86 shutdown_tx: Arc::new(Mutex::new(None)),
87 }
88 }
89
90 pub fn from_config(config: TransportConfig, options: ConnectionOptions) -> McpResult<Self> {
92 match config {
93 TransportConfig::WebSocket { url, headers } => {
94 Ok(Self::new(WebSocketConfig { url, headers }, options))
95 }
96 _ => Err(McpError::config(
97 "Expected WebSocket transport configuration",
98 )),
99 }
100 }
101
102 pub fn next_request_id(&self) -> String {
104 let id = self.request_counter.fetch_add(1, Ordering::SeqCst);
105 format!("ws-req-{}", id)
106 }
107
108 async fn set_state(&self, state: TransportState) {
110 let mut current = self.state.write().await;
111 *current = state;
112 }
113
114 async fn emit_event(&self, event: TransportEvent) {
116 if let Some(tx) = self.event_tx.lock().await.as_ref() {
117 let _ = tx.send(event).await;
118 }
119 }
120
121 async fn handle_message(
123 message: &str,
124 pending_requests: &Arc<Mutex<HashMap<String, PendingRequest>>>,
125 event_tx: &Arc<Mutex<Option<mpsc::Sender<TransportEvent>>>>,
126 ) {
127 if let Ok(response) = serde_json::from_str::<McpResponse>(message) {
129 let id_str = match &response.id {
130 serde_json::Value::String(s) => s.clone(),
131 serde_json::Value::Number(n) => n.to_string(),
132 _ => return,
133 };
134
135 let mut pending = pending_requests.lock().await;
136 if let Some(req) = pending.remove(&id_str) {
137 let _ = req.tx.send(Ok(response));
138 }
139 return;
140 }
141
142 if let Ok(notification) = serde_json::from_str::<McpNotification>(message) {
144 if let Some(tx) = event_tx.lock().await.as_ref() {
145 let _ = tx
146 .send(TransportEvent::MessageReceived(Box::new(
147 McpMessage::Notification(notification),
148 )))
149 .await;
150 }
151 return;
152 }
153
154 if let Ok(request) = serde_json::from_str::<McpRequest>(message) {
156 if let Some(tx) = event_tx.lock().await.as_ref() {
157 let _ = tx
158 .send(TransportEvent::MessageReceived(Box::new(
159 McpMessage::Request(request),
160 )))
161 .await;
162 }
163 }
164 }
165
166 fn start_reader_task(&self, mut reader: WsReader, mut shutdown_rx: mpsc::Receiver<()>) {
168 let pending_requests = self.pending_requests.clone();
169 let event_tx = self.event_tx.clone();
170 let state = self.state.clone();
171
172 tokio::spawn(async move {
173 loop {
174 tokio::select! {
175 msg = reader.next() => {
176 match msg {
177 Some(Ok(Message::Text(text))) => {
178 Self::handle_message(&text, &pending_requests, &event_tx).await;
179 }
180 Some(Ok(Message::Close(_))) => {
181 let mut s = state.write().await;
182 *s = TransportState::Disconnected;
183 if let Some(tx) = event_tx.lock().await.as_ref() {
184 let _ = tx.send(TransportEvent::Disconnected {
185 reason: Some("WebSocket closed by server".to_string()),
186 }).await;
187 }
188 break;
189 }
190 Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => {
191 }
193 Some(Ok(Message::Binary(_))) => {
194 }
196 Some(Ok(Message::Frame(_))) => {
197 }
199 Some(Err(e)) => {
200 let mut s = state.write().await;
201 *s = TransportState::Error;
202 if let Some(tx) = event_tx.lock().await.as_ref() {
203 let _ = tx.send(TransportEvent::Error {
204 error: e.to_string(),
205 }).await;
206 }
207 break;
208 }
209 None => {
210 let mut s = state.write().await;
211 *s = TransportState::Disconnected;
212 if let Some(tx) = event_tx.lock().await.as_ref() {
213 let _ = tx.send(TransportEvent::Disconnected {
214 reason: Some("WebSocket stream ended".to_string()),
215 }).await;
216 }
217 break;
218 }
219 }
220 }
221 _ = shutdown_rx.recv() => {
222 break;
223 }
224 }
225 }
226 });
227 }
228
229 fn start_writer_task(&self, mut writer: WsWriter, mut message_rx: mpsc::Receiver<String>) {
231 let state = self.state.clone();
232 let event_tx = self.event_tx.clone();
233
234 tokio::spawn(async move {
235 while let Some(message) = message_rx.recv().await {
236 if let Err(e) = writer.send(Message::Text(message.into())).await {
237 let mut s = state.write().await;
238 *s = TransportState::Error;
239 if let Some(tx) = event_tx.lock().await.as_ref() {
240 let _ = tx
241 .send(TransportEvent::Error {
242 error: e.to_string(),
243 })
244 .await;
245 }
246 break;
247 }
248 }
249 });
250 }
251}
252
253#[async_trait]
254impl Transport for WebSocketTransport {
255 fn transport_type(&self) -> TransportType {
256 TransportType::WebSocket
257 }
258
259 fn state(&self) -> TransportState {
260 self.state
261 .try_read()
262 .map(|s| *s)
263 .unwrap_or(TransportState::Disconnected)
264 }
265
266 async fn connect(&mut self) -> McpResult<()> {
267 self.set_state(TransportState::Connecting).await;
268 self.emit_event(TransportEvent::Connecting).await;
269
270 let mut request = Request::builder().uri(&self.config.url);
272
273 for (key, value) in &self.config.headers {
274 request = request.header(key, value);
275 }
276
277 let request = request.body(()).map_err(|e| {
278 McpError::transport(format!("Failed to build WebSocket request: {}", e))
279 })?;
280
281 let (ws_stream, _response) = connect_async(request).await.map_err(|e| {
283 McpError::transport_with_source(
284 format!("Failed to connect to WebSocket server: {}", self.config.url),
285 e,
286 )
287 })?;
288
289 let (writer, reader) = ws_stream.split();
291
292 let (message_tx, message_rx) = mpsc::channel::<String>(self.options.queue_max_size);
294 let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
295 let (event_tx, _event_rx) = mpsc::channel::<TransportEvent>(100);
296
297 *self.writer.lock().await = Some(writer);
299 *self.message_tx.lock().await = Some(message_tx);
300 *self.shutdown_tx.lock().await = Some(shutdown_tx);
301 *self.event_tx.lock().await = Some(event_tx);
302
303 self.start_reader_task(reader, shutdown_rx);
305 self.start_writer_task(self.writer.lock().await.take().unwrap(), message_rx);
306
307 self.set_state(TransportState::Connected).await;
308 self.emit_event(TransportEvent::Connected).await;
309
310 Ok(())
311 }
312
313 async fn disconnect(&mut self) -> McpResult<()> {
314 self.set_state(TransportState::Closing).await;
315
316 if let Some(tx) = self.shutdown_tx.lock().await.take() {
318 let _ = tx.send(()).await;
319 }
320
321 *self.message_tx.lock().await = None;
323
324 let mut pending = self.pending_requests.lock().await;
326 for (_, req) in pending.drain() {
327 let _ = req.tx.send(Err(McpError::cancelled(
328 "Transport disconnected",
329 Some("disconnect".to_string()),
330 )));
331 }
332
333 self.set_state(TransportState::Disconnected).await;
334 self.emit_event(TransportEvent::Disconnected {
335 reason: Some("Disconnected by user".to_string()),
336 })
337 .await;
338
339 Ok(())
340 }
341
342 async fn send(&mut self, message: McpMessage) -> McpResult<()> {
343 let state = *self.state.read().await;
344 if state != TransportState::Connected {
345 return Err(McpError::transport("Transport is not connected"));
346 }
347
348 let json = serde_json::to_string(&message)?;
349
350 if let Some(tx) = self.message_tx.lock().await.as_ref() {
351 tx.send(json)
352 .await
353 .map_err(|e| McpError::transport(format!("Failed to send message: {}", e)))?;
354 } else {
355 return Err(McpError::transport("Message channel not available"));
356 }
357
358 Ok(())
359 }
360
361 async fn send_request(&mut self, request: McpRequest) -> McpResult<McpResponse> {
362 self.send_request_with_timeout(request, self.options.timeout)
363 .await
364 }
365
366 async fn send_request_with_timeout(
367 &mut self,
368 request: McpRequest,
369 timeout: Duration,
370 ) -> McpResult<McpResponse> {
371 let state = *self.state.read().await;
372 if state != TransportState::Connected {
373 return Err(McpError::transport("Transport is not connected"));
374 }
375
376 let id_str = match &request.id {
378 serde_json::Value::String(s) => s.clone(),
379 serde_json::Value::Number(n) => n.to_string(),
380 _ => return Err(McpError::protocol("Invalid request ID type")),
381 };
382
383 let (tx, rx) = oneshot::channel();
385
386 {
388 let mut pending = self.pending_requests.lock().await;
389 pending.insert(id_str.clone(), PendingRequest { tx });
390 }
391
392 let json = serde_json::to_string(&request)?;
394 if let Some(message_tx) = self.message_tx.lock().await.as_ref() {
395 message_tx
396 .send(json)
397 .await
398 .map_err(|e| McpError::transport(format!("Failed to send request: {}", e)))?;
399 } else {
400 self.pending_requests.lock().await.remove(&id_str);
402 return Err(McpError::transport("Message channel not available"));
403 }
404
405 match tokio::time::timeout(timeout, rx).await {
407 Ok(Ok(result)) => result,
408 Ok(Err(_)) => {
409 self.pending_requests.lock().await.remove(&id_str);
411 Err(McpError::transport("Response channel closed"))
412 }
413 Err(_) => {
414 self.pending_requests.lock().await.remove(&id_str);
416 Err(McpError::timeout("Request timed out", timeout))
417 }
418 }
419 }
420
421 fn subscribe(&self) -> mpsc::Receiver<TransportEvent> {
422 let (tx, rx) = mpsc::channel(100);
423 let event_tx = self.event_tx.clone();
424 tokio::spawn(async move {
425 *event_tx.lock().await = Some(tx);
426 });
427 rx
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn test_websocket_config() {
437 let config = WebSocketConfig {
438 url: "ws://localhost:8080".to_string(),
439 headers: HashMap::new(),
440 };
441 assert_eq!(config.url, "ws://localhost:8080");
442 }
443
444 #[test]
445 fn test_websocket_transport_new() {
446 let config = WebSocketConfig {
447 url: "ws://localhost:8080".to_string(),
448 headers: HashMap::new(),
449 };
450 let transport = WebSocketTransport::new(config, ConnectionOptions::default());
451 assert_eq!(transport.transport_type(), TransportType::WebSocket);
452 assert_eq!(transport.state(), TransportState::Disconnected);
453 }
454
455 #[test]
456 fn test_from_config() {
457 let config = TransportConfig::WebSocket {
458 url: "ws://localhost:8080".to_string(),
459 headers: HashMap::new(),
460 };
461 let transport = WebSocketTransport::from_config(config, ConnectionOptions::default());
462 assert!(transport.is_ok());
463 }
464
465 #[test]
466 fn test_from_config_wrong_type() {
467 let config = TransportConfig::Stdio {
468 command: "node".to_string(),
469 args: vec![],
470 env: HashMap::new(),
471 cwd: None,
472 };
473 let transport = WebSocketTransport::from_config(config, ConnectionOptions::default());
474 assert!(transport.is_err());
475 }
476
477 #[test]
478 fn test_next_request_id() {
479 let config = WebSocketConfig {
480 url: "ws://localhost:8080".to_string(),
481 headers: HashMap::new(),
482 };
483 let transport = WebSocketTransport::new(config, ConnectionOptions::default());
484
485 let id1 = transport.next_request_id();
486 let id2 = transport.next_request_id();
487
488 assert_ne!(id1, id2);
489 assert!(id1.starts_with("ws-req-"));
490 assert!(id2.starts_with("ws-req-"));
491 }
492
493 #[tokio::test]
494 async fn test_send_not_connected() {
495 let config = WebSocketConfig {
496 url: "ws://localhost:8080".to_string(),
497 headers: HashMap::new(),
498 };
499 let mut transport = WebSocketTransport::new(config, ConnectionOptions::default());
500
501 let request = McpRequest::new(serde_json::json!(1), "test/method");
502 let result = transport.send(McpMessage::Request(request)).await;
503 assert!(result.is_err());
504 }
505
506 #[tokio::test]
507 async fn test_connect_invalid_url() {
508 let config = WebSocketConfig {
509 url: "ws://localhost:99999/invalid".to_string(),
510 headers: HashMap::new(),
511 };
512 let mut transport = WebSocketTransport::new(config, ConnectionOptions::default());
513
514 let result = transport.connect().await;
515 assert!(result.is_err());
516 }
517}