mockforge_collab/
websocket.rs1use crate::auth::AuthService;
4use crate::error::{CollabError, Result};
5use crate::events::EventBus;
6use crate::sync::{SyncEngine, SyncMessage};
7use crate::workspace;
8use axum::{
9 extract::{
10 ws::{Message, WebSocket, WebSocketUpgrade},
11 Query, State,
12 },
13 response::Response,
14};
15use futures::{sink::SinkExt, stream::StreamExt};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::select;
19use uuid::Uuid;
20
21#[derive(Clone)]
23pub struct WsState {
24 pub auth: Arc<AuthService>,
25 pub sync: Arc<SyncEngine>,
26 pub event_bus: Arc<EventBus>,
27 pub workspace: Arc<workspace::WorkspaceService>,
28}
29
30pub async fn ws_handler(
32 ws: WebSocketUpgrade,
33 Query(params): Query<HashMap<String, String>>,
34 State(state): State<WsState>,
35) -> Response {
36 let user_id = params
38 .get("token")
39 .and_then(|token| {
40 state
41 .auth
42 .verify_token(token)
43 .ok()
44 .and_then(|claims| Uuid::parse_str(&claims.sub).ok())
45 })
46 .or_else(|| {
47 params
49 .get("user_id")
50 .and_then(|id| Uuid::parse_str(id).ok())
51 });
52
53 ws.on_upgrade(move |socket| handle_socket(socket, state, user_id))
54}
55
56async fn handle_socket(socket: WebSocket, state: WsState, user_id: Option<Uuid>) {
58 let (mut sender, mut receiver) = socket.split();
59
60 let client_id = Uuid::new_v4();
62 tracing::info!(
63 "WebSocket client connected: {} (user: {:?})",
64 client_id,
65 user_id
66 );
67
68 let mut subscriptions: Vec<Uuid> = Vec::new();
70
71 let mut event_rx = state.event_bus.subscribe();
73
74 loop {
75 select! {
76 msg = receiver.next() => {
78 match msg {
79 Some(Ok(Message::Text(text))) => {
80 if let Err(e) = handle_client_message(&text, client_id, user_id, &state, &mut subscriptions, &mut sender).await {
81 tracing::error!("Error handling client message: {}", e);
82 let _ = sender.send(Message::Text(
83 serde_json::to_string(&SyncMessage::Error {
84 message: e.to_string(),
85 }).unwrap().into()
86 )).await;
87 }
88 }
89 Some(Ok(Message::Close(_))) => {
90 tracing::info!("Client {} requested close", client_id);
91 break;
92 }
93 Some(Ok(Message::Ping(data))) => {
94 let _ = sender.send(Message::Pong(data)).await;
95 }
96 Some(Err(e)) => {
97 tracing::error!("WebSocket error: {}", e);
98 break;
99 }
100 None => {
101 tracing::info!("Client {} disconnected", client_id);
102 break;
103 }
104 _ => {}
105 }
106 }
107
108 event = event_rx.recv() => {
110 match event {
111 Ok(change_event) => {
112 if subscriptions.contains(&change_event.workspace_id) {
114 let msg = SyncMessage::Change { event: change_event };
115 if let Ok(json) = serde_json::to_string(&msg) {
116 let _ = sender.send(Message::Text(json.into())).await;
117 }
118 }
119 }
120 Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
121 tracing::warn!("Client {} lagged {} messages", client_id, n);
122 }
123 Err(_) => {
124 tracing::error!("Event channel closed");
125 break;
126 }
127 }
128 }
129 }
130 }
131
132 for workspace_id in subscriptions {
134 let _ = state.sync.unsubscribe(workspace_id, client_id);
135 }
136
137 tracing::info!("Client {} connection closed", client_id);
138}
139
140async fn handle_client_message(
142 text: &str,
143 client_id: Uuid,
144 user_id: Option<Uuid>,
145 state: &WsState,
146 subscriptions: &mut Vec<Uuid>,
147 sender: &mut futures::stream::SplitSink<WebSocket, Message>,
148) -> Result<()> {
149 let message: SyncMessage = serde_json::from_str(text)
150 .map_err(|e| CollabError::InvalidInput(format!("Invalid JSON: {e}")))?;
151
152 match message {
153 SyncMessage::Subscribe { workspace_id } => {
154 if let Some(uid) = user_id {
156 if let Err(e) = state.workspace.get_member(workspace_id, uid).await {
158 tracing::warn!(
159 "User {} attempted to access workspace {} without permission: {}",
160 uid,
161 workspace_id,
162 e
163 );
164 return Err(CollabError::AuthorizationFailed(format!(
165 "Access denied to workspace {}",
166 workspace_id
167 )));
168 }
169 } else {
170 return Err(CollabError::AuthenticationFailed(
173 "Authentication required for workspace access".to_string(),
174 ));
175 }
176
177 state.sync.subscribe(workspace_id, client_id)?;
179 subscriptions.push(workspace_id);
180
181 tracing::info!("Client {} subscribed to workspace {}", client_id, workspace_id);
182
183 if let Some(sync_state) = state.sync.get_state(workspace_id) {
185 let response = SyncMessage::StateResponse {
186 workspace_id,
187 version: sync_state.version,
188 state: sync_state.state,
189 };
190 let json = serde_json::to_string(&response)?;
191 sender
192 .send(Message::Text(json.into()))
193 .await
194 .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
195 }
196 }
197
198 SyncMessage::Unsubscribe { workspace_id } => {
199 state.sync.unsubscribe(workspace_id, client_id)?;
200 subscriptions.retain(|id| *id != workspace_id);
201
202 tracing::info!("Client {} unsubscribed from workspace {}", client_id, workspace_id);
203 }
204
205 SyncMessage::StateRequest {
206 workspace_id,
207 version,
208 } => {
209 if let Some(sync_state) = state.sync.get_state(workspace_id) {
211 if sync_state.version > version {
212 let response = SyncMessage::StateResponse {
213 workspace_id,
214 version: sync_state.version,
215 state: sync_state.state,
216 };
217 let json = serde_json::to_string(&response)?;
218 sender
219 .send(Message::Text(json.into()))
220 .await
221 .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
222 }
223 }
224 }
225
226 SyncMessage::Ping => {
227 let pong = SyncMessage::Pong;
228 let json = serde_json::to_string(&pong)?;
229 sender
230 .send(Message::Text(json.into()))
231 .await
232 .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
233 }
234
235 _ => {
236 tracing::warn!("Unexpected message type from client {}", client_id);
237 }
238 }
239
240 Ok(())
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_sync_message_serialization() {
249 let msg = SyncMessage::Subscribe {
250 workspace_id: Uuid::new_v4(),
251 };
252
253 let json = serde_json::to_string(&msg).unwrap();
254 assert!(json.contains("subscribe"));
255
256 let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
257 match deserialized {
258 SyncMessage::Subscribe { .. } => {}
259 _ => panic!("Wrong message type"),
260 }
261 }
262}