1use std::collections::HashMap;
4use std::convert::Infallible;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use axum::Json;
11use axum::extract::{Extension, Query, State};
12use axum::http::StatusCode;
13use axum::response::IntoResponse;
14use axum::response::sse::{Event, KeepAlive, Sse};
15use futures_util::Stream;
16use serde::{Deserialize, Serialize};
17use tokio::sync::{RwLock, mpsc};
18use tokio_util::sync::CancellationToken;
19
20struct ReceiverStream<T> {
22 rx: mpsc::Receiver<T>,
23}
24
25impl<T> Stream for ReceiverStream<T> {
26 type Item = T;
27 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
28 self.rx.poll_recv(cx)
29 }
30}
31
32use forge_core::function::AuthContext;
33use forge_core::realtime::{SessionId, SubscriptionId};
34
35use super::auth::AuthMiddleware;
36use crate::realtime::Reactor;
37use crate::realtime::RealtimeMessage;
38
39const MAX_CLIENT_SUB_ID_LEN: usize = 255;
41
42fn try_parse_session_id(session_id: &str) -> Option<SessionId> {
43 uuid::Uuid::parse_str(session_id)
44 .ok()
45 .map(SessionId::from_uuid)
46}
47
48fn same_principal(a: &AuthContext, b: &AuthContext) -> bool {
49 match (a.is_authenticated(), b.is_authenticated()) {
50 (false, false) => true,
51 (true, true) => a.principal_id().is_some() && a.principal_id() == b.principal_id(),
52 _ => false,
53 }
54}
55
56fn resolve_sse_auth_context(
57 request_auth: &AuthContext,
58 query_auth: Option<AuthContext>,
59) -> AuthContext {
60 query_auth.unwrap_or_else(|| request_auth.clone())
61}
62
63fn authorize_session_access(
64 session: &SseSessionData,
65 session_secret: &str,
66 requester_auth: &AuthContext,
67) -> Result<AuthContext, (StatusCode, Json<SseSubscribeResponse>)> {
68 if session.session_secret != session_secret {
69 return Err(subscribe_error(
70 StatusCode::UNAUTHORIZED,
71 "INVALID_SESSION_SECRET",
72 "Session secret mismatch",
73 ));
74 }
75
76 if !same_principal(&session.auth_context, requester_auth) {
77 return Err(subscribe_error(
78 StatusCode::FORBIDDEN,
79 "SESSION_PRINCIPAL_MISMATCH",
80 "Request principal does not match session principal",
81 ));
82 }
83
84 Ok(session.auth_context.clone())
85}
86
87#[derive(Debug, Clone)]
89pub struct SseConfig {
90 pub max_sessions: usize,
92 pub channel_buffer_size: usize,
94 pub keepalive_interval_secs: u64,
96}
97
98impl Default for SseConfig {
99 fn default() -> Self {
100 Self {
101 max_sessions: 10_000,
102 channel_buffer_size: 256,
103 keepalive_interval_secs: 30,
104 }
105 }
106}
107
108#[derive(Debug, Deserialize)]
110pub struct SseQuery {
111 pub token: Option<String>,
113}
114
115struct SseSessionData {
116 auth_context: AuthContext,
117 session_secret: String,
118 subscriptions: HashMap<String, SubscriptionId>,
120}
121
122#[derive(Clone)]
124pub struct SseState {
125 reactor: Arc<Reactor>,
126 auth_middleware: Arc<AuthMiddleware>,
127 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
129 config: SseConfig,
130}
131
132impl SseState {
133 pub fn new(reactor: Arc<Reactor>, auth_middleware: Arc<AuthMiddleware>) -> Self {
135 Self::with_config(reactor, auth_middleware, SseConfig::default())
136 }
137
138 pub fn with_config(
140 reactor: Arc<Reactor>,
141 auth_middleware: Arc<AuthMiddleware>,
142 config: SseConfig,
143 ) -> Self {
144 Self {
145 reactor,
146 auth_middleware,
147 sessions: Arc::new(RwLock::new(HashMap::new())),
148 config,
149 }
150 }
151
152 pub async fn can_accept_session(&self) -> bool {
154 self.sessions.read().await.len() < self.config.max_sessions
155 }
156}
157
158struct SessionCleanupGuard {
161 session_id: SessionId,
162 reactor: Arc<Reactor>,
163 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
164 dropped: bool,
165}
166
167impl SessionCleanupGuard {
168 fn new(
169 session_id: SessionId,
170 reactor: Arc<Reactor>,
171 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
172 ) -> Self {
173 Self {
174 session_id,
175 reactor,
176 sessions,
177 dropped: false,
178 }
179 }
180
181 fn mark_closed(&mut self) {
183 self.dropped = true;
184 }
185}
186
187impl Drop for SessionCleanupGuard {
188 fn drop(&mut self) {
189 if self.dropped {
190 return;
191 }
192 let session_id = self.session_id;
193 let reactor = self.reactor.clone();
194 let sessions = self.sessions.clone();
195
196 crate::observability::set_active_connections("sse", -1);
199 if let Ok(handle) = tokio::runtime::Handle::try_current() {
200 handle.spawn(async move {
201 reactor.remove_session(session_id).await;
202 sessions.write().await.remove(&session_id);
203 tracing::debug!(%session_id, "SSE session cleaned up on disconnect");
204 });
205 } else {
206 tracing::warn!(%session_id, "Could not spawn cleanup task, runtime unavailable");
208 }
209 }
210}
211
212#[derive(Debug, Serialize)]
214#[serde(tag = "type", rename_all = "snake_case")]
215pub enum SsePayload {
216 Update {
218 target: String,
219 payload: serde_json::Value,
220 },
221 Error {
223 target: String,
224 code: String,
225 message: String,
226 },
227 Connected {
229 session_id: String,
230 session_secret: String,
231 },
232}
233
234#[derive(Debug)]
236pub enum SseMessage {
237 Data {
238 target: String,
239 payload: serde_json::Value,
240 },
241 Error {
242 target: String,
243 code: String,
244 message: String,
245 },
246}
247
248#[derive(Debug, Deserialize)]
250pub struct SseSubscribeRequest {
251 pub session_id: String,
252 pub session_secret: String,
253 pub id: String,
254 pub function: String,
255 #[serde(default)]
256 pub args: serde_json::Value,
257}
258
259#[derive(Debug, Deserialize)]
261pub struct SseUnsubscribeRequest {
262 pub session_id: String,
263 pub session_secret: String,
264 pub id: String,
265}
266
267#[derive(Debug, Deserialize)]
269pub struct SseJobSubscribeRequest {
270 pub session_id: String,
271 pub session_secret: String,
272 pub id: String,
273 pub job_id: String,
274}
275
276#[derive(Debug, Deserialize)]
278pub struct SseWorkflowSubscribeRequest {
279 pub session_id: String,
280 pub session_secret: String,
281 pub id: String,
282 pub workflow_id: String,
283}
284
285#[derive(Debug, Serialize)]
287pub struct SseError {
288 pub code: String,
289 pub message: String,
290}
291
292#[derive(Debug, Serialize)]
294pub struct SseSubscribeResponse {
295 pub success: bool,
296 #[serde(skip_serializing_if = "Option::is_none")]
297 pub data: Option<serde_json::Value>,
298 #[serde(skip_serializing_if = "Option::is_none")]
299 pub error: Option<SseError>,
300}
301
302#[derive(Debug, Serialize)]
304pub struct SseUnsubscribeResponse {
305 pub success: bool,
306 #[serde(skip_serializing_if = "Option::is_none")]
307 pub error: Option<SseError>,
308}
309
310fn subscribe_error(
312 status: StatusCode,
313 code: impl Into<String>,
314 message: impl Into<String>,
315) -> (StatusCode, Json<SseSubscribeResponse>) {
316 (
317 status,
318 Json(SseSubscribeResponse {
319 success: false,
320 data: None,
321 error: Some(SseError {
322 code: code.into(),
323 message: message.into(),
324 }),
325 }),
326 )
327}
328
329fn unsubscribe_error(
331 status: StatusCode,
332 code: impl Into<String>,
333 message: impl Into<String>,
334) -> (StatusCode, Json<SseUnsubscribeResponse>) {
335 (
336 status,
337 Json(SseUnsubscribeResponse {
338 success: false,
339 error: Some(SseError {
340 code: code.into(),
341 message: message.into(),
342 }),
343 }),
344 )
345}
346
347pub async fn sse_handler(
349 State(state): State<Arc<SseState>>,
350 Extension(request_auth): Extension<AuthContext>,
351 Query(query): Query<SseQuery>,
352) -> impl IntoResponse {
353 if !state.can_accept_session().await {
355 return (
356 StatusCode::SERVICE_UNAVAILABLE,
357 "Server at capacity".to_string(),
358 )
359 .into_response();
360 }
361
362 let session_id = SessionId::new();
363 let buffer_size = state.config.channel_buffer_size;
364 let keepalive_secs = state.config.keepalive_interval_secs;
365 let (tx, mut rx) = mpsc::channel::<SseMessage>(buffer_size);
366 let cancel_token = CancellationToken::new();
367
368 let query_auth = if let Some(token) = &query.token {
369 match state.auth_middleware.validate_token_async(token).await {
370 Ok(claims) => Some(super::auth::build_auth_context_from_claims(claims)),
371 Err(e) => {
372 tracing::warn!("SSE token validation failed: {}", e);
373 return (
374 StatusCode::UNAUTHORIZED,
375 "Invalid authentication token".to_string(),
376 )
377 .into_response();
378 }
379 }
380 } else {
381 None
382 };
383 let auth_context = resolve_sse_auth_context(&request_auth, query_auth);
384 let session_secret = uuid::Uuid::new_v4().to_string();
385
386 let reactor = state.reactor.clone();
388 let cancel = cancel_token.clone();
389
390 let (rt_tx, mut rt_rx) = mpsc::channel(buffer_size);
392 reactor.register_session(session_id, rt_tx);
393
394 {
396 let mut sessions = state.sessions.write().await;
397 sessions.insert(
398 session_id,
399 SseSessionData {
400 auth_context: auth_context.clone(),
401 session_secret: session_secret.clone(),
402 subscriptions: HashMap::new(),
403 },
404 );
405 }
406
407 let sessions = state.sessions.clone();
409
410 let cleanup_guard = SessionCleanupGuard::new(session_id, reactor.clone(), sessions.clone());
412 crate::observability::set_active_connections("sse", 1);
413
414 let bridge_cancel = cancel_token.clone();
416 tokio::spawn(async move {
417 loop {
418 tokio::select! {
419 msg = rt_rx.recv() => {
420 match msg {
421 Some(rt_msg) => {
422 if let Some(sse_msg) = convert_realtime_to_sse(rt_msg)
423 && tx.send(sse_msg).await.is_err() {
424 break;
425 }
426 }
427 None => break,
428 }
429 }
430 _ = bridge_cancel.cancelled() => break,
431 }
432 }
433 });
434
435 let (event_tx, event_rx) = mpsc::channel::<Result<Event, Infallible>>(buffer_size);
437
438 tokio::spawn(async move {
439 let mut _guard = cleanup_guard;
440
441 let connected = SsePayload::Connected {
443 session_id: session_id.to_string(),
444 session_secret: session_secret.clone(),
445 };
446 match serde_json::to_string(&connected) {
447 Ok(json) => {
448 let _ = event_tx
449 .send(Ok(Event::default().event("connected").data(json)))
450 .await;
451 }
452 Err(e) => {
453 tracing::error!("Failed to serialize SSE connected payload: {}", e);
454 }
455 }
456
457 loop {
458 tokio::select! {
459 msg = rx.recv() => {
460 match msg {
461 Some(sse_msg) => {
462 let event = match sse_msg {
463 SseMessage::Data { target, payload } => {
464 let data = SsePayload::Update { target, payload };
465 serde_json::to_string(&data).ok().map(|json| {
466 Event::default().event("update").data(json)
467 })
468 }
469 SseMessage::Error { target, code, message } => {
470 let data = SsePayload::Error { target, code, message };
471 serde_json::to_string(&data).ok().map(|json| {
472 Event::default().event("error").data(json)
473 })
474 }
475 };
476 if let Some(evt) = event
477 && event_tx.send(Ok(evt)).await.is_err()
478 {
479 break;
480 }
481 }
482 None => break,
483 }
484 }
485 _ = cancel.cancelled() => break,
486 }
487 }
488
489 _guard.mark_closed();
491 crate::observability::set_active_connections("sse", -1);
492 reactor.remove_session(session_id).await;
493 sessions.write().await.remove(&session_id);
494 });
495
496 let stream = ReceiverStream { rx: event_rx };
497
498 Sse::new(stream)
499 .keep_alive(
500 KeepAlive::new()
501 .interval(Duration::from_secs(keepalive_secs))
502 .text("ping"),
503 )
504 .into_response()
505}
506
507fn convert_realtime_to_sse(msg: RealtimeMessage) -> Option<SseMessage> {
509 match msg {
510 RealtimeMessage::Data {
511 subscription_id,
512 data,
513 } => Some(SseMessage::Data {
514 target: format!("sub:{}", subscription_id),
515 payload: data,
516 }),
517 RealtimeMessage::DeltaUpdate {
518 subscription_id,
519 delta,
520 } => match serde_json::to_value(&delta) {
521 Ok(payload) => Some(SseMessage::Data {
522 target: format!("sub:{}", subscription_id),
523 payload,
524 }),
525 Err(e) => {
526 tracing::error!("Failed to serialize delta update: {}", e);
527 Some(SseMessage::Error {
528 target: format!("sub:{}", subscription_id),
529 code: "SERIALIZATION_ERROR".to_string(),
530 message: "Failed to serialize update data".to_string(),
531 })
532 }
533 },
534 RealtimeMessage::JobUpdate { client_sub_id, job } => match serde_json::to_value(&job) {
535 Ok(payload) => Some(SseMessage::Data {
536 target: format!("job:{}", client_sub_id),
537 payload,
538 }),
539 Err(e) => {
540 tracing::error!("Failed to serialize job update: {}", e);
541 Some(SseMessage::Error {
542 target: format!("job:{}", client_sub_id),
543 code: "SERIALIZATION_ERROR".to_string(),
544 message: "Failed to serialize job update".to_string(),
545 })
546 }
547 },
548 RealtimeMessage::WorkflowUpdate {
549 client_sub_id,
550 workflow,
551 } => match serde_json::to_value(&workflow) {
552 Ok(payload) => Some(SseMessage::Data {
553 target: format!("wf:{}", client_sub_id),
554 payload,
555 }),
556 Err(e) => {
557 tracing::error!("Failed to serialize workflow update: {}", e);
558 Some(SseMessage::Error {
559 target: format!("wf:{}", client_sub_id),
560 code: "SERIALIZATION_ERROR".to_string(),
561 message: "Failed to serialize workflow update".to_string(),
562 })
563 }
564 },
565 RealtimeMessage::Error { code, message } => Some(SseMessage::Error {
566 target: String::new(),
567 code,
568 message,
569 }),
570 RealtimeMessage::ErrorWithId { id, code, message } => Some(SseMessage::Error {
571 target: id,
572 code,
573 message,
574 }),
575 RealtimeMessage::Subscribe { .. }
577 | RealtimeMessage::Unsubscribe { .. }
578 | RealtimeMessage::Ping
579 | RealtimeMessage::Pong
580 | RealtimeMessage::AuthSuccess
581 | RealtimeMessage::AuthFailed { .. }
582 | RealtimeMessage::Lagging => None,
583 }
584}
585
586pub async fn sse_subscribe_handler(
588 State(state): State<Arc<SseState>>,
589 Extension(request_auth): Extension<AuthContext>,
590 Json(request): Json<SseSubscribeRequest>,
591) -> impl IntoResponse {
592 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
594 return subscribe_error(
595 StatusCode::BAD_REQUEST,
596 "INVALID_ID",
597 format!(
598 "Subscription ID too long (max {} chars)",
599 MAX_CLIENT_SUB_ID_LEN
600 ),
601 );
602 }
603
604 let Some(session_id) = try_parse_session_id(&request.session_id) else {
605 return subscribe_error(
606 StatusCode::BAD_REQUEST,
607 "INVALID_SESSION",
608 "Invalid session ID format",
609 );
610 };
611
612 let sessions = state.sessions.read().await;
614 let session_data = match sessions.get(&session_id) {
615 Some(data) => {
616 match authorize_session_access(data, &request.session_secret, &request_auth) {
617 Ok(auth) => auth,
618 Err(resp) => return resp,
619 }
620 }
621 None => {
622 return subscribe_error(
623 StatusCode::NOT_FOUND,
624 "SESSION_NOT_FOUND",
625 "Session not found or expired",
626 );
627 }
628 };
629 drop(sessions);
630
631 let result = state
633 .reactor
634 .subscribe(
635 session_id,
636 request.id.clone(),
637 request.function,
638 request.args,
639 session_data,
640 )
641 .await;
642
643 match result {
644 Ok((subscription_id, data)) => {
645 let mut sessions = state.sessions.write().await;
647 match sessions.get_mut(&session_id) {
648 Some(session) => {
649 session.subscriptions.insert(request.id, subscription_id);
650 }
651 None => {
652 return subscribe_error(
654 StatusCode::NOT_FOUND,
655 "SESSION_NOT_FOUND",
656 "Session expired during subscription",
657 );
658 }
659 }
660
661 tracing::debug!(
662 %session_id,
663 %subscription_id,
664 "SSE subscription registered"
665 );
666
667 (
668 StatusCode::OK,
669 Json(SseSubscribeResponse {
670 success: true,
671 data: Some(data),
672 error: None,
673 }),
674 )
675 }
676 Err(e) => {
677 tracing::warn!(%session_id, error = %e, "SSE subscription failed");
678 match e {
679 forge_core::ForgeError::Unauthorized(msg) => {
680 subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
681 }
682 forge_core::ForgeError::Forbidden(msg) => {
683 subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
684 }
685 forge_core::ForgeError::InvalidArgument(msg)
686 | forge_core::ForgeError::Validation(msg) => {
687 subscribe_error(StatusCode::BAD_REQUEST, "INVALID_ARGUMENT", msg)
688 }
689 forge_core::ForgeError::NotFound(msg) => {
690 subscribe_error(StatusCode::NOT_FOUND, "NOT_FOUND", msg)
691 }
692 _ => subscribe_error(
693 StatusCode::INTERNAL_SERVER_ERROR,
694 "SUBSCRIPTION_FAILED",
695 "Subscription failed",
696 ),
697 }
698 }
699 }
700}
701
702pub async fn sse_unsubscribe_handler(
704 State(state): State<Arc<SseState>>,
705 Extension(request_auth): Extension<AuthContext>,
706 Json(request): Json<SseUnsubscribeRequest>,
707) -> impl IntoResponse {
708 let Some(session_id) = try_parse_session_id(&request.session_id) else {
709 return unsubscribe_error(
710 StatusCode::BAD_REQUEST,
711 "INVALID_SESSION",
712 "Invalid session ID format",
713 );
714 };
715
716 let subscription_id = {
718 let sessions = state.sessions.read().await;
719 match sessions.get(&session_id) {
720 Some(session) => {
721 if session.session_secret != request.session_secret
722 || !same_principal(&session.auth_context, &request_auth)
723 {
724 return unsubscribe_error(
725 StatusCode::FORBIDDEN,
726 "SESSION_PRINCIPAL_MISMATCH",
727 "Request principal does not match session principal",
728 );
729 }
730 session.subscriptions.get(&request.id).copied()
731 }
732 None => None,
733 }
734 };
735
736 let Some(subscription_id) = subscription_id else {
737 return unsubscribe_error(
738 StatusCode::NOT_FOUND,
739 "SUBSCRIPTION_NOT_FOUND",
740 "Subscription not found",
741 );
742 };
743
744 state.reactor.unsubscribe(subscription_id);
746
747 {
749 let mut sessions = state.sessions.write().await;
750 if let Some(session) = sessions.get_mut(&session_id) {
751 session.subscriptions.remove(&request.id);
752 }
753 }
754
755 tracing::debug!(
756 %session_id,
757 %subscription_id,
758 "SSE subscription removed"
759 );
760
761 (
762 StatusCode::OK,
763 Json(SseUnsubscribeResponse {
764 success: true,
765 error: None,
766 }),
767 )
768}
769
770pub async fn sse_job_subscribe_handler(
772 State(state): State<Arc<SseState>>,
773 Extension(request_auth): Extension<AuthContext>,
774 Json(request): Json<SseJobSubscribeRequest>,
775) -> impl IntoResponse {
776 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
777 return subscribe_error(
778 StatusCode::BAD_REQUEST,
779 "INVALID_ID",
780 format!(
781 "Subscription ID too long (max {} chars)",
782 MAX_CLIENT_SUB_ID_LEN
783 ),
784 );
785 }
786
787 let Some(session_id) = try_parse_session_id(&request.session_id) else {
788 return subscribe_error(
789 StatusCode::BAD_REQUEST,
790 "INVALID_SESSION",
791 "Invalid session ID format",
792 );
793 };
794
795 let session_auth = {
797 let sessions = state.sessions.read().await;
798 match sessions.get(&session_id) {
799 Some(session) => {
800 match authorize_session_access(session, &request.session_secret, &request_auth) {
801 Ok(auth) => auth,
802 Err(resp) => return resp,
803 }
804 }
805 None => {
806 return subscribe_error(
807 StatusCode::NOT_FOUND,
808 "SESSION_NOT_FOUND",
809 "Session not found or expired",
810 );
811 }
812 }
813 };
814
815 let job_uuid = match uuid::Uuid::parse_str(&request.job_id) {
817 Ok(uuid) => uuid,
818 Err(_) => {
819 return subscribe_error(
820 StatusCode::BAD_REQUEST,
821 "INVALID_JOB_ID",
822 "Invalid job ID format",
823 );
824 }
825 };
826
827 match state
829 .reactor
830 .subscribe_job(session_id, request.id.clone(), job_uuid, &session_auth)
831 .await
832 {
833 Ok(job_data) => {
834 let data = match serde_json::to_value(&job_data) {
835 Ok(v) => v,
836 Err(e) => {
837 tracing::error!("Failed to serialize job data: {}", e);
838 return subscribe_error(
839 StatusCode::INTERNAL_SERVER_ERROR,
840 "SERIALIZE_ERROR",
841 "Failed to serialize job data",
842 );
843 }
844 };
845 tracing::debug!(
846 %session_id,
847 job_id = %request.job_id,
848 client_sub_id = %request.id,
849 "SSE job subscription registered"
850 );
851 (
852 StatusCode::OK,
853 Json(SseSubscribeResponse {
854 success: true,
855 data: Some(data),
856 error: None,
857 }),
858 )
859 }
860 Err(e) => match e {
861 forge_core::ForgeError::Unauthorized(msg) => {
862 subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
863 }
864 forge_core::ForgeError::Forbidden(msg) => {
865 subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
866 }
867 forge_core::ForgeError::NotFound(msg) => {
868 subscribe_error(StatusCode::NOT_FOUND, "JOB_NOT_FOUND", msg)
869 }
870 _ => subscribe_error(
871 StatusCode::INTERNAL_SERVER_ERROR,
872 "SUBSCRIPTION_FAILED",
873 "Subscription failed",
874 ),
875 },
876 }
877}
878
879pub async fn sse_workflow_subscribe_handler(
881 State(state): State<Arc<SseState>>,
882 Extension(request_auth): Extension<AuthContext>,
883 Json(request): Json<SseWorkflowSubscribeRequest>,
884) -> impl IntoResponse {
885 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
886 return subscribe_error(
887 StatusCode::BAD_REQUEST,
888 "INVALID_ID",
889 format!(
890 "Subscription ID too long (max {} chars)",
891 MAX_CLIENT_SUB_ID_LEN
892 ),
893 );
894 }
895
896 let Some(session_id) = try_parse_session_id(&request.session_id) else {
897 return subscribe_error(
898 StatusCode::BAD_REQUEST,
899 "INVALID_SESSION",
900 "Invalid session ID format",
901 );
902 };
903
904 let session_auth = {
906 let sessions = state.sessions.read().await;
907 match sessions.get(&session_id) {
908 Some(session) => {
909 match authorize_session_access(session, &request.session_secret, &request_auth) {
910 Ok(auth) => auth,
911 Err(resp) => return resp,
912 }
913 }
914 None => {
915 return subscribe_error(
916 StatusCode::NOT_FOUND,
917 "SESSION_NOT_FOUND",
918 "Session not found or expired",
919 );
920 }
921 }
922 };
923
924 let workflow_uuid = match uuid::Uuid::parse_str(&request.workflow_id) {
926 Ok(uuid) => uuid,
927 Err(_) => {
928 return subscribe_error(
929 StatusCode::BAD_REQUEST,
930 "INVALID_WORKFLOW_ID",
931 "Invalid workflow ID format",
932 );
933 }
934 };
935
936 match state
938 .reactor
939 .subscribe_workflow(session_id, request.id.clone(), workflow_uuid, &session_auth)
940 .await
941 {
942 Ok(workflow_data) => {
943 let data = match serde_json::to_value(&workflow_data) {
944 Ok(v) => v,
945 Err(e) => {
946 tracing::error!("Failed to serialize workflow data: {}", e);
947 return subscribe_error(
948 StatusCode::INTERNAL_SERVER_ERROR,
949 "SERIALIZE_ERROR",
950 "Failed to serialize workflow data",
951 );
952 }
953 };
954 tracing::debug!(
955 %session_id,
956 workflow_id = %request.workflow_id,
957 client_sub_id = %request.id,
958 "SSE workflow subscription registered"
959 );
960 (
961 StatusCode::OK,
962 Json(SseSubscribeResponse {
963 success: true,
964 data: Some(data),
965 error: None,
966 }),
967 )
968 }
969 Err(e) => match e {
970 forge_core::ForgeError::Unauthorized(msg) => {
971 subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
972 }
973 forge_core::ForgeError::Forbidden(msg) => {
974 subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
975 }
976 forge_core::ForgeError::NotFound(msg) => {
977 subscribe_error(StatusCode::NOT_FOUND, "WORKFLOW_NOT_FOUND", msg)
978 }
979 _ => subscribe_error(
980 StatusCode::INTERNAL_SERVER_ERROR,
981 "SUBSCRIPTION_FAILED",
982 "Subscription failed",
983 ),
984 },
985 }
986}
987
988#[cfg(test)]
989#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
990mod tests {
991 use super::*;
992 use std::collections::HashMap;
993
994 use uuid::Uuid;
995
996 #[test]
997 fn test_sse_payload_serialization() {
998 let payload = SsePayload::Update {
999 target: "sub:123".to_string(),
1000 payload: serde_json::json!({"id": 1}),
1001 };
1002 let json = serde_json::to_string(&payload).unwrap();
1003 assert!(json.contains("\"type\":\"update\""));
1004 assert!(json.contains("\"target\":\"sub:123\""));
1005 }
1006
1007 #[test]
1008 fn test_sse_error_serialization() {
1009 let payload = SsePayload::Error {
1010 target: "sub:456".to_string(),
1011 code: "NOT_FOUND".to_string(),
1012 message: "Subscription not found".to_string(),
1013 };
1014 let json = serde_json::to_string(&payload).unwrap();
1015 assert!(json.contains("\"type\":\"error\""));
1016 assert!(json.contains("NOT_FOUND"));
1017 }
1018
1019 #[test]
1020 fn resolve_sse_auth_context_prefers_request_auth_when_query_token_absent() {
1021 let request_auth =
1022 AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new());
1023
1024 let resolved = resolve_sse_auth_context(&request_auth, None);
1025
1026 assert!(resolved.is_authenticated());
1027 assert_eq!(resolved.principal_id(), request_auth.principal_id());
1028 }
1029
1030 #[test]
1031 fn resolve_sse_auth_context_prefers_query_token_when_present() {
1032 let request_auth =
1033 AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new());
1034 let query_auth =
1035 AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new());
1036
1037 let resolved = resolve_sse_auth_context(&request_auth, Some(query_auth.clone()));
1038
1039 assert_eq!(resolved.principal_id(), query_auth.principal_id());
1040 }
1041}