use crate::error::{CollabError, Result};
use crate::events::ChangeEvent;
use crate::sync::SyncMessage;
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::RwLock;
use tokio::time::sleep;
use tokio_tungstenite::{connect_async, tungstenite::Message};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientConfig {
pub server_url: String,
pub auth_token: String,
pub max_reconnect_attempts: Option<u32>,
pub max_queue_size: usize,
pub initial_backoff_ms: u64,
pub max_backoff_ms: u64,
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
server_url: String::new(),
auth_token: String::new(),
max_reconnect_attempts: None,
max_queue_size: 1000,
initial_backoff_ms: 1000,
max_backoff_ms: 30000,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Disconnected,
Connecting,
Connected,
Reconnecting,
}
pub type WorkspaceUpdateCallback = Box<dyn Fn(ChangeEvent) + Send + Sync>;
pub type StateChangeCallback = Box<dyn Fn(ConnectionState) + Send + Sync>;
pub struct CollabClient {
config: ClientConfig,
_client_id: Uuid,
state: Arc<RwLock<ConnectionState>>,
message_queue: Arc<RwLock<Vec<SyncMessage>>>,
ws_sender: Arc<RwLock<Option<mpsc::UnboundedSender<SyncMessage>>>>,
connection_task: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
workspace_callbacks: Arc<RwLock<Vec<WorkspaceUpdateCallback>>>,
state_callbacks: Arc<RwLock<Vec<StateChangeCallback>>>,
reconnect_count: Arc<RwLock<u32>>,
stop_signal: Arc<RwLock<bool>>,
}
impl CollabClient {
pub async fn connect(config: ClientConfig) -> Result<Self> {
if config.server_url.is_empty() {
return Err(CollabError::InvalidInput("server_url cannot be empty".to_string()));
}
let client = Self {
config: config.clone(),
_client_id: Uuid::new_v4(),
state: Arc::new(RwLock::new(ConnectionState::Connecting)),
message_queue: Arc::new(RwLock::new(Vec::new())),
ws_sender: Arc::new(RwLock::new(None)),
connection_task: Arc::new(RwLock::new(None)),
workspace_callbacks: Arc::new(RwLock::new(Vec::new())),
state_callbacks: Arc::new(RwLock::new(Vec::new())),
reconnect_count: Arc::new(RwLock::new(0)),
stop_signal: Arc::new(RwLock::new(false)),
};
client.update_state(ConnectionState::Connecting).await;
client.start_connection_loop().await?;
Ok(client)
}
async fn start_connection_loop(&self) -> Result<()> {
let config = self.config.clone();
let state = self.state.clone();
let message_queue = self.message_queue.clone();
let ws_sender = self.ws_sender.clone();
let stop_signal = self.stop_signal.clone();
let reconnect_count = self.reconnect_count.clone();
let workspace_callbacks = self.workspace_callbacks.clone();
let state_callbacks = self.state_callbacks.clone();
let task = tokio::spawn(async move {
let mut backoff_ms = config.initial_backoff_ms;
loop {
if *stop_signal.read().await {
break;
}
match Self::try_connect(
&config,
&state,
&ws_sender,
&workspace_callbacks,
&state_callbacks,
&stop_signal,
)
.await
{
Ok(()) => {
backoff_ms = config.initial_backoff_ms;
*reconnect_count.write().await = 0;
let mut queue = message_queue.write().await;
while let Some(msg) = queue.pop() {
if let Some(ref sender) = *ws_sender.read().await {
let _ = sender.send(msg);
}
}
}
Err(e) => {
tracing::warn!("Connection failed: {}, will retry", e);
let current_count = *reconnect_count.read().await;
if let Some(max) = config.max_reconnect_attempts {
if current_count >= max {
tracing::error!("Max reconnect attempts ({}) reached", max);
*state.write().await = ConnectionState::Disconnected;
Self::notify_state_change(
&state_callbacks,
ConnectionState::Disconnected,
)
.await;
break;
}
}
*reconnect_count.write().await += 1;
*state.write().await = ConnectionState::Reconnecting;
Self::notify_state_change(&state_callbacks, ConnectionState::Reconnecting)
.await;
sleep(Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms * 2).min(config.max_backoff_ms);
}
}
}
});
*self.connection_task.write().await = Some(task);
Ok(())
}
async fn try_connect(
config: &ClientConfig,
state: &Arc<RwLock<ConnectionState>>,
ws_sender: &Arc<RwLock<Option<mpsc::UnboundedSender<SyncMessage>>>>,
workspace_callbacks: &Arc<RwLock<Vec<WorkspaceUpdateCallback>>>,
state_callbacks: &Arc<RwLock<Vec<StateChangeCallback>>>,
stop_signal: &Arc<RwLock<bool>>,
) -> Result<()> {
let url = format!("{}?token={}", config.server_url, config.auth_token);
tracing::info!("Connecting to WebSocket: {}", config.server_url);
let (ws_stream, _) = connect_async(&url)
.await
.map_err(|e| CollabError::Internal(format!("WebSocket connection failed: {e}")))?;
*state.write().await = ConnectionState::Connected;
Self::notify_state_change(state_callbacks, ConnectionState::Connected).await;
tracing::info!("WebSocket connected successfully");
let (write, mut read) = ws_stream.split();
let (tx, mut rx) = mpsc::unbounded_channel();
*ws_sender.write().await = Some(tx);
let mut write_handle = write;
let write_task = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
let json = match serde_json::to_string(&msg) {
Ok(json) => json,
Err(e) => {
tracing::error!("Failed to serialize message: {}", e);
continue;
}
};
if let Err(e) = write_handle.send(Message::Text(json)).await {
tracing::error!("Failed to send message: {}", e);
break;
}
}
});
loop {
if *stop_signal.read().await {
tracing::info!("Stop signal received, closing connection");
break;
}
tokio::select! {
msg_opt = read.next() => {
match msg_opt {
Some(Ok(Message::Text(text))) => {
Self::handle_server_message(&text, workspace_callbacks).await;
}
Some(Ok(Message::Close(_))) => {
tracing::info!("Server closed connection");
*state.write().await = ConnectionState::Disconnected;
Self::notify_state_change(state_callbacks, ConnectionState::Disconnected).await;
break;
}
Some(Ok(Message::Ping(_))) => {
tracing::debug!("Received ping");
}
Some(Ok(Message::Pong(_))) => {
tracing::debug!("Received pong");
}
Some(Err(e)) => {
tracing::error!("WebSocket error: {}", e);
*state.write().await = ConnectionState::Disconnected;
Self::notify_state_change(state_callbacks, ConnectionState::Disconnected).await;
return Err(CollabError::Internal(format!("WebSocket error: {e}")));
}
None => {
tracing::info!("WebSocket stream ended");
*state.write().await = ConnectionState::Disconnected;
Self::notify_state_change(state_callbacks, ConnectionState::Disconnected).await;
break;
}
_ => {}
}
}
() = sleep(Duration::from_millis(100)) => {
if *stop_signal.read().await {
tracing::info!("Stop signal received, closing connection");
break;
}
}
}
}
write_task.abort();
*ws_sender.write().await = None;
Err(CollabError::Internal("Connection closed".to_string()))
}
async fn handle_server_message(
text: &str,
workspace_callbacks: &Arc<RwLock<Vec<WorkspaceUpdateCallback>>>,
) {
match serde_json::from_str::<SyncMessage>(text) {
Ok(SyncMessage::Change { event }) => {
let callbacks = workspace_callbacks.read().await;
for callback in callbacks.iter() {
callback(event.clone());
}
}
Ok(SyncMessage::StateResponse {
workspace_id,
version,
state: _,
}) => {
tracing::debug!(
"Received state response for workspace {} (version {})",
workspace_id,
version
);
}
Ok(SyncMessage::Error { message }) => {
tracing::error!("Server error: {}", message);
}
Ok(SyncMessage::Pong) => {
tracing::debug!("Received pong");
}
Ok(other) => {
tracing::debug!("Received message: {:?}", other);
}
Err(e) => {
tracing::warn!("Failed to parse server message: {} - {}", e, text);
}
}
}
async fn notify_state_change(
callbacks: &Arc<RwLock<Vec<StateChangeCallback>>>,
new_state: ConnectionState,
) {
let callbacks = callbacks.read().await;
for callback in callbacks.iter() {
callback(new_state);
}
}
async fn update_state(&self, new_state: ConnectionState) {
*self.state.write().await = new_state;
let callbacks = self.state_callbacks.read().await;
for callback in callbacks.iter() {
callback(new_state);
}
}
async fn send_message(&self, message: SyncMessage) -> Result<()> {
let state = *self.state.read().await;
if state == ConnectionState::Connected {
if let Some(ref sender) = *self.ws_sender.read().await {
sender.send(message).map_err(|_| {
CollabError::Internal("Failed to send message (channel closed)".to_string())
})?;
return Ok(());
}
}
let mut queue = self.message_queue.write().await;
if queue.len() >= self.config.max_queue_size {
return Err(CollabError::InvalidInput(format!(
"Message queue full (max: {})",
self.config.max_queue_size
)));
}
queue.push(message);
drop(queue);
Ok(())
}
pub async fn on_workspace_update<F>(&self, callback: F)
where
F: Fn(ChangeEvent) + Send + Sync + 'static,
{
let mut callbacks = self.workspace_callbacks.write().await;
callbacks.push(Box::new(callback));
}
pub async fn on_state_change<F>(&self, callback: F)
where
F: Fn(ConnectionState) + Send + Sync + 'static,
{
let mut callbacks = self.state_callbacks.write().await;
callbacks.push(Box::new(callback));
}
pub async fn subscribe_to_workspace(&self, workspace_id: &str) -> Result<()> {
let workspace_id = Uuid::parse_str(workspace_id)
.map_err(|e| CollabError::InvalidInput(format!("Invalid workspace ID: {e}")))?;
let message = SyncMessage::Subscribe { workspace_id };
self.send_message(message).await?;
Ok(())
}
pub async fn unsubscribe_from_workspace(&self, workspace_id: &str) -> Result<()> {
let workspace_id = Uuid::parse_str(workspace_id)
.map_err(|e| CollabError::InvalidInput(format!("Invalid workspace ID: {e}")))?;
let message = SyncMessage::Unsubscribe { workspace_id };
self.send_message(message).await?;
Ok(())
}
pub async fn request_state(&self, workspace_id: &str, version: i64) -> Result<()> {
let workspace_id = Uuid::parse_str(workspace_id)
.map_err(|e| CollabError::InvalidInput(format!("Invalid workspace ID: {e}")))?;
let message = SyncMessage::StateRequest {
workspace_id,
version,
};
self.send_message(message).await?;
Ok(())
}
pub async fn ping(&self) -> Result<()> {
let message = SyncMessage::Ping;
self.send_message(message).await?;
Ok(())
}
pub async fn state(&self) -> ConnectionState {
*self.state.read().await
}
pub async fn queued_message_count(&self) -> usize {
self.message_queue.read().await.len()
}
pub async fn reconnect_count(&self) -> u32 {
*self.reconnect_count.read().await
}
pub async fn disconnect(&self) -> Result<()> {
*self.stop_signal.write().await = true;
*self.state.write().await = ConnectionState::Disconnected;
Self::notify_state_change(&self.state_callbacks, ConnectionState::Disconnected).await;
let task = self.connection_task.write().await.take();
if let Some(task) = task {
task.abort();
}
Ok(())
}
}
impl Drop for CollabClient {
fn drop(&mut self) {
let stop_signal = self.stop_signal.clone();
let state = self.state.clone();
if let Ok(handle) = tokio::runtime::Handle::try_current() {
handle.spawn(async move {
*stop_signal.write().await = true;
*state.write().await = ConnectionState::Disconnected;
});
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_config_default() {
let config = ClientConfig::default();
assert_eq!(config.server_url, String::new());
assert_eq!(config.auth_token, "");
assert_eq!(config.max_reconnect_attempts, None);
assert_eq!(config.max_queue_size, 1000);
assert_eq!(config.initial_backoff_ms, 1000);
assert_eq!(config.max_backoff_ms, 30000);
}
#[test]
fn test_client_config_clone() {
let config = ClientConfig {
server_url: "ws://localhost:8080".to_string(),
auth_token: "token123".to_string(),
max_reconnect_attempts: Some(5),
max_queue_size: 500,
initial_backoff_ms: 500,
max_backoff_ms: 10000,
};
let cloned = config.clone();
assert_eq!(config.server_url, cloned.server_url);
assert_eq!(config.auth_token, cloned.auth_token);
assert_eq!(config.max_reconnect_attempts, cloned.max_reconnect_attempts);
assert_eq!(config.max_queue_size, cloned.max_queue_size);
}
#[test]
fn test_client_config_serialization() {
let config = ClientConfig {
server_url: "ws://localhost:8080".to_string(),
auth_token: "token123".to_string(),
max_reconnect_attempts: Some(3),
max_queue_size: 200,
initial_backoff_ms: 1500,
max_backoff_ms: 20000,
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: ClientConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config.server_url, deserialized.server_url);
assert_eq!(config.auth_token, deserialized.auth_token);
assert_eq!(config.max_reconnect_attempts, deserialized.max_reconnect_attempts);
}
#[test]
fn test_connection_state_equality() {
assert_eq!(ConnectionState::Disconnected, ConnectionState::Disconnected);
assert_eq!(ConnectionState::Connecting, ConnectionState::Connecting);
assert_eq!(ConnectionState::Connected, ConnectionState::Connected);
assert_eq!(ConnectionState::Reconnecting, ConnectionState::Reconnecting);
assert_ne!(ConnectionState::Disconnected, ConnectionState::Connected);
assert_ne!(ConnectionState::Connecting, ConnectionState::Reconnecting);
}
#[test]
fn test_connection_state_copy() {
let state = ConnectionState::Connected;
let copied = state;
assert_eq!(state, copied);
}
#[test]
fn test_connection_state_debug() {
let state = ConnectionState::Connected;
let debug_str = format!("{state:?}");
assert!(debug_str.contains("Connected"));
}
#[tokio::test]
async fn test_connect_with_empty_url() {
let config = ClientConfig {
server_url: String::new(),
auth_token: "token".to_string(),
..Default::default()
};
let result = CollabClient::connect(config).await;
assert!(result.is_err());
if let Err(e) = result {
match e {
CollabError::InvalidInput(msg) => {
assert!(msg.contains("server_url"));
}
_ => panic!("Expected InvalidInput error"),
}
}
}
#[tokio::test]
async fn test_subscribe_to_workspace_invalid_id() {
let workspace_id = "invalid-uuid";
let result = Uuid::parse_str(workspace_id);
assert!(result.is_err());
}
#[tokio::test]
async fn test_subscribe_to_workspace_valid_id() {
let workspace_id = Uuid::new_v4().to_string();
let result = Uuid::parse_str(&workspace_id);
assert!(result.is_ok());
}
#[test]
fn test_client_config_with_max_attempts() {
let config = ClientConfig {
max_reconnect_attempts: Some(10),
..Default::default()
};
assert_eq!(config.max_reconnect_attempts, Some(10));
}
#[test]
fn test_client_config_unlimited_attempts() {
let config = ClientConfig {
max_reconnect_attempts: None,
..Default::default()
};
assert_eq!(config.max_reconnect_attempts, None);
}
#[test]
fn test_client_config_queue_size() {
let config = ClientConfig {
max_queue_size: 5000,
..Default::default()
};
assert_eq!(config.max_queue_size, 5000);
}
#[test]
fn test_client_config_backoff_values() {
let config = ClientConfig {
initial_backoff_ms: 2000,
max_backoff_ms: 60000,
..Default::default()
};
assert_eq!(config.initial_backoff_ms, 2000);
assert_eq!(config.max_backoff_ms, 60000);
}
#[test]
fn test_sync_message_subscribe() {
let workspace_id = Uuid::new_v4();
let msg = SyncMessage::Subscribe { workspace_id };
let json = serde_json::to_string(&msg).unwrap();
let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
match deserialized {
SyncMessage::Subscribe {
workspace_id: ws_id,
} => {
assert_eq!(ws_id, workspace_id);
}
_ => panic!("Expected Subscribe message"),
}
}
#[test]
fn test_sync_message_unsubscribe() {
let workspace_id = Uuid::new_v4();
let msg = SyncMessage::Unsubscribe { workspace_id };
let json = serde_json::to_string(&msg).unwrap();
let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
match deserialized {
SyncMessage::Unsubscribe {
workspace_id: ws_id,
} => {
assert_eq!(ws_id, workspace_id);
}
_ => panic!("Expected Unsubscribe message"),
}
}
#[test]
fn test_sync_message_ping() {
let msg = SyncMessage::Ping;
let json = serde_json::to_string(&msg).unwrap();
let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
match deserialized {
SyncMessage::Ping => {}
_ => panic!("Expected Ping message"),
}
}
#[test]
fn test_sync_message_pong() {
let msg = SyncMessage::Pong;
let json = serde_json::to_string(&msg).unwrap();
let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
match deserialized {
SyncMessage::Pong => {}
_ => panic!("Expected Pong message"),
}
}
#[test]
fn test_sync_message_error() {
let msg = SyncMessage::Error {
message: "Test error".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
match deserialized {
SyncMessage::Error { message } => {
assert_eq!(message, "Test error");
}
_ => panic!("Expected Error message"),
}
}
#[test]
fn test_sync_message_state_request() {
let workspace_id = Uuid::new_v4();
let msg = SyncMessage::StateRequest {
workspace_id,
version: 42,
};
let json = serde_json::to_string(&msg).unwrap();
let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
match deserialized {
SyncMessage::StateRequest {
workspace_id: ws_id,
version,
} => {
assert_eq!(ws_id, workspace_id);
assert_eq!(version, 42);
}
_ => panic!("Expected StateRequest message"),
}
}
}