leptos_sync_core/transport/
websocket.rs1use super::{SyncTransport, TransportError};
4use std::collections::VecDeque;
5use std::sync::Arc;
6use std::time::Duration;
7use thiserror::Error;
8use tokio::sync::{mpsc, RwLock};
9
10#[cfg(target_arch = "wasm32")]
11use wasm_bindgen::prelude::*;
12#[cfg(target_arch = "wasm32")]
13use web_sys::{CloseEvent, ErrorEvent, MessageEvent, WebSocket};
14
15#[cfg(all(not(target_arch = "wasm32"), feature = "websocket"))]
16use futures_util::{SinkExt, StreamExt};
17#[cfg(all(not(target_arch = "wasm32"), feature = "websocket"))]
18use tokio_tungstenite::{connect_async, tungstenite::Message};
19
20#[cfg(all(not(target_arch = "wasm32"), feature = "websocket"))]
22use tungstenite::Message;
23
24#[derive(Error, Debug)]
25pub enum WebSocketError {
26 #[error("Connection failed: {0}")]
27 ConnectionFailed(String),
28 #[error("Send failed: {0}")]
29 SendFailed(String),
30 #[error("Receive failed: {0}")]
31 ReceiveFailed(String),
32 #[error("Not connected")]
33 NotConnected,
34 #[error("Serialization failed: {0}")]
35 SerializationFailed(String),
36 #[error("WebSocket error: {0}")]
37 WebSocketError(String),
38}
39
40impl From<WebSocketError> for TransportError {
41 fn from(err: WebSocketError) -> Self {
42 match err {
43 WebSocketError::ConnectionFailed(msg) => TransportError::ConnectionFailed(msg),
44 WebSocketError::SendFailed(msg) => TransportError::SendFailed(msg),
45 WebSocketError::ReceiveFailed(msg) => TransportError::ReceiveFailed(msg),
46 WebSocketError::NotConnected => TransportError::NotConnected,
47 WebSocketError::SerializationFailed(msg) => TransportError::SerializationFailed(msg),
48 WebSocketError::WebSocketError(msg) => TransportError::ConnectionFailed(msg),
49 }
50 }
51}
52
53#[derive(Debug, Clone, PartialEq)]
54pub enum ConnectionState {
55 Disconnected,
56 Connecting,
57 Connected,
58 Reconnecting,
59 Failed,
60}
61
62pub struct WebSocketTransport {
63 url: String,
64 connection_state: Arc<RwLock<ConnectionState>>,
65 message_queue: Arc<RwLock<VecDeque<Vec<u8>>>>,
66 message_sender: Option<mpsc::UnboundedSender<Vec<u8>>>,
67 message_receiver: Arc<RwLock<Option<mpsc::UnboundedReceiver<Vec<u8>>>>>,
68 config: WebSocketConfig,
69 #[cfg(target_arch = "wasm32")]
70 websocket: Arc<RwLock<Option<WebSocket>>>,
71}
72
73impl WebSocketTransport {
74 pub fn new(url: String) -> Self {
75 Self::with_config(url, WebSocketConfig::default())
76 }
77
78 pub fn with_config(url: String, config: WebSocketConfig) -> Self {
79 let (tx, rx) = mpsc::unbounded_channel();
80 Self {
81 url,
82 connection_state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
83 message_queue: Arc::new(RwLock::new(VecDeque::new())),
84 message_sender: Some(tx),
85 message_receiver: Arc::new(RwLock::new(Some(rx))),
86 config,
87 #[cfg(target_arch = "wasm32")]
88 websocket: Arc::new(RwLock::new(None)),
89 }
90 }
91
92 pub fn with_reconnect_config(url: String, max_attempts: usize, delay_ms: u32) -> Self {
93 let config = WebSocketConfig {
94 max_reconnect_attempts: max_attempts,
95 reconnect_delay: Duration::from_millis(delay_ms as u64),
96 ..Default::default()
97 };
98 Self::with_config(url, config)
99 }
100
101 pub async fn connect(&self) -> Result<(), WebSocketError> {
102 let mut state = self.connection_state.write().await;
103 if *state == ConnectionState::Connected {
104 return Ok(());
105 }
106
107 *state = ConnectionState::Connecting;
108 drop(state);
109
110 for attempt in 0..self.config.max_reconnect_attempts {
112 match self.attempt_connection().await {
113 Ok(()) => {
114 let mut state = self.connection_state.write().await;
115 *state = ConnectionState::Connected;
116 return Ok(());
117 }
118 Err(e) => {
119 if attempt < self.config.max_reconnect_attempts - 1 {
120 tracing::warn!(
121 "Connection attempt {} failed: {}. Retrying in {:?}...",
122 attempt + 1,
123 e,
124 self.config.reconnect_delay
125 );
126
127 let mut state = self.connection_state.write().await;
128 *state = ConnectionState::Reconnecting;
129 drop(state);
130
131 tokio::time::sleep(self.config.reconnect_delay).await;
132 } else {
133 let mut state = self.connection_state.write().await;
134 *state = ConnectionState::Failed;
135 return Err(e);
136 }
137 }
138 }
139 }
140
141 let mut state = self.connection_state.write().await;
142 *state = ConnectionState::Failed;
143 Err(WebSocketError::ConnectionFailed(
144 "Max reconnection attempts exceeded".to_string(),
145 ))
146 }
147
148 async fn attempt_connection(&self) -> Result<(), WebSocketError> {
149 #[cfg(target_arch = "wasm32")]
150 {
151 self.connect_wasm().await
152 }
153
154 #[cfg(not(target_arch = "wasm32"))]
155 {
156 self.connect_native().await
157 }
158 }
159
160 #[cfg(target_arch = "wasm32")]
161 async fn connect_wasm(&self) -> Result<(), WebSocketError> {
162 use wasm_bindgen_futures::JsFuture;
163
164 let ws = WebSocket::new(&self.url).map_err(|e| {
165 WebSocketError::ConnectionFailed(format!("Failed to create WebSocket: {:?}", e))
166 })?;
167
168 let message_queue = self.message_queue.clone();
170 let connection_state = self.connection_state.clone();
171
172 let onmessage = Closure::wrap(Box::new(move |event: MessageEvent| {
173 if let Some(data) = event.data().dyn_ref::<js_sys::Uint8Array>() {
174 let bytes: Vec<u8> = data.to_vec();
175 let message_queue = message_queue.clone();
176 wasm_bindgen_futures::spawn_local(async move {
177 let mut queue = message_queue.write().await;
178 queue.push_back(bytes);
179 });
180 }
181 }) as Box<dyn FnMut(_)>);
182
183 let onerror = Closure::wrap(Box::new(move |_event: ErrorEvent| {
184 let connection_state = connection_state.clone();
185 wasm_bindgen_futures::spawn_local(async move {
186 let mut state = connection_state.write().await;
187 *state = ConnectionState::Failed;
188 });
189 }) as Box<dyn FnMut(_)>);
190
191 let onclose = Closure::wrap(Box::new(move |_event: CloseEvent| {
192 let connection_state = connection_state.clone();
193 wasm_bindgen_futures::spawn_local(async move {
194 let mut state = connection_state.write().await;
195 *state = ConnectionState::Disconnected;
196 });
197 }) as Box<dyn FnMut(_)>);
198
199 ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
200 ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
201 ws.set_onclose(Some(onclose.as_ref().unchecked_ref()));
202
203 {
205 let mut ws_guard = self.websocket.write().await;
206 *ws_guard = Some(ws);
207 }
208
209 onmessage.forget();
211 onerror.forget();
212 onclose.forget();
213
214 Ok(())
215 }
216
217 #[cfg(all(not(target_arch = "wasm32"), feature = "websocket"))]
218 async fn connect_native(&self) -> Result<(), WebSocketError> {
219 let (ws_stream, _) = connect_async(&self.url)
220 .await
221 .map_err(|e| WebSocketError::ConnectionFailed(e.to_string()))?;
222
223 let (mut write, mut read) = ws_stream.split();
224
225 let message_queue = self.message_queue.clone();
227 tokio::spawn(async move {
228 while let Some(msg) = read.next().await {
229 match msg {
230 Ok(Message::Binary(data)) => {
231 let mut queue = message_queue.write().await;
232 queue.push_back(data);
233 }
234 Ok(Message::Text(text)) => {
235 let mut queue = message_queue.write().await;
236 queue.push_back(text.into_bytes());
237 }
238 Ok(Message::Close(_)) => {
239 break;
240 }
241 Err(e) => {
242 tracing::error!("WebSocket read error: {}", e);
243 break;
244 }
245 _ => {}
246 }
247 }
248 });
249
250 Ok(())
254 }
255
256 #[cfg(all(not(target_arch = "wasm32"), not(feature = "websocket")))]
257 async fn connect_native(&self) -> Result<(), WebSocketError> {
258 Err(WebSocketError::ConnectionFailed(
259 "WebSocket feature not enabled".to_string(),
260 ))
261 }
262
263 pub async fn disconnect(&self) -> Result<(), WebSocketError> {
264 let mut state = self.connection_state.write().await;
265 *state = ConnectionState::Disconnected;
266
267 let mut queue = self.message_queue.write().await;
269 queue.clear();
270
271 #[cfg(target_arch = "wasm32")]
272 {
273 let mut ws_guard = self.websocket.write().await;
274 if let Some(ws) = ws_guard.take() {
275 ws.close().ok();
276 }
277 }
278
279 Ok(())
280 }
281
282 pub async fn send_binary(&self, data: &[u8]) -> Result<(), WebSocketError> {
283 let state = self.connection_state.read().await;
284 if *state != ConnectionState::Connected {
285 return Err(WebSocketError::NotConnected);
286 }
287 drop(state);
288
289 #[cfg(target_arch = "wasm32")]
290 {
291 let ws_guard = self.websocket.read().await;
292 if let Some(ws) = ws_guard.as_ref() {
293 let array = js_sys::Uint8Array::new_with_length(data.len() as u32);
294 array.copy_from(data);
295 ws.send_with_u8_array(&array)
296 .map_err(|e| WebSocketError::SendFailed(format!("Failed to send: {:?}", e)))?;
297 } else {
298 return Err(WebSocketError::NotConnected);
299 }
300 }
301
302 #[cfg(not(target_arch = "wasm32"))]
303 {
304 tracing::debug!("Sent binary data: {} bytes", data.len());
307 }
308
309 Ok(())
310 }
311
312 pub async fn send_text(&self, text: &str) -> Result<(), WebSocketError> {
313 let state = self.connection_state.read().await;
314 if *state != ConnectionState::Connected {
315 return Err(WebSocketError::NotConnected);
316 }
317 drop(state);
318
319 #[cfg(target_arch = "wasm32")]
320 {
321 let ws_guard = self.websocket.read().await;
322 if let Some(ws) = ws_guard.as_ref() {
323 ws.send_with_str(text)
324 .map_err(|e| WebSocketError::SendFailed(format!("Failed to send: {:?}", e)))?;
325 } else {
326 return Err(WebSocketError::NotConnected);
327 }
328 }
329
330 #[cfg(not(target_arch = "wasm32"))]
331 {
332 tracing::debug!("Sent text: {}", text);
335 }
336
337 Ok(())
338 }
339
340 pub async fn connection_state(&self) -> ConnectionState {
341 self.connection_state.read().await.clone()
342 }
343
344 pub fn is_connected_sync(&self) -> bool {
345 match self.connection_state.try_read() {
346 Ok(state) => *state == ConnectionState::Connected,
347 Err(_) => false,
348 }
349 }
350}
351
352impl SyncTransport for WebSocketTransport {
353 type Error = TransportError;
354
355 fn send<'a>(
356 &'a self,
357 data: &'a [u8],
358 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Self::Error>> + Send + 'a>>
359 {
360 Box::pin(async move { self.send_binary(data).await.map_err(Into::into) })
361 }
362
363 fn receive(
364 &self,
365 ) -> std::pin::Pin<
366 Box<dyn std::future::Future<Output = Result<Vec<Vec<u8>>, Self::Error>> + Send + '_>,
367 > {
368 Box::pin(async move {
369 let mut queue = self.message_queue.write().await;
370 let messages = queue.drain(..).collect();
371 Ok(messages)
372 })
373 }
374
375 fn is_connected(&self) -> bool {
376 self.is_connected_sync()
377 }
378}
379
380impl Clone for WebSocketTransport {
381 fn clone(&self) -> Self {
382 let (tx, rx) = mpsc::unbounded_channel();
383 Self {
384 url: self.url.clone(),
385 connection_state: self.connection_state.clone(),
386 message_queue: self.message_queue.clone(),
387 message_sender: Some(tx),
388 message_receiver: Arc::new(RwLock::new(Some(rx))),
389 config: self.config.clone(),
390 #[cfg(target_arch = "wasm32")]
391 websocket: Arc::new(RwLock::new(None)),
392 }
393 }
394}
395
396#[derive(Debug, Clone)]
398pub struct WebSocketConfig {
399 pub auto_reconnect: bool,
400 pub max_reconnect_attempts: usize,
401 pub reconnect_delay: Duration,
402 pub heartbeat_interval: Duration,
403 pub connection_timeout: Duration,
404}
405
406impl Default for WebSocketConfig {
407 fn default() -> Self {
408 Self {
409 auto_reconnect: true,
410 max_reconnect_attempts: 5,
411 reconnect_delay: Duration::from_millis(1000),
412 heartbeat_interval: Duration::from_secs(30),
413 connection_timeout: Duration::from_secs(10),
414 }
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 #[tokio::test]
423 async fn test_websocket_transport_creation() {
424 let transport = WebSocketTransport::new("ws://localhost:8080".to_string());
425 assert_eq!(transport.url, "ws://localhost:8080");
426 assert!(!transport.is_connected());
427 }
428
429 #[tokio::test]
430 async fn test_websocket_config_default() {
431 let config = WebSocketConfig::default();
432 assert!(config.auto_reconnect);
433 assert_eq!(config.max_reconnect_attempts, 5);
434 assert_eq!(config.reconnect_delay, Duration::from_millis(1000));
435 }
436
437 #[tokio::test]
438 async fn test_websocket_with_reconnect_config() {
439 let transport =
440 WebSocketTransport::with_reconnect_config("ws://localhost:8080".to_string(), 10, 2000);
441 assert_eq!(transport.url, "ws://localhost:8080");
442 }
443
444 #[tokio::test]
445 async fn test_websocket_transport_operations() {
446 let transport = WebSocketTransport::new("ws://localhost:8080".to_string());
447
448 assert!(!transport.is_connected());
450 let state = transport.connection_state().await;
451 assert_eq!(state, ConnectionState::Disconnected);
452
453 assert!(transport.disconnect().await.is_ok());
455
456 assert!(transport.send_binary(b"test data").await.is_err());
458 assert!(transport.send_text("test message").await.is_err());
459
460 assert!(transport.send(b"test").await.is_err());
462 let received = transport.receive().await.unwrap();
463 assert_eq!(received.len(), 0); assert!(!transport.is_connected());
465 }
466
467 #[tokio::test]
468 async fn test_websocket_transport_clone() {
469 let transport1 = WebSocketTransport::new("ws://localhost:8080".to_string());
470 let transport2 = transport1.clone();
471
472 assert_eq!(transport1.url, transport2.url);
473 assert_eq!(transport1.is_connected(), transport2.is_connected());
474 }
475
476 #[tokio::test]
477 async fn test_websocket_connection_state() {
478 let transport = WebSocketTransport::new("ws://localhost:8080".to_string());
479
480 let state = transport.connection_state().await;
481 assert_eq!(state, ConnectionState::Disconnected);
482
483 let invalid_transport = WebSocketTransport::new("ws://invalid:9999".to_string());
485 let result = invalid_transport.connect().await;
486 assert!(result.is_err());
487
488 let state = invalid_transport.connection_state().await;
489 assert_eq!(state, ConnectionState::Failed);
490 }
491
492 #[tokio::test]
493 async fn test_websocket_config_custom() {
494 let config = WebSocketConfig {
495 auto_reconnect: false,
496 max_reconnect_attempts: 3,
497 reconnect_delay: Duration::from_millis(500),
498 heartbeat_interval: Duration::from_secs(60),
499 connection_timeout: Duration::from_secs(5),
500 };
501
502 let transport = WebSocketTransport::with_config("ws://localhost:8080".to_string(), config);
503 assert_eq!(transport.url, "ws://localhost:8080");
504 }
505}