Skip to main content

forge_runtime/gateway/
sse.rs

1//! Server-Sent Events (SSE) handler for real-time updates.
2
3use std::collections::HashMap;
4use std::convert::Infallible;
5use std::sync::Arc;
6use std::time::Duration;
7
8use axum::Json;
9use axum::extract::{Extension, Query, State};
10use axum::http::StatusCode;
11use axum::response::IntoResponse;
12use axum::response::sse::{Event, KeepAlive, Sse};
13use serde::{Deserialize, Serialize};
14use tokio::sync::{RwLock, mpsc};
15use tokio_util::sync::CancellationToken;
16
17use forge_core::function::AuthContext;
18use forge_core::realtime::{SessionId, SubscriptionId};
19
20use super::auth::AuthMiddleware;
21use crate::realtime::Reactor;
22use crate::realtime::RealtimeMessage;
23
24/// Maximum length for client subscription IDs to prevent memory bloat
25const MAX_CLIENT_SUB_ID_LEN: usize = 255;
26
27fn try_parse_session_id(session_id: &str) -> Option<SessionId> {
28    uuid::Uuid::parse_str(session_id)
29        .ok()
30        .map(SessionId::from_uuid)
31}
32
33fn same_principal(a: &AuthContext, b: &AuthContext) -> bool {
34    match (a.is_authenticated(), b.is_authenticated()) {
35        (false, false) => true,
36        (true, true) => a.principal_id().is_some() && a.principal_id() == b.principal_id(),
37        _ => false,
38    }
39}
40
41fn authorize_session_access(
42    session: &SseSessionData,
43    session_secret: &str,
44    requester_auth: &AuthContext,
45) -> Result<AuthContext, (StatusCode, Json<SseSubscribeResponse>)> {
46    if session.session_secret != session_secret {
47        return Err(subscribe_error(
48            StatusCode::UNAUTHORIZED,
49            "INVALID_SESSION_SECRET",
50            "Session secret mismatch",
51        ));
52    }
53
54    if !same_principal(&session.auth_context, requester_auth) {
55        return Err(subscribe_error(
56            StatusCode::FORBIDDEN,
57            "SESSION_PRINCIPAL_MISMATCH",
58            "Request principal does not match session principal",
59        ));
60    }
61
62    Ok(session.auth_context.clone())
63}
64
65/// SSE configuration.
66#[derive(Debug, Clone)]
67pub struct SseConfig {
68    /// Maximum sessions per server instance.
69    pub max_sessions: usize,
70    /// Channel buffer size for SSE messages.
71    pub channel_buffer_size: usize,
72    /// Keepalive interval in seconds.
73    pub keepalive_interval_secs: u64,
74}
75
76impl Default for SseConfig {
77    fn default() -> Self {
78        Self {
79            max_sessions: 10_000,
80            channel_buffer_size: 256,
81            keepalive_interval_secs: 30,
82        }
83    }
84}
85
86/// SSE query parameters.
87#[derive(Debug, Deserialize)]
88pub struct SseQuery {
89    /// Authentication token.
90    pub token: Option<String>,
91}
92
93struct SseSessionData {
94    auth_context: AuthContext,
95    session_secret: String,
96    /// Maps client subscription ID -> internal SubscriptionId
97    subscriptions: HashMap<String, SubscriptionId>,
98}
99
100/// State for SSE handler.
101#[derive(Clone)]
102pub struct SseState {
103    reactor: Arc<Reactor>,
104    auth_middleware: Arc<AuthMiddleware>,
105    /// Per-session data: auth context and subscription mappings
106    sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
107    config: SseConfig,
108}
109
110impl SseState {
111    /// Create new SSE state with default config.
112    pub fn new(reactor: Arc<Reactor>, auth_middleware: Arc<AuthMiddleware>) -> Self {
113        Self::with_config(reactor, auth_middleware, SseConfig::default())
114    }
115
116    /// Create new SSE state with custom config.
117    pub fn with_config(
118        reactor: Arc<Reactor>,
119        auth_middleware: Arc<AuthMiddleware>,
120        config: SseConfig,
121    ) -> Self {
122        Self {
123            reactor,
124            auth_middleware,
125            sessions: Arc::new(RwLock::new(HashMap::new())),
126            config,
127        }
128    }
129
130    /// Check if we can accept new sessions.
131    pub async fn can_accept_session(&self) -> bool {
132        self.sessions.read().await.len() < self.config.max_sessions
133    }
134}
135
136/// Guard to ensure session cleanup on disconnect.
137/// Spawns a cleanup task on drop to handle abrupt disconnects.
138struct SessionCleanupGuard {
139    session_id: SessionId,
140    reactor: Arc<Reactor>,
141    sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
142    dropped: bool,
143}
144
145impl SessionCleanupGuard {
146    fn new(
147        session_id: SessionId,
148        reactor: Arc<Reactor>,
149        sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
150    ) -> Self {
151        Self {
152            session_id,
153            reactor,
154            sessions,
155            dropped: false,
156        }
157    }
158
159    /// Mark as cleanly closed (cleanup will be skipped).
160    fn mark_closed(&mut self) {
161        self.dropped = true;
162    }
163}
164
165impl Drop for SessionCleanupGuard {
166    fn drop(&mut self) {
167        if self.dropped {
168            return;
169        }
170        let session_id = self.session_id;
171        let reactor = self.reactor.clone();
172        let sessions = self.sessions.clone();
173
174        // Spawn cleanup task since we can't await in drop
175        // Use spawn to handle cleanup even if the runtime is shutting down
176        if let Ok(handle) = tokio::runtime::Handle::try_current() {
177            handle.spawn(async move {
178                reactor.remove_session(session_id).await;
179                sessions.write().await.remove(&session_id);
180                tracing::debug!(%session_id, "SSE session cleaned up on disconnect");
181            });
182        } else {
183            // Runtime not available, likely shutting down. Session will be cleaned up on restart.
184            tracing::warn!(%session_id, "Could not spawn cleanup task, runtime unavailable");
185        }
186    }
187}
188
189/// SSE event payload sent to clients.
190#[derive(Debug, Serialize)]
191#[serde(tag = "type", rename_all = "snake_case")]
192pub enum SsePayload {
193    /// Subscription data update.
194    Update {
195        target: String,
196        payload: serde_json::Value,
197    },
198    /// Error message.
199    Error {
200        target: String,
201        code: String,
202        message: String,
203    },
204    /// Connection acknowledged.
205    Connected {
206        session_id: String,
207        session_secret: String,
208    },
209}
210
211/// Internal message type for SSE stream.
212#[derive(Debug)]
213pub enum SseMessage {
214    Data {
215        target: String,
216        payload: serde_json::Value,
217    },
218    Error {
219        target: String,
220        code: String,
221        message: String,
222    },
223}
224
225/// SSE subscribe request.
226#[derive(Debug, Deserialize)]
227pub struct SseSubscribeRequest {
228    pub session_id: String,
229    pub session_secret: String,
230    pub id: String,
231    pub function: String,
232    #[serde(default)]
233    pub args: serde_json::Value,
234}
235
236/// SSE unsubscribe request.
237#[derive(Debug, Deserialize)]
238pub struct SseUnsubscribeRequest {
239    pub session_id: String,
240    pub session_secret: String,
241    pub id: String,
242}
243
244/// SSE job subscribe request.
245#[derive(Debug, Deserialize)]
246pub struct SseJobSubscribeRequest {
247    pub session_id: String,
248    pub session_secret: String,
249    pub id: String,
250    pub job_id: String,
251}
252
253/// SSE workflow subscribe request.
254#[derive(Debug, Deserialize)]
255pub struct SseWorkflowSubscribeRequest {
256    pub session_id: String,
257    pub session_secret: String,
258    pub id: String,
259    pub workflow_id: String,
260}
261
262/// SSE error response.
263#[derive(Debug, Serialize)]
264pub struct SseError {
265    pub code: String,
266    pub message: String,
267}
268
269/// SSE subscribe response.
270#[derive(Debug, Serialize)]
271pub struct SseSubscribeResponse {
272    pub success: bool,
273    #[serde(skip_serializing_if = "Option::is_none")]
274    pub data: Option<serde_json::Value>,
275    #[serde(skip_serializing_if = "Option::is_none")]
276    pub error: Option<SseError>,
277}
278
279/// SSE unsubscribe response.
280#[derive(Debug, Serialize)]
281pub struct SseUnsubscribeResponse {
282    pub success: bool,
283    #[serde(skip_serializing_if = "Option::is_none")]
284    pub error: Option<SseError>,
285}
286
287/// Create an SSE subscribe error response.
288fn subscribe_error(
289    status: StatusCode,
290    code: impl Into<String>,
291    message: impl Into<String>,
292) -> (StatusCode, Json<SseSubscribeResponse>) {
293    (
294        status,
295        Json(SseSubscribeResponse {
296            success: false,
297            data: None,
298            error: Some(SseError {
299                code: code.into(),
300                message: message.into(),
301            }),
302        }),
303    )
304}
305
306/// Create an SSE unsubscribe error response.
307fn unsubscribe_error(
308    status: StatusCode,
309    code: impl Into<String>,
310    message: impl Into<String>,
311) -> (StatusCode, Json<SseUnsubscribeResponse>) {
312    (
313        status,
314        Json(SseUnsubscribeResponse {
315            success: false,
316            error: Some(SseError {
317                code: code.into(),
318                message: message.into(),
319            }),
320        }),
321    )
322}
323
324/// SSE handler for GET /events.
325pub async fn sse_handler(
326    State(state): State<Arc<SseState>>,
327    Query(query): Query<SseQuery>,
328) -> impl IntoResponse {
329    // Check session limit
330    if !state.can_accept_session().await {
331        return (
332            StatusCode::SERVICE_UNAVAILABLE,
333            "Server at capacity".to_string(),
334        )
335            .into_response();
336    }
337
338    let session_id = SessionId::new();
339    let buffer_size = state.config.channel_buffer_size;
340    let keepalive_secs = state.config.keepalive_interval_secs;
341    let (tx, mut rx) = mpsc::channel::<SseMessage>(buffer_size);
342    let cancel_token = CancellationToken::new();
343
344    let auth_context = if let Some(token) = &query.token {
345        match state.auth_middleware.validate_token_async(token).await {
346            Ok(claims) => super::auth::build_auth_context_from_claims(claims),
347            Err(e) => {
348                tracing::warn!("SSE token validation failed: {}", e);
349                return (
350                    StatusCode::UNAUTHORIZED,
351                    "Invalid authentication token".to_string(),
352                )
353                    .into_response();
354            }
355        }
356    } else {
357        forge_core::function::AuthContext::unauthenticated()
358    };
359    let session_secret = uuid::Uuid::new_v4().to_string();
360
361    // Register session with reactor
362    let reactor = state.reactor.clone();
363    let cancel = cancel_token.clone();
364
365    // Create a bridge channel for the reactor's message format
366    let (rt_tx, mut rt_rx) = mpsc::channel(buffer_size);
367    reactor.register_session(session_id, rt_tx).await;
368
369    // Store session data for subscription handlers
370    {
371        let mut sessions = state.sessions.write().await;
372        sessions.insert(
373            session_id,
374            SseSessionData {
375                auth_context: auth_context.clone(),
376                session_secret: session_secret.clone(),
377                subscriptions: HashMap::new(),
378            },
379        );
380    }
381
382    // Capture sessions for cleanup guard
383    let sessions = state.sessions.clone();
384
385    // Create cleanup guard - will clean up on drop if stream ends unexpectedly
386    let cleanup_guard = SessionCleanupGuard::new(session_id, reactor.clone(), sessions.clone());
387
388    // Bridge reactor messages to SSE messages
389    let bridge_cancel = cancel_token.clone();
390    tokio::spawn(async move {
391        loop {
392            tokio::select! {
393                msg = rt_rx.recv() => {
394                    match msg {
395                        Some(rt_msg) => {
396                            if let Some(sse_msg) = convert_realtime_to_sse(rt_msg)
397                                && tx.send(sse_msg).await.is_err() {
398                                    break;
399                                }
400                        }
401                        None => break,
402                    }
403                }
404                _ = bridge_cancel.cancelled() => break,
405            }
406        }
407    });
408
409    // Create the SSE stream
410    let stream = async_stream::stream! {
411        // Move guard into stream so it's dropped when stream ends
412        let mut _guard = cleanup_guard;
413
414        // Send connected event
415        let connected = SsePayload::Connected {
416            session_id: session_id.to_string(),
417            session_secret: session_secret.clone(),
418        };
419        match serde_json::to_string(&connected) {
420            Ok(json) => {
421                yield Ok::<Event, Infallible>(Event::default().event("connected").data(json));
422            }
423            Err(e) => {
424                tracing::error!("Failed to serialize SSE connected payload: {}", e);
425            }
426        }
427
428        loop {
429            tokio::select! {
430                msg = rx.recv() => {
431                    match msg {
432                        Some(SseMessage::Data { target, payload }) => {
433                            let event_data = SsePayload::Update { target, payload };
434                            match serde_json::to_string(&event_data) {
435                                Ok(json) => {
436                                    yield Ok::<Event, Infallible>(Event::default().event("update").data(json));
437                                }
438                                Err(e) => {
439                                    tracing::error!("Failed to serialize SSE update payload: {}", e);
440                                }
441                            }
442                        }
443                        Some(SseMessage::Error { target, code, message }) => {
444                            let event_data = SsePayload::Error { target, code, message };
445                            match serde_json::to_string(&event_data) {
446                                Ok(json) => {
447                                    yield Ok::<Event, Infallible>(Event::default().event("error").data(json));
448                                }
449                                Err(e) => {
450                                    tracing::error!("Failed to serialize SSE error payload: {}", e);
451                                }
452                            }
453                        }
454                        None => break,
455                    }
456                }
457                _ = cancel.cancelled() => break,
458            }
459        }
460
461        // Clean shutdown - mark guard as handled so Drop doesn't duplicate cleanup
462        _guard.mark_closed();
463        reactor.remove_session(session_id).await;
464        sessions.write().await.remove(&session_id);
465    };
466
467    Sse::new(stream)
468        .keep_alive(
469            KeepAlive::new()
470                .interval(Duration::from_secs(keepalive_secs))
471                .text("ping"),
472        )
473        .into_response()
474}
475
476/// Convert realtime message to SSE message.
477fn convert_realtime_to_sse(msg: RealtimeMessage) -> Option<SseMessage> {
478    match msg {
479        RealtimeMessage::Data {
480            subscription_id,
481            data,
482        } => Some(SseMessage::Data {
483            target: format!("sub:{}", subscription_id),
484            payload: data,
485        }),
486        RealtimeMessage::DeltaUpdate {
487            subscription_id,
488            delta,
489        } => match serde_json::to_value(&delta) {
490            Ok(payload) => Some(SseMessage::Data {
491                target: format!("sub:{}", subscription_id),
492                payload,
493            }),
494            Err(e) => {
495                tracing::error!("Failed to serialize delta update: {}", e);
496                Some(SseMessage::Error {
497                    target: format!("sub:{}", subscription_id),
498                    code: "SERIALIZATION_ERROR".to_string(),
499                    message: "Failed to serialize update data".to_string(),
500                })
501            }
502        },
503        RealtimeMessage::JobUpdate { client_sub_id, job } => match serde_json::to_value(&job) {
504            Ok(payload) => Some(SseMessage::Data {
505                target: format!("job:{}", client_sub_id),
506                payload,
507            }),
508            Err(e) => {
509                tracing::error!("Failed to serialize job update: {}", e);
510                Some(SseMessage::Error {
511                    target: format!("job:{}", client_sub_id),
512                    code: "SERIALIZATION_ERROR".to_string(),
513                    message: "Failed to serialize job update".to_string(),
514                })
515            }
516        },
517        RealtimeMessage::WorkflowUpdate {
518            client_sub_id,
519            workflow,
520        } => match serde_json::to_value(&workflow) {
521            Ok(payload) => Some(SseMessage::Data {
522                target: format!("wf:{}", client_sub_id),
523                payload,
524            }),
525            Err(e) => {
526                tracing::error!("Failed to serialize workflow update: {}", e);
527                Some(SseMessage::Error {
528                    target: format!("wf:{}", client_sub_id),
529                    code: "SERIALIZATION_ERROR".to_string(),
530                    message: "Failed to serialize workflow update".to_string(),
531                })
532            }
533        },
534        RealtimeMessage::Error { code, message } => Some(SseMessage::Error {
535            target: String::new(),
536            code,
537            message,
538        }),
539        RealtimeMessage::ErrorWithId { id, code, message } => Some(SseMessage::Error {
540            target: id,
541            code,
542            message,
543        }),
544        // Ignore control messages
545        RealtimeMessage::Subscribe { .. }
546        | RealtimeMessage::Unsubscribe { .. }
547        | RealtimeMessage::Ping
548        | RealtimeMessage::Pong
549        | RealtimeMessage::AuthSuccess
550        | RealtimeMessage::AuthFailed { .. } => None,
551    }
552}
553
554/// SSE subscribe handler for POST /subscribe.
555pub async fn sse_subscribe_handler(
556    State(state): State<Arc<SseState>>,
557    Extension(request_auth): Extension<AuthContext>,
558    Json(request): Json<SseSubscribeRequest>,
559) -> impl IntoResponse {
560    // Validate subscription ID length to prevent memory bloat
561    if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
562        return subscribe_error(
563            StatusCode::BAD_REQUEST,
564            "INVALID_ID",
565            format!(
566                "Subscription ID too long (max {} chars)",
567                MAX_CLIENT_SUB_ID_LEN
568            ),
569        );
570    }
571
572    let Some(session_id) = try_parse_session_id(&request.session_id) else {
573        return subscribe_error(
574            StatusCode::BAD_REQUEST,
575            "INVALID_SESSION",
576            "Invalid session ID format",
577        );
578    };
579
580    // Get session data (auth context)
581    let sessions = state.sessions.read().await;
582    let session_data = match sessions.get(&session_id) {
583        Some(data) => {
584            match authorize_session_access(data, &request.session_secret, &request_auth) {
585                Ok(auth) => auth,
586                Err(resp) => return resp,
587            }
588        }
589        None => {
590            return subscribe_error(
591                StatusCode::NOT_FOUND,
592                "SESSION_NOT_FOUND",
593                "Session not found or expired",
594            );
595        }
596    };
597    drop(sessions);
598
599    // Subscribe via reactor
600    let result = state
601        .reactor
602        .subscribe(
603            session_id,
604            request.id.clone(),
605            request.function,
606            request.args,
607            session_data,
608        )
609        .await;
610
611    match result {
612        Ok((subscription_id, data)) => {
613            // Store the subscription mapping
614            let mut sessions = state.sessions.write().await;
615            match sessions.get_mut(&session_id) {
616                Some(session) => {
617                    session.subscriptions.insert(request.id, subscription_id);
618                }
619                None => {
620                    // Session was removed between read and write lock
621                    return subscribe_error(
622                        StatusCode::NOT_FOUND,
623                        "SESSION_NOT_FOUND",
624                        "Session expired during subscription",
625                    );
626                }
627            }
628
629            tracing::debug!(
630                %session_id,
631                %subscription_id,
632                "SSE subscription registered"
633            );
634
635            (
636                StatusCode::OK,
637                Json(SseSubscribeResponse {
638                    success: true,
639                    data: Some(data),
640                    error: None,
641                }),
642            )
643        }
644        Err(e) => {
645            tracing::warn!(%session_id, error = %e, "SSE subscription failed");
646            match e {
647                forge_core::ForgeError::Unauthorized(msg) => {
648                    subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
649                }
650                forge_core::ForgeError::Forbidden(msg) => {
651                    subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
652                }
653                forge_core::ForgeError::InvalidArgument(msg)
654                | forge_core::ForgeError::Validation(msg) => {
655                    subscribe_error(StatusCode::BAD_REQUEST, "INVALID_ARGUMENT", msg)
656                }
657                forge_core::ForgeError::NotFound(msg) => {
658                    subscribe_error(StatusCode::NOT_FOUND, "NOT_FOUND", msg)
659                }
660                _ => subscribe_error(
661                    StatusCode::INTERNAL_SERVER_ERROR,
662                    "SUBSCRIPTION_FAILED",
663                    "Subscription failed",
664                ),
665            }
666        }
667    }
668}
669
670/// SSE unsubscribe handler for POST /unsubscribe.
671pub async fn sse_unsubscribe_handler(
672    State(state): State<Arc<SseState>>,
673    Extension(request_auth): Extension<AuthContext>,
674    Json(request): Json<SseUnsubscribeRequest>,
675) -> impl IntoResponse {
676    let Some(session_id) = try_parse_session_id(&request.session_id) else {
677        return unsubscribe_error(
678            StatusCode::BAD_REQUEST,
679            "INVALID_SESSION",
680            "Invalid session ID format",
681        );
682    };
683
684    // Look up internal subscription ID and validate session ownership
685    let subscription_id = {
686        let sessions = state.sessions.read().await;
687        match sessions.get(&session_id) {
688            Some(session) => {
689                if session.session_secret != request.session_secret
690                    || !same_principal(&session.auth_context, &request_auth)
691                {
692                    return unsubscribe_error(
693                        StatusCode::FORBIDDEN,
694                        "SESSION_PRINCIPAL_MISMATCH",
695                        "Request principal does not match session principal",
696                    );
697                }
698                session.subscriptions.get(&request.id).copied()
699            }
700            None => None,
701        }
702    };
703
704    let Some(subscription_id) = subscription_id else {
705        return unsubscribe_error(
706            StatusCode::NOT_FOUND,
707            "SUBSCRIPTION_NOT_FOUND",
708            "Subscription not found",
709        );
710    };
711
712    // Unsubscribe via reactor
713    state.reactor.unsubscribe(subscription_id).await;
714
715    // Remove from session tracking
716    {
717        let mut sessions = state.sessions.write().await;
718        if let Some(session) = sessions.get_mut(&session_id) {
719            session.subscriptions.remove(&request.id);
720        }
721    }
722
723    tracing::debug!(
724        %session_id,
725        %subscription_id,
726        "SSE subscription removed"
727    );
728
729    (
730        StatusCode::OK,
731        Json(SseUnsubscribeResponse {
732            success: true,
733            error: None,
734        }),
735    )
736}
737
738/// SSE job subscribe handler for POST /subscribe-job.
739pub async fn sse_job_subscribe_handler(
740    State(state): State<Arc<SseState>>,
741    Extension(request_auth): Extension<AuthContext>,
742    Json(request): Json<SseJobSubscribeRequest>,
743) -> impl IntoResponse {
744    if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
745        return subscribe_error(
746            StatusCode::BAD_REQUEST,
747            "INVALID_ID",
748            format!(
749                "Subscription ID too long (max {} chars)",
750                MAX_CLIENT_SUB_ID_LEN
751            ),
752        );
753    }
754
755    let Some(session_id) = try_parse_session_id(&request.session_id) else {
756        return subscribe_error(
757            StatusCode::BAD_REQUEST,
758            "INVALID_SESSION",
759            "Invalid session ID format",
760        );
761    };
762
763    // Validate session exists + principal binding
764    let session_auth = {
765        let sessions = state.sessions.read().await;
766        match sessions.get(&session_id) {
767            Some(session) => {
768                match authorize_session_access(session, &request.session_secret, &request_auth) {
769                    Ok(auth) => auth,
770                    Err(resp) => return resp,
771                }
772            }
773            None => {
774                return subscribe_error(
775                    StatusCode::NOT_FOUND,
776                    "SESSION_NOT_FOUND",
777                    "Session not found or expired",
778                );
779            }
780        }
781    };
782
783    // Parse job ID
784    let job_uuid = match uuid::Uuid::parse_str(&request.job_id) {
785        Ok(uuid) => uuid,
786        Err(_) => {
787            return subscribe_error(
788                StatusCode::BAD_REQUEST,
789                "INVALID_JOB_ID",
790                "Invalid job ID format",
791            );
792        }
793    };
794
795    // Subscribe to job updates via reactor
796    match state
797        .reactor
798        .subscribe_job(session_id, request.id.clone(), job_uuid, &session_auth)
799        .await
800    {
801        Ok(job_data) => {
802            let data = match serde_json::to_value(&job_data) {
803                Ok(v) => v,
804                Err(e) => {
805                    tracing::error!("Failed to serialize job data: {}", e);
806                    return subscribe_error(
807                        StatusCode::INTERNAL_SERVER_ERROR,
808                        "SERIALIZE_ERROR",
809                        "Failed to serialize job data",
810                    );
811                }
812            };
813            tracing::debug!(
814                %session_id,
815                job_id = %request.job_id,
816                client_sub_id = %request.id,
817                "SSE job subscription registered"
818            );
819            (
820                StatusCode::OK,
821                Json(SseSubscribeResponse {
822                    success: true,
823                    data: Some(data),
824                    error: None,
825                }),
826            )
827        }
828        Err(e) => match e {
829            forge_core::ForgeError::Unauthorized(msg) => {
830                subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
831            }
832            forge_core::ForgeError::Forbidden(msg) => {
833                subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
834            }
835            forge_core::ForgeError::NotFound(msg) => {
836                subscribe_error(StatusCode::NOT_FOUND, "JOB_NOT_FOUND", msg)
837            }
838            _ => subscribe_error(
839                StatusCode::INTERNAL_SERVER_ERROR,
840                "SUBSCRIPTION_FAILED",
841                "Subscription failed",
842            ),
843        },
844    }
845}
846
847/// SSE workflow subscribe handler for POST /subscribe-workflow.
848pub async fn sse_workflow_subscribe_handler(
849    State(state): State<Arc<SseState>>,
850    Extension(request_auth): Extension<AuthContext>,
851    Json(request): Json<SseWorkflowSubscribeRequest>,
852) -> impl IntoResponse {
853    if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
854        return subscribe_error(
855            StatusCode::BAD_REQUEST,
856            "INVALID_ID",
857            format!(
858                "Subscription ID too long (max {} chars)",
859                MAX_CLIENT_SUB_ID_LEN
860            ),
861        );
862    }
863
864    let Some(session_id) = try_parse_session_id(&request.session_id) else {
865        return subscribe_error(
866            StatusCode::BAD_REQUEST,
867            "INVALID_SESSION",
868            "Invalid session ID format",
869        );
870    };
871
872    // Validate session exists + principal binding
873    let session_auth = {
874        let sessions = state.sessions.read().await;
875        match sessions.get(&session_id) {
876            Some(session) => {
877                match authorize_session_access(session, &request.session_secret, &request_auth) {
878                    Ok(auth) => auth,
879                    Err(resp) => return resp,
880                }
881            }
882            None => {
883                return subscribe_error(
884                    StatusCode::NOT_FOUND,
885                    "SESSION_NOT_FOUND",
886                    "Session not found or expired",
887                );
888            }
889        }
890    };
891
892    // Parse workflow ID
893    let workflow_uuid = match uuid::Uuid::parse_str(&request.workflow_id) {
894        Ok(uuid) => uuid,
895        Err(_) => {
896            return subscribe_error(
897                StatusCode::BAD_REQUEST,
898                "INVALID_WORKFLOW_ID",
899                "Invalid workflow ID format",
900            );
901        }
902    };
903
904    // Subscribe to workflow updates via reactor
905    match state
906        .reactor
907        .subscribe_workflow(session_id, request.id.clone(), workflow_uuid, &session_auth)
908        .await
909    {
910        Ok(workflow_data) => {
911            let data = match serde_json::to_value(&workflow_data) {
912                Ok(v) => v,
913                Err(e) => {
914                    tracing::error!("Failed to serialize workflow data: {}", e);
915                    return subscribe_error(
916                        StatusCode::INTERNAL_SERVER_ERROR,
917                        "SERIALIZE_ERROR",
918                        "Failed to serialize workflow data",
919                    );
920                }
921            };
922            tracing::debug!(
923                %session_id,
924                workflow_id = %request.workflow_id,
925                client_sub_id = %request.id,
926                "SSE workflow subscription registered"
927            );
928            (
929                StatusCode::OK,
930                Json(SseSubscribeResponse {
931                    success: true,
932                    data: Some(data),
933                    error: None,
934                }),
935            )
936        }
937        Err(e) => match e {
938            forge_core::ForgeError::Unauthorized(msg) => {
939                subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
940            }
941            forge_core::ForgeError::Forbidden(msg) => {
942                subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
943            }
944            forge_core::ForgeError::NotFound(msg) => {
945                subscribe_error(StatusCode::NOT_FOUND, "WORKFLOW_NOT_FOUND", msg)
946            }
947            _ => subscribe_error(
948                StatusCode::INTERNAL_SERVER_ERROR,
949                "SUBSCRIPTION_FAILED",
950                "Subscription failed",
951            ),
952        },
953    }
954}
955
956#[cfg(test)]
957#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
958mod tests {
959    use super::*;
960
961    #[test]
962    fn test_sse_payload_serialization() {
963        let payload = SsePayload::Update {
964            target: "sub:123".to_string(),
965            payload: serde_json::json!({"id": 1}),
966        };
967        let json = serde_json::to_string(&payload).unwrap();
968        assert!(json.contains("\"type\":\"update\""));
969        assert!(json.contains("\"target\":\"sub:123\""));
970    }
971
972    #[test]
973    fn test_sse_error_serialization() {
974        let payload = SsePayload::Error {
975            target: "sub:456".to_string(),
976            code: "NOT_FOUND".to_string(),
977            message: "Subscription not found".to_string(),
978        };
979        let json = serde_json::to_string(&payload).unwrap();
980        assert!(json.contains("\"type\":\"error\""));
981        assert!(json.contains("NOT_FOUND"));
982    }
983}