1use std::collections::HashMap;
4use std::convert::Infallible;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use axum::Json;
11use axum::extract::{Extension, Query, State};
12use axum::http::StatusCode;
13use axum::response::IntoResponse;
14use axum::response::sse::{Event, KeepAlive, Sse};
15use futures_util::Stream;
16use serde::{Deserialize, Serialize};
17use tokio::sync::{RwLock, mpsc};
18use tokio_util::sync::CancellationToken;
19
20struct ReceiverStream<T> {
22 rx: mpsc::Receiver<T>,
23}
24
25impl<T> Stream for ReceiverStream<T> {
26 type Item = T;
27 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
28 self.rx.poll_recv(cx)
29 }
30}
31
32use forge_core::function::AuthContext;
33use forge_core::realtime::{SessionId, SubscriptionId};
34
35use super::auth::AuthMiddleware;
36use crate::realtime::Reactor;
37use crate::realtime::RealtimeMessage;
38
39const MAX_CLIENT_SUB_ID_LEN: usize = 255;
41
42fn try_parse_session_id(session_id: &str) -> Option<SessionId> {
43 uuid::Uuid::parse_str(session_id)
44 .ok()
45 .map(SessionId::from_uuid)
46}
47
48fn same_principal(a: &AuthContext, b: &AuthContext) -> bool {
49 match (a.is_authenticated(), b.is_authenticated()) {
50 (false, false) => true,
51 (true, true) => a.principal_id().is_some() && a.principal_id() == b.principal_id(),
52 _ => false,
53 }
54}
55
56fn authorize_session_access(
57 session: &SseSessionData,
58 session_secret: &str,
59 requester_auth: &AuthContext,
60) -> Result<AuthContext, (StatusCode, Json<SseSubscribeResponse>)> {
61 if session.session_secret != session_secret {
62 return Err(subscribe_error(
63 StatusCode::UNAUTHORIZED,
64 "INVALID_SESSION_SECRET",
65 "Session secret mismatch",
66 ));
67 }
68
69 if !same_principal(&session.auth_context, requester_auth) {
70 return Err(subscribe_error(
71 StatusCode::FORBIDDEN,
72 "SESSION_PRINCIPAL_MISMATCH",
73 "Request principal does not match session principal",
74 ));
75 }
76
77 Ok(session.auth_context.clone())
78}
79
80#[derive(Debug, Clone)]
82pub struct SseConfig {
83 pub max_sessions: usize,
85 pub channel_buffer_size: usize,
87 pub keepalive_interval_secs: u64,
89}
90
91impl Default for SseConfig {
92 fn default() -> Self {
93 Self {
94 max_sessions: 10_000,
95 channel_buffer_size: 256,
96 keepalive_interval_secs: 30,
97 }
98 }
99}
100
101#[derive(Debug, Deserialize)]
103pub struct SseQuery {
104 pub token: Option<String>,
106}
107
108struct SseSessionData {
109 auth_context: AuthContext,
110 session_secret: String,
111 subscriptions: HashMap<String, SubscriptionId>,
113}
114
115#[derive(Clone)]
117pub struct SseState {
118 reactor: Arc<Reactor>,
119 auth_middleware: Arc<AuthMiddleware>,
120 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
122 config: SseConfig,
123}
124
125impl SseState {
126 pub fn new(reactor: Arc<Reactor>, auth_middleware: Arc<AuthMiddleware>) -> Self {
128 Self::with_config(reactor, auth_middleware, SseConfig::default())
129 }
130
131 pub fn with_config(
133 reactor: Arc<Reactor>,
134 auth_middleware: Arc<AuthMiddleware>,
135 config: SseConfig,
136 ) -> Self {
137 Self {
138 reactor,
139 auth_middleware,
140 sessions: Arc::new(RwLock::new(HashMap::new())),
141 config,
142 }
143 }
144
145 pub async fn can_accept_session(&self) -> bool {
147 self.sessions.read().await.len() < self.config.max_sessions
148 }
149}
150
151struct SessionCleanupGuard {
154 session_id: SessionId,
155 reactor: Arc<Reactor>,
156 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
157 dropped: bool,
158}
159
160impl SessionCleanupGuard {
161 fn new(
162 session_id: SessionId,
163 reactor: Arc<Reactor>,
164 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
165 ) -> Self {
166 Self {
167 session_id,
168 reactor,
169 sessions,
170 dropped: false,
171 }
172 }
173
174 fn mark_closed(&mut self) {
176 self.dropped = true;
177 }
178}
179
180impl Drop for SessionCleanupGuard {
181 fn drop(&mut self) {
182 if self.dropped {
183 return;
184 }
185 let session_id = self.session_id;
186 let reactor = self.reactor.clone();
187 let sessions = self.sessions.clone();
188
189 if let Ok(handle) = tokio::runtime::Handle::try_current() {
192 handle.spawn(async move {
193 reactor.remove_session(session_id).await;
194 sessions.write().await.remove(&session_id);
195 tracing::debug!(%session_id, "SSE session cleaned up on disconnect");
196 });
197 } else {
198 tracing::warn!(%session_id, "Could not spawn cleanup task, runtime unavailable");
200 }
201 }
202}
203
204#[derive(Debug, Serialize)]
206#[serde(tag = "type", rename_all = "snake_case")]
207pub enum SsePayload {
208 Update {
210 target: String,
211 payload: serde_json::Value,
212 },
213 Error {
215 target: String,
216 code: String,
217 message: String,
218 },
219 Connected {
221 session_id: String,
222 session_secret: String,
223 },
224}
225
226#[derive(Debug)]
228pub enum SseMessage {
229 Data {
230 target: String,
231 payload: serde_json::Value,
232 },
233 Error {
234 target: String,
235 code: String,
236 message: String,
237 },
238}
239
240#[derive(Debug, Deserialize)]
242pub struct SseSubscribeRequest {
243 pub session_id: String,
244 pub session_secret: String,
245 pub id: String,
246 pub function: String,
247 #[serde(default)]
248 pub args: serde_json::Value,
249}
250
251#[derive(Debug, Deserialize)]
253pub struct SseUnsubscribeRequest {
254 pub session_id: String,
255 pub session_secret: String,
256 pub id: String,
257}
258
259#[derive(Debug, Deserialize)]
261pub struct SseJobSubscribeRequest {
262 pub session_id: String,
263 pub session_secret: String,
264 pub id: String,
265 pub job_id: String,
266}
267
268#[derive(Debug, Deserialize)]
270pub struct SseWorkflowSubscribeRequest {
271 pub session_id: String,
272 pub session_secret: String,
273 pub id: String,
274 pub workflow_id: String,
275}
276
277#[derive(Debug, Serialize)]
279pub struct SseError {
280 pub code: String,
281 pub message: String,
282}
283
284#[derive(Debug, Serialize)]
286pub struct SseSubscribeResponse {
287 pub success: bool,
288 #[serde(skip_serializing_if = "Option::is_none")]
289 pub data: Option<serde_json::Value>,
290 #[serde(skip_serializing_if = "Option::is_none")]
291 pub error: Option<SseError>,
292}
293
294#[derive(Debug, Serialize)]
296pub struct SseUnsubscribeResponse {
297 pub success: bool,
298 #[serde(skip_serializing_if = "Option::is_none")]
299 pub error: Option<SseError>,
300}
301
302fn subscribe_error(
304 status: StatusCode,
305 code: impl Into<String>,
306 message: impl Into<String>,
307) -> (StatusCode, Json<SseSubscribeResponse>) {
308 (
309 status,
310 Json(SseSubscribeResponse {
311 success: false,
312 data: None,
313 error: Some(SseError {
314 code: code.into(),
315 message: message.into(),
316 }),
317 }),
318 )
319}
320
321fn unsubscribe_error(
323 status: StatusCode,
324 code: impl Into<String>,
325 message: impl Into<String>,
326) -> (StatusCode, Json<SseUnsubscribeResponse>) {
327 (
328 status,
329 Json(SseUnsubscribeResponse {
330 success: false,
331 error: Some(SseError {
332 code: code.into(),
333 message: message.into(),
334 }),
335 }),
336 )
337}
338
339pub async fn sse_handler(
341 State(state): State<Arc<SseState>>,
342 Query(query): Query<SseQuery>,
343) -> impl IntoResponse {
344 if !state.can_accept_session().await {
346 return (
347 StatusCode::SERVICE_UNAVAILABLE,
348 "Server at capacity".to_string(),
349 )
350 .into_response();
351 }
352
353 let session_id = SessionId::new();
354 let buffer_size = state.config.channel_buffer_size;
355 let keepalive_secs = state.config.keepalive_interval_secs;
356 let (tx, mut rx) = mpsc::channel::<SseMessage>(buffer_size);
357 let cancel_token = CancellationToken::new();
358
359 let auth_context = if let Some(token) = &query.token {
360 match state.auth_middleware.validate_token_async(token).await {
361 Ok(claims) => super::auth::build_auth_context_from_claims(claims),
362 Err(e) => {
363 tracing::warn!("SSE token validation failed: {}", e);
364 return (
365 StatusCode::UNAUTHORIZED,
366 "Invalid authentication token".to_string(),
367 )
368 .into_response();
369 }
370 }
371 } else {
372 forge_core::function::AuthContext::unauthenticated()
373 };
374 let session_secret = uuid::Uuid::new_v4().to_string();
375
376 let reactor = state.reactor.clone();
378 let cancel = cancel_token.clone();
379
380 let (rt_tx, mut rt_rx) = mpsc::channel(buffer_size);
382 reactor.register_session(session_id, rt_tx);
383
384 {
386 let mut sessions = state.sessions.write().await;
387 sessions.insert(
388 session_id,
389 SseSessionData {
390 auth_context: auth_context.clone(),
391 session_secret: session_secret.clone(),
392 subscriptions: HashMap::new(),
393 },
394 );
395 }
396
397 let sessions = state.sessions.clone();
399
400 let cleanup_guard = SessionCleanupGuard::new(session_id, reactor.clone(), sessions.clone());
402
403 let bridge_cancel = cancel_token.clone();
405 tokio::spawn(async move {
406 loop {
407 tokio::select! {
408 msg = rt_rx.recv() => {
409 match msg {
410 Some(rt_msg) => {
411 if let Some(sse_msg) = convert_realtime_to_sse(rt_msg)
412 && tx.send(sse_msg).await.is_err() {
413 break;
414 }
415 }
416 None => break,
417 }
418 }
419 _ = bridge_cancel.cancelled() => break,
420 }
421 }
422 });
423
424 let (event_tx, event_rx) = mpsc::channel::<Result<Event, Infallible>>(buffer_size);
426
427 tokio::spawn(async move {
428 let mut _guard = cleanup_guard;
429
430 let connected = SsePayload::Connected {
432 session_id: session_id.to_string(),
433 session_secret: session_secret.clone(),
434 };
435 match serde_json::to_string(&connected) {
436 Ok(json) => {
437 let _ = event_tx
438 .send(Ok(Event::default().event("connected").data(json)))
439 .await;
440 }
441 Err(e) => {
442 tracing::error!("Failed to serialize SSE connected payload: {}", e);
443 }
444 }
445
446 loop {
447 tokio::select! {
448 msg = rx.recv() => {
449 match msg {
450 Some(sse_msg) => {
451 let event = match sse_msg {
452 SseMessage::Data { target, payload } => {
453 let data = SsePayload::Update { target, payload };
454 serde_json::to_string(&data).ok().map(|json| {
455 Event::default().event("update").data(json)
456 })
457 }
458 SseMessage::Error { target, code, message } => {
459 let data = SsePayload::Error { target, code, message };
460 serde_json::to_string(&data).ok().map(|json| {
461 Event::default().event("error").data(json)
462 })
463 }
464 };
465 if let Some(evt) = event
466 && event_tx.send(Ok(evt)).await.is_err()
467 {
468 break;
469 }
470 }
471 None => break,
472 }
473 }
474 _ = cancel.cancelled() => break,
475 }
476 }
477
478 _guard.mark_closed();
480 reactor.remove_session(session_id).await;
481 sessions.write().await.remove(&session_id);
482 });
483
484 let stream = ReceiverStream { rx: event_rx };
485
486 Sse::new(stream)
487 .keep_alive(
488 KeepAlive::new()
489 .interval(Duration::from_secs(keepalive_secs))
490 .text("ping"),
491 )
492 .into_response()
493}
494
495fn convert_realtime_to_sse(msg: RealtimeMessage) -> Option<SseMessage> {
497 match msg {
498 RealtimeMessage::Data {
499 subscription_id,
500 data,
501 } => Some(SseMessage::Data {
502 target: format!("sub:{}", subscription_id),
503 payload: data,
504 }),
505 RealtimeMessage::DeltaUpdate {
506 subscription_id,
507 delta,
508 } => match serde_json::to_value(&delta) {
509 Ok(payload) => Some(SseMessage::Data {
510 target: format!("sub:{}", subscription_id),
511 payload,
512 }),
513 Err(e) => {
514 tracing::error!("Failed to serialize delta update: {}", e);
515 Some(SseMessage::Error {
516 target: format!("sub:{}", subscription_id),
517 code: "SERIALIZATION_ERROR".to_string(),
518 message: "Failed to serialize update data".to_string(),
519 })
520 }
521 },
522 RealtimeMessage::JobUpdate { client_sub_id, job } => match serde_json::to_value(&job) {
523 Ok(payload) => Some(SseMessage::Data {
524 target: format!("job:{}", client_sub_id),
525 payload,
526 }),
527 Err(e) => {
528 tracing::error!("Failed to serialize job update: {}", e);
529 Some(SseMessage::Error {
530 target: format!("job:{}", client_sub_id),
531 code: "SERIALIZATION_ERROR".to_string(),
532 message: "Failed to serialize job update".to_string(),
533 })
534 }
535 },
536 RealtimeMessage::WorkflowUpdate {
537 client_sub_id,
538 workflow,
539 } => match serde_json::to_value(&workflow) {
540 Ok(payload) => Some(SseMessage::Data {
541 target: format!("wf:{}", client_sub_id),
542 payload,
543 }),
544 Err(e) => {
545 tracing::error!("Failed to serialize workflow update: {}", e);
546 Some(SseMessage::Error {
547 target: format!("wf:{}", client_sub_id),
548 code: "SERIALIZATION_ERROR".to_string(),
549 message: "Failed to serialize workflow update".to_string(),
550 })
551 }
552 },
553 RealtimeMessage::Error { code, message } => Some(SseMessage::Error {
554 target: String::new(),
555 code,
556 message,
557 }),
558 RealtimeMessage::ErrorWithId { id, code, message } => Some(SseMessage::Error {
559 target: id,
560 code,
561 message,
562 }),
563 RealtimeMessage::Subscribe { .. }
565 | RealtimeMessage::Unsubscribe { .. }
566 | RealtimeMessage::Ping
567 | RealtimeMessage::Pong
568 | RealtimeMessage::AuthSuccess
569 | RealtimeMessage::AuthFailed { .. }
570 | RealtimeMessage::Lagging => None,
571 }
572}
573
574pub async fn sse_subscribe_handler(
576 State(state): State<Arc<SseState>>,
577 Extension(request_auth): Extension<AuthContext>,
578 Json(request): Json<SseSubscribeRequest>,
579) -> impl IntoResponse {
580 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
582 return subscribe_error(
583 StatusCode::BAD_REQUEST,
584 "INVALID_ID",
585 format!(
586 "Subscription ID too long (max {} chars)",
587 MAX_CLIENT_SUB_ID_LEN
588 ),
589 );
590 }
591
592 let Some(session_id) = try_parse_session_id(&request.session_id) else {
593 return subscribe_error(
594 StatusCode::BAD_REQUEST,
595 "INVALID_SESSION",
596 "Invalid session ID format",
597 );
598 };
599
600 let sessions = state.sessions.read().await;
602 let session_data = match sessions.get(&session_id) {
603 Some(data) => {
604 match authorize_session_access(data, &request.session_secret, &request_auth) {
605 Ok(auth) => auth,
606 Err(resp) => return resp,
607 }
608 }
609 None => {
610 return subscribe_error(
611 StatusCode::NOT_FOUND,
612 "SESSION_NOT_FOUND",
613 "Session not found or expired",
614 );
615 }
616 };
617 drop(sessions);
618
619 let result = state
621 .reactor
622 .subscribe(
623 session_id,
624 request.id.clone(),
625 request.function,
626 request.args,
627 session_data,
628 )
629 .await;
630
631 match result {
632 Ok((subscription_id, data)) => {
633 let mut sessions = state.sessions.write().await;
635 match sessions.get_mut(&session_id) {
636 Some(session) => {
637 session.subscriptions.insert(request.id, subscription_id);
638 }
639 None => {
640 return subscribe_error(
642 StatusCode::NOT_FOUND,
643 "SESSION_NOT_FOUND",
644 "Session expired during subscription",
645 );
646 }
647 }
648
649 tracing::debug!(
650 %session_id,
651 %subscription_id,
652 "SSE subscription registered"
653 );
654
655 (
656 StatusCode::OK,
657 Json(SseSubscribeResponse {
658 success: true,
659 data: Some(data),
660 error: None,
661 }),
662 )
663 }
664 Err(e) => {
665 tracing::warn!(%session_id, error = %e, "SSE subscription failed");
666 match e {
667 forge_core::ForgeError::Unauthorized(msg) => {
668 subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
669 }
670 forge_core::ForgeError::Forbidden(msg) => {
671 subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
672 }
673 forge_core::ForgeError::InvalidArgument(msg)
674 | forge_core::ForgeError::Validation(msg) => {
675 subscribe_error(StatusCode::BAD_REQUEST, "INVALID_ARGUMENT", msg)
676 }
677 forge_core::ForgeError::NotFound(msg) => {
678 subscribe_error(StatusCode::NOT_FOUND, "NOT_FOUND", msg)
679 }
680 _ => subscribe_error(
681 StatusCode::INTERNAL_SERVER_ERROR,
682 "SUBSCRIPTION_FAILED",
683 "Subscription failed",
684 ),
685 }
686 }
687 }
688}
689
690pub async fn sse_unsubscribe_handler(
692 State(state): State<Arc<SseState>>,
693 Extension(request_auth): Extension<AuthContext>,
694 Json(request): Json<SseUnsubscribeRequest>,
695) -> impl IntoResponse {
696 let Some(session_id) = try_parse_session_id(&request.session_id) else {
697 return unsubscribe_error(
698 StatusCode::BAD_REQUEST,
699 "INVALID_SESSION",
700 "Invalid session ID format",
701 );
702 };
703
704 let subscription_id = {
706 let sessions = state.sessions.read().await;
707 match sessions.get(&session_id) {
708 Some(session) => {
709 if session.session_secret != request.session_secret
710 || !same_principal(&session.auth_context, &request_auth)
711 {
712 return unsubscribe_error(
713 StatusCode::FORBIDDEN,
714 "SESSION_PRINCIPAL_MISMATCH",
715 "Request principal does not match session principal",
716 );
717 }
718 session.subscriptions.get(&request.id).copied()
719 }
720 None => None,
721 }
722 };
723
724 let Some(subscription_id) = subscription_id else {
725 return unsubscribe_error(
726 StatusCode::NOT_FOUND,
727 "SUBSCRIPTION_NOT_FOUND",
728 "Subscription not found",
729 );
730 };
731
732 state.reactor.unsubscribe(subscription_id);
734
735 {
737 let mut sessions = state.sessions.write().await;
738 if let Some(session) = sessions.get_mut(&session_id) {
739 session.subscriptions.remove(&request.id);
740 }
741 }
742
743 tracing::debug!(
744 %session_id,
745 %subscription_id,
746 "SSE subscription removed"
747 );
748
749 (
750 StatusCode::OK,
751 Json(SseUnsubscribeResponse {
752 success: true,
753 error: None,
754 }),
755 )
756}
757
758pub async fn sse_job_subscribe_handler(
760 State(state): State<Arc<SseState>>,
761 Extension(request_auth): Extension<AuthContext>,
762 Json(request): Json<SseJobSubscribeRequest>,
763) -> impl IntoResponse {
764 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
765 return subscribe_error(
766 StatusCode::BAD_REQUEST,
767 "INVALID_ID",
768 format!(
769 "Subscription ID too long (max {} chars)",
770 MAX_CLIENT_SUB_ID_LEN
771 ),
772 );
773 }
774
775 let Some(session_id) = try_parse_session_id(&request.session_id) else {
776 return subscribe_error(
777 StatusCode::BAD_REQUEST,
778 "INVALID_SESSION",
779 "Invalid session ID format",
780 );
781 };
782
783 let session_auth = {
785 let sessions = state.sessions.read().await;
786 match sessions.get(&session_id) {
787 Some(session) => {
788 match authorize_session_access(session, &request.session_secret, &request_auth) {
789 Ok(auth) => auth,
790 Err(resp) => return resp,
791 }
792 }
793 None => {
794 return subscribe_error(
795 StatusCode::NOT_FOUND,
796 "SESSION_NOT_FOUND",
797 "Session not found or expired",
798 );
799 }
800 }
801 };
802
803 let job_uuid = match uuid::Uuid::parse_str(&request.job_id) {
805 Ok(uuid) => uuid,
806 Err(_) => {
807 return subscribe_error(
808 StatusCode::BAD_REQUEST,
809 "INVALID_JOB_ID",
810 "Invalid job ID format",
811 );
812 }
813 };
814
815 match state
817 .reactor
818 .subscribe_job(session_id, request.id.clone(), job_uuid, &session_auth)
819 .await
820 {
821 Ok(job_data) => {
822 let data = match serde_json::to_value(&job_data) {
823 Ok(v) => v,
824 Err(e) => {
825 tracing::error!("Failed to serialize job data: {}", e);
826 return subscribe_error(
827 StatusCode::INTERNAL_SERVER_ERROR,
828 "SERIALIZE_ERROR",
829 "Failed to serialize job data",
830 );
831 }
832 };
833 tracing::debug!(
834 %session_id,
835 job_id = %request.job_id,
836 client_sub_id = %request.id,
837 "SSE job subscription registered"
838 );
839 (
840 StatusCode::OK,
841 Json(SseSubscribeResponse {
842 success: true,
843 data: Some(data),
844 error: None,
845 }),
846 )
847 }
848 Err(e) => match e {
849 forge_core::ForgeError::Unauthorized(msg) => {
850 subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
851 }
852 forge_core::ForgeError::Forbidden(msg) => {
853 subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
854 }
855 forge_core::ForgeError::NotFound(msg) => {
856 subscribe_error(StatusCode::NOT_FOUND, "JOB_NOT_FOUND", msg)
857 }
858 _ => subscribe_error(
859 StatusCode::INTERNAL_SERVER_ERROR,
860 "SUBSCRIPTION_FAILED",
861 "Subscription failed",
862 ),
863 },
864 }
865}
866
867pub async fn sse_workflow_subscribe_handler(
869 State(state): State<Arc<SseState>>,
870 Extension(request_auth): Extension<AuthContext>,
871 Json(request): Json<SseWorkflowSubscribeRequest>,
872) -> impl IntoResponse {
873 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
874 return subscribe_error(
875 StatusCode::BAD_REQUEST,
876 "INVALID_ID",
877 format!(
878 "Subscription ID too long (max {} chars)",
879 MAX_CLIENT_SUB_ID_LEN
880 ),
881 );
882 }
883
884 let Some(session_id) = try_parse_session_id(&request.session_id) else {
885 return subscribe_error(
886 StatusCode::BAD_REQUEST,
887 "INVALID_SESSION",
888 "Invalid session ID format",
889 );
890 };
891
892 let session_auth = {
894 let sessions = state.sessions.read().await;
895 match sessions.get(&session_id) {
896 Some(session) => {
897 match authorize_session_access(session, &request.session_secret, &request_auth) {
898 Ok(auth) => auth,
899 Err(resp) => return resp,
900 }
901 }
902 None => {
903 return subscribe_error(
904 StatusCode::NOT_FOUND,
905 "SESSION_NOT_FOUND",
906 "Session not found or expired",
907 );
908 }
909 }
910 };
911
912 let workflow_uuid = match uuid::Uuid::parse_str(&request.workflow_id) {
914 Ok(uuid) => uuid,
915 Err(_) => {
916 return subscribe_error(
917 StatusCode::BAD_REQUEST,
918 "INVALID_WORKFLOW_ID",
919 "Invalid workflow ID format",
920 );
921 }
922 };
923
924 match state
926 .reactor
927 .subscribe_workflow(session_id, request.id.clone(), workflow_uuid, &session_auth)
928 .await
929 {
930 Ok(workflow_data) => {
931 let data = match serde_json::to_value(&workflow_data) {
932 Ok(v) => v,
933 Err(e) => {
934 tracing::error!("Failed to serialize workflow data: {}", e);
935 return subscribe_error(
936 StatusCode::INTERNAL_SERVER_ERROR,
937 "SERIALIZE_ERROR",
938 "Failed to serialize workflow data",
939 );
940 }
941 };
942 tracing::debug!(
943 %session_id,
944 workflow_id = %request.workflow_id,
945 client_sub_id = %request.id,
946 "SSE workflow subscription registered"
947 );
948 (
949 StatusCode::OK,
950 Json(SseSubscribeResponse {
951 success: true,
952 data: Some(data),
953 error: None,
954 }),
955 )
956 }
957 Err(e) => match e {
958 forge_core::ForgeError::Unauthorized(msg) => {
959 subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
960 }
961 forge_core::ForgeError::Forbidden(msg) => {
962 subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
963 }
964 forge_core::ForgeError::NotFound(msg) => {
965 subscribe_error(StatusCode::NOT_FOUND, "WORKFLOW_NOT_FOUND", msg)
966 }
967 _ => subscribe_error(
968 StatusCode::INTERNAL_SERVER_ERROR,
969 "SUBSCRIPTION_FAILED",
970 "Subscription failed",
971 ),
972 },
973 }
974}
975
976#[cfg(test)]
977#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
978mod tests {
979 use super::*;
980
981 #[test]
982 fn test_sse_payload_serialization() {
983 let payload = SsePayload::Update {
984 target: "sub:123".to_string(),
985 payload: serde_json::json!({"id": 1}),
986 };
987 let json = serde_json::to_string(&payload).unwrap();
988 assert!(json.contains("\"type\":\"update\""));
989 assert!(json.contains("\"target\":\"sub:123\""));
990 }
991
992 #[test]
993 fn test_sse_error_serialization() {
994 let payload = SsePayload::Error {
995 target: "sub:456".to_string(),
996 code: "NOT_FOUND".to_string(),
997 message: "Subscription not found".to_string(),
998 };
999 let json = serde_json::to_string(&payload).unwrap();
1000 assert!(json.contains("\"type\":\"error\""));
1001 assert!(json.contains("NOT_FOUND"));
1002 }
1003}