Skip to main content

forge_runtime/gateway/
sse.rs

1//! Server-Sent Events (SSE) handler for real-time updates.
2
3use std::collections::HashMap;
4use std::convert::Infallible;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use axum::Json;
11use axum::extract::{Extension, Query, State};
12use axum::http::StatusCode;
13use axum::response::IntoResponse;
14use axum::response::sse::{Event, KeepAlive, Sse};
15use futures_util::Stream;
16use serde::{Deserialize, Serialize};
17use tokio::sync::{RwLock, mpsc};
18use tokio_util::sync::CancellationToken;
19
20/// Wraps an mpsc::Receiver as a Stream for SSE.
21struct ReceiverStream<T> {
22    rx: mpsc::Receiver<T>,
23}
24
25impl<T> Stream for ReceiverStream<T> {
26    type Item = T;
27    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
28        self.rx.poll_recv(cx)
29    }
30}
31
32use forge_core::function::AuthContext;
33use forge_core::realtime::{SessionId, SubscriptionId};
34
35use super::auth::AuthMiddleware;
36use crate::realtime::Reactor;
37use crate::realtime::RealtimeMessage;
38
39/// Maximum length for client subscription IDs to prevent memory bloat
40const MAX_CLIENT_SUB_ID_LEN: usize = 255;
41
42fn try_parse_session_id(session_id: &str) -> Option<SessionId> {
43    uuid::Uuid::parse_str(session_id)
44        .ok()
45        .map(SessionId::from_uuid)
46}
47
48fn same_principal(a: &AuthContext, b: &AuthContext) -> bool {
49    match (a.is_authenticated(), b.is_authenticated()) {
50        (false, false) => true,
51        (true, true) => a.principal_id().is_some() && a.principal_id() == b.principal_id(),
52        _ => false,
53    }
54}
55
56fn authorize_session_access(
57    session: &SseSessionData,
58    session_secret: &str,
59    requester_auth: &AuthContext,
60) -> Result<AuthContext, (StatusCode, Json<SseSubscribeResponse>)> {
61    if session.session_secret != session_secret {
62        return Err(subscribe_error(
63            StatusCode::UNAUTHORIZED,
64            "INVALID_SESSION_SECRET",
65            "Session secret mismatch",
66        ));
67    }
68
69    if !same_principal(&session.auth_context, requester_auth) {
70        return Err(subscribe_error(
71            StatusCode::FORBIDDEN,
72            "SESSION_PRINCIPAL_MISMATCH",
73            "Request principal does not match session principal",
74        ));
75    }
76
77    Ok(session.auth_context.clone())
78}
79
80/// SSE configuration.
81#[derive(Debug, Clone)]
82pub struct SseConfig {
83    /// Maximum sessions per server instance.
84    pub max_sessions: usize,
85    /// Channel buffer size for SSE messages.
86    pub channel_buffer_size: usize,
87    /// Keepalive interval in seconds.
88    pub keepalive_interval_secs: u64,
89}
90
91impl Default for SseConfig {
92    fn default() -> Self {
93        Self {
94            max_sessions: 10_000,
95            channel_buffer_size: 256,
96            keepalive_interval_secs: 30,
97        }
98    }
99}
100
101/// SSE query parameters.
102#[derive(Debug, Deserialize)]
103pub struct SseQuery {
104    /// Authentication token.
105    pub token: Option<String>,
106}
107
108struct SseSessionData {
109    auth_context: AuthContext,
110    session_secret: String,
111    /// Maps client subscription ID -> internal SubscriptionId
112    subscriptions: HashMap<String, SubscriptionId>,
113}
114
115/// State for SSE handler.
116#[derive(Clone)]
117pub struct SseState {
118    reactor: Arc<Reactor>,
119    auth_middleware: Arc<AuthMiddleware>,
120    /// Per-session data: auth context and subscription mappings
121    sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
122    config: SseConfig,
123}
124
125impl SseState {
126    /// Create new SSE state with default config.
127    pub fn new(reactor: Arc<Reactor>, auth_middleware: Arc<AuthMiddleware>) -> Self {
128        Self::with_config(reactor, auth_middleware, SseConfig::default())
129    }
130
131    /// Create new SSE state with custom config.
132    pub fn with_config(
133        reactor: Arc<Reactor>,
134        auth_middleware: Arc<AuthMiddleware>,
135        config: SseConfig,
136    ) -> Self {
137        Self {
138            reactor,
139            auth_middleware,
140            sessions: Arc::new(RwLock::new(HashMap::new())),
141            config,
142        }
143    }
144
145    /// Check if we can accept new sessions.
146    pub async fn can_accept_session(&self) -> bool {
147        self.sessions.read().await.len() < self.config.max_sessions
148    }
149}
150
151/// Guard to ensure session cleanup on disconnect.
152/// Spawns a cleanup task on drop to handle abrupt disconnects.
153struct SessionCleanupGuard {
154    session_id: SessionId,
155    reactor: Arc<Reactor>,
156    sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
157    dropped: bool,
158}
159
160impl SessionCleanupGuard {
161    fn new(
162        session_id: SessionId,
163        reactor: Arc<Reactor>,
164        sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
165    ) -> Self {
166        Self {
167            session_id,
168            reactor,
169            sessions,
170            dropped: false,
171        }
172    }
173
174    /// Mark as cleanly closed (cleanup will be skipped).
175    fn mark_closed(&mut self) {
176        self.dropped = true;
177    }
178}
179
180impl Drop for SessionCleanupGuard {
181    fn drop(&mut self) {
182        if self.dropped {
183            return;
184        }
185        let session_id = self.session_id;
186        let reactor = self.reactor.clone();
187        let sessions = self.sessions.clone();
188
189        // Spawn cleanup task since we can't await in drop
190        // Use spawn to handle cleanup even if the runtime is shutting down
191        if let Ok(handle) = tokio::runtime::Handle::try_current() {
192            handle.spawn(async move {
193                reactor.remove_session(session_id).await;
194                sessions.write().await.remove(&session_id);
195                tracing::debug!(%session_id, "SSE session cleaned up on disconnect");
196            });
197        } else {
198            // Runtime not available, likely shutting down. Session will be cleaned up on restart.
199            tracing::warn!(%session_id, "Could not spawn cleanup task, runtime unavailable");
200        }
201    }
202}
203
204/// SSE event payload sent to clients.
205#[derive(Debug, Serialize)]
206#[serde(tag = "type", rename_all = "snake_case")]
207pub enum SsePayload {
208    /// Subscription data update.
209    Update {
210        target: String,
211        payload: serde_json::Value,
212    },
213    /// Error message.
214    Error {
215        target: String,
216        code: String,
217        message: String,
218    },
219    /// Connection acknowledged.
220    Connected {
221        session_id: String,
222        session_secret: String,
223    },
224}
225
226/// Internal message type for SSE stream.
227#[derive(Debug)]
228pub enum SseMessage {
229    Data {
230        target: String,
231        payload: serde_json::Value,
232    },
233    Error {
234        target: String,
235        code: String,
236        message: String,
237    },
238}
239
240/// SSE subscribe request.
241#[derive(Debug, Deserialize)]
242pub struct SseSubscribeRequest {
243    pub session_id: String,
244    pub session_secret: String,
245    pub id: String,
246    pub function: String,
247    #[serde(default)]
248    pub args: serde_json::Value,
249}
250
251/// SSE unsubscribe request.
252#[derive(Debug, Deserialize)]
253pub struct SseUnsubscribeRequest {
254    pub session_id: String,
255    pub session_secret: String,
256    pub id: String,
257}
258
259/// SSE job subscribe request.
260#[derive(Debug, Deserialize)]
261pub struct SseJobSubscribeRequest {
262    pub session_id: String,
263    pub session_secret: String,
264    pub id: String,
265    pub job_id: String,
266}
267
268/// SSE workflow subscribe request.
269#[derive(Debug, Deserialize)]
270pub struct SseWorkflowSubscribeRequest {
271    pub session_id: String,
272    pub session_secret: String,
273    pub id: String,
274    pub workflow_id: String,
275}
276
277/// SSE error response.
278#[derive(Debug, Serialize)]
279pub struct SseError {
280    pub code: String,
281    pub message: String,
282}
283
284/// SSE subscribe response.
285#[derive(Debug, Serialize)]
286pub struct SseSubscribeResponse {
287    pub success: bool,
288    #[serde(skip_serializing_if = "Option::is_none")]
289    pub data: Option<serde_json::Value>,
290    #[serde(skip_serializing_if = "Option::is_none")]
291    pub error: Option<SseError>,
292}
293
294/// SSE unsubscribe response.
295#[derive(Debug, Serialize)]
296pub struct SseUnsubscribeResponse {
297    pub success: bool,
298    #[serde(skip_serializing_if = "Option::is_none")]
299    pub error: Option<SseError>,
300}
301
302/// Create an SSE subscribe error response.
303fn subscribe_error(
304    status: StatusCode,
305    code: impl Into<String>,
306    message: impl Into<String>,
307) -> (StatusCode, Json<SseSubscribeResponse>) {
308    (
309        status,
310        Json(SseSubscribeResponse {
311            success: false,
312            data: None,
313            error: Some(SseError {
314                code: code.into(),
315                message: message.into(),
316            }),
317        }),
318    )
319}
320
321/// Create an SSE unsubscribe error response.
322fn unsubscribe_error(
323    status: StatusCode,
324    code: impl Into<String>,
325    message: impl Into<String>,
326) -> (StatusCode, Json<SseUnsubscribeResponse>) {
327    (
328        status,
329        Json(SseUnsubscribeResponse {
330            success: false,
331            error: Some(SseError {
332                code: code.into(),
333                message: message.into(),
334            }),
335        }),
336    )
337}
338
339/// SSE handler for GET /events.
340pub async fn sse_handler(
341    State(state): State<Arc<SseState>>,
342    Query(query): Query<SseQuery>,
343) -> impl IntoResponse {
344    // Check session limit
345    if !state.can_accept_session().await {
346        return (
347            StatusCode::SERVICE_UNAVAILABLE,
348            "Server at capacity".to_string(),
349        )
350            .into_response();
351    }
352
353    let session_id = SessionId::new();
354    let buffer_size = state.config.channel_buffer_size;
355    let keepalive_secs = state.config.keepalive_interval_secs;
356    let (tx, mut rx) = mpsc::channel::<SseMessage>(buffer_size);
357    let cancel_token = CancellationToken::new();
358
359    let auth_context = if let Some(token) = &query.token {
360        match state.auth_middleware.validate_token_async(token).await {
361            Ok(claims) => super::auth::build_auth_context_from_claims(claims),
362            Err(e) => {
363                tracing::warn!("SSE token validation failed: {}", e);
364                return (
365                    StatusCode::UNAUTHORIZED,
366                    "Invalid authentication token".to_string(),
367                )
368                    .into_response();
369            }
370        }
371    } else {
372        forge_core::function::AuthContext::unauthenticated()
373    };
374    let session_secret = uuid::Uuid::new_v4().to_string();
375
376    // Register session with reactor
377    let reactor = state.reactor.clone();
378    let cancel = cancel_token.clone();
379
380    // Create a bridge channel for the reactor's message format
381    let (rt_tx, mut rt_rx) = mpsc::channel(buffer_size);
382    reactor.register_session(session_id, rt_tx);
383
384    // Store session data for subscription handlers
385    {
386        let mut sessions = state.sessions.write().await;
387        sessions.insert(
388            session_id,
389            SseSessionData {
390                auth_context: auth_context.clone(),
391                session_secret: session_secret.clone(),
392                subscriptions: HashMap::new(),
393            },
394        );
395    }
396
397    // Capture sessions for cleanup guard
398    let sessions = state.sessions.clone();
399
400    // Create cleanup guard - will clean up on drop if stream ends unexpectedly
401    let cleanup_guard = SessionCleanupGuard::new(session_id, reactor.clone(), sessions.clone());
402
403    // Bridge reactor messages to SSE messages
404    let bridge_cancel = cancel_token.clone();
405    tokio::spawn(async move {
406        loop {
407            tokio::select! {
408                msg = rt_rx.recv() => {
409                    match msg {
410                        Some(rt_msg) => {
411                            if let Some(sse_msg) = convert_realtime_to_sse(rt_msg)
412                                && tx.send(sse_msg).await.is_err() {
413                                    break;
414                                }
415                        }
416                        None => break,
417                    }
418                }
419                _ = bridge_cancel.cancelled() => break,
420            }
421        }
422    });
423
424    // Spawn a task that feeds SSE events into a channel
425    let (event_tx, event_rx) = mpsc::channel::<Result<Event, Infallible>>(buffer_size);
426
427    tokio::spawn(async move {
428        let mut _guard = cleanup_guard;
429
430        // Send connected event
431        let connected = SsePayload::Connected {
432            session_id: session_id.to_string(),
433            session_secret: session_secret.clone(),
434        };
435        match serde_json::to_string(&connected) {
436            Ok(json) => {
437                let _ = event_tx
438                    .send(Ok(Event::default().event("connected").data(json)))
439                    .await;
440            }
441            Err(e) => {
442                tracing::error!("Failed to serialize SSE connected payload: {}", e);
443            }
444        }
445
446        loop {
447            tokio::select! {
448                msg = rx.recv() => {
449                    match msg {
450                        Some(sse_msg) => {
451                            let event = match sse_msg {
452                                SseMessage::Data { target, payload } => {
453                                    let data = SsePayload::Update { target, payload };
454                                    serde_json::to_string(&data).ok().map(|json| {
455                                        Event::default().event("update").data(json)
456                                    })
457                                }
458                                SseMessage::Error { target, code, message } => {
459                                    let data = SsePayload::Error { target, code, message };
460                                    serde_json::to_string(&data).ok().map(|json| {
461                                        Event::default().event("error").data(json)
462                                    })
463                                }
464                            };
465                            if let Some(evt) = event
466                                && event_tx.send(Ok(evt)).await.is_err()
467                            {
468                                break;
469                            }
470                        }
471                        None => break,
472                    }
473                }
474                _ = cancel.cancelled() => break,
475            }
476        }
477
478        // Clean shutdown
479        _guard.mark_closed();
480        reactor.remove_session(session_id).await;
481        sessions.write().await.remove(&session_id);
482    });
483
484    let stream = ReceiverStream { rx: event_rx };
485
486    Sse::new(stream)
487        .keep_alive(
488            KeepAlive::new()
489                .interval(Duration::from_secs(keepalive_secs))
490                .text("ping"),
491        )
492        .into_response()
493}
494
495/// Convert realtime message to SSE message.
496fn convert_realtime_to_sse(msg: RealtimeMessage) -> Option<SseMessage> {
497    match msg {
498        RealtimeMessage::Data {
499            subscription_id,
500            data,
501        } => Some(SseMessage::Data {
502            target: format!("sub:{}", subscription_id),
503            payload: data,
504        }),
505        RealtimeMessage::DeltaUpdate {
506            subscription_id,
507            delta,
508        } => match serde_json::to_value(&delta) {
509            Ok(payload) => Some(SseMessage::Data {
510                target: format!("sub:{}", subscription_id),
511                payload,
512            }),
513            Err(e) => {
514                tracing::error!("Failed to serialize delta update: {}", e);
515                Some(SseMessage::Error {
516                    target: format!("sub:{}", subscription_id),
517                    code: "SERIALIZATION_ERROR".to_string(),
518                    message: "Failed to serialize update data".to_string(),
519                })
520            }
521        },
522        RealtimeMessage::JobUpdate { client_sub_id, job } => match serde_json::to_value(&job) {
523            Ok(payload) => Some(SseMessage::Data {
524                target: format!("job:{}", client_sub_id),
525                payload,
526            }),
527            Err(e) => {
528                tracing::error!("Failed to serialize job update: {}", e);
529                Some(SseMessage::Error {
530                    target: format!("job:{}", client_sub_id),
531                    code: "SERIALIZATION_ERROR".to_string(),
532                    message: "Failed to serialize job update".to_string(),
533                })
534            }
535        },
536        RealtimeMessage::WorkflowUpdate {
537            client_sub_id,
538            workflow,
539        } => match serde_json::to_value(&workflow) {
540            Ok(payload) => Some(SseMessage::Data {
541                target: format!("wf:{}", client_sub_id),
542                payload,
543            }),
544            Err(e) => {
545                tracing::error!("Failed to serialize workflow update: {}", e);
546                Some(SseMessage::Error {
547                    target: format!("wf:{}", client_sub_id),
548                    code: "SERIALIZATION_ERROR".to_string(),
549                    message: "Failed to serialize workflow update".to_string(),
550                })
551            }
552        },
553        RealtimeMessage::Error { code, message } => Some(SseMessage::Error {
554            target: String::new(),
555            code,
556            message,
557        }),
558        RealtimeMessage::ErrorWithId { id, code, message } => Some(SseMessage::Error {
559            target: id,
560            code,
561            message,
562        }),
563        // Ignore control messages
564        RealtimeMessage::Subscribe { .. }
565        | RealtimeMessage::Unsubscribe { .. }
566        | RealtimeMessage::Ping
567        | RealtimeMessage::Pong
568        | RealtimeMessage::AuthSuccess
569        | RealtimeMessage::AuthFailed { .. }
570        | RealtimeMessage::Lagging => None,
571    }
572}
573
574/// SSE subscribe handler for POST /subscribe.
575pub async fn sse_subscribe_handler(
576    State(state): State<Arc<SseState>>,
577    Extension(request_auth): Extension<AuthContext>,
578    Json(request): Json<SseSubscribeRequest>,
579) -> impl IntoResponse {
580    // Validate subscription ID length to prevent memory bloat
581    if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
582        return subscribe_error(
583            StatusCode::BAD_REQUEST,
584            "INVALID_ID",
585            format!(
586                "Subscription ID too long (max {} chars)",
587                MAX_CLIENT_SUB_ID_LEN
588            ),
589        );
590    }
591
592    let Some(session_id) = try_parse_session_id(&request.session_id) else {
593        return subscribe_error(
594            StatusCode::BAD_REQUEST,
595            "INVALID_SESSION",
596            "Invalid session ID format",
597        );
598    };
599
600    // Get session data (auth context)
601    let sessions = state.sessions.read().await;
602    let session_data = match sessions.get(&session_id) {
603        Some(data) => {
604            match authorize_session_access(data, &request.session_secret, &request_auth) {
605                Ok(auth) => auth,
606                Err(resp) => return resp,
607            }
608        }
609        None => {
610            return subscribe_error(
611                StatusCode::NOT_FOUND,
612                "SESSION_NOT_FOUND",
613                "Session not found or expired",
614            );
615        }
616    };
617    drop(sessions);
618
619    // Subscribe via reactor
620    let result = state
621        .reactor
622        .subscribe(
623            session_id,
624            request.id.clone(),
625            request.function,
626            request.args,
627            session_data,
628        )
629        .await;
630
631    match result {
632        Ok((subscription_id, data)) => {
633            // Store the subscription mapping
634            let mut sessions = state.sessions.write().await;
635            match sessions.get_mut(&session_id) {
636                Some(session) => {
637                    session.subscriptions.insert(request.id, subscription_id);
638                }
639                None => {
640                    // Session was removed between read and write lock
641                    return subscribe_error(
642                        StatusCode::NOT_FOUND,
643                        "SESSION_NOT_FOUND",
644                        "Session expired during subscription",
645                    );
646                }
647            }
648
649            tracing::debug!(
650                %session_id,
651                %subscription_id,
652                "SSE subscription registered"
653            );
654
655            (
656                StatusCode::OK,
657                Json(SseSubscribeResponse {
658                    success: true,
659                    data: Some(data),
660                    error: None,
661                }),
662            )
663        }
664        Err(e) => {
665            tracing::warn!(%session_id, error = %e, "SSE subscription failed");
666            match e {
667                forge_core::ForgeError::Unauthorized(msg) => {
668                    subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
669                }
670                forge_core::ForgeError::Forbidden(msg) => {
671                    subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
672                }
673                forge_core::ForgeError::InvalidArgument(msg)
674                | forge_core::ForgeError::Validation(msg) => {
675                    subscribe_error(StatusCode::BAD_REQUEST, "INVALID_ARGUMENT", msg)
676                }
677                forge_core::ForgeError::NotFound(msg) => {
678                    subscribe_error(StatusCode::NOT_FOUND, "NOT_FOUND", msg)
679                }
680                _ => subscribe_error(
681                    StatusCode::INTERNAL_SERVER_ERROR,
682                    "SUBSCRIPTION_FAILED",
683                    "Subscription failed",
684                ),
685            }
686        }
687    }
688}
689
690/// SSE unsubscribe handler for POST /unsubscribe.
691pub async fn sse_unsubscribe_handler(
692    State(state): State<Arc<SseState>>,
693    Extension(request_auth): Extension<AuthContext>,
694    Json(request): Json<SseUnsubscribeRequest>,
695) -> impl IntoResponse {
696    let Some(session_id) = try_parse_session_id(&request.session_id) else {
697        return unsubscribe_error(
698            StatusCode::BAD_REQUEST,
699            "INVALID_SESSION",
700            "Invalid session ID format",
701        );
702    };
703
704    // Look up internal subscription ID and validate session ownership
705    let subscription_id = {
706        let sessions = state.sessions.read().await;
707        match sessions.get(&session_id) {
708            Some(session) => {
709                if session.session_secret != request.session_secret
710                    || !same_principal(&session.auth_context, &request_auth)
711                {
712                    return unsubscribe_error(
713                        StatusCode::FORBIDDEN,
714                        "SESSION_PRINCIPAL_MISMATCH",
715                        "Request principal does not match session principal",
716                    );
717                }
718                session.subscriptions.get(&request.id).copied()
719            }
720            None => None,
721        }
722    };
723
724    let Some(subscription_id) = subscription_id else {
725        return unsubscribe_error(
726            StatusCode::NOT_FOUND,
727            "SUBSCRIPTION_NOT_FOUND",
728            "Subscription not found",
729        );
730    };
731
732    // Unsubscribe via reactor
733    state.reactor.unsubscribe(subscription_id);
734
735    // Remove from session tracking
736    {
737        let mut sessions = state.sessions.write().await;
738        if let Some(session) = sessions.get_mut(&session_id) {
739            session.subscriptions.remove(&request.id);
740        }
741    }
742
743    tracing::debug!(
744        %session_id,
745        %subscription_id,
746        "SSE subscription removed"
747    );
748
749    (
750        StatusCode::OK,
751        Json(SseUnsubscribeResponse {
752            success: true,
753            error: None,
754        }),
755    )
756}
757
758/// SSE job subscribe handler for POST /subscribe-job.
759pub async fn sse_job_subscribe_handler(
760    State(state): State<Arc<SseState>>,
761    Extension(request_auth): Extension<AuthContext>,
762    Json(request): Json<SseJobSubscribeRequest>,
763) -> impl IntoResponse {
764    if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
765        return subscribe_error(
766            StatusCode::BAD_REQUEST,
767            "INVALID_ID",
768            format!(
769                "Subscription ID too long (max {} chars)",
770                MAX_CLIENT_SUB_ID_LEN
771            ),
772        );
773    }
774
775    let Some(session_id) = try_parse_session_id(&request.session_id) else {
776        return subscribe_error(
777            StatusCode::BAD_REQUEST,
778            "INVALID_SESSION",
779            "Invalid session ID format",
780        );
781    };
782
783    // Validate session exists + principal binding
784    let session_auth = {
785        let sessions = state.sessions.read().await;
786        match sessions.get(&session_id) {
787            Some(session) => {
788                match authorize_session_access(session, &request.session_secret, &request_auth) {
789                    Ok(auth) => auth,
790                    Err(resp) => return resp,
791                }
792            }
793            None => {
794                return subscribe_error(
795                    StatusCode::NOT_FOUND,
796                    "SESSION_NOT_FOUND",
797                    "Session not found or expired",
798                );
799            }
800        }
801    };
802
803    // Parse job ID
804    let job_uuid = match uuid::Uuid::parse_str(&request.job_id) {
805        Ok(uuid) => uuid,
806        Err(_) => {
807            return subscribe_error(
808                StatusCode::BAD_REQUEST,
809                "INVALID_JOB_ID",
810                "Invalid job ID format",
811            );
812        }
813    };
814
815    // Subscribe to job updates via reactor
816    match state
817        .reactor
818        .subscribe_job(session_id, request.id.clone(), job_uuid, &session_auth)
819        .await
820    {
821        Ok(job_data) => {
822            let data = match serde_json::to_value(&job_data) {
823                Ok(v) => v,
824                Err(e) => {
825                    tracing::error!("Failed to serialize job data: {}", e);
826                    return subscribe_error(
827                        StatusCode::INTERNAL_SERVER_ERROR,
828                        "SERIALIZE_ERROR",
829                        "Failed to serialize job data",
830                    );
831                }
832            };
833            tracing::debug!(
834                %session_id,
835                job_id = %request.job_id,
836                client_sub_id = %request.id,
837                "SSE job subscription registered"
838            );
839            (
840                StatusCode::OK,
841                Json(SseSubscribeResponse {
842                    success: true,
843                    data: Some(data),
844                    error: None,
845                }),
846            )
847        }
848        Err(e) => match e {
849            forge_core::ForgeError::Unauthorized(msg) => {
850                subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
851            }
852            forge_core::ForgeError::Forbidden(msg) => {
853                subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
854            }
855            forge_core::ForgeError::NotFound(msg) => {
856                subscribe_error(StatusCode::NOT_FOUND, "JOB_NOT_FOUND", msg)
857            }
858            _ => subscribe_error(
859                StatusCode::INTERNAL_SERVER_ERROR,
860                "SUBSCRIPTION_FAILED",
861                "Subscription failed",
862            ),
863        },
864    }
865}
866
867/// SSE workflow subscribe handler for POST /subscribe-workflow.
868pub async fn sse_workflow_subscribe_handler(
869    State(state): State<Arc<SseState>>,
870    Extension(request_auth): Extension<AuthContext>,
871    Json(request): Json<SseWorkflowSubscribeRequest>,
872) -> impl IntoResponse {
873    if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
874        return subscribe_error(
875            StatusCode::BAD_REQUEST,
876            "INVALID_ID",
877            format!(
878                "Subscription ID too long (max {} chars)",
879                MAX_CLIENT_SUB_ID_LEN
880            ),
881        );
882    }
883
884    let Some(session_id) = try_parse_session_id(&request.session_id) else {
885        return subscribe_error(
886            StatusCode::BAD_REQUEST,
887            "INVALID_SESSION",
888            "Invalid session ID format",
889        );
890    };
891
892    // Validate session exists + principal binding
893    let session_auth = {
894        let sessions = state.sessions.read().await;
895        match sessions.get(&session_id) {
896            Some(session) => {
897                match authorize_session_access(session, &request.session_secret, &request_auth) {
898                    Ok(auth) => auth,
899                    Err(resp) => return resp,
900                }
901            }
902            None => {
903                return subscribe_error(
904                    StatusCode::NOT_FOUND,
905                    "SESSION_NOT_FOUND",
906                    "Session not found or expired",
907                );
908            }
909        }
910    };
911
912    // Parse workflow ID
913    let workflow_uuid = match uuid::Uuid::parse_str(&request.workflow_id) {
914        Ok(uuid) => uuid,
915        Err(_) => {
916            return subscribe_error(
917                StatusCode::BAD_REQUEST,
918                "INVALID_WORKFLOW_ID",
919                "Invalid workflow ID format",
920            );
921        }
922    };
923
924    // Subscribe to workflow updates via reactor
925    match state
926        .reactor
927        .subscribe_workflow(session_id, request.id.clone(), workflow_uuid, &session_auth)
928        .await
929    {
930        Ok(workflow_data) => {
931            let data = match serde_json::to_value(&workflow_data) {
932                Ok(v) => v,
933                Err(e) => {
934                    tracing::error!("Failed to serialize workflow data: {}", e);
935                    return subscribe_error(
936                        StatusCode::INTERNAL_SERVER_ERROR,
937                        "SERIALIZE_ERROR",
938                        "Failed to serialize workflow data",
939                    );
940                }
941            };
942            tracing::debug!(
943                %session_id,
944                workflow_id = %request.workflow_id,
945                client_sub_id = %request.id,
946                "SSE workflow subscription registered"
947            );
948            (
949                StatusCode::OK,
950                Json(SseSubscribeResponse {
951                    success: true,
952                    data: Some(data),
953                    error: None,
954                }),
955            )
956        }
957        Err(e) => match e {
958            forge_core::ForgeError::Unauthorized(msg) => {
959                subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
960            }
961            forge_core::ForgeError::Forbidden(msg) => {
962                subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
963            }
964            forge_core::ForgeError::NotFound(msg) => {
965                subscribe_error(StatusCode::NOT_FOUND, "WORKFLOW_NOT_FOUND", msg)
966            }
967            _ => subscribe_error(
968                StatusCode::INTERNAL_SERVER_ERROR,
969                "SUBSCRIPTION_FAILED",
970                "Subscription failed",
971            ),
972        },
973    }
974}
975
976#[cfg(test)]
977#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
978mod tests {
979    use super::*;
980
981    #[test]
982    fn test_sse_payload_serialization() {
983        let payload = SsePayload::Update {
984            target: "sub:123".to_string(),
985            payload: serde_json::json!({"id": 1}),
986        };
987        let json = serde_json::to_string(&payload).unwrap();
988        assert!(json.contains("\"type\":\"update\""));
989        assert!(json.contains("\"target\":\"sub:123\""));
990    }
991
992    #[test]
993    fn test_sse_error_serialization() {
994        let payload = SsePayload::Error {
995            target: "sub:456".to_string(),
996            code: "NOT_FOUND".to_string(),
997            message: "Subscription not found".to_string(),
998        };
999        let json = serde_json::to_string(&payload).unwrap();
1000        assert!(json.contains("\"type\":\"error\""));
1001        assert!(json.contains("NOT_FOUND"));
1002    }
1003}