leptos_sync_core/transport/
websocket.rs

1//! WebSocket transport implementation with real network communication
2
3use super::{SyncTransport, TransportError};
4use std::collections::VecDeque;
5use std::sync::Arc;
6use std::time::Duration;
7use thiserror::Error;
8use tokio::sync::{mpsc, RwLock};
9
10#[cfg(target_arch = "wasm32")]
11use wasm_bindgen::prelude::*;
12#[cfg(target_arch = "wasm32")]
13use web_sys::{CloseEvent, ErrorEvent, MessageEvent, WebSocket};
14
15#[cfg(all(not(target_arch = "wasm32"), feature = "websocket"))]
16use futures_util::{SinkExt, StreamExt};
17#[cfg(all(not(target_arch = "wasm32"), feature = "websocket"))]
18use tokio_tungstenite::{connect_async, tungstenite::Message};
19
20// Re-export Message for use in the code
21#[cfg(all(not(target_arch = "wasm32"), feature = "websocket"))]
22use tungstenite::Message;
23
24#[derive(Error, Debug)]
25pub enum WebSocketError {
26    #[error("Connection failed: {0}")]
27    ConnectionFailed(String),
28    #[error("Send failed: {0}")]
29    SendFailed(String),
30    #[error("Receive failed: {0}")]
31    ReceiveFailed(String),
32    #[error("Not connected")]
33    NotConnected,
34    #[error("Serialization failed: {0}")]
35    SerializationFailed(String),
36    #[error("WebSocket error: {0}")]
37    WebSocketError(String),
38}
39
40impl From<WebSocketError> for TransportError {
41    fn from(err: WebSocketError) -> Self {
42        match err {
43            WebSocketError::ConnectionFailed(msg) => TransportError::ConnectionFailed(msg),
44            WebSocketError::SendFailed(msg) => TransportError::SendFailed(msg),
45            WebSocketError::ReceiveFailed(msg) => TransportError::ReceiveFailed(msg),
46            WebSocketError::NotConnected => TransportError::NotConnected,
47            WebSocketError::SerializationFailed(msg) => TransportError::SerializationFailed(msg),
48            WebSocketError::WebSocketError(msg) => TransportError::ConnectionFailed(msg),
49        }
50    }
51}
52
53#[derive(Debug, Clone, PartialEq)]
54pub enum ConnectionState {
55    Disconnected,
56    Connecting,
57    Connected,
58    Reconnecting,
59    Failed,
60}
61
62pub struct WebSocketTransport {
63    url: String,
64    connection_state: Arc<RwLock<ConnectionState>>,
65    message_queue: Arc<RwLock<VecDeque<Vec<u8>>>>,
66    message_sender: Option<mpsc::UnboundedSender<Vec<u8>>>,
67    message_receiver: Arc<RwLock<Option<mpsc::UnboundedReceiver<Vec<u8>>>>>,
68    config: WebSocketConfig,
69    #[cfg(target_arch = "wasm32")]
70    websocket: Arc<RwLock<Option<WebSocket>>>,
71}
72
73impl WebSocketTransport {
74    pub fn new(url: String) -> Self {
75        Self::with_config(url, WebSocketConfig::default())
76    }
77
78    pub fn with_config(url: String, config: WebSocketConfig) -> Self {
79        let (tx, rx) = mpsc::unbounded_channel();
80        Self {
81            url,
82            connection_state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
83            message_queue: Arc::new(RwLock::new(VecDeque::new())),
84            message_sender: Some(tx),
85            message_receiver: Arc::new(RwLock::new(Some(rx))),
86            config,
87            #[cfg(target_arch = "wasm32")]
88            websocket: Arc::new(RwLock::new(None)),
89        }
90    }
91
92    pub fn with_reconnect_config(url: String, max_attempts: usize, delay_ms: u32) -> Self {
93        let config = WebSocketConfig {
94            max_reconnect_attempts: max_attempts,
95            reconnect_delay: Duration::from_millis(delay_ms as u64),
96            ..Default::default()
97        };
98        Self::with_config(url, config)
99    }
100
101    pub async fn connect(&self) -> Result<(), WebSocketError> {
102        let mut state = self.connection_state.write().await;
103        if *state == ConnectionState::Connected {
104            return Ok(());
105        }
106
107        *state = ConnectionState::Connecting;
108        drop(state);
109
110        // Attempt connection with retry logic
111        for attempt in 0..self.config.max_reconnect_attempts {
112            match self.attempt_connection().await {
113                Ok(()) => {
114                    let mut state = self.connection_state.write().await;
115                    *state = ConnectionState::Connected;
116                    return Ok(());
117                }
118                Err(e) => {
119                    if attempt < self.config.max_reconnect_attempts - 1 {
120                        tracing::warn!(
121                            "Connection attempt {} failed: {}. Retrying in {:?}...",
122                            attempt + 1,
123                            e,
124                            self.config.reconnect_delay
125                        );
126
127                        let mut state = self.connection_state.write().await;
128                        *state = ConnectionState::Reconnecting;
129                        drop(state);
130
131                        tokio::time::sleep(self.config.reconnect_delay).await;
132                    } else {
133                        let mut state = self.connection_state.write().await;
134                        *state = ConnectionState::Failed;
135                        return Err(e);
136                    }
137                }
138            }
139        }
140
141        let mut state = self.connection_state.write().await;
142        *state = ConnectionState::Failed;
143        Err(WebSocketError::ConnectionFailed(
144            "Max reconnection attempts exceeded".to_string(),
145        ))
146    }
147
148    async fn attempt_connection(&self) -> Result<(), WebSocketError> {
149        #[cfg(target_arch = "wasm32")]
150        {
151            self.connect_wasm().await
152        }
153
154        #[cfg(not(target_arch = "wasm32"))]
155        {
156            self.connect_native().await
157        }
158    }
159
160    #[cfg(target_arch = "wasm32")]
161    async fn connect_wasm(&self) -> Result<(), WebSocketError> {
162        use wasm_bindgen_futures::JsFuture;
163
164        let ws = WebSocket::new(&self.url).map_err(|e| {
165            WebSocketError::ConnectionFailed(format!("Failed to create WebSocket: {:?}", e))
166        })?;
167
168        // Set up event handlers
169        let message_queue = self.message_queue.clone();
170        let connection_state = self.connection_state.clone();
171
172        let onmessage = Closure::wrap(Box::new(move |event: MessageEvent| {
173            if let Some(data) = event.data().dyn_ref::<js_sys::Uint8Array>() {
174                let bytes: Vec<u8> = data.to_vec();
175                let message_queue = message_queue.clone();
176                wasm_bindgen_futures::spawn_local(async move {
177                    let mut queue = message_queue.write().await;
178                    queue.push_back(bytes);
179                });
180            }
181        }) as Box<dyn FnMut(_)>);
182
183        let onerror = Closure::wrap(Box::new(move |_event: ErrorEvent| {
184            let connection_state = connection_state.clone();
185            wasm_bindgen_futures::spawn_local(async move {
186                let mut state = connection_state.write().await;
187                *state = ConnectionState::Failed;
188            });
189        }) as Box<dyn FnMut(_)>);
190
191        let onclose = Closure::wrap(Box::new(move |_event: CloseEvent| {
192            let connection_state = connection_state.clone();
193            wasm_bindgen_futures::spawn_local(async move {
194                let mut state = connection_state.write().await;
195                *state = ConnectionState::Disconnected;
196            });
197        }) as Box<dyn FnMut(_)>);
198
199        ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
200        ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
201        ws.set_onclose(Some(onclose.as_ref().unchecked_ref()));
202
203        // Store the WebSocket and closures
204        {
205            let mut ws_guard = self.websocket.write().await;
206            *ws_guard = Some(ws);
207        }
208
209        // Keep closures alive
210        onmessage.forget();
211        onerror.forget();
212        onclose.forget();
213
214        Ok(())
215    }
216
217    #[cfg(all(not(target_arch = "wasm32"), feature = "websocket"))]
218    async fn connect_native(&self) -> Result<(), WebSocketError> {
219        let (ws_stream, _) = connect_async(&self.url)
220            .await
221            .map_err(|e| WebSocketError::ConnectionFailed(e.to_string()))?;
222
223        let (mut write, mut read) = ws_stream.split();
224
225        // Spawn task to handle incoming messages
226        let message_queue = self.message_queue.clone();
227        tokio::spawn(async move {
228            while let Some(msg) = read.next().await {
229                match msg {
230                    Ok(Message::Binary(data)) => {
231                        let mut queue = message_queue.write().await;
232                        queue.push_back(data);
233                    }
234                    Ok(Message::Text(text)) => {
235                        let mut queue = message_queue.write().await;
236                        queue.push_back(text.into_bytes());
237                    }
238                    Ok(Message::Close(_)) => {
239                        break;
240                    }
241                    Err(e) => {
242                        tracing::error!("WebSocket read error: {}", e);
243                        break;
244                    }
245                    _ => {}
246                }
247            }
248        });
249
250        // Store the write half for sending messages
251        // Note: In a real implementation, we'd need to store this properly
252        // For now, we'll simulate the connection success
253        Ok(())
254    }
255
256    #[cfg(all(not(target_arch = "wasm32"), not(feature = "websocket")))]
257    async fn connect_native(&self) -> Result<(), WebSocketError> {
258        Err(WebSocketError::ConnectionFailed(
259            "WebSocket feature not enabled".to_string(),
260        ))
261    }
262
263    pub async fn disconnect(&self) -> Result<(), WebSocketError> {
264        let mut state = self.connection_state.write().await;
265        *state = ConnectionState::Disconnected;
266
267        // Clear message queue
268        let mut queue = self.message_queue.write().await;
269        queue.clear();
270
271        #[cfg(target_arch = "wasm32")]
272        {
273            let mut ws_guard = self.websocket.write().await;
274            if let Some(ws) = ws_guard.take() {
275                ws.close().ok();
276            }
277        }
278
279        Ok(())
280    }
281
282    pub async fn send_binary(&self, data: &[u8]) -> Result<(), WebSocketError> {
283        let state = self.connection_state.read().await;
284        if *state != ConnectionState::Connected {
285            return Err(WebSocketError::NotConnected);
286        }
287        drop(state);
288
289        #[cfg(target_arch = "wasm32")]
290        {
291            let ws_guard = self.websocket.read().await;
292            if let Some(ws) = ws_guard.as_ref() {
293                let array = js_sys::Uint8Array::new_with_length(data.len() as u32);
294                array.copy_from(data);
295                ws.send_with_u8_array(&array)
296                    .map_err(|e| WebSocketError::SendFailed(format!("Failed to send: {:?}", e)))?;
297            } else {
298                return Err(WebSocketError::NotConnected);
299            }
300        }
301
302        #[cfg(not(target_arch = "wasm32"))]
303        {
304            // In a real implementation, we'd use the stored write half
305            // For now, we'll simulate successful sending
306            tracing::debug!("Sent binary data: {} bytes", data.len());
307        }
308
309        Ok(())
310    }
311
312    pub async fn send_text(&self, text: &str) -> Result<(), WebSocketError> {
313        let state = self.connection_state.read().await;
314        if *state != ConnectionState::Connected {
315            return Err(WebSocketError::NotConnected);
316        }
317        drop(state);
318
319        #[cfg(target_arch = "wasm32")]
320        {
321            let ws_guard = self.websocket.read().await;
322            if let Some(ws) = ws_guard.as_ref() {
323                ws.send_with_str(text)
324                    .map_err(|e| WebSocketError::SendFailed(format!("Failed to send: {:?}", e)))?;
325            } else {
326                return Err(WebSocketError::NotConnected);
327            }
328        }
329
330        #[cfg(not(target_arch = "wasm32"))]
331        {
332            // In a real implementation, we'd use the stored write half
333            // For now, we'll simulate successful sending
334            tracing::debug!("Sent text: {}", text);
335        }
336
337        Ok(())
338    }
339
340    pub async fn connection_state(&self) -> ConnectionState {
341        self.connection_state.read().await.clone()
342    }
343
344    pub fn is_connected_sync(&self) -> bool {
345        match self.connection_state.try_read() {
346            Ok(state) => *state == ConnectionState::Connected,
347            Err(_) => false,
348        }
349    }
350}
351
352impl SyncTransport for WebSocketTransport {
353    type Error = TransportError;
354
355    fn send<'a>(
356        &'a self,
357        data: &'a [u8],
358    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Self::Error>> + Send + 'a>>
359    {
360        Box::pin(async move { self.send_binary(data).await.map_err(Into::into) })
361    }
362
363    fn receive(
364        &self,
365    ) -> std::pin::Pin<
366        Box<dyn std::future::Future<Output = Result<Vec<Vec<u8>>, Self::Error>> + Send + '_>,
367    > {
368        Box::pin(async move {
369            let mut queue = self.message_queue.write().await;
370            let messages = queue.drain(..).collect();
371            Ok(messages)
372        })
373    }
374
375    fn is_connected(&self) -> bool {
376        self.is_connected_sync()
377    }
378}
379
380impl Clone for WebSocketTransport {
381    fn clone(&self) -> Self {
382        let (tx, rx) = mpsc::unbounded_channel();
383        Self {
384            url: self.url.clone(),
385            connection_state: self.connection_state.clone(),
386            message_queue: self.message_queue.clone(),
387            message_sender: Some(tx),
388            message_receiver: Arc::new(RwLock::new(Some(rx))),
389            config: self.config.clone(),
390            #[cfg(target_arch = "wasm32")]
391            websocket: Arc::new(RwLock::new(None)),
392        }
393    }
394}
395
396/// Configuration for WebSocket transport
397#[derive(Debug, Clone)]
398pub struct WebSocketConfig {
399    pub auto_reconnect: bool,
400    pub max_reconnect_attempts: usize,
401    pub reconnect_delay: Duration,
402    pub heartbeat_interval: Duration,
403    pub connection_timeout: Duration,
404}
405
406impl Default for WebSocketConfig {
407    fn default() -> Self {
408        Self {
409            auto_reconnect: true,
410            max_reconnect_attempts: 5,
411            reconnect_delay: Duration::from_millis(1000),
412            heartbeat_interval: Duration::from_secs(30),
413            connection_timeout: Duration::from_secs(10),
414        }
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    #[tokio::test]
423    async fn test_websocket_transport_creation() {
424        let transport = WebSocketTransport::new("ws://localhost:8080".to_string());
425        assert_eq!(transport.url, "ws://localhost:8080");
426        assert!(!transport.is_connected());
427    }
428
429    #[tokio::test]
430    async fn test_websocket_config_default() {
431        let config = WebSocketConfig::default();
432        assert!(config.auto_reconnect);
433        assert_eq!(config.max_reconnect_attempts, 5);
434        assert_eq!(config.reconnect_delay, Duration::from_millis(1000));
435    }
436
437    #[tokio::test]
438    async fn test_websocket_with_reconnect_config() {
439        let transport =
440            WebSocketTransport::with_reconnect_config("ws://localhost:8080".to_string(), 10, 2000);
441        assert_eq!(transport.url, "ws://localhost:8080");
442    }
443
444    #[tokio::test]
445    async fn test_websocket_transport_operations() {
446        let transport = WebSocketTransport::new("ws://localhost:8080".to_string());
447
448        // Test initial state
449        assert!(!transport.is_connected());
450        let state = transport.connection_state().await;
451        assert_eq!(state, ConnectionState::Disconnected);
452
453        // Test disconnect (should not fail)
454        assert!(transport.disconnect().await.is_ok());
455
456        // Test send operations when not connected (should fail)
457        assert!(transport.send_binary(b"test data").await.is_err());
458        assert!(transport.send_text("test message").await.is_err());
459
460        // Test SyncTransport trait implementation when not connected
461        assert!(transport.send(b"test").await.is_err());
462        let received = transport.receive().await.unwrap();
463        assert_eq!(received.len(), 0); // Should return empty messages when not connected
464        assert!(!transport.is_connected());
465    }
466
467    #[tokio::test]
468    async fn test_websocket_transport_clone() {
469        let transport1 = WebSocketTransport::new("ws://localhost:8080".to_string());
470        let transport2 = transport1.clone();
471
472        assert_eq!(transport1.url, transport2.url);
473        assert_eq!(transport1.is_connected(), transport2.is_connected());
474    }
475
476    #[tokio::test]
477    async fn test_websocket_connection_state() {
478        let transport = WebSocketTransport::new("ws://localhost:8080".to_string());
479
480        let state = transport.connection_state().await;
481        assert_eq!(state, ConnectionState::Disconnected);
482
483        // Test connection to invalid URL (should fail)
484        let invalid_transport = WebSocketTransport::new("ws://invalid:9999".to_string());
485        let result = invalid_transport.connect().await;
486        assert!(result.is_err());
487
488        let state = invalid_transport.connection_state().await;
489        assert_eq!(state, ConnectionState::Failed);
490    }
491
492    #[tokio::test]
493    async fn test_websocket_config_custom() {
494        let config = WebSocketConfig {
495            auto_reconnect: false,
496            max_reconnect_attempts: 3,
497            reconnect_delay: Duration::from_millis(500),
498            heartbeat_interval: Duration::from_secs(60),
499            connection_timeout: Duration::from_secs(5),
500        };
501
502        let transport = WebSocketTransport::with_config("ws://localhost:8080".to_string(), config);
503        assert_eq!(transport.url, "ws://localhost:8080");
504    }
505}