mockforge_collab/
websocket.rs

1//! WebSocket handler for real-time collaboration
2
3use crate::auth::AuthService;
4use crate::error::{CollabError, Result};
5use crate::events::EventBus;
6use crate::sync::{SyncEngine, SyncMessage};
7use axum::{
8    extract::{
9        ws::{Message, WebSocket, WebSocketUpgrade},
10        State,
11    },
12    response::Response,
13};
14use futures::{sink::SinkExt, stream::StreamExt};
15use std::sync::Arc;
16use tokio::select;
17use uuid::Uuid;
18
19/// WebSocket state
20#[derive(Clone)]
21pub struct WsState {
22    pub auth: Arc<AuthService>,
23    pub sync: Arc<SyncEngine>,
24    pub event_bus: Arc<EventBus>,
25}
26
27/// Handle WebSocket upgrade
28pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<WsState>) -> Response {
29    ws.on_upgrade(|socket| handle_socket(socket, state))
30}
31
32/// Handle WebSocket connection
33async fn handle_socket(socket: WebSocket, state: WsState) {
34    let (mut sender, mut receiver) = socket.split();
35
36    // Generate client ID
37    let client_id = Uuid::new_v4();
38    tracing::info!("WebSocket client connected: {}", client_id);
39
40    // Track subscribed workspaces
41    let mut subscriptions: Vec<Uuid> = Vec::new();
42
43    // Subscribe to event bus
44    let mut event_rx = state.event_bus.subscribe();
45
46    loop {
47        select! {
48            // Handle incoming messages from client
49            msg = receiver.next() => {
50                match msg {
51                    Some(Ok(Message::Text(text))) => {
52                        if let Err(e) = handle_client_message(&text, client_id, &state, &mut subscriptions, &mut sender).await {
53                            tracing::error!("Error handling client message: {}", e);
54                            let _ = sender.send(Message::Text(
55                                serde_json::to_string(&SyncMessage::Error {
56                                    message: e.to_string(),
57                                }).unwrap().into()
58                            )).await;
59                        }
60                    }
61                    Some(Ok(Message::Close(_))) => {
62                        tracing::info!("Client {} requested close", client_id);
63                        break;
64                    }
65                    Some(Ok(Message::Ping(data))) => {
66                        let _ = sender.send(Message::Pong(data)).await;
67                    }
68                    Some(Err(e)) => {
69                        tracing::error!("WebSocket error: {}", e);
70                        break;
71                    }
72                    None => {
73                        tracing::info!("Client {} disconnected", client_id);
74                        break;
75                    }
76                    _ => {}
77                }
78            }
79
80            // Handle broadcast events
81            event = event_rx.recv() => {
82                match event {
83                    Ok(change_event) => {
84                        // Only send events for subscribed workspaces
85                        if subscriptions.contains(&change_event.workspace_id) {
86                            let msg = SyncMessage::Change { event: change_event };
87                            if let Ok(json) = serde_json::to_string(&msg) {
88                                let _ = sender.send(Message::Text(json.into())).await;
89                            }
90                        }
91                    }
92                    Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
93                        tracing::warn!("Client {} lagged {} messages", client_id, n);
94                    }
95                    Err(_) => {
96                        tracing::error!("Event channel closed");
97                        break;
98                    }
99                }
100            }
101        }
102    }
103
104    // Cleanup: unsubscribe from all workspaces
105    for workspace_id in subscriptions {
106        let _ = state.sync.unsubscribe(workspace_id, client_id);
107    }
108
109    tracing::info!("Client {} connection closed", client_id);
110}
111
112/// Handle a message from the client
113async fn handle_client_message(
114    text: &str,
115    client_id: Uuid,
116    state: &WsState,
117    subscriptions: &mut Vec<Uuid>,
118    sender: &mut futures::stream::SplitSink<WebSocket, Message>,
119) -> Result<()> {
120    let message: SyncMessage = serde_json::from_str(text)
121        .map_err(|e| CollabError::InvalidInput(format!("Invalid JSON: {}", e)))?;
122
123    match message {
124        SyncMessage::Subscribe { workspace_id } => {
125            // TODO: Verify user has access to workspace
126
127            // Subscribe to workspace
128            state.sync.subscribe(workspace_id, client_id)?;
129            subscriptions.push(workspace_id);
130
131            tracing::info!("Client {} subscribed to workspace {}", client_id, workspace_id);
132
133            // Send current state
134            if let Some(sync_state) = state.sync.get_state(workspace_id) {
135                let response = SyncMessage::StateResponse {
136                    workspace_id,
137                    version: sync_state.version,
138                    state: sync_state.state,
139                };
140                let json = serde_json::to_string(&response)?;
141                sender
142                    .send(Message::Text(json.into()))
143                    .await
144                    .map_err(|e| CollabError::Internal(format!("Failed to send: {}", e)))?;
145            }
146        }
147
148        SyncMessage::Unsubscribe { workspace_id } => {
149            state.sync.unsubscribe(workspace_id, client_id)?;
150            subscriptions.retain(|id| *id != workspace_id);
151
152            tracing::info!("Client {} unsubscribed from workspace {}", client_id, workspace_id);
153        }
154
155        SyncMessage::StateRequest {
156            workspace_id,
157            version,
158        } => {
159            // Check if client needs update
160            if let Some(sync_state) = state.sync.get_state(workspace_id) {
161                if sync_state.version > version {
162                    let response = SyncMessage::StateResponse {
163                        workspace_id,
164                        version: sync_state.version,
165                        state: sync_state.state,
166                    };
167                    let json = serde_json::to_string(&response)?;
168                    sender
169                        .send(Message::Text(json.into()))
170                        .await
171                        .map_err(|e| CollabError::Internal(format!("Failed to send: {}", e)))?;
172                }
173            }
174        }
175
176        SyncMessage::Ping => {
177            let pong = SyncMessage::Pong;
178            let json = serde_json::to_string(&pong)?;
179            sender
180                .send(Message::Text(json.into()))
181                .await
182                .map_err(|e| CollabError::Internal(format!("Failed to send: {}", e)))?;
183        }
184
185        _ => {
186            tracing::warn!("Unexpected message type from client {}", client_id);
187        }
188    }
189
190    Ok(())
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn test_sync_message_serialization() {
199        let msg = SyncMessage::Subscribe {
200            workspace_id: Uuid::new_v4(),
201        };
202
203        let json = serde_json::to_string(&msg).unwrap();
204        assert!(json.contains("subscribe"));
205
206        let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
207        match deserialized {
208            SyncMessage::Subscribe { .. } => {}
209            _ => panic!("Wrong message type"),
210        }
211    }
212}