leptos_sync_core/transport/
websocket_client.rs1use super::{SyncTransport, TransportError};
4use super::message_protocol::{SyncMessage, MessageCodec};
5use crate::crdt::ReplicaId;
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use std::time::{Duration, SystemTime};
9use tokio::sync::{RwLock, mpsc};
10use tokio::time::{timeout, interval};
11use thiserror::Error;
12
13#[derive(Error, Debug)]
14pub enum WebSocketClientError {
15 #[error("Connection failed: {0}")]
16 ConnectionFailed(String),
17 #[error("Send failed: {0}")]
18 SendFailed(String),
19 #[error("Receive failed: {0}")]
20 ReceiveFailed(String),
21 #[error("Serialization failed: {0}")]
22 SerializationFailed(String),
23 #[error("Not connected")]
24 NotConnected,
25 #[error("Timeout: {0}")]
26 Timeout(String),
27 #[error("WebSocket error: {0}")]
28 WebSocketError(String),
29}
30
31impl From<WebSocketClientError> for TransportError {
32 fn from(err: WebSocketClientError) -> Self {
33 match err {
34 WebSocketClientError::ConnectionFailed(msg) => TransportError::ConnectionFailed(msg),
35 WebSocketClientError::SendFailed(msg) => TransportError::SendFailed(msg),
36 WebSocketClientError::ReceiveFailed(msg) => TransportError::ReceiveFailed(msg),
37 WebSocketClientError::SerializationFailed(msg) => TransportError::SerializationFailed(msg),
38 WebSocketClientError::NotConnected => TransportError::NotConnected,
39 WebSocketClientError::Timeout(msg) => TransportError::ConnectionFailed(msg),
40 WebSocketClientError::WebSocketError(msg) => TransportError::ConnectionFailed(msg),
41 }
42 }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct WebSocketClientConfig {
48 pub url: String,
49 pub reconnect_attempts: u32,
50 pub heartbeat_interval: Duration,
51 pub message_timeout: Duration,
52 pub connection_timeout: Duration,
53 pub retry_delay: Duration,
54}
55
56impl Default for WebSocketClientConfig {
57 fn default() -> Self {
58 Self {
59 url: "ws://localhost:3001/sync".to_string(),
60 reconnect_attempts: 5,
61 heartbeat_interval: Duration::from_secs(30),
62 message_timeout: Duration::from_secs(10),
63 connection_timeout: Duration::from_secs(10),
64 retry_delay: Duration::from_millis(1000),
65 }
66 }
67}
68
69#[derive(Debug, Clone, PartialEq)]
71pub enum ConnectionState {
72 Disconnected,
73 Connecting,
74 Connected,
75 Reconnecting,
76 Failed,
77}
78
79pub struct WebSocketClient {
81 config: WebSocketClientConfig,
82 replica_id: ReplicaId,
83 connection_state: Arc<RwLock<ConnectionState>>,
84 message_sender: mpsc::UnboundedSender<Vec<u8>>,
85 message_receiver: Arc<RwLock<mpsc::UnboundedReceiver<Vec<u8>>>>,
86 heartbeat_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
87}
88
89impl WebSocketClient {
90 pub fn new(config: WebSocketClientConfig, replica_id: ReplicaId) -> Self {
92 let (tx, rx) = mpsc::unbounded_channel();
93
94 Self {
95 config,
96 replica_id,
97 connection_state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
98 message_sender: tx,
99 message_receiver: Arc::new(RwLock::new(rx)),
100 heartbeat_task: Arc::new(RwLock::new(None)),
101 }
102 }
103
104 pub fn with_url(url: String, replica_id: ReplicaId) -> Self {
106 let config = WebSocketClientConfig {
107 url,
108 ..Default::default()
109 };
110 Self::new(config, replica_id)
111 }
112
113 pub async fn connection_state(&self) -> ConnectionState {
115 self.connection_state.read().await.clone()
116 }
117
118 pub fn replica_id(&self) -> ReplicaId {
120 self.replica_id
121 }
122
123 pub async fn connect(&self) -> Result<(), WebSocketClientError> {
125 let mut state = self.connection_state.write().await;
126 if *state == ConnectionState::Connected {
127 return Ok(());
128 }
129
130 *state = ConnectionState::Connecting;
131 drop(state);
132
133 for attempt in 0..self.config.reconnect_attempts {
135 match self.attempt_connection().await {
136 Ok(()) => {
137 let mut state = self.connection_state.write().await;
138 *state = ConnectionState::Connected;
139
140 self.start_heartbeat().await;
142
143 return Ok(());
144 }
145 Err(e) => {
146 if attempt < self.config.reconnect_attempts - 1 {
147 tracing::warn!(
148 "Connection attempt {} failed: {}. Retrying in {:?}...",
149 attempt + 1,
150 e,
151 self.config.retry_delay
152 );
153
154 let mut state = self.connection_state.write().await;
155 *state = ConnectionState::Reconnecting;
156 drop(state);
157
158 tokio::time::sleep(self.config.retry_delay).await;
159 } else {
160 let mut state = self.connection_state.write().await;
161 *state = ConnectionState::Failed;
162 return Err(e);
163 }
164 }
165 }
166 }
167
168 Err(WebSocketClientError::ConnectionFailed("Max retry attempts exceeded".to_string()))
169 }
170
171 pub async fn disconnect(&self) -> Result<(), WebSocketClientError> {
173 self.stop_heartbeat().await;
175
176 let mut state = self.connection_state.write().await;
177 *state = ConnectionState::Disconnected;
178
179 tracing::debug!("Disconnected from WebSocket server");
181 Ok(())
182 }
183
184 pub async fn send_message(&self, message: SyncMessage) -> Result<(), WebSocketClientError> {
186 if !self.is_connected().await {
187 return Err(WebSocketClientError::NotConnected);
188 }
189
190 let serialized = MessageCodec::serialize(&message)
191 .map_err(|e| WebSocketClientError::SerializationFailed(e.to_string()))?;
192
193 self.send_raw(&serialized).await
194 }
195
196 pub async fn send_raw(&self, data: &[u8]) -> Result<(), WebSocketClientError> {
198 if !self.is_connected().await {
199 return Err(WebSocketClientError::NotConnected);
200 }
201
202 tracing::debug!("Would send {} bytes via WebSocket", data.len());
204 Ok(())
205 }
206
207 pub async fn receive_message(&self) -> Result<Option<SyncMessage>, WebSocketClientError> {
209 let mut receiver = self.message_receiver.write().await;
210
211 match timeout(self.config.message_timeout, receiver.recv()).await {
212 Ok(Some(data)) => {
213 let message = MessageCodec::deserialize(&data)
214 .map_err(|e| WebSocketClientError::SerializationFailed(e.to_string()))?;
215 Ok(Some(message))
216 }
217 Ok(None) => Ok(None),
218 Err(_) => Err(WebSocketClientError::Timeout("Message receive timeout".to_string())),
219 }
220 }
221
222 pub async fn is_connected(&self) -> bool {
224 *self.connection_state.read().await == ConnectionState::Connected
225 }
226
227 async fn attempt_connection(&self) -> Result<(), WebSocketClientError> {
230 tracing::debug!("Attempting to connect to {}", self.config.url);
233
234 tokio::time::sleep(Duration::from_millis(100)).await;
236
237 Ok(())
239 }
240
241 async fn start_heartbeat(&self) {
242 let config = self.config.clone();
243 let replica_id = self.replica_id;
244 let sender = self.message_sender.clone();
245 let state = self.connection_state.clone();
246
247 let heartbeat_task = tokio::spawn(async move {
248 let mut interval = interval(config.heartbeat_interval);
249
250 loop {
251 interval.tick().await;
252
253 if *state.read().await != ConnectionState::Connected {
255 break;
256 }
257
258 let heartbeat = SyncMessage::Heartbeat {
260 replica_id,
261 timestamp: SystemTime::now(),
262 };
263
264 match MessageCodec::serialize(&heartbeat) {
265 Ok(data) => {
266 if sender.send(data).is_err() {
267 tracing::warn!("Failed to send heartbeat - connection may be lost");
268 break;
269 }
270 }
271 Err(e) => {
272 tracing::error!("Failed to serialize heartbeat: {}", e);
273 break;
274 }
275 }
276 }
277 });
278
279 let mut task_handle = self.heartbeat_task.write().await;
280 *task_handle = Some(heartbeat_task);
281 }
282
283 async fn stop_heartbeat(&self) {
284 let mut task_handle = self.heartbeat_task.write().await;
285 if let Some(task) = task_handle.take() {
286 task.abort();
287 }
288 }
289}
290
291impl SyncTransport for WebSocketClient {
292 type Error = TransportError;
293
294 fn send<'a>(&'a self, data: &'a [u8]) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Self::Error>> + Send + 'a>> {
295 Box::pin(async move {
296 self.send_raw(data).await.map_err(|e| e.into())
297 })
298 }
299
300 fn receive(&self) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<Vec<u8>>, Self::Error>> + Send + '_>> {
301 Box::pin(async move {
302 match self.receive_message().await {
303 Ok(Some(message)) => {
304 let data = MessageCodec::serialize(&message)
305 .map_err(|e| TransportError::SerializationFailed(e.to_string()))?;
306 Ok(vec![data])
307 }
308 Ok(None) => Ok(Vec::new()),
309 Err(e) => Err(e.into()),
310 }
311 })
312 }
313
314 fn is_connected(&self) -> bool {
315 true }
319}
320
321impl Clone for WebSocketClient {
322 fn clone(&self) -> Self {
323 Self::new(self.config.clone(), self.replica_id)
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use crate::crdt::ReplicaId;
332
333 fn create_test_replica_id() -> ReplicaId {
334 ReplicaId::from(uuid::Uuid::new_v4())
335 }
336
337 #[tokio::test]
338 async fn test_websocket_client_creation() {
339 let replica_id = create_test_replica_id();
340 let config = WebSocketClientConfig::default();
341 let client = WebSocketClient::new(config, replica_id);
342
343 assert_eq!(client.replica_id(), replica_id);
344 assert_eq!(client.connection_state().await, ConnectionState::Disconnected);
345 }
346
347 #[tokio::test]
348 async fn test_websocket_client_with_url() {
349 let replica_id = create_test_replica_id();
350 let client = WebSocketClient::with_url("ws://test.example.com".to_string(), replica_id);
351
352 assert_eq!(client.config.url, "ws://test.example.com");
353 assert_eq!(client.replica_id(), replica_id);
354 }
355
356 #[tokio::test]
357 async fn test_connection_state_transitions() {
358 let replica_id = create_test_replica_id();
359 let client = WebSocketClient::new(WebSocketClientConfig::default(), replica_id);
360
361 assert_eq!(client.connection_state().await, ConnectionState::Disconnected);
363
364 let result = client.connect().await;
366 assert!(result.is_ok());
367 assert_eq!(client.connection_state().await, ConnectionState::Connected);
368
369 let result = client.disconnect().await;
371 assert!(result.is_ok());
372 assert_eq!(client.connection_state().await, ConnectionState::Disconnected);
373 }
374
375 #[tokio::test]
376 async fn test_send_message() {
377 let replica_id = create_test_replica_id();
378 let client = WebSocketClient::new(WebSocketClientConfig::default(), replica_id);
379
380 client.connect().await.unwrap();
382
383 let message = SyncMessage::Heartbeat {
385 replica_id: replica_id.clone(),
386 timestamp: SystemTime::now(),
387 };
388
389 let result = client.send_message(message).await;
390 assert!(result.is_ok());
391 }
392
393 #[tokio::test]
394 async fn test_sync_transport_implementation() {
395 let replica_id = create_test_replica_id();
396 let client = WebSocketClient::new(WebSocketClientConfig::default(), replica_id);
397
398 let test_data = b"test data";
400 let result = client.send(test_data).await;
401 assert!(result.is_ok());
402
403 let result = client.receive().await;
405 assert!(result.is_ok());
406 }
407}