use crate::auth::AuthService;
use crate::error::{CollabError, Result};
use crate::events::EventBus;
use crate::sync::{SyncEngine, SyncMessage};
use crate::workspace;
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Query, State,
},
response::Response,
};
use futures::{sink::SinkExt, stream::StreamExt};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::select;
use uuid::Uuid;
#[derive(Clone)]
pub struct WsState {
pub auth: Arc<AuthService>,
pub sync: Arc<SyncEngine>,
pub event_bus: Arc<EventBus>,
pub workspace: Arc<workspace::WorkspaceService>,
}
#[allow(clippy::implicit_hasher)]
pub async fn ws_handler(
ws: WebSocketUpgrade,
Query(params): Query<HashMap<String, String>>,
State(state): State<WsState>,
) -> Response {
let user_id = params
.get("token")
.and_then(|token| {
state
.auth
.verify_token(token)
.ok()
.and_then(|claims| Uuid::parse_str(&claims.sub).ok())
})
.or_else(|| {
params.get("user_id").and_then(|id| Uuid::parse_str(id).ok())
});
ws.on_upgrade(move |socket| handle_socket(socket, state, user_id))
}
async fn handle_socket(socket: WebSocket, state: WsState, user_id: Option<Uuid>) {
let (mut sender, mut receiver) = socket.split();
let client_id = Uuid::new_v4();
tracing::info!("WebSocket client connected: {} (user: {:?})", client_id, user_id);
let mut subscriptions: Vec<Uuid> = Vec::new();
let mut event_rx = state.event_bus.subscribe();
loop {
select! {
msg = receiver.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
if let Err(e) = handle_client_message(&text, client_id, user_id, &state, &mut subscriptions, &mut sender).await {
tracing::error!("Error handling client message: {}", e);
let _ = sender.send(Message::Text(
serde_json::to_string(&SyncMessage::Error {
message: e.to_string(),
}).unwrap().into()
)).await;
}
}
Some(Ok(Message::Close(_))) => {
tracing::info!("Client {} requested close", client_id);
break;
}
Some(Ok(Message::Ping(data))) => {
let _ = sender.send(Message::Pong(data)).await;
}
Some(Err(e)) => {
tracing::error!("WebSocket error: {}", e);
break;
}
None => {
tracing::info!("Client {} disconnected", client_id);
break;
}
_ => {}
}
}
event = event_rx.recv() => {
match event {
Ok(change_event) => {
if subscriptions.contains(&change_event.workspace_id) {
let msg = SyncMessage::Change { event: change_event };
if let Ok(json) = serde_json::to_string(&msg) {
let _ = sender.send(Message::Text(json.into())).await;
}
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("Client {} lagged {} messages", client_id, n);
}
Err(_) => {
tracing::error!("Event channel closed");
break;
}
}
}
}
}
for workspace_id in subscriptions {
let _ = state.sync.unsubscribe(workspace_id, client_id);
}
tracing::info!("Client {} connection closed", client_id);
}
async fn handle_client_message(
text: &str,
client_id: Uuid,
user_id: Option<Uuid>,
state: &WsState,
subscriptions: &mut Vec<Uuid>,
sender: &mut futures::stream::SplitSink<WebSocket, Message>,
) -> Result<()> {
let message: SyncMessage = serde_json::from_str(text)
.map_err(|e| CollabError::InvalidInput(format!("Invalid JSON: {e}")))?;
match message {
SyncMessage::Subscribe { workspace_id } => {
if let Some(uid) = user_id {
if let Err(e) = state.workspace.get_member(workspace_id, uid).await {
tracing::warn!(
"User {uid} attempted to access workspace {workspace_id} without permission: {e}"
);
return Err(CollabError::AuthorizationFailed(format!(
"Access denied to workspace {workspace_id}"
)));
}
} else {
return Err(CollabError::AuthenticationFailed(
"Authentication required for workspace access".to_string(),
));
}
state.sync.subscribe(workspace_id, client_id)?;
subscriptions.push(workspace_id);
tracing::info!("Client {} subscribed to workspace {}", client_id, workspace_id);
if let Some(sync_state) = state.sync.get_state(workspace_id) {
let response = SyncMessage::StateResponse {
workspace_id,
version: sync_state.version,
state: sync_state.state,
};
let json = serde_json::to_string(&response)?;
sender
.send(Message::Text(json.into()))
.await
.map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
}
}
SyncMessage::Unsubscribe { workspace_id } => {
state.sync.unsubscribe(workspace_id, client_id)?;
subscriptions.retain(|id| *id != workspace_id);
tracing::info!("Client {} unsubscribed from workspace {}", client_id, workspace_id);
}
SyncMessage::StateRequest {
workspace_id,
version,
} => {
if let Some(sync_state) = state.sync.get_state(workspace_id) {
if sync_state.version > version {
let response = SyncMessage::StateResponse {
workspace_id,
version: sync_state.version,
state: sync_state.state,
};
let json = serde_json::to_string(&response)?;
sender
.send(Message::Text(json.into()))
.await
.map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
}
}
}
SyncMessage::Ping => {
let pong = SyncMessage::Pong;
let json = serde_json::to_string(&pong)?;
sender
.send(Message::Text(json.into()))
.await
.map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
}
_ => {
tracing::warn!("Unexpected message type from client {}", client_id);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sync_message_serialization() {
let msg = SyncMessage::Subscribe {
workspace_id: Uuid::new_v4(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("subscribe"));
let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
match deserialized {
SyncMessage::Subscribe { .. } => {}
_ => panic!("Wrong message type"),
}
}
}