mockforge_collab/
websocket.rs1use crate::auth::AuthService;
4use crate::error::{CollabError, Result};
5use crate::events::EventBus;
6use crate::sync::{SyncEngine, SyncMessage};
7use axum::{
8 extract::{
9 ws::{Message, WebSocket, WebSocketUpgrade},
10 State,
11 },
12 response::Response,
13};
14use futures::{sink::SinkExt, stream::StreamExt};
15use std::sync::Arc;
16use tokio::select;
17use uuid::Uuid;
18
19#[derive(Clone)]
21pub struct WsState {
22 pub auth: Arc<AuthService>,
23 pub sync: Arc<SyncEngine>,
24 pub event_bus: Arc<EventBus>,
25}
26
27pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<WsState>) -> Response {
29 ws.on_upgrade(|socket| handle_socket(socket, state))
30}
31
32async fn handle_socket(socket: WebSocket, state: WsState) {
34 let (mut sender, mut receiver) = socket.split();
35
36 let client_id = Uuid::new_v4();
38 tracing::info!("WebSocket client connected: {}", client_id);
39
40 let mut subscriptions: Vec<Uuid> = Vec::new();
42
43 let mut event_rx = state.event_bus.subscribe();
45
46 loop {
47 select! {
48 msg = receiver.next() => {
50 match msg {
51 Some(Ok(Message::Text(text))) => {
52 if let Err(e) = handle_client_message(&text, client_id, &state, &mut subscriptions, &mut sender).await {
53 tracing::error!("Error handling client message: {}", e);
54 let _ = sender.send(Message::Text(
55 serde_json::to_string(&SyncMessage::Error {
56 message: e.to_string(),
57 }).unwrap().into()
58 )).await;
59 }
60 }
61 Some(Ok(Message::Close(_))) => {
62 tracing::info!("Client {} requested close", client_id);
63 break;
64 }
65 Some(Ok(Message::Ping(data))) => {
66 let _ = sender.send(Message::Pong(data)).await;
67 }
68 Some(Err(e)) => {
69 tracing::error!("WebSocket error: {}", e);
70 break;
71 }
72 None => {
73 tracing::info!("Client {} disconnected", client_id);
74 break;
75 }
76 _ => {}
77 }
78 }
79
80 event = event_rx.recv() => {
82 match event {
83 Ok(change_event) => {
84 if subscriptions.contains(&change_event.workspace_id) {
86 let msg = SyncMessage::Change { event: change_event };
87 if let Ok(json) = serde_json::to_string(&msg) {
88 let _ = sender.send(Message::Text(json.into())).await;
89 }
90 }
91 }
92 Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
93 tracing::warn!("Client {} lagged {} messages", client_id, n);
94 }
95 Err(_) => {
96 tracing::error!("Event channel closed");
97 break;
98 }
99 }
100 }
101 }
102 }
103
104 for workspace_id in subscriptions {
106 let _ = state.sync.unsubscribe(workspace_id, client_id);
107 }
108
109 tracing::info!("Client {} connection closed", client_id);
110}
111
112async fn handle_client_message(
114 text: &str,
115 client_id: Uuid,
116 state: &WsState,
117 subscriptions: &mut Vec<Uuid>,
118 sender: &mut futures::stream::SplitSink<WebSocket, Message>,
119) -> Result<()> {
120 let message: SyncMessage = serde_json::from_str(text)
121 .map_err(|e| CollabError::InvalidInput(format!("Invalid JSON: {}", e)))?;
122
123 match message {
124 SyncMessage::Subscribe { workspace_id } => {
125 state.sync.subscribe(workspace_id, client_id)?;
129 subscriptions.push(workspace_id);
130
131 tracing::info!("Client {} subscribed to workspace {}", client_id, workspace_id);
132
133 if let Some(sync_state) = state.sync.get_state(workspace_id) {
135 let response = SyncMessage::StateResponse {
136 workspace_id,
137 version: sync_state.version,
138 state: sync_state.state,
139 };
140 let json = serde_json::to_string(&response)?;
141 sender
142 .send(Message::Text(json.into()))
143 .await
144 .map_err(|e| CollabError::Internal(format!("Failed to send: {}", e)))?;
145 }
146 }
147
148 SyncMessage::Unsubscribe { workspace_id } => {
149 state.sync.unsubscribe(workspace_id, client_id)?;
150 subscriptions.retain(|id| *id != workspace_id);
151
152 tracing::info!("Client {} unsubscribed from workspace {}", client_id, workspace_id);
153 }
154
155 SyncMessage::StateRequest {
156 workspace_id,
157 version,
158 } => {
159 if let Some(sync_state) = state.sync.get_state(workspace_id) {
161 if sync_state.version > version {
162 let response = SyncMessage::StateResponse {
163 workspace_id,
164 version: sync_state.version,
165 state: sync_state.state,
166 };
167 let json = serde_json::to_string(&response)?;
168 sender
169 .send(Message::Text(json.into()))
170 .await
171 .map_err(|e| CollabError::Internal(format!("Failed to send: {}", e)))?;
172 }
173 }
174 }
175
176 SyncMessage::Ping => {
177 let pong = SyncMessage::Pong;
178 let json = serde_json::to_string(&pong)?;
179 sender
180 .send(Message::Text(json.into()))
181 .await
182 .map_err(|e| CollabError::Internal(format!("Failed to send: {}", e)))?;
183 }
184
185 _ => {
186 tracing::warn!("Unexpected message type from client {}", client_id);
187 }
188 }
189
190 Ok(())
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn test_sync_message_serialization() {
199 let msg = SyncMessage::Subscribe {
200 workspace_id: Uuid::new_v4(),
201 };
202
203 let json = serde_json::to_string(&msg).unwrap();
204 assert!(json.contains("subscribe"));
205
206 let deserialized: SyncMessage = serde_json::from_str(&json).unwrap();
207 match deserialized {
208 SyncMessage::Subscribe { .. } => {}
209 _ => panic!("Wrong message type"),
210 }
211 }
212}