1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::{
5 extract::{
6 State, WebSocketUpgrade,
7 ws::{Message, WebSocket},
8 },
9 response::Response,
10};
11use futures_util::{SinkExt, StreamExt};
12use serde::{Deserialize, Serialize};
13use sqlx::PgPool;
14use tokio::sync::{RwLock, mpsc};
15use uuid::Uuid;
16
17use forge_core::cluster::NodeId;
18use forge_core::realtime::SessionId;
19
20use crate::realtime::{Reactor, WebSocketMessage as ReactorMessage};
21
22fn parse_uuid(s: &str, field_name: &str) -> Result<Uuid, String> {
25 if s.len() > 36 {
27 return Err(format!("Invalid {}: too long", field_name));
28 }
29 Uuid::parse_str(s).map_err(|_| format!("Invalid {}: must be a valid UUID", field_name))
30}
31
32const MAX_CLIENT_SUB_ID_LEN: usize = 255;
34
35#[derive(Clone)]
37pub struct WsState {
38 pub reactor: Arc<Reactor>,
39 pub db_pool: PgPool,
40 pub node_id: NodeId,
41}
42
43impl WsState {
44 pub fn new(reactor: Arc<Reactor>, db_pool: PgPool, node_id: NodeId) -> Self {
45 Self {
46 reactor,
47 db_pool,
48 node_id,
49 }
50 }
51}
52
53#[derive(Debug, Deserialize)]
55#[serde(tag = "type", rename_all = "snake_case")]
56pub enum ClientMessage {
57 Subscribe {
59 id: String,
60 #[serde(rename = "function")]
61 function_name: String,
62 args: Option<serde_json::Value>,
63 },
64 Unsubscribe { id: String },
66 SubscribeJob {
68 id: String,
70 job_id: String,
72 },
73 UnsubscribeJob { id: String },
75 SubscribeWorkflow {
77 id: String,
79 workflow_id: String,
81 },
82 UnsubscribeWorkflow { id: String },
84 Ping,
86 Auth {
88 #[allow(dead_code)]
89 token: String,
90 },
91}
92
93#[derive(Debug, Serialize)]
95#[serde(tag = "type", rename_all = "snake_case")]
96pub enum ServerMessage {
97 Connected,
99 Pong,
101 Data { id: String, data: serde_json::Value },
103 JobUpdate { id: String, job: JobData },
105 WorkflowUpdate { id: String, workflow: WorkflowData },
107 Error {
109 id: Option<String>,
110 code: String,
111 message: String,
112 },
113 #[allow(dead_code)]
115 Subscribed { id: String },
116 #[allow(dead_code)]
118 Unsubscribed { id: String },
119}
120
121#[derive(Debug, Clone, Serialize)]
123pub struct JobData {
124 pub job_id: String,
125 pub status: String,
126 pub progress_percent: Option<i32>,
127 pub progress_message: Option<String>,
128 pub output: Option<serde_json::Value>,
129 pub error: Option<String>,
130}
131
132#[derive(Debug, Clone, Serialize)]
134pub struct WorkflowData {
135 pub workflow_id: String,
136 pub status: String,
137 pub current_step: Option<String>,
138 pub steps: Vec<WorkflowStepData>,
139 pub output: Option<serde_json::Value>,
140 pub error: Option<String>,
141}
142
143#[derive(Debug, Clone, Serialize)]
145pub struct WorkflowStepData {
146 pub name: String,
147 pub status: String,
148 pub error: Option<String>,
149}
150
151pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<WsState>>) -> Response {
153 ws.on_upgrade(move |socket| handle_socket(socket, state))
154}
155
156async fn handle_socket(socket: WebSocket, state: Arc<WsState>) {
158 let (mut ws_sender, mut ws_receiver) = socket.split();
159
160 let session_id = SessionId::new();
162 let session_uuid = session_id.0;
163 let node_uuid = state.node_id.0;
164
165 let _ = sqlx::query(
167 r#"
168 INSERT INTO forge_sessions (id, node_id, status, connected_at, last_activity)
169 VALUES ($1, $2, 'connected', NOW(), NOW())
170 ON CONFLICT (id) DO UPDATE SET status = 'connected', last_activity = NOW()
171 "#,
172 )
173 .bind(session_uuid)
174 .bind(node_uuid)
175 .execute(&state.db_pool)
176 .await;
177
178 let (reactor_tx, mut reactor_rx) = mpsc::channel::<ReactorMessage>(256);
180
181 state.reactor.register_session(session_id, reactor_tx).await;
183
184 #[allow(clippy::type_complexity)]
186 let client_to_internal: Arc<RwLock<HashMap<String, forge_core::realtime::SubscriptionId>>> =
187 Arc::new(RwLock::new(HashMap::new()));
188 let internal_to_client: Arc<RwLock<HashMap<forge_core::realtime::SubscriptionId, String>>> =
189 Arc::new(RwLock::new(HashMap::new()));
190
191 let connected = ServerMessage::Connected;
193 if let Ok(json) = serde_json::to_string(&connected) {
194 let _ = ws_sender.send(Message::Text(json.into())).await;
195 }
196
197 tracing::debug!(?session_id, "WebSocket connection established");
198
199 let internal_to_client_clone = internal_to_client.clone();
201
202 let sender_handle = tokio::spawn(async move {
204 while let Some(msg) = reactor_rx.recv().await {
205 let server_msg = match msg {
206 ReactorMessage::Data {
207 subscription_id,
208 data,
209 } => {
210 let client_id = {
212 let map = internal_to_client_clone.read().await;
213 map.get(&subscription_id).cloned()
214 };
215
216 if let Some(id) = client_id {
217 ServerMessage::Data { id, data }
218 } else {
219 continue;
220 }
221 }
222 ReactorMessage::DeltaUpdate {
223 subscription_id,
224 delta,
225 } => {
226 let client_id = {
228 let map = internal_to_client_clone.read().await;
229 map.get(&subscription_id).cloned()
230 };
231
232 if let Some(id) = client_id {
233 ServerMessage::Data {
235 id,
236 data: serde_json::json!({
237 "delta": {
238 "added": delta.added,
239 "removed": delta.removed,
240 "updated": delta.updated
241 }
242 }),
243 }
244 } else {
245 continue;
246 }
247 }
248 ReactorMessage::JobUpdate { client_sub_id, job } => ServerMessage::JobUpdate {
249 id: client_sub_id,
250 job,
251 },
252 ReactorMessage::WorkflowUpdate {
253 client_sub_id,
254 workflow,
255 } => ServerMessage::WorkflowUpdate {
256 id: client_sub_id,
257 workflow,
258 },
259 ReactorMessage::Error { code, message } => ServerMessage::Error {
260 id: None,
261 code,
262 message,
263 },
264 ReactorMessage::ErrorWithId { id, code, message } => ServerMessage::Error {
265 id: Some(id),
266 code,
267 message,
268 },
269 ReactorMessage::Ping => ServerMessage::Pong,
270 ReactorMessage::Pong => continue,
271 _ => continue,
272 };
273
274 if let Ok(json) = serde_json::to_string(&server_msg) {
275 if ws_sender.send(Message::Text(json.into())).await.is_err() {
276 break;
277 }
278 }
279 }
280 });
281
282 while let Some(msg) = ws_receiver.next().await {
284 let msg = match msg {
285 Ok(Message::Text(text)) => text,
286 Ok(Message::Close(_)) => break,
287 Ok(Message::Ping(data)) => {
288 let _ = data;
290 continue;
291 }
292 _ => continue,
293 };
294
295 let client_msg: ClientMessage = match serde_json::from_str(&msg) {
297 Ok(m) => m,
298 Err(e) => {
299 tracing::warn!("Failed to parse client message: {}", e);
300 continue;
301 }
302 };
303
304 match client_msg {
305 ClientMessage::Ping => {
306 }
308 ClientMessage::Auth { token: _ } => {
309 }
311 ClientMessage::Subscribe {
312 id,
313 function_name,
314 args,
315 } => {
316 let normalized_args = args.unwrap_or(serde_json::Value::Null);
317
318 match state
320 .reactor
321 .subscribe(session_id, id.clone(), function_name, normalized_args)
322 .await
323 {
324 Ok((subscription_id, data)) => {
325 {
327 let mut map = client_to_internal.write().await;
328 map.insert(id.clone(), subscription_id);
329 }
330 {
331 let mut map = internal_to_client.write().await;
332 map.insert(subscription_id, id.clone());
333 }
334
335 tracing::debug!(?subscription_id, client_id = %id, "Subscription created");
339
340 let _ = state
349 .reactor
350 .ws_server()
351 .send_to_session(
352 session_id,
353 ReactorMessage::Data {
354 subscription_id,
355 data,
356 },
357 )
358 .await;
359 }
360 Err(e) => {
361 let _ = state
362 .reactor
363 .ws_server()
364 .send_to_session(
365 session_id,
366 ReactorMessage::Error {
367 code: "SUBSCRIBE_ERROR".to_string(),
368 message: e.to_string(),
369 },
370 )
371 .await;
372 }
373 }
374 }
375 ClientMessage::Unsubscribe { id } => {
376 let subscription_id = {
378 let map = client_to_internal.read().await;
379 map.get(&id).copied()
380 };
381
382 if let Some(sub_id) = subscription_id {
383 state.reactor.unsubscribe(sub_id).await;
384
385 {
387 let mut map = client_to_internal.write().await;
388 map.remove(&id);
389 }
390 {
391 let mut map = internal_to_client.write().await;
392 map.remove(&sub_id);
393 }
394
395 tracing::debug!(?sub_id, client_id = %id, "Subscription removed");
396 }
397 }
398 ClientMessage::SubscribeJob { id, job_id } => {
399 let job_uuid = match parse_uuid(&job_id, "job_id") {
401 Ok(uuid) => uuid,
402 Err(msg) => {
403 let _ = state
405 .reactor
406 .ws_server()
407 .send_to_session(
408 session_id,
409 ReactorMessage::Error {
410 code: "INVALID_JOB_ID".to_string(),
411 message: msg,
412 },
413 )
414 .await;
415 continue;
416 }
417 };
418
419 if id.len() > MAX_CLIENT_SUB_ID_LEN {
421 let _ = state
422 .reactor
423 .ws_server()
424 .send_to_session(
425 session_id,
426 ReactorMessage::Error {
427 code: "INVALID_ID".to_string(),
428 message: "Subscription ID too long".to_string(),
429 },
430 )
431 .await;
432 continue;
433 }
434
435 match state
436 .reactor
437 .subscribe_job(session_id, id.clone(), job_uuid)
438 .await
439 {
440 Ok(job_data) => {
441 let _ = state
443 .reactor
444 .ws_server()
445 .send_to_session(
446 session_id,
447 ReactorMessage::JobUpdate {
448 client_sub_id: id,
449 job: job_data,
450 },
451 )
452 .await;
453 }
454 Err(e) => {
455 let _ = state
457 .reactor
458 .ws_server()
459 .send_to_session(
460 session_id,
461 ReactorMessage::ErrorWithId {
462 id: id.clone(),
463 code: "NOT_FOUND".to_string(),
464 message: "Job not found".to_string(),
465 },
466 )
467 .await;
468 tracing::warn!(job_id = %job_uuid, "Job subscription failed: {}", e);
469 }
470 }
471 }
472 ClientMessage::UnsubscribeJob { id } => {
473 state.reactor.unsubscribe_job(session_id, &id).await;
474 tracing::debug!(client_id = %id, "Job subscription removed");
475 }
476 ClientMessage::SubscribeWorkflow { id, workflow_id } => {
477 let workflow_uuid = match parse_uuid(&workflow_id, "workflow_id") {
479 Ok(uuid) => uuid,
480 Err(msg) => {
481 let _ = state
482 .reactor
483 .ws_server()
484 .send_to_session(
485 session_id,
486 ReactorMessage::Error {
487 code: "INVALID_WORKFLOW_ID".to_string(),
488 message: msg,
489 },
490 )
491 .await;
492 continue;
493 }
494 };
495
496 if id.len() > MAX_CLIENT_SUB_ID_LEN {
498 let _ = state
499 .reactor
500 .ws_server()
501 .send_to_session(
502 session_id,
503 ReactorMessage::Error {
504 code: "INVALID_ID".to_string(),
505 message: "Subscription ID too long".to_string(),
506 },
507 )
508 .await;
509 continue;
510 }
511
512 match state
513 .reactor
514 .subscribe_workflow(session_id, id.clone(), workflow_uuid)
515 .await
516 {
517 Ok(workflow_data) => {
518 let _ = state
520 .reactor
521 .ws_server()
522 .send_to_session(
523 session_id,
524 ReactorMessage::WorkflowUpdate {
525 client_sub_id: id,
526 workflow: workflow_data,
527 },
528 )
529 .await;
530 }
531 Err(e) => {
532 let _ = state
533 .reactor
534 .ws_server()
535 .send_to_session(
536 session_id,
537 ReactorMessage::ErrorWithId {
538 id: id.clone(),
539 code: "NOT_FOUND".to_string(),
540 message: "Workflow not found".to_string(),
541 },
542 )
543 .await;
544 tracing::warn!(workflow_id = %workflow_uuid, "Workflow subscription failed: {}", e);
545 }
546 }
547 }
548 ClientMessage::UnsubscribeWorkflow { id } => {
549 state.reactor.unsubscribe_workflow(session_id, &id).await;
550 tracing::debug!(client_id = %id, "Workflow subscription removed");
551 }
552 }
553 }
554
555 sender_handle.abort();
557 state.reactor.remove_session(session_id).await;
558
559 let _ = sqlx::query("DELETE FROM forge_sessions WHERE id = $1")
561 .bind(session_uuid)
562 .execute(&state.db_pool)
563 .await;
564
565 tracing::debug!(?session_id, "WebSocket connection closed");
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571
572 #[test]
573 fn test_client_message_parsing() {
574 let json = r#"{"type":"ping"}"#;
575 let msg: ClientMessage = serde_json::from_str(json).unwrap();
576 assert!(matches!(msg, ClientMessage::Ping));
577 }
578
579 #[test]
580 fn test_subscribe_message_parsing() {
581 let json = r#"{"type":"subscribe","id":"sub1","function":"get_users","args":null}"#;
582 let msg: ClientMessage = serde_json::from_str(json).unwrap();
583 assert!(matches!(msg, ClientMessage::Subscribe { .. }));
584 }
585
586 #[test]
587 fn test_server_message_serialization() {
588 let msg = ServerMessage::Connected;
589 let json = serde_json::to_string(&msg).unwrap();
590 assert!(json.contains("connected"));
591 }
592}