mockforge_collab/
websocket.rs

1//! WebSocket handler for real-time collaboration
2
3use crate::auth::AuthService;
4use crate::error::{CollabError, Result};
5use crate::events::EventBus;
6use crate::sync::{SyncEngine, SyncMessage};
7use crate::workspace;
8use axum::{
9    extract::{
10        ws::{Message, WebSocket, WebSocketUpgrade},
11        Query, State,
12    },
13    response::Response,
14};
15use futures::{sink::SinkExt, stream::StreamExt};
16use std::collections::HashMap;
17use std::sync::Arc;
18use tokio::select;
19use uuid::Uuid;
20
21/// WebSocket state
22#[derive(Clone)]
23pub struct WsState {
24    pub auth: Arc<AuthService>,
25    pub sync: Arc<SyncEngine>,
26    pub event_bus: Arc<EventBus>,
27    pub workspace: Arc<workspace::WorkspaceService>,
28}
29
30/// Handle WebSocket upgrade
31pub async fn ws_handler(
32    ws: WebSocketUpgrade,
33    Query(params): Query<HashMap<String, String>>,
34    State(state): State<WsState>,
35) -> Response {
36    // Extract user info from query params (token) or headers
37    let user_id = params
38        .get("token")
39        .and_then(|token| {
40            state
41                .auth
42                .verify_token(token)
43                .ok()
44                .and_then(|claims| Uuid::parse_str(&claims.sub).ok())
45        })
46        .or_else(|| {
47            // Fallback: try to get from user_id param (for development)
48            params
49                .get("user_id")
50                .and_then(|id| Uuid::parse_str(id).ok())
51        });
52
53    ws.on_upgrade(move |socket| handle_socket(socket, state, user_id))
54}
55
56/// Handle WebSocket connection
57async fn handle_socket(socket: WebSocket, state: WsState, user_id: Option<Uuid>) {
58    let (mut sender, mut receiver) = socket.split();
59
60    // Generate client ID
61    let client_id = Uuid::new_v4();
62    tracing::info!(
63        "WebSocket client connected: {} (user: {:?})",
64        client_id,
65        user_id
66    );
67
68    // Track subscribed workspaces
69    let mut subscriptions: Vec<Uuid> = Vec::new();
70
71    // Subscribe to event bus
72    let mut event_rx = state.event_bus.subscribe();
73
74    loop {
75        select! {
76            // Handle incoming messages from client
77            msg = receiver.next() => {
78                match msg {
79                    Some(Ok(Message::Text(text))) => {
80                        if let Err(e) = handle_client_message(&text, client_id, user_id, &state, &mut subscriptions, &mut sender).await {
81                            tracing::error!("Error handling client message: {}", e);
82                            let _ = sender.send(Message::Text(
83                                serde_json::to_string(&SyncMessage::Error {
84                                    message: e.to_string(),
85                                }).unwrap().into()
86                            )).await;
87                        }
88                    }
89                    Some(Ok(Message::Close(_))) => {
90                        tracing::info!("Client {} requested close", client_id);
91                        break;
92                    }
93                    Some(Ok(Message::Ping(data))) => {
94                        let _ = sender.send(Message::Pong(data)).await;
95                    }
96                    Some(Err(e)) => {
97                        tracing::error!("WebSocket error: {}", e);
98                        break;
99                    }
100                    None => {
101                        tracing::info!("Client {} disconnected", client_id);
102                        break;
103                    }
104                    _ => {}
105                }
106            }
107
108            // Handle broadcast events
109            event = event_rx.recv() => {
110                match event {
111                    Ok(change_event) => {
112                        // Only send events for subscribed workspaces
113                        if subscriptions.contains(&change_event.workspace_id) {
114                            let msg = SyncMessage::Change { event: change_event };
115                            if let Ok(json) = serde_json::to_string(&msg) {
116                                let _ = sender.send(Message::Text(json.into())).await;
117                            }
118                        }
119                    }
120                    Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
121                        tracing::warn!("Client {} lagged {} messages", client_id, n);
122                    }
123                    Err(_) => {
124                        tracing::error!("Event channel closed");
125                        break;
126                    }
127                }
128            }
129        }
130    }
131
132    // Cleanup: unsubscribe from all workspaces
133    for workspace_id in subscriptions {
134        let _ = state.sync.unsubscribe(workspace_id, client_id);
135    }
136
137    tracing::info!("Client {} connection closed", client_id);
138}
139
140/// Handle a message from the client
141async fn handle_client_message(
142    text: &str,
143    client_id: Uuid,
144    user_id: Option<Uuid>,
145    state: &WsState,
146    subscriptions: &mut Vec<Uuid>,
147    sender: &mut futures::stream::SplitSink<WebSocket, Message>,
148) -> Result<()> {
149    let message: SyncMessage = serde_json::from_str(text)
150        .map_err(|e| CollabError::InvalidInput(format!("Invalid JSON: {e}")))?;
151
152    match message {
153        SyncMessage::Subscribe { workspace_id } => {
154            // Verify user has access to workspace
155            if let Some(uid) = user_id {
156                // Check if user is a member of the workspace
157                if let Err(e) = state.workspace.get_member(workspace_id, uid).await {
158                    tracing::warn!(
159                        "User {} attempted to access workspace {} without permission: {}",
160                        uid,
161                        workspace_id,
162                        e
163                    );
164                    return Err(CollabError::AuthorizationFailed(format!(
165                        "Access denied to workspace {}",
166                        workspace_id
167                    )));
168                }
169            } else {
170                // No user ID provided - deny access in production
171                // In development, this might be allowed, but for security, we require authentication
172                return Err(CollabError::AuthenticationFailed(
173                    "Authentication required for workspace access".to_string(),
174                ));
175            }
176
177            // Subscribe to workspace
178            state.sync.subscribe(workspace_id, client_id)?;
179            subscriptions.push(workspace_id);
180
181            tracing::info!("Client {} subscribed to workspace {}", client_id, workspace_id);
182
183            // Send current state
184            if let Some(sync_state) = state.sync.get_state(workspace_id) {
185                let response = SyncMessage::StateResponse {
186                    workspace_id,
187                    version: sync_state.version,
188                    state: sync_state.state,
189                };
190                let json = serde_json::to_string(&response)?;
191                sender
192                    .send(Message::Text(json.into()))
193                    .await
194                    .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
195            }
196        }
197
198        SyncMessage::Unsubscribe { workspace_id } => {
199            state.sync.unsubscribe(workspace_id, client_id)?;
200            subscriptions.retain(|id| *id != workspace_id);
201
202            tracing::info!("Client {} unsubscribed from workspace {}", client_id, workspace_id);
203        }
204
205        SyncMessage::StateRequest {
206            workspace_id,
207            version,
208        } => {
209            // Check if client needs update
210            if let Some(sync_state) = state.sync.get_state(workspace_id) {
211                if sync_state.version > version {
212                    let response = SyncMessage::StateResponse {
213                        workspace_id,
214                        version: sync_state.version,
215                        state: sync_state.state,
216                    };
217                    let json = serde_json::to_string(&response)?;
218                    sender
219                        .send(Message::Text(json.into()))
220                        .await
221                        .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
222                }
223            }
224        }
225
226        SyncMessage::Ping => {
227            let pong = SyncMessage::Pong;
228            let json = serde_json::to_string(&pong)?;
229            sender
230                .send(Message::Text(json.into()))
231                .await
232                .map_err(|e| CollabError::Internal(format!("Failed to send: {e}")))?;
233        }
234
235        _ => {
236            tracing::warn!("Unexpected message type from client {}", client_id);
237        }
238    }
239
240    Ok(())
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_sync_message_serialization() {
249        let msg = SyncMessage::Subscribe {
250            workspace_id: Uuid::new_v4(),
251        };
252
253        let json = serde_json::to_string(&msg).unwrap();
254        assert!(json.contains("subscribe"));
255
256        let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
257        match deserialized {
258            SyncMessage::Subscribe { .. } => {}
259            _ => panic!("Wrong message type"),
260        }
261    }
262}