Skip to main content

clasp_transport/
websocket.rs

1//! WebSocket transport implementation
2
3use async_trait::async_trait;
4use bytes::Bytes;
5use futures_util::{SinkExt, StreamExt};
6use parking_lot::Mutex;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use tokio::io::AsyncWriteExt;
10use tokio::sync::mpsc;
11use tokio_tungstenite::{
12    connect_async,
13    tungstenite::{
14        handshake::{
15            client::generate_key,
16            server::{Request as HsRequest, Response as HsResponse},
17        },
18        http::Request,
19        protocol::Message as WsMessage,
20    },
21};
22use tracing::{debug, error, info, warn};
23
24use crate::error::{Result, TransportError};
25use crate::traits::{
26    Transport, TransportEvent, TransportReceiver, TransportSender, TransportServer,
27};
28
29use clasp_core::WS_SUBPROTOCOL;
30
31/// Default channel buffer size for WebSocket connections
32/// Larger buffers help prevent message drops under load
33pub const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 1000;
34
35/// WebSocket configuration
36#[derive(Debug, Clone)]
37pub struct WebSocketConfig {
38    /// Subprotocol to use
39    pub subprotocol: String,
40    /// Maximum message size
41    pub max_message_size: usize,
42    /// Ping interval in seconds
43    pub ping_interval: u64,
44    /// Channel buffer size for send/receive queues
45    pub channel_buffer_size: usize,
46}
47
48impl Default for WebSocketConfig {
49    fn default() -> Self {
50        Self {
51            subprotocol: WS_SUBPROTOCOL.to_string(),
52            max_message_size: 64 * 1024, // 64KB
53            ping_interval: 30,
54            channel_buffer_size: DEFAULT_CHANNEL_BUFFER_SIZE,
55        }
56    }
57}
58
59/// WebSocket transport
60pub struct WebSocketTransport {
61    #[allow(dead_code)]
62    config: WebSocketConfig,
63}
64
65impl WebSocketTransport {
66    pub fn new() -> Self {
67        Self {
68            config: WebSocketConfig::default(),
69        }
70    }
71
72    pub fn with_config(config: WebSocketConfig) -> Self {
73        Self { config }
74    }
75}
76
77impl Default for WebSocketTransport {
78    fn default() -> Self {
79        Self::new()
80    }
81}
82
83/// WebSocket sender
84pub struct WebSocketSender {
85    tx: mpsc::Sender<WsMessage>,
86    connected: Arc<Mutex<bool>>,
87}
88
89#[async_trait]
90impl TransportSender for WebSocketSender {
91    async fn send(&self, data: Bytes) -> Result<()> {
92        if !self.is_connected() {
93            return Err(TransportError::NotConnected);
94        }
95
96        self.tx
97            .send(WsMessage::Binary(data.to_vec()))
98            .await
99            .map_err(|e| TransportError::SendFailed(e.to_string()))
100    }
101
102    fn try_send(&self, data: Bytes) -> Result<()> {
103        if !self.is_connected() {
104            return Err(TransportError::NotConnected);
105        }
106
107        self.tx
108            .try_send(WsMessage::Binary(data.to_vec()))
109            .map_err(|e| match e {
110                mpsc::error::TrySendError::Full(_) => TransportError::BufferFull,
111                mpsc::error::TrySendError::Closed(_) => TransportError::ConnectionClosed,
112            })
113    }
114
115    fn is_connected(&self) -> bool {
116        *self.connected.lock()
117    }
118
119    async fn close(&self) -> Result<()> {
120        let _ = self.tx.send(WsMessage::Close(None)).await;
121        *self.connected.lock() = false;
122        Ok(())
123    }
124}
125
126/// WebSocket receiver
127pub struct WebSocketReceiver {
128    rx: mpsc::Receiver<TransportEvent>,
129}
130
131#[async_trait]
132impl TransportReceiver for WebSocketReceiver {
133    async fn recv(&mut self) -> Option<TransportEvent> {
134        self.rx.recv().await
135    }
136}
137
138#[async_trait]
139impl Transport for WebSocketTransport {
140    type Sender = WebSocketSender;
141    type Receiver = WebSocketReceiver;
142
143    async fn connect(url: &str) -> Result<(Self::Sender, Self::Receiver)> {
144        info!("Connecting to WebSocket: {}", url);
145
146        // Parse the URL to extract host for the Host header
147        let parsed_url =
148            url::Url::parse(url).map_err(|e| TransportError::InvalidUrl(e.to_string()))?;
149
150        let host = parsed_url
151            .host_str()
152            .ok_or_else(|| TransportError::InvalidUrl("Missing host in URL".to_string()))?;
153
154        let host_header = if let Some(port) = parsed_url.port() {
155            format!("{}:{}", host, port)
156        } else {
157            host.to_string()
158        };
159
160        // Build a complete WebSocket upgrade request with all required headers
161        let ws_key = generate_key();
162        let request = Request::builder()
163            .method("GET")
164            .uri(url)
165            .header("Host", &host_header)
166            .header("Upgrade", "websocket")
167            .header("Connection", "Upgrade")
168            .header("Sec-WebSocket-Key", &ws_key)
169            .header("Sec-WebSocket-Version", "13")
170            .header("Sec-WebSocket-Protocol", WS_SUBPROTOCOL)
171            .body(())
172            .map_err(|e| TransportError::InvalidUrl(e.to_string()))?;
173
174        // Connect
175        let (ws_stream, response) = connect_async(request)
176            .await
177            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
178
179        debug!("WebSocket connected, response: {:?}", response.status());
180
181        // Check subprotocol
182        if let Some(protocol) = response.headers().get("Sec-WebSocket-Protocol") {
183            debug!("Server subprotocol: {:?}", protocol);
184        }
185
186        // Split the WebSocket stream
187        let (write, read) = ws_stream.split();
188
189        // Create channels with larger buffers for better load handling
190        let (send_tx, mut send_rx) = mpsc::channel::<WsMessage>(DEFAULT_CHANNEL_BUFFER_SIZE);
191        let (event_tx, event_rx) = mpsc::channel::<TransportEvent>(DEFAULT_CHANNEL_BUFFER_SIZE);
192
193        let connected = Arc::new(Mutex::new(true));
194        let connected_write = connected.clone();
195        let connected_read = connected.clone();
196
197        // Spawn writer task
198        tokio::spawn(async move {
199            let mut write = write;
200            while let Some(msg) = send_rx.recv().await {
201                if let Err(e) = write.send(msg).await {
202                    error!("WebSocket write error: {}", e);
203                    break;
204                }
205            }
206            *connected_write.lock() = false;
207        });
208
209        // Spawn reader task
210        let event_tx_clone = event_tx.clone();
211        tokio::spawn(async move {
212            let mut read = read;
213
214            // Send connected event
215            let _ = event_tx_clone.send(TransportEvent::Connected).await;
216
217            while let Some(result) = read.next().await {
218                match result {
219                    Ok(msg) => {
220                        match msg {
221                            WsMessage::Binary(data) => {
222                                let _ = event_tx_clone
223                                    .send(TransportEvent::Data(Bytes::from(data)))
224                                    .await;
225                            }
226                            WsMessage::Text(text) => {
227                                // Convert text to bytes (shouldn't happen in Clasp)
228                                warn!("Received text message, converting to bytes");
229                                let _ = event_tx_clone
230                                    .send(TransportEvent::Data(Bytes::from(text)))
231                                    .await;
232                            }
233                            WsMessage::Ping(data) => {
234                                debug!("Received ping");
235                                // Pong is handled automatically by tungstenite
236                                let _ = data;
237                            }
238                            WsMessage::Pong(_) => {
239                                debug!("Received pong");
240                            }
241                            WsMessage::Close(frame) => {
242                                let reason = frame.map(|f| f.reason.to_string());
243                                info!("WebSocket closed: {:?}", reason);
244                                let _ = event_tx_clone
245                                    .send(TransportEvent::Disconnected { reason })
246                                    .await;
247                                break;
248                            }
249                            WsMessage::Frame(_) => {
250                                // Raw frame, ignore
251                            }
252                        }
253                    }
254                    Err(e) => {
255                        error!("WebSocket read error: {}", e);
256                        let _ = event_tx_clone
257                            .send(TransportEvent::Error(e.to_string()))
258                            .await;
259                        let _ = event_tx_clone
260                            .send(TransportEvent::Disconnected {
261                                reason: Some(e.to_string()),
262                            })
263                            .await;
264                        break;
265                    }
266                }
267            }
268
269            *connected_read.lock() = false;
270        });
271
272        let sender = WebSocketSender {
273            tx: send_tx,
274            connected,
275        };
276
277        let receiver = WebSocketReceiver { rx: event_rx };
278
279        Ok((sender, receiver))
280    }
281
282    fn local_addr(&self) -> Option<SocketAddr> {
283        None
284    }
285
286    fn remote_addr(&self) -> Option<SocketAddr> {
287        None
288    }
289}
290
291/// WebSocket server
292pub struct WebSocketServer {
293    listener: tokio::net::TcpListener,
294    config: WebSocketConfig,
295}
296
297impl WebSocketServer {
298    pub async fn bind(addr: &str) -> Result<Self> {
299        let listener = tokio::net::TcpListener::bind(addr)
300            .await
301            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
302
303        info!("WebSocket server listening on {}", addr);
304
305        Ok(Self {
306            listener,
307            config: WebSocketConfig::default(),
308        })
309    }
310
311    pub fn with_config(mut self, config: WebSocketConfig) -> Self {
312        self.config = config;
313        self
314    }
315}
316
317#[async_trait]
318impl TransportServer for WebSocketServer {
319    type Sender = WebSocketSender;
320    type Receiver = WebSocketReceiver;
321
322    async fn accept(&mut self) -> Result<(Self::Sender, Self::Receiver, SocketAddr)> {
323        let (stream, addr) = loop {
324            let (mut stream, addr) = self
325                .listener
326                .accept()
327                .await
328                .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
329
330            // Peek at incoming bytes to detect plain HTTP requests.
331            // Load balancers, health checkers, and platform routers may probe
332            // the WS port with plain HTTP — not a WebSocket upgrade. Without
333            // this intercept, tungstenite rejects them as bad handshakes.
334            //
335            // Use a 4 KiB buffer so that the Upgrade header is visible even
336            // when reverse proxies (e.g. Caddy) forward many browser headers.
337            let mut peek_buf = [0u8; 4096];
338            match stream.peek(&mut peek_buf).await {
339                Ok(n) if n > 0 => {
340                    if let Ok(text) = std::str::from_utf8(&peek_buf[..n]) {
341                        let lower = text.to_ascii_lowercase();
342                        // Only non-GET methods are definitely not WebSocket.
343                        // GET requests without "upgrade: websocket" are plain
344                        // HTTP only if the header block is complete (contains
345                        // the blank line terminator "\r\n\r\n").  If the peek
346                        // didn't capture the full headers we must assume it
347                        // could still be a WebSocket upgrade and let
348                        // tungstenite handle it.
349                        let is_definitely_not_ws = if text.starts_with("HEAD ")
350                            || text.starts_with("POST ")
351                            || text.starts_with("OPTIONS ")
352                        {
353                            true
354                        } else if text.starts_with("GET ") {
355                            let has_upgrade = lower.contains("upgrade: websocket");
356                            let headers_complete = lower.contains("\r\n\r\n");
357                            !has_upgrade && headers_complete
358                        } else {
359                            false
360                        };
361
362                        if is_definitely_not_ws {
363                            info!("Plain HTTP probe from {}, responding 200", addr);
364                            let resp = "HTTP/1.1 200 OK\r\n\
365                                        Content-Type: text/plain\r\n\
366                                        Content-Length: 3\r\n\
367                                        Connection: close\r\n\r\nok\n";
368                            let _ = stream.try_write(resp.as_bytes());
369                            let _ = stream.shutdown().await;
370                            continue;
371                        }
372                    }
373                }
374                Ok(_) => {
375                    // Empty peek — TCP probe, just close
376                    info!("Empty TCP probe from {}", addr);
377                    let _ = stream.shutdown().await;
378                    continue;
379                }
380                Err(e) => {
381                    warn!("Peek error from {}: {}", addr, e);
382                    let _ = stream.shutdown().await;
383                    continue;
384                }
385            }
386
387            break (stream, addr);
388        };
389
390        debug!("Accepted TCP connection from {}", addr);
391
392        // Upgrade to WebSocket with subprotocol negotiation
393        let subprotocol = self.config.subprotocol.clone();
394        let ws_stream = tokio_tungstenite::accept_hdr_async(
395            stream,
396            |req: &HsRequest, mut response: HsResponse| {
397                // Check if client requested our subprotocol
398                if let Some(protocols) = req.headers().get("Sec-WebSocket-Protocol") {
399                    if let Ok(protocols_str) = protocols.to_str() {
400                        // Client may request multiple protocols, comma-separated
401                        let requested: Vec<&str> =
402                            protocols_str.split(',').map(|s| s.trim()).collect();
403                        if requested.contains(&subprotocol.as_str()) {
404                            // Add our subprotocol to the response
405                            response
406                                .headers_mut()
407                                .insert("Sec-WebSocket-Protocol", subprotocol.parse().unwrap());
408                        }
409                    }
410                }
411                Ok(response)
412            },
413        )
414        .await
415        .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
416
417        info!("WebSocket client connected from {}", addr);
418
419        // Split the stream
420        let (write, read) = ws_stream.split();
421
422        // Create channels with configurable buffer size for better load handling
423        let buffer_size = self.config.channel_buffer_size;
424        let (send_tx, mut send_rx) = mpsc::channel::<WsMessage>(buffer_size);
425        let (event_tx, event_rx) = mpsc::channel::<TransportEvent>(buffer_size);
426
427        let connected = Arc::new(Mutex::new(true));
428        let connected_write = connected.clone();
429        let connected_read = connected.clone();
430
431        // Spawn writer task
432        tokio::spawn(async move {
433            let mut write = write;
434            while let Some(msg) = send_rx.recv().await {
435                if let Err(e) = write.send(msg).await {
436                    error!("WebSocket write error: {}", e);
437                    break;
438                }
439            }
440            *connected_write.lock() = false;
441        });
442
443        // Spawn reader task
444        let event_tx_clone = event_tx.clone();
445        tokio::spawn(async move {
446            let mut read = read;
447
448            let _ = event_tx_clone.send(TransportEvent::Connected).await;
449
450            while let Some(result) = read.next().await {
451                match result {
452                    Ok(msg) => match msg {
453                        WsMessage::Binary(data) => {
454                            let _ = event_tx_clone
455                                .send(TransportEvent::Data(Bytes::from(data)))
456                                .await;
457                        }
458                        WsMessage::Close(frame) => {
459                            let reason = frame.map(|f| f.reason.to_string());
460                            let _ = event_tx_clone
461                                .send(TransportEvent::Disconnected { reason })
462                                .await;
463                            break;
464                        }
465                        WsMessage::Ping(_) | WsMessage::Pong(_) => {
466                            // tungstenite auto-responds to Ping with Pong
467                        }
468                        WsMessage::Text(_) => {
469                            debug!("Ignoring unexpected text WebSocket frame");
470                        }
471                        _ => {}
472                    },
473                    Err(e) => {
474                        let _ = event_tx_clone
475                            .send(TransportEvent::Disconnected {
476                                reason: Some(e.to_string()),
477                            })
478                            .await;
479                        break;
480                    }
481                }
482            }
483
484            *connected_read.lock() = false;
485        });
486
487        let sender = WebSocketSender {
488            tx: send_tx,
489            connected,
490        };
491
492        let receiver = WebSocketReceiver { rx: event_rx };
493
494        Ok((sender, receiver, addr))
495    }
496
497    fn local_addr(&self) -> Result<SocketAddr> {
498        self.listener.local_addr().map_err(TransportError::Io)
499    }
500
501    async fn close(&self) -> Result<()> {
502        // TCP listener doesn't need explicit close
503        Ok(())
504    }
505}
506
507#[cfg(test)]
508mod tests {
509    use super::*;
510
511    #[tokio::test]
512    async fn test_websocket_config() {
513        let config = WebSocketConfig::default();
514        assert_eq!(config.subprotocol, "clasp");
515    }
516}