1use crate::transport::{
13 Connection, ConnectionMetrics, Transport, TransportCapabilities, TransportError,
14 TransportStats, TransportType,
15};
16use async_trait::async_trait;
17use bytes::Bytes;
18use futures::{SinkExt, StreamExt};
19use parking_lot::RwLock;
20use std::collections::HashMap;
21use std::net::SocketAddr;
22use std::sync::Arc;
23use std::time::{Duration, Instant};
24use tokio::net::{TcpListener, TcpStream};
25use tokio::sync::Mutex;
26use tokio_tungstenite::{
27 accept_async, connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream,
28};
29use tracing::{debug, info};
30
31#[derive(Debug, Clone)]
33pub struct WebSocketConfig {
34 pub ping_interval: Duration,
36 pub connect_timeout: Duration,
38 pub max_message_size: usize,
40 pub use_binary: bool,
42}
43
44impl Default for WebSocketConfig {
45 fn default() -> Self {
46 Self {
47 ping_interval: Duration::from_secs(30),
48 connect_timeout: Duration::from_secs(10),
49 max_message_size: 16 * 1024 * 1024, use_binary: true,
51 }
52 }
53}
54
55pub struct WebSocketConnection {
57 stream: Arc<Mutex<WebSocketStream<MaybeTlsStream<TcpStream>>>>,
58 remote_addr: SocketAddr,
59 metrics: Arc<RwLock<ConnectionMetrics>>,
60 created_at: Instant,
61 alive: Arc<RwLock<bool>>,
62 config: WebSocketConfig,
63}
64
65pub struct WebSocketServerConnection {
67 stream: Arc<Mutex<WebSocketStream<TcpStream>>>,
68 remote_addr: SocketAddr,
69 metrics: Arc<RwLock<ConnectionMetrics>>,
70 created_at: Instant,
71 alive: Arc<RwLock<bool>>,
72 config: WebSocketConfig,
73}
74
75impl WebSocketConnection {
76 pub fn new(
78 stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
79 remote_addr: SocketAddr,
80 config: WebSocketConfig,
81 ) -> Self {
82 debug!("WebSocket connection established to {}", remote_addr);
83
84 Self {
85 stream: Arc::new(Mutex::new(stream)),
86 remote_addr,
87 metrics: Arc::new(RwLock::new(ConnectionMetrics::default())),
88 created_at: Instant::now(),
89 alive: Arc::new(RwLock::new(true)),
90 config,
91 }
92 }
93
94 #[allow(dead_code)]
96 async fn send_ping(&self) -> Result<(), TransportError> {
97 let mut stream = self.stream.lock().await;
98 stream
99 .send(Message::Ping(vec![].into()))
100 .await
101 .map_err(|e| TransportError::SendFailed(format!("Ping failed: {}", e)))?;
102 Ok(())
103 }
104}
105
106#[async_trait]
107impl Connection for WebSocketConnection {
108 async fn send(&mut self, data: Bytes) -> Result<(), TransportError> {
109 if data.len() > self.config.max_message_size {
110 return Err(TransportError::ProtocolError(format!(
111 "Message size {} exceeds maximum {}",
112 data.len(),
113 self.config.max_message_size
114 )));
115 }
116
117 let data_len = data.len();
118 let message = if self.config.use_binary {
119 Message::Binary(data)
120 } else {
121 Message::Text(String::from_utf8_lossy(&data).to_string().into())
122 };
123
124 let mut stream = self.stream.lock().await;
125
126 stream.send(message).await.map_err(|e| {
127 *self.alive.write() = false;
128 TransportError::SendFailed(format!("WebSocket send failed: {}", e))
129 })?;
130
131 {
133 let mut metrics = self.metrics.write();
134 metrics.bytes_sent += data_len as u64;
135 }
136
137 Ok(())
138 }
139
140 async fn receive(&mut self) -> Result<Bytes, TransportError> {
141 let mut stream = self.stream.lock().await;
142
143 loop {
144 match stream.next().await {
145 Some(Ok(message)) => match message {
146 Message::Binary(data) => {
147 {
149 let mut metrics = self.metrics.write();
150 metrics.bytes_received += data.len() as u64;
151 }
152 return Ok(Bytes::from(data));
153 }
154 Message::Text(text) => {
155 let data = Bytes::copy_from_slice(text.as_bytes());
157 {
159 let mut metrics = self.metrics.write();
160 metrics.bytes_received += data.len() as u64;
161 }
162 return Ok(data);
163 }
164 Message::Ping(_) => {
165 debug!("Received ping, sending pong");
167 stream
168 .send(Message::Pong(vec![].into()))
169 .await
170 .map_err(|e| {
171 TransportError::SendFailed(format!("Pong failed: {}", e))
172 })?;
173 continue;
174 }
175 Message::Pong(_) => {
176 debug!("Received pong");
177 continue;
178 }
179 Message::Close(_) => {
180 *self.alive.write() = false;
181 return Err(TransportError::ConnectionClosed(
182 "Received close frame".to_string(),
183 ));
184 }
185 Message::Frame(_) => {
186 continue;
188 }
189 },
190 Some(Err(e)) => {
191 *self.alive.write() = false;
192 return Err(TransportError::ReceiveFailed(format!(
193 "WebSocket receive error: {}",
194 e
195 )));
196 }
197 None => {
198 *self.alive.write() = false;
199 return Err(TransportError::ConnectionClosed(
200 "WebSocket stream ended".to_string(),
201 ));
202 }
203 }
204 }
205 }
206
207 async fn close(&mut self) -> Result<(), TransportError> {
208 *self.alive.write() = false;
209 let mut stream = self.stream.lock().await;
210 stream
211 .close(None)
212 .await
213 .map_err(|e| TransportError::ConnectionClosed(format!("Close failed: {}", e)))?;
214 debug!("WebSocket connection to {} closed", self.remote_addr);
215 Ok(())
216 }
217
218 fn is_alive(&self) -> bool {
219 *self.alive.read()
220 }
221
222 fn metrics(&self) -> ConnectionMetrics {
223 let mut metrics = self.metrics.read().clone();
224 metrics.uptime = self.created_at.elapsed();
225 metrics.active_streams = 1; metrics
227 }
228
229 fn remote_addr(&self) -> SocketAddr {
230 self.remote_addr
231 }
232
233 fn transport_type(&self) -> TransportType {
234 TransportType::WebSocket
235 }
236}
237
238impl WebSocketServerConnection {
239 pub fn new(
241 stream: WebSocketStream<TcpStream>,
242 remote_addr: SocketAddr,
243 config: WebSocketConfig,
244 ) -> Self {
245 debug!("WebSocket server connection accepted from {}", remote_addr);
246
247 Self {
248 stream: Arc::new(Mutex::new(stream)),
249 remote_addr,
250 metrics: Arc::new(RwLock::new(ConnectionMetrics::default())),
251 created_at: Instant::now(),
252 alive: Arc::new(RwLock::new(true)),
253 config,
254 }
255 }
256}
257
258#[async_trait]
259impl Connection for WebSocketServerConnection {
260 async fn send(&mut self, data: Bytes) -> Result<(), TransportError> {
261 if data.len() > self.config.max_message_size {
262 return Err(TransportError::ProtocolError(format!(
263 "Message size {} exceeds maximum {}",
264 data.len(),
265 self.config.max_message_size
266 )));
267 }
268
269 let data_len = data.len();
270 let message = if self.config.use_binary {
271 Message::Binary(data)
272 } else {
273 Message::Text(String::from_utf8_lossy(&data).to_string().into())
274 };
275
276 let mut stream = self.stream.lock().await;
277
278 stream.send(message).await.map_err(|e| {
279 *self.alive.write() = false;
280 TransportError::SendFailed(format!("WebSocket send failed: {}", e))
281 })?;
282
283 {
285 let mut metrics = self.metrics.write();
286 metrics.bytes_sent += data_len as u64;
287 }
288
289 Ok(())
290 }
291
292 async fn receive(&mut self) -> Result<Bytes, TransportError> {
293 let mut stream = self.stream.lock().await;
294
295 loop {
296 match stream.next().await {
297 Some(Ok(message)) => match message {
298 Message::Binary(data) => {
299 {
301 let mut metrics = self.metrics.write();
302 metrics.bytes_received += data.len() as u64;
303 }
304 return Ok(Bytes::from(data));
305 }
306 Message::Text(text) => {
307 let data = Bytes::copy_from_slice(text.as_bytes());
309 {
311 let mut metrics = self.metrics.write();
312 metrics.bytes_received += data.len() as u64;
313 }
314 return Ok(data);
315 }
316 Message::Ping(_) => {
317 debug!("Received ping, sending pong");
319 stream
320 .send(Message::Pong(vec![].into()))
321 .await
322 .map_err(|e| {
323 TransportError::SendFailed(format!("Pong failed: {}", e))
324 })?;
325 continue;
326 }
327 Message::Pong(_) => {
328 debug!("Received pong");
329 continue;
330 }
331 Message::Close(_) => {
332 *self.alive.write() = false;
333 return Err(TransportError::ConnectionClosed(
334 "Received close frame".to_string(),
335 ));
336 }
337 Message::Frame(_) => {
338 continue;
340 }
341 },
342 Some(Err(e)) => {
343 *self.alive.write() = false;
344 return Err(TransportError::ReceiveFailed(format!(
345 "WebSocket receive error: {}",
346 e
347 )));
348 }
349 None => {
350 *self.alive.write() = false;
351 return Err(TransportError::ConnectionClosed(
352 "WebSocket stream ended".to_string(),
353 ));
354 }
355 }
356 }
357 }
358
359 async fn close(&mut self) -> Result<(), TransportError> {
360 *self.alive.write() = false;
361 let mut stream = self.stream.lock().await;
362 stream
363 .close(None)
364 .await
365 .map_err(|e| TransportError::ConnectionClosed(format!("Close failed: {}", e)))?;
366 debug!("WebSocket connection to {} closed", self.remote_addr);
367 Ok(())
368 }
369
370 fn is_alive(&self) -> bool {
371 *self.alive.read()
372 }
373
374 fn metrics(&self) -> ConnectionMetrics {
375 let mut metrics = self.metrics.read().clone();
376 metrics.uptime = self.created_at.elapsed();
377 metrics.active_streams = 1; metrics
379 }
380
381 fn remote_addr(&self) -> SocketAddr {
382 self.remote_addr
383 }
384
385 fn transport_type(&self) -> TransportType {
386 TransportType::WebSocket
387 }
388}
389
390pub struct WebSocketTransport {
392 config: WebSocketConfig,
393 listener: Arc<Mutex<Option<TcpListener>>>,
394 stats: Arc<RwLock<TransportStats>>,
395 connections: Arc<RwLock<HashMap<SocketAddr, Instant>>>,
396}
397
398impl WebSocketTransport {
399 pub fn new(config: WebSocketConfig) -> Self {
401 Self {
402 config,
403 listener: Arc::new(Mutex::new(None)),
404 stats: Arc::new(RwLock::new(TransportStats::default())),
405 connections: Arc::new(RwLock::new(HashMap::new())),
406 }
407 }
408
409 pub fn default_config() -> Self {
411 Self::new(WebSocketConfig::default())
412 }
413}
414
415#[async_trait]
416impl Transport for WebSocketTransport {
417 fn transport_type(&self) -> TransportType {
418 TransportType::WebSocket
419 }
420
421 fn capabilities(&self) -> TransportCapabilities {
422 TransportCapabilities::websocket()
423 }
424
425 fn is_available(&self) -> bool {
426 true
428 }
429
430 async fn connect(&self, addr: SocketAddr) -> Result<Box<dyn Connection>, TransportError> {
431 debug!("Connecting to {} via WebSocket", addr);
432
433 let url = format!("ws://{}", addr);
435
436 let (ws_stream, _) = tokio::time::timeout(self.config.connect_timeout, connect_async(&url))
437 .await
438 .map_err(|_| TransportError::Timeout(self.config.connect_timeout))?
439 .map_err(|e| {
440 self.stats.write().connections_failed += 1;
441 TransportError::ConnectionFailed(format!("WebSocket connect failed: {}", e))
442 })?;
443
444 let connection = WebSocketConnection::new(ws_stream, addr, self.config.clone());
446
447 {
449 let mut stats = self.stats.write();
450 stats.connections_established += 1;
451 stats.active_connections += 1;
452 }
453
454 self.connections.write().insert(addr, Instant::now());
456
457 info!("WebSocket connection established to {}", addr);
458
459 Ok(Box::new(connection))
460 }
461
462 async fn listen(&self, addr: SocketAddr) -> Result<(), TransportError> {
463 let listener = TcpListener::bind(addr).await.map_err(|e| {
464 TransportError::ConnectionFailed(format!("Failed to bind WebSocket listener: {}", e))
465 })?;
466
467 info!("WebSocket transport listening on {}", addr);
468
469 *self.listener.lock().await = Some(listener);
470 Ok(())
471 }
472
473 async fn accept(&self) -> Result<Box<dyn Connection>, TransportError> {
474 let listener = self.listener.lock().await;
475 let listener = listener
476 .as_ref()
477 .ok_or_else(|| TransportError::ProtocolError("No listener bound".to_string()))?;
478
479 let (stream, addr) = listener
480 .accept()
481 .await
482 .map_err(|e| TransportError::ConnectionFailed(format!("Accept failed: {}", e)))?;
483
484 debug!("Accepting WebSocket connection from {}", addr);
485
486 let ws_stream = accept_async(stream).await.map_err(|e| {
488 TransportError::ConnectionFailed(format!("WebSocket handshake failed: {}", e))
489 })?;
490
491 let connection = WebSocketServerConnection::new(ws_stream, addr, self.config.clone());
492
493 {
495 let mut stats = self.stats.write();
496 stats.connections_established += 1;
497 stats.active_connections += 1;
498 }
499
500 self.connections.write().insert(addr, Instant::now());
502
503 info!("WebSocket connection accepted from {}", addr);
504
505 Ok(Box::new(connection))
506 }
507
508 fn stats(&self) -> TransportStats {
509 self.stats.read().clone()
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn test_websocket_config_default() {
519 let config = WebSocketConfig::default();
520 assert_eq!(config.ping_interval, Duration::from_secs(30));
521 assert!(config.use_binary);
522 assert_eq!(config.max_message_size, 16 * 1024 * 1024);
523 }
524
525 #[tokio::test]
526 async fn test_websocket_transport_creation() {
527 let transport = WebSocketTransport::default_config();
528 assert_eq!(transport.transport_type(), TransportType::WebSocket);
529 assert!(transport.is_available());
530
531 let caps = transport.capabilities();
532 assert!(!caps.multiplexing);
533 assert!(caps.encryption);
534 assert_eq!(caps.max_message_size, Some(16 * 1024 * 1024));
535 }
536
537 #[tokio::test]
538 async fn test_websocket_listen_and_connect() {
539 let transport = Arc::new(WebSocketTransport::default_config());
540
541 let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
543 transport.listen(addr).await.unwrap();
544
545 let listener = transport.listener.lock().await;
547 let bound_addr = listener.as_ref().unwrap().local_addr().unwrap();
548 drop(listener);
549
550 let transport_clone = transport.clone();
552 let accept_handle = tokio::spawn(async move { transport_clone.accept().await });
553
554 tokio::time::sleep(Duration::from_millis(50)).await;
556
557 let mut client_conn = transport.connect(bound_addr).await.unwrap();
559 let mut server_conn = accept_handle.await.unwrap().unwrap();
560
561 let test_data = Bytes::from("Hello, WebSocket!");
563 client_conn.send(test_data.clone()).await.unwrap();
564
565 let received = server_conn.receive().await.unwrap();
566 assert_eq!(received, test_data);
567
568 let client_metrics = client_conn.metrics();
570 assert!(client_metrics.bytes_sent > 0);
571
572 let server_metrics = server_conn.metrics();
573 assert!(server_metrics.bytes_received > 0);
574
575 client_conn.close().await.unwrap();
577 server_conn.close().await.unwrap();
578 }
579}