1use async_trait::async_trait;
4use bytes::Bytes;
5use futures_util::{SinkExt, StreamExt};
6use parking_lot::Mutex;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use tokio::io::AsyncWriteExt;
10use tokio::sync::mpsc;
11use tokio_tungstenite::{
12 connect_async,
13 tungstenite::{
14 handshake::{
15 client::generate_key,
16 server::{Request as HsRequest, Response as HsResponse},
17 },
18 http::Request,
19 protocol::Message as WsMessage,
20 },
21};
22use tracing::{debug, error, info, warn};
23
24use crate::error::{Result, TransportError};
25use crate::traits::{
26 Transport, TransportEvent, TransportReceiver, TransportSender, TransportServer,
27};
28
29use clasp_core::WS_SUBPROTOCOL;
30
31pub const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 1000;
34
35#[derive(Debug, Clone)]
37pub struct WebSocketConfig {
38 pub subprotocol: String,
40 pub max_message_size: usize,
42 pub ping_interval: u64,
44 pub channel_buffer_size: usize,
46}
47
48impl Default for WebSocketConfig {
49 fn default() -> Self {
50 Self {
51 subprotocol: WS_SUBPROTOCOL.to_string(),
52 max_message_size: 64 * 1024, ping_interval: 30,
54 channel_buffer_size: DEFAULT_CHANNEL_BUFFER_SIZE,
55 }
56 }
57}
58
59pub struct WebSocketTransport {
61 #[allow(dead_code)]
62 config: WebSocketConfig,
63}
64
65impl WebSocketTransport {
66 pub fn new() -> Self {
67 Self {
68 config: WebSocketConfig::default(),
69 }
70 }
71
72 pub fn with_config(config: WebSocketConfig) -> Self {
73 Self { config }
74 }
75}
76
77impl Default for WebSocketTransport {
78 fn default() -> Self {
79 Self::new()
80 }
81}
82
83pub struct WebSocketSender {
85 tx: mpsc::Sender<WsMessage>,
86 connected: Arc<Mutex<bool>>,
87}
88
89#[async_trait]
90impl TransportSender for WebSocketSender {
91 async fn send(&self, data: Bytes) -> Result<()> {
92 if !self.is_connected() {
93 return Err(TransportError::NotConnected);
94 }
95
96 self.tx
97 .send(WsMessage::Binary(data.to_vec()))
98 .await
99 .map_err(|e| TransportError::SendFailed(e.to_string()))
100 }
101
102 fn try_send(&self, data: Bytes) -> Result<()> {
103 if !self.is_connected() {
104 return Err(TransportError::NotConnected);
105 }
106
107 self.tx
108 .try_send(WsMessage::Binary(data.to_vec()))
109 .map_err(|e| match e {
110 mpsc::error::TrySendError::Full(_) => TransportError::BufferFull,
111 mpsc::error::TrySendError::Closed(_) => TransportError::ConnectionClosed,
112 })
113 }
114
115 fn is_connected(&self) -> bool {
116 *self.connected.lock()
117 }
118
119 async fn close(&self) -> Result<()> {
120 let _ = self.tx.send(WsMessage::Close(None)).await;
121 *self.connected.lock() = false;
122 Ok(())
123 }
124}
125
126pub struct WebSocketReceiver {
128 rx: mpsc::Receiver<TransportEvent>,
129}
130
131#[async_trait]
132impl TransportReceiver for WebSocketReceiver {
133 async fn recv(&mut self) -> Option<TransportEvent> {
134 self.rx.recv().await
135 }
136}
137
138#[async_trait]
139impl Transport for WebSocketTransport {
140 type Sender = WebSocketSender;
141 type Receiver = WebSocketReceiver;
142
143 async fn connect(url: &str) -> Result<(Self::Sender, Self::Receiver)> {
144 info!("Connecting to WebSocket: {}", url);
145
146 let parsed_url =
148 url::Url::parse(url).map_err(|e| TransportError::InvalidUrl(e.to_string()))?;
149
150 let host = parsed_url
151 .host_str()
152 .ok_or_else(|| TransportError::InvalidUrl("Missing host in URL".to_string()))?;
153
154 let host_header = if let Some(port) = parsed_url.port() {
155 format!("{}:{}", host, port)
156 } else {
157 host.to_string()
158 };
159
160 let ws_key = generate_key();
162 let request = Request::builder()
163 .method("GET")
164 .uri(url)
165 .header("Host", &host_header)
166 .header("Upgrade", "websocket")
167 .header("Connection", "Upgrade")
168 .header("Sec-WebSocket-Key", &ws_key)
169 .header("Sec-WebSocket-Version", "13")
170 .header("Sec-WebSocket-Protocol", WS_SUBPROTOCOL)
171 .body(())
172 .map_err(|e| TransportError::InvalidUrl(e.to_string()))?;
173
174 let (ws_stream, response) = connect_async(request)
176 .await
177 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
178
179 debug!("WebSocket connected, response: {:?}", response.status());
180
181 if let Some(protocol) = response.headers().get("Sec-WebSocket-Protocol") {
183 debug!("Server subprotocol: {:?}", protocol);
184 }
185
186 let (write, read) = ws_stream.split();
188
189 let (send_tx, mut send_rx) = mpsc::channel::<WsMessage>(DEFAULT_CHANNEL_BUFFER_SIZE);
191 let (event_tx, event_rx) = mpsc::channel::<TransportEvent>(DEFAULT_CHANNEL_BUFFER_SIZE);
192
193 let connected = Arc::new(Mutex::new(true));
194 let connected_write = connected.clone();
195 let connected_read = connected.clone();
196
197 tokio::spawn(async move {
199 let mut write = write;
200 while let Some(msg) = send_rx.recv().await {
201 if let Err(e) = write.send(msg).await {
202 error!("WebSocket write error: {}", e);
203 break;
204 }
205 }
206 *connected_write.lock() = false;
207 });
208
209 let event_tx_clone = event_tx.clone();
211 tokio::spawn(async move {
212 let mut read = read;
213
214 let _ = event_tx_clone.send(TransportEvent::Connected).await;
216
217 while let Some(result) = read.next().await {
218 match result {
219 Ok(msg) => {
220 match msg {
221 WsMessage::Binary(data) => {
222 let _ = event_tx_clone
223 .send(TransportEvent::Data(Bytes::from(data)))
224 .await;
225 }
226 WsMessage::Text(text) => {
227 warn!("Received text message, converting to bytes");
229 let _ = event_tx_clone
230 .send(TransportEvent::Data(Bytes::from(text)))
231 .await;
232 }
233 WsMessage::Ping(data) => {
234 debug!("Received ping");
235 let _ = data;
237 }
238 WsMessage::Pong(_) => {
239 debug!("Received pong");
240 }
241 WsMessage::Close(frame) => {
242 let reason = frame.map(|f| f.reason.to_string());
243 info!("WebSocket closed: {:?}", reason);
244 let _ = event_tx_clone
245 .send(TransportEvent::Disconnected { reason })
246 .await;
247 break;
248 }
249 WsMessage::Frame(_) => {
250 }
252 }
253 }
254 Err(e) => {
255 error!("WebSocket read error: {}", e);
256 let _ = event_tx_clone
257 .send(TransportEvent::Error(e.to_string()))
258 .await;
259 let _ = event_tx_clone
260 .send(TransportEvent::Disconnected {
261 reason: Some(e.to_string()),
262 })
263 .await;
264 break;
265 }
266 }
267 }
268
269 *connected_read.lock() = false;
270 });
271
272 let sender = WebSocketSender {
273 tx: send_tx,
274 connected,
275 };
276
277 let receiver = WebSocketReceiver { rx: event_rx };
278
279 Ok((sender, receiver))
280 }
281
282 fn local_addr(&self) -> Option<SocketAddr> {
283 None
284 }
285
286 fn remote_addr(&self) -> Option<SocketAddr> {
287 None
288 }
289}
290
291pub struct WebSocketServer {
293 listener: tokio::net::TcpListener,
294 config: WebSocketConfig,
295}
296
297impl WebSocketServer {
298 pub async fn bind(addr: &str) -> Result<Self> {
299 let listener = tokio::net::TcpListener::bind(addr)
300 .await
301 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
302
303 info!("WebSocket server listening on {}", addr);
304
305 Ok(Self {
306 listener,
307 config: WebSocketConfig::default(),
308 })
309 }
310
311 pub fn with_config(mut self, config: WebSocketConfig) -> Self {
312 self.config = config;
313 self
314 }
315}
316
317#[async_trait]
318impl TransportServer for WebSocketServer {
319 type Sender = WebSocketSender;
320 type Receiver = WebSocketReceiver;
321
322 async fn accept(&mut self) -> Result<(Self::Sender, Self::Receiver, SocketAddr)> {
323 let (stream, addr) = loop {
324 let (mut stream, addr) = self
325 .listener
326 .accept()
327 .await
328 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
329
330 let mut peek_buf = [0u8; 4096];
338 match stream.peek(&mut peek_buf).await {
339 Ok(n) if n > 0 => {
340 if let Ok(text) = std::str::from_utf8(&peek_buf[..n]) {
341 let lower = text.to_ascii_lowercase();
342 let is_definitely_not_ws = if text.starts_with("HEAD ")
350 || text.starts_with("POST ")
351 || text.starts_with("OPTIONS ")
352 {
353 true
354 } else if text.starts_with("GET ") {
355 let has_upgrade = lower.contains("upgrade: websocket");
356 let headers_complete = lower.contains("\r\n\r\n");
357 !has_upgrade && headers_complete
358 } else {
359 false
360 };
361
362 if is_definitely_not_ws {
363 info!("Plain HTTP probe from {}, responding 200", addr);
364 let resp = "HTTP/1.1 200 OK\r\n\
365 Content-Type: text/plain\r\n\
366 Content-Length: 3\r\n\
367 Connection: close\r\n\r\nok\n";
368 let _ = stream.try_write(resp.as_bytes());
369 let _ = stream.shutdown().await;
370 continue;
371 }
372 }
373 }
374 Ok(_) => {
375 info!("Empty TCP probe from {}", addr);
377 let _ = stream.shutdown().await;
378 continue;
379 }
380 Err(e) => {
381 warn!("Peek error from {}: {}", addr, e);
382 let _ = stream.shutdown().await;
383 continue;
384 }
385 }
386
387 break (stream, addr);
388 };
389
390 debug!("Accepted TCP connection from {}", addr);
391
392 let subprotocol = self.config.subprotocol.clone();
394 let ws_stream = tokio_tungstenite::accept_hdr_async(
395 stream,
396 |req: &HsRequest, mut response: HsResponse| {
397 if let Some(protocols) = req.headers().get("Sec-WebSocket-Protocol") {
399 if let Ok(protocols_str) = protocols.to_str() {
400 let requested: Vec<&str> =
402 protocols_str.split(',').map(|s| s.trim()).collect();
403 if requested.contains(&subprotocol.as_str()) {
404 response
406 .headers_mut()
407 .insert("Sec-WebSocket-Protocol", subprotocol.parse().unwrap());
408 }
409 }
410 }
411 Ok(response)
412 },
413 )
414 .await
415 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
416
417 info!("WebSocket client connected from {}", addr);
418
419 let (write, read) = ws_stream.split();
421
422 let buffer_size = self.config.channel_buffer_size;
424 let (send_tx, mut send_rx) = mpsc::channel::<WsMessage>(buffer_size);
425 let (event_tx, event_rx) = mpsc::channel::<TransportEvent>(buffer_size);
426
427 let connected = Arc::new(Mutex::new(true));
428 let connected_write = connected.clone();
429 let connected_read = connected.clone();
430
431 tokio::spawn(async move {
433 let mut write = write;
434 while let Some(msg) = send_rx.recv().await {
435 if let Err(e) = write.send(msg).await {
436 error!("WebSocket write error: {}", e);
437 break;
438 }
439 }
440 *connected_write.lock() = false;
441 });
442
443 let event_tx_clone = event_tx.clone();
445 tokio::spawn(async move {
446 let mut read = read;
447
448 let _ = event_tx_clone.send(TransportEvent::Connected).await;
449
450 while let Some(result) = read.next().await {
451 match result {
452 Ok(msg) => match msg {
453 WsMessage::Binary(data) => {
454 let _ = event_tx_clone
455 .send(TransportEvent::Data(Bytes::from(data)))
456 .await;
457 }
458 WsMessage::Close(frame) => {
459 let reason = frame.map(|f| f.reason.to_string());
460 let _ = event_tx_clone
461 .send(TransportEvent::Disconnected { reason })
462 .await;
463 break;
464 }
465 WsMessage::Ping(_) | WsMessage::Pong(_) => {
466 }
468 WsMessage::Text(_) => {
469 debug!("Ignoring unexpected text WebSocket frame");
470 }
471 _ => {}
472 },
473 Err(e) => {
474 let _ = event_tx_clone
475 .send(TransportEvent::Disconnected {
476 reason: Some(e.to_string()),
477 })
478 .await;
479 break;
480 }
481 }
482 }
483
484 *connected_read.lock() = false;
485 });
486
487 let sender = WebSocketSender {
488 tx: send_tx,
489 connected,
490 };
491
492 let receiver = WebSocketReceiver { rx: event_rx };
493
494 Ok((sender, receiver, addr))
495 }
496
497 fn local_addr(&self) -> Result<SocketAddr> {
498 self.listener.local_addr().map_err(TransportError::Io)
499 }
500
501 async fn close(&self) -> Result<()> {
502 Ok(())
504 }
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510
511 #[tokio::test]
512 async fn test_websocket_config() {
513 let config = WebSocketConfig::default();
514 assert_eq!(config.subprotocol, "clasp");
515 }
516}