Skip to main content

forge_runtime/gateway/
sse.rs

1//! Server-Sent Events (SSE) handler for real-time updates.
2
3use std::collections::HashMap;
4use std::convert::Infallible;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use axum::Json;
11use axum::extract::{Extension, Query, State};
12use axum::http::StatusCode;
13use axum::response::IntoResponse;
14use axum::response::sse::{Event, KeepAlive, Sse};
15use futures_util::Stream;
16use serde::{Deserialize, Serialize};
17use tokio::sync::{RwLock, mpsc};
18use tokio_util::sync::CancellationToken;
19
20/// Wraps an mpsc::Receiver as a Stream for SSE.
21struct ReceiverStream<T> {
22    rx: mpsc::Receiver<T>,
23}
24
25impl<T> Stream for ReceiverStream<T> {
26    type Item = T;
27    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
28        self.rx.poll_recv(cx)
29    }
30}
31
32use forge_core::function::AuthContext;
33use forge_core::realtime::{SessionId, SubscriptionId};
34
35use super::auth::AuthMiddleware;
36use crate::realtime::Reactor;
37use crate::realtime::RealtimeMessage;
38
39/// Maximum length for client subscription IDs to prevent memory bloat
40const MAX_CLIENT_SUB_ID_LEN: usize = 255;
41
42fn try_parse_session_id(session_id: &str) -> Option<SessionId> {
43    uuid::Uuid::parse_str(session_id)
44        .ok()
45        .map(SessionId::from_uuid)
46}
47
48fn same_principal(a: &AuthContext, b: &AuthContext) -> bool {
49    match (a.is_authenticated(), b.is_authenticated()) {
50        (false, false) => true,
51        (true, true) => a.principal_id().is_some() && a.principal_id() == b.principal_id(),
52        _ => false,
53    }
54}
55
56fn authorize_session_access(
57    session: &SseSessionData,
58    session_secret: &str,
59    requester_auth: &AuthContext,
60) -> Result<AuthContext, (StatusCode, Json<SseSubscribeResponse>)> {
61    if session.session_secret != session_secret {
62        return Err(subscribe_error(
63            StatusCode::UNAUTHORIZED,
64            "INVALID_SESSION_SECRET",
65            "Session secret mismatch",
66        ));
67    }
68
69    if !same_principal(&session.auth_context, requester_auth) {
70        return Err(subscribe_error(
71            StatusCode::FORBIDDEN,
72            "SESSION_PRINCIPAL_MISMATCH",
73            "Request principal does not match session principal",
74        ));
75    }
76
77    Ok(session.auth_context.clone())
78}
79
80/// SSE configuration.
81#[derive(Debug, Clone)]
82pub struct SseConfig {
83    /// Maximum sessions per server instance.
84    pub max_sessions: usize,
85    /// Channel buffer size for SSE messages.
86    pub channel_buffer_size: usize,
87    /// Keepalive interval in seconds.
88    pub keepalive_interval_secs: u64,
89}
90
91impl Default for SseConfig {
92    fn default() -> Self {
93        Self {
94            max_sessions: 10_000,
95            channel_buffer_size: 256,
96            keepalive_interval_secs: 30,
97        }
98    }
99}
100
101/// SSE query parameters.
102#[derive(Debug, Deserialize)]
103pub struct SseQuery {
104    /// Authentication token.
105    pub token: Option<String>,
106}
107
108struct SseSessionData {
109    auth_context: AuthContext,
110    session_secret: String,
111    /// Maps client subscription ID -> internal SubscriptionId
112    subscriptions: HashMap<String, SubscriptionId>,
113}
114
115/// State for SSE handler.
116#[derive(Clone)]
117pub struct SseState {
118    reactor: Arc<Reactor>,
119    auth_middleware: Arc<AuthMiddleware>,
120    /// Per-session data: auth context and subscription mappings
121    sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
122    config: SseConfig,
123}
124
125impl SseState {
126    /// Create new SSE state with default config.
127    pub fn new(reactor: Arc<Reactor>, auth_middleware: Arc<AuthMiddleware>) -> Self {
128        Self::with_config(reactor, auth_middleware, SseConfig::default())
129    }
130
131    /// Create new SSE state with custom config.
132    pub fn with_config(
133        reactor: Arc<Reactor>,
134        auth_middleware: Arc<AuthMiddleware>,
135        config: SseConfig,
136    ) -> Self {
137        Self {
138            reactor,
139            auth_middleware,
140            sessions: Arc::new(RwLock::new(HashMap::new())),
141            config,
142        }
143    }
144
145    /// Check if we can accept new sessions.
146    pub async fn can_accept_session(&self) -> bool {
147        self.sessions.read().await.len() < self.config.max_sessions
148    }
149}
150
151/// Guard to ensure session cleanup on disconnect.
152/// Spawns a cleanup task on drop to handle abrupt disconnects.
153struct SessionCleanupGuard {
154    session_id: SessionId,
155    reactor: Arc<Reactor>,
156    sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
157    dropped: bool,
158}
159
160impl SessionCleanupGuard {
161    fn new(
162        session_id: SessionId,
163        reactor: Arc<Reactor>,
164        sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
165    ) -> Self {
166        Self {
167            session_id,
168            reactor,
169            sessions,
170            dropped: false,
171        }
172    }
173
174    /// Mark as cleanly closed (cleanup will be skipped).
175    fn mark_closed(&mut self) {
176        self.dropped = true;
177    }
178}
179
180impl Drop for SessionCleanupGuard {
181    fn drop(&mut self) {
182        if self.dropped {
183            return;
184        }
185        let session_id = self.session_id;
186        let reactor = self.reactor.clone();
187        let sessions = self.sessions.clone();
188
189        // Spawn cleanup task since we can't await in drop
190        // Use spawn to handle cleanup even if the runtime is shutting down
191        crate::observability::set_active_connections("sse", -1);
192        if let Ok(handle) = tokio::runtime::Handle::try_current() {
193            handle.spawn(async move {
194                reactor.remove_session(session_id).await;
195                sessions.write().await.remove(&session_id);
196                tracing::debug!(%session_id, "SSE session cleaned up on disconnect");
197            });
198        } else {
199            // Runtime not available, likely shutting down. Session will be cleaned up on restart.
200            tracing::warn!(%session_id, "Could not spawn cleanup task, runtime unavailable");
201        }
202    }
203}
204
205/// SSE event payload sent to clients.
206#[derive(Debug, Serialize)]
207#[serde(tag = "type", rename_all = "snake_case")]
208pub enum SsePayload {
209    /// Subscription data update.
210    Update {
211        target: String,
212        payload: serde_json::Value,
213    },
214    /// Error message.
215    Error {
216        target: String,
217        code: String,
218        message: String,
219    },
220    /// Connection acknowledged.
221    Connected {
222        session_id: String,
223        session_secret: String,
224    },
225}
226
227/// Internal message type for SSE stream.
228#[derive(Debug)]
229pub enum SseMessage {
230    Data {
231        target: String,
232        payload: serde_json::Value,
233    },
234    Error {
235        target: String,
236        code: String,
237        message: String,
238    },
239}
240
241/// SSE subscribe request.
242#[derive(Debug, Deserialize)]
243pub struct SseSubscribeRequest {
244    pub session_id: String,
245    pub session_secret: String,
246    pub id: String,
247    pub function: String,
248    #[serde(default)]
249    pub args: serde_json::Value,
250}
251
252/// SSE unsubscribe request.
253#[derive(Debug, Deserialize)]
254pub struct SseUnsubscribeRequest {
255    pub session_id: String,
256    pub session_secret: String,
257    pub id: String,
258}
259
260/// SSE job subscribe request.
261#[derive(Debug, Deserialize)]
262pub struct SseJobSubscribeRequest {
263    pub session_id: String,
264    pub session_secret: String,
265    pub id: String,
266    pub job_id: String,
267}
268
269/// SSE workflow subscribe request.
270#[derive(Debug, Deserialize)]
271pub struct SseWorkflowSubscribeRequest {
272    pub session_id: String,
273    pub session_secret: String,
274    pub id: String,
275    pub workflow_id: String,
276}
277
278/// SSE error response.
279#[derive(Debug, Serialize)]
280pub struct SseError {
281    pub code: String,
282    pub message: String,
283}
284
285/// SSE subscribe response.
286#[derive(Debug, Serialize)]
287pub struct SseSubscribeResponse {
288    pub success: bool,
289    #[serde(skip_serializing_if = "Option::is_none")]
290    pub data: Option<serde_json::Value>,
291    #[serde(skip_serializing_if = "Option::is_none")]
292    pub error: Option<SseError>,
293}
294
295/// SSE unsubscribe response.
296#[derive(Debug, Serialize)]
297pub struct SseUnsubscribeResponse {
298    pub success: bool,
299    #[serde(skip_serializing_if = "Option::is_none")]
300    pub error: Option<SseError>,
301}
302
303/// Create an SSE subscribe error response.
304fn subscribe_error(
305    status: StatusCode,
306    code: impl Into<String>,
307    message: impl Into<String>,
308) -> (StatusCode, Json<SseSubscribeResponse>) {
309    (
310        status,
311        Json(SseSubscribeResponse {
312            success: false,
313            data: None,
314            error: Some(SseError {
315                code: code.into(),
316                message: message.into(),
317            }),
318        }),
319    )
320}
321
322/// Create an SSE unsubscribe error response.
323fn unsubscribe_error(
324    status: StatusCode,
325    code: impl Into<String>,
326    message: impl Into<String>,
327) -> (StatusCode, Json<SseUnsubscribeResponse>) {
328    (
329        status,
330        Json(SseUnsubscribeResponse {
331            success: false,
332            error: Some(SseError {
333                code: code.into(),
334                message: message.into(),
335            }),
336        }),
337    )
338}
339
340/// SSE handler for GET /events.
341pub async fn sse_handler(
342    State(state): State<Arc<SseState>>,
343    Query(query): Query<SseQuery>,
344) -> impl IntoResponse {
345    // Check session limit
346    if !state.can_accept_session().await {
347        return (
348            StatusCode::SERVICE_UNAVAILABLE,
349            "Server at capacity".to_string(),
350        )
351            .into_response();
352    }
353
354    let session_id = SessionId::new();
355    let buffer_size = state.config.channel_buffer_size;
356    let keepalive_secs = state.config.keepalive_interval_secs;
357    let (tx, mut rx) = mpsc::channel::<SseMessage>(buffer_size);
358    let cancel_token = CancellationToken::new();
359
360    let auth_context = if let Some(token) = &query.token {
361        match state.auth_middleware.validate_token_async(token).await {
362            Ok(claims) => super::auth::build_auth_context_from_claims(claims),
363            Err(e) => {
364                tracing::warn!("SSE token validation failed: {}", e);
365                return (
366                    StatusCode::UNAUTHORIZED,
367                    "Invalid authentication token".to_string(),
368                )
369                    .into_response();
370            }
371        }
372    } else {
373        forge_core::function::AuthContext::unauthenticated()
374    };
375    let session_secret = uuid::Uuid::new_v4().to_string();
376
377    // Register session with reactor
378    let reactor = state.reactor.clone();
379    let cancel = cancel_token.clone();
380
381    // Create a bridge channel for the reactor's message format
382    let (rt_tx, mut rt_rx) = mpsc::channel(buffer_size);
383    reactor.register_session(session_id, rt_tx);
384
385    // Store session data for subscription handlers
386    {
387        let mut sessions = state.sessions.write().await;
388        sessions.insert(
389            session_id,
390            SseSessionData {
391                auth_context: auth_context.clone(),
392                session_secret: session_secret.clone(),
393                subscriptions: HashMap::new(),
394            },
395        );
396    }
397
398    // Capture sessions for cleanup guard
399    let sessions = state.sessions.clone();
400
401    // Create cleanup guard - will clean up on drop if stream ends unexpectedly
402    let cleanup_guard = SessionCleanupGuard::new(session_id, reactor.clone(), sessions.clone());
403    crate::observability::set_active_connections("sse", 1);
404
405    // Bridge reactor messages to SSE messages
406    let bridge_cancel = cancel_token.clone();
407    tokio::spawn(async move {
408        loop {
409            tokio::select! {
410                msg = rt_rx.recv() => {
411                    match msg {
412                        Some(rt_msg) => {
413                            if let Some(sse_msg) = convert_realtime_to_sse(rt_msg)
414                                && tx.send(sse_msg).await.is_err() {
415                                    break;
416                                }
417                        }
418                        None => break,
419                    }
420                }
421                _ = bridge_cancel.cancelled() => break,
422            }
423        }
424    });
425
426    // Spawn a task that feeds SSE events into a channel
427    let (event_tx, event_rx) = mpsc::channel::<Result<Event, Infallible>>(buffer_size);
428
429    tokio::spawn(async move {
430        let mut _guard = cleanup_guard;
431
432        // Send connected event
433        let connected = SsePayload::Connected {
434            session_id: session_id.to_string(),
435            session_secret: session_secret.clone(),
436        };
437        match serde_json::to_string(&connected) {
438            Ok(json) => {
439                let _ = event_tx
440                    .send(Ok(Event::default().event("connected").data(json)))
441                    .await;
442            }
443            Err(e) => {
444                tracing::error!("Failed to serialize SSE connected payload: {}", e);
445            }
446        }
447
448        loop {
449            tokio::select! {
450                msg = rx.recv() => {
451                    match msg {
452                        Some(sse_msg) => {
453                            let event = match sse_msg {
454                                SseMessage::Data { target, payload } => {
455                                    let data = SsePayload::Update { target, payload };
456                                    serde_json::to_string(&data).ok().map(|json| {
457                                        Event::default().event("update").data(json)
458                                    })
459                                }
460                                SseMessage::Error { target, code, message } => {
461                                    let data = SsePayload::Error { target, code, message };
462                                    serde_json::to_string(&data).ok().map(|json| {
463                                        Event::default().event("error").data(json)
464                                    })
465                                }
466                            };
467                            if let Some(evt) = event
468                                && event_tx.send(Ok(evt)).await.is_err()
469                            {
470                                break;
471                            }
472                        }
473                        None => break,
474                    }
475                }
476                _ = cancel.cancelled() => break,
477            }
478        }
479
480        // Clean shutdown
481        _guard.mark_closed();
482        reactor.remove_session(session_id).await;
483        sessions.write().await.remove(&session_id);
484    });
485
486    let stream = ReceiverStream { rx: event_rx };
487
488    Sse::new(stream)
489        .keep_alive(
490            KeepAlive::new()
491                .interval(Duration::from_secs(keepalive_secs))
492                .text("ping"),
493        )
494        .into_response()
495}
496
497/// Convert realtime message to SSE message.
498fn convert_realtime_to_sse(msg: RealtimeMessage) -> Option<SseMessage> {
499    match msg {
500        RealtimeMessage::Data {
501            subscription_id,
502            data,
503        } => Some(SseMessage::Data {
504            target: format!("sub:{}", subscription_id),
505            payload: data,
506        }),
507        RealtimeMessage::DeltaUpdate {
508            subscription_id,
509            delta,
510        } => match serde_json::to_value(&delta) {
511            Ok(payload) => Some(SseMessage::Data {
512                target: format!("sub:{}", subscription_id),
513                payload,
514            }),
515            Err(e) => {
516                tracing::error!("Failed to serialize delta update: {}", e);
517                Some(SseMessage::Error {
518                    target: format!("sub:{}", subscription_id),
519                    code: "SERIALIZATION_ERROR".to_string(),
520                    message: "Failed to serialize update data".to_string(),
521                })
522            }
523        },
524        RealtimeMessage::JobUpdate { client_sub_id, job } => match serde_json::to_value(&job) {
525            Ok(payload) => Some(SseMessage::Data {
526                target: format!("job:{}", client_sub_id),
527                payload,
528            }),
529            Err(e) => {
530                tracing::error!("Failed to serialize job update: {}", e);
531                Some(SseMessage::Error {
532                    target: format!("job:{}", client_sub_id),
533                    code: "SERIALIZATION_ERROR".to_string(),
534                    message: "Failed to serialize job update".to_string(),
535                })
536            }
537        },
538        RealtimeMessage::WorkflowUpdate {
539            client_sub_id,
540            workflow,
541        } => match serde_json::to_value(&workflow) {
542            Ok(payload) => Some(SseMessage::Data {
543                target: format!("wf:{}", client_sub_id),
544                payload,
545            }),
546            Err(e) => {
547                tracing::error!("Failed to serialize workflow update: {}", e);
548                Some(SseMessage::Error {
549                    target: format!("wf:{}", client_sub_id),
550                    code: "SERIALIZATION_ERROR".to_string(),
551                    message: "Failed to serialize workflow update".to_string(),
552                })
553            }
554        },
555        RealtimeMessage::Error { code, message } => Some(SseMessage::Error {
556            target: String::new(),
557            code,
558            message,
559        }),
560        RealtimeMessage::ErrorWithId { id, code, message } => Some(SseMessage::Error {
561            target: id,
562            code,
563            message,
564        }),
565        // Ignore control messages
566        RealtimeMessage::Subscribe { .. }
567        | RealtimeMessage::Unsubscribe { .. }
568        | RealtimeMessage::Ping
569        | RealtimeMessage::Pong
570        | RealtimeMessage::AuthSuccess
571        | RealtimeMessage::AuthFailed { .. }
572        | RealtimeMessage::Lagging => None,
573    }
574}
575
576/// SSE subscribe handler for POST /subscribe.
577pub async fn sse_subscribe_handler(
578    State(state): State<Arc<SseState>>,
579    Extension(request_auth): Extension<AuthContext>,
580    Json(request): Json<SseSubscribeRequest>,
581) -> impl IntoResponse {
582    // Validate subscription ID length to prevent memory bloat
583    if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
584        return subscribe_error(
585            StatusCode::BAD_REQUEST,
586            "INVALID_ID",
587            format!(
588                "Subscription ID too long (max {} chars)",
589                MAX_CLIENT_SUB_ID_LEN
590            ),
591        );
592    }
593
594    let Some(session_id) = try_parse_session_id(&request.session_id) else {
595        return subscribe_error(
596            StatusCode::BAD_REQUEST,
597            "INVALID_SESSION",
598            "Invalid session ID format",
599        );
600    };
601
602    // Get session data (auth context)
603    let sessions = state.sessions.read().await;
604    let session_data = match sessions.get(&session_id) {
605        Some(data) => {
606            match authorize_session_access(data, &request.session_secret, &request_auth) {
607                Ok(auth) => auth,
608                Err(resp) => return resp,
609            }
610        }
611        None => {
612            return subscribe_error(
613                StatusCode::NOT_FOUND,
614                "SESSION_NOT_FOUND",
615                "Session not found or expired",
616            );
617        }
618    };
619    drop(sessions);
620
621    // Subscribe via reactor
622    let result = state
623        .reactor
624        .subscribe(
625            session_id,
626            request.id.clone(),
627            request.function,
628            request.args,
629            session_data,
630        )
631        .await;
632
633    match result {
634        Ok((subscription_id, data)) => {
635            // Store the subscription mapping
636            let mut sessions = state.sessions.write().await;
637            match sessions.get_mut(&session_id) {
638                Some(session) => {
639                    session.subscriptions.insert(request.id, subscription_id);
640                }
641                None => {
642                    // Session was removed between read and write lock
643                    return subscribe_error(
644                        StatusCode::NOT_FOUND,
645                        "SESSION_NOT_FOUND",
646                        "Session expired during subscription",
647                    );
648                }
649            }
650
651            tracing::debug!(
652                %session_id,
653                %subscription_id,
654                "SSE subscription registered"
655            );
656
657            (
658                StatusCode::OK,
659                Json(SseSubscribeResponse {
660                    success: true,
661                    data: Some(data),
662                    error: None,
663                }),
664            )
665        }
666        Err(e) => {
667            tracing::warn!(%session_id, error = %e, "SSE subscription failed");
668            match e {
669                forge_core::ForgeError::Unauthorized(msg) => {
670                    subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
671                }
672                forge_core::ForgeError::Forbidden(msg) => {
673                    subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
674                }
675                forge_core::ForgeError::InvalidArgument(msg)
676                | forge_core::ForgeError::Validation(msg) => {
677                    subscribe_error(StatusCode::BAD_REQUEST, "INVALID_ARGUMENT", msg)
678                }
679                forge_core::ForgeError::NotFound(msg) => {
680                    subscribe_error(StatusCode::NOT_FOUND, "NOT_FOUND", msg)
681                }
682                _ => subscribe_error(
683                    StatusCode::INTERNAL_SERVER_ERROR,
684                    "SUBSCRIPTION_FAILED",
685                    "Subscription failed",
686                ),
687            }
688        }
689    }
690}
691
692/// SSE unsubscribe handler for POST /unsubscribe.
693pub async fn sse_unsubscribe_handler(
694    State(state): State<Arc<SseState>>,
695    Extension(request_auth): Extension<AuthContext>,
696    Json(request): Json<SseUnsubscribeRequest>,
697) -> impl IntoResponse {
698    let Some(session_id) = try_parse_session_id(&request.session_id) else {
699        return unsubscribe_error(
700            StatusCode::BAD_REQUEST,
701            "INVALID_SESSION",
702            "Invalid session ID format",
703        );
704    };
705
706    // Look up internal subscription ID and validate session ownership
707    let subscription_id = {
708        let sessions = state.sessions.read().await;
709        match sessions.get(&session_id) {
710            Some(session) => {
711                if session.session_secret != request.session_secret
712                    || !same_principal(&session.auth_context, &request_auth)
713                {
714                    return unsubscribe_error(
715                        StatusCode::FORBIDDEN,
716                        "SESSION_PRINCIPAL_MISMATCH",
717                        "Request principal does not match session principal",
718                    );
719                }
720                session.subscriptions.get(&request.id).copied()
721            }
722            None => None,
723        }
724    };
725
726    let Some(subscription_id) = subscription_id else {
727        return unsubscribe_error(
728            StatusCode::NOT_FOUND,
729            "SUBSCRIPTION_NOT_FOUND",
730            "Subscription not found",
731        );
732    };
733
734    // Unsubscribe via reactor
735    state.reactor.unsubscribe(subscription_id);
736
737    // Remove from session tracking
738    {
739        let mut sessions = state.sessions.write().await;
740        if let Some(session) = sessions.get_mut(&session_id) {
741            session.subscriptions.remove(&request.id);
742        }
743    }
744
745    tracing::debug!(
746        %session_id,
747        %subscription_id,
748        "SSE subscription removed"
749    );
750
751    (
752        StatusCode::OK,
753        Json(SseUnsubscribeResponse {
754            success: true,
755            error: None,
756        }),
757    )
758}
759
760/// SSE job subscribe handler for POST /subscribe-job.
761pub async fn sse_job_subscribe_handler(
762    State(state): State<Arc<SseState>>,
763    Extension(request_auth): Extension<AuthContext>,
764    Json(request): Json<SseJobSubscribeRequest>,
765) -> impl IntoResponse {
766    if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
767        return subscribe_error(
768            StatusCode::BAD_REQUEST,
769            "INVALID_ID",
770            format!(
771                "Subscription ID too long (max {} chars)",
772                MAX_CLIENT_SUB_ID_LEN
773            ),
774        );
775    }
776
777    let Some(session_id) = try_parse_session_id(&request.session_id) else {
778        return subscribe_error(
779            StatusCode::BAD_REQUEST,
780            "INVALID_SESSION",
781            "Invalid session ID format",
782        );
783    };
784
785    // Validate session exists + principal binding
786    let session_auth = {
787        let sessions = state.sessions.read().await;
788        match sessions.get(&session_id) {
789            Some(session) => {
790                match authorize_session_access(session, &request.session_secret, &request_auth) {
791                    Ok(auth) => auth,
792                    Err(resp) => return resp,
793                }
794            }
795            None => {
796                return subscribe_error(
797                    StatusCode::NOT_FOUND,
798                    "SESSION_NOT_FOUND",
799                    "Session not found or expired",
800                );
801            }
802        }
803    };
804
805    // Parse job ID
806    let job_uuid = match uuid::Uuid::parse_str(&request.job_id) {
807        Ok(uuid) => uuid,
808        Err(_) => {
809            return subscribe_error(
810                StatusCode::BAD_REQUEST,
811                "INVALID_JOB_ID",
812                "Invalid job ID format",
813            );
814        }
815    };
816
817    // Subscribe to job updates via reactor
818    match state
819        .reactor
820        .subscribe_job(session_id, request.id.clone(), job_uuid, &session_auth)
821        .await
822    {
823        Ok(job_data) => {
824            let data = match serde_json::to_value(&job_data) {
825                Ok(v) => v,
826                Err(e) => {
827                    tracing::error!("Failed to serialize job data: {}", e);
828                    return subscribe_error(
829                        StatusCode::INTERNAL_SERVER_ERROR,
830                        "SERIALIZE_ERROR",
831                        "Failed to serialize job data",
832                    );
833                }
834            };
835            tracing::debug!(
836                %session_id,
837                job_id = %request.job_id,
838                client_sub_id = %request.id,
839                "SSE job subscription registered"
840            );
841            (
842                StatusCode::OK,
843                Json(SseSubscribeResponse {
844                    success: true,
845                    data: Some(data),
846                    error: None,
847                }),
848            )
849        }
850        Err(e) => match e {
851            forge_core::ForgeError::Unauthorized(msg) => {
852                subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
853            }
854            forge_core::ForgeError::Forbidden(msg) => {
855                subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
856            }
857            forge_core::ForgeError::NotFound(msg) => {
858                subscribe_error(StatusCode::NOT_FOUND, "JOB_NOT_FOUND", msg)
859            }
860            _ => subscribe_error(
861                StatusCode::INTERNAL_SERVER_ERROR,
862                "SUBSCRIPTION_FAILED",
863                "Subscription failed",
864            ),
865        },
866    }
867}
868
869/// SSE workflow subscribe handler for POST /subscribe-workflow.
870pub async fn sse_workflow_subscribe_handler(
871    State(state): State<Arc<SseState>>,
872    Extension(request_auth): Extension<AuthContext>,
873    Json(request): Json<SseWorkflowSubscribeRequest>,
874) -> impl IntoResponse {
875    if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
876        return subscribe_error(
877            StatusCode::BAD_REQUEST,
878            "INVALID_ID",
879            format!(
880                "Subscription ID too long (max {} chars)",
881                MAX_CLIENT_SUB_ID_LEN
882            ),
883        );
884    }
885
886    let Some(session_id) = try_parse_session_id(&request.session_id) else {
887        return subscribe_error(
888            StatusCode::BAD_REQUEST,
889            "INVALID_SESSION",
890            "Invalid session ID format",
891        );
892    };
893
894    // Validate session exists + principal binding
895    let session_auth = {
896        let sessions = state.sessions.read().await;
897        match sessions.get(&session_id) {
898            Some(session) => {
899                match authorize_session_access(session, &request.session_secret, &request_auth) {
900                    Ok(auth) => auth,
901                    Err(resp) => return resp,
902                }
903            }
904            None => {
905                return subscribe_error(
906                    StatusCode::NOT_FOUND,
907                    "SESSION_NOT_FOUND",
908                    "Session not found or expired",
909                );
910            }
911        }
912    };
913
914    // Parse workflow ID
915    let workflow_uuid = match uuid::Uuid::parse_str(&request.workflow_id) {
916        Ok(uuid) => uuid,
917        Err(_) => {
918            return subscribe_error(
919                StatusCode::BAD_REQUEST,
920                "INVALID_WORKFLOW_ID",
921                "Invalid workflow ID format",
922            );
923        }
924    };
925
926    // Subscribe to workflow updates via reactor
927    match state
928        .reactor
929        .subscribe_workflow(session_id, request.id.clone(), workflow_uuid, &session_auth)
930        .await
931    {
932        Ok(workflow_data) => {
933            let data = match serde_json::to_value(&workflow_data) {
934                Ok(v) => v,
935                Err(e) => {
936                    tracing::error!("Failed to serialize workflow data: {}", e);
937                    return subscribe_error(
938                        StatusCode::INTERNAL_SERVER_ERROR,
939                        "SERIALIZE_ERROR",
940                        "Failed to serialize workflow data",
941                    );
942                }
943            };
944            tracing::debug!(
945                %session_id,
946                workflow_id = %request.workflow_id,
947                client_sub_id = %request.id,
948                "SSE workflow subscription registered"
949            );
950            (
951                StatusCode::OK,
952                Json(SseSubscribeResponse {
953                    success: true,
954                    data: Some(data),
955                    error: None,
956                }),
957            )
958        }
959        Err(e) => match e {
960            forge_core::ForgeError::Unauthorized(msg) => {
961                subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
962            }
963            forge_core::ForgeError::Forbidden(msg) => {
964                subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
965            }
966            forge_core::ForgeError::NotFound(msg) => {
967                subscribe_error(StatusCode::NOT_FOUND, "WORKFLOW_NOT_FOUND", msg)
968            }
969            _ => subscribe_error(
970                StatusCode::INTERNAL_SERVER_ERROR,
971                "SUBSCRIPTION_FAILED",
972                "Subscription failed",
973            ),
974        },
975    }
976}
977
978#[cfg(test)]
979#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
980mod tests {
981    use super::*;
982
983    #[test]
984    fn test_sse_payload_serialization() {
985        let payload = SsePayload::Update {
986            target: "sub:123".to_string(),
987            payload: serde_json::json!({"id": 1}),
988        };
989        let json = serde_json::to_string(&payload).unwrap();
990        assert!(json.contains("\"type\":\"update\""));
991        assert!(json.contains("\"target\":\"sub:123\""));
992    }
993
994    #[test]
995    fn test_sse_error_serialization() {
996        let payload = SsePayload::Error {
997            target: "sub:456".to_string(),
998            code: "NOT_FOUND".to_string(),
999            message: "Subscription not found".to_string(),
1000        };
1001        let json = serde_json::to_string(&payload).unwrap();
1002        assert!(json.contains("\"type\":\"error\""));
1003        assert!(json.contains("NOT_FOUND"));
1004    }
1005}