Skip to main content

forge_runtime/gateway/
sse.rs

1//! Server-Sent Events (SSE) handler for real-time updates.
2
3use std::collections::HashMap;
4use std::convert::Infallible;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use axum::Json;
11use axum::extract::{Extension, Query, State};
12use axum::http::StatusCode;
13use axum::response::IntoResponse;
14use axum::response::sse::{Event, KeepAlive, Sse};
15use futures_util::Stream;
16use serde::{Deserialize, Serialize};
17use tokio::sync::{RwLock, mpsc};
18use tokio_util::sync::CancellationToken;
19
20/// Wraps an mpsc::Receiver as a Stream for SSE.
21struct ReceiverStream<T> {
22    rx: mpsc::Receiver<T>,
23}
24
25impl<T> Stream for ReceiverStream<T> {
26    type Item = T;
27    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
28        self.rx.poll_recv(cx)
29    }
30}
31
32use forge_core::function::AuthContext;
33use forge_core::realtime::{SessionId, SubscriptionId};
34
35use super::auth::AuthMiddleware;
36use crate::realtime::Reactor;
37use crate::realtime::RealtimeMessage;
38
39/// Maximum length for client subscription IDs to prevent memory bloat
40const MAX_CLIENT_SUB_ID_LEN: usize = 255;
41
42fn try_parse_session_id(session_id: &str) -> Option<SessionId> {
43    uuid::Uuid::parse_str(session_id)
44        .ok()
45        .map(SessionId::from_uuid)
46}
47
48fn same_principal(a: &AuthContext, b: &AuthContext) -> bool {
49    match (a.is_authenticated(), b.is_authenticated()) {
50        (false, false) => true,
51        (true, true) => a.principal_id().is_some() && a.principal_id() == b.principal_id(),
52        _ => false,
53    }
54}
55
56fn resolve_sse_auth_context(
57    request_auth: &AuthContext,
58    query_auth: Option<AuthContext>,
59) -> AuthContext {
60    query_auth.unwrap_or_else(|| request_auth.clone())
61}
62
63fn authorize_session_access(
64    session: &SseSessionData,
65    session_secret: &str,
66    requester_auth: &AuthContext,
67) -> Result<AuthContext, (StatusCode, Json<SseSubscribeResponse>)> {
68    if session.session_secret != session_secret {
69        return Err(subscribe_error(
70            StatusCode::UNAUTHORIZED,
71            "INVALID_SESSION_SECRET",
72            "Session secret mismatch",
73        ));
74    }
75
76    if !same_principal(&session.auth_context, requester_auth) {
77        return Err(subscribe_error(
78            StatusCode::FORBIDDEN,
79            "SESSION_PRINCIPAL_MISMATCH",
80            "Request principal does not match session principal",
81        ));
82    }
83
84    Ok(session.auth_context.clone())
85}
86
87/// SSE configuration.
88#[derive(Debug, Clone)]
89pub struct SseConfig {
90    /// Maximum sessions per server instance.
91    pub max_sessions: usize,
92    /// Channel buffer size for SSE messages.
93    pub channel_buffer_size: usize,
94    /// Keepalive interval in seconds.
95    pub keepalive_interval_secs: u64,
96}
97
98impl Default for SseConfig {
99    fn default() -> Self {
100        Self {
101            max_sessions: 10_000,
102            channel_buffer_size: 256,
103            keepalive_interval_secs: 30,
104        }
105    }
106}
107
108/// SSE query parameters.
109#[derive(Debug, Deserialize)]
110pub struct SseQuery {
111    /// Authentication token.
112    pub token: Option<String>,
113}
114
115struct SseSessionData {
116    auth_context: AuthContext,
117    session_secret: String,
118    /// Maps client subscription ID -> internal SubscriptionId
119    subscriptions: HashMap<String, SubscriptionId>,
120}
121
122/// State for SSE handler.
123#[derive(Clone)]
124pub struct SseState {
125    reactor: Arc<Reactor>,
126    auth_middleware: Arc<AuthMiddleware>,
127    /// Per-session data: auth context and subscription mappings
128    sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
129    config: SseConfig,
130}
131
132impl SseState {
133    /// Create new SSE state with default config.
134    pub fn new(reactor: Arc<Reactor>, auth_middleware: Arc<AuthMiddleware>) -> Self {
135        Self::with_config(reactor, auth_middleware, SseConfig::default())
136    }
137
138    /// Create new SSE state with custom config.
139    pub fn with_config(
140        reactor: Arc<Reactor>,
141        auth_middleware: Arc<AuthMiddleware>,
142        config: SseConfig,
143    ) -> Self {
144        Self {
145            reactor,
146            auth_middleware,
147            sessions: Arc::new(RwLock::new(HashMap::new())),
148            config,
149        }
150    }
151
152    /// Check if we can accept new sessions.
153    pub async fn can_accept_session(&self) -> bool {
154        self.sessions.read().await.len() < self.config.max_sessions
155    }
156}
157
158/// Guard to ensure session cleanup on disconnect.
159/// Spawns a cleanup task on drop to handle abrupt disconnects.
160struct SessionCleanupGuard {
161    session_id: SessionId,
162    reactor: Arc<Reactor>,
163    sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
164    dropped: bool,
165}
166
167impl SessionCleanupGuard {
168    fn new(
169        session_id: SessionId,
170        reactor: Arc<Reactor>,
171        sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
172    ) -> Self {
173        Self {
174            session_id,
175            reactor,
176            sessions,
177            dropped: false,
178        }
179    }
180
181    /// Mark as cleanly closed (cleanup will be skipped).
182    fn mark_closed(&mut self) {
183        self.dropped = true;
184    }
185}
186
187impl Drop for SessionCleanupGuard {
188    fn drop(&mut self) {
189        if self.dropped {
190            return;
191        }
192        let session_id = self.session_id;
193        let reactor = self.reactor.clone();
194        let sessions = self.sessions.clone();
195
196        // Spawn cleanup task since we can't await in drop
197        // Use spawn to handle cleanup even if the runtime is shutting down
198        crate::observability::set_active_connections("sse", -1);
199        if let Ok(handle) = tokio::runtime::Handle::try_current() {
200            handle.spawn(async move {
201                reactor.remove_session(session_id).await;
202                sessions.write().await.remove(&session_id);
203                tracing::debug!(%session_id, "SSE session cleaned up on disconnect");
204            });
205        } else {
206            // Runtime not available, likely shutting down. Session will be cleaned up on restart.
207            tracing::warn!(%session_id, "Could not spawn cleanup task, runtime unavailable");
208        }
209    }
210}
211
212/// SSE event payload sent to clients.
213#[derive(Debug, Serialize)]
214#[serde(tag = "type", rename_all = "snake_case")]
215pub enum SsePayload {
216    /// Subscription data update.
217    Update {
218        target: String,
219        payload: serde_json::Value,
220    },
221    /// Error message.
222    Error {
223        target: String,
224        code: String,
225        message: String,
226    },
227    /// Connection acknowledged.
228    Connected {
229        session_id: String,
230        session_secret: String,
231    },
232}
233
234/// Internal message type for SSE stream.
235#[derive(Debug)]
236pub enum SseMessage {
237    Data {
238        target: String,
239        payload: serde_json::Value,
240    },
241    Error {
242        target: String,
243        code: String,
244        message: String,
245    },
246}
247
248/// SSE subscribe request.
249#[derive(Debug, Deserialize)]
250pub struct SseSubscribeRequest {
251    pub session_id: String,
252    pub session_secret: String,
253    pub id: String,
254    pub function: String,
255    #[serde(default)]
256    pub args: serde_json::Value,
257}
258
259/// SSE unsubscribe request.
260#[derive(Debug, Deserialize)]
261pub struct SseUnsubscribeRequest {
262    pub session_id: String,
263    pub session_secret: String,
264    pub id: String,
265}
266
267/// SSE job subscribe request.
268#[derive(Debug, Deserialize)]
269pub struct SseJobSubscribeRequest {
270    pub session_id: String,
271    pub session_secret: String,
272    pub id: String,
273    pub job_id: String,
274}
275
276/// SSE workflow subscribe request.
277#[derive(Debug, Deserialize)]
278pub struct SseWorkflowSubscribeRequest {
279    pub session_id: String,
280    pub session_secret: String,
281    pub id: String,
282    pub workflow_id: String,
283}
284
285/// SSE error response.
286#[derive(Debug, Serialize)]
287pub struct SseError {
288    pub code: String,
289    pub message: String,
290}
291
292/// SSE subscribe response.
293#[derive(Debug, Serialize)]
294pub struct SseSubscribeResponse {
295    pub success: bool,
296    #[serde(skip_serializing_if = "Option::is_none")]
297    pub data: Option<serde_json::Value>,
298    #[serde(skip_serializing_if = "Option::is_none")]
299    pub error: Option<SseError>,
300}
301
302/// SSE unsubscribe response.
303#[derive(Debug, Serialize)]
304pub struct SseUnsubscribeResponse {
305    pub success: bool,
306    #[serde(skip_serializing_if = "Option::is_none")]
307    pub error: Option<SseError>,
308}
309
310/// Create an SSE subscribe error response.
311fn subscribe_error(
312    status: StatusCode,
313    code: impl Into<String>,
314    message: impl Into<String>,
315) -> (StatusCode, Json<SseSubscribeResponse>) {
316    (
317        status,
318        Json(SseSubscribeResponse {
319            success: false,
320            data: None,
321            error: Some(SseError {
322                code: code.into(),
323                message: message.into(),
324            }),
325        }),
326    )
327}
328
329/// Create an SSE unsubscribe error response.
330fn unsubscribe_error(
331    status: StatusCode,
332    code: impl Into<String>,
333    message: impl Into<String>,
334) -> (StatusCode, Json<SseUnsubscribeResponse>) {
335    (
336        status,
337        Json(SseUnsubscribeResponse {
338            success: false,
339            error: Some(SseError {
340                code: code.into(),
341                message: message.into(),
342            }),
343        }),
344    )
345}
346
347/// SSE handler for GET /events.
348pub async fn sse_handler(
349    State(state): State<Arc<SseState>>,
350    Extension(request_auth): Extension<AuthContext>,
351    Query(query): Query<SseQuery>,
352) -> impl IntoResponse {
353    // Check session limit
354    if !state.can_accept_session().await {
355        return (
356            StatusCode::SERVICE_UNAVAILABLE,
357            "Server at capacity".to_string(),
358        )
359            .into_response();
360    }
361
362    let session_id = SessionId::new();
363    let buffer_size = state.config.channel_buffer_size;
364    let keepalive_secs = state.config.keepalive_interval_secs;
365    let (tx, mut rx) = mpsc::channel::<SseMessage>(buffer_size);
366    let cancel_token = CancellationToken::new();
367
368    let query_auth = if let Some(token) = &query.token {
369        match state.auth_middleware.validate_token_async(token).await {
370            Ok(claims) => Some(super::auth::build_auth_context_from_claims(claims)),
371            Err(e) => {
372                tracing::warn!("SSE token validation failed: {}", e);
373                return (
374                    StatusCode::UNAUTHORIZED,
375                    "Invalid authentication token".to_string(),
376                )
377                    .into_response();
378            }
379        }
380    } else {
381        None
382    };
383    let auth_context = resolve_sse_auth_context(&request_auth, query_auth);
384    let session_secret = uuid::Uuid::new_v4().to_string();
385
386    // Register session with reactor
387    let reactor = state.reactor.clone();
388    let cancel = cancel_token.clone();
389
390    // Create a bridge channel for the reactor's message format
391    let (rt_tx, mut rt_rx) = mpsc::channel(buffer_size);
392    reactor.register_session(session_id, rt_tx);
393
394    // Store session data for subscription handlers
395    {
396        let mut sessions = state.sessions.write().await;
397        sessions.insert(
398            session_id,
399            SseSessionData {
400                auth_context: auth_context.clone(),
401                session_secret: session_secret.clone(),
402                subscriptions: HashMap::new(),
403            },
404        );
405    }
406
407    // Capture sessions for cleanup guard
408    let sessions = state.sessions.clone();
409
410    // Create cleanup guard - will clean up on drop if stream ends unexpectedly
411    let cleanup_guard = SessionCleanupGuard::new(session_id, reactor.clone(), sessions.clone());
412    crate::observability::set_active_connections("sse", 1);
413
414    // Bridge reactor messages to SSE messages
415    let bridge_cancel = cancel_token.clone();
416    tokio::spawn(async move {
417        loop {
418            tokio::select! {
419                msg = rt_rx.recv() => {
420                    match msg {
421                        Some(rt_msg) => {
422                            if let Some(sse_msg) = convert_realtime_to_sse(rt_msg)
423                                && tx.send(sse_msg).await.is_err() {
424                                    break;
425                                }
426                        }
427                        None => break,
428                    }
429                }
430                _ = bridge_cancel.cancelled() => break,
431            }
432        }
433    });
434
435    // Spawn a task that feeds SSE events into a channel
436    let (event_tx, event_rx) = mpsc::channel::<Result<Event, Infallible>>(buffer_size);
437
438    tokio::spawn(async move {
439        let mut _guard = cleanup_guard;
440
441        // Send connected event
442        let connected = SsePayload::Connected {
443            session_id: session_id.to_string(),
444            session_secret: session_secret.clone(),
445        };
446        match serde_json::to_string(&connected) {
447            Ok(json) => {
448                let _ = event_tx
449                    .send(Ok(Event::default().event("connected").data(json)))
450                    .await;
451            }
452            Err(e) => {
453                tracing::error!("Failed to serialize SSE connected payload: {}", e);
454            }
455        }
456
457        loop {
458            tokio::select! {
459                msg = rx.recv() => {
460                    match msg {
461                        Some(sse_msg) => {
462                            let event = match sse_msg {
463                                SseMessage::Data { target, payload } => {
464                                    let data = SsePayload::Update { target, payload };
465                                    serde_json::to_string(&data).ok().map(|json| {
466                                        Event::default().event("update").data(json)
467                                    })
468                                }
469                                SseMessage::Error { target, code, message } => {
470                                    let data = SsePayload::Error { target, code, message };
471                                    serde_json::to_string(&data).ok().map(|json| {
472                                        Event::default().event("error").data(json)
473                                    })
474                                }
475                            };
476                            if let Some(evt) = event
477                                && event_tx.send(Ok(evt)).await.is_err()
478                            {
479                                break;
480                            }
481                        }
482                        None => break,
483                    }
484                }
485                _ = cancel.cancelled() => break,
486            }
487        }
488
489        // Clean shutdown: decrement counter and clean up session state
490        _guard.mark_closed();
491        crate::observability::set_active_connections("sse", -1);
492        reactor.remove_session(session_id).await;
493        sessions.write().await.remove(&session_id);
494    });
495
496    let stream = ReceiverStream { rx: event_rx };
497
498    Sse::new(stream)
499        .keep_alive(
500            KeepAlive::new()
501                .interval(Duration::from_secs(keepalive_secs))
502                .text("ping"),
503        )
504        .into_response()
505}
506
507/// Convert realtime message to SSE message.
508fn convert_realtime_to_sse(msg: RealtimeMessage) -> Option<SseMessage> {
509    match msg {
510        RealtimeMessage::Data {
511            subscription_id,
512            data,
513        } => Some(SseMessage::Data {
514            target: format!("sub:{}", subscription_id),
515            payload: data,
516        }),
517        RealtimeMessage::DeltaUpdate {
518            subscription_id,
519            delta,
520        } => match serde_json::to_value(&delta) {
521            Ok(payload) => Some(SseMessage::Data {
522                target: format!("sub:{}", subscription_id),
523                payload,
524            }),
525            Err(e) => {
526                tracing::error!("Failed to serialize delta update: {}", e);
527                Some(SseMessage::Error {
528                    target: format!("sub:{}", subscription_id),
529                    code: "SERIALIZATION_ERROR".to_string(),
530                    message: "Failed to serialize update data".to_string(),
531                })
532            }
533        },
534        RealtimeMessage::JobUpdate { client_sub_id, job } => match serde_json::to_value(&job) {
535            Ok(payload) => Some(SseMessage::Data {
536                target: format!("job:{}", client_sub_id),
537                payload,
538            }),
539            Err(e) => {
540                tracing::error!("Failed to serialize job update: {}", e);
541                Some(SseMessage::Error {
542                    target: format!("job:{}", client_sub_id),
543                    code: "SERIALIZATION_ERROR".to_string(),
544                    message: "Failed to serialize job update".to_string(),
545                })
546            }
547        },
548        RealtimeMessage::WorkflowUpdate {
549            client_sub_id,
550            workflow,
551        } => match serde_json::to_value(&workflow) {
552            Ok(payload) => Some(SseMessage::Data {
553                target: format!("wf:{}", client_sub_id),
554                payload,
555            }),
556            Err(e) => {
557                tracing::error!("Failed to serialize workflow update: {}", e);
558                Some(SseMessage::Error {
559                    target: format!("wf:{}", client_sub_id),
560                    code: "SERIALIZATION_ERROR".to_string(),
561                    message: "Failed to serialize workflow update".to_string(),
562                })
563            }
564        },
565        RealtimeMessage::Error { code, message } => Some(SseMessage::Error {
566            target: String::new(),
567            code,
568            message,
569        }),
570        RealtimeMessage::ErrorWithId { id, code, message } => Some(SseMessage::Error {
571            target: id,
572            code,
573            message,
574        }),
575        // Ignore control messages
576        RealtimeMessage::Subscribe { .. }
577        | RealtimeMessage::Unsubscribe { .. }
578        | RealtimeMessage::Ping
579        | RealtimeMessage::Pong
580        | RealtimeMessage::AuthSuccess
581        | RealtimeMessage::AuthFailed { .. }
582        | RealtimeMessage::Lagging => None,
583    }
584}
585
586/// SSE subscribe handler for POST /subscribe.
587pub async fn sse_subscribe_handler(
588    State(state): State<Arc<SseState>>,
589    Extension(request_auth): Extension<AuthContext>,
590    Json(request): Json<SseSubscribeRequest>,
591) -> impl IntoResponse {
592    // Validate subscription ID length to prevent memory bloat
593    if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
594        return subscribe_error(
595            StatusCode::BAD_REQUEST,
596            "INVALID_ID",
597            format!(
598                "Subscription ID too long (max {} chars)",
599                MAX_CLIENT_SUB_ID_LEN
600            ),
601        );
602    }
603
604    let Some(session_id) = try_parse_session_id(&request.session_id) else {
605        return subscribe_error(
606            StatusCode::BAD_REQUEST,
607            "INVALID_SESSION",
608            "Invalid session ID format",
609        );
610    };
611
612    // Get session data (auth context)
613    let sessions = state.sessions.read().await;
614    let session_data = match sessions.get(&session_id) {
615        Some(data) => {
616            match authorize_session_access(data, &request.session_secret, &request_auth) {
617                Ok(auth) => auth,
618                Err(resp) => return resp,
619            }
620        }
621        None => {
622            return subscribe_error(
623                StatusCode::NOT_FOUND,
624                "SESSION_NOT_FOUND",
625                "Session not found or expired",
626            );
627        }
628    };
629    drop(sessions);
630
631    // Subscribe via reactor
632    let result = state
633        .reactor
634        .subscribe(
635            session_id,
636            request.id.clone(),
637            request.function,
638            request.args,
639            session_data,
640        )
641        .await;
642
643    match result {
644        Ok((subscription_id, data)) => {
645            // Store the subscription mapping
646            let mut sessions = state.sessions.write().await;
647            match sessions.get_mut(&session_id) {
648                Some(session) => {
649                    session.subscriptions.insert(request.id, subscription_id);
650                }
651                None => {
652                    // Session was removed between read and write lock
653                    return subscribe_error(
654                        StatusCode::NOT_FOUND,
655                        "SESSION_NOT_FOUND",
656                        "Session expired during subscription",
657                    );
658                }
659            }
660
661            tracing::debug!(
662                %session_id,
663                %subscription_id,
664                "SSE subscription registered"
665            );
666
667            (
668                StatusCode::OK,
669                Json(SseSubscribeResponse {
670                    success: true,
671                    data: Some(data),
672                    error: None,
673                }),
674            )
675        }
676        Err(e) => {
677            tracing::warn!(%session_id, error = %e, "SSE subscription failed");
678            match e {
679                forge_core::ForgeError::Unauthorized(msg) => {
680                    subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
681                }
682                forge_core::ForgeError::Forbidden(msg) => {
683                    subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
684                }
685                forge_core::ForgeError::InvalidArgument(msg)
686                | forge_core::ForgeError::Validation(msg) => {
687                    subscribe_error(StatusCode::BAD_REQUEST, "INVALID_ARGUMENT", msg)
688                }
689                forge_core::ForgeError::NotFound(msg) => {
690                    subscribe_error(StatusCode::NOT_FOUND, "NOT_FOUND", msg)
691                }
692                _ => subscribe_error(
693                    StatusCode::INTERNAL_SERVER_ERROR,
694                    "SUBSCRIPTION_FAILED",
695                    "Subscription failed",
696                ),
697            }
698        }
699    }
700}
701
702/// SSE unsubscribe handler for POST /unsubscribe.
703pub async fn sse_unsubscribe_handler(
704    State(state): State<Arc<SseState>>,
705    Extension(request_auth): Extension<AuthContext>,
706    Json(request): Json<SseUnsubscribeRequest>,
707) -> impl IntoResponse {
708    let Some(session_id) = try_parse_session_id(&request.session_id) else {
709        return unsubscribe_error(
710            StatusCode::BAD_REQUEST,
711            "INVALID_SESSION",
712            "Invalid session ID format",
713        );
714    };
715
716    // Look up internal subscription ID and validate session ownership
717    let subscription_id = {
718        let sessions = state.sessions.read().await;
719        match sessions.get(&session_id) {
720            Some(session) => {
721                if session.session_secret != request.session_secret
722                    || !same_principal(&session.auth_context, &request_auth)
723                {
724                    return unsubscribe_error(
725                        StatusCode::FORBIDDEN,
726                        "SESSION_PRINCIPAL_MISMATCH",
727                        "Request principal does not match session principal",
728                    );
729                }
730                session.subscriptions.get(&request.id).copied()
731            }
732            None => None,
733        }
734    };
735
736    let Some(subscription_id) = subscription_id else {
737        return unsubscribe_error(
738            StatusCode::NOT_FOUND,
739            "SUBSCRIPTION_NOT_FOUND",
740            "Subscription not found",
741        );
742    };
743
744    // Unsubscribe via reactor
745    state.reactor.unsubscribe(subscription_id);
746
747    // Remove from session tracking
748    {
749        let mut sessions = state.sessions.write().await;
750        if let Some(session) = sessions.get_mut(&session_id) {
751            session.subscriptions.remove(&request.id);
752        }
753    }
754
755    tracing::debug!(
756        %session_id,
757        %subscription_id,
758        "SSE subscription removed"
759    );
760
761    (
762        StatusCode::OK,
763        Json(SseUnsubscribeResponse {
764            success: true,
765            error: None,
766        }),
767    )
768}
769
770/// SSE job subscribe handler for POST /subscribe-job.
771pub async fn sse_job_subscribe_handler(
772    State(state): State<Arc<SseState>>,
773    Extension(request_auth): Extension<AuthContext>,
774    Json(request): Json<SseJobSubscribeRequest>,
775) -> impl IntoResponse {
776    if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
777        return subscribe_error(
778            StatusCode::BAD_REQUEST,
779            "INVALID_ID",
780            format!(
781                "Subscription ID too long (max {} chars)",
782                MAX_CLIENT_SUB_ID_LEN
783            ),
784        );
785    }
786
787    let Some(session_id) = try_parse_session_id(&request.session_id) else {
788        return subscribe_error(
789            StatusCode::BAD_REQUEST,
790            "INVALID_SESSION",
791            "Invalid session ID format",
792        );
793    };
794
795    // Validate session exists + principal binding
796    let session_auth = {
797        let sessions = state.sessions.read().await;
798        match sessions.get(&session_id) {
799            Some(session) => {
800                match authorize_session_access(session, &request.session_secret, &request_auth) {
801                    Ok(auth) => auth,
802                    Err(resp) => return resp,
803                }
804            }
805            None => {
806                return subscribe_error(
807                    StatusCode::NOT_FOUND,
808                    "SESSION_NOT_FOUND",
809                    "Session not found or expired",
810                );
811            }
812        }
813    };
814
815    // Parse job ID
816    let job_uuid = match uuid::Uuid::parse_str(&request.job_id) {
817        Ok(uuid) => uuid,
818        Err(_) => {
819            return subscribe_error(
820                StatusCode::BAD_REQUEST,
821                "INVALID_JOB_ID",
822                "Invalid job ID format",
823            );
824        }
825    };
826
827    // Subscribe to job updates via reactor
828    match state
829        .reactor
830        .subscribe_job(session_id, request.id.clone(), job_uuid, &session_auth)
831        .await
832    {
833        Ok(job_data) => {
834            let data = match serde_json::to_value(&job_data) {
835                Ok(v) => v,
836                Err(e) => {
837                    tracing::error!("Failed to serialize job data: {}", e);
838                    return subscribe_error(
839                        StatusCode::INTERNAL_SERVER_ERROR,
840                        "SERIALIZE_ERROR",
841                        "Failed to serialize job data",
842                    );
843                }
844            };
845            tracing::debug!(
846                %session_id,
847                job_id = %request.job_id,
848                client_sub_id = %request.id,
849                "SSE job subscription registered"
850            );
851            (
852                StatusCode::OK,
853                Json(SseSubscribeResponse {
854                    success: true,
855                    data: Some(data),
856                    error: None,
857                }),
858            )
859        }
860        Err(e) => match e {
861            forge_core::ForgeError::Unauthorized(msg) => {
862                subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
863            }
864            forge_core::ForgeError::Forbidden(msg) => {
865                subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
866            }
867            forge_core::ForgeError::NotFound(msg) => {
868                subscribe_error(StatusCode::NOT_FOUND, "JOB_NOT_FOUND", msg)
869            }
870            _ => subscribe_error(
871                StatusCode::INTERNAL_SERVER_ERROR,
872                "SUBSCRIPTION_FAILED",
873                "Subscription failed",
874            ),
875        },
876    }
877}
878
879/// SSE workflow subscribe handler for POST /subscribe-workflow.
880pub async fn sse_workflow_subscribe_handler(
881    State(state): State<Arc<SseState>>,
882    Extension(request_auth): Extension<AuthContext>,
883    Json(request): Json<SseWorkflowSubscribeRequest>,
884) -> impl IntoResponse {
885    if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
886        return subscribe_error(
887            StatusCode::BAD_REQUEST,
888            "INVALID_ID",
889            format!(
890                "Subscription ID too long (max {} chars)",
891                MAX_CLIENT_SUB_ID_LEN
892            ),
893        );
894    }
895
896    let Some(session_id) = try_parse_session_id(&request.session_id) else {
897        return subscribe_error(
898            StatusCode::BAD_REQUEST,
899            "INVALID_SESSION",
900            "Invalid session ID format",
901        );
902    };
903
904    // Validate session exists + principal binding
905    let session_auth = {
906        let sessions = state.sessions.read().await;
907        match sessions.get(&session_id) {
908            Some(session) => {
909                match authorize_session_access(session, &request.session_secret, &request_auth) {
910                    Ok(auth) => auth,
911                    Err(resp) => return resp,
912                }
913            }
914            None => {
915                return subscribe_error(
916                    StatusCode::NOT_FOUND,
917                    "SESSION_NOT_FOUND",
918                    "Session not found or expired",
919                );
920            }
921        }
922    };
923
924    // Parse workflow ID
925    let workflow_uuid = match uuid::Uuid::parse_str(&request.workflow_id) {
926        Ok(uuid) => uuid,
927        Err(_) => {
928            return subscribe_error(
929                StatusCode::BAD_REQUEST,
930                "INVALID_WORKFLOW_ID",
931                "Invalid workflow ID format",
932            );
933        }
934    };
935
936    // Subscribe to workflow updates via reactor
937    match state
938        .reactor
939        .subscribe_workflow(session_id, request.id.clone(), workflow_uuid, &session_auth)
940        .await
941    {
942        Ok(workflow_data) => {
943            let data = match serde_json::to_value(&workflow_data) {
944                Ok(v) => v,
945                Err(e) => {
946                    tracing::error!("Failed to serialize workflow data: {}", e);
947                    return subscribe_error(
948                        StatusCode::INTERNAL_SERVER_ERROR,
949                        "SERIALIZE_ERROR",
950                        "Failed to serialize workflow data",
951                    );
952                }
953            };
954            tracing::debug!(
955                %session_id,
956                workflow_id = %request.workflow_id,
957                client_sub_id = %request.id,
958                "SSE workflow subscription registered"
959            );
960            (
961                StatusCode::OK,
962                Json(SseSubscribeResponse {
963                    success: true,
964                    data: Some(data),
965                    error: None,
966                }),
967            )
968        }
969        Err(e) => match e {
970            forge_core::ForgeError::Unauthorized(msg) => {
971                subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
972            }
973            forge_core::ForgeError::Forbidden(msg) => {
974                subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
975            }
976            forge_core::ForgeError::NotFound(msg) => {
977                subscribe_error(StatusCode::NOT_FOUND, "WORKFLOW_NOT_FOUND", msg)
978            }
979            _ => subscribe_error(
980                StatusCode::INTERNAL_SERVER_ERROR,
981                "SUBSCRIPTION_FAILED",
982                "Subscription failed",
983            ),
984        },
985    }
986}
987
988#[cfg(test)]
989#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
990mod tests {
991    use super::*;
992    use std::collections::HashMap;
993
994    use uuid::Uuid;
995
996    #[test]
997    fn test_sse_payload_serialization() {
998        let payload = SsePayload::Update {
999            target: "sub:123".to_string(),
1000            payload: serde_json::json!({"id": 1}),
1001        };
1002        let json = serde_json::to_string(&payload).unwrap();
1003        assert!(json.contains("\"type\":\"update\""));
1004        assert!(json.contains("\"target\":\"sub:123\""));
1005    }
1006
1007    #[test]
1008    fn test_sse_error_serialization() {
1009        let payload = SsePayload::Error {
1010            target: "sub:456".to_string(),
1011            code: "NOT_FOUND".to_string(),
1012            message: "Subscription not found".to_string(),
1013        };
1014        let json = serde_json::to_string(&payload).unwrap();
1015        assert!(json.contains("\"type\":\"error\""));
1016        assert!(json.contains("NOT_FOUND"));
1017    }
1018
1019    #[test]
1020    fn resolve_sse_auth_context_prefers_request_auth_when_query_token_absent() {
1021        let request_auth =
1022            AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new());
1023
1024        let resolved = resolve_sse_auth_context(&request_auth, None);
1025
1026        assert!(resolved.is_authenticated());
1027        assert_eq!(resolved.principal_id(), request_auth.principal_id());
1028    }
1029
1030    #[test]
1031    fn resolve_sse_auth_context_prefers_query_token_when_present() {
1032        let request_auth =
1033            AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new());
1034        let query_auth =
1035            AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new());
1036
1037        let resolved = resolve_sse_auth_context(&request_auth, Some(query_auth.clone()));
1038
1039        assert_eq!(resolved.principal_id(), query_auth.principal_id());
1040    }
1041}