1use std::collections::HashMap;
4use std::convert::Infallible;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use axum::Json;
11use axum::extract::{Extension, Query, State};
12use axum::http::StatusCode;
13use axum::response::IntoResponse;
14use axum::response::sse::{Event, KeepAlive, Sse};
15use futures_util::Stream;
16use serde::{Deserialize, Serialize};
17use tokio::sync::{RwLock, mpsc};
18use tokio_util::sync::CancellationToken;
19
20struct ReceiverStream<T> {
22 rx: mpsc::Receiver<T>,
23}
24
25impl<T> Stream for ReceiverStream<T> {
26 type Item = T;
27 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
28 self.rx.poll_recv(cx)
29 }
30}
31
32use forge_core::function::AuthContext;
33use forge_core::realtime::{SessionId, SubscriptionId};
34
35use super::auth::AuthMiddleware;
36use crate::realtime::Reactor;
37use crate::realtime::RealtimeMessage;
38
39const MAX_CLIENT_SUB_ID_LEN: usize = 255;
41
42fn try_parse_session_id(session_id: &str) -> Option<SessionId> {
43 uuid::Uuid::parse_str(session_id)
44 .ok()
45 .map(SessionId::from_uuid)
46}
47
48fn same_principal(a: &AuthContext, b: &AuthContext) -> bool {
49 match (a.is_authenticated(), b.is_authenticated()) {
50 (false, false) => true,
51 (true, true) => a.principal_id().is_some() && a.principal_id() == b.principal_id(),
52 _ => false,
53 }
54}
55
56fn authorize_session_access(
57 session: &SseSessionData,
58 session_secret: &str,
59 requester_auth: &AuthContext,
60) -> Result<AuthContext, (StatusCode, Json<SseSubscribeResponse>)> {
61 if session.session_secret != session_secret {
62 return Err(subscribe_error(
63 StatusCode::UNAUTHORIZED,
64 "INVALID_SESSION_SECRET",
65 "Session secret mismatch",
66 ));
67 }
68
69 if !same_principal(&session.auth_context, requester_auth) {
70 return Err(subscribe_error(
71 StatusCode::FORBIDDEN,
72 "SESSION_PRINCIPAL_MISMATCH",
73 "Request principal does not match session principal",
74 ));
75 }
76
77 Ok(session.auth_context.clone())
78}
79
80#[derive(Debug, Clone)]
82pub struct SseConfig {
83 pub max_sessions: usize,
85 pub channel_buffer_size: usize,
87 pub keepalive_interval_secs: u64,
89}
90
91impl Default for SseConfig {
92 fn default() -> Self {
93 Self {
94 max_sessions: 10_000,
95 channel_buffer_size: 256,
96 keepalive_interval_secs: 30,
97 }
98 }
99}
100
101#[derive(Debug, Deserialize)]
103pub struct SseQuery {
104 pub token: Option<String>,
106}
107
108struct SseSessionData {
109 auth_context: AuthContext,
110 session_secret: String,
111 subscriptions: HashMap<String, SubscriptionId>,
113}
114
115#[derive(Clone)]
117pub struct SseState {
118 reactor: Arc<Reactor>,
119 auth_middleware: Arc<AuthMiddleware>,
120 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
122 config: SseConfig,
123}
124
125impl SseState {
126 pub fn new(reactor: Arc<Reactor>, auth_middleware: Arc<AuthMiddleware>) -> Self {
128 Self::with_config(reactor, auth_middleware, SseConfig::default())
129 }
130
131 pub fn with_config(
133 reactor: Arc<Reactor>,
134 auth_middleware: Arc<AuthMiddleware>,
135 config: SseConfig,
136 ) -> Self {
137 Self {
138 reactor,
139 auth_middleware,
140 sessions: Arc::new(RwLock::new(HashMap::new())),
141 config,
142 }
143 }
144
145 pub async fn can_accept_session(&self) -> bool {
147 self.sessions.read().await.len() < self.config.max_sessions
148 }
149}
150
151struct SessionCleanupGuard {
154 session_id: SessionId,
155 reactor: Arc<Reactor>,
156 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
157 dropped: bool,
158}
159
160impl SessionCleanupGuard {
161 fn new(
162 session_id: SessionId,
163 reactor: Arc<Reactor>,
164 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
165 ) -> Self {
166 Self {
167 session_id,
168 reactor,
169 sessions,
170 dropped: false,
171 }
172 }
173
174 fn mark_closed(&mut self) {
176 self.dropped = true;
177 }
178}
179
180impl Drop for SessionCleanupGuard {
181 fn drop(&mut self) {
182 if self.dropped {
183 return;
184 }
185 let session_id = self.session_id;
186 let reactor = self.reactor.clone();
187 let sessions = self.sessions.clone();
188
189 crate::observability::set_active_connections("sse", -1);
192 if let Ok(handle) = tokio::runtime::Handle::try_current() {
193 handle.spawn(async move {
194 reactor.remove_session(session_id).await;
195 sessions.write().await.remove(&session_id);
196 tracing::debug!(%session_id, "SSE session cleaned up on disconnect");
197 });
198 } else {
199 tracing::warn!(%session_id, "Could not spawn cleanup task, runtime unavailable");
201 }
202 }
203}
204
205#[derive(Debug, Serialize)]
207#[serde(tag = "type", rename_all = "snake_case")]
208pub enum SsePayload {
209 Update {
211 target: String,
212 payload: serde_json::Value,
213 },
214 Error {
216 target: String,
217 code: String,
218 message: String,
219 },
220 Connected {
222 session_id: String,
223 session_secret: String,
224 },
225}
226
227#[derive(Debug)]
229pub enum SseMessage {
230 Data {
231 target: String,
232 payload: serde_json::Value,
233 },
234 Error {
235 target: String,
236 code: String,
237 message: String,
238 },
239}
240
241#[derive(Debug, Deserialize)]
243pub struct SseSubscribeRequest {
244 pub session_id: String,
245 pub session_secret: String,
246 pub id: String,
247 pub function: String,
248 #[serde(default)]
249 pub args: serde_json::Value,
250}
251
252#[derive(Debug, Deserialize)]
254pub struct SseUnsubscribeRequest {
255 pub session_id: String,
256 pub session_secret: String,
257 pub id: String,
258}
259
260#[derive(Debug, Deserialize)]
262pub struct SseJobSubscribeRequest {
263 pub session_id: String,
264 pub session_secret: String,
265 pub id: String,
266 pub job_id: String,
267}
268
269#[derive(Debug, Deserialize)]
271pub struct SseWorkflowSubscribeRequest {
272 pub session_id: String,
273 pub session_secret: String,
274 pub id: String,
275 pub workflow_id: String,
276}
277
278#[derive(Debug, Serialize)]
280pub struct SseError {
281 pub code: String,
282 pub message: String,
283}
284
285#[derive(Debug, Serialize)]
287pub struct SseSubscribeResponse {
288 pub success: bool,
289 #[serde(skip_serializing_if = "Option::is_none")]
290 pub data: Option<serde_json::Value>,
291 #[serde(skip_serializing_if = "Option::is_none")]
292 pub error: Option<SseError>,
293}
294
295#[derive(Debug, Serialize)]
297pub struct SseUnsubscribeResponse {
298 pub success: bool,
299 #[serde(skip_serializing_if = "Option::is_none")]
300 pub error: Option<SseError>,
301}
302
303fn subscribe_error(
305 status: StatusCode,
306 code: impl Into<String>,
307 message: impl Into<String>,
308) -> (StatusCode, Json<SseSubscribeResponse>) {
309 (
310 status,
311 Json(SseSubscribeResponse {
312 success: false,
313 data: None,
314 error: Some(SseError {
315 code: code.into(),
316 message: message.into(),
317 }),
318 }),
319 )
320}
321
322fn unsubscribe_error(
324 status: StatusCode,
325 code: impl Into<String>,
326 message: impl Into<String>,
327) -> (StatusCode, Json<SseUnsubscribeResponse>) {
328 (
329 status,
330 Json(SseUnsubscribeResponse {
331 success: false,
332 error: Some(SseError {
333 code: code.into(),
334 message: message.into(),
335 }),
336 }),
337 )
338}
339
340pub async fn sse_handler(
342 State(state): State<Arc<SseState>>,
343 Query(query): Query<SseQuery>,
344) -> impl IntoResponse {
345 if !state.can_accept_session().await {
347 return (
348 StatusCode::SERVICE_UNAVAILABLE,
349 "Server at capacity".to_string(),
350 )
351 .into_response();
352 }
353
354 let session_id = SessionId::new();
355 let buffer_size = state.config.channel_buffer_size;
356 let keepalive_secs = state.config.keepalive_interval_secs;
357 let (tx, mut rx) = mpsc::channel::<SseMessage>(buffer_size);
358 let cancel_token = CancellationToken::new();
359
360 let auth_context = if let Some(token) = &query.token {
361 match state.auth_middleware.validate_token_async(token).await {
362 Ok(claims) => super::auth::build_auth_context_from_claims(claims),
363 Err(e) => {
364 tracing::warn!("SSE token validation failed: {}", e);
365 return (
366 StatusCode::UNAUTHORIZED,
367 "Invalid authentication token".to_string(),
368 )
369 .into_response();
370 }
371 }
372 } else {
373 forge_core::function::AuthContext::unauthenticated()
374 };
375 let session_secret = uuid::Uuid::new_v4().to_string();
376
377 let reactor = state.reactor.clone();
379 let cancel = cancel_token.clone();
380
381 let (rt_tx, mut rt_rx) = mpsc::channel(buffer_size);
383 reactor.register_session(session_id, rt_tx);
384
385 {
387 let mut sessions = state.sessions.write().await;
388 sessions.insert(
389 session_id,
390 SseSessionData {
391 auth_context: auth_context.clone(),
392 session_secret: session_secret.clone(),
393 subscriptions: HashMap::new(),
394 },
395 );
396 }
397
398 let sessions = state.sessions.clone();
400
401 let cleanup_guard = SessionCleanupGuard::new(session_id, reactor.clone(), sessions.clone());
403 crate::observability::set_active_connections("sse", 1);
404
405 let bridge_cancel = cancel_token.clone();
407 tokio::spawn(async move {
408 loop {
409 tokio::select! {
410 msg = rt_rx.recv() => {
411 match msg {
412 Some(rt_msg) => {
413 if let Some(sse_msg) = convert_realtime_to_sse(rt_msg)
414 && tx.send(sse_msg).await.is_err() {
415 break;
416 }
417 }
418 None => break,
419 }
420 }
421 _ = bridge_cancel.cancelled() => break,
422 }
423 }
424 });
425
426 let (event_tx, event_rx) = mpsc::channel::<Result<Event, Infallible>>(buffer_size);
428
429 tokio::spawn(async move {
430 let mut _guard = cleanup_guard;
431
432 let connected = SsePayload::Connected {
434 session_id: session_id.to_string(),
435 session_secret: session_secret.clone(),
436 };
437 match serde_json::to_string(&connected) {
438 Ok(json) => {
439 let _ = event_tx
440 .send(Ok(Event::default().event("connected").data(json)))
441 .await;
442 }
443 Err(e) => {
444 tracing::error!("Failed to serialize SSE connected payload: {}", e);
445 }
446 }
447
448 loop {
449 tokio::select! {
450 msg = rx.recv() => {
451 match msg {
452 Some(sse_msg) => {
453 let event = match sse_msg {
454 SseMessage::Data { target, payload } => {
455 let data = SsePayload::Update { target, payload };
456 serde_json::to_string(&data).ok().map(|json| {
457 Event::default().event("update").data(json)
458 })
459 }
460 SseMessage::Error { target, code, message } => {
461 let data = SsePayload::Error { target, code, message };
462 serde_json::to_string(&data).ok().map(|json| {
463 Event::default().event("error").data(json)
464 })
465 }
466 };
467 if let Some(evt) = event
468 && event_tx.send(Ok(evt)).await.is_err()
469 {
470 break;
471 }
472 }
473 None => break,
474 }
475 }
476 _ = cancel.cancelled() => break,
477 }
478 }
479
480 _guard.mark_closed();
482 reactor.remove_session(session_id).await;
483 sessions.write().await.remove(&session_id);
484 });
485
486 let stream = ReceiverStream { rx: event_rx };
487
488 Sse::new(stream)
489 .keep_alive(
490 KeepAlive::new()
491 .interval(Duration::from_secs(keepalive_secs))
492 .text("ping"),
493 )
494 .into_response()
495}
496
497fn convert_realtime_to_sse(msg: RealtimeMessage) -> Option<SseMessage> {
499 match msg {
500 RealtimeMessage::Data {
501 subscription_id,
502 data,
503 } => Some(SseMessage::Data {
504 target: format!("sub:{}", subscription_id),
505 payload: data,
506 }),
507 RealtimeMessage::DeltaUpdate {
508 subscription_id,
509 delta,
510 } => match serde_json::to_value(&delta) {
511 Ok(payload) => Some(SseMessage::Data {
512 target: format!("sub:{}", subscription_id),
513 payload,
514 }),
515 Err(e) => {
516 tracing::error!("Failed to serialize delta update: {}", e);
517 Some(SseMessage::Error {
518 target: format!("sub:{}", subscription_id),
519 code: "SERIALIZATION_ERROR".to_string(),
520 message: "Failed to serialize update data".to_string(),
521 })
522 }
523 },
524 RealtimeMessage::JobUpdate { client_sub_id, job } => match serde_json::to_value(&job) {
525 Ok(payload) => Some(SseMessage::Data {
526 target: format!("job:{}", client_sub_id),
527 payload,
528 }),
529 Err(e) => {
530 tracing::error!("Failed to serialize job update: {}", e);
531 Some(SseMessage::Error {
532 target: format!("job:{}", client_sub_id),
533 code: "SERIALIZATION_ERROR".to_string(),
534 message: "Failed to serialize job update".to_string(),
535 })
536 }
537 },
538 RealtimeMessage::WorkflowUpdate {
539 client_sub_id,
540 workflow,
541 } => match serde_json::to_value(&workflow) {
542 Ok(payload) => Some(SseMessage::Data {
543 target: format!("wf:{}", client_sub_id),
544 payload,
545 }),
546 Err(e) => {
547 tracing::error!("Failed to serialize workflow update: {}", e);
548 Some(SseMessage::Error {
549 target: format!("wf:{}", client_sub_id),
550 code: "SERIALIZATION_ERROR".to_string(),
551 message: "Failed to serialize workflow update".to_string(),
552 })
553 }
554 },
555 RealtimeMessage::Error { code, message } => Some(SseMessage::Error {
556 target: String::new(),
557 code,
558 message,
559 }),
560 RealtimeMessage::ErrorWithId { id, code, message } => Some(SseMessage::Error {
561 target: id,
562 code,
563 message,
564 }),
565 RealtimeMessage::Subscribe { .. }
567 | RealtimeMessage::Unsubscribe { .. }
568 | RealtimeMessage::Ping
569 | RealtimeMessage::Pong
570 | RealtimeMessage::AuthSuccess
571 | RealtimeMessage::AuthFailed { .. }
572 | RealtimeMessage::Lagging => None,
573 }
574}
575
576pub async fn sse_subscribe_handler(
578 State(state): State<Arc<SseState>>,
579 Extension(request_auth): Extension<AuthContext>,
580 Json(request): Json<SseSubscribeRequest>,
581) -> impl IntoResponse {
582 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
584 return subscribe_error(
585 StatusCode::BAD_REQUEST,
586 "INVALID_ID",
587 format!(
588 "Subscription ID too long (max {} chars)",
589 MAX_CLIENT_SUB_ID_LEN
590 ),
591 );
592 }
593
594 let Some(session_id) = try_parse_session_id(&request.session_id) else {
595 return subscribe_error(
596 StatusCode::BAD_REQUEST,
597 "INVALID_SESSION",
598 "Invalid session ID format",
599 );
600 };
601
602 let sessions = state.sessions.read().await;
604 let session_data = match sessions.get(&session_id) {
605 Some(data) => {
606 match authorize_session_access(data, &request.session_secret, &request_auth) {
607 Ok(auth) => auth,
608 Err(resp) => return resp,
609 }
610 }
611 None => {
612 return subscribe_error(
613 StatusCode::NOT_FOUND,
614 "SESSION_NOT_FOUND",
615 "Session not found or expired",
616 );
617 }
618 };
619 drop(sessions);
620
621 let result = state
623 .reactor
624 .subscribe(
625 session_id,
626 request.id.clone(),
627 request.function,
628 request.args,
629 session_data,
630 )
631 .await;
632
633 match result {
634 Ok((subscription_id, data)) => {
635 let mut sessions = state.sessions.write().await;
637 match sessions.get_mut(&session_id) {
638 Some(session) => {
639 session.subscriptions.insert(request.id, subscription_id);
640 }
641 None => {
642 return subscribe_error(
644 StatusCode::NOT_FOUND,
645 "SESSION_NOT_FOUND",
646 "Session expired during subscription",
647 );
648 }
649 }
650
651 tracing::debug!(
652 %session_id,
653 %subscription_id,
654 "SSE subscription registered"
655 );
656
657 (
658 StatusCode::OK,
659 Json(SseSubscribeResponse {
660 success: true,
661 data: Some(data),
662 error: None,
663 }),
664 )
665 }
666 Err(e) => {
667 tracing::warn!(%session_id, error = %e, "SSE subscription failed");
668 match e {
669 forge_core::ForgeError::Unauthorized(msg) => {
670 subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
671 }
672 forge_core::ForgeError::Forbidden(msg) => {
673 subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
674 }
675 forge_core::ForgeError::InvalidArgument(msg)
676 | forge_core::ForgeError::Validation(msg) => {
677 subscribe_error(StatusCode::BAD_REQUEST, "INVALID_ARGUMENT", msg)
678 }
679 forge_core::ForgeError::NotFound(msg) => {
680 subscribe_error(StatusCode::NOT_FOUND, "NOT_FOUND", msg)
681 }
682 _ => subscribe_error(
683 StatusCode::INTERNAL_SERVER_ERROR,
684 "SUBSCRIPTION_FAILED",
685 "Subscription failed",
686 ),
687 }
688 }
689 }
690}
691
692pub async fn sse_unsubscribe_handler(
694 State(state): State<Arc<SseState>>,
695 Extension(request_auth): Extension<AuthContext>,
696 Json(request): Json<SseUnsubscribeRequest>,
697) -> impl IntoResponse {
698 let Some(session_id) = try_parse_session_id(&request.session_id) else {
699 return unsubscribe_error(
700 StatusCode::BAD_REQUEST,
701 "INVALID_SESSION",
702 "Invalid session ID format",
703 );
704 };
705
706 let subscription_id = {
708 let sessions = state.sessions.read().await;
709 match sessions.get(&session_id) {
710 Some(session) => {
711 if session.session_secret != request.session_secret
712 || !same_principal(&session.auth_context, &request_auth)
713 {
714 return unsubscribe_error(
715 StatusCode::FORBIDDEN,
716 "SESSION_PRINCIPAL_MISMATCH",
717 "Request principal does not match session principal",
718 );
719 }
720 session.subscriptions.get(&request.id).copied()
721 }
722 None => None,
723 }
724 };
725
726 let Some(subscription_id) = subscription_id else {
727 return unsubscribe_error(
728 StatusCode::NOT_FOUND,
729 "SUBSCRIPTION_NOT_FOUND",
730 "Subscription not found",
731 );
732 };
733
734 state.reactor.unsubscribe(subscription_id);
736
737 {
739 let mut sessions = state.sessions.write().await;
740 if let Some(session) = sessions.get_mut(&session_id) {
741 session.subscriptions.remove(&request.id);
742 }
743 }
744
745 tracing::debug!(
746 %session_id,
747 %subscription_id,
748 "SSE subscription removed"
749 );
750
751 (
752 StatusCode::OK,
753 Json(SseUnsubscribeResponse {
754 success: true,
755 error: None,
756 }),
757 )
758}
759
760pub async fn sse_job_subscribe_handler(
762 State(state): State<Arc<SseState>>,
763 Extension(request_auth): Extension<AuthContext>,
764 Json(request): Json<SseJobSubscribeRequest>,
765) -> impl IntoResponse {
766 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
767 return subscribe_error(
768 StatusCode::BAD_REQUEST,
769 "INVALID_ID",
770 format!(
771 "Subscription ID too long (max {} chars)",
772 MAX_CLIENT_SUB_ID_LEN
773 ),
774 );
775 }
776
777 let Some(session_id) = try_parse_session_id(&request.session_id) else {
778 return subscribe_error(
779 StatusCode::BAD_REQUEST,
780 "INVALID_SESSION",
781 "Invalid session ID format",
782 );
783 };
784
785 let session_auth = {
787 let sessions = state.sessions.read().await;
788 match sessions.get(&session_id) {
789 Some(session) => {
790 match authorize_session_access(session, &request.session_secret, &request_auth) {
791 Ok(auth) => auth,
792 Err(resp) => return resp,
793 }
794 }
795 None => {
796 return subscribe_error(
797 StatusCode::NOT_FOUND,
798 "SESSION_NOT_FOUND",
799 "Session not found or expired",
800 );
801 }
802 }
803 };
804
805 let job_uuid = match uuid::Uuid::parse_str(&request.job_id) {
807 Ok(uuid) => uuid,
808 Err(_) => {
809 return subscribe_error(
810 StatusCode::BAD_REQUEST,
811 "INVALID_JOB_ID",
812 "Invalid job ID format",
813 );
814 }
815 };
816
817 match state
819 .reactor
820 .subscribe_job(session_id, request.id.clone(), job_uuid, &session_auth)
821 .await
822 {
823 Ok(job_data) => {
824 let data = match serde_json::to_value(&job_data) {
825 Ok(v) => v,
826 Err(e) => {
827 tracing::error!("Failed to serialize job data: {}", e);
828 return subscribe_error(
829 StatusCode::INTERNAL_SERVER_ERROR,
830 "SERIALIZE_ERROR",
831 "Failed to serialize job data",
832 );
833 }
834 };
835 tracing::debug!(
836 %session_id,
837 job_id = %request.job_id,
838 client_sub_id = %request.id,
839 "SSE job subscription registered"
840 );
841 (
842 StatusCode::OK,
843 Json(SseSubscribeResponse {
844 success: true,
845 data: Some(data),
846 error: None,
847 }),
848 )
849 }
850 Err(e) => match e {
851 forge_core::ForgeError::Unauthorized(msg) => {
852 subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
853 }
854 forge_core::ForgeError::Forbidden(msg) => {
855 subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
856 }
857 forge_core::ForgeError::NotFound(msg) => {
858 subscribe_error(StatusCode::NOT_FOUND, "JOB_NOT_FOUND", msg)
859 }
860 _ => subscribe_error(
861 StatusCode::INTERNAL_SERVER_ERROR,
862 "SUBSCRIPTION_FAILED",
863 "Subscription failed",
864 ),
865 },
866 }
867}
868
869pub async fn sse_workflow_subscribe_handler(
871 State(state): State<Arc<SseState>>,
872 Extension(request_auth): Extension<AuthContext>,
873 Json(request): Json<SseWorkflowSubscribeRequest>,
874) -> impl IntoResponse {
875 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
876 return subscribe_error(
877 StatusCode::BAD_REQUEST,
878 "INVALID_ID",
879 format!(
880 "Subscription ID too long (max {} chars)",
881 MAX_CLIENT_SUB_ID_LEN
882 ),
883 );
884 }
885
886 let Some(session_id) = try_parse_session_id(&request.session_id) else {
887 return subscribe_error(
888 StatusCode::BAD_REQUEST,
889 "INVALID_SESSION",
890 "Invalid session ID format",
891 );
892 };
893
894 let session_auth = {
896 let sessions = state.sessions.read().await;
897 match sessions.get(&session_id) {
898 Some(session) => {
899 match authorize_session_access(session, &request.session_secret, &request_auth) {
900 Ok(auth) => auth,
901 Err(resp) => return resp,
902 }
903 }
904 None => {
905 return subscribe_error(
906 StatusCode::NOT_FOUND,
907 "SESSION_NOT_FOUND",
908 "Session not found or expired",
909 );
910 }
911 }
912 };
913
914 let workflow_uuid = match uuid::Uuid::parse_str(&request.workflow_id) {
916 Ok(uuid) => uuid,
917 Err(_) => {
918 return subscribe_error(
919 StatusCode::BAD_REQUEST,
920 "INVALID_WORKFLOW_ID",
921 "Invalid workflow ID format",
922 );
923 }
924 };
925
926 match state
928 .reactor
929 .subscribe_workflow(session_id, request.id.clone(), workflow_uuid, &session_auth)
930 .await
931 {
932 Ok(workflow_data) => {
933 let data = match serde_json::to_value(&workflow_data) {
934 Ok(v) => v,
935 Err(e) => {
936 tracing::error!("Failed to serialize workflow data: {}", e);
937 return subscribe_error(
938 StatusCode::INTERNAL_SERVER_ERROR,
939 "SERIALIZE_ERROR",
940 "Failed to serialize workflow data",
941 );
942 }
943 };
944 tracing::debug!(
945 %session_id,
946 workflow_id = %request.workflow_id,
947 client_sub_id = %request.id,
948 "SSE workflow subscription registered"
949 );
950 (
951 StatusCode::OK,
952 Json(SseSubscribeResponse {
953 success: true,
954 data: Some(data),
955 error: None,
956 }),
957 )
958 }
959 Err(e) => match e {
960 forge_core::ForgeError::Unauthorized(msg) => {
961 subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
962 }
963 forge_core::ForgeError::Forbidden(msg) => {
964 subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
965 }
966 forge_core::ForgeError::NotFound(msg) => {
967 subscribe_error(StatusCode::NOT_FOUND, "WORKFLOW_NOT_FOUND", msg)
968 }
969 _ => subscribe_error(
970 StatusCode::INTERNAL_SERVER_ERROR,
971 "SUBSCRIPTION_FAILED",
972 "Subscription failed",
973 ),
974 },
975 }
976}
977
978#[cfg(test)]
979#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
980mod tests {
981 use super::*;
982
983 #[test]
984 fn test_sse_payload_serialization() {
985 let payload = SsePayload::Update {
986 target: "sub:123".to_string(),
987 payload: serde_json::json!({"id": 1}),
988 };
989 let json = serde_json::to_string(&payload).unwrap();
990 assert!(json.contains("\"type\":\"update\""));
991 assert!(json.contains("\"target\":\"sub:123\""));
992 }
993
994 #[test]
995 fn test_sse_error_serialization() {
996 let payload = SsePayload::Error {
997 target: "sub:456".to_string(),
998 code: "NOT_FOUND".to_string(),
999 message: "Subscription not found".to_string(),
1000 };
1001 let json = serde_json::to_string(&payload).unwrap();
1002 assert!(json.contains("\"type\":\"error\""));
1003 assert!(json.contains("NOT_FOUND"));
1004 }
1005}