forge_runtime/gateway/
websocket.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::{
5    extract::{
6        State, WebSocketUpgrade,
7        ws::{Message, WebSocket},
8    },
9    response::Response,
10};
11use futures_util::{SinkExt, StreamExt};
12use serde::{Deserialize, Serialize};
13use sqlx::PgPool;
14use tokio::sync::{RwLock, mpsc};
15use uuid::Uuid;
16
17use forge_core::cluster::NodeId;
18use forge_core::function::AuthContext;
19use forge_core::realtime::SessionId;
20
21use super::auth::{AuthMiddleware, build_auth_context_from_claims};
22use crate::realtime::{Reactor, WebSocketMessage as ReactorMessage};
23
24/// Validate and parse a string as UUID.
25/// Returns error message suitable for client display (no internal details).
26fn parse_uuid(s: &str, field_name: &str) -> Result<Uuid, String> {
27    // Limit length to prevent DoS via huge strings
28    if s.len() > 36 {
29        return Err(format!("Invalid {}: too long", field_name));
30    }
31    Uuid::parse_str(s).map_err(|_| format!("Invalid {}: must be a valid UUID", field_name))
32}
33
34/// Maximum length for client subscription IDs
35const MAX_CLIENT_SUB_ID_LEN: usize = 255;
36
37/// WebSocket connection state shared across the gateway.
38#[derive(Clone)]
39pub struct WsState {
40    pub reactor: Arc<Reactor>,
41    pub db_pool: PgPool,
42    pub node_id: NodeId,
43    pub auth_middleware: Option<Arc<AuthMiddleware>>,
44}
45
46impl WsState {
47    pub fn new(reactor: Arc<Reactor>, db_pool: PgPool, node_id: NodeId) -> Self {
48        Self {
49            reactor,
50            db_pool,
51            node_id,
52            auth_middleware: None,
53        }
54    }
55
56    /// Create a new WebSocket state with authentication middleware.
57    pub fn with_auth(
58        reactor: Arc<Reactor>,
59        db_pool: PgPool,
60        node_id: NodeId,
61        auth_middleware: Arc<AuthMiddleware>,
62    ) -> Self {
63        Self {
64            reactor,
65            db_pool,
66            node_id,
67            auth_middleware: Some(auth_middleware),
68        }
69    }
70}
71
72/// Incoming WebSocket message from client.
73#[derive(Debug, Deserialize)]
74#[serde(tag = "type", rename_all = "snake_case")]
75pub enum ClientMessage {
76    /// Subscribe to a query.
77    Subscribe {
78        id: String,
79        #[serde(rename = "function")]
80        function_name: String,
81        args: Option<serde_json::Value>,
82    },
83    /// Unsubscribe from a subscription.
84    Unsubscribe { id: String },
85    /// Subscribe to job progress updates.
86    SubscribeJob {
87        /// Client-provided subscription ID (for correlation)
88        id: String,
89        /// Job UUID - MUST be validated as UUID
90        job_id: String,
91    },
92    /// Unsubscribe from job updates.
93    UnsubscribeJob { id: String },
94    /// Subscribe to workflow progress updates.
95    SubscribeWorkflow {
96        /// Client-provided subscription ID (for correlation)
97        id: String,
98        /// Workflow run UUID - MUST be validated as UUID
99        workflow_id: String,
100    },
101    /// Unsubscribe from workflow updates.
102    UnsubscribeWorkflow { id: String },
103    /// Ping for keepalive.
104    Ping,
105    /// Authentication.
106    Auth {
107        #[allow(dead_code)]
108        token: String,
109    },
110}
111
112/// Outgoing WebSocket message to client.
113#[derive(Debug, Serialize)]
114#[serde(tag = "type", rename_all = "snake_case")]
115pub enum ServerMessage {
116    /// Connection established.
117    Connected,
118    /// Ping response.
119    Pong,
120    /// Authentication successful.
121    AuthSuccess,
122    /// Authentication failed.
123    AuthFailed { reason: String },
124    /// Subscription data.
125    Data { id: String, data: serde_json::Value },
126    /// Job progress update.
127    JobUpdate { id: String, job: JobData },
128    /// Workflow progress update.
129    WorkflowUpdate { id: String, workflow: WorkflowData },
130    /// Subscription error.
131    Error {
132        id: Option<String>,
133        code: String,
134        message: String,
135    },
136    /// Subscription response (success/failure).
137    #[allow(dead_code)]
138    Subscribed { id: String },
139    /// Unsubscribed confirmation.
140    #[allow(dead_code)]
141    Unsubscribed { id: String },
142}
143
144/// Job data sent to client (subset of internal JobRecord).
145#[derive(Debug, Clone, Serialize)]
146pub struct JobData {
147    pub job_id: String,
148    pub status: String,
149    pub progress_percent: Option<i32>,
150    pub progress_message: Option<String>,
151    pub output: Option<serde_json::Value>,
152    pub error: Option<String>,
153}
154
155/// Workflow data sent to client.
156#[derive(Debug, Clone, Serialize)]
157pub struct WorkflowData {
158    pub workflow_id: String,
159    pub status: String,
160    pub current_step: Option<String>,
161    pub steps: Vec<WorkflowStepData>,
162    pub output: Option<serde_json::Value>,
163    pub error: Option<String>,
164}
165
166/// Workflow step data sent to client.
167#[derive(Debug, Clone, Serialize)]
168pub struct WorkflowStepData {
169    pub name: String,
170    pub status: String,
171    pub error: Option<String>,
172}
173
174/// WebSocket upgrade handler.
175pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<WsState>>) -> Response {
176    ws.on_upgrade(move |socket| handle_socket(socket, state))
177}
178
179/// Handle a WebSocket connection.
180async fn handle_socket(socket: WebSocket, state: Arc<WsState>) {
181    let (mut ws_sender, mut ws_receiver) = socket.split();
182
183    // Create a session for this connection
184    let session_id = SessionId::new();
185    let session_uuid = session_id.0;
186    let node_uuid = state.node_id.0;
187
188    // Insert session into database for tracking
189    let _ = sqlx::query(
190        r#"
191        INSERT INTO forge_sessions (id, node_id, status, connected_at, last_activity)
192        VALUES ($1, $2, 'connected', NOW(), NOW())
193        ON CONFLICT (id) DO UPDATE SET status = 'connected', last_activity = NOW()
194        "#,
195    )
196    .bind(session_uuid)
197    .bind(node_uuid)
198    .execute(&state.db_pool)
199    .await;
200
201    // Create channels for reactor -> websocket communication
202    let (reactor_tx, mut reactor_rx) = mpsc::channel::<ReactorMessage>(256);
203
204    // Register session with reactor
205    state.reactor.register_session(session_id, reactor_tx).await;
206
207    // Track client subscription IDs to internal subscription IDs
208    #[allow(clippy::type_complexity)]
209    let client_to_internal: Arc<RwLock<HashMap<String, forge_core::realtime::SubscriptionId>>> =
210        Arc::new(RwLock::new(HashMap::new()));
211    let internal_to_client: Arc<RwLock<HashMap<forge_core::realtime::SubscriptionId, String>>> =
212        Arc::new(RwLock::new(HashMap::new()));
213
214    // Track connection's auth context (starts unauthenticated)
215    let connection_auth: Arc<RwLock<AuthContext>> =
216        Arc::new(RwLock::new(AuthContext::unauthenticated()));
217
218    let connected = ServerMessage::Connected;
219    if let Ok(json) = serde_json::to_string(&connected) {
220        let _ = ws_sender.send(Message::Text(json.into())).await;
221    }
222
223    tracing::debug!(?session_id, "WebSocket connection established");
224
225    // Clone state for the reactor message handler
226    let internal_to_client_clone = internal_to_client.clone();
227
228    // Spawn task to forward reactor messages to WebSocket
229    let sender_handle = tokio::spawn(async move {
230        while let Some(msg) = reactor_rx.recv().await {
231            let server_msg = match msg {
232                ReactorMessage::Data {
233                    subscription_id,
234                    data,
235                } => {
236                    // Map internal subscription ID back to client ID
237                    let client_id = {
238                        let map = internal_to_client_clone.read().await;
239                        map.get(&subscription_id).cloned()
240                    };
241
242                    if let Some(id) = client_id {
243                        ServerMessage::Data { id, data }
244                    } else {
245                        continue;
246                    }
247                }
248                ReactorMessage::DeltaUpdate {
249                    subscription_id,
250                    delta,
251                } => {
252                    // Map internal subscription ID back to client ID
253                    let client_id = {
254                        let map = internal_to_client_clone.read().await;
255                        map.get(&subscription_id).cloned()
256                    };
257
258                    if let Some(id) = client_id {
259                        // Convert delta to data update
260                        ServerMessage::Data {
261                            id,
262                            data: serde_json::json!({
263                                "delta": {
264                                    "added": delta.added,
265                                    "removed": delta.removed,
266                                    "updated": delta.updated
267                                }
268                            }),
269                        }
270                    } else {
271                        continue;
272                    }
273                }
274                ReactorMessage::JobUpdate { client_sub_id, job } => ServerMessage::JobUpdate {
275                    id: client_sub_id,
276                    job,
277                },
278                ReactorMessage::WorkflowUpdate {
279                    client_sub_id,
280                    workflow,
281                } => ServerMessage::WorkflowUpdate {
282                    id: client_sub_id,
283                    workflow,
284                },
285                ReactorMessage::Error { code, message } => ServerMessage::Error {
286                    id: None,
287                    code,
288                    message,
289                },
290                ReactorMessage::ErrorWithId { id, code, message } => ServerMessage::Error {
291                    id: Some(id),
292                    code,
293                    message,
294                },
295                ReactorMessage::AuthSuccess => ServerMessage::AuthSuccess,
296                ReactorMessage::AuthFailed { reason } => ServerMessage::AuthFailed { reason },
297                ReactorMessage::Ping => ServerMessage::Pong,
298                ReactorMessage::Pong => continue,
299                _ => continue,
300            };
301
302            if let Ok(json) = serde_json::to_string(&server_msg) {
303                if ws_sender.send(Message::Text(json.into())).await.is_err() {
304                    break;
305                }
306            }
307        }
308    });
309
310    while let Some(msg) = ws_receiver.next().await {
311        let msg = match msg {
312            Ok(Message::Text(text)) => text,
313            Ok(Message::Close(_)) => break,
314            Ok(Message::Ping(data)) => {
315                // Note: Can't send directly since sender is moved, but axum handles pings
316                let _ = data;
317                continue;
318            }
319            _ => continue,
320        };
321
322        let client_msg: ClientMessage = match serde_json::from_str(&msg) {
323            Ok(m) => m,
324            Err(e) => {
325                tracing::warn!("Failed to parse client message: {}", e);
326                continue;
327            }
328        };
329
330        match client_msg {
331            ClientMessage::Ping => {
332                // Pong is handled by the reactor message sender
333            }
334            ClientMessage::Auth { token } => {
335                // Validate token and set auth context
336                if let Some(ref auth_middleware) = state.auth_middleware {
337                    match auth_middleware.validate_token_async(&token).await {
338                        Ok(claims) => {
339                            let auth_context = build_auth_context_from_claims(claims);
340                            *connection_auth.write().await = auth_context;
341
342                            let _ = state
343                                .reactor
344                                .ws_server()
345                                .send_to_session(session_id, ReactorMessage::AuthSuccess)
346                                .await;
347
348                            tracing::debug!(?session_id, "WebSocket authentication successful");
349                        }
350                        Err(e) => {
351                            let _ = state
352                                .reactor
353                                .ws_server()
354                                .send_to_session(
355                                    session_id,
356                                    ReactorMessage::AuthFailed {
357                                        reason: e.to_string(),
358                                    },
359                                )
360                                .await;
361
362                            tracing::debug!(?session_id, error = %e, "WebSocket authentication failed");
363                        }
364                    }
365                } else {
366                    // No auth middleware configured - auth not available
367                    let _ = state
368                        .reactor
369                        .ws_server()
370                        .send_to_session(
371                            session_id,
372                            ReactorMessage::AuthFailed {
373                                reason: "Authentication not configured".to_string(),
374                            },
375                        )
376                        .await;
377                }
378            }
379            ClientMessage::Subscribe {
380                id,
381                function_name,
382                args,
383            } => {
384                let normalized_args = args.unwrap_or(serde_json::Value::Null);
385                let auth = connection_auth.read().await.clone();
386
387                match state
388                    .reactor
389                    .subscribe(session_id, id.clone(), function_name, normalized_args, auth)
390                    .await
391                {
392                    Ok((subscription_id, data)) => {
393                        {
394                            let mut map = client_to_internal.write().await;
395                            map.insert(id.clone(), subscription_id);
396                        }
397                        {
398                            let mut map = internal_to_client.write().await;
399                            map.insert(subscription_id, id.clone());
400                        }
401
402                        tracing::debug!(?subscription_id, client_id = %id, "Subscription created");
403
404                        // Actually, the data is returned from subscribe, so we should send it
405                        // The sender_handle has ws_sender, so we can't send from here directly
406                        // Let's use the reactor channel to send these messages
407
408                        let _ = state
409                            .reactor
410                            .ws_server()
411                            .send_to_session(
412                                session_id,
413                                ReactorMessage::Data {
414                                    subscription_id,
415                                    data,
416                                },
417                            )
418                            .await;
419                    }
420                    Err(e) => {
421                        let _ = state
422                            .reactor
423                            .ws_server()
424                            .send_to_session(
425                                session_id,
426                                ReactorMessage::Error {
427                                    code: "SUBSCRIBE_ERROR".to_string(),
428                                    message: e.to_string(),
429                                },
430                            )
431                            .await;
432                    }
433                }
434            }
435            ClientMessage::Unsubscribe { id } => {
436                // Look up internal subscription ID
437                let subscription_id = {
438                    let map = client_to_internal.read().await;
439                    map.get(&id).copied()
440                };
441
442                if let Some(sub_id) = subscription_id {
443                    state.reactor.unsubscribe(sub_id).await;
444
445                    // Clean up mappings
446                    {
447                        let mut map = client_to_internal.write().await;
448                        map.remove(&id);
449                    }
450                    {
451                        let mut map = internal_to_client.write().await;
452                        map.remove(&sub_id);
453                    }
454
455                    tracing::debug!(?sub_id, client_id = %id, "Subscription removed");
456                }
457            }
458            ClientMessage::SubscribeJob { id, job_id } => {
459                // SECURITY: Validate UUID BEFORE any processing
460                let job_uuid = match parse_uuid(&job_id, "job_id") {
461                    Ok(uuid) => uuid,
462                    Err(msg) => {
463                        // Send error to client, do NOT log the invalid input
464                        let _ = state
465                            .reactor
466                            .ws_server()
467                            .send_to_session(
468                                session_id,
469                                ReactorMessage::Error {
470                                    code: "INVALID_JOB_ID".to_string(),
471                                    message: msg,
472                                },
473                            )
474                            .await;
475                        continue;
476                    }
477                };
478
479                // SECURITY: Limit client_sub_id length
480                if id.len() > MAX_CLIENT_SUB_ID_LEN {
481                    let _ = state
482                        .reactor
483                        .ws_server()
484                        .send_to_session(
485                            session_id,
486                            ReactorMessage::Error {
487                                code: "INVALID_ID".to_string(),
488                                message: "Subscription ID too long".to_string(),
489                            },
490                        )
491                        .await;
492                    continue;
493                }
494
495                match state
496                    .reactor
497                    .subscribe_job(session_id, id.clone(), job_uuid)
498                    .await
499                {
500                    Ok(job_data) => {
501                        // Send current job state immediately
502                        let _ = state
503                            .reactor
504                            .ws_server()
505                            .send_to_session(
506                                session_id,
507                                ReactorMessage::JobUpdate {
508                                    client_sub_id: id,
509                                    job: job_data,
510                                },
511                            )
512                            .await;
513                    }
514                    Err(e) => {
515                        // Generic error - don't expose internal details
516                        let _ = state
517                            .reactor
518                            .ws_server()
519                            .send_to_session(
520                                session_id,
521                                ReactorMessage::ErrorWithId {
522                                    id: id.clone(),
523                                    code: "NOT_FOUND".to_string(),
524                                    message: "Job not found".to_string(),
525                                },
526                            )
527                            .await;
528                        tracing::warn!(job_id = %job_uuid, "Job subscription failed: {}", e);
529                    }
530                }
531            }
532            ClientMessage::UnsubscribeJob { id } => {
533                state.reactor.unsubscribe_job(session_id, &id).await;
534                tracing::debug!(client_id = %id, "Job subscription removed");
535            }
536            ClientMessage::SubscribeWorkflow { id, workflow_id } => {
537                // SECURITY: Validate UUID BEFORE any processing
538                let workflow_uuid = match parse_uuid(&workflow_id, "workflow_id") {
539                    Ok(uuid) => uuid,
540                    Err(msg) => {
541                        let _ = state
542                            .reactor
543                            .ws_server()
544                            .send_to_session(
545                                session_id,
546                                ReactorMessage::Error {
547                                    code: "INVALID_WORKFLOW_ID".to_string(),
548                                    message: msg,
549                                },
550                            )
551                            .await;
552                        continue;
553                    }
554                };
555
556                // SECURITY: Limit client_sub_id length
557                if id.len() > MAX_CLIENT_SUB_ID_LEN {
558                    let _ = state
559                        .reactor
560                        .ws_server()
561                        .send_to_session(
562                            session_id,
563                            ReactorMessage::Error {
564                                code: "INVALID_ID".to_string(),
565                                message: "Subscription ID too long".to_string(),
566                            },
567                        )
568                        .await;
569                    continue;
570                }
571
572                match state
573                    .reactor
574                    .subscribe_workflow(session_id, id.clone(), workflow_uuid)
575                    .await
576                {
577                    Ok(workflow_data) => {
578                        // Send current workflow state immediately
579                        let _ = state
580                            .reactor
581                            .ws_server()
582                            .send_to_session(
583                                session_id,
584                                ReactorMessage::WorkflowUpdate {
585                                    client_sub_id: id,
586                                    workflow: workflow_data,
587                                },
588                            )
589                            .await;
590                    }
591                    Err(e) => {
592                        let _ = state
593                            .reactor
594                            .ws_server()
595                            .send_to_session(
596                                session_id,
597                                ReactorMessage::ErrorWithId {
598                                    id: id.clone(),
599                                    code: "NOT_FOUND".to_string(),
600                                    message: "Workflow not found".to_string(),
601                                },
602                            )
603                            .await;
604                        tracing::warn!(workflow_id = %workflow_uuid, "Workflow subscription failed: {}", e);
605                    }
606                }
607            }
608            ClientMessage::UnsubscribeWorkflow { id } => {
609                state.reactor.unsubscribe_workflow(session_id, &id).await;
610                tracing::debug!(client_id = %id, "Workflow subscription removed");
611            }
612        }
613    }
614
615    sender_handle.abort();
616    state.reactor.remove_session(session_id).await;
617
618    let _ = sqlx::query("DELETE FROM forge_sessions WHERE id = $1")
619        .bind(session_uuid)
620        .execute(&state.db_pool)
621        .await;
622
623    tracing::debug!(?session_id, "WebSocket connection closed");
624}
625
626#[cfg(test)]
627mod tests {
628    use super::*;
629
630    #[test]
631    fn test_client_message_parsing() {
632        let json = r#"{"type":"ping"}"#;
633        let msg: ClientMessage = serde_json::from_str(json).unwrap();
634        assert!(matches!(msg, ClientMessage::Ping));
635    }
636
637    #[test]
638    fn test_subscribe_message_parsing() {
639        let json = r#"{"type":"subscribe","id":"sub1","function":"get_users","args":null}"#;
640        let msg: ClientMessage = serde_json::from_str(json).unwrap();
641        assert!(matches!(msg, ClientMessage::Subscribe { .. }));
642    }
643
644    #[test]
645    fn test_server_message_serialization() {
646        let msg = ServerMessage::Connected;
647        let json = serde_json::to_string(&msg).unwrap();
648        assert!(json.contains("connected"));
649    }
650}