1use std::collections::HashMap;
4use std::convert::Infallible;
5use std::sync::Arc;
6use std::time::Duration;
7
8use axum::Json;
9use axum::extract::{Extension, Query, State};
10use axum::http::StatusCode;
11use axum::response::IntoResponse;
12use axum::response::sse::{Event, KeepAlive, Sse};
13use serde::{Deserialize, Serialize};
14use tokio::sync::{RwLock, mpsc};
15use tokio_util::sync::CancellationToken;
16
17use forge_core::function::AuthContext;
18use forge_core::realtime::{SessionId, SubscriptionId};
19
20use super::auth::AuthMiddleware;
21use crate::realtime::Reactor;
22use crate::realtime::RealtimeMessage;
23
24const MAX_CLIENT_SUB_ID_LEN: usize = 255;
26
27fn try_parse_session_id(session_id: &str) -> Option<SessionId> {
28 uuid::Uuid::parse_str(session_id)
29 .ok()
30 .map(SessionId::from_uuid)
31}
32
33fn same_principal(a: &AuthContext, b: &AuthContext) -> bool {
34 match (a.is_authenticated(), b.is_authenticated()) {
35 (false, false) => true,
36 (true, true) => a.principal_id().is_some() && a.principal_id() == b.principal_id(),
37 _ => false,
38 }
39}
40
41fn authorize_session_access(
42 session: &SseSessionData,
43 session_secret: &str,
44 requester_auth: &AuthContext,
45) -> Result<AuthContext, (StatusCode, Json<SseSubscribeResponse>)> {
46 if session.session_secret != session_secret {
47 return Err(subscribe_error(
48 StatusCode::UNAUTHORIZED,
49 "INVALID_SESSION_SECRET",
50 "Session secret mismatch",
51 ));
52 }
53
54 if !same_principal(&session.auth_context, requester_auth) {
55 return Err(subscribe_error(
56 StatusCode::FORBIDDEN,
57 "SESSION_PRINCIPAL_MISMATCH",
58 "Request principal does not match session principal",
59 ));
60 }
61
62 Ok(session.auth_context.clone())
63}
64
65#[derive(Debug, Clone)]
67pub struct SseConfig {
68 pub max_sessions: usize,
70 pub channel_buffer_size: usize,
72 pub keepalive_interval_secs: u64,
74}
75
76impl Default for SseConfig {
77 fn default() -> Self {
78 Self {
79 max_sessions: 10_000,
80 channel_buffer_size: 256,
81 keepalive_interval_secs: 30,
82 }
83 }
84}
85
86#[derive(Debug, Deserialize)]
88pub struct SseQuery {
89 pub token: Option<String>,
91}
92
93struct SseSessionData {
94 auth_context: AuthContext,
95 session_secret: String,
96 subscriptions: HashMap<String, SubscriptionId>,
98}
99
100#[derive(Clone)]
102pub struct SseState {
103 reactor: Arc<Reactor>,
104 auth_middleware: Arc<AuthMiddleware>,
105 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
107 config: SseConfig,
108}
109
110impl SseState {
111 pub fn new(reactor: Arc<Reactor>, auth_middleware: Arc<AuthMiddleware>) -> Self {
113 Self::with_config(reactor, auth_middleware, SseConfig::default())
114 }
115
116 pub fn with_config(
118 reactor: Arc<Reactor>,
119 auth_middleware: Arc<AuthMiddleware>,
120 config: SseConfig,
121 ) -> Self {
122 Self {
123 reactor,
124 auth_middleware,
125 sessions: Arc::new(RwLock::new(HashMap::new())),
126 config,
127 }
128 }
129
130 pub async fn can_accept_session(&self) -> bool {
132 self.sessions.read().await.len() < self.config.max_sessions
133 }
134}
135
136struct SessionCleanupGuard {
139 session_id: SessionId,
140 reactor: Arc<Reactor>,
141 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
142 dropped: bool,
143}
144
145impl SessionCleanupGuard {
146 fn new(
147 session_id: SessionId,
148 reactor: Arc<Reactor>,
149 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
150 ) -> Self {
151 Self {
152 session_id,
153 reactor,
154 sessions,
155 dropped: false,
156 }
157 }
158
159 fn mark_closed(&mut self) {
161 self.dropped = true;
162 }
163}
164
165impl Drop for SessionCleanupGuard {
166 fn drop(&mut self) {
167 if self.dropped {
168 return;
169 }
170 let session_id = self.session_id;
171 let reactor = self.reactor.clone();
172 let sessions = self.sessions.clone();
173
174 if let Ok(handle) = tokio::runtime::Handle::try_current() {
177 handle.spawn(async move {
178 reactor.remove_session(session_id).await;
179 sessions.write().await.remove(&session_id);
180 tracing::debug!(%session_id, "SSE session cleaned up on disconnect");
181 });
182 } else {
183 tracing::warn!(%session_id, "Could not spawn cleanup task, runtime unavailable");
185 }
186 }
187}
188
189#[derive(Debug, Serialize)]
191#[serde(tag = "type", rename_all = "snake_case")]
192pub enum SsePayload {
193 Update {
195 target: String,
196 payload: serde_json::Value,
197 },
198 Error {
200 target: String,
201 code: String,
202 message: String,
203 },
204 Connected {
206 session_id: String,
207 session_secret: String,
208 },
209}
210
211#[derive(Debug)]
213pub enum SseMessage {
214 Data {
215 target: String,
216 payload: serde_json::Value,
217 },
218 Error {
219 target: String,
220 code: String,
221 message: String,
222 },
223}
224
225#[derive(Debug, Deserialize)]
227pub struct SseSubscribeRequest {
228 pub session_id: String,
229 pub session_secret: String,
230 pub id: String,
231 pub function: String,
232 #[serde(default)]
233 pub args: serde_json::Value,
234}
235
236#[derive(Debug, Deserialize)]
238pub struct SseUnsubscribeRequest {
239 pub session_id: String,
240 pub session_secret: String,
241 pub id: String,
242}
243
244#[derive(Debug, Deserialize)]
246pub struct SseJobSubscribeRequest {
247 pub session_id: String,
248 pub session_secret: String,
249 pub id: String,
250 pub job_id: String,
251}
252
253#[derive(Debug, Deserialize)]
255pub struct SseWorkflowSubscribeRequest {
256 pub session_id: String,
257 pub session_secret: String,
258 pub id: String,
259 pub workflow_id: String,
260}
261
262#[derive(Debug, Serialize)]
264pub struct SseError {
265 pub code: String,
266 pub message: String,
267}
268
269#[derive(Debug, Serialize)]
271pub struct SseSubscribeResponse {
272 pub success: bool,
273 #[serde(skip_serializing_if = "Option::is_none")]
274 pub data: Option<serde_json::Value>,
275 #[serde(skip_serializing_if = "Option::is_none")]
276 pub error: Option<SseError>,
277}
278
279#[derive(Debug, Serialize)]
281pub struct SseUnsubscribeResponse {
282 pub success: bool,
283 #[serde(skip_serializing_if = "Option::is_none")]
284 pub error: Option<SseError>,
285}
286
287fn subscribe_error(
289 status: StatusCode,
290 code: impl Into<String>,
291 message: impl Into<String>,
292) -> (StatusCode, Json<SseSubscribeResponse>) {
293 (
294 status,
295 Json(SseSubscribeResponse {
296 success: false,
297 data: None,
298 error: Some(SseError {
299 code: code.into(),
300 message: message.into(),
301 }),
302 }),
303 )
304}
305
306fn unsubscribe_error(
308 status: StatusCode,
309 code: impl Into<String>,
310 message: impl Into<String>,
311) -> (StatusCode, Json<SseUnsubscribeResponse>) {
312 (
313 status,
314 Json(SseUnsubscribeResponse {
315 success: false,
316 error: Some(SseError {
317 code: code.into(),
318 message: message.into(),
319 }),
320 }),
321 )
322}
323
324pub async fn sse_handler(
326 State(state): State<Arc<SseState>>,
327 Query(query): Query<SseQuery>,
328) -> impl IntoResponse {
329 if !state.can_accept_session().await {
331 return (
332 StatusCode::SERVICE_UNAVAILABLE,
333 "Server at capacity".to_string(),
334 )
335 .into_response();
336 }
337
338 let session_id = SessionId::new();
339 let buffer_size = state.config.channel_buffer_size;
340 let keepalive_secs = state.config.keepalive_interval_secs;
341 let (tx, mut rx) = mpsc::channel::<SseMessage>(buffer_size);
342 let cancel_token = CancellationToken::new();
343
344 let auth_context = if let Some(token) = &query.token {
345 match state.auth_middleware.validate_token_async(token).await {
346 Ok(claims) => super::auth::build_auth_context_from_claims(claims),
347 Err(e) => {
348 tracing::warn!("SSE token validation failed: {}", e);
349 return (
350 StatusCode::UNAUTHORIZED,
351 "Invalid authentication token".to_string(),
352 )
353 .into_response();
354 }
355 }
356 } else {
357 forge_core::function::AuthContext::unauthenticated()
358 };
359 let session_secret = uuid::Uuid::new_v4().to_string();
360
361 let reactor = state.reactor.clone();
363 let cancel = cancel_token.clone();
364
365 let (rt_tx, mut rt_rx) = mpsc::channel(buffer_size);
367 reactor.register_session(session_id, rt_tx).await;
368
369 {
371 let mut sessions = state.sessions.write().await;
372 sessions.insert(
373 session_id,
374 SseSessionData {
375 auth_context: auth_context.clone(),
376 session_secret: session_secret.clone(),
377 subscriptions: HashMap::new(),
378 },
379 );
380 }
381
382 let sessions = state.sessions.clone();
384
385 let cleanup_guard = SessionCleanupGuard::new(session_id, reactor.clone(), sessions.clone());
387
388 let bridge_cancel = cancel_token.clone();
390 tokio::spawn(async move {
391 loop {
392 tokio::select! {
393 msg = rt_rx.recv() => {
394 match msg {
395 Some(rt_msg) => {
396 if let Some(sse_msg) = convert_realtime_to_sse(rt_msg)
397 && tx.send(sse_msg).await.is_err() {
398 break;
399 }
400 }
401 None => break,
402 }
403 }
404 _ = bridge_cancel.cancelled() => break,
405 }
406 }
407 });
408
409 let stream = async_stream::stream! {
411 let mut _guard = cleanup_guard;
413
414 let connected = SsePayload::Connected {
416 session_id: session_id.to_string(),
417 session_secret: session_secret.clone(),
418 };
419 match serde_json::to_string(&connected) {
420 Ok(json) => {
421 yield Ok::<Event, Infallible>(Event::default().event("connected").data(json));
422 }
423 Err(e) => {
424 tracing::error!("Failed to serialize SSE connected payload: {}", e);
425 }
426 }
427
428 loop {
429 tokio::select! {
430 msg = rx.recv() => {
431 match msg {
432 Some(SseMessage::Data { target, payload }) => {
433 let event_data = SsePayload::Update { target, payload };
434 match serde_json::to_string(&event_data) {
435 Ok(json) => {
436 yield Ok::<Event, Infallible>(Event::default().event("update").data(json));
437 }
438 Err(e) => {
439 tracing::error!("Failed to serialize SSE update payload: {}", e);
440 }
441 }
442 }
443 Some(SseMessage::Error { target, code, message }) => {
444 let event_data = SsePayload::Error { target, code, message };
445 match serde_json::to_string(&event_data) {
446 Ok(json) => {
447 yield Ok::<Event, Infallible>(Event::default().event("error").data(json));
448 }
449 Err(e) => {
450 tracing::error!("Failed to serialize SSE error payload: {}", e);
451 }
452 }
453 }
454 None => break,
455 }
456 }
457 _ = cancel.cancelled() => break,
458 }
459 }
460
461 _guard.mark_closed();
463 reactor.remove_session(session_id).await;
464 sessions.write().await.remove(&session_id);
465 };
466
467 Sse::new(stream)
468 .keep_alive(
469 KeepAlive::new()
470 .interval(Duration::from_secs(keepalive_secs))
471 .text("ping"),
472 )
473 .into_response()
474}
475
476fn convert_realtime_to_sse(msg: RealtimeMessage) -> Option<SseMessage> {
478 match msg {
479 RealtimeMessage::Data {
480 subscription_id,
481 data,
482 } => Some(SseMessage::Data {
483 target: format!("sub:{}", subscription_id),
484 payload: data,
485 }),
486 RealtimeMessage::DeltaUpdate {
487 subscription_id,
488 delta,
489 } => match serde_json::to_value(&delta) {
490 Ok(payload) => Some(SseMessage::Data {
491 target: format!("sub:{}", subscription_id),
492 payload,
493 }),
494 Err(e) => {
495 tracing::error!("Failed to serialize delta update: {}", e);
496 Some(SseMessage::Error {
497 target: format!("sub:{}", subscription_id),
498 code: "SERIALIZATION_ERROR".to_string(),
499 message: "Failed to serialize update data".to_string(),
500 })
501 }
502 },
503 RealtimeMessage::JobUpdate { client_sub_id, job } => match serde_json::to_value(&job) {
504 Ok(payload) => Some(SseMessage::Data {
505 target: format!("job:{}", client_sub_id),
506 payload,
507 }),
508 Err(e) => {
509 tracing::error!("Failed to serialize job update: {}", e);
510 Some(SseMessage::Error {
511 target: format!("job:{}", client_sub_id),
512 code: "SERIALIZATION_ERROR".to_string(),
513 message: "Failed to serialize job update".to_string(),
514 })
515 }
516 },
517 RealtimeMessage::WorkflowUpdate {
518 client_sub_id,
519 workflow,
520 } => match serde_json::to_value(&workflow) {
521 Ok(payload) => Some(SseMessage::Data {
522 target: format!("wf:{}", client_sub_id),
523 payload,
524 }),
525 Err(e) => {
526 tracing::error!("Failed to serialize workflow update: {}", e);
527 Some(SseMessage::Error {
528 target: format!("wf:{}", client_sub_id),
529 code: "SERIALIZATION_ERROR".to_string(),
530 message: "Failed to serialize workflow update".to_string(),
531 })
532 }
533 },
534 RealtimeMessage::Error { code, message } => Some(SseMessage::Error {
535 target: String::new(),
536 code,
537 message,
538 }),
539 RealtimeMessage::ErrorWithId { id, code, message } => Some(SseMessage::Error {
540 target: id,
541 code,
542 message,
543 }),
544 RealtimeMessage::Subscribe { .. }
546 | RealtimeMessage::Unsubscribe { .. }
547 | RealtimeMessage::Ping
548 | RealtimeMessage::Pong
549 | RealtimeMessage::AuthSuccess
550 | RealtimeMessage::AuthFailed { .. } => None,
551 }
552}
553
554pub async fn sse_subscribe_handler(
556 State(state): State<Arc<SseState>>,
557 Extension(request_auth): Extension<AuthContext>,
558 Json(request): Json<SseSubscribeRequest>,
559) -> impl IntoResponse {
560 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
562 return subscribe_error(
563 StatusCode::BAD_REQUEST,
564 "INVALID_ID",
565 format!(
566 "Subscription ID too long (max {} chars)",
567 MAX_CLIENT_SUB_ID_LEN
568 ),
569 );
570 }
571
572 let Some(session_id) = try_parse_session_id(&request.session_id) else {
573 return subscribe_error(
574 StatusCode::BAD_REQUEST,
575 "INVALID_SESSION",
576 "Invalid session ID format",
577 );
578 };
579
580 let sessions = state.sessions.read().await;
582 let session_data = match sessions.get(&session_id) {
583 Some(data) => {
584 match authorize_session_access(data, &request.session_secret, &request_auth) {
585 Ok(auth) => auth,
586 Err(resp) => return resp,
587 }
588 }
589 None => {
590 return subscribe_error(
591 StatusCode::NOT_FOUND,
592 "SESSION_NOT_FOUND",
593 "Session not found or expired",
594 );
595 }
596 };
597 drop(sessions);
598
599 let result = state
601 .reactor
602 .subscribe(
603 session_id,
604 request.id.clone(),
605 request.function,
606 request.args,
607 session_data,
608 )
609 .await;
610
611 match result {
612 Ok((subscription_id, data)) => {
613 let mut sessions = state.sessions.write().await;
615 match sessions.get_mut(&session_id) {
616 Some(session) => {
617 session.subscriptions.insert(request.id, subscription_id);
618 }
619 None => {
620 return subscribe_error(
622 StatusCode::NOT_FOUND,
623 "SESSION_NOT_FOUND",
624 "Session expired during subscription",
625 );
626 }
627 }
628
629 tracing::debug!(
630 %session_id,
631 %subscription_id,
632 "SSE subscription registered"
633 );
634
635 (
636 StatusCode::OK,
637 Json(SseSubscribeResponse {
638 success: true,
639 data: Some(data),
640 error: None,
641 }),
642 )
643 }
644 Err(e) => {
645 tracing::warn!(%session_id, error = %e, "SSE subscription failed");
646 match e {
647 forge_core::ForgeError::Unauthorized(msg) => {
648 subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
649 }
650 forge_core::ForgeError::Forbidden(msg) => {
651 subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
652 }
653 forge_core::ForgeError::InvalidArgument(msg)
654 | forge_core::ForgeError::Validation(msg) => {
655 subscribe_error(StatusCode::BAD_REQUEST, "INVALID_ARGUMENT", msg)
656 }
657 forge_core::ForgeError::NotFound(msg) => {
658 subscribe_error(StatusCode::NOT_FOUND, "NOT_FOUND", msg)
659 }
660 _ => subscribe_error(
661 StatusCode::INTERNAL_SERVER_ERROR,
662 "SUBSCRIPTION_FAILED",
663 "Subscription failed",
664 ),
665 }
666 }
667 }
668}
669
670pub async fn sse_unsubscribe_handler(
672 State(state): State<Arc<SseState>>,
673 Extension(request_auth): Extension<AuthContext>,
674 Json(request): Json<SseUnsubscribeRequest>,
675) -> impl IntoResponse {
676 let Some(session_id) = try_parse_session_id(&request.session_id) else {
677 return unsubscribe_error(
678 StatusCode::BAD_REQUEST,
679 "INVALID_SESSION",
680 "Invalid session ID format",
681 );
682 };
683
684 let subscription_id = {
686 let sessions = state.sessions.read().await;
687 match sessions.get(&session_id) {
688 Some(session) => {
689 if session.session_secret != request.session_secret
690 || !same_principal(&session.auth_context, &request_auth)
691 {
692 return unsubscribe_error(
693 StatusCode::FORBIDDEN,
694 "SESSION_PRINCIPAL_MISMATCH",
695 "Request principal does not match session principal",
696 );
697 }
698 session.subscriptions.get(&request.id).copied()
699 }
700 None => None,
701 }
702 };
703
704 let Some(subscription_id) = subscription_id else {
705 return unsubscribe_error(
706 StatusCode::NOT_FOUND,
707 "SUBSCRIPTION_NOT_FOUND",
708 "Subscription not found",
709 );
710 };
711
712 state.reactor.unsubscribe(subscription_id).await;
714
715 {
717 let mut sessions = state.sessions.write().await;
718 if let Some(session) = sessions.get_mut(&session_id) {
719 session.subscriptions.remove(&request.id);
720 }
721 }
722
723 tracing::debug!(
724 %session_id,
725 %subscription_id,
726 "SSE subscription removed"
727 );
728
729 (
730 StatusCode::OK,
731 Json(SseUnsubscribeResponse {
732 success: true,
733 error: None,
734 }),
735 )
736}
737
738pub async fn sse_job_subscribe_handler(
740 State(state): State<Arc<SseState>>,
741 Extension(request_auth): Extension<AuthContext>,
742 Json(request): Json<SseJobSubscribeRequest>,
743) -> impl IntoResponse {
744 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
745 return subscribe_error(
746 StatusCode::BAD_REQUEST,
747 "INVALID_ID",
748 format!(
749 "Subscription ID too long (max {} chars)",
750 MAX_CLIENT_SUB_ID_LEN
751 ),
752 );
753 }
754
755 let Some(session_id) = try_parse_session_id(&request.session_id) else {
756 return subscribe_error(
757 StatusCode::BAD_REQUEST,
758 "INVALID_SESSION",
759 "Invalid session ID format",
760 );
761 };
762
763 let session_auth = {
765 let sessions = state.sessions.read().await;
766 match sessions.get(&session_id) {
767 Some(session) => {
768 match authorize_session_access(session, &request.session_secret, &request_auth) {
769 Ok(auth) => auth,
770 Err(resp) => return resp,
771 }
772 }
773 None => {
774 return subscribe_error(
775 StatusCode::NOT_FOUND,
776 "SESSION_NOT_FOUND",
777 "Session not found or expired",
778 );
779 }
780 }
781 };
782
783 let job_uuid = match uuid::Uuid::parse_str(&request.job_id) {
785 Ok(uuid) => uuid,
786 Err(_) => {
787 return subscribe_error(
788 StatusCode::BAD_REQUEST,
789 "INVALID_JOB_ID",
790 "Invalid job ID format",
791 );
792 }
793 };
794
795 match state
797 .reactor
798 .subscribe_job(session_id, request.id.clone(), job_uuid, &session_auth)
799 .await
800 {
801 Ok(job_data) => {
802 let data = match serde_json::to_value(&job_data) {
803 Ok(v) => v,
804 Err(e) => {
805 tracing::error!("Failed to serialize job data: {}", e);
806 return subscribe_error(
807 StatusCode::INTERNAL_SERVER_ERROR,
808 "SERIALIZE_ERROR",
809 "Failed to serialize job data",
810 );
811 }
812 };
813 tracing::debug!(
814 %session_id,
815 job_id = %request.job_id,
816 client_sub_id = %request.id,
817 "SSE job subscription registered"
818 );
819 (
820 StatusCode::OK,
821 Json(SseSubscribeResponse {
822 success: true,
823 data: Some(data),
824 error: None,
825 }),
826 )
827 }
828 Err(e) => match e {
829 forge_core::ForgeError::Unauthorized(msg) => {
830 subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
831 }
832 forge_core::ForgeError::Forbidden(msg) => {
833 subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
834 }
835 forge_core::ForgeError::NotFound(msg) => {
836 subscribe_error(StatusCode::NOT_FOUND, "JOB_NOT_FOUND", msg)
837 }
838 _ => subscribe_error(
839 StatusCode::INTERNAL_SERVER_ERROR,
840 "SUBSCRIPTION_FAILED",
841 "Subscription failed",
842 ),
843 },
844 }
845}
846
847pub async fn sse_workflow_subscribe_handler(
849 State(state): State<Arc<SseState>>,
850 Extension(request_auth): Extension<AuthContext>,
851 Json(request): Json<SseWorkflowSubscribeRequest>,
852) -> impl IntoResponse {
853 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
854 return subscribe_error(
855 StatusCode::BAD_REQUEST,
856 "INVALID_ID",
857 format!(
858 "Subscription ID too long (max {} chars)",
859 MAX_CLIENT_SUB_ID_LEN
860 ),
861 );
862 }
863
864 let Some(session_id) = try_parse_session_id(&request.session_id) else {
865 return subscribe_error(
866 StatusCode::BAD_REQUEST,
867 "INVALID_SESSION",
868 "Invalid session ID format",
869 );
870 };
871
872 let session_auth = {
874 let sessions = state.sessions.read().await;
875 match sessions.get(&session_id) {
876 Some(session) => {
877 match authorize_session_access(session, &request.session_secret, &request_auth) {
878 Ok(auth) => auth,
879 Err(resp) => return resp,
880 }
881 }
882 None => {
883 return subscribe_error(
884 StatusCode::NOT_FOUND,
885 "SESSION_NOT_FOUND",
886 "Session not found or expired",
887 );
888 }
889 }
890 };
891
892 let workflow_uuid = match uuid::Uuid::parse_str(&request.workflow_id) {
894 Ok(uuid) => uuid,
895 Err(_) => {
896 return subscribe_error(
897 StatusCode::BAD_REQUEST,
898 "INVALID_WORKFLOW_ID",
899 "Invalid workflow ID format",
900 );
901 }
902 };
903
904 match state
906 .reactor
907 .subscribe_workflow(session_id, request.id.clone(), workflow_uuid, &session_auth)
908 .await
909 {
910 Ok(workflow_data) => {
911 let data = match serde_json::to_value(&workflow_data) {
912 Ok(v) => v,
913 Err(e) => {
914 tracing::error!("Failed to serialize workflow data: {}", e);
915 return subscribe_error(
916 StatusCode::INTERNAL_SERVER_ERROR,
917 "SERIALIZE_ERROR",
918 "Failed to serialize workflow data",
919 );
920 }
921 };
922 tracing::debug!(
923 %session_id,
924 workflow_id = %request.workflow_id,
925 client_sub_id = %request.id,
926 "SSE workflow subscription registered"
927 );
928 (
929 StatusCode::OK,
930 Json(SseSubscribeResponse {
931 success: true,
932 data: Some(data),
933 error: None,
934 }),
935 )
936 }
937 Err(e) => match e {
938 forge_core::ForgeError::Unauthorized(msg) => {
939 subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
940 }
941 forge_core::ForgeError::Forbidden(msg) => {
942 subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
943 }
944 forge_core::ForgeError::NotFound(msg) => {
945 subscribe_error(StatusCode::NOT_FOUND, "WORKFLOW_NOT_FOUND", msg)
946 }
947 _ => subscribe_error(
948 StatusCode::INTERNAL_SERVER_ERROR,
949 "SUBSCRIPTION_FAILED",
950 "Subscription failed",
951 ),
952 },
953 }
954}
955
956#[cfg(test)]
957#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
958mod tests {
959 use super::*;
960
961 #[test]
962 fn test_sse_payload_serialization() {
963 let payload = SsePayload::Update {
964 target: "sub:123".to_string(),
965 payload: serde_json::json!({"id": 1}),
966 };
967 let json = serde_json::to_string(&payload).unwrap();
968 assert!(json.contains("\"type\":\"update\""));
969 assert!(json.contains("\"target\":\"sub:123\""));
970 }
971
972 #[test]
973 fn test_sse_error_serialization() {
974 let payload = SsePayload::Error {
975 target: "sub:456".to_string(),
976 code: "NOT_FOUND".to_string(),
977 message: "Subscription not found".to_string(),
978 };
979 let json = serde_json::to_string(&payload).unwrap();
980 assert!(json.contains("\"type\":\"error\""));
981 assert!(json.contains("NOT_FOUND"));
982 }
983}