mockforge_collab/
websocket.rs1use crate::auth::AuthService;
4use crate::error::{CollabError, Result};
5use crate::events::EventBus;
6use crate::sync::{SyncEngine, SyncMessage};
7use crate::workspace;
8use axum::{
9 extract::{
10 ws::{Message, WebSocket, WebSocketUpgrade},
11 Query, State,
12 },
13 response::Response,
14};
15use futures::{sink::SinkExt, stream::StreamExt};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::select;
19use uuid::Uuid;
20
21#[derive(Clone)]
23pub struct WsState {
24 pub auth: Arc<AuthService>,
25 pub sync: Arc<SyncEngine>,
26 pub event_bus: Arc<EventBus>,
27 pub workspace: Arc<workspace::WorkspaceService>,
28}
29
30pub async fn ws_handler(
32 ws: WebSocketUpgrade,
33 Query(params): Query<HashMap<String, String>>,
34 State(state): State<WsState>,
35) -> Response {
36 let user_id = params
38 .get("token")
39 .and_then(|token| {
40 state
41 .auth
42 .verify_token(token)
43 .ok()
44 .and_then(|claims| Uuid::parse_str(&claims.sub).ok())
45 })
46 .or_else(|| {
47 params.get("user_id").and_then(|id| Uuid::parse_str(id).ok())
49 });
50
51 ws.on_upgrade(move |socket| handle_socket(socket, state, user_id))
52}
53
54async fn handle_socket(socket: WebSocket, state: WsState, user_id: Option<Uuid>) {
56 let (mut sender, mut receiver) = socket.split();
57
58 let client_id = Uuid::new_v4();
60 tracing::info!("WebSocket client connected: {} (user: {:?})", client_id, user_id);
61
62 let mut subscriptions: Vec<Uuid> = Vec::new();
64
65 let mut event_rx = state.event_bus.subscribe();
67
68 loop {
69 select! {
70 msg = receiver.next() => {
72 match msg {
73 Some(Ok(Message::Text(text))) => {
74 if let Err(e) = handle_client_message(&text, client_id, user_id, &state, &mut subscriptions, &mut sender).await {
75 tracing::error!("Error handling client message: {}", e);
76 let _ = sender.send(Message::Text(
77 serde_json::to_string(&SyncMessage::Error {
78 message: e.to_string(),
79 }).unwrap().into()
80 )).await;
81 }
82 }
83 Some(Ok(Message::Close(_))) => {
84 tracing::info!("Client {} requested close", client_id);
85 break;
86 }
87 Some(Ok(Message::Ping(data))) => {
88 let _ = sender.send(Message::Pong(data)).await;
89 }
90 Some(Err(e)) => {
91 tracing::error!("WebSocket error: {}", e);
92 break;
93 }
94 None => {
95 tracing::info!("Client {} disconnected", client_id);
96 break;
97 }
98 _ => {}
99 }
100 }
101
102 event = event_rx.recv() => {
104 match event {
105 Ok(change_event) => {
106 if subscriptions.contains(&change_event.workspace_id) {
108 let msg = SyncMessage::Change { event: change_event };
109 if let Ok(json) = serde_json::to_string(&msg) {
110 let _ = sender.send(Message::Text(json.into())).await;
111 }
112 }
113 }
114 Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
115 tracing::warn!("Client {} lagged {} messages", client_id, n);
116 }
117 Err(_) => {
118 tracing::error!("Event channel closed");
119 break;
120 }
121 }
122 }
123 }
124 }
125
126 for workspace_id in subscriptions {
128 let _ = state.sync.unsubscribe(workspace_id, client_id);
129 }
130
131 tracing::info!("Client {} connection closed", client_id);
132}
133
134async fn handle_client_message(
136 text: &str,
137 client_id: Uuid,
138 user_id: Option<Uuid>,
139 state: &WsState,
140 subscriptions: &mut Vec<Uuid>,
141 sender: &mut futures::stream::SplitSink<WebSocket, Message>,
142) -> Result<()> {
143 let message: SyncMessage = serde_json::from_str(text)
144 .map_err(|e| CollabError::InvalidInput(format!("Invalid JSON: {e}")))?;
145
146 match message {
147 SyncMessage::Subscribe { workspace_id } => {
148 if let Some(uid) = user_id {
150 if let Err(e) = state.workspace.get_member(workspace_id, uid).await {
152 tracing::warn!(
153 "User {} attempted to access workspace {} without permission: {}",
154 uid,
155 workspace_id,
156 e
157 );
158 return Err(CollabError::AuthorizationFailed(format!(
159 "Access denied to workspace {}",
160 workspace_id
161 )));
162 }
163 } else {
164 return Err(CollabError::AuthenticationFailed(
167 "Authentication required for workspace access".to_string(),
168 ));
169 }
170
171 state.sync.subscribe(workspace_id, client_id)?;
173 subscriptions.push(workspace_id);
174
175 tracing::info!("Client {} subscribed to workspace {}", client_id, workspace_id);
176
177 if let Some(sync_state) = state.sync.get_state(workspace_id) {
179 let response = SyncMessage::StateResponse {
180 workspace_id,
181 version: sync_state.version,
182 state: sync_state.state,
183 };
184 let json = serde_json::to_string(&response)?;
185 sender
186 .send(Message::Text(json.into()))
187 .await
188 .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
189 }
190 }
191
192 SyncMessage::Unsubscribe { workspace_id } => {
193 state.sync.unsubscribe(workspace_id, client_id)?;
194 subscriptions.retain(|id| *id != workspace_id);
195
196 tracing::info!("Client {} unsubscribed from workspace {}", client_id, workspace_id);
197 }
198
199 SyncMessage::StateRequest {
200 workspace_id,
201 version,
202 } => {
203 if let Some(sync_state) = state.sync.get_state(workspace_id) {
205 if sync_state.version > version {
206 let response = SyncMessage::StateResponse {
207 workspace_id,
208 version: sync_state.version,
209 state: sync_state.state,
210 };
211 let json = serde_json::to_string(&response)?;
212 sender
213 .send(Message::Text(json.into()))
214 .await
215 .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
216 }
217 }
218 }
219
220 SyncMessage::Ping => {
221 let pong = SyncMessage::Pong;
222 let json = serde_json::to_string(&pong)?;
223 sender
224 .send(Message::Text(json.into()))
225 .await
226 .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
227 }
228
229 _ => {
230 tracing::warn!("Unexpected message type from client {}", client_id);
231 }
232 }
233
234 Ok(())
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn test_sync_message_serialization() {
243 let msg = SyncMessage::Subscribe {
244 workspace_id: Uuid::new_v4(),
245 };
246
247 let json = serde_json::to_string(&msg).unwrap();
248 assert!(json.contains("subscribe"));
249
250 let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
251 match deserialized {
252 SyncMessage::Subscribe { .. } => {}
253 _ => panic!("Wrong message type"),
254 }
255 }
256}