ipfrs_transport/
websocket.rs

1//! WebSocket transport for gateway compatibility
2//!
3//! Provides WebSocket-based transport for compatibility with web gateways
4//! and restrictive network environments.
5//!
6//! Features:
7//! - Text and binary message support
8//! - Automatic reconnection
9//! - Ping/pong keepalive
10//! - TLS support
11
12use crate::transport::{
13    Connection, ConnectionMetrics, Transport, TransportCapabilities, TransportError,
14    TransportStats, TransportType,
15};
16use async_trait::async_trait;
17use bytes::Bytes;
18use futures::{SinkExt, StreamExt};
19use parking_lot::RwLock;
20use std::collections::HashMap;
21use std::net::SocketAddr;
22use std::sync::Arc;
23use std::time::{Duration, Instant};
24use tokio::net::{TcpListener, TcpStream};
25use tokio::sync::Mutex;
26use tokio_tungstenite::{
27    accept_async, connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream,
28};
29use tracing::{debug, info};
30
31/// WebSocket transport configuration
32#[derive(Debug, Clone)]
33pub struct WebSocketConfig {
34    /// Ping interval for keepalive
35    pub ping_interval: Duration,
36    /// Connection timeout
37    pub connect_timeout: Duration,
38    /// Maximum message size (16MB default)
39    pub max_message_size: usize,
40    /// Use binary frames (vs text)
41    pub use_binary: bool,
42}
43
44impl Default for WebSocketConfig {
45    fn default() -> Self {
46        Self {
47            ping_interval: Duration::from_secs(30),
48            connect_timeout: Duration::from_secs(10),
49            max_message_size: 16 * 1024 * 1024, // 16 MB
50            use_binary: true,
51        }
52    }
53}
54
55/// WebSocket connection wrapper for client connections
56pub struct WebSocketConnection {
57    stream: Arc<Mutex<WebSocketStream<MaybeTlsStream<TcpStream>>>>,
58    remote_addr: SocketAddr,
59    metrics: Arc<RwLock<ConnectionMetrics>>,
60    created_at: Instant,
61    alive: Arc<RwLock<bool>>,
62    config: WebSocketConfig,
63}
64
65/// WebSocket connection wrapper for server connections
66pub struct WebSocketServerConnection {
67    stream: Arc<Mutex<WebSocketStream<TcpStream>>>,
68    remote_addr: SocketAddr,
69    metrics: Arc<RwLock<ConnectionMetrics>>,
70    created_at: Instant,
71    alive: Arc<RwLock<bool>>,
72    config: WebSocketConfig,
73}
74
75impl WebSocketConnection {
76    /// Create a new WebSocket client connection
77    pub fn new(
78        stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
79        remote_addr: SocketAddr,
80        config: WebSocketConfig,
81    ) -> Self {
82        debug!("WebSocket connection established to {}", remote_addr);
83
84        Self {
85            stream: Arc::new(Mutex::new(stream)),
86            remote_addr,
87            metrics: Arc::new(RwLock::new(ConnectionMetrics::default())),
88            created_at: Instant::now(),
89            alive: Arc::new(RwLock::new(true)),
90            config,
91        }
92    }
93
94    /// Send a ping frame
95    #[allow(dead_code)]
96    async fn send_ping(&self) -> Result<(), TransportError> {
97        let mut stream = self.stream.lock().await;
98        stream
99            .send(Message::Ping(vec![].into()))
100            .await
101            .map_err(|e| TransportError::SendFailed(format!("Ping failed: {}", e)))?;
102        Ok(())
103    }
104}
105
106#[async_trait]
107impl Connection for WebSocketConnection {
108    async fn send(&mut self, data: Bytes) -> Result<(), TransportError> {
109        if data.len() > self.config.max_message_size {
110            return Err(TransportError::ProtocolError(format!(
111                "Message size {} exceeds maximum {}",
112                data.len(),
113                self.config.max_message_size
114            )));
115        }
116
117        let data_len = data.len();
118        let message = if self.config.use_binary {
119            Message::Binary(data)
120        } else {
121            Message::Text(String::from_utf8_lossy(&data).to_string().into())
122        };
123
124        let mut stream = self.stream.lock().await;
125
126        stream.send(message).await.map_err(|e| {
127            *self.alive.write() = false;
128            TransportError::SendFailed(format!("WebSocket send failed: {}", e))
129        })?;
130
131        // Update metrics
132        {
133            let mut metrics = self.metrics.write();
134            metrics.bytes_sent += data_len as u64;
135        }
136
137        Ok(())
138    }
139
140    async fn receive(&mut self) -> Result<Bytes, TransportError> {
141        let mut stream = self.stream.lock().await;
142
143        loop {
144            match stream.next().await {
145                Some(Ok(message)) => match message {
146                    Message::Binary(data) => {
147                        // Update metrics
148                        {
149                            let mut metrics = self.metrics.write();
150                            metrics.bytes_received += data.len() as u64;
151                        }
152                        return Ok(Bytes::from(data));
153                    }
154                    Message::Text(text) => {
155                        // Convert Utf8Bytes to bytes::Bytes
156                        let data = Bytes::copy_from_slice(text.as_bytes());
157                        // Update metrics
158                        {
159                            let mut metrics = self.metrics.write();
160                            metrics.bytes_received += data.len() as u64;
161                        }
162                        return Ok(data);
163                    }
164                    Message::Ping(_) => {
165                        // Automatically respond with pong
166                        debug!("Received ping, sending pong");
167                        stream
168                            .send(Message::Pong(vec![].into()))
169                            .await
170                            .map_err(|e| {
171                                TransportError::SendFailed(format!("Pong failed: {}", e))
172                            })?;
173                        continue;
174                    }
175                    Message::Pong(_) => {
176                        debug!("Received pong");
177                        continue;
178                    }
179                    Message::Close(_) => {
180                        *self.alive.write() = false;
181                        return Err(TransportError::ConnectionClosed(
182                            "Received close frame".to_string(),
183                        ));
184                    }
185                    Message::Frame(_) => {
186                        // Raw frames shouldn't happen in normal operation
187                        continue;
188                    }
189                },
190                Some(Err(e)) => {
191                    *self.alive.write() = false;
192                    return Err(TransportError::ReceiveFailed(format!(
193                        "WebSocket receive error: {}",
194                        e
195                    )));
196                }
197                None => {
198                    *self.alive.write() = false;
199                    return Err(TransportError::ConnectionClosed(
200                        "WebSocket stream ended".to_string(),
201                    ));
202                }
203            }
204        }
205    }
206
207    async fn close(&mut self) -> Result<(), TransportError> {
208        *self.alive.write() = false;
209        let mut stream = self.stream.lock().await;
210        stream
211            .close(None)
212            .await
213            .map_err(|e| TransportError::ConnectionClosed(format!("Close failed: {}", e)))?;
214        debug!("WebSocket connection to {} closed", self.remote_addr);
215        Ok(())
216    }
217
218    fn is_alive(&self) -> bool {
219        *self.alive.read()
220    }
221
222    fn metrics(&self) -> ConnectionMetrics {
223        let mut metrics = self.metrics.read().clone();
224        metrics.uptime = self.created_at.elapsed();
225        metrics.active_streams = 1; // WebSocket has single stream
226        metrics
227    }
228
229    fn remote_addr(&self) -> SocketAddr {
230        self.remote_addr
231    }
232
233    fn transport_type(&self) -> TransportType {
234        TransportType::WebSocket
235    }
236}
237
238impl WebSocketServerConnection {
239    /// Create a new WebSocket server connection
240    pub fn new(
241        stream: WebSocketStream<TcpStream>,
242        remote_addr: SocketAddr,
243        config: WebSocketConfig,
244    ) -> Self {
245        debug!("WebSocket server connection accepted from {}", remote_addr);
246
247        Self {
248            stream: Arc::new(Mutex::new(stream)),
249            remote_addr,
250            metrics: Arc::new(RwLock::new(ConnectionMetrics::default())),
251            created_at: Instant::now(),
252            alive: Arc::new(RwLock::new(true)),
253            config,
254        }
255    }
256}
257
258#[async_trait]
259impl Connection for WebSocketServerConnection {
260    async fn send(&mut self, data: Bytes) -> Result<(), TransportError> {
261        if data.len() > self.config.max_message_size {
262            return Err(TransportError::ProtocolError(format!(
263                "Message size {} exceeds maximum {}",
264                data.len(),
265                self.config.max_message_size
266            )));
267        }
268
269        let data_len = data.len();
270        let message = if self.config.use_binary {
271            Message::Binary(data)
272        } else {
273            Message::Text(String::from_utf8_lossy(&data).to_string().into())
274        };
275
276        let mut stream = self.stream.lock().await;
277
278        stream.send(message).await.map_err(|e| {
279            *self.alive.write() = false;
280            TransportError::SendFailed(format!("WebSocket send failed: {}", e))
281        })?;
282
283        // Update metrics
284        {
285            let mut metrics = self.metrics.write();
286            metrics.bytes_sent += data_len as u64;
287        }
288
289        Ok(())
290    }
291
292    async fn receive(&mut self) -> Result<Bytes, TransportError> {
293        let mut stream = self.stream.lock().await;
294
295        loop {
296            match stream.next().await {
297                Some(Ok(message)) => match message {
298                    Message::Binary(data) => {
299                        // Update metrics
300                        {
301                            let mut metrics = self.metrics.write();
302                            metrics.bytes_received += data.len() as u64;
303                        }
304                        return Ok(Bytes::from(data));
305                    }
306                    Message::Text(text) => {
307                        // Convert Utf8Bytes to bytes::Bytes
308                        let data = Bytes::copy_from_slice(text.as_bytes());
309                        // Update metrics
310                        {
311                            let mut metrics = self.metrics.write();
312                            metrics.bytes_received += data.len() as u64;
313                        }
314                        return Ok(data);
315                    }
316                    Message::Ping(_) => {
317                        // Automatically respond with pong
318                        debug!("Received ping, sending pong");
319                        stream
320                            .send(Message::Pong(vec![].into()))
321                            .await
322                            .map_err(|e| {
323                                TransportError::SendFailed(format!("Pong failed: {}", e))
324                            })?;
325                        continue;
326                    }
327                    Message::Pong(_) => {
328                        debug!("Received pong");
329                        continue;
330                    }
331                    Message::Close(_) => {
332                        *self.alive.write() = false;
333                        return Err(TransportError::ConnectionClosed(
334                            "Received close frame".to_string(),
335                        ));
336                    }
337                    Message::Frame(_) => {
338                        // Raw frames shouldn't happen in normal operation
339                        continue;
340                    }
341                },
342                Some(Err(e)) => {
343                    *self.alive.write() = false;
344                    return Err(TransportError::ReceiveFailed(format!(
345                        "WebSocket receive error: {}",
346                        e
347                    )));
348                }
349                None => {
350                    *self.alive.write() = false;
351                    return Err(TransportError::ConnectionClosed(
352                        "WebSocket stream ended".to_string(),
353                    ));
354                }
355            }
356        }
357    }
358
359    async fn close(&mut self) -> Result<(), TransportError> {
360        *self.alive.write() = false;
361        let mut stream = self.stream.lock().await;
362        stream
363            .close(None)
364            .await
365            .map_err(|e| TransportError::ConnectionClosed(format!("Close failed: {}", e)))?;
366        debug!("WebSocket connection to {} closed", self.remote_addr);
367        Ok(())
368    }
369
370    fn is_alive(&self) -> bool {
371        *self.alive.read()
372    }
373
374    fn metrics(&self) -> ConnectionMetrics {
375        let mut metrics = self.metrics.read().clone();
376        metrics.uptime = self.created_at.elapsed();
377        metrics.active_streams = 1; // WebSocket has single stream
378        metrics
379    }
380
381    fn remote_addr(&self) -> SocketAddr {
382        self.remote_addr
383    }
384
385    fn transport_type(&self) -> TransportType {
386        TransportType::WebSocket
387    }
388}
389
390/// WebSocket transport implementation
391pub struct WebSocketTransport {
392    config: WebSocketConfig,
393    listener: Arc<Mutex<Option<TcpListener>>>,
394    stats: Arc<RwLock<TransportStats>>,
395    connections: Arc<RwLock<HashMap<SocketAddr, Instant>>>,
396}
397
398impl WebSocketTransport {
399    /// Create a new WebSocket transport
400    pub fn new(config: WebSocketConfig) -> Self {
401        Self {
402            config,
403            listener: Arc::new(Mutex::new(None)),
404            stats: Arc::new(RwLock::new(TransportStats::default())),
405            connections: Arc::new(RwLock::new(HashMap::new())),
406        }
407    }
408
409    /// Create with default configuration
410    pub fn default_config() -> Self {
411        Self::new(WebSocketConfig::default())
412    }
413}
414
415#[async_trait]
416impl Transport for WebSocketTransport {
417    fn transport_type(&self) -> TransportType {
418        TransportType::WebSocket
419    }
420
421    fn capabilities(&self) -> TransportCapabilities {
422        TransportCapabilities::websocket()
423    }
424
425    fn is_available(&self) -> bool {
426        // WebSocket is always available
427        true
428    }
429
430    async fn connect(&self, addr: SocketAddr) -> Result<Box<dyn Connection>, TransportError> {
431        debug!("Connecting to {} via WebSocket", addr);
432
433        // Construct WebSocket URL
434        let url = format!("ws://{}", addr);
435
436        let (ws_stream, _) = tokio::time::timeout(self.config.connect_timeout, connect_async(&url))
437            .await
438            .map_err(|_| TransportError::Timeout(self.config.connect_timeout))?
439            .map_err(|e| {
440                self.stats.write().connections_failed += 1;
441                TransportError::ConnectionFailed(format!("WebSocket connect failed: {}", e))
442            })?;
443
444        // Extract underlying TCP stream's peer address
445        let connection = WebSocketConnection::new(ws_stream, addr, self.config.clone());
446
447        // Update stats
448        {
449            let mut stats = self.stats.write();
450            stats.connections_established += 1;
451            stats.active_connections += 1;
452        }
453
454        // Track connection
455        self.connections.write().insert(addr, Instant::now());
456
457        info!("WebSocket connection established to {}", addr);
458
459        Ok(Box::new(connection))
460    }
461
462    async fn listen(&self, addr: SocketAddr) -> Result<(), TransportError> {
463        let listener = TcpListener::bind(addr).await.map_err(|e| {
464            TransportError::ConnectionFailed(format!("Failed to bind WebSocket listener: {}", e))
465        })?;
466
467        info!("WebSocket transport listening on {}", addr);
468
469        *self.listener.lock().await = Some(listener);
470        Ok(())
471    }
472
473    async fn accept(&self) -> Result<Box<dyn Connection>, TransportError> {
474        let listener = self.listener.lock().await;
475        let listener = listener
476            .as_ref()
477            .ok_or_else(|| TransportError::ProtocolError("No listener bound".to_string()))?;
478
479        let (stream, addr) = listener
480            .accept()
481            .await
482            .map_err(|e| TransportError::ConnectionFailed(format!("Accept failed: {}", e)))?;
483
484        debug!("Accepting WebSocket connection from {}", addr);
485
486        // Perform WebSocket handshake
487        let ws_stream = accept_async(stream).await.map_err(|e| {
488            TransportError::ConnectionFailed(format!("WebSocket handshake failed: {}", e))
489        })?;
490
491        let connection = WebSocketServerConnection::new(ws_stream, addr, self.config.clone());
492
493        // Update stats
494        {
495            let mut stats = self.stats.write();
496            stats.connections_established += 1;
497            stats.active_connections += 1;
498        }
499
500        // Track connection
501        self.connections.write().insert(addr, Instant::now());
502
503        info!("WebSocket connection accepted from {}", addr);
504
505        Ok(Box::new(connection))
506    }
507
508    fn stats(&self) -> TransportStats {
509        self.stats.read().clone()
510    }
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    #[test]
518    fn test_websocket_config_default() {
519        let config = WebSocketConfig::default();
520        assert_eq!(config.ping_interval, Duration::from_secs(30));
521        assert!(config.use_binary);
522        assert_eq!(config.max_message_size, 16 * 1024 * 1024);
523    }
524
525    #[tokio::test]
526    async fn test_websocket_transport_creation() {
527        let transport = WebSocketTransport::default_config();
528        assert_eq!(transport.transport_type(), TransportType::WebSocket);
529        assert!(transport.is_available());
530
531        let caps = transport.capabilities();
532        assert!(!caps.multiplexing);
533        assert!(caps.encryption);
534        assert_eq!(caps.max_message_size, Some(16 * 1024 * 1024));
535    }
536
537    #[tokio::test]
538    async fn test_websocket_listen_and_connect() {
539        let transport = Arc::new(WebSocketTransport::default_config());
540
541        // Bind to localhost
542        let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
543        transport.listen(addr).await.unwrap();
544
545        // Get the actual bound address
546        let listener = transport.listener.lock().await;
547        let bound_addr = listener.as_ref().unwrap().local_addr().unwrap();
548        drop(listener);
549
550        // Spawn accept task
551        let transport_clone = transport.clone();
552        let accept_handle = tokio::spawn(async move { transport_clone.accept().await });
553
554        // Give accept time to start
555        tokio::time::sleep(Duration::from_millis(50)).await;
556
557        // Connect
558        let mut client_conn = transport.connect(bound_addr).await.unwrap();
559        let mut server_conn = accept_handle.await.unwrap().unwrap();
560
561        // Test send/receive
562        let test_data = Bytes::from("Hello, WebSocket!");
563        client_conn.send(test_data.clone()).await.unwrap();
564
565        let received = server_conn.receive().await.unwrap();
566        assert_eq!(received, test_data);
567
568        // Check metrics
569        let client_metrics = client_conn.metrics();
570        assert!(client_metrics.bytes_sent > 0);
571
572        let server_metrics = server_conn.metrics();
573        assert!(server_metrics.bytes_received > 0);
574
575        // Close connections
576        client_conn.close().await.unwrap();
577        server_conn.close().await.unwrap();
578    }
579}