1use crate::error::{CollabError, Result};
8use crate::events::ChangeEvent;
9use crate::sync::SyncMessage;
10use futures::{SinkExt, StreamExt};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::mpsc;
15use tokio::sync::RwLock;
16use tokio::time::sleep;
17use tokio_tungstenite::{connect_async, tungstenite::Message};
18use uuid::Uuid;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ClientConfig {
23 pub server_url: String,
25 pub auth_token: String,
27 pub max_reconnect_attempts: Option<u32>,
29 pub max_queue_size: usize,
31 pub initial_backoff_ms: u64,
33 pub max_backoff_ms: u64,
35}
36
37impl Default for ClientConfig {
38 fn default() -> Self {
39 Self {
40 server_url: String::new(),
41 auth_token: String::new(),
42 max_reconnect_attempts: None,
43 max_queue_size: 1000,
44 initial_backoff_ms: 1000,
45 max_backoff_ms: 30000,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum ConnectionState {
53 Disconnected,
55 Connecting,
57 Connected,
59 Reconnecting,
61}
62
63pub type WorkspaceUpdateCallback = Box<dyn Fn(ChangeEvent) + Send + Sync>;
65
66pub type StateChangeCallback = Box<dyn Fn(ConnectionState) + Send + Sync>;
68
69pub struct CollabClient {
71 config: ClientConfig,
73 client_id: Uuid,
75 state: Arc<RwLock<ConnectionState>>,
77 message_queue: Arc<RwLock<Vec<SyncMessage>>>,
79 ws_sender: Arc<RwLock<Option<mpsc::UnboundedSender<SyncMessage>>>>,
81 connection_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
83 workspace_callbacks: Arc<RwLock<Vec<WorkspaceUpdateCallback>>>,
85 state_callbacks: Arc<RwLock<Vec<StateChangeCallback>>>,
87 reconnect_count: Arc<RwLock<u32>>,
89 stop_signal: Arc<RwLock<bool>>,
91}
92
93impl CollabClient {
94 pub async fn connect(config: ClientConfig) -> Result<Self> {
96 if config.server_url.is_empty() {
97 return Err(CollabError::InvalidInput("server_url cannot be empty".to_string()));
98 }
99
100 let client = Self {
101 config: config.clone(),
102 client_id: Uuid::new_v4(),
103 state: Arc::new(RwLock::new(ConnectionState::Connecting)),
104 message_queue: Arc::new(RwLock::new(Vec::new())),
105 ws_sender: Arc::new(RwLock::new(None)),
106 connection_task: Arc::new(RwLock::new(None)),
107 workspace_callbacks: Arc::new(RwLock::new(Vec::new())),
108 state_callbacks: Arc::new(RwLock::new(Vec::new())),
109 reconnect_count: Arc::new(RwLock::new(0)),
110 stop_signal: Arc::new(RwLock::new(false)),
111 };
112
113 client.update_state(ConnectionState::Connecting).await;
115 client.start_connection_loop().await?;
116
117 Ok(client)
118 }
119
120 async fn start_connection_loop(&self) -> Result<()> {
122 let config = self.config.clone();
123 let state = self.state.clone();
124 let message_queue = self.message_queue.clone();
125 let ws_sender = self.ws_sender.clone();
126 let stop_signal = self.stop_signal.clone();
127 let reconnect_count = self.reconnect_count.clone();
128 let workspace_callbacks = self.workspace_callbacks.clone();
129 let state_callbacks = self.state_callbacks.clone();
130
131 let task = tokio::spawn(async move {
132 let mut backoff_ms = config.initial_backoff_ms;
133
134 loop {
135 if *stop_signal.read().await {
137 break;
138 }
139
140 match Self::try_connect(
142 &config,
143 &state,
144 &ws_sender,
145 &workspace_callbacks,
146 &state_callbacks,
147 &stop_signal,
148 )
149 .await
150 {
151 Ok(()) => {
152 backoff_ms = config.initial_backoff_ms;
154 *reconnect_count.write().await = 0;
155
156 let mut queue = message_queue.write().await;
158 while let Some(msg) = queue.pop() {
159 if let Some(ref sender) = *ws_sender.read().await {
160 let _ = sender.send(msg);
161 }
162 }
163
164 }
167 Err(e) => {
168 tracing::warn!("Connection failed: {}, will retry", e);
169
170 let current_count = *reconnect_count.read().await;
172 if let Some(max) = config.max_reconnect_attempts {
173 if current_count >= max {
174 tracing::error!("Max reconnect attempts ({}) reached", max);
175 *state.write().await = ConnectionState::Disconnected;
176 Self::notify_state_change(
177 &state_callbacks,
178 ConnectionState::Disconnected,
179 )
180 .await;
181 break;
182 }
183 }
184
185 *reconnect_count.write().await += 1;
186 *state.write().await = ConnectionState::Reconnecting;
187 Self::notify_state_change(&state_callbacks, ConnectionState::Reconnecting)
188 .await;
189
190 sleep(Duration::from_millis(backoff_ms)).await;
192 backoff_ms = (backoff_ms * 2).min(config.max_backoff_ms);
193 }
194 }
195 }
196 });
197
198 *self.connection_task.write().await = Some(task);
199 Ok(())
200 }
201
202 async fn try_connect(
204 config: &ClientConfig,
205 state: &Arc<RwLock<ConnectionState>>,
206 ws_sender: &Arc<RwLock<Option<mpsc::UnboundedSender<SyncMessage>>>>,
207 workspace_callbacks: &Arc<RwLock<Vec<WorkspaceUpdateCallback>>>,
208 state_callbacks: &Arc<RwLock<Vec<StateChangeCallback>>>,
209 stop_signal: &Arc<RwLock<bool>>,
210 ) -> Result<()> {
211 let url = format!("{}?token={}", config.server_url, config.auth_token);
213 tracing::info!("Connecting to WebSocket: {}", config.server_url);
214
215 let (ws_stream, _) = connect_async(&url)
217 .await
218 .map_err(|e| CollabError::Internal(format!("WebSocket connection failed: {}", e)))?;
219
220 *state.write().await = ConnectionState::Connected;
221 Self::notify_state_change(state_callbacks, ConnectionState::Connected).await;
222
223 tracing::info!("WebSocket connected successfully");
224
225 let (mut write, mut read) = ws_stream.split();
227
228 let (tx, mut rx) = mpsc::unbounded_channel();
230 *ws_sender.write().await = Some(tx);
231
232 let mut write_handle = write;
234 let write_task = tokio::spawn(async move {
235 while let Some(msg) = rx.recv().await {
236 let json = match serde_json::to_string(&msg) {
237 Ok(json) => json,
238 Err(e) => {
239 tracing::error!("Failed to serialize message: {}", e);
240 continue;
241 }
242 };
243
244 if let Err(e) = write_handle.send(Message::Text(json)).await {
245 tracing::error!("Failed to send message: {}", e);
246 break;
247 }
248 }
249 });
250
251 loop {
253 if *stop_signal.read().await {
255 tracing::info!("Stop signal received, closing connection");
256 break;
257 }
258
259 tokio::select! {
260 msg_opt = read.next() => {
262 match msg_opt {
263 Some(Ok(Message::Text(text))) => {
264 Self::handle_server_message(&text, workspace_callbacks).await;
265 }
266 Some(Ok(Message::Close(_))) => {
267 tracing::info!("Server closed connection");
268 *state.write().await = ConnectionState::Disconnected;
269 Self::notify_state_change(state_callbacks, ConnectionState::Disconnected).await;
270 break;
271 }
272 Some(Ok(Message::Ping(_))) => {
273 tracing::debug!("Received ping");
275 }
276 Some(Ok(Message::Pong(_))) => {
277 tracing::debug!("Received pong");
278 }
279 Some(Err(e)) => {
280 tracing::error!("WebSocket error: {}", e);
281 *state.write().await = ConnectionState::Disconnected;
282 Self::notify_state_change(state_callbacks, ConnectionState::Disconnected).await;
283 return Err(CollabError::Internal(format!("WebSocket error: {}", e)));
284 }
285 None => {
286 tracing::info!("WebSocket stream ended");
287 *state.write().await = ConnectionState::Disconnected;
288 Self::notify_state_change(state_callbacks, ConnectionState::Disconnected).await;
289 break;
290 }
291 _ => {}
292 }
293 }
294
295 _ = tokio::time::sleep(Duration::from_millis(100)) => {
297 if *stop_signal.read().await {
298 tracing::info!("Stop signal received, closing connection");
299 break;
300 }
301 }
302 }
303 }
304
305 write_task.abort();
307 *ws_sender.write().await = None;
308
309 Err(CollabError::Internal("Connection closed".to_string()))
310 }
311
312 async fn handle_server_message(
314 text: &str,
315 workspace_callbacks: &Arc<RwLock<Vec<WorkspaceUpdateCallback>>>,
316 ) {
317 match serde_json::from_str::<SyncMessage>(text) {
318 Ok(SyncMessage::Change { event }) => {
319 let callbacks = workspace_callbacks.read().await;
321 for callback in callbacks.iter() {
322 callback(event.clone());
323 }
324 }
325 Ok(SyncMessage::StateResponse {
326 workspace_id,
327 version,
328 state,
329 }) => {
330 tracing::debug!(
331 "Received state response for workspace {} (version {})",
332 workspace_id,
333 version
334 );
335 }
337 Ok(SyncMessage::Error { message }) => {
338 tracing::error!("Server error: {}", message);
339 }
340 Ok(SyncMessage::Pong) => {
341 tracing::debug!("Received pong");
342 }
343 Ok(other) => {
344 tracing::debug!("Received message: {:?}", other);
345 }
346 Err(e) => {
347 tracing::warn!("Failed to parse server message: {} - {}", e, text);
348 }
349 }
350 }
351
352 async fn notify_state_change(
354 callbacks: &Arc<RwLock<Vec<StateChangeCallback>>>,
355 new_state: ConnectionState,
356 ) {
357 let callbacks = callbacks.read().await;
358 for callback in callbacks.iter() {
359 callback(new_state);
360 }
361 }
362
363 async fn update_state(&self, new_state: ConnectionState) {
365 *self.state.write().await = new_state;
366 let callbacks = self.state_callbacks.read().await;
367 for callback in callbacks.iter() {
368 callback(new_state);
369 }
370 }
371
372 async fn send_message(&self, message: SyncMessage) -> Result<()> {
374 let state = *self.state.read().await;
375
376 if state == ConnectionState::Connected {
377 if let Some(ref sender) = *self.ws_sender.read().await {
379 sender.send(message).map_err(|_| {
380 CollabError::Internal("Failed to send message (channel closed)".to_string())
381 })?;
382 return Ok(());
383 }
384 }
385
386 let mut queue = self.message_queue.write().await;
388 if queue.len() >= self.config.max_queue_size {
389 return Err(CollabError::InvalidInput(format!(
390 "Message queue full (max: {})",
391 self.config.max_queue_size
392 )));
393 }
394
395 queue.push(message);
396 Ok(())
397 }
398
399 pub async fn on_workspace_update<F>(&self, callback: F)
404 where
405 F: Fn(ChangeEvent) + Send + Sync + 'static,
406 {
407 let mut callbacks = self.workspace_callbacks.write().await;
408 callbacks.push(Box::new(callback));
409 }
410
411 pub async fn on_state_change<F>(&self, callback: F)
416 where
417 F: Fn(ConnectionState) + Send + Sync + 'static,
418 {
419 let mut callbacks = self.state_callbacks.write().await;
420 callbacks.push(Box::new(callback));
421 }
422
423 pub async fn subscribe_to_workspace(&self, workspace_id: &str) -> Result<()> {
425 let workspace_id = Uuid::parse_str(workspace_id)
426 .map_err(|e| CollabError::InvalidInput(format!("Invalid workspace ID: {}", e)))?;
427
428 let message = SyncMessage::Subscribe { workspace_id };
429 self.send_message(message).await?;
430
431 Ok(())
432 }
433
434 pub async fn unsubscribe_from_workspace(&self, workspace_id: &str) -> Result<()> {
436 let workspace_id = Uuid::parse_str(workspace_id)
437 .map_err(|e| CollabError::InvalidInput(format!("Invalid workspace ID: {}", e)))?;
438
439 let message = SyncMessage::Unsubscribe { workspace_id };
440 self.send_message(message).await?;
441
442 Ok(())
443 }
444
445 pub async fn request_state(&self, workspace_id: &str, version: i64) -> Result<()> {
447 let workspace_id = Uuid::parse_str(workspace_id)
448 .map_err(|e| CollabError::InvalidInput(format!("Invalid workspace ID: {}", e)))?;
449
450 let message = SyncMessage::StateRequest {
451 workspace_id,
452 version,
453 };
454 self.send_message(message).await?;
455
456 Ok(())
457 }
458
459 pub async fn ping(&self) -> Result<()> {
461 let message = SyncMessage::Ping;
462 self.send_message(message).await?;
463 Ok(())
464 }
465
466 pub async fn state(&self) -> ConnectionState {
468 *self.state.read().await
469 }
470
471 pub async fn queued_message_count(&self) -> usize {
473 self.message_queue.read().await.len()
474 }
475
476 pub async fn reconnect_count(&self) -> u32 {
478 *self.reconnect_count.read().await
479 }
480
481 pub async fn disconnect(&self) -> Result<()> {
483 *self.stop_signal.write().await = true;
485
486 *self.state.write().await = ConnectionState::Disconnected;
488 Self::notify_state_change(&self.state_callbacks, ConnectionState::Disconnected).await;
489
490 if let Some(task) = self.connection_task.write().await.take() {
492 task.abort();
493 }
494
495 Ok(())
496 }
497}
498
499impl Drop for CollabClient {
500 fn drop(&mut self) {
501 let stop_signal = self.stop_signal.clone();
503 let state = self.state.clone();
504 tokio::runtime::Handle::try_current().map(|handle| {
505 handle.spawn(async move {
506 *stop_signal.write().await = true;
507 *state.write().await = ConnectionState::Disconnected;
508 });
509 });
510 }
511}