1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::{
5 extract::{
6 State, WebSocketUpgrade,
7 ws::{Message, WebSocket},
8 },
9 response::Response,
10};
11use futures_util::{SinkExt, StreamExt};
12use serde::{Deserialize, Serialize};
13use sqlx::PgPool;
14use tokio::sync::{RwLock, mpsc};
15use uuid::Uuid;
16
17use forge_core::cluster::NodeId;
18use forge_core::function::AuthContext;
19use forge_core::realtime::SessionId;
20
21use super::auth::{AuthMiddleware, build_auth_context_from_claims};
22use crate::realtime::{Reactor, WebSocketMessage as ReactorMessage};
23
24fn parse_uuid(s: &str, field_name: &str) -> Result<Uuid, String> {
27 if s.len() > 36 {
29 return Err(format!("Invalid {}: too long", field_name));
30 }
31 Uuid::parse_str(s).map_err(|_| format!("Invalid {}: must be a valid UUID", field_name))
32}
33
34const MAX_CLIENT_SUB_ID_LEN: usize = 255;
36
37#[derive(Clone)]
39pub struct WsState {
40 pub reactor: Arc<Reactor>,
41 pub db_pool: PgPool,
42 pub node_id: NodeId,
43 pub auth_middleware: Option<Arc<AuthMiddleware>>,
44}
45
46impl WsState {
47 pub fn new(reactor: Arc<Reactor>, db_pool: PgPool, node_id: NodeId) -> Self {
48 Self {
49 reactor,
50 db_pool,
51 node_id,
52 auth_middleware: None,
53 }
54 }
55
56 pub fn with_auth(
58 reactor: Arc<Reactor>,
59 db_pool: PgPool,
60 node_id: NodeId,
61 auth_middleware: Arc<AuthMiddleware>,
62 ) -> Self {
63 Self {
64 reactor,
65 db_pool,
66 node_id,
67 auth_middleware: Some(auth_middleware),
68 }
69 }
70}
71
72#[derive(Debug, Deserialize)]
74#[serde(tag = "type", rename_all = "snake_case")]
75pub enum ClientMessage {
76 Subscribe {
78 id: String,
79 #[serde(rename = "function")]
80 function_name: String,
81 args: Option<serde_json::Value>,
82 },
83 Unsubscribe { id: String },
85 SubscribeJob {
87 id: String,
89 job_id: String,
91 },
92 UnsubscribeJob { id: String },
94 SubscribeWorkflow {
96 id: String,
98 workflow_id: String,
100 },
101 UnsubscribeWorkflow { id: String },
103 Ping,
105 Auth {
107 #[allow(dead_code)]
108 token: String,
109 },
110}
111
112#[derive(Debug, Serialize)]
114#[serde(tag = "type", rename_all = "snake_case")]
115pub enum ServerMessage {
116 Connected,
118 Pong,
120 AuthSuccess,
122 AuthFailed { reason: String },
124 Data { id: String, data: serde_json::Value },
126 JobUpdate { id: String, job: JobData },
128 WorkflowUpdate { id: String, workflow: WorkflowData },
130 Error {
132 id: Option<String>,
133 code: String,
134 message: String,
135 },
136 #[allow(dead_code)]
138 Subscribed { id: String },
139 #[allow(dead_code)]
141 Unsubscribed { id: String },
142}
143
144#[derive(Debug, Clone, Serialize)]
146pub struct JobData {
147 pub job_id: String,
148 pub status: String,
149 pub progress_percent: Option<i32>,
150 pub progress_message: Option<String>,
151 pub output: Option<serde_json::Value>,
152 pub error: Option<String>,
153}
154
155#[derive(Debug, Clone, Serialize)]
157pub struct WorkflowData {
158 pub workflow_id: String,
159 pub status: String,
160 pub current_step: Option<String>,
161 pub steps: Vec<WorkflowStepData>,
162 pub output: Option<serde_json::Value>,
163 pub error: Option<String>,
164}
165
166#[derive(Debug, Clone, Serialize)]
168pub struct WorkflowStepData {
169 pub name: String,
170 pub status: String,
171 pub error: Option<String>,
172}
173
174pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<WsState>>) -> Response {
176 ws.on_upgrade(move |socket| handle_socket(socket, state))
177}
178
179async fn handle_socket(socket: WebSocket, state: Arc<WsState>) {
181 let (mut ws_sender, mut ws_receiver) = socket.split();
182
183 let session_id = SessionId::new();
185 let session_uuid = session_id.0;
186 let node_uuid = state.node_id.0;
187
188 let _ = sqlx::query(
190 r#"
191 INSERT INTO forge_sessions (id, node_id, status, connected_at, last_activity)
192 VALUES ($1, $2, 'connected', NOW(), NOW())
193 ON CONFLICT (id) DO UPDATE SET status = 'connected', last_activity = NOW()
194 "#,
195 )
196 .bind(session_uuid)
197 .bind(node_uuid)
198 .execute(&state.db_pool)
199 .await;
200
201 let (reactor_tx, mut reactor_rx) = mpsc::channel::<ReactorMessage>(256);
203
204 state.reactor.register_session(session_id, reactor_tx).await;
206
207 #[allow(clippy::type_complexity)]
209 let client_to_internal: Arc<RwLock<HashMap<String, forge_core::realtime::SubscriptionId>>> =
210 Arc::new(RwLock::new(HashMap::new()));
211 let internal_to_client: Arc<RwLock<HashMap<forge_core::realtime::SubscriptionId, String>>> =
212 Arc::new(RwLock::new(HashMap::new()));
213
214 let connection_auth: Arc<RwLock<AuthContext>> =
216 Arc::new(RwLock::new(AuthContext::unauthenticated()));
217
218 let connected = ServerMessage::Connected;
219 if let Ok(json) = serde_json::to_string(&connected) {
220 let _ = ws_sender.send(Message::Text(json.into())).await;
221 }
222
223 tracing::debug!(?session_id, "WebSocket connection established");
224
225 let internal_to_client_clone = internal_to_client.clone();
227
228 let sender_handle = tokio::spawn(async move {
230 while let Some(msg) = reactor_rx.recv().await {
231 let server_msg = match msg {
232 ReactorMessage::Data {
233 subscription_id,
234 data,
235 } => {
236 let client_id = {
238 let map = internal_to_client_clone.read().await;
239 map.get(&subscription_id).cloned()
240 };
241
242 if let Some(id) = client_id {
243 ServerMessage::Data { id, data }
244 } else {
245 continue;
246 }
247 }
248 ReactorMessage::DeltaUpdate {
249 subscription_id,
250 delta,
251 } => {
252 let client_id = {
254 let map = internal_to_client_clone.read().await;
255 map.get(&subscription_id).cloned()
256 };
257
258 if let Some(id) = client_id {
259 ServerMessage::Data {
261 id,
262 data: serde_json::json!({
263 "delta": {
264 "added": delta.added,
265 "removed": delta.removed,
266 "updated": delta.updated
267 }
268 }),
269 }
270 } else {
271 continue;
272 }
273 }
274 ReactorMessage::JobUpdate { client_sub_id, job } => ServerMessage::JobUpdate {
275 id: client_sub_id,
276 job,
277 },
278 ReactorMessage::WorkflowUpdate {
279 client_sub_id,
280 workflow,
281 } => ServerMessage::WorkflowUpdate {
282 id: client_sub_id,
283 workflow,
284 },
285 ReactorMessage::Error { code, message } => ServerMessage::Error {
286 id: None,
287 code,
288 message,
289 },
290 ReactorMessage::ErrorWithId { id, code, message } => ServerMessage::Error {
291 id: Some(id),
292 code,
293 message,
294 },
295 ReactorMessage::AuthSuccess => ServerMessage::AuthSuccess,
296 ReactorMessage::AuthFailed { reason } => ServerMessage::AuthFailed { reason },
297 ReactorMessage::Ping => ServerMessage::Pong,
298 ReactorMessage::Pong => continue,
299 _ => continue,
300 };
301
302 if let Ok(json) = serde_json::to_string(&server_msg) {
303 if ws_sender.send(Message::Text(json.into())).await.is_err() {
304 break;
305 }
306 }
307 }
308 });
309
310 while let Some(msg) = ws_receiver.next().await {
311 let msg = match msg {
312 Ok(Message::Text(text)) => text,
313 Ok(Message::Close(_)) => break,
314 Ok(Message::Ping(data)) => {
315 let _ = data;
317 continue;
318 }
319 _ => continue,
320 };
321
322 let client_msg: ClientMessage = match serde_json::from_str(&msg) {
323 Ok(m) => m,
324 Err(e) => {
325 tracing::warn!("Failed to parse client message: {}", e);
326 continue;
327 }
328 };
329
330 match client_msg {
331 ClientMessage::Ping => {
332 }
334 ClientMessage::Auth { token } => {
335 if let Some(ref auth_middleware) = state.auth_middleware {
337 match auth_middleware.validate_token_async(&token).await {
338 Ok(claims) => {
339 let auth_context = build_auth_context_from_claims(claims);
340 *connection_auth.write().await = auth_context;
341
342 let _ = state
343 .reactor
344 .ws_server()
345 .send_to_session(session_id, ReactorMessage::AuthSuccess)
346 .await;
347
348 tracing::debug!(?session_id, "WebSocket authentication successful");
349 }
350 Err(e) => {
351 let _ = state
352 .reactor
353 .ws_server()
354 .send_to_session(
355 session_id,
356 ReactorMessage::AuthFailed {
357 reason: e.to_string(),
358 },
359 )
360 .await;
361
362 tracing::debug!(?session_id, error = %e, "WebSocket authentication failed");
363 }
364 }
365 } else {
366 let _ = state
368 .reactor
369 .ws_server()
370 .send_to_session(
371 session_id,
372 ReactorMessage::AuthFailed {
373 reason: "Authentication not configured".to_string(),
374 },
375 )
376 .await;
377 }
378 }
379 ClientMessage::Subscribe {
380 id,
381 function_name,
382 args,
383 } => {
384 let normalized_args = args.unwrap_or(serde_json::Value::Null);
385 let auth = connection_auth.read().await.clone();
386
387 match state
388 .reactor
389 .subscribe(session_id, id.clone(), function_name, normalized_args, auth)
390 .await
391 {
392 Ok((subscription_id, data)) => {
393 {
394 let mut map = client_to_internal.write().await;
395 map.insert(id.clone(), subscription_id);
396 }
397 {
398 let mut map = internal_to_client.write().await;
399 map.insert(subscription_id, id.clone());
400 }
401
402 tracing::debug!(?subscription_id, client_id = %id, "Subscription created");
403
404 let _ = state
409 .reactor
410 .ws_server()
411 .send_to_session(
412 session_id,
413 ReactorMessage::Data {
414 subscription_id,
415 data,
416 },
417 )
418 .await;
419 }
420 Err(e) => {
421 let _ = state
422 .reactor
423 .ws_server()
424 .send_to_session(
425 session_id,
426 ReactorMessage::Error {
427 code: "SUBSCRIBE_ERROR".to_string(),
428 message: e.to_string(),
429 },
430 )
431 .await;
432 }
433 }
434 }
435 ClientMessage::Unsubscribe { id } => {
436 let subscription_id = {
438 let map = client_to_internal.read().await;
439 map.get(&id).copied()
440 };
441
442 if let Some(sub_id) = subscription_id {
443 state.reactor.unsubscribe(sub_id).await;
444
445 {
447 let mut map = client_to_internal.write().await;
448 map.remove(&id);
449 }
450 {
451 let mut map = internal_to_client.write().await;
452 map.remove(&sub_id);
453 }
454
455 tracing::debug!(?sub_id, client_id = %id, "Subscription removed");
456 }
457 }
458 ClientMessage::SubscribeJob { id, job_id } => {
459 let job_uuid = match parse_uuid(&job_id, "job_id") {
461 Ok(uuid) => uuid,
462 Err(msg) => {
463 let _ = state
465 .reactor
466 .ws_server()
467 .send_to_session(
468 session_id,
469 ReactorMessage::Error {
470 code: "INVALID_JOB_ID".to_string(),
471 message: msg,
472 },
473 )
474 .await;
475 continue;
476 }
477 };
478
479 if id.len() > MAX_CLIENT_SUB_ID_LEN {
481 let _ = state
482 .reactor
483 .ws_server()
484 .send_to_session(
485 session_id,
486 ReactorMessage::Error {
487 code: "INVALID_ID".to_string(),
488 message: "Subscription ID too long".to_string(),
489 },
490 )
491 .await;
492 continue;
493 }
494
495 match state
496 .reactor
497 .subscribe_job(session_id, id.clone(), job_uuid)
498 .await
499 {
500 Ok(job_data) => {
501 let _ = state
503 .reactor
504 .ws_server()
505 .send_to_session(
506 session_id,
507 ReactorMessage::JobUpdate {
508 client_sub_id: id,
509 job: job_data,
510 },
511 )
512 .await;
513 }
514 Err(e) => {
515 let _ = state
517 .reactor
518 .ws_server()
519 .send_to_session(
520 session_id,
521 ReactorMessage::ErrorWithId {
522 id: id.clone(),
523 code: "NOT_FOUND".to_string(),
524 message: "Job not found".to_string(),
525 },
526 )
527 .await;
528 tracing::warn!(job_id = %job_uuid, "Job subscription failed: {}", e);
529 }
530 }
531 }
532 ClientMessage::UnsubscribeJob { id } => {
533 state.reactor.unsubscribe_job(session_id, &id).await;
534 tracing::debug!(client_id = %id, "Job subscription removed");
535 }
536 ClientMessage::SubscribeWorkflow { id, workflow_id } => {
537 let workflow_uuid = match parse_uuid(&workflow_id, "workflow_id") {
539 Ok(uuid) => uuid,
540 Err(msg) => {
541 let _ = state
542 .reactor
543 .ws_server()
544 .send_to_session(
545 session_id,
546 ReactorMessage::Error {
547 code: "INVALID_WORKFLOW_ID".to_string(),
548 message: msg,
549 },
550 )
551 .await;
552 continue;
553 }
554 };
555
556 if id.len() > MAX_CLIENT_SUB_ID_LEN {
558 let _ = state
559 .reactor
560 .ws_server()
561 .send_to_session(
562 session_id,
563 ReactorMessage::Error {
564 code: "INVALID_ID".to_string(),
565 message: "Subscription ID too long".to_string(),
566 },
567 )
568 .await;
569 continue;
570 }
571
572 match state
573 .reactor
574 .subscribe_workflow(session_id, id.clone(), workflow_uuid)
575 .await
576 {
577 Ok(workflow_data) => {
578 let _ = state
580 .reactor
581 .ws_server()
582 .send_to_session(
583 session_id,
584 ReactorMessage::WorkflowUpdate {
585 client_sub_id: id,
586 workflow: workflow_data,
587 },
588 )
589 .await;
590 }
591 Err(e) => {
592 let _ = state
593 .reactor
594 .ws_server()
595 .send_to_session(
596 session_id,
597 ReactorMessage::ErrorWithId {
598 id: id.clone(),
599 code: "NOT_FOUND".to_string(),
600 message: "Workflow not found".to_string(),
601 },
602 )
603 .await;
604 tracing::warn!(workflow_id = %workflow_uuid, "Workflow subscription failed: {}", e);
605 }
606 }
607 }
608 ClientMessage::UnsubscribeWorkflow { id } => {
609 state.reactor.unsubscribe_workflow(session_id, &id).await;
610 tracing::debug!(client_id = %id, "Workflow subscription removed");
611 }
612 }
613 }
614
615 sender_handle.abort();
616 state.reactor.remove_session(session_id).await;
617
618 let _ = sqlx::query("DELETE FROM forge_sessions WHERE id = $1")
619 .bind(session_uuid)
620 .execute(&state.db_pool)
621 .await;
622
623 tracing::debug!(?session_id, "WebSocket connection closed");
624}
625
626#[cfg(test)]
627mod tests {
628 use super::*;
629
630 #[test]
631 fn test_client_message_parsing() {
632 let json = r#"{"type":"ping"}"#;
633 let msg: ClientMessage = serde_json::from_str(json).unwrap();
634 assert!(matches!(msg, ClientMessage::Ping));
635 }
636
637 #[test]
638 fn test_subscribe_message_parsing() {
639 let json = r#"{"type":"subscribe","id":"sub1","function":"get_users","args":null}"#;
640 let msg: ClientMessage = serde_json::from_str(json).unwrap();
641 assert!(matches!(msg, ClientMessage::Subscribe { .. }));
642 }
643
644 #[test]
645 fn test_server_message_serialization() {
646 let msg = ServerMessage::Connected;
647 let json = serde_json::to_string(&msg).unwrap();
648 assert!(json.contains("connected"));
649 }
650}