Skip to main content

aster/mcp/transport/
websocket.rs

1//! WebSocket Transport Implementation
2//!
3//! This module implements the WebSocket transport for MCP communication.
4//! It provides full-duplex communication over WebSocket connections.
5//!
6//! # Message Format
7//!
8//! Messages are sent as JSON-RPC 2.0 format over WebSocket text frames.
9//! Each message is a single JSON object.
10
11use async_trait::async_trait;
12use futures::stream::{SplitSink, SplitStream};
13use futures::{SinkExt, StreamExt};
14use std::collections::HashMap;
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::Arc;
17use std::time::Duration;
18use tokio::net::TcpStream;
19use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
20use tokio_tungstenite::tungstenite::http::Request;
21use tokio_tungstenite::tungstenite::Message;
22use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
23
24use crate::mcp::error::{McpError, McpResult};
25use crate::mcp::transport::{
26    McpMessage, McpNotification, McpRequest, McpResponse, Transport, TransportConfig,
27    TransportEvent, TransportState,
28};
29use crate::mcp::types::{ConnectionOptions, TransportType};
30
31/// WebSocket-specific configuration
32#[derive(Debug, Clone)]
33pub struct WebSocketConfig {
34    /// Server URL (ws:// or wss://)
35    pub url: String,
36    /// HTTP headers for upgrade request
37    pub headers: HashMap<String, String>,
38}
39
40/// Pending request waiting for response
41struct PendingRequest {
42    /// Channel to send the response
43    tx: oneshot::Sender<McpResult<McpResponse>>,
44}
45
46type WsWriter = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
47type WsReader = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
48
49/// WebSocket transport for MCP communication
50///
51/// This transport provides full-duplex communication over WebSocket connections.
52/// Messages are sent as JSON-RPC 2.0 format over WebSocket text frames.
53pub struct WebSocketTransport {
54    /// Transport configuration
55    config: WebSocketConfig,
56    /// Connection options
57    options: ConnectionOptions,
58    /// Current transport state
59    state: Arc<RwLock<TransportState>>,
60    /// WebSocket writer
61    writer: Arc<Mutex<Option<WsWriter>>>,
62    /// Message sender channel
63    message_tx: Arc<Mutex<Option<mpsc::Sender<String>>>>,
64    /// Pending requests waiting for responses
65    pending_requests: Arc<Mutex<HashMap<String, PendingRequest>>>,
66    /// Event channel sender
67    event_tx: Arc<Mutex<Option<mpsc::Sender<TransportEvent>>>>,
68    /// Request ID counter
69    request_counter: AtomicU64,
70    /// Shutdown signal
71    shutdown_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
72}
73
74impl WebSocketTransport {
75    /// Create a new WebSocket transport
76    pub fn new(config: WebSocketConfig, options: ConnectionOptions) -> Self {
77        Self {
78            config,
79            options,
80            state: Arc::new(RwLock::new(TransportState::Disconnected)),
81            writer: Arc::new(Mutex::new(None)),
82            message_tx: Arc::new(Mutex::new(None)),
83            pending_requests: Arc::new(Mutex::new(HashMap::new())),
84            event_tx: Arc::new(Mutex::new(None)),
85            request_counter: AtomicU64::new(1),
86            shutdown_tx: Arc::new(Mutex::new(None)),
87        }
88    }
89
90    /// Create from transport config
91    pub fn from_config(config: TransportConfig, options: ConnectionOptions) -> McpResult<Self> {
92        match config {
93            TransportConfig::WebSocket { url, headers } => {
94                Ok(Self::new(WebSocketConfig { url, headers }, options))
95            }
96            _ => Err(McpError::config(
97                "Expected WebSocket transport configuration",
98            )),
99        }
100    }
101
102    /// Generate a unique request ID
103    pub fn next_request_id(&self) -> String {
104        let id = self.request_counter.fetch_add(1, Ordering::SeqCst);
105        format!("ws-req-{}", id)
106    }
107
108    /// Set the transport state
109    async fn set_state(&self, state: TransportState) {
110        let mut current = self.state.write().await;
111        *current = state;
112    }
113
114    /// Emit a transport event
115    async fn emit_event(&self, event: TransportEvent) {
116        if let Some(tx) = self.event_tx.lock().await.as_ref() {
117            let _ = tx.send(event).await;
118        }
119    }
120
121    /// Handle incoming message from WebSocket
122    async fn handle_message(
123        message: &str,
124        pending_requests: &Arc<Mutex<HashMap<String, PendingRequest>>>,
125        event_tx: &Arc<Mutex<Option<mpsc::Sender<TransportEvent>>>>,
126    ) {
127        // Try to parse as a response first
128        if let Ok(response) = serde_json::from_str::<McpResponse>(message) {
129            let id_str = match &response.id {
130                serde_json::Value::String(s) => s.clone(),
131                serde_json::Value::Number(n) => n.to_string(),
132                _ => return,
133            };
134
135            let mut pending = pending_requests.lock().await;
136            if let Some(req) = pending.remove(&id_str) {
137                let _ = req.tx.send(Ok(response));
138            }
139            return;
140        }
141
142        // Try to parse as a notification
143        if let Ok(notification) = serde_json::from_str::<McpNotification>(message) {
144            if let Some(tx) = event_tx.lock().await.as_ref() {
145                let _ = tx
146                    .send(TransportEvent::MessageReceived(Box::new(
147                        McpMessage::Notification(notification),
148                    )))
149                    .await;
150            }
151            return;
152        }
153
154        // Try to parse as a request (server-initiated)
155        if let Ok(request) = serde_json::from_str::<McpRequest>(message) {
156            if let Some(tx) = event_tx.lock().await.as_ref() {
157                let _ = tx
158                    .send(TransportEvent::MessageReceived(Box::new(
159                        McpMessage::Request(request),
160                    )))
161                    .await;
162            }
163        }
164    }
165
166    /// Start the reader task for WebSocket
167    fn start_reader_task(&self, mut reader: WsReader, mut shutdown_rx: mpsc::Receiver<()>) {
168        let pending_requests = self.pending_requests.clone();
169        let event_tx = self.event_tx.clone();
170        let state = self.state.clone();
171
172        tokio::spawn(async move {
173            loop {
174                tokio::select! {
175                    msg = reader.next() => {
176                        match msg {
177                            Some(Ok(Message::Text(text))) => {
178                                Self::handle_message(&text, &pending_requests, &event_tx).await;
179                            }
180                            Some(Ok(Message::Close(_))) => {
181                                let mut s = state.write().await;
182                                *s = TransportState::Disconnected;
183                                if let Some(tx) = event_tx.lock().await.as_ref() {
184                                    let _ = tx.send(TransportEvent::Disconnected {
185                                        reason: Some("WebSocket closed by server".to_string()),
186                                    }).await;
187                                }
188                                break;
189                            }
190                            Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => {
191                                // Ignore ping/pong frames
192                            }
193                            Some(Ok(Message::Binary(_))) => {
194                                // Ignore binary frames for now
195                            }
196                            Some(Ok(Message::Frame(_))) => {
197                                // Ignore raw frames
198                            }
199                            Some(Err(e)) => {
200                                let mut s = state.write().await;
201                                *s = TransportState::Error;
202                                if let Some(tx) = event_tx.lock().await.as_ref() {
203                                    let _ = tx.send(TransportEvent::Error {
204                                        error: e.to_string(),
205                                    }).await;
206                                }
207                                break;
208                            }
209                            None => {
210                                let mut s = state.write().await;
211                                *s = TransportState::Disconnected;
212                                if let Some(tx) = event_tx.lock().await.as_ref() {
213                                    let _ = tx.send(TransportEvent::Disconnected {
214                                        reason: Some("WebSocket stream ended".to_string()),
215                                    }).await;
216                                }
217                                break;
218                            }
219                        }
220                    }
221                    _ = shutdown_rx.recv() => {
222                        break;
223                    }
224                }
225            }
226        });
227    }
228
229    /// Start the writer task for WebSocket
230    fn start_writer_task(&self, mut writer: WsWriter, mut message_rx: mpsc::Receiver<String>) {
231        let state = self.state.clone();
232        let event_tx = self.event_tx.clone();
233
234        tokio::spawn(async move {
235            while let Some(message) = message_rx.recv().await {
236                if let Err(e) = writer.send(Message::Text(message.into())).await {
237                    let mut s = state.write().await;
238                    *s = TransportState::Error;
239                    if let Some(tx) = event_tx.lock().await.as_ref() {
240                        let _ = tx
241                            .send(TransportEvent::Error {
242                                error: e.to_string(),
243                            })
244                            .await;
245                    }
246                    break;
247                }
248            }
249        });
250    }
251}
252
253#[async_trait]
254impl Transport for WebSocketTransport {
255    fn transport_type(&self) -> TransportType {
256        TransportType::WebSocket
257    }
258
259    fn state(&self) -> TransportState {
260        self.state
261            .try_read()
262            .map(|s| *s)
263            .unwrap_or(TransportState::Disconnected)
264    }
265
266    async fn connect(&mut self) -> McpResult<()> {
267        self.set_state(TransportState::Connecting).await;
268        self.emit_event(TransportEvent::Connecting).await;
269
270        // Build the WebSocket request with headers
271        let mut request = Request::builder().uri(&self.config.url);
272
273        for (key, value) in &self.config.headers {
274            request = request.header(key, value);
275        }
276
277        let request = request.body(()).map_err(|e| {
278            McpError::transport(format!("Failed to build WebSocket request: {}", e))
279        })?;
280
281        // Connect to WebSocket server
282        let (ws_stream, _response) = connect_async(request).await.map_err(|e| {
283            McpError::transport_with_source(
284                format!("Failed to connect to WebSocket server: {}", self.config.url),
285                e,
286            )
287        })?;
288
289        // Split the stream into reader and writer
290        let (writer, reader) = ws_stream.split();
291
292        // Create channels
293        let (message_tx, message_rx) = mpsc::channel::<String>(self.options.queue_max_size);
294        let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
295        let (event_tx, _event_rx) = mpsc::channel::<TransportEvent>(100);
296
297        // Store handles
298        *self.writer.lock().await = Some(writer);
299        *self.message_tx.lock().await = Some(message_tx);
300        *self.shutdown_tx.lock().await = Some(shutdown_tx);
301        *self.event_tx.lock().await = Some(event_tx);
302
303        // Start reader and writer tasks
304        self.start_reader_task(reader, shutdown_rx);
305        self.start_writer_task(self.writer.lock().await.take().unwrap(), message_rx);
306
307        self.set_state(TransportState::Connected).await;
308        self.emit_event(TransportEvent::Connected).await;
309
310        Ok(())
311    }
312
313    async fn disconnect(&mut self) -> McpResult<()> {
314        self.set_state(TransportState::Closing).await;
315
316        // Send shutdown signal
317        if let Some(tx) = self.shutdown_tx.lock().await.take() {
318            let _ = tx.send(()).await;
319        }
320
321        // Close message channel
322        *self.message_tx.lock().await = None;
323
324        // Clear pending requests
325        let mut pending = self.pending_requests.lock().await;
326        for (_, req) in pending.drain() {
327            let _ = req.tx.send(Err(McpError::cancelled(
328                "Transport disconnected",
329                Some("disconnect".to_string()),
330            )));
331        }
332
333        self.set_state(TransportState::Disconnected).await;
334        self.emit_event(TransportEvent::Disconnected {
335            reason: Some("Disconnected by user".to_string()),
336        })
337        .await;
338
339        Ok(())
340    }
341
342    async fn send(&mut self, message: McpMessage) -> McpResult<()> {
343        let state = *self.state.read().await;
344        if state != TransportState::Connected {
345            return Err(McpError::transport("Transport is not connected"));
346        }
347
348        let json = serde_json::to_string(&message)?;
349
350        if let Some(tx) = self.message_tx.lock().await.as_ref() {
351            tx.send(json)
352                .await
353                .map_err(|e| McpError::transport(format!("Failed to send message: {}", e)))?;
354        } else {
355            return Err(McpError::transport("Message channel not available"));
356        }
357
358        Ok(())
359    }
360
361    async fn send_request(&mut self, request: McpRequest) -> McpResult<McpResponse> {
362        self.send_request_with_timeout(request, self.options.timeout)
363            .await
364    }
365
366    async fn send_request_with_timeout(
367        &mut self,
368        request: McpRequest,
369        timeout: Duration,
370    ) -> McpResult<McpResponse> {
371        let state = *self.state.read().await;
372        if state != TransportState::Connected {
373            return Err(McpError::transport("Transport is not connected"));
374        }
375
376        // Get the request ID as string
377        let id_str = match &request.id {
378            serde_json::Value::String(s) => s.clone(),
379            serde_json::Value::Number(n) => n.to_string(),
380            _ => return Err(McpError::protocol("Invalid request ID type")),
381        };
382
383        // Create response channel
384        let (tx, rx) = oneshot::channel();
385
386        // Register pending request
387        {
388            let mut pending = self.pending_requests.lock().await;
389            pending.insert(id_str.clone(), PendingRequest { tx });
390        }
391
392        // Send the request
393        let json = serde_json::to_string(&request)?;
394        if let Some(message_tx) = self.message_tx.lock().await.as_ref() {
395            message_tx
396                .send(json)
397                .await
398                .map_err(|e| McpError::transport(format!("Failed to send request: {}", e)))?;
399        } else {
400            // Remove pending request on failure
401            self.pending_requests.lock().await.remove(&id_str);
402            return Err(McpError::transport("Message channel not available"));
403        }
404
405        // Wait for response with timeout
406        match tokio::time::timeout(timeout, rx).await {
407            Ok(Ok(result)) => result,
408            Ok(Err(_)) => {
409                // Channel closed
410                self.pending_requests.lock().await.remove(&id_str);
411                Err(McpError::transport("Response channel closed"))
412            }
413            Err(_) => {
414                // Timeout
415                self.pending_requests.lock().await.remove(&id_str);
416                Err(McpError::timeout("Request timed out", timeout))
417            }
418        }
419    }
420
421    fn subscribe(&self) -> mpsc::Receiver<TransportEvent> {
422        let (tx, rx) = mpsc::channel(100);
423        let event_tx = self.event_tx.clone();
424        tokio::spawn(async move {
425            *event_tx.lock().await = Some(tx);
426        });
427        rx
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    #[test]
436    fn test_websocket_config() {
437        let config = WebSocketConfig {
438            url: "ws://localhost:8080".to_string(),
439            headers: HashMap::new(),
440        };
441        assert_eq!(config.url, "ws://localhost:8080");
442    }
443
444    #[test]
445    fn test_websocket_transport_new() {
446        let config = WebSocketConfig {
447            url: "ws://localhost:8080".to_string(),
448            headers: HashMap::new(),
449        };
450        let transport = WebSocketTransport::new(config, ConnectionOptions::default());
451        assert_eq!(transport.transport_type(), TransportType::WebSocket);
452        assert_eq!(transport.state(), TransportState::Disconnected);
453    }
454
455    #[test]
456    fn test_from_config() {
457        let config = TransportConfig::WebSocket {
458            url: "ws://localhost:8080".to_string(),
459            headers: HashMap::new(),
460        };
461        let transport = WebSocketTransport::from_config(config, ConnectionOptions::default());
462        assert!(transport.is_ok());
463    }
464
465    #[test]
466    fn test_from_config_wrong_type() {
467        let config = TransportConfig::Stdio {
468            command: "node".to_string(),
469            args: vec![],
470            env: HashMap::new(),
471            cwd: None,
472        };
473        let transport = WebSocketTransport::from_config(config, ConnectionOptions::default());
474        assert!(transport.is_err());
475    }
476
477    #[test]
478    fn test_next_request_id() {
479        let config = WebSocketConfig {
480            url: "ws://localhost:8080".to_string(),
481            headers: HashMap::new(),
482        };
483        let transport = WebSocketTransport::new(config, ConnectionOptions::default());
484
485        let id1 = transport.next_request_id();
486        let id2 = transport.next_request_id();
487
488        assert_ne!(id1, id2);
489        assert!(id1.starts_with("ws-req-"));
490        assert!(id2.starts_with("ws-req-"));
491    }
492
493    #[tokio::test]
494    async fn test_send_not_connected() {
495        let config = WebSocketConfig {
496            url: "ws://localhost:8080".to_string(),
497            headers: HashMap::new(),
498        };
499        let mut transport = WebSocketTransport::new(config, ConnectionOptions::default());
500
501        let request = McpRequest::new(serde_json::json!(1), "test/method");
502        let result = transport.send(McpMessage::Request(request)).await;
503        assert!(result.is_err());
504    }
505
506    #[tokio::test]
507    async fn test_connect_invalid_url() {
508        let config = WebSocketConfig {
509            url: "ws://localhost:99999/invalid".to_string(),
510            headers: HashMap::new(),
511        };
512        let mut transport = WebSocketTransport::new(config, ConnectionOptions::default());
513
514        let result = transport.connect().await;
515        assert!(result.is_err());
516    }
517}