1use crate::error::{CollabError, Result};
8use crate::events::ChangeEvent;
9use crate::sync::SyncMessage;
10use futures::{SinkExt, StreamExt};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::mpsc;
15use tokio::sync::RwLock;
16use tokio::time::sleep;
17use tokio_tungstenite::{connect_async, tungstenite::Message};
18use uuid::Uuid;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ClientConfig {
23 pub server_url: String,
25 pub auth_token: String,
27 pub max_reconnect_attempts: Option<u32>,
29 pub max_queue_size: usize,
31 pub initial_backoff_ms: u64,
33 pub max_backoff_ms: u64,
35}
36
37impl Default for ClientConfig {
38 fn default() -> Self {
39 Self {
40 server_url: String::new(),
41 auth_token: String::new(),
42 max_reconnect_attempts: None,
43 max_queue_size: 1000,
44 initial_backoff_ms: 1000,
45 max_backoff_ms: 30000,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum ConnectionState {
53 Disconnected,
55 Connecting,
57 Connected,
59 Reconnecting,
61}
62
63pub type WorkspaceUpdateCallback = Box<dyn Fn(ChangeEvent) + Send + Sync>;
65
66pub type StateChangeCallback = Box<dyn Fn(ConnectionState) + Send + Sync>;
68
69pub struct CollabClient {
71 config: ClientConfig,
73 client_id: Uuid,
75 state: Arc<RwLock<ConnectionState>>,
77 message_queue: Arc<RwLock<Vec<SyncMessage>>>,
79 ws_sender: Arc<RwLock<Option<mpsc::UnboundedSender<SyncMessage>>>>,
81 connection_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
83 workspace_callbacks: Arc<RwLock<Vec<WorkspaceUpdateCallback>>>,
85 state_callbacks: Arc<RwLock<Vec<StateChangeCallback>>>,
87 reconnect_count: Arc<RwLock<u32>>,
89 stop_signal: Arc<RwLock<bool>>,
91}
92
93impl CollabClient {
94 pub async fn connect(config: ClientConfig) -> Result<Self> {
96 if config.server_url.is_empty() {
97 return Err(CollabError::InvalidInput("server_url cannot be empty".to_string()));
98 }
99
100 let client = Self {
101 config: config.clone(),
102 client_id: Uuid::new_v4(),
103 state: Arc::new(RwLock::new(ConnectionState::Connecting)),
104 message_queue: Arc::new(RwLock::new(Vec::new())),
105 ws_sender: Arc::new(RwLock::new(None)),
106 connection_task: Arc::new(RwLock::new(None)),
107 workspace_callbacks: Arc::new(RwLock::new(Vec::new())),
108 state_callbacks: Arc::new(RwLock::new(Vec::new())),
109 reconnect_count: Arc::new(RwLock::new(0)),
110 stop_signal: Arc::new(RwLock::new(false)),
111 };
112
113 client.update_state(ConnectionState::Connecting).await;
115 client.start_connection_loop().await?;
116
117 Ok(client)
118 }
119
120 async fn start_connection_loop(&self) -> Result<()> {
122 let config = self.config.clone();
123 let state = self.state.clone();
124 let message_queue = self.message_queue.clone();
125 let ws_sender = self.ws_sender.clone();
126 let stop_signal = self.stop_signal.clone();
127 let reconnect_count = self.reconnect_count.clone();
128 let workspace_callbacks = self.workspace_callbacks.clone();
129 let state_callbacks = self.state_callbacks.clone();
130
131 let task = tokio::spawn(async move {
132 let mut backoff_ms = config.initial_backoff_ms;
133
134 loop {
135 if *stop_signal.read().await {
137 break;
138 }
139
140 match Self::try_connect(
142 &config,
143 &state,
144 &ws_sender,
145 &workspace_callbacks,
146 &state_callbacks,
147 &stop_signal,
148 )
149 .await
150 {
151 Ok(()) => {
152 backoff_ms = config.initial_backoff_ms;
154 *reconnect_count.write().await = 0;
155
156 let mut queue = message_queue.write().await;
158 while let Some(msg) = queue.pop() {
159 if let Some(ref sender) = *ws_sender.read().await {
160 let _ = sender.send(msg);
161 }
162 }
163
164 }
167 Err(e) => {
168 tracing::warn!("Connection failed: {}, will retry", e);
169
170 let current_count = *reconnect_count.read().await;
172 if let Some(max) = config.max_reconnect_attempts {
173 if current_count >= max {
174 tracing::error!("Max reconnect attempts ({}) reached", max);
175 *state.write().await = ConnectionState::Disconnected;
176 Self::notify_state_change(
177 &state_callbacks,
178 ConnectionState::Disconnected,
179 )
180 .await;
181 break;
182 }
183 }
184
185 *reconnect_count.write().await += 1;
186 *state.write().await = ConnectionState::Reconnecting;
187 Self::notify_state_change(&state_callbacks, ConnectionState::Reconnecting)
188 .await;
189
190 sleep(Duration::from_millis(backoff_ms)).await;
192 backoff_ms = (backoff_ms * 2).min(config.max_backoff_ms);
193 }
194 }
195 }
196 });
197
198 *self.connection_task.write().await = Some(task);
199 Ok(())
200 }
201
202 async fn try_connect(
204 config: &ClientConfig,
205 state: &Arc<RwLock<ConnectionState>>,
206 ws_sender: &Arc<RwLock<Option<mpsc::UnboundedSender<SyncMessage>>>>,
207 workspace_callbacks: &Arc<RwLock<Vec<WorkspaceUpdateCallback>>>,
208 state_callbacks: &Arc<RwLock<Vec<StateChangeCallback>>>,
209 stop_signal: &Arc<RwLock<bool>>,
210 ) -> Result<()> {
211 let url = format!("{}?token={}", config.server_url, config.auth_token);
213 tracing::info!("Connecting to WebSocket: {}", config.server_url);
214
215 let (ws_stream, _) = connect_async(&url)
217 .await
218 .map_err(|e| CollabError::Internal(format!("WebSocket connection failed: {e}")))?;
219
220 *state.write().await = ConnectionState::Connected;
221 Self::notify_state_change(state_callbacks, ConnectionState::Connected).await;
222
223 tracing::info!("WebSocket connected successfully");
224
225 let (write, mut read) = ws_stream.split();
227
228 let (tx, mut rx) = mpsc::unbounded_channel();
230 *ws_sender.write().await = Some(tx);
231
232 let mut write_handle = write;
234 let write_task = tokio::spawn(async move {
235 while let Some(msg) = rx.recv().await {
236 let json = match serde_json::to_string(&msg) {
237 Ok(json) => json,
238 Err(e) => {
239 tracing::error!("Failed to serialize message: {}", e);
240 continue;
241 }
242 };
243
244 if let Err(e) = write_handle.send(Message::Text(json)).await {
245 tracing::error!("Failed to send message: {}", e);
246 break;
247 }
248 }
249 });
250
251 loop {
253 if *stop_signal.read().await {
255 tracing::info!("Stop signal received, closing connection");
256 break;
257 }
258
259 tokio::select! {
260 msg_opt = read.next() => {
262 match msg_opt {
263 Some(Ok(Message::Text(text))) => {
264 Self::handle_server_message(&text, workspace_callbacks).await;
265 }
266 Some(Ok(Message::Close(_))) => {
267 tracing::info!("Server closed connection");
268 *state.write().await = ConnectionState::Disconnected;
269 Self::notify_state_change(state_callbacks, ConnectionState::Disconnected).await;
270 break;
271 }
272 Some(Ok(Message::Ping(_))) => {
273 tracing::debug!("Received ping");
275 }
276 Some(Ok(Message::Pong(_))) => {
277 tracing::debug!("Received pong");
278 }
279 Some(Err(e)) => {
280 tracing::error!("WebSocket error: {}", e);
281 *state.write().await = ConnectionState::Disconnected;
282 Self::notify_state_change(state_callbacks, ConnectionState::Disconnected).await;
283 return Err(CollabError::Internal(format!("WebSocket error: {e}")));
284 }
285 None => {
286 tracing::info!("WebSocket stream ended");
287 *state.write().await = ConnectionState::Disconnected;
288 Self::notify_state_change(state_callbacks, ConnectionState::Disconnected).await;
289 break;
290 }
291 _ => {}
292 }
293 }
294
295 () = sleep(Duration::from_millis(100)) => {
297 if *stop_signal.read().await {
298 tracing::info!("Stop signal received, closing connection");
299 break;
300 }
301 }
302 }
303 }
304
305 write_task.abort();
307 *ws_sender.write().await = None;
308
309 Err(CollabError::Internal("Connection closed".to_string()))
310 }
311
312 async fn handle_server_message(
314 text: &str,
315 workspace_callbacks: &Arc<RwLock<Vec<WorkspaceUpdateCallback>>>,
316 ) {
317 match serde_json::from_str::<SyncMessage>(text) {
318 Ok(SyncMessage::Change { event }) => {
319 let callbacks = workspace_callbacks.read().await;
321 for callback in callbacks.iter() {
322 callback(event.clone());
323 }
324 }
325 Ok(SyncMessage::StateResponse {
326 workspace_id,
327 version,
328 state,
329 }) => {
330 tracing::debug!(
331 "Received state response for workspace {} (version {})",
332 workspace_id,
333 version
334 );
335 }
337 Ok(SyncMessage::Error { message }) => {
338 tracing::error!("Server error: {}", message);
339 }
340 Ok(SyncMessage::Pong) => {
341 tracing::debug!("Received pong");
342 }
343 Ok(other) => {
344 tracing::debug!("Received message: {:?}", other);
345 }
346 Err(e) => {
347 tracing::warn!("Failed to parse server message: {} - {}", e, text);
348 }
349 }
350 }
351
352 async fn notify_state_change(
354 callbacks: &Arc<RwLock<Vec<StateChangeCallback>>>,
355 new_state: ConnectionState,
356 ) {
357 let callbacks = callbacks.read().await;
358 for callback in callbacks.iter() {
359 callback(new_state);
360 }
361 }
362
363 async fn update_state(&self, new_state: ConnectionState) {
365 *self.state.write().await = new_state;
366 let callbacks = self.state_callbacks.read().await;
367 for callback in callbacks.iter() {
368 callback(new_state);
369 }
370 }
371
372 async fn send_message(&self, message: SyncMessage) -> Result<()> {
374 let state = *self.state.read().await;
375
376 if state == ConnectionState::Connected {
377 if let Some(ref sender) = *self.ws_sender.read().await {
379 sender.send(message).map_err(|_| {
380 CollabError::Internal("Failed to send message (channel closed)".to_string())
381 })?;
382 return Ok(());
383 }
384 }
385
386 let mut queue = self.message_queue.write().await;
388 if queue.len() >= self.config.max_queue_size {
389 return Err(CollabError::InvalidInput(format!(
390 "Message queue full (max: {})",
391 self.config.max_queue_size
392 )));
393 }
394
395 queue.push(message);
396 Ok(())
397 }
398
399 pub async fn on_workspace_update<F>(&self, callback: F)
404 where
405 F: Fn(ChangeEvent) + Send + Sync + 'static,
406 {
407 let mut callbacks = self.workspace_callbacks.write().await;
408 callbacks.push(Box::new(callback));
409 }
410
411 pub async fn on_state_change<F>(&self, callback: F)
416 where
417 F: Fn(ConnectionState) + Send + Sync + 'static,
418 {
419 let mut callbacks = self.state_callbacks.write().await;
420 callbacks.push(Box::new(callback));
421 }
422
423 pub async fn subscribe_to_workspace(&self, workspace_id: &str) -> Result<()> {
425 let workspace_id = Uuid::parse_str(workspace_id)
426 .map_err(|e| CollabError::InvalidInput(format!("Invalid workspace ID: {e}")))?;
427
428 let message = SyncMessage::Subscribe { workspace_id };
429 self.send_message(message).await?;
430
431 Ok(())
432 }
433
434 pub async fn unsubscribe_from_workspace(&self, workspace_id: &str) -> Result<()> {
436 let workspace_id = Uuid::parse_str(workspace_id)
437 .map_err(|e| CollabError::InvalidInput(format!("Invalid workspace ID: {e}")))?;
438
439 let message = SyncMessage::Unsubscribe { workspace_id };
440 self.send_message(message).await?;
441
442 Ok(())
443 }
444
445 pub async fn request_state(&self, workspace_id: &str, version: i64) -> Result<()> {
447 let workspace_id = Uuid::parse_str(workspace_id)
448 .map_err(|e| CollabError::InvalidInput(format!("Invalid workspace ID: {e}")))?;
449
450 let message = SyncMessage::StateRequest {
451 workspace_id,
452 version,
453 };
454 self.send_message(message).await?;
455
456 Ok(())
457 }
458
459 pub async fn ping(&self) -> Result<()> {
461 let message = SyncMessage::Ping;
462 self.send_message(message).await?;
463 Ok(())
464 }
465
466 pub async fn state(&self) -> ConnectionState {
468 *self.state.read().await
469 }
470
471 pub async fn queued_message_count(&self) -> usize {
473 self.message_queue.read().await.len()
474 }
475
476 pub async fn reconnect_count(&self) -> u32 {
478 *self.reconnect_count.read().await
479 }
480
481 pub async fn disconnect(&self) -> Result<()> {
483 *self.stop_signal.write().await = true;
485
486 *self.state.write().await = ConnectionState::Disconnected;
488 Self::notify_state_change(&self.state_callbacks, ConnectionState::Disconnected).await;
489
490 if let Some(task) = self.connection_task.write().await.take() {
492 task.abort();
493 }
494
495 Ok(())
496 }
497}
498
499impl Drop for CollabClient {
500 fn drop(&mut self) {
501 let stop_signal = self.stop_signal.clone();
503 let state = self.state.clone();
504 if let Ok(handle) = tokio::runtime::Handle::try_current() {
505 handle.spawn(async move {
506 *stop_signal.write().await = true;
507 *state.write().await = ConnectionState::Disconnected;
508 });
509 }
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn test_client_config_default() {
519 let config = ClientConfig::default();
520
521 assert_eq!(config.server_url, "");
522 assert_eq!(config.auth_token, "");
523 assert_eq!(config.max_reconnect_attempts, None);
524 assert_eq!(config.max_queue_size, 1000);
525 assert_eq!(config.initial_backoff_ms, 1000);
526 assert_eq!(config.max_backoff_ms, 30000);
527 }
528
529 #[test]
530 fn test_client_config_clone() {
531 let config = ClientConfig {
532 server_url: "ws://localhost:8080".to_string(),
533 auth_token: "token123".to_string(),
534 max_reconnect_attempts: Some(5),
535 max_queue_size: 500,
536 initial_backoff_ms: 500,
537 max_backoff_ms: 10000,
538 };
539
540 let cloned = config.clone();
541
542 assert_eq!(config.server_url, cloned.server_url);
543 assert_eq!(config.auth_token, cloned.auth_token);
544 assert_eq!(config.max_reconnect_attempts, cloned.max_reconnect_attempts);
545 assert_eq!(config.max_queue_size, cloned.max_queue_size);
546 }
547
548 #[test]
549 fn test_client_config_serialization() {
550 let config = ClientConfig {
551 server_url: "ws://localhost:8080".to_string(),
552 auth_token: "token123".to_string(),
553 max_reconnect_attempts: Some(3),
554 max_queue_size: 200,
555 initial_backoff_ms: 1500,
556 max_backoff_ms: 20000,
557 };
558
559 let json = serde_json::to_string(&config).unwrap();
560 let deserialized: ClientConfig = serde_json::from_str(&json).unwrap();
561
562 assert_eq!(config.server_url, deserialized.server_url);
563 assert_eq!(config.auth_token, deserialized.auth_token);
564 assert_eq!(config.max_reconnect_attempts, deserialized.max_reconnect_attempts);
565 }
566
567 #[test]
568 fn test_connection_state_equality() {
569 assert_eq!(ConnectionState::Disconnected, ConnectionState::Disconnected);
570 assert_eq!(ConnectionState::Connecting, ConnectionState::Connecting);
571 assert_eq!(ConnectionState::Connected, ConnectionState::Connected);
572 assert_eq!(ConnectionState::Reconnecting, ConnectionState::Reconnecting);
573
574 assert_ne!(ConnectionState::Disconnected, ConnectionState::Connected);
575 assert_ne!(ConnectionState::Connecting, ConnectionState::Reconnecting);
576 }
577
578 #[test]
579 fn test_connection_state_copy() {
580 let state = ConnectionState::Connected;
581 let copied = state;
582
583 assert_eq!(state, copied);
584 }
585
586 #[test]
587 fn test_connection_state_debug() {
588 let state = ConnectionState::Connected;
589 let debug_str = format!("{:?}", state);
590
591 assert!(debug_str.contains("Connected"));
592 }
593
594 #[tokio::test]
595 async fn test_connect_with_empty_url() {
596 let config = ClientConfig {
597 server_url: "".to_string(),
598 auth_token: "token".to_string(),
599 ..Default::default()
600 };
601
602 let result = CollabClient::connect(config).await;
603 assert!(result.is_err());
604
605 if let Err(e) = result {
606 match e {
607 CollabError::InvalidInput(msg) => {
608 assert!(msg.contains("server_url"));
609 }
610 _ => panic!("Expected InvalidInput error"),
611 }
612 }
613 }
614
615 #[tokio::test]
616 async fn test_subscribe_to_workspace_invalid_id() {
617 let workspace_id = "invalid-uuid";
620 let result = Uuid::parse_str(workspace_id);
621 assert!(result.is_err());
622 }
623
624 #[tokio::test]
625 async fn test_subscribe_to_workspace_valid_id() {
626 let workspace_id = Uuid::new_v4().to_string();
627 let result = Uuid::parse_str(&workspace_id);
628 assert!(result.is_ok());
629 }
630
631 #[test]
632 fn test_client_config_with_max_attempts() {
633 let config = ClientConfig {
634 max_reconnect_attempts: Some(10),
635 ..Default::default()
636 };
637
638 assert_eq!(config.max_reconnect_attempts, Some(10));
639 }
640
641 #[test]
642 fn test_client_config_unlimited_attempts() {
643 let config = ClientConfig {
644 max_reconnect_attempts: None,
645 ..Default::default()
646 };
647
648 assert_eq!(config.max_reconnect_attempts, None);
649 }
650
651 #[test]
652 fn test_client_config_queue_size() {
653 let config = ClientConfig {
654 max_queue_size: 5000,
655 ..Default::default()
656 };
657
658 assert_eq!(config.max_queue_size, 5000);
659 }
660
661 #[test]
662 fn test_client_config_backoff_values() {
663 let config = ClientConfig {
664 initial_backoff_ms: 2000,
665 max_backoff_ms: 60000,
666 ..Default::default()
667 };
668
669 assert_eq!(config.initial_backoff_ms, 2000);
670 assert_eq!(config.max_backoff_ms, 60000);
671 }
672
673 #[test]
674 fn test_sync_message_subscribe() {
675 let workspace_id = Uuid::new_v4();
676 let msg = SyncMessage::Subscribe { workspace_id };
677
678 let json = serde_json::to_string(&msg).unwrap();
679 let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
680
681 match deserialized {
682 SyncMessage::Subscribe {
683 workspace_id: ws_id,
684 } => {
685 assert_eq!(ws_id, workspace_id);
686 }
687 _ => panic!("Expected Subscribe message"),
688 }
689 }
690
691 #[test]
692 fn test_sync_message_unsubscribe() {
693 let workspace_id = Uuid::new_v4();
694 let msg = SyncMessage::Unsubscribe { workspace_id };
695
696 let json = serde_json::to_string(&msg).unwrap();
697 let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
698
699 match deserialized {
700 SyncMessage::Unsubscribe {
701 workspace_id: ws_id,
702 } => {
703 assert_eq!(ws_id, workspace_id);
704 }
705 _ => panic!("Expected Unsubscribe message"),
706 }
707 }
708
709 #[test]
710 fn test_sync_message_ping() {
711 let msg = SyncMessage::Ping;
712 let json = serde_json::to_string(&msg).unwrap();
713 let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
714
715 match deserialized {
716 SyncMessage::Ping => {}
717 _ => panic!("Expected Ping message"),
718 }
719 }
720
721 #[test]
722 fn test_sync_message_pong() {
723 let msg = SyncMessage::Pong;
724 let json = serde_json::to_string(&msg).unwrap();
725 let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
726
727 match deserialized {
728 SyncMessage::Pong => {}
729 _ => panic!("Expected Pong message"),
730 }
731 }
732
733 #[test]
734 fn test_sync_message_error() {
735 let msg = SyncMessage::Error {
736 message: "Test error".to_string(),
737 };
738
739 let json = serde_json::to_string(&msg).unwrap();
740 let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
741
742 match deserialized {
743 SyncMessage::Error { message } => {
744 assert_eq!(message, "Test error");
745 }
746 _ => panic!("Expected Error message"),
747 }
748 }
749
750 #[test]
751 fn test_sync_message_state_request() {
752 let workspace_id = Uuid::new_v4();
753 let msg = SyncMessage::StateRequest {
754 workspace_id,
755 version: 42,
756 };
757
758 let json = serde_json::to_string(&msg).unwrap();
759 let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
760
761 match deserialized {
762 SyncMessage::StateRequest {
763 workspace_id: ws_id,
764 version,
765 } => {
766 assert_eq!(ws_id, workspace_id);
767 assert_eq!(version, 42);
768 }
769 _ => panic!("Expected StateRequest message"),
770 }
771 }
772}