mockforge_collab/
websocket.rs1use crate::auth::AuthService;
4use crate::error::{CollabError, Result};
5use crate::events::EventBus;
6use crate::sync::{SyncEngine, SyncMessage};
7use crate::workspace;
8use axum::{
9 extract::{
10 ws::{Message, WebSocket, WebSocketUpgrade},
11 Query, State,
12 },
13 response::Response,
14};
15use futures::{sink::SinkExt, stream::StreamExt};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::select;
19use uuid::Uuid;
20
21#[derive(Clone)]
23pub struct WsState {
24 pub auth: Arc<AuthService>,
26 pub sync: Arc<SyncEngine>,
28 pub event_bus: Arc<EventBus>,
30 pub workspace: Arc<workspace::WorkspaceService>,
32}
33
34#[allow(clippy::implicit_hasher)]
36pub async fn ws_handler(
37 ws: WebSocketUpgrade,
38 Query(params): Query<HashMap<String, String>>,
39 State(state): State<WsState>,
40) -> Response {
41 let user_id = params
43 .get("token")
44 .and_then(|token| {
45 state
46 .auth
47 .verify_token(token)
48 .ok()
49 .and_then(|claims| Uuid::parse_str(&claims.sub).ok())
50 })
51 .or_else(|| {
52 params.get("user_id").and_then(|id| Uuid::parse_str(id).ok())
54 });
55
56 ws.on_upgrade(move |socket| handle_socket(socket, state, user_id))
57}
58
59async fn handle_socket(socket: WebSocket, state: WsState, user_id: Option<Uuid>) {
61 let (mut sender, mut receiver) = socket.split();
62
63 let client_id = Uuid::new_v4();
65 tracing::info!("WebSocket client connected: {} (user: {:?})", client_id, user_id);
66
67 let mut subscriptions: Vec<Uuid> = Vec::new();
69
70 let mut event_rx = state.event_bus.subscribe();
72
73 loop {
74 select! {
75 msg = receiver.next() => {
77 match msg {
78 Some(Ok(Message::Text(text))) => {
79 if let Err(e) = handle_client_message(&text, client_id, user_id, &state, &mut subscriptions, &mut sender).await {
80 tracing::error!("Error handling client message: {}", e);
81 let _ = sender.send(Message::Text(
82 serde_json::to_string(&SyncMessage::Error {
83 message: e.to_string(),
84 }).unwrap().into()
85 )).await;
86 }
87 }
88 Some(Ok(Message::Close(_))) => {
89 tracing::info!("Client {} requested close", client_id);
90 break;
91 }
92 Some(Ok(Message::Ping(data))) => {
93 let _ = sender.send(Message::Pong(data)).await;
94 }
95 Some(Err(e)) => {
96 tracing::error!("WebSocket error: {}", e);
97 break;
98 }
99 None => {
100 tracing::info!("Client {} disconnected", client_id);
101 break;
102 }
103 _ => {}
104 }
105 }
106
107 event = event_rx.recv() => {
109 match event {
110 Ok(change_event) => {
111 if subscriptions.contains(&change_event.workspace_id) {
113 let msg = SyncMessage::Change { event: change_event };
114 if let Ok(json) = serde_json::to_string(&msg) {
115 let _ = sender.send(Message::Text(json.into())).await;
116 }
117 }
118 }
119 Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
120 tracing::warn!("Client {} lagged {} messages", client_id, n);
121 }
122 Err(_) => {
123 tracing::error!("Event channel closed");
124 break;
125 }
126 }
127 }
128 }
129 }
130
131 for workspace_id in subscriptions {
133 let _ = state.sync.unsubscribe(workspace_id, client_id);
134 }
135
136 tracing::info!("Client {} connection closed", client_id);
137}
138
139async fn handle_client_message(
141 text: &str,
142 client_id: Uuid,
143 user_id: Option<Uuid>,
144 state: &WsState,
145 subscriptions: &mut Vec<Uuid>,
146 sender: &mut futures::stream::SplitSink<WebSocket, Message>,
147) -> Result<()> {
148 let message: SyncMessage = serde_json::from_str(text)
149 .map_err(|e| CollabError::InvalidInput(format!("Invalid JSON: {e}")))?;
150
151 match message {
152 SyncMessage::Subscribe { workspace_id } => {
153 if let Some(uid) = user_id {
155 if let Err(e) = state.workspace.get_member(workspace_id, uid).await {
157 tracing::warn!(
158 "User {uid} attempted to access workspace {workspace_id} without permission: {e}"
159 );
160 return Err(CollabError::AuthorizationFailed(format!(
161 "Access denied to workspace {workspace_id}"
162 )));
163 }
164 } else {
165 return Err(CollabError::AuthenticationFailed(
168 "Authentication required for workspace access".to_string(),
169 ));
170 }
171
172 state.sync.subscribe(workspace_id, client_id)?;
174 subscriptions.push(workspace_id);
175
176 tracing::info!("Client {} subscribed to workspace {}", client_id, workspace_id);
177
178 if let Some(sync_state) = state.sync.get_state(workspace_id) {
180 let response = SyncMessage::StateResponse {
181 workspace_id,
182 version: sync_state.version,
183 state: sync_state.state,
184 };
185 let json = serde_json::to_string(&response)?;
186 sender
187 .send(Message::Text(json.into()))
188 .await
189 .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
190 }
191 }
192
193 SyncMessage::Unsubscribe { workspace_id } => {
194 state.sync.unsubscribe(workspace_id, client_id)?;
195 subscriptions.retain(|id| *id != workspace_id);
196
197 tracing::info!("Client {} unsubscribed from workspace {}", client_id, workspace_id);
198 }
199
200 SyncMessage::StateRequest {
201 workspace_id,
202 version,
203 } => {
204 if let Some(sync_state) = state.sync.get_state(workspace_id) {
206 if sync_state.version > version {
207 let response = SyncMessage::StateResponse {
208 workspace_id,
209 version: sync_state.version,
210 state: sync_state.state,
211 };
212 let json = serde_json::to_string(&response)?;
213 sender
214 .send(Message::Text(json.into()))
215 .await
216 .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
217 }
218 }
219 }
220
221 SyncMessage::Ping => {
222 let pong = SyncMessage::Pong;
223 let json = serde_json::to_string(&pong)?;
224 sender
225 .send(Message::Text(json.into()))
226 .await
227 .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
228 }
229
230 _ => {
231 tracing::warn!("Unexpected message type from client {}", client_id);
232 }
233 }
234
235 Ok(())
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn test_sync_message_serialization() {
244 let msg = SyncMessage::Subscribe {
245 workspace_id: Uuid::new_v4(),
246 };
247
248 let json = serde_json::to_string(&msg).unwrap();
249 assert!(json.contains("subscribe"));
250
251 let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
252 match deserialized {
253 SyncMessage::Subscribe { .. } => {}
254 _ => panic!("Wrong message type"),
255 }
256 }
257}