1use crate::error::{CollabError, Result};
8use crate::events::ChangeEvent;
9use crate::sync::SyncMessage;
10use futures::{SinkExt, StreamExt};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::mpsc;
15use tokio::sync::RwLock;
16use tokio::time::sleep;
17use tokio_tungstenite::{connect_async, tungstenite::Message};
18use uuid::Uuid;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ClientConfig {
23 pub server_url: String,
25 pub auth_token: String,
27 pub max_reconnect_attempts: Option<u32>,
29 pub max_queue_size: usize,
31 pub initial_backoff_ms: u64,
33 pub max_backoff_ms: u64,
35}
36
37impl Default for ClientConfig {
38 fn default() -> Self {
39 Self {
40 server_url: String::new(),
41 auth_token: String::new(),
42 max_reconnect_attempts: None,
43 max_queue_size: 1000,
44 initial_backoff_ms: 1000,
45 max_backoff_ms: 30000,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum ConnectionState {
53 Disconnected,
55 Connecting,
57 Connected,
59 Reconnecting,
61}
62
63pub type WorkspaceUpdateCallback = Box<dyn Fn(ChangeEvent) + Send + Sync>;
65
66pub type StateChangeCallback = Box<dyn Fn(ConnectionState) + Send + Sync>;
68
69pub struct CollabClient {
71 config: ClientConfig,
73 _client_id: Uuid,
75 state: Arc<RwLock<ConnectionState>>,
77 message_queue: Arc<RwLock<Vec<SyncMessage>>>,
79 ws_sender: Arc<RwLock<Option<mpsc::UnboundedSender<SyncMessage>>>>,
81 connection_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
83 workspace_callbacks: Arc<RwLock<Vec<WorkspaceUpdateCallback>>>,
85 state_callbacks: Arc<RwLock<Vec<StateChangeCallback>>>,
87 reconnect_count: Arc<RwLock<u32>>,
89 stop_signal: Arc<RwLock<bool>>,
91}
92
93impl CollabClient {
94 pub async fn connect(config: ClientConfig) -> Result<Self> {
100 if config.server_url.is_empty() {
101 return Err(CollabError::InvalidInput("server_url cannot be empty".to_string()));
102 }
103
104 let client = Self {
105 config: config.clone(),
106 _client_id: Uuid::new_v4(),
107 state: Arc::new(RwLock::new(ConnectionState::Connecting)),
108 message_queue: Arc::new(RwLock::new(Vec::new())),
109 ws_sender: Arc::new(RwLock::new(None)),
110 connection_task: Arc::new(RwLock::new(None)),
111 workspace_callbacks: Arc::new(RwLock::new(Vec::new())),
112 state_callbacks: Arc::new(RwLock::new(Vec::new())),
113 reconnect_count: Arc::new(RwLock::new(0)),
114 stop_signal: Arc::new(RwLock::new(false)),
115 };
116
117 client.update_state(ConnectionState::Connecting).await;
119 client.start_connection_loop().await?;
120
121 Ok(client)
122 }
123
124 async fn start_connection_loop(&self) -> Result<()> {
126 let config = self.config.clone();
127 let state = self.state.clone();
128 let message_queue = self.message_queue.clone();
129 let ws_sender = self.ws_sender.clone();
130 let stop_signal = self.stop_signal.clone();
131 let reconnect_count = self.reconnect_count.clone();
132 let workspace_callbacks = self.workspace_callbacks.clone();
133 let state_callbacks = self.state_callbacks.clone();
134
135 let task = tokio::spawn(async move {
136 let mut backoff_ms = config.initial_backoff_ms;
137
138 loop {
139 if *stop_signal.read().await {
141 break;
142 }
143
144 match Self::try_connect(
146 &config,
147 &state,
148 &ws_sender,
149 &workspace_callbacks,
150 &state_callbacks,
151 &stop_signal,
152 )
153 .await
154 {
155 Ok(()) => {
156 backoff_ms = config.initial_backoff_ms;
158 *reconnect_count.write().await = 0;
159
160 let mut queue = message_queue.write().await;
162 while let Some(msg) = queue.pop() {
163 if let Some(ref sender) = *ws_sender.read().await {
164 let _ = sender.send(msg);
165 }
166 }
167
168 }
171 Err(e) => {
172 tracing::warn!("Connection failed: {}, will retry", e);
173
174 let current_count = *reconnect_count.read().await;
176 if let Some(max) = config.max_reconnect_attempts {
177 if current_count >= max {
178 tracing::error!("Max reconnect attempts ({}) reached", max);
179 *state.write().await = ConnectionState::Disconnected;
180 Self::notify_state_change(
181 &state_callbacks,
182 ConnectionState::Disconnected,
183 )
184 .await;
185 break;
186 }
187 }
188
189 *reconnect_count.write().await += 1;
190 *state.write().await = ConnectionState::Reconnecting;
191 Self::notify_state_change(&state_callbacks, ConnectionState::Reconnecting)
192 .await;
193
194 sleep(Duration::from_millis(backoff_ms)).await;
196 backoff_ms = (backoff_ms * 2).min(config.max_backoff_ms);
197 }
198 }
199 }
200 });
201
202 *self.connection_task.write().await = Some(task);
203 Ok(())
204 }
205
206 async fn try_connect(
208 config: &ClientConfig,
209 state: &Arc<RwLock<ConnectionState>>,
210 ws_sender: &Arc<RwLock<Option<mpsc::UnboundedSender<SyncMessage>>>>,
211 workspace_callbacks: &Arc<RwLock<Vec<WorkspaceUpdateCallback>>>,
212 state_callbacks: &Arc<RwLock<Vec<StateChangeCallback>>>,
213 stop_signal: &Arc<RwLock<bool>>,
214 ) -> Result<()> {
215 let url = format!("{}?token={}", config.server_url, config.auth_token);
217 tracing::info!("Connecting to WebSocket: {}", config.server_url);
218
219 let (ws_stream, _) = connect_async(&url)
221 .await
222 .map_err(|e| CollabError::Internal(format!("WebSocket connection failed: {e}")))?;
223
224 *state.write().await = ConnectionState::Connected;
225 Self::notify_state_change(state_callbacks, ConnectionState::Connected).await;
226
227 tracing::info!("WebSocket connected successfully");
228
229 let (write, mut read) = ws_stream.split();
231
232 let (tx, mut rx) = mpsc::unbounded_channel();
234 *ws_sender.write().await = Some(tx);
235
236 let mut write_handle = write;
238 let write_task = tokio::spawn(async move {
239 while let Some(msg) = rx.recv().await {
240 let json = match serde_json::to_string(&msg) {
241 Ok(json) => json,
242 Err(e) => {
243 tracing::error!("Failed to serialize message: {}", e);
244 continue;
245 }
246 };
247
248 if let Err(e) = write_handle.send(Message::Text(json)).await {
249 tracing::error!("Failed to send message: {}", e);
250 break;
251 }
252 }
253 });
254
255 loop {
257 if *stop_signal.read().await {
259 tracing::info!("Stop signal received, closing connection");
260 break;
261 }
262
263 tokio::select! {
264 msg_opt = read.next() => {
266 match msg_opt {
267 Some(Ok(Message::Text(text))) => {
268 Self::handle_server_message(&text, workspace_callbacks).await;
269 }
270 Some(Ok(Message::Close(_))) => {
271 tracing::info!("Server closed connection");
272 *state.write().await = ConnectionState::Disconnected;
273 Self::notify_state_change(state_callbacks, ConnectionState::Disconnected).await;
274 break;
275 }
276 Some(Ok(Message::Ping(_))) => {
277 tracing::debug!("Received ping");
279 }
280 Some(Ok(Message::Pong(_))) => {
281 tracing::debug!("Received pong");
282 }
283 Some(Err(e)) => {
284 tracing::error!("WebSocket error: {}", e);
285 *state.write().await = ConnectionState::Disconnected;
286 Self::notify_state_change(state_callbacks, ConnectionState::Disconnected).await;
287 return Err(CollabError::Internal(format!("WebSocket error: {e}")));
288 }
289 None => {
290 tracing::info!("WebSocket stream ended");
291 *state.write().await = ConnectionState::Disconnected;
292 Self::notify_state_change(state_callbacks, ConnectionState::Disconnected).await;
293 break;
294 }
295 _ => {}
296 }
297 }
298
299 () = sleep(Duration::from_millis(100)) => {
301 if *stop_signal.read().await {
302 tracing::info!("Stop signal received, closing connection");
303 break;
304 }
305 }
306 }
307 }
308
309 write_task.abort();
311 *ws_sender.write().await = None;
312
313 Err(CollabError::Internal("Connection closed".to_string()))
314 }
315
316 async fn handle_server_message(
318 text: &str,
319 workspace_callbacks: &Arc<RwLock<Vec<WorkspaceUpdateCallback>>>,
320 ) {
321 match serde_json::from_str::<SyncMessage>(text) {
322 Ok(SyncMessage::Change { event }) => {
323 let callbacks = workspace_callbacks.read().await;
325 for callback in callbacks.iter() {
326 callback(event.clone());
327 }
328 }
329 Ok(SyncMessage::StateResponse {
330 workspace_id,
331 version,
332 state: _,
333 }) => {
334 tracing::debug!(
335 "Received state response for workspace {} (version {})",
336 workspace_id,
337 version
338 );
339 }
341 Ok(SyncMessage::Error { message }) => {
342 tracing::error!("Server error: {}", message);
343 }
344 Ok(SyncMessage::Pong) => {
345 tracing::debug!("Received pong");
346 }
347 Ok(other) => {
348 tracing::debug!("Received message: {:?}", other);
349 }
350 Err(e) => {
351 tracing::warn!("Failed to parse server message: {} - {}", e, text);
352 }
353 }
354 }
355
356 async fn notify_state_change(
358 callbacks: &Arc<RwLock<Vec<StateChangeCallback>>>,
359 new_state: ConnectionState,
360 ) {
361 let callbacks = callbacks.read().await;
362 for callback in callbacks.iter() {
363 callback(new_state);
364 }
365 }
366
367 async fn update_state(&self, new_state: ConnectionState) {
369 *self.state.write().await = new_state;
370 let callbacks = self.state_callbacks.read().await;
371 for callback in callbacks.iter() {
372 callback(new_state);
373 }
374 }
375
376 async fn send_message(&self, message: SyncMessage) -> Result<()> {
378 let state = *self.state.read().await;
379
380 if state == ConnectionState::Connected {
381 if let Some(ref sender) = *self.ws_sender.read().await {
383 sender.send(message).map_err(|_| {
384 CollabError::Internal("Failed to send message (channel closed)".to_string())
385 })?;
386 return Ok(());
387 }
388 }
389
390 let mut queue = self.message_queue.write().await;
392 if queue.len() >= self.config.max_queue_size {
393 return Err(CollabError::InvalidInput(format!(
394 "Message queue full (max: {})",
395 self.config.max_queue_size
396 )));
397 }
398
399 queue.push(message);
400 drop(queue);
401 Ok(())
402 }
403
404 pub async fn on_workspace_update<F>(&self, callback: F)
409 where
410 F: Fn(ChangeEvent) + Send + Sync + 'static,
411 {
412 let mut callbacks = self.workspace_callbacks.write().await;
413 callbacks.push(Box::new(callback));
414 }
415
416 pub async fn on_state_change<F>(&self, callback: F)
421 where
422 F: Fn(ConnectionState) + Send + Sync + 'static,
423 {
424 let mut callbacks = self.state_callbacks.write().await;
425 callbacks.push(Box::new(callback));
426 }
427
428 pub async fn subscribe_to_workspace(&self, workspace_id: &str) -> Result<()> {
434 let workspace_id = Uuid::parse_str(workspace_id)
435 .map_err(|e| CollabError::InvalidInput(format!("Invalid workspace ID: {e}")))?;
436
437 let message = SyncMessage::Subscribe { workspace_id };
438 self.send_message(message).await?;
439
440 Ok(())
441 }
442
443 pub async fn unsubscribe_from_workspace(&self, workspace_id: &str) -> Result<()> {
449 let workspace_id = Uuid::parse_str(workspace_id)
450 .map_err(|e| CollabError::InvalidInput(format!("Invalid workspace ID: {e}")))?;
451
452 let message = SyncMessage::Unsubscribe { workspace_id };
453 self.send_message(message).await?;
454
455 Ok(())
456 }
457
458 pub async fn request_state(&self, workspace_id: &str, version: i64) -> Result<()> {
464 let workspace_id = Uuid::parse_str(workspace_id)
465 .map_err(|e| CollabError::InvalidInput(format!("Invalid workspace ID: {e}")))?;
466
467 let message = SyncMessage::StateRequest {
468 workspace_id,
469 version,
470 };
471 self.send_message(message).await?;
472
473 Ok(())
474 }
475
476 pub async fn ping(&self) -> Result<()> {
482 let message = SyncMessage::Ping;
483 self.send_message(message).await?;
484 Ok(())
485 }
486
487 pub async fn state(&self) -> ConnectionState {
489 *self.state.read().await
490 }
491
492 pub async fn queued_message_count(&self) -> usize {
494 self.message_queue.read().await.len()
495 }
496
497 pub async fn reconnect_count(&self) -> u32 {
499 *self.reconnect_count.read().await
500 }
501
502 pub async fn disconnect(&self) -> Result<()> {
508 *self.stop_signal.write().await = true;
510
511 *self.state.write().await = ConnectionState::Disconnected;
513 Self::notify_state_change(&self.state_callbacks, ConnectionState::Disconnected).await;
514
515 let task = self.connection_task.write().await.take();
517 if let Some(task) = task {
518 task.abort();
519 }
520
521 Ok(())
522 }
523}
524
525impl Drop for CollabClient {
526 fn drop(&mut self) {
527 let stop_signal = self.stop_signal.clone();
529 let state = self.state.clone();
530 if let Ok(handle) = tokio::runtime::Handle::try_current() {
531 handle.spawn(async move {
532 *stop_signal.write().await = true;
533 *state.write().await = ConnectionState::Disconnected;
534 });
535 }
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542
543 #[test]
544 fn test_client_config_default() {
545 let config = ClientConfig::default();
546
547 assert_eq!(config.server_url, String::new());
548 assert_eq!(config.auth_token, "");
549 assert_eq!(config.max_reconnect_attempts, None);
550 assert_eq!(config.max_queue_size, 1000);
551 assert_eq!(config.initial_backoff_ms, 1000);
552 assert_eq!(config.max_backoff_ms, 30000);
553 }
554
555 #[test]
556 fn test_client_config_clone() {
557 let config = ClientConfig {
558 server_url: "ws://localhost:8080".to_string(),
559 auth_token: "token123".to_string(),
560 max_reconnect_attempts: Some(5),
561 max_queue_size: 500,
562 initial_backoff_ms: 500,
563 max_backoff_ms: 10000,
564 };
565
566 let cloned = config.clone();
567
568 assert_eq!(config.server_url, cloned.server_url);
569 assert_eq!(config.auth_token, cloned.auth_token);
570 assert_eq!(config.max_reconnect_attempts, cloned.max_reconnect_attempts);
571 assert_eq!(config.max_queue_size, cloned.max_queue_size);
572 }
573
574 #[test]
575 fn test_client_config_serialization() {
576 let config = ClientConfig {
577 server_url: "ws://localhost:8080".to_string(),
578 auth_token: "token123".to_string(),
579 max_reconnect_attempts: Some(3),
580 max_queue_size: 200,
581 initial_backoff_ms: 1500,
582 max_backoff_ms: 20000,
583 };
584
585 let json = serde_json::to_string(&config).unwrap();
586 let deserialized: ClientConfig = serde_json::from_str(&json).unwrap();
587
588 assert_eq!(config.server_url, deserialized.server_url);
589 assert_eq!(config.auth_token, deserialized.auth_token);
590 assert_eq!(config.max_reconnect_attempts, deserialized.max_reconnect_attempts);
591 }
592
593 #[test]
594 fn test_connection_state_equality() {
595 assert_eq!(ConnectionState::Disconnected, ConnectionState::Disconnected);
596 assert_eq!(ConnectionState::Connecting, ConnectionState::Connecting);
597 assert_eq!(ConnectionState::Connected, ConnectionState::Connected);
598 assert_eq!(ConnectionState::Reconnecting, ConnectionState::Reconnecting);
599
600 assert_ne!(ConnectionState::Disconnected, ConnectionState::Connected);
601 assert_ne!(ConnectionState::Connecting, ConnectionState::Reconnecting);
602 }
603
604 #[test]
605 fn test_connection_state_copy() {
606 let state = ConnectionState::Connected;
607 let copied = state;
608
609 assert_eq!(state, copied);
610 }
611
612 #[test]
613 fn test_connection_state_debug() {
614 let state = ConnectionState::Connected;
615 let debug_str = format!("{state:?}");
616
617 assert!(debug_str.contains("Connected"));
618 }
619
620 #[tokio::test]
621 async fn test_connect_with_empty_url() {
622 let config = ClientConfig {
623 server_url: String::new(),
624 auth_token: "token".to_string(),
625 ..Default::default()
626 };
627
628 let result = CollabClient::connect(config).await;
629 assert!(result.is_err());
630
631 if let Err(e) = result {
632 match e {
633 CollabError::InvalidInput(msg) => {
634 assert!(msg.contains("server_url"));
635 }
636 _ => panic!("Expected InvalidInput error"),
637 }
638 }
639 }
640
641 #[tokio::test]
642 async fn test_subscribe_to_workspace_invalid_id() {
643 let workspace_id = "invalid-uuid";
646 let result = Uuid::parse_str(workspace_id);
647 assert!(result.is_err());
648 }
649
650 #[tokio::test]
651 async fn test_subscribe_to_workspace_valid_id() {
652 let workspace_id = Uuid::new_v4().to_string();
653 let result = Uuid::parse_str(&workspace_id);
654 assert!(result.is_ok());
655 }
656
657 #[test]
658 fn test_client_config_with_max_attempts() {
659 let config = ClientConfig {
660 max_reconnect_attempts: Some(10),
661 ..Default::default()
662 };
663
664 assert_eq!(config.max_reconnect_attempts, Some(10));
665 }
666
667 #[test]
668 fn test_client_config_unlimited_attempts() {
669 let config = ClientConfig {
670 max_reconnect_attempts: None,
671 ..Default::default()
672 };
673
674 assert_eq!(config.max_reconnect_attempts, None);
675 }
676
677 #[test]
678 fn test_client_config_queue_size() {
679 let config = ClientConfig {
680 max_queue_size: 5000,
681 ..Default::default()
682 };
683
684 assert_eq!(config.max_queue_size, 5000);
685 }
686
687 #[test]
688 fn test_client_config_backoff_values() {
689 let config = ClientConfig {
690 initial_backoff_ms: 2000,
691 max_backoff_ms: 60000,
692 ..Default::default()
693 };
694
695 assert_eq!(config.initial_backoff_ms, 2000);
696 assert_eq!(config.max_backoff_ms, 60000);
697 }
698
699 #[test]
700 fn test_sync_message_subscribe() {
701 let workspace_id = Uuid::new_v4();
702 let msg = SyncMessage::Subscribe { workspace_id };
703
704 let json = serde_json::to_string(&msg).unwrap();
705 let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
706
707 match deserialized {
708 SyncMessage::Subscribe {
709 workspace_id: ws_id,
710 } => {
711 assert_eq!(ws_id, workspace_id);
712 }
713 _ => panic!("Expected Subscribe message"),
714 }
715 }
716
717 #[test]
718 fn test_sync_message_unsubscribe() {
719 let workspace_id = Uuid::new_v4();
720 let msg = SyncMessage::Unsubscribe { workspace_id };
721
722 let json = serde_json::to_string(&msg).unwrap();
723 let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
724
725 match deserialized {
726 SyncMessage::Unsubscribe {
727 workspace_id: ws_id,
728 } => {
729 assert_eq!(ws_id, workspace_id);
730 }
731 _ => panic!("Expected Unsubscribe message"),
732 }
733 }
734
735 #[test]
736 fn test_sync_message_ping() {
737 let msg = SyncMessage::Ping;
738 let json = serde_json::to_string(&msg).unwrap();
739 let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
740
741 match deserialized {
742 SyncMessage::Ping => {}
743 _ => panic!("Expected Ping message"),
744 }
745 }
746
747 #[test]
748 fn test_sync_message_pong() {
749 let msg = SyncMessage::Pong;
750 let json = serde_json::to_string(&msg).unwrap();
751 let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
752
753 match deserialized {
754 SyncMessage::Pong => {}
755 _ => panic!("Expected Pong message"),
756 }
757 }
758
759 #[test]
760 fn test_sync_message_error() {
761 let msg = SyncMessage::Error {
762 message: "Test error".to_string(),
763 };
764
765 let json = serde_json::to_string(&msg).unwrap();
766 let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
767
768 match deserialized {
769 SyncMessage::Error { message } => {
770 assert_eq!(message, "Test error");
771 }
772 _ => panic!("Expected Error message"),
773 }
774 }
775
776 #[test]
777 fn test_sync_message_state_request() {
778 let workspace_id = Uuid::new_v4();
779 let msg = SyncMessage::StateRequest {
780 workspace_id,
781 version: 42,
782 };
783
784 let json = serde_json::to_string(&msg).unwrap();
785 let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
786
787 match deserialized {
788 SyncMessage::StateRequest {
789 workspace_id: ws_id,
790 version,
791 } => {
792 assert_eq!(ws_id, workspace_id);
793 assert_eq!(version, 42);
794 }
795 _ => panic!("Expected StateRequest message"),
796 }
797 }
798}