Skip to main content

forge_runtime/gateway/
sse.rs

1//! Server-Sent Events (SSE) handler for real-time updates.
2
3use std::collections::HashMap;
4use std::convert::Infallible;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use axum::Json;
11use axum::extract::{Extension, Query, State};
12use axum::http::StatusCode;
13use axum::response::IntoResponse;
14use axum::response::sse::{Event, KeepAlive, Sse};
15use futures_util::Stream;
16use serde::{Deserialize, Serialize};
17use tokio::sync::{RwLock, mpsc};
18use tokio_util::sync::CancellationToken;
19
20/// Wraps an mpsc::Receiver as a Stream for SSE.
21struct ReceiverStream<T> {
22    rx: mpsc::Receiver<T>,
23}
24
25impl<T> Stream for ReceiverStream<T> {
26    type Item = T;
27    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
28        self.rx.poll_recv(cx)
29    }
30}
31
32use forge_core::function::AuthContext;
33use forge_core::realtime::{SessionId, SubscriptionId};
34
35use super::auth::AuthMiddleware;
36use crate::realtime::Reactor;
37use crate::realtime::RealtimeMessage;
38
39/// Maximum length for client subscription IDs to prevent memory bloat
40const MAX_CLIENT_SUB_ID_LEN: usize = 255;
41
42/// Maximum subscriptions per session to prevent resource exhaustion
43const MAX_SUBSCRIPTIONS_PER_SESSION: usize = 100;
44
45fn try_parse_session_id(session_id: &str) -> Option<SessionId> {
46    uuid::Uuid::parse_str(session_id)
47        .ok()
48        .map(SessionId::from_uuid)
49}
50
51fn same_principal(a: &AuthContext, b: &AuthContext) -> bool {
52    match (a.is_authenticated(), b.is_authenticated()) {
53        (false, false) => true,
54        (true, true) => a.principal_id().is_some() && a.principal_id() == b.principal_id(),
55        _ => false,
56    }
57}
58
59fn resolve_sse_auth_context(
60    request_auth: &AuthContext,
61    query_auth: Option<AuthContext>,
62) -> AuthContext {
63    query_auth.unwrap_or_else(|| request_auth.clone())
64}
65
66fn authorize_session_access(
67    session: &SseSessionData,
68    session_secret: &str,
69    requester_auth: &AuthContext,
70) -> Result<AuthContext, (StatusCode, Json<SseSubscribeResponse>)> {
71    if session.session_secret != session_secret {
72        return Err(subscribe_error(
73            StatusCode::UNAUTHORIZED,
74            "INVALID_SESSION_SECRET",
75            "Session secret mismatch",
76        ));
77    }
78
79    if !same_principal(&session.auth_context, requester_auth) {
80        return Err(subscribe_error(
81            StatusCode::FORBIDDEN,
82            "SESSION_PRINCIPAL_MISMATCH",
83            "Request principal does not match session principal",
84        ));
85    }
86
87    Ok(session.auth_context.clone())
88}
89
90/// SSE configuration.
91#[derive(Debug, Clone)]
92pub struct SseConfig {
93    /// Maximum sessions per server instance.
94    pub max_sessions: usize,
95    /// Channel buffer size for SSE messages.
96    pub channel_buffer_size: usize,
97    /// Keepalive interval in seconds.
98    pub keepalive_interval_secs: u64,
99}
100
101impl Default for SseConfig {
102    fn default() -> Self {
103        Self {
104            max_sessions: 10_000,
105            channel_buffer_size: 256,
106            keepalive_interval_secs: 30,
107        }
108    }
109}
110
111/// SSE query parameters.
112#[derive(Debug, Deserialize)]
113pub struct SseQuery {
114    /// Authentication token.
115    pub token: Option<String>,
116}
117
118struct SseSessionData {
119    auth_context: AuthContext,
120    session_secret: String,
121    /// Maps client subscription ID -> internal SubscriptionId
122    subscriptions: HashMap<String, SubscriptionId>,
123}
124
125/// State for SSE handler.
126#[derive(Clone)]
127pub struct SseState {
128    reactor: Arc<Reactor>,
129    auth_middleware: Arc<AuthMiddleware>,
130    /// Per-session data: auth context and subscription mappings
131    sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
132    config: SseConfig,
133}
134
135impl SseState {
136    /// Create new SSE state with default config.
137    pub fn new(reactor: Arc<Reactor>, auth_middleware: Arc<AuthMiddleware>) -> Self {
138        Self::with_config(reactor, auth_middleware, SseConfig::default())
139    }
140
141    /// Create new SSE state with custom config.
142    pub fn with_config(
143        reactor: Arc<Reactor>,
144        auth_middleware: Arc<AuthMiddleware>,
145        config: SseConfig,
146    ) -> Self {
147        Self {
148            reactor,
149            auth_middleware,
150            sessions: Arc::new(RwLock::new(HashMap::new())),
151            config,
152        }
153    }
154
155    /// Check if we can accept new sessions.
156    pub async fn can_accept_session(&self) -> bool {
157        self.sessions.read().await.len() < self.config.max_sessions
158    }
159}
160
161/// Guard to ensure session cleanup on disconnect.
162/// Spawns a cleanup task on drop to handle abrupt disconnects.
163struct SessionCleanupGuard {
164    session_id: SessionId,
165    reactor: Arc<Reactor>,
166    sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
167    dropped: bool,
168}
169
170impl SessionCleanupGuard {
171    fn new(
172        session_id: SessionId,
173        reactor: Arc<Reactor>,
174        sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
175    ) -> Self {
176        Self {
177            session_id,
178            reactor,
179            sessions,
180            dropped: false,
181        }
182    }
183
184    /// Mark as cleanly closed (cleanup will be skipped).
185    fn mark_closed(&mut self) {
186        self.dropped = true;
187    }
188}
189
190impl Drop for SessionCleanupGuard {
191    fn drop(&mut self) {
192        if self.dropped {
193            return;
194        }
195        let session_id = self.session_id;
196        let reactor = self.reactor.clone();
197        let sessions = self.sessions.clone();
198
199        // Spawn cleanup task since we can't await in drop
200        // Use spawn to handle cleanup even if the runtime is shutting down
201        crate::observability::set_active_connections("sse", -1);
202        if let Ok(handle) = tokio::runtime::Handle::try_current() {
203            handle.spawn(async move {
204                reactor.remove_session(session_id).await;
205                sessions.write().await.remove(&session_id);
206                tracing::debug!(%session_id, "SSE session cleaned up on disconnect");
207            });
208        } else {
209            // Runtime not available, likely shutting down. Session will be cleaned up on restart.
210            tracing::warn!(%session_id, "Could not spawn cleanup task, runtime unavailable");
211        }
212    }
213}
214
215/// SSE event payload sent to clients.
216#[derive(Debug, Serialize)]
217#[serde(tag = "type", rename_all = "snake_case")]
218pub enum SsePayload {
219    /// Subscription data update.
220    Update {
221        target: String,
222        payload: serde_json::Value,
223    },
224    /// Error message.
225    Error {
226        target: String,
227        code: String,
228        message: String,
229    },
230    /// Connection acknowledged.
231    Connected {
232        session_id: String,
233        session_secret: String,
234    },
235}
236
237/// Internal message type for SSE stream.
238#[derive(Debug)]
239pub enum SseMessage {
240    Data {
241        target: String,
242        payload: serde_json::Value,
243    },
244    Error {
245        target: String,
246        code: String,
247        message: String,
248    },
249}
250
251/// SSE subscribe request.
252#[derive(Debug, Deserialize)]
253pub struct SseSubscribeRequest {
254    pub session_id: String,
255    pub session_secret: String,
256    pub id: String,
257    pub function: String,
258    #[serde(default)]
259    pub args: serde_json::Value,
260}
261
262/// SSE unsubscribe request.
263#[derive(Debug, Deserialize)]
264pub struct SseUnsubscribeRequest {
265    pub session_id: String,
266    pub session_secret: String,
267    pub id: String,
268}
269
270/// SSE job subscribe request.
271#[derive(Debug, Deserialize)]
272pub struct SseJobSubscribeRequest {
273    pub session_id: String,
274    pub session_secret: String,
275    pub id: String,
276    pub job_id: String,
277}
278
279/// SSE workflow subscribe request.
280#[derive(Debug, Deserialize)]
281pub struct SseWorkflowSubscribeRequest {
282    pub session_id: String,
283    pub session_secret: String,
284    pub id: String,
285    pub workflow_id: String,
286}
287
288/// SSE error response.
289#[derive(Debug, Serialize)]
290pub struct SseError {
291    pub code: String,
292    pub message: String,
293}
294
295/// SSE subscribe response.
296#[derive(Debug, Serialize)]
297pub struct SseSubscribeResponse {
298    pub success: bool,
299    #[serde(skip_serializing_if = "Option::is_none")]
300    pub data: Option<serde_json::Value>,
301    #[serde(skip_serializing_if = "Option::is_none")]
302    pub error: Option<SseError>,
303}
304
305/// SSE unsubscribe response.
306#[derive(Debug, Serialize)]
307pub struct SseUnsubscribeResponse {
308    pub success: bool,
309    #[serde(skip_serializing_if = "Option::is_none")]
310    pub error: Option<SseError>,
311}
312
313/// Create an SSE subscribe error response.
314fn subscribe_error(
315    status: StatusCode,
316    code: impl Into<String>,
317    message: impl Into<String>,
318) -> (StatusCode, Json<SseSubscribeResponse>) {
319    (
320        status,
321        Json(SseSubscribeResponse {
322            success: false,
323            data: None,
324            error: Some(SseError {
325                code: code.into(),
326                message: message.into(),
327            }),
328        }),
329    )
330}
331
332/// Create an SSE unsubscribe error response.
333fn unsubscribe_error(
334    status: StatusCode,
335    code: impl Into<String>,
336    message: impl Into<String>,
337) -> (StatusCode, Json<SseUnsubscribeResponse>) {
338    (
339        status,
340        Json(SseUnsubscribeResponse {
341            success: false,
342            error: Some(SseError {
343                code: code.into(),
344                message: message.into(),
345            }),
346        }),
347    )
348}
349
350/// SSE handler for GET /events.
351pub async fn sse_handler(
352    State(state): State<Arc<SseState>>,
353    Extension(request_auth): Extension<AuthContext>,
354    Query(query): Query<SseQuery>,
355) -> impl IntoResponse {
356    // Check session limit
357    if !state.can_accept_session().await {
358        return (
359            StatusCode::SERVICE_UNAVAILABLE,
360            "Server at capacity".to_string(),
361        )
362            .into_response();
363    }
364
365    let session_id = SessionId::new();
366    let buffer_size = state.config.channel_buffer_size;
367    let keepalive_secs = state.config.keepalive_interval_secs;
368    let (tx, mut rx) = mpsc::channel::<SseMessage>(buffer_size);
369    let cancel_token = CancellationToken::new();
370
371    let query_auth = if let Some(token) = &query.token {
372        match state.auth_middleware.validate_token_async(token).await {
373            Ok(claims) => Some(super::auth::build_auth_context_from_claims(claims)),
374            Err(e) => {
375                tracing::warn!("SSE token validation failed: {}", e);
376                return (
377                    StatusCode::UNAUTHORIZED,
378                    "Invalid authentication token".to_string(),
379                )
380                    .into_response();
381            }
382        }
383    } else {
384        None
385    };
386    let auth_context = resolve_sse_auth_context(&request_auth, query_auth);
387    let session_secret = uuid::Uuid::new_v4().to_string();
388
389    // Register session with reactor
390    let reactor = state.reactor.clone();
391    let cancel = cancel_token.clone();
392
393    // Create a bridge channel for the reactor's message format
394    let (rt_tx, mut rt_rx) = mpsc::channel(buffer_size);
395    reactor.register_session(session_id, rt_tx);
396
397    // Store session data for subscription handlers
398    {
399        let mut sessions = state.sessions.write().await;
400        sessions.insert(
401            session_id,
402            SseSessionData {
403                auth_context: auth_context.clone(),
404                session_secret: session_secret.clone(),
405                subscriptions: HashMap::new(),
406            },
407        );
408    }
409
410    // Capture sessions for cleanup guard
411    let sessions = state.sessions.clone();
412
413    // Create cleanup guard - will clean up on drop if stream ends unexpectedly
414    let cleanup_guard = SessionCleanupGuard::new(session_id, reactor.clone(), sessions.clone());
415    crate::observability::set_active_connections("sse", 1);
416
417    // Bridge reactor messages to SSE messages
418    let bridge_cancel = cancel_token.clone();
419    tokio::spawn(async move {
420        loop {
421            tokio::select! {
422                msg = rt_rx.recv() => {
423                    match msg {
424                        Some(rt_msg) => {
425                            if let Some(sse_msg) = convert_realtime_to_sse(rt_msg)
426                                && tx.send(sse_msg).await.is_err() {
427                                    break;
428                                }
429                        }
430                        None => break,
431                    }
432                }
433                _ = bridge_cancel.cancelled() => break,
434            }
435        }
436    });
437
438    // Spawn a task that feeds SSE events into a channel
439    let (event_tx, event_rx) = mpsc::channel::<Result<Event, Infallible>>(buffer_size);
440
441    tokio::spawn(async move {
442        let mut _guard = cleanup_guard;
443
444        // Send connected event
445        let connected = SsePayload::Connected {
446            session_id: session_id.to_string(),
447            session_secret: session_secret.clone(),
448        };
449        match serde_json::to_string(&connected) {
450            Ok(json) => {
451                let _ = event_tx
452                    .send(Ok(Event::default().event("connected").data(json)))
453                    .await;
454            }
455            Err(e) => {
456                tracing::error!("Failed to serialize SSE connected payload: {}", e);
457            }
458        }
459
460        loop {
461            tokio::select! {
462                msg = rx.recv() => {
463                    match msg {
464                        Some(sse_msg) => {
465                            let event = match sse_msg {
466                                SseMessage::Data { target, payload } => {
467                                    let data = SsePayload::Update { target, payload };
468                                    serde_json::to_string(&data).ok().map(|json| {
469                                        Event::default().event("update").data(json)
470                                    })
471                                }
472                                SseMessage::Error { target, code, message } => {
473                                    let data = SsePayload::Error { target, code, message };
474                                    serde_json::to_string(&data).ok().map(|json| {
475                                        Event::default().event("error").data(json)
476                                    })
477                                }
478                            };
479                            if let Some(evt) = event
480                                && event_tx.send(Ok(evt)).await.is_err()
481                            {
482                                break;
483                            }
484                        }
485                        None => break,
486                    }
487                }
488                _ = cancel.cancelled() => break,
489            }
490        }
491
492        // Clean shutdown: decrement counter and clean up session state
493        _guard.mark_closed();
494        crate::observability::set_active_connections("sse", -1);
495        reactor.remove_session(session_id).await;
496        sessions.write().await.remove(&session_id);
497    });
498
499    let stream = ReceiverStream { rx: event_rx };
500
501    Sse::new(stream)
502        .keep_alive(
503            KeepAlive::new()
504                .interval(Duration::from_secs(keepalive_secs))
505                .text("ping"),
506        )
507        .into_response()
508}
509
510/// Convert realtime message to SSE message.
511fn convert_realtime_to_sse(msg: RealtimeMessage) -> Option<SseMessage> {
512    match msg {
513        RealtimeMessage::Data {
514            subscription_id,
515            data,
516        } => Some(SseMessage::Data {
517            target: format!("sub:{}", subscription_id),
518            payload: data,
519        }),
520        RealtimeMessage::DeltaUpdate {
521            subscription_id,
522            delta,
523        } => match serde_json::to_value(&delta) {
524            Ok(payload) => Some(SseMessage::Data {
525                target: format!("sub:{}", subscription_id),
526                payload,
527            }),
528            Err(e) => {
529                tracing::error!("Failed to serialize delta update: {}", e);
530                Some(SseMessage::Error {
531                    target: format!("sub:{}", subscription_id),
532                    code: "SERIALIZATION_ERROR".to_string(),
533                    message: "Failed to serialize update data".to_string(),
534                })
535            }
536        },
537        RealtimeMessage::JobUpdate { client_sub_id, job } => match serde_json::to_value(&job) {
538            Ok(payload) => Some(SseMessage::Data {
539                target: format!("job:{}", client_sub_id),
540                payload,
541            }),
542            Err(e) => {
543                tracing::error!("Failed to serialize job update: {}", e);
544                Some(SseMessage::Error {
545                    target: format!("job:{}", client_sub_id),
546                    code: "SERIALIZATION_ERROR".to_string(),
547                    message: "Failed to serialize job update".to_string(),
548                })
549            }
550        },
551        RealtimeMessage::WorkflowUpdate {
552            client_sub_id,
553            workflow,
554        } => match serde_json::to_value(&workflow) {
555            Ok(payload) => Some(SseMessage::Data {
556                target: format!("wf:{}", client_sub_id),
557                payload,
558            }),
559            Err(e) => {
560                tracing::error!("Failed to serialize workflow update: {}", e);
561                Some(SseMessage::Error {
562                    target: format!("wf:{}", client_sub_id),
563                    code: "SERIALIZATION_ERROR".to_string(),
564                    message: "Failed to serialize workflow update".to_string(),
565                })
566            }
567        },
568        RealtimeMessage::Error { code, message } => Some(SseMessage::Error {
569            target: String::new(),
570            code,
571            message,
572        }),
573        RealtimeMessage::ErrorWithId { id, code, message } => Some(SseMessage::Error {
574            target: id,
575            code,
576            message,
577        }),
578        // Ignore control messages
579        RealtimeMessage::Subscribe { .. }
580        | RealtimeMessage::Unsubscribe { .. }
581        | RealtimeMessage::Ping
582        | RealtimeMessage::Pong
583        | RealtimeMessage::AuthSuccess
584        | RealtimeMessage::AuthFailed { .. }
585        | RealtimeMessage::Lagging => None,
586    }
587}
588
589/// SSE subscribe handler for POST /subscribe.
590pub async fn sse_subscribe_handler(
591    State(state): State<Arc<SseState>>,
592    Extension(request_auth): Extension<AuthContext>,
593    Json(request): Json<SseSubscribeRequest>,
594) -> impl IntoResponse {
595    // Validate subscription ID length to prevent memory bloat
596    if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
597        return subscribe_error(
598            StatusCode::BAD_REQUEST,
599            "INVALID_ID",
600            format!(
601                "Subscription ID too long (max {} chars)",
602                MAX_CLIENT_SUB_ID_LEN
603            ),
604        );
605    }
606
607    let Some(session_id) = try_parse_session_id(&request.session_id) else {
608        return subscribe_error(
609            StatusCode::BAD_REQUEST,
610            "INVALID_SESSION",
611            "Invalid session ID format",
612        );
613    };
614
615    // Get session data (auth context) - use write lock to atomically check limit and prevent TOCTOU
616    let sessions = state.sessions.write().await;
617    let session_data = match sessions.get(&session_id) {
618        Some(data) => {
619            // Enforce per-session subscription limit to prevent resource exhaustion
620            if data.subscriptions.len() >= MAX_SUBSCRIPTIONS_PER_SESSION {
621                return subscribe_error(
622                    StatusCode::TOO_MANY_REQUESTS,
623                    "TOO_MANY_SUBSCRIPTIONS",
624                    format!(
625                        "Session has reached the maximum of {} subscriptions",
626                        MAX_SUBSCRIPTIONS_PER_SESSION
627                    ),
628                );
629            }
630            match authorize_session_access(data, &request.session_secret, &request_auth) {
631                Ok(auth) => auth,
632                Err(resp) => return resp,
633            }
634        }
635        None => {
636            return subscribe_error(
637                StatusCode::NOT_FOUND,
638                "SESSION_NOT_FOUND",
639                "Session not found or expired",
640            );
641        }
642    };
643    drop(sessions);
644
645    // Subscribe via reactor
646    let result = state
647        .reactor
648        .subscribe(
649            session_id,
650            request.id.clone(),
651            request.function,
652            request.args,
653            session_data,
654        )
655        .await;
656
657    match result {
658        Ok((subscription_id, data)) => {
659            // Store the subscription mapping
660            let mut sessions = state.sessions.write().await;
661            match sessions.get_mut(&session_id) {
662                Some(session) => {
663                    session.subscriptions.insert(request.id, subscription_id);
664                }
665                None => {
666                    // Session was removed between read and write lock
667                    return subscribe_error(
668                        StatusCode::NOT_FOUND,
669                        "SESSION_NOT_FOUND",
670                        "Session expired during subscription",
671                    );
672                }
673            }
674
675            tracing::debug!(
676                %session_id,
677                %subscription_id,
678                "SSE subscription registered"
679            );
680
681            (
682                StatusCode::OK,
683                Json(SseSubscribeResponse {
684                    success: true,
685                    data: Some(data),
686                    error: None,
687                }),
688            )
689        }
690        Err(e) => {
691            tracing::warn!(%session_id, error = %e, "SSE subscription failed");
692            match e {
693                forge_core::ForgeError::Unauthorized(msg) => {
694                    subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
695                }
696                forge_core::ForgeError::Forbidden(msg) => {
697                    subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
698                }
699                forge_core::ForgeError::InvalidArgument(msg)
700                | forge_core::ForgeError::Validation(msg) => {
701                    subscribe_error(StatusCode::BAD_REQUEST, "INVALID_ARGUMENT", msg)
702                }
703                forge_core::ForgeError::NotFound(msg) => {
704                    subscribe_error(StatusCode::NOT_FOUND, "NOT_FOUND", msg)
705                }
706                _ => subscribe_error(
707                    StatusCode::INTERNAL_SERVER_ERROR,
708                    "SUBSCRIPTION_FAILED",
709                    "Subscription failed",
710                ),
711            }
712        }
713    }
714}
715
716/// SSE unsubscribe handler for POST /unsubscribe.
717pub async fn sse_unsubscribe_handler(
718    State(state): State<Arc<SseState>>,
719    Extension(request_auth): Extension<AuthContext>,
720    Json(request): Json<SseUnsubscribeRequest>,
721) -> impl IntoResponse {
722    let Some(session_id) = try_parse_session_id(&request.session_id) else {
723        return unsubscribe_error(
724            StatusCode::BAD_REQUEST,
725            "INVALID_SESSION",
726            "Invalid session ID format",
727        );
728    };
729
730    // Look up internal subscription ID and validate session ownership
731    let subscription_id = {
732        let sessions = state.sessions.read().await;
733        match sessions.get(&session_id) {
734            Some(session) => {
735                if session.session_secret != request.session_secret
736                    || !same_principal(&session.auth_context, &request_auth)
737                {
738                    return unsubscribe_error(
739                        StatusCode::FORBIDDEN,
740                        "SESSION_PRINCIPAL_MISMATCH",
741                        "Request principal does not match session principal",
742                    );
743                }
744                session.subscriptions.get(&request.id).copied()
745            }
746            None => None,
747        }
748    };
749
750    let Some(subscription_id) = subscription_id else {
751        return unsubscribe_error(
752            StatusCode::NOT_FOUND,
753            "SUBSCRIPTION_NOT_FOUND",
754            "Subscription not found",
755        );
756    };
757
758    // Unsubscribe via reactor
759    state.reactor.unsubscribe(subscription_id);
760
761    // Remove from session tracking
762    {
763        let mut sessions = state.sessions.write().await;
764        if let Some(session) = sessions.get_mut(&session_id) {
765            session.subscriptions.remove(&request.id);
766        }
767    }
768
769    tracing::debug!(
770        %session_id,
771        %subscription_id,
772        "SSE subscription removed"
773    );
774
775    (
776        StatusCode::OK,
777        Json(SseUnsubscribeResponse {
778            success: true,
779            error: None,
780        }),
781    )
782}
783
784/// SSE job subscribe handler for POST /subscribe-job.
785pub async fn sse_job_subscribe_handler(
786    State(state): State<Arc<SseState>>,
787    Extension(request_auth): Extension<AuthContext>,
788    Json(request): Json<SseJobSubscribeRequest>,
789) -> impl IntoResponse {
790    if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
791        return subscribe_error(
792            StatusCode::BAD_REQUEST,
793            "INVALID_ID",
794            format!(
795                "Subscription ID too long (max {} chars)",
796                MAX_CLIENT_SUB_ID_LEN
797            ),
798        );
799    }
800
801    let Some(session_id) = try_parse_session_id(&request.session_id) else {
802        return subscribe_error(
803            StatusCode::BAD_REQUEST,
804            "INVALID_SESSION",
805            "Invalid session ID format",
806        );
807    };
808
809    // Validate session exists + principal binding
810    let session_auth = {
811        let sessions = state.sessions.read().await;
812        match sessions.get(&session_id) {
813            Some(session) => {
814                match authorize_session_access(session, &request.session_secret, &request_auth) {
815                    Ok(auth) => auth,
816                    Err(resp) => return resp,
817                }
818            }
819            None => {
820                return subscribe_error(
821                    StatusCode::NOT_FOUND,
822                    "SESSION_NOT_FOUND",
823                    "Session not found or expired",
824                );
825            }
826        }
827    };
828
829    // Parse job ID
830    let job_uuid = match uuid::Uuid::parse_str(&request.job_id) {
831        Ok(uuid) => uuid,
832        Err(_) => {
833            return subscribe_error(
834                StatusCode::BAD_REQUEST,
835                "INVALID_JOB_ID",
836                "Invalid job ID format",
837            );
838        }
839    };
840
841    // Subscribe to job updates via reactor
842    match state
843        .reactor
844        .subscribe_job(session_id, request.id.clone(), job_uuid, &session_auth)
845        .await
846    {
847        Ok(job_data) => {
848            let data = match serde_json::to_value(&job_data) {
849                Ok(v) => v,
850                Err(e) => {
851                    tracing::error!("Failed to serialize job data: {}", e);
852                    return subscribe_error(
853                        StatusCode::INTERNAL_SERVER_ERROR,
854                        "SERIALIZE_ERROR",
855                        "Failed to serialize job data",
856                    );
857                }
858            };
859            tracing::debug!(
860                %session_id,
861                job_id = %request.job_id,
862                client_sub_id = %request.id,
863                "SSE job subscription registered"
864            );
865            (
866                StatusCode::OK,
867                Json(SseSubscribeResponse {
868                    success: true,
869                    data: Some(data),
870                    error: None,
871                }),
872            )
873        }
874        Err(e) => match e {
875            forge_core::ForgeError::Unauthorized(msg) => {
876                subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
877            }
878            forge_core::ForgeError::Forbidden(msg) => {
879                subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
880            }
881            forge_core::ForgeError::NotFound(msg) => {
882                subscribe_error(StatusCode::NOT_FOUND, "JOB_NOT_FOUND", msg)
883            }
884            _ => subscribe_error(
885                StatusCode::INTERNAL_SERVER_ERROR,
886                "SUBSCRIPTION_FAILED",
887                "Subscription failed",
888            ),
889        },
890    }
891}
892
893/// SSE workflow subscribe handler for POST /subscribe-workflow.
894pub async fn sse_workflow_subscribe_handler(
895    State(state): State<Arc<SseState>>,
896    Extension(request_auth): Extension<AuthContext>,
897    Json(request): Json<SseWorkflowSubscribeRequest>,
898) -> impl IntoResponse {
899    if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
900        return subscribe_error(
901            StatusCode::BAD_REQUEST,
902            "INVALID_ID",
903            format!(
904                "Subscription ID too long (max {} chars)",
905                MAX_CLIENT_SUB_ID_LEN
906            ),
907        );
908    }
909
910    let Some(session_id) = try_parse_session_id(&request.session_id) else {
911        return subscribe_error(
912            StatusCode::BAD_REQUEST,
913            "INVALID_SESSION",
914            "Invalid session ID format",
915        );
916    };
917
918    // Validate session exists + principal binding
919    let session_auth = {
920        let sessions = state.sessions.read().await;
921        match sessions.get(&session_id) {
922            Some(session) => {
923                match authorize_session_access(session, &request.session_secret, &request_auth) {
924                    Ok(auth) => auth,
925                    Err(resp) => return resp,
926                }
927            }
928            None => {
929                return subscribe_error(
930                    StatusCode::NOT_FOUND,
931                    "SESSION_NOT_FOUND",
932                    "Session not found or expired",
933                );
934            }
935        }
936    };
937
938    // Parse workflow ID
939    let workflow_uuid = match uuid::Uuid::parse_str(&request.workflow_id) {
940        Ok(uuid) => uuid,
941        Err(_) => {
942            return subscribe_error(
943                StatusCode::BAD_REQUEST,
944                "INVALID_WORKFLOW_ID",
945                "Invalid workflow ID format",
946            );
947        }
948    };
949
950    // Subscribe to workflow updates via reactor
951    match state
952        .reactor
953        .subscribe_workflow(session_id, request.id.clone(), workflow_uuid, &session_auth)
954        .await
955    {
956        Ok(workflow_data) => {
957            let data = match serde_json::to_value(&workflow_data) {
958                Ok(v) => v,
959                Err(e) => {
960                    tracing::error!("Failed to serialize workflow data: {}", e);
961                    return subscribe_error(
962                        StatusCode::INTERNAL_SERVER_ERROR,
963                        "SERIALIZE_ERROR",
964                        "Failed to serialize workflow data",
965                    );
966                }
967            };
968            tracing::debug!(
969                %session_id,
970                workflow_id = %request.workflow_id,
971                client_sub_id = %request.id,
972                "SSE workflow subscription registered"
973            );
974            (
975                StatusCode::OK,
976                Json(SseSubscribeResponse {
977                    success: true,
978                    data: Some(data),
979                    error: None,
980                }),
981            )
982        }
983        Err(e) => match e {
984            forge_core::ForgeError::Unauthorized(msg) => {
985                subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
986            }
987            forge_core::ForgeError::Forbidden(msg) => {
988                subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
989            }
990            forge_core::ForgeError::NotFound(msg) => {
991                subscribe_error(StatusCode::NOT_FOUND, "WORKFLOW_NOT_FOUND", msg)
992            }
993            _ => subscribe_error(
994                StatusCode::INTERNAL_SERVER_ERROR,
995                "SUBSCRIPTION_FAILED",
996                "Subscription failed",
997            ),
998        },
999    }
1000}
1001
1002#[cfg(test)]
1003#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
1004mod tests {
1005    use super::*;
1006    use std::collections::HashMap;
1007
1008    use uuid::Uuid;
1009
1010    #[test]
1011    fn test_sse_payload_serialization() {
1012        let payload = SsePayload::Update {
1013            target: "sub:123".to_string(),
1014            payload: serde_json::json!({"id": 1}),
1015        };
1016        let json = serde_json::to_string(&payload).unwrap();
1017        assert!(json.contains("\"type\":\"update\""));
1018        assert!(json.contains("\"target\":\"sub:123\""));
1019    }
1020
1021    #[test]
1022    fn test_sse_error_serialization() {
1023        let payload = SsePayload::Error {
1024            target: "sub:456".to_string(),
1025            code: "NOT_FOUND".to_string(),
1026            message: "Subscription not found".to_string(),
1027        };
1028        let json = serde_json::to_string(&payload).unwrap();
1029        assert!(json.contains("\"type\":\"error\""));
1030        assert!(json.contains("NOT_FOUND"));
1031    }
1032
1033    #[test]
1034    fn resolve_sse_auth_context_prefers_request_auth_when_query_token_absent() {
1035        let request_auth =
1036            AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new());
1037
1038        let resolved = resolve_sse_auth_context(&request_auth, None);
1039
1040        assert!(resolved.is_authenticated());
1041        assert_eq!(resolved.principal_id(), request_auth.principal_id());
1042    }
1043
1044    #[test]
1045    fn resolve_sse_auth_context_prefers_query_token_when_present() {
1046        let request_auth =
1047            AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new());
1048        let query_auth =
1049            AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new());
1050
1051        let resolved = resolve_sse_auth_context(&request_auth, Some(query_auth.clone()));
1052
1053        assert_eq!(resolved.principal_id(), query_auth.principal_id());
1054    }
1055}