forge_runtime/gateway/
websocket.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::{
5    extract::{
6        State, WebSocketUpgrade,
7        ws::{Message, WebSocket},
8    },
9    response::Response,
10};
11use futures_util::{SinkExt, StreamExt};
12use serde::{Deserialize, Serialize};
13use sqlx::PgPool;
14use tokio::sync::{RwLock, mpsc};
15use uuid::Uuid;
16
17use forge_core::cluster::NodeId;
18use forge_core::realtime::SessionId;
19
20use crate::realtime::{Reactor, WebSocketMessage as ReactorMessage};
21
22/// Validate and parse a string as UUID.
23/// Returns error message suitable for client display (no internal details).
24fn parse_uuid(s: &str, field_name: &str) -> Result<Uuid, String> {
25    // Limit length to prevent DoS via huge strings
26    if s.len() > 36 {
27        return Err(format!("Invalid {}: too long", field_name));
28    }
29    Uuid::parse_str(s).map_err(|_| format!("Invalid {}: must be a valid UUID", field_name))
30}
31
32/// Maximum length for client subscription IDs
33const MAX_CLIENT_SUB_ID_LEN: usize = 255;
34
35/// WebSocket connection state shared across the gateway.
36#[derive(Clone)]
37pub struct WsState {
38    pub reactor: Arc<Reactor>,
39    pub db_pool: PgPool,
40    pub node_id: NodeId,
41}
42
43impl WsState {
44    pub fn new(reactor: Arc<Reactor>, db_pool: PgPool, node_id: NodeId) -> Self {
45        Self {
46            reactor,
47            db_pool,
48            node_id,
49        }
50    }
51}
52
53/// Incoming WebSocket message from client.
54#[derive(Debug, Deserialize)]
55#[serde(tag = "type", rename_all = "snake_case")]
56pub enum ClientMessage {
57    /// Subscribe to a query.
58    Subscribe {
59        id: String,
60        #[serde(rename = "function")]
61        function_name: String,
62        args: Option<serde_json::Value>,
63    },
64    /// Unsubscribe from a subscription.
65    Unsubscribe { id: String },
66    /// Subscribe to job progress updates.
67    SubscribeJob {
68        /// Client-provided subscription ID (for correlation)
69        id: String,
70        /// Job UUID - MUST be validated as UUID
71        job_id: String,
72    },
73    /// Unsubscribe from job updates.
74    UnsubscribeJob { id: String },
75    /// Subscribe to workflow progress updates.
76    SubscribeWorkflow {
77        /// Client-provided subscription ID (for correlation)
78        id: String,
79        /// Workflow run UUID - MUST be validated as UUID
80        workflow_id: String,
81    },
82    /// Unsubscribe from workflow updates.
83    UnsubscribeWorkflow { id: String },
84    /// Ping for keepalive.
85    Ping,
86    /// Authentication.
87    Auth {
88        #[allow(dead_code)]
89        token: String,
90    },
91}
92
93/// Outgoing WebSocket message to client.
94#[derive(Debug, Serialize)]
95#[serde(tag = "type", rename_all = "snake_case")]
96pub enum ServerMessage {
97    /// Connection established.
98    Connected,
99    /// Ping response.
100    Pong,
101    /// Subscription data.
102    Data { id: String, data: serde_json::Value },
103    /// Job progress update.
104    JobUpdate { id: String, job: JobData },
105    /// Workflow progress update.
106    WorkflowUpdate { id: String, workflow: WorkflowData },
107    /// Subscription error.
108    Error {
109        id: Option<String>,
110        code: String,
111        message: String,
112    },
113    /// Subscription response (success/failure).
114    #[allow(dead_code)]
115    Subscribed { id: String },
116    /// Unsubscribed confirmation.
117    #[allow(dead_code)]
118    Unsubscribed { id: String },
119}
120
121/// Job data sent to client (subset of internal JobRecord).
122#[derive(Debug, Clone, Serialize)]
123pub struct JobData {
124    pub job_id: String,
125    pub status: String,
126    pub progress_percent: Option<i32>,
127    pub progress_message: Option<String>,
128    pub output: Option<serde_json::Value>,
129    pub error: Option<String>,
130}
131
132/// Workflow data sent to client.
133#[derive(Debug, Clone, Serialize)]
134pub struct WorkflowData {
135    pub workflow_id: String,
136    pub status: String,
137    pub current_step: Option<String>,
138    pub steps: Vec<WorkflowStepData>,
139    pub output: Option<serde_json::Value>,
140    pub error: Option<String>,
141}
142
143/// Workflow step data sent to client.
144#[derive(Debug, Clone, Serialize)]
145pub struct WorkflowStepData {
146    pub name: String,
147    pub status: String,
148    pub error: Option<String>,
149}
150
151/// WebSocket upgrade handler.
152pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<WsState>>) -> Response {
153    ws.on_upgrade(move |socket| handle_socket(socket, state))
154}
155
156/// Handle a WebSocket connection.
157async fn handle_socket(socket: WebSocket, state: Arc<WsState>) {
158    let (mut ws_sender, mut ws_receiver) = socket.split();
159
160    // Create a session for this connection
161    let session_id = SessionId::new();
162    let session_uuid = session_id.0;
163    let node_uuid = state.node_id.0;
164
165    // Insert session into database for tracking
166    let _ = sqlx::query(
167        r#"
168        INSERT INTO forge_sessions (id, node_id, status, connected_at, last_activity)
169        VALUES ($1, $2, 'connected', NOW(), NOW())
170        ON CONFLICT (id) DO UPDATE SET status = 'connected', last_activity = NOW()
171        "#,
172    )
173    .bind(session_uuid)
174    .bind(node_uuid)
175    .execute(&state.db_pool)
176    .await;
177
178    // Create channels for reactor -> websocket communication
179    let (reactor_tx, mut reactor_rx) = mpsc::channel::<ReactorMessage>(256);
180
181    // Register session with reactor
182    state.reactor.register_session(session_id, reactor_tx).await;
183
184    // Track client subscription IDs to internal subscription IDs
185    #[allow(clippy::type_complexity)]
186    let client_to_internal: Arc<RwLock<HashMap<String, forge_core::realtime::SubscriptionId>>> =
187        Arc::new(RwLock::new(HashMap::new()));
188    let internal_to_client: Arc<RwLock<HashMap<forge_core::realtime::SubscriptionId, String>>> =
189        Arc::new(RwLock::new(HashMap::new()));
190
191    // Send connected message
192    let connected = ServerMessage::Connected;
193    if let Ok(json) = serde_json::to_string(&connected) {
194        let _ = ws_sender.send(Message::Text(json.into())).await;
195    }
196
197    tracing::debug!(?session_id, "WebSocket connection established");
198
199    // Clone state for the reactor message handler
200    let internal_to_client_clone = internal_to_client.clone();
201
202    // Spawn task to forward reactor messages to WebSocket
203    let sender_handle = tokio::spawn(async move {
204        while let Some(msg) = reactor_rx.recv().await {
205            let server_msg = match msg {
206                ReactorMessage::Data {
207                    subscription_id,
208                    data,
209                } => {
210                    // Map internal subscription ID back to client ID
211                    let client_id = {
212                        let map = internal_to_client_clone.read().await;
213                        map.get(&subscription_id).cloned()
214                    };
215
216                    if let Some(id) = client_id {
217                        ServerMessage::Data { id, data }
218                    } else {
219                        continue;
220                    }
221                }
222                ReactorMessage::DeltaUpdate {
223                    subscription_id,
224                    delta,
225                } => {
226                    // Map internal subscription ID back to client ID
227                    let client_id = {
228                        let map = internal_to_client_clone.read().await;
229                        map.get(&subscription_id).cloned()
230                    };
231
232                    if let Some(id) = client_id {
233                        // Convert delta to data update
234                        ServerMessage::Data {
235                            id,
236                            data: serde_json::json!({
237                                "delta": {
238                                    "added": delta.added,
239                                    "removed": delta.removed,
240                                    "updated": delta.updated
241                                }
242                            }),
243                        }
244                    } else {
245                        continue;
246                    }
247                }
248                ReactorMessage::JobUpdate { client_sub_id, job } => ServerMessage::JobUpdate {
249                    id: client_sub_id,
250                    job,
251                },
252                ReactorMessage::WorkflowUpdate {
253                    client_sub_id,
254                    workflow,
255                } => ServerMessage::WorkflowUpdate {
256                    id: client_sub_id,
257                    workflow,
258                },
259                ReactorMessage::Error { code, message } => ServerMessage::Error {
260                    id: None,
261                    code,
262                    message,
263                },
264                ReactorMessage::ErrorWithId { id, code, message } => ServerMessage::Error {
265                    id: Some(id),
266                    code,
267                    message,
268                },
269                ReactorMessage::Ping => ServerMessage::Pong,
270                ReactorMessage::Pong => continue,
271                _ => continue,
272            };
273
274            if let Ok(json) = serde_json::to_string(&server_msg) {
275                if ws_sender.send(Message::Text(json.into())).await.is_err() {
276                    break;
277                }
278            }
279        }
280    });
281
282    // Handle incoming messages from client
283    while let Some(msg) = ws_receiver.next().await {
284        let msg = match msg {
285            Ok(Message::Text(text)) => text,
286            Ok(Message::Close(_)) => break,
287            Ok(Message::Ping(data)) => {
288                // Note: Can't send directly since sender is moved, but axum handles pings
289                let _ = data;
290                continue;
291            }
292            _ => continue,
293        };
294
295        // Parse client message
296        let client_msg: ClientMessage = match serde_json::from_str(&msg) {
297            Ok(m) => m,
298            Err(e) => {
299                tracing::warn!("Failed to parse client message: {}", e);
300                continue;
301            }
302        };
303
304        match client_msg {
305            ClientMessage::Ping => {
306                // Pong is handled by the reactor message sender
307            }
308            ClientMessage::Auth { token: _ } => {
309                // TODO: Validate token and set auth context
310            }
311            ClientMessage::Subscribe {
312                id,
313                function_name,
314                args,
315            } => {
316                let normalized_args = args.unwrap_or(serde_json::Value::Null);
317
318                // Subscribe through reactor
319                match state
320                    .reactor
321                    .subscribe(session_id, id.clone(), function_name, normalized_args)
322                    .await
323                {
324                    Ok((subscription_id, data)) => {
325                        // Track the mapping
326                        {
327                            let mut map = client_to_internal.write().await;
328                            map.insert(id.clone(), subscription_id);
329                        }
330                        {
331                            let mut map = internal_to_client.write().await;
332                            map.insert(subscription_id, id.clone());
333                        }
334
335                        // Send subscribed confirmation + initial data
336                        // Note: These need to go through the reactor channel
337                        // For now, we'll store the messages to send
338                        tracing::debug!(?subscription_id, client_id = %id, "Subscription created");
339
340                        // The reactor will send the initial data through the reactor_tx channel
341                        // But we need to send the subscribed confirmation first
342                        // This is a bit awkward - let's inject directly
343
344                        // Actually, the data is returned from subscribe, so we should send it
345                        // The sender_handle has ws_sender, so we can't send from here directly
346                        // Let's use the reactor channel to send these messages
347
348                        let _ = state
349                            .reactor
350                            .ws_server()
351                            .send_to_session(
352                                session_id,
353                                ReactorMessage::Data {
354                                    subscription_id,
355                                    data,
356                                },
357                            )
358                            .await;
359                    }
360                    Err(e) => {
361                        let _ = state
362                            .reactor
363                            .ws_server()
364                            .send_to_session(
365                                session_id,
366                                ReactorMessage::Error {
367                                    code: "SUBSCRIBE_ERROR".to_string(),
368                                    message: e.to_string(),
369                                },
370                            )
371                            .await;
372                    }
373                }
374            }
375            ClientMessage::Unsubscribe { id } => {
376                // Look up internal subscription ID
377                let subscription_id = {
378                    let map = client_to_internal.read().await;
379                    map.get(&id).copied()
380                };
381
382                if let Some(sub_id) = subscription_id {
383                    state.reactor.unsubscribe(sub_id).await;
384
385                    // Clean up mappings
386                    {
387                        let mut map = client_to_internal.write().await;
388                        map.remove(&id);
389                    }
390                    {
391                        let mut map = internal_to_client.write().await;
392                        map.remove(&sub_id);
393                    }
394
395                    tracing::debug!(?sub_id, client_id = %id, "Subscription removed");
396                }
397            }
398            ClientMessage::SubscribeJob { id, job_id } => {
399                // SECURITY: Validate UUID BEFORE any processing
400                let job_uuid = match parse_uuid(&job_id, "job_id") {
401                    Ok(uuid) => uuid,
402                    Err(msg) => {
403                        // Send error to client, do NOT log the invalid input
404                        let _ = state
405                            .reactor
406                            .ws_server()
407                            .send_to_session(
408                                session_id,
409                                ReactorMessage::Error {
410                                    code: "INVALID_JOB_ID".to_string(),
411                                    message: msg,
412                                },
413                            )
414                            .await;
415                        continue;
416                    }
417                };
418
419                // SECURITY: Limit client_sub_id length
420                if id.len() > MAX_CLIENT_SUB_ID_LEN {
421                    let _ = state
422                        .reactor
423                        .ws_server()
424                        .send_to_session(
425                            session_id,
426                            ReactorMessage::Error {
427                                code: "INVALID_ID".to_string(),
428                                message: "Subscription ID too long".to_string(),
429                            },
430                        )
431                        .await;
432                    continue;
433                }
434
435                match state
436                    .reactor
437                    .subscribe_job(session_id, id.clone(), job_uuid)
438                    .await
439                {
440                    Ok(job_data) => {
441                        // Send current job state immediately
442                        let _ = state
443                            .reactor
444                            .ws_server()
445                            .send_to_session(
446                                session_id,
447                                ReactorMessage::JobUpdate {
448                                    client_sub_id: id,
449                                    job: job_data,
450                                },
451                            )
452                            .await;
453                    }
454                    Err(e) => {
455                        // Generic error - don't expose internal details
456                        let _ = state
457                            .reactor
458                            .ws_server()
459                            .send_to_session(
460                                session_id,
461                                ReactorMessage::ErrorWithId {
462                                    id: id.clone(),
463                                    code: "NOT_FOUND".to_string(),
464                                    message: "Job not found".to_string(),
465                                },
466                            )
467                            .await;
468                        tracing::warn!(job_id = %job_uuid, "Job subscription failed: {}", e);
469                    }
470                }
471            }
472            ClientMessage::UnsubscribeJob { id } => {
473                state.reactor.unsubscribe_job(session_id, &id).await;
474                tracing::debug!(client_id = %id, "Job subscription removed");
475            }
476            ClientMessage::SubscribeWorkflow { id, workflow_id } => {
477                // SECURITY: Validate UUID BEFORE any processing
478                let workflow_uuid = match parse_uuid(&workflow_id, "workflow_id") {
479                    Ok(uuid) => uuid,
480                    Err(msg) => {
481                        let _ = state
482                            .reactor
483                            .ws_server()
484                            .send_to_session(
485                                session_id,
486                                ReactorMessage::Error {
487                                    code: "INVALID_WORKFLOW_ID".to_string(),
488                                    message: msg,
489                                },
490                            )
491                            .await;
492                        continue;
493                    }
494                };
495
496                // SECURITY: Limit client_sub_id length
497                if id.len() > MAX_CLIENT_SUB_ID_LEN {
498                    let _ = state
499                        .reactor
500                        .ws_server()
501                        .send_to_session(
502                            session_id,
503                            ReactorMessage::Error {
504                                code: "INVALID_ID".to_string(),
505                                message: "Subscription ID too long".to_string(),
506                            },
507                        )
508                        .await;
509                    continue;
510                }
511
512                match state
513                    .reactor
514                    .subscribe_workflow(session_id, id.clone(), workflow_uuid)
515                    .await
516                {
517                    Ok(workflow_data) => {
518                        // Send current workflow state immediately
519                        let _ = state
520                            .reactor
521                            .ws_server()
522                            .send_to_session(
523                                session_id,
524                                ReactorMessage::WorkflowUpdate {
525                                    client_sub_id: id,
526                                    workflow: workflow_data,
527                                },
528                            )
529                            .await;
530                    }
531                    Err(e) => {
532                        let _ = state
533                            .reactor
534                            .ws_server()
535                            .send_to_session(
536                                session_id,
537                                ReactorMessage::ErrorWithId {
538                                    id: id.clone(),
539                                    code: "NOT_FOUND".to_string(),
540                                    message: "Workflow not found".to_string(),
541                                },
542                            )
543                            .await;
544                        tracing::warn!(workflow_id = %workflow_uuid, "Workflow subscription failed: {}", e);
545                    }
546                }
547            }
548            ClientMessage::UnsubscribeWorkflow { id } => {
549                state.reactor.unsubscribe_workflow(session_id, &id).await;
550                tracing::debug!(client_id = %id, "Workflow subscription removed");
551            }
552        }
553    }
554
555    // Clean up on disconnect
556    sender_handle.abort();
557    state.reactor.remove_session(session_id).await;
558
559    // Remove session from database
560    let _ = sqlx::query("DELETE FROM forge_sessions WHERE id = $1")
561        .bind(session_uuid)
562        .execute(&state.db_pool)
563        .await;
564
565    tracing::debug!(?session_id, "WebSocket connection closed");
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571
572    #[test]
573    fn test_client_message_parsing() {
574        let json = r#"{"type":"ping"}"#;
575        let msg: ClientMessage = serde_json::from_str(json).unwrap();
576        assert!(matches!(msg, ClientMessage::Ping));
577    }
578
579    #[test]
580    fn test_subscribe_message_parsing() {
581        let json = r#"{"type":"subscribe","id":"sub1","function":"get_users","args":null}"#;
582        let msg: ClientMessage = serde_json::from_str(json).unwrap();
583        assert!(matches!(msg, ClientMessage::Subscribe { .. }));
584    }
585
586    #[test]
587    fn test_server_message_serialization() {
588        let msg = ServerMessage::Connected;
589        let json = serde_json::to_string(&msg).unwrap();
590        assert!(json.contains("connected"));
591    }
592}