mockforge_collab/
client.rs

1//! Collaboration client for connecting to servers
2//!
3//! This module provides a client library for connecting to MockForge collaboration servers
4//! via WebSocket. It handles connection management, automatic reconnection, message queuing,
5//! and provides an event-driven API for workspace updates.
6
7use crate::error::{CollabError, Result};
8use crate::events::ChangeEvent;
9use crate::sync::SyncMessage;
10use futures::{SinkExt, StreamExt};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::mpsc;
15use tokio::sync::RwLock;
16use tokio::time::sleep;
17use tokio_tungstenite::{connect_async, tungstenite::Message};
18use uuid::Uuid;
19
20/// Client configuration
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ClientConfig {
23    /// Server WebSocket URL (e.g., ws://localhost:8080/ws or wss://api.example.com/ws)
24    pub server_url: String,
25    /// Authentication token (JWT)
26    pub auth_token: String,
27    /// Maximum reconnect attempts (None for unlimited)
28    pub max_reconnect_attempts: Option<u32>,
29    /// Maximum queue size for messages (when disconnected)
30    pub max_queue_size: usize,
31    /// Initial backoff delay in milliseconds (exponential backoff starts here)
32    pub initial_backoff_ms: u64,
33    /// Maximum backoff delay in milliseconds
34    pub max_backoff_ms: u64,
35}
36
37impl Default for ClientConfig {
38    fn default() -> Self {
39        Self {
40            server_url: String::new(),
41            auth_token: String::new(),
42            max_reconnect_attempts: None,
43            max_queue_size: 1000,
44            initial_backoff_ms: 1000,
45            max_backoff_ms: 30000,
46        }
47    }
48}
49
50/// Connection state
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum ConnectionState {
53    /// Not connected
54    Disconnected,
55    /// Connecting
56    Connecting,
57    /// Connected and ready
58    Connected,
59    /// Reconnecting after error
60    Reconnecting,
61}
62
63/// Callback function type for workspace updates
64pub type WorkspaceUpdateCallback = Box<dyn Fn(ChangeEvent) + Send + Sync>;
65
66/// Callback function type for connection state changes
67pub type StateChangeCallback = Box<dyn Fn(ConnectionState) + Send + Sync>;
68
69/// Collaboration client
70pub struct CollabClient {
71    /// Configuration
72    config: ClientConfig,
73    /// Client ID
74    client_id: Uuid,
75    /// Connection state
76    state: Arc<RwLock<ConnectionState>>,
77    /// Message queue for when disconnected
78    message_queue: Arc<RwLock<Vec<SyncMessage>>>,
79    /// WebSocket connection handle
80    ws_sender: Arc<RwLock<Option<mpsc::UnboundedSender<SyncMessage>>>>,
81    /// Connection task handle for cleanup
82    connection_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
83    /// Workspace update callbacks
84    workspace_callbacks: Arc<RwLock<Vec<WorkspaceUpdateCallback>>>,
85    /// State change callbacks
86    state_callbacks: Arc<RwLock<Vec<StateChangeCallback>>>,
87    /// Reconnect attempt count
88    reconnect_count: Arc<RwLock<u32>>,
89    /// Stop signal
90    stop_signal: Arc<RwLock<bool>>,
91}
92
93impl CollabClient {
94    /// Create a new client and connect to server
95    pub async fn connect(config: ClientConfig) -> Result<Self> {
96        if config.server_url.is_empty() {
97            return Err(CollabError::InvalidInput("server_url cannot be empty".to_string()));
98        }
99
100        let client = Self {
101            config: config.clone(),
102            client_id: Uuid::new_v4(),
103            state: Arc::new(RwLock::new(ConnectionState::Connecting)),
104            message_queue: Arc::new(RwLock::new(Vec::new())),
105            ws_sender: Arc::new(RwLock::new(None)),
106            connection_task: Arc::new(RwLock::new(None)),
107            workspace_callbacks: Arc::new(RwLock::new(Vec::new())),
108            state_callbacks: Arc::new(RwLock::new(Vec::new())),
109            reconnect_count: Arc::new(RwLock::new(0)),
110            stop_signal: Arc::new(RwLock::new(false)),
111        };
112
113        // Start connection process
114        client.update_state(ConnectionState::Connecting).await;
115        client.start_connection_loop().await?;
116
117        Ok(client)
118    }
119
120    /// Internal: Start the connection loop with reconnection logic
121    async fn start_connection_loop(&self) -> Result<()> {
122        let config = self.config.clone();
123        let state = self.state.clone();
124        let message_queue = self.message_queue.clone();
125        let ws_sender = self.ws_sender.clone();
126        let stop_signal = self.stop_signal.clone();
127        let reconnect_count = self.reconnect_count.clone();
128        let workspace_callbacks = self.workspace_callbacks.clone();
129        let state_callbacks = self.state_callbacks.clone();
130
131        let task = tokio::spawn(async move {
132            let mut backoff_ms = config.initial_backoff_ms;
133
134            loop {
135                // Check if we should stop
136                if *stop_signal.read().await {
137                    break;
138                }
139
140                // Attempt connection
141                match Self::try_connect(
142                    &config,
143                    &state,
144                    &ws_sender,
145                    &workspace_callbacks,
146                    &state_callbacks,
147                    &stop_signal,
148                )
149                .await
150                {
151                    Ok(()) => {
152                        // Connection successful, reset backoff
153                        backoff_ms = config.initial_backoff_ms;
154                        *reconnect_count.write().await = 0;
155
156                        // Flush message queue
157                        let mut queue = message_queue.write().await;
158                        while let Some(msg) = queue.pop() {
159                            if let Some(ref sender) = *ws_sender.read().await {
160                                let _ = sender.send(msg);
161                            }
162                        }
163
164                        // Wait for connection to close
165                        // (This will happen when try_connect returns on error/disconnect)
166                    }
167                    Err(e) => {
168                        tracing::warn!("Connection failed: {}, will retry", e);
169
170                        // Check max reconnect attempts
171                        let current_count = *reconnect_count.read().await;
172                        if let Some(max) = config.max_reconnect_attempts {
173                            if current_count >= max {
174                                tracing::error!("Max reconnect attempts ({}) reached", max);
175                                *state.write().await = ConnectionState::Disconnected;
176                                Self::notify_state_change(
177                                    &state_callbacks,
178                                    ConnectionState::Disconnected,
179                                )
180                                .await;
181                                break;
182                            }
183                        }
184
185                        *reconnect_count.write().await += 1;
186                        *state.write().await = ConnectionState::Reconnecting;
187                        Self::notify_state_change(&state_callbacks, ConnectionState::Reconnecting)
188                            .await;
189
190                        // Exponential backoff
191                        sleep(Duration::from_millis(backoff_ms)).await;
192                        backoff_ms = (backoff_ms * 2).min(config.max_backoff_ms);
193                    }
194                }
195            }
196        });
197
198        *self.connection_task.write().await = Some(task);
199        Ok(())
200    }
201
202    /// Internal: Attempt to establish WebSocket connection
203    async fn try_connect(
204        config: &ClientConfig,
205        state: &Arc<RwLock<ConnectionState>>,
206        ws_sender: &Arc<RwLock<Option<mpsc::UnboundedSender<SyncMessage>>>>,
207        workspace_callbacks: &Arc<RwLock<Vec<WorkspaceUpdateCallback>>>,
208        state_callbacks: &Arc<RwLock<Vec<StateChangeCallback>>>,
209        stop_signal: &Arc<RwLock<bool>>,
210    ) -> Result<()> {
211        // Build WebSocket URL with auth token
212        let url = format!("{}?token={}", config.server_url, config.auth_token);
213        tracing::info!("Connecting to WebSocket: {}", config.server_url);
214
215        // Connect to WebSocket
216        let (ws_stream, _) = connect_async(&url)
217            .await
218            .map_err(|e| CollabError::Internal(format!("WebSocket connection failed: {}", e)))?;
219
220        *state.write().await = ConnectionState::Connected;
221        Self::notify_state_change(state_callbacks, ConnectionState::Connected).await;
222
223        tracing::info!("WebSocket connected successfully");
224
225        // Split stream into sender and receiver
226        let (mut write, mut read) = ws_stream.split();
227
228        // Create message channel for sending messages
229        let (tx, mut rx) = mpsc::unbounded_channel();
230        *ws_sender.write().await = Some(tx);
231
232        // Spawn task to handle outgoing messages
233        let mut write_handle = write;
234        let write_task = tokio::spawn(async move {
235            while let Some(msg) = rx.recv().await {
236                let json = match serde_json::to_string(&msg) {
237                    Ok(json) => json,
238                    Err(e) => {
239                        tracing::error!("Failed to serialize message: {}", e);
240                        continue;
241                    }
242                };
243
244                if let Err(e) = write_handle.send(Message::Text(json)).await {
245                    tracing::error!("Failed to send message: {}", e);
246                    break;
247                }
248            }
249        });
250
251        // Handle incoming messages
252        loop {
253            // Check for stop signal first
254            if *stop_signal.read().await {
255                tracing::info!("Stop signal received, closing connection");
256                break;
257            }
258
259            tokio::select! {
260                // Receive message from server
261                msg_opt = read.next() => {
262                    match msg_opt {
263                        Some(Ok(Message::Text(text))) => {
264                            Self::handle_server_message(&text, workspace_callbacks).await;
265                        }
266                        Some(Ok(Message::Close(_))) => {
267                            tracing::info!("Server closed connection");
268                            *state.write().await = ConnectionState::Disconnected;
269                            Self::notify_state_change(state_callbacks, ConnectionState::Disconnected).await;
270                            break;
271                        }
272                        Some(Ok(Message::Ping(_))) => {
273                            // Tungstenite handles pings automatically
274                            tracing::debug!("Received ping");
275                        }
276                        Some(Ok(Message::Pong(_))) => {
277                            tracing::debug!("Received pong");
278                        }
279                        Some(Err(e)) => {
280                            tracing::error!("WebSocket error: {}", e);
281                            *state.write().await = ConnectionState::Disconnected;
282                            Self::notify_state_change(state_callbacks, ConnectionState::Disconnected).await;
283                            return Err(CollabError::Internal(format!("WebSocket error: {}", e)));
284                        }
285                        None => {
286                            tracing::info!("WebSocket stream ended");
287                            *state.write().await = ConnectionState::Disconnected;
288                            Self::notify_state_change(state_callbacks, ConnectionState::Disconnected).await;
289                            break;
290                        }
291                        _ => {}
292                    }
293                }
294
295                // Periodic stop signal check
296                _ = tokio::time::sleep(Duration::from_millis(100)) => {
297                    if *stop_signal.read().await {
298                        tracing::info!("Stop signal received, closing connection");
299                        break;
300                    }
301                }
302            }
303        }
304
305        // Clean up
306        write_task.abort();
307        *ws_sender.write().await = None;
308
309        Err(CollabError::Internal("Connection closed".to_string()))
310    }
311
312    /// Internal: Handle message from server
313    async fn handle_server_message(
314        text: &str,
315        workspace_callbacks: &Arc<RwLock<Vec<WorkspaceUpdateCallback>>>,
316    ) {
317        match serde_json::from_str::<SyncMessage>(text) {
318            Ok(SyncMessage::Change { event }) => {
319                // Notify all workspace callbacks
320                let callbacks = workspace_callbacks.read().await;
321                for callback in callbacks.iter() {
322                    callback(event.clone());
323                }
324            }
325            Ok(SyncMessage::StateResponse {
326                workspace_id,
327                version,
328                state,
329            }) => {
330                tracing::debug!(
331                    "Received state response for workspace {} (version {})",
332                    workspace_id,
333                    version
334                );
335                // Could emit this as a separate event type if needed
336            }
337            Ok(SyncMessage::Error { message }) => {
338                tracing::error!("Server error: {}", message);
339            }
340            Ok(SyncMessage::Pong) => {
341                tracing::debug!("Received pong");
342            }
343            Ok(other) => {
344                tracing::debug!("Received message: {:?}", other);
345            }
346            Err(e) => {
347                tracing::warn!("Failed to parse server message: {} - {}", e, text);
348            }
349        }
350    }
351
352    /// Internal: Notify state change callbacks
353    async fn notify_state_change(
354        callbacks: &Arc<RwLock<Vec<StateChangeCallback>>>,
355        new_state: ConnectionState,
356    ) {
357        let callbacks = callbacks.read().await;
358        for callback in callbacks.iter() {
359            callback(new_state);
360        }
361    }
362
363    /// Internal: Update connection state and notify callbacks
364    async fn update_state(&self, new_state: ConnectionState) {
365        *self.state.write().await = new_state;
366        let callbacks = self.state_callbacks.read().await;
367        for callback in callbacks.iter() {
368            callback(new_state);
369        }
370    }
371
372    /// Internal: Send message (queue if disconnected)
373    async fn send_message(&self, message: SyncMessage) -> Result<()> {
374        let state = *self.state.read().await;
375
376        if state == ConnectionState::Connected {
377            // Try to send immediately
378            if let Some(ref sender) = *self.ws_sender.read().await {
379                sender.send(message).map_err(|_| {
380                    CollabError::Internal("Failed to send message (channel closed)".to_string())
381                })?;
382                return Ok(());
383            }
384        }
385
386        // Queue message if disconnected or sender unavailable
387        let mut queue = self.message_queue.write().await;
388        if queue.len() >= self.config.max_queue_size {
389            return Err(CollabError::InvalidInput(format!(
390                "Message queue full (max: {})",
391                self.config.max_queue_size
392            )));
393        }
394
395        queue.push(message);
396        Ok(())
397    }
398
399    /// Subscribe to workspace updates
400    ///
401    /// # Arguments
402    /// * `callback` - Function to call when workspace changes occur
403    pub async fn on_workspace_update<F>(&self, callback: F)
404    where
405        F: Fn(ChangeEvent) + Send + Sync + 'static,
406    {
407        let mut callbacks = self.workspace_callbacks.write().await;
408        callbacks.push(Box::new(callback));
409    }
410
411    /// Subscribe to connection state changes
412    ///
413    /// # Arguments
414    /// * `callback` - Function to call when connection state changes
415    pub async fn on_state_change<F>(&self, callback: F)
416    where
417        F: Fn(ConnectionState) + Send + Sync + 'static,
418    {
419        let mut callbacks = self.state_callbacks.write().await;
420        callbacks.push(Box::new(callback));
421    }
422
423    /// Subscribe to a workspace
424    pub async fn subscribe_to_workspace(&self, workspace_id: &str) -> Result<()> {
425        let workspace_id = Uuid::parse_str(workspace_id)
426            .map_err(|e| CollabError::InvalidInput(format!("Invalid workspace ID: {}", e)))?;
427
428        let message = SyncMessage::Subscribe { workspace_id };
429        self.send_message(message).await?;
430
431        Ok(())
432    }
433
434    /// Unsubscribe from a workspace
435    pub async fn unsubscribe_from_workspace(&self, workspace_id: &str) -> Result<()> {
436        let workspace_id = Uuid::parse_str(workspace_id)
437            .map_err(|e| CollabError::InvalidInput(format!("Invalid workspace ID: {}", e)))?;
438
439        let message = SyncMessage::Unsubscribe { workspace_id };
440        self.send_message(message).await?;
441
442        Ok(())
443    }
444
445    /// Request state for a workspace
446    pub async fn request_state(&self, workspace_id: &str, version: i64) -> Result<()> {
447        let workspace_id = Uuid::parse_str(workspace_id)
448            .map_err(|e| CollabError::InvalidInput(format!("Invalid workspace ID: {}", e)))?;
449
450        let message = SyncMessage::StateRequest {
451            workspace_id,
452            version,
453        };
454        self.send_message(message).await?;
455
456        Ok(())
457    }
458
459    /// Send ping (heartbeat)
460    pub async fn ping(&self) -> Result<()> {
461        let message = SyncMessage::Ping;
462        self.send_message(message).await?;
463        Ok(())
464    }
465
466    /// Get connection state
467    pub async fn state(&self) -> ConnectionState {
468        *self.state.read().await
469    }
470
471    /// Get queued message count
472    pub async fn queued_message_count(&self) -> usize {
473        self.message_queue.read().await.len()
474    }
475
476    /// Get reconnect attempt count
477    pub async fn reconnect_count(&self) -> u32 {
478        *self.reconnect_count.read().await
479    }
480
481    /// Disconnect from server
482    pub async fn disconnect(&self) -> Result<()> {
483        // Signal stop
484        *self.stop_signal.write().await = true;
485
486        // Update state
487        *self.state.write().await = ConnectionState::Disconnected;
488        Self::notify_state_change(&self.state_callbacks, ConnectionState::Disconnected).await;
489
490        // Wait for connection task to finish
491        if let Some(task) = self.connection_task.write().await.take() {
492            task.abort();
493        }
494
495        Ok(())
496    }
497}
498
499impl Drop for CollabClient {
500    fn drop(&mut self) {
501        // Ensure we disconnect when dropped
502        let stop_signal = self.stop_signal.clone();
503        let state = self.state.clone();
504        tokio::runtime::Handle::try_current().map(|handle| {
505            handle.spawn(async move {
506                *stop_signal.write().await = true;
507                *state.write().await = ConnectionState::Disconnected;
508            });
509        });
510    }
511}