mockforge_collab/
websocket.rs

1//! WebSocket handler for real-time collaboration
2
3use crate::auth::AuthService;
4use crate::error::{CollabError, Result};
5use crate::events::EventBus;
6use crate::sync::{SyncEngine, SyncMessage};
7use crate::workspace;
8use axum::{
9    extract::{
10        ws::{Message, WebSocket, WebSocketUpgrade},
11        Query, State,
12    },
13    response::Response,
14};
15use futures::{sink::SinkExt, stream::StreamExt};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::select;
19use uuid::Uuid;
20
21/// WebSocket state
22#[derive(Clone)]
23pub struct WsState {
24    pub auth: Arc<AuthService>,
25    pub sync: Arc<SyncEngine>,
26    pub event_bus: Arc<EventBus>,
27    pub workspace: Arc<workspace::WorkspaceService>,
28}
29
30/// Handle WebSocket upgrade
31pub async fn ws_handler(
32    ws: WebSocketUpgrade,
33    Query(params): Query<HashMap<String, String>>,
34    State(state): State<WsState>,
35) -> Response {
36    // Extract user info from query params (token) or headers
37    let user_id = params
38        .get("token")
39        .and_then(|token| {
40            state
41                .auth
42                .verify_token(token)
43                .ok()
44                .and_then(|claims| Uuid::parse_str(&claims.sub).ok())
45        })
46        .or_else(|| {
47            // Fallback: try to get from user_id param (for development)
48            params.get("user_id").and_then(|id| Uuid::parse_str(id).ok())
49        });
50
51    ws.on_upgrade(move |socket| handle_socket(socket, state, user_id))
52}
53
54/// Handle WebSocket connection
55async fn handle_socket(socket: WebSocket, state: WsState, user_id: Option<Uuid>) {
56    let (mut sender, mut receiver) = socket.split();
57
58    // Generate client ID
59    let client_id = Uuid::new_v4();
60    tracing::info!("WebSocket client connected: {} (user: {:?})", client_id, user_id);
61
62    // Track subscribed workspaces
63    let mut subscriptions: Vec<Uuid> = Vec::new();
64
65    // Subscribe to event bus
66    let mut event_rx = state.event_bus.subscribe();
67
68    loop {
69        select! {
70            // Handle incoming messages from client
71            msg = receiver.next() => {
72                match msg {
73                    Some(Ok(Message::Text(text))) => {
74                        if let Err(e) = handle_client_message(&text, client_id, user_id, &state, &mut subscriptions, &mut sender).await {
75                            tracing::error!("Error handling client message: {}", e);
76                            let _ = sender.send(Message::Text(
77                                serde_json::to_string(&SyncMessage::Error {
78                                    message: e.to_string(),
79                                }).unwrap().into()
80                            )).await;
81                        }
82                    }
83                    Some(Ok(Message::Close(_))) => {
84                        tracing::info!("Client {} requested close", client_id);
85                        break;
86                    }
87                    Some(Ok(Message::Ping(data))) => {
88                        let _ = sender.send(Message::Pong(data)).await;
89                    }
90                    Some(Err(e)) => {
91                        tracing::error!("WebSocket error: {}", e);
92                        break;
93                    }
94                    None => {
95                        tracing::info!("Client {} disconnected", client_id);
96                        break;
97                    }
98                    _ => {}
99                }
100            }
101
102            // Handle broadcast events
103            event = event_rx.recv() => {
104                match event {
105                    Ok(change_event) => {
106                        // Only send events for subscribed workspaces
107                        if subscriptions.contains(&change_event.workspace_id) {
108                            let msg = SyncMessage::Change { event: change_event };
109                            if let Ok(json) = serde_json::to_string(&msg) {
110                                let _ = sender.send(Message::Text(json.into())).await;
111                            }
112                        }
113                    }
114                    Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
115                        tracing::warn!("Client {} lagged {} messages", client_id, n);
116                    }
117                    Err(_) => {
118                        tracing::error!("Event channel closed");
119                        break;
120                    }
121                }
122            }
123        }
124    }
125
126    // Cleanup: unsubscribe from all workspaces
127    for workspace_id in subscriptions {
128        let _ = state.sync.unsubscribe(workspace_id, client_id);
129    }
130
131    tracing::info!("Client {} connection closed", client_id);
132}
133
134/// Handle a message from the client
135async fn handle_client_message(
136    text: &str,
137    client_id: Uuid,
138    user_id: Option<Uuid>,
139    state: &WsState,
140    subscriptions: &mut Vec<Uuid>,
141    sender: &mut futures::stream::SplitSink<WebSocket, Message>,
142) -> Result<()> {
143    let message: SyncMessage = serde_json::from_str(text)
144        .map_err(|e| CollabError::InvalidInput(format!("Invalid JSON: {e}")))?;
145
146    match message {
147        SyncMessage::Subscribe { workspace_id } => {
148            // Verify user has access to workspace
149            if let Some(uid) = user_id {
150                // Check if user is a member of the workspace
151                if let Err(e) = state.workspace.get_member(workspace_id, uid).await {
152                    tracing::warn!(
153                        "User {} attempted to access workspace {} without permission: {}",
154                        uid,
155                        workspace_id,
156                        e
157                    );
158                    return Err(CollabError::AuthorizationFailed(format!(
159                        "Access denied to workspace {}",
160                        workspace_id
161                    )));
162                }
163            } else {
164                // No user ID provided - deny access in production
165                // In development, this might be allowed, but for security, we require authentication
166                return Err(CollabError::AuthenticationFailed(
167                    "Authentication required for workspace access".to_string(),
168                ));
169            }
170
171            // Subscribe to workspace
172            state.sync.subscribe(workspace_id, client_id)?;
173            subscriptions.push(workspace_id);
174
175            tracing::info!("Client {} subscribed to workspace {}", client_id, workspace_id);
176
177            // Send current state
178            if let Some(sync_state) = state.sync.get_state(workspace_id) {
179                let response = SyncMessage::StateResponse {
180                    workspace_id,
181                    version: sync_state.version,
182                    state: sync_state.state,
183                };
184                let json = serde_json::to_string(&response)?;
185                sender
186                    .send(Message::Text(json.into()))
187                    .await
188                    .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
189            }
190        }
191
192        SyncMessage::Unsubscribe { workspace_id } => {
193            state.sync.unsubscribe(workspace_id, client_id)?;
194            subscriptions.retain(|id| *id != workspace_id);
195
196            tracing::info!("Client {} unsubscribed from workspace {}", client_id, workspace_id);
197        }
198
199        SyncMessage::StateRequest {
200            workspace_id,
201            version,
202        } => {
203            // Check if client needs update
204            if let Some(sync_state) = state.sync.get_state(workspace_id) {
205                if sync_state.version > version {
206                    let response = SyncMessage::StateResponse {
207                        workspace_id,
208                        version: sync_state.version,
209                        state: sync_state.state,
210                    };
211                    let json = serde_json::to_string(&response)?;
212                    sender
213                        .send(Message::Text(json.into()))
214                        .await
215                        .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
216                }
217            }
218        }
219
220        SyncMessage::Ping => {
221            let pong = SyncMessage::Pong;
222            let json = serde_json::to_string(&pong)?;
223            sender
224                .send(Message::Text(json.into()))
225                .await
226                .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
227        }
228
229        _ => {
230            tracing::warn!("Unexpected message type from client {}", client_id);
231        }
232    }
233
234    Ok(())
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn test_sync_message_serialization() {
243        let msg = SyncMessage::Subscribe {
244            workspace_id: Uuid::new_v4(),
245        };
246
247        let json = serde_json::to_string(&msg).unwrap();
248        assert!(json.contains("subscribe"));
249
250        let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
251        match deserialized {
252            SyncMessage::Subscribe { .. } => {}
253            _ => panic!("Wrong message type"),
254        }
255    }
256}