Skip to main content

mockforge_collab/
websocket.rs

1//! WebSocket handler for real-time collaboration
2
3use crate::auth::AuthService;
4use crate::error::{CollabError, Result};
5use crate::events::EventBus;
6use crate::sync::{SyncEngine, SyncMessage};
7use crate::workspace;
8use axum::{
9    extract::{
10        ws::{Message, WebSocket, WebSocketUpgrade},
11        Query, State,
12    },
13    response::Response,
14};
15use futures::{sink::SinkExt, stream::StreamExt};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::select;
19use uuid::Uuid;
20
21/// WebSocket state
22#[derive(Clone)]
23pub struct WsState {
24    /// Authentication service
25    pub auth: Arc<AuthService>,
26    /// Sync engine
27    pub sync: Arc<SyncEngine>,
28    /// Event bus
29    pub event_bus: Arc<EventBus>,
30    /// Workspace service
31    pub workspace: Arc<workspace::WorkspaceService>,
32}
33
34/// Handle WebSocket upgrade
35#[allow(clippy::implicit_hasher)]
36pub async fn ws_handler(
37    ws: WebSocketUpgrade,
38    Query(params): Query<HashMap<String, String>>,
39    State(state): State<WsState>,
40) -> Response {
41    // Extract user info from query params (token) or headers
42    let user_id = params
43        .get("token")
44        .and_then(|token| {
45            state
46                .auth
47                .verify_token(token)
48                .ok()
49                .and_then(|claims| Uuid::parse_str(&claims.sub).ok())
50        })
51        .or_else(|| {
52            // Fallback: try to get from user_id param (for development)
53            params.get("user_id").and_then(|id| Uuid::parse_str(id).ok())
54        });
55
56    ws.on_upgrade(move |socket| handle_socket(socket, state, user_id))
57}
58
59/// Handle WebSocket connection
60async fn handle_socket(socket: WebSocket, state: WsState, user_id: Option<Uuid>) {
61    let (mut sender, mut receiver) = socket.split();
62
63    // Generate client ID
64    let client_id = Uuid::new_v4();
65    tracing::info!("WebSocket client connected: {} (user: {:?})", client_id, user_id);
66
67    // Track subscribed workspaces
68    let mut subscriptions: Vec<Uuid> = Vec::new();
69
70    // Subscribe to event bus
71    let mut event_rx = state.event_bus.subscribe();
72
73    loop {
74        select! {
75            // Handle incoming messages from client
76            msg = receiver.next() => {
77                match msg {
78                    Some(Ok(Message::Text(text))) => {
79                        if let Err(e) = handle_client_message(&text, client_id, user_id, &state, &mut subscriptions, &mut sender).await {
80                            tracing::error!("Error handling client message: {}", e);
81                            let _ = sender.send(Message::Text(
82                                serde_json::to_string(&SyncMessage::Error {
83                                    message: e.to_string(),
84                                }).unwrap().into()
85                            )).await;
86                        }
87                    }
88                    Some(Ok(Message::Close(_))) => {
89                        tracing::info!("Client {} requested close", client_id);
90                        break;
91                    }
92                    Some(Ok(Message::Ping(data))) => {
93                        let _ = sender.send(Message::Pong(data)).await;
94                    }
95                    Some(Err(e)) => {
96                        tracing::error!("WebSocket error: {}", e);
97                        break;
98                    }
99                    None => {
100                        tracing::info!("Client {} disconnected", client_id);
101                        break;
102                    }
103                    _ => {}
104                }
105            }
106
107            // Handle broadcast events
108            event = event_rx.recv() => {
109                match event {
110                    Ok(change_event) => {
111                        // Only send events for subscribed workspaces
112                        if subscriptions.contains(&change_event.workspace_id) {
113                            let msg = SyncMessage::Change { event: change_event };
114                            if let Ok(json) = serde_json::to_string(&msg) {
115                                let _ = sender.send(Message::Text(json.into())).await;
116                            }
117                        }
118                    }
119                    Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
120                        tracing::warn!("Client {} lagged {} messages", client_id, n);
121                    }
122                    Err(_) => {
123                        tracing::error!("Event channel closed");
124                        break;
125                    }
126                }
127            }
128        }
129    }
130
131    // Cleanup: unsubscribe from all workspaces
132    for workspace_id in subscriptions {
133        let _ = state.sync.unsubscribe(workspace_id, client_id);
134    }
135
136    tracing::info!("Client {} connection closed", client_id);
137}
138
139/// Handle a message from the client
140async fn handle_client_message(
141    text: &str,
142    client_id: Uuid,
143    user_id: Option<Uuid>,
144    state: &WsState,
145    subscriptions: &mut Vec<Uuid>,
146    sender: &mut futures::stream::SplitSink<WebSocket, Message>,
147) -> Result<()> {
148    let message: SyncMessage = serde_json::from_str(text)
149        .map_err(|e| CollabError::InvalidInput(format!("Invalid JSON: {e}")))?;
150
151    match message {
152        SyncMessage::Subscribe { workspace_id } => {
153            // Verify user has access to workspace
154            if let Some(uid) = user_id {
155                // Check if user is a member of the workspace
156                if let Err(e) = state.workspace.get_member(workspace_id, uid).await {
157                    tracing::warn!(
158                        "User {uid} attempted to access workspace {workspace_id} without permission: {e}"
159                    );
160                    return Err(CollabError::AuthorizationFailed(format!(
161                        "Access denied to workspace {workspace_id}"
162                    )));
163                }
164            } else {
165                // No user ID provided - deny access in production
166                // In development, this might be allowed, but for security, we require authentication
167                return Err(CollabError::AuthenticationFailed(
168                    "Authentication required for workspace access".to_string(),
169                ));
170            }
171
172            // Subscribe to workspace
173            state.sync.subscribe(workspace_id, client_id)?;
174            subscriptions.push(workspace_id);
175
176            tracing::info!("Client {} subscribed to workspace {}", client_id, workspace_id);
177
178            // Send current state
179            if let Some(sync_state) = state.sync.get_state(workspace_id) {
180                let response = SyncMessage::StateResponse {
181                    workspace_id,
182                    version: sync_state.version,
183                    state: sync_state.state,
184                };
185                let json = serde_json::to_string(&response)?;
186                sender
187                    .send(Message::Text(json.into()))
188                    .await
189                    .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
190            }
191        }
192
193        SyncMessage::Unsubscribe { workspace_id } => {
194            state.sync.unsubscribe(workspace_id, client_id)?;
195            subscriptions.retain(|id| *id != workspace_id);
196
197            tracing::info!("Client {} unsubscribed from workspace {}", client_id, workspace_id);
198        }
199
200        SyncMessage::StateRequest {
201            workspace_id,
202            version,
203        } => {
204            // Check if client needs update
205            if let Some(sync_state) = state.sync.get_state(workspace_id) {
206                if sync_state.version > version {
207                    let response = SyncMessage::StateResponse {
208                        workspace_id,
209                        version: sync_state.version,
210                        state: sync_state.state,
211                    };
212                    let json = serde_json::to_string(&response)?;
213                    sender
214                        .send(Message::Text(json.into()))
215                        .await
216                        .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
217                }
218            }
219        }
220
221        SyncMessage::Ping => {
222            let pong = SyncMessage::Pong;
223            let json = serde_json::to_string(&pong)?;
224            sender
225                .send(Message::Text(json.into()))
226                .await
227                .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
228        }
229
230        _ => {
231            tracing::warn!("Unexpected message type from client {}", client_id);
232        }
233    }
234
235    Ok(())
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn test_sync_message_serialization() {
244        let msg = SyncMessage::Subscribe {
245            workspace_id: Uuid::new_v4(),
246        };
247
248        let json = serde_json::to_string(&msg).unwrap();
249        assert!(json.contains("subscribe"));
250
251        let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
252        match deserialized {
253            SyncMessage::Subscribe { .. } => {}
254            _ => panic!("Wrong message type"),
255        }
256    }
257}