1use std::collections::HashMap;
4use std::convert::Infallible;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use axum::Json;
11use axum::extract::{Extension, Query, State};
12use axum::http::StatusCode;
13use axum::response::IntoResponse;
14use axum::response::sse::{Event, KeepAlive, Sse};
15use futures_util::Stream;
16use serde::{Deserialize, Serialize};
17use tokio::sync::{RwLock, mpsc};
18use tokio_util::sync::CancellationToken;
19
20struct ReceiverStream<T> {
22 rx: mpsc::Receiver<T>,
23}
24
25impl<T> Stream for ReceiverStream<T> {
26 type Item = T;
27 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
28 self.rx.poll_recv(cx)
29 }
30}
31
32use forge_core::function::AuthContext;
33use forge_core::realtime::{SessionId, SubscriptionId};
34
35use super::auth::AuthMiddleware;
36use crate::realtime::Reactor;
37use crate::realtime::RealtimeMessage;
38
39const MAX_CLIENT_SUB_ID_LEN: usize = 255;
41
42const MAX_SUBSCRIPTIONS_PER_SESSION: usize = 100;
44
45fn try_parse_session_id(session_id: &str) -> Option<SessionId> {
46 uuid::Uuid::parse_str(session_id)
47 .ok()
48 .map(SessionId::from_uuid)
49}
50
51fn same_principal(a: &AuthContext, b: &AuthContext) -> bool {
52 match (a.is_authenticated(), b.is_authenticated()) {
53 (false, false) => true,
54 (true, true) => a.principal_id().is_some() && a.principal_id() == b.principal_id(),
55 _ => false,
56 }
57}
58
59fn resolve_sse_auth_context(
60 request_auth: &AuthContext,
61 query_auth: Option<AuthContext>,
62) -> AuthContext {
63 query_auth.unwrap_or_else(|| request_auth.clone())
64}
65
66fn authorize_session_access(
67 session: &SseSessionData,
68 session_secret: &str,
69 requester_auth: &AuthContext,
70) -> Result<AuthContext, (StatusCode, Json<SseSubscribeResponse>)> {
71 if session.session_secret != session_secret {
72 return Err(subscribe_error(
73 StatusCode::UNAUTHORIZED,
74 "INVALID_SESSION_SECRET",
75 "Session secret mismatch",
76 ));
77 }
78
79 if !same_principal(&session.auth_context, requester_auth) {
80 return Err(subscribe_error(
81 StatusCode::FORBIDDEN,
82 "SESSION_PRINCIPAL_MISMATCH",
83 "Request principal does not match session principal",
84 ));
85 }
86
87 Ok(session.auth_context.clone())
88}
89
90#[derive(Debug, Clone)]
92pub struct SseConfig {
93 pub max_sessions: usize,
95 pub channel_buffer_size: usize,
97 pub keepalive_interval_secs: u64,
99}
100
101impl Default for SseConfig {
102 fn default() -> Self {
103 Self {
104 max_sessions: 10_000,
105 channel_buffer_size: 256,
106 keepalive_interval_secs: 30,
107 }
108 }
109}
110
111#[derive(Debug, Deserialize)]
113pub struct SseQuery {
114 pub token: Option<String>,
116}
117
118struct SseSessionData {
119 auth_context: AuthContext,
120 session_secret: String,
121 subscriptions: HashMap<String, SubscriptionId>,
123}
124
125#[derive(Clone)]
127pub struct SseState {
128 reactor: Arc<Reactor>,
129 auth_middleware: Arc<AuthMiddleware>,
130 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
132 config: SseConfig,
133}
134
135impl SseState {
136 pub fn new(reactor: Arc<Reactor>, auth_middleware: Arc<AuthMiddleware>) -> Self {
138 Self::with_config(reactor, auth_middleware, SseConfig::default())
139 }
140
141 pub fn with_config(
143 reactor: Arc<Reactor>,
144 auth_middleware: Arc<AuthMiddleware>,
145 config: SseConfig,
146 ) -> Self {
147 Self {
148 reactor,
149 auth_middleware,
150 sessions: Arc::new(RwLock::new(HashMap::new())),
151 config,
152 }
153 }
154
155 pub async fn can_accept_session(&self) -> bool {
157 self.sessions.read().await.len() < self.config.max_sessions
158 }
159}
160
161struct SessionCleanupGuard {
164 session_id: SessionId,
165 reactor: Arc<Reactor>,
166 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
167 dropped: bool,
168}
169
170impl SessionCleanupGuard {
171 fn new(
172 session_id: SessionId,
173 reactor: Arc<Reactor>,
174 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
175 ) -> Self {
176 Self {
177 session_id,
178 reactor,
179 sessions,
180 dropped: false,
181 }
182 }
183
184 fn mark_closed(&mut self) {
186 self.dropped = true;
187 }
188}
189
190impl Drop for SessionCleanupGuard {
191 fn drop(&mut self) {
192 if self.dropped {
193 return;
194 }
195 let session_id = self.session_id;
196 let reactor = self.reactor.clone();
197 let sessions = self.sessions.clone();
198
199 crate::observability::set_active_connections("sse", -1);
202 if let Ok(handle) = tokio::runtime::Handle::try_current() {
203 handle.spawn(async move {
204 reactor.remove_session(session_id).await;
205 sessions.write().await.remove(&session_id);
206 tracing::debug!(%session_id, "SSE session cleaned up on disconnect");
207 });
208 } else {
209 tracing::warn!(%session_id, "Could not spawn cleanup task, runtime unavailable");
211 }
212 }
213}
214
215#[derive(Debug, Serialize)]
217#[serde(tag = "type", rename_all = "snake_case")]
218pub enum SsePayload {
219 Update {
221 target: String,
222 payload: serde_json::Value,
223 },
224 Error {
226 target: String,
227 code: String,
228 message: String,
229 },
230 Connected {
232 session_id: String,
233 session_secret: String,
234 },
235}
236
237#[derive(Debug)]
239pub enum SseMessage {
240 Data {
241 target: String,
242 payload: serde_json::Value,
243 },
244 Error {
245 target: String,
246 code: String,
247 message: String,
248 },
249}
250
251#[derive(Debug, Deserialize)]
253pub struct SseSubscribeRequest {
254 pub session_id: String,
255 pub session_secret: String,
256 pub id: String,
257 pub function: String,
258 #[serde(default)]
259 pub args: serde_json::Value,
260}
261
262#[derive(Debug, Deserialize)]
264pub struct SseUnsubscribeRequest {
265 pub session_id: String,
266 pub session_secret: String,
267 pub id: String,
268}
269
270#[derive(Debug, Deserialize)]
272pub struct SseJobSubscribeRequest {
273 pub session_id: String,
274 pub session_secret: String,
275 pub id: String,
276 pub job_id: String,
277}
278
279#[derive(Debug, Deserialize)]
281pub struct SseWorkflowSubscribeRequest {
282 pub session_id: String,
283 pub session_secret: String,
284 pub id: String,
285 pub workflow_id: String,
286}
287
288#[derive(Debug, Serialize)]
290pub struct SseError {
291 pub code: String,
292 pub message: String,
293}
294
295#[derive(Debug, Serialize)]
297pub struct SseSubscribeResponse {
298 pub success: bool,
299 #[serde(skip_serializing_if = "Option::is_none")]
300 pub data: Option<serde_json::Value>,
301 #[serde(skip_serializing_if = "Option::is_none")]
302 pub error: Option<SseError>,
303}
304
305#[derive(Debug, Serialize)]
307pub struct SseUnsubscribeResponse {
308 pub success: bool,
309 #[serde(skip_serializing_if = "Option::is_none")]
310 pub error: Option<SseError>,
311}
312
313fn subscribe_error(
315 status: StatusCode,
316 code: impl Into<String>,
317 message: impl Into<String>,
318) -> (StatusCode, Json<SseSubscribeResponse>) {
319 (
320 status,
321 Json(SseSubscribeResponse {
322 success: false,
323 data: None,
324 error: Some(SseError {
325 code: code.into(),
326 message: message.into(),
327 }),
328 }),
329 )
330}
331
332fn unsubscribe_error(
334 status: StatusCode,
335 code: impl Into<String>,
336 message: impl Into<String>,
337) -> (StatusCode, Json<SseUnsubscribeResponse>) {
338 (
339 status,
340 Json(SseUnsubscribeResponse {
341 success: false,
342 error: Some(SseError {
343 code: code.into(),
344 message: message.into(),
345 }),
346 }),
347 )
348}
349
350pub async fn sse_handler(
352 State(state): State<Arc<SseState>>,
353 Extension(request_auth): Extension<AuthContext>,
354 Query(query): Query<SseQuery>,
355) -> impl IntoResponse {
356 if !state.can_accept_session().await {
358 return (
359 StatusCode::SERVICE_UNAVAILABLE,
360 "Server at capacity".to_string(),
361 )
362 .into_response();
363 }
364
365 let session_id = SessionId::new();
366 let buffer_size = state.config.channel_buffer_size;
367 let keepalive_secs = state.config.keepalive_interval_secs;
368 let (tx, mut rx) = mpsc::channel::<SseMessage>(buffer_size);
369 let cancel_token = CancellationToken::new();
370
371 let query_auth = if let Some(token) = &query.token {
372 match state.auth_middleware.validate_token_async(token).await {
373 Ok(claims) => Some(super::auth::build_auth_context_from_claims(claims)),
374 Err(e) => {
375 tracing::warn!("SSE token validation failed: {}", e);
376 return (
377 StatusCode::UNAUTHORIZED,
378 "Invalid authentication token".to_string(),
379 )
380 .into_response();
381 }
382 }
383 } else {
384 None
385 };
386 let auth_context = resolve_sse_auth_context(&request_auth, query_auth);
387 let session_secret = uuid::Uuid::new_v4().to_string();
388
389 let reactor = state.reactor.clone();
391 let cancel = cancel_token.clone();
392
393 let (rt_tx, mut rt_rx) = mpsc::channel(buffer_size);
395 reactor.register_session(session_id, rt_tx);
396
397 {
399 let mut sessions = state.sessions.write().await;
400 sessions.insert(
401 session_id,
402 SseSessionData {
403 auth_context: auth_context.clone(),
404 session_secret: session_secret.clone(),
405 subscriptions: HashMap::new(),
406 },
407 );
408 }
409
410 let sessions = state.sessions.clone();
412
413 let cleanup_guard = SessionCleanupGuard::new(session_id, reactor.clone(), sessions.clone());
415 crate::observability::set_active_connections("sse", 1);
416
417 let bridge_cancel = cancel_token.clone();
419 tokio::spawn(async move {
420 loop {
421 tokio::select! {
422 msg = rt_rx.recv() => {
423 match msg {
424 Some(rt_msg) => {
425 if let Some(sse_msg) = convert_realtime_to_sse(rt_msg)
426 && tx.send(sse_msg).await.is_err() {
427 break;
428 }
429 }
430 None => break,
431 }
432 }
433 _ = bridge_cancel.cancelled() => break,
434 }
435 }
436 });
437
438 let (event_tx, event_rx) = mpsc::channel::<Result<Event, Infallible>>(buffer_size);
440
441 tokio::spawn(async move {
442 let mut _guard = cleanup_guard;
443
444 let connected = SsePayload::Connected {
446 session_id: session_id.to_string(),
447 session_secret: session_secret.clone(),
448 };
449 match serde_json::to_string(&connected) {
450 Ok(json) => {
451 let _ = event_tx
452 .send(Ok(Event::default().event("connected").data(json)))
453 .await;
454 }
455 Err(e) => {
456 tracing::error!("Failed to serialize SSE connected payload: {}", e);
457 }
458 }
459
460 loop {
461 tokio::select! {
462 msg = rx.recv() => {
463 match msg {
464 Some(sse_msg) => {
465 let event = match sse_msg {
466 SseMessage::Data { target, payload } => {
467 let data = SsePayload::Update { target, payload };
468 serde_json::to_string(&data).ok().map(|json| {
469 Event::default().event("update").data(json)
470 })
471 }
472 SseMessage::Error { target, code, message } => {
473 let data = SsePayload::Error { target, code, message };
474 serde_json::to_string(&data).ok().map(|json| {
475 Event::default().event("error").data(json)
476 })
477 }
478 };
479 if let Some(evt) = event
480 && event_tx.send(Ok(evt)).await.is_err()
481 {
482 break;
483 }
484 }
485 None => break,
486 }
487 }
488 _ = cancel.cancelled() => break,
489 }
490 }
491
492 _guard.mark_closed();
494 crate::observability::set_active_connections("sse", -1);
495 reactor.remove_session(session_id).await;
496 sessions.write().await.remove(&session_id);
497 });
498
499 let stream = ReceiverStream { rx: event_rx };
500
501 Sse::new(stream)
502 .keep_alive(
503 KeepAlive::new()
504 .interval(Duration::from_secs(keepalive_secs))
505 .text("ping"),
506 )
507 .into_response()
508}
509
510fn convert_realtime_to_sse(msg: RealtimeMessage) -> Option<SseMessage> {
512 match msg {
513 RealtimeMessage::Data {
514 subscription_id,
515 data,
516 } => Some(SseMessage::Data {
517 target: format!("sub:{}", subscription_id),
518 payload: data,
519 }),
520 RealtimeMessage::DeltaUpdate {
521 subscription_id,
522 delta,
523 } => match serde_json::to_value(&delta) {
524 Ok(payload) => Some(SseMessage::Data {
525 target: format!("sub:{}", subscription_id),
526 payload,
527 }),
528 Err(e) => {
529 tracing::error!("Failed to serialize delta update: {}", e);
530 Some(SseMessage::Error {
531 target: format!("sub:{}", subscription_id),
532 code: "SERIALIZATION_ERROR".to_string(),
533 message: "Failed to serialize update data".to_string(),
534 })
535 }
536 },
537 RealtimeMessage::JobUpdate { client_sub_id, job } => match serde_json::to_value(&job) {
538 Ok(payload) => Some(SseMessage::Data {
539 target: format!("job:{}", client_sub_id),
540 payload,
541 }),
542 Err(e) => {
543 tracing::error!("Failed to serialize job update: {}", e);
544 Some(SseMessage::Error {
545 target: format!("job:{}", client_sub_id),
546 code: "SERIALIZATION_ERROR".to_string(),
547 message: "Failed to serialize job update".to_string(),
548 })
549 }
550 },
551 RealtimeMessage::WorkflowUpdate {
552 client_sub_id,
553 workflow,
554 } => match serde_json::to_value(&workflow) {
555 Ok(payload) => Some(SseMessage::Data {
556 target: format!("wf:{}", client_sub_id),
557 payload,
558 }),
559 Err(e) => {
560 tracing::error!("Failed to serialize workflow update: {}", e);
561 Some(SseMessage::Error {
562 target: format!("wf:{}", client_sub_id),
563 code: "SERIALIZATION_ERROR".to_string(),
564 message: "Failed to serialize workflow update".to_string(),
565 })
566 }
567 },
568 RealtimeMessage::Error { code, message } => Some(SseMessage::Error {
569 target: String::new(),
570 code,
571 message,
572 }),
573 RealtimeMessage::ErrorWithId { id, code, message } => Some(SseMessage::Error {
574 target: id,
575 code,
576 message,
577 }),
578 RealtimeMessage::Subscribe { .. }
580 | RealtimeMessage::Unsubscribe { .. }
581 | RealtimeMessage::Ping
582 | RealtimeMessage::Pong
583 | RealtimeMessage::AuthSuccess
584 | RealtimeMessage::AuthFailed { .. }
585 | RealtimeMessage::Lagging => None,
586 }
587}
588
589pub async fn sse_subscribe_handler(
591 State(state): State<Arc<SseState>>,
592 Extension(request_auth): Extension<AuthContext>,
593 Json(request): Json<SseSubscribeRequest>,
594) -> impl IntoResponse {
595 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
597 return subscribe_error(
598 StatusCode::BAD_REQUEST,
599 "INVALID_ID",
600 format!(
601 "Subscription ID too long (max {} chars)",
602 MAX_CLIENT_SUB_ID_LEN
603 ),
604 );
605 }
606
607 let Some(session_id) = try_parse_session_id(&request.session_id) else {
608 return subscribe_error(
609 StatusCode::BAD_REQUEST,
610 "INVALID_SESSION",
611 "Invalid session ID format",
612 );
613 };
614
615 let sessions = state.sessions.write().await;
617 let session_data = match sessions.get(&session_id) {
618 Some(data) => {
619 if data.subscriptions.len() >= MAX_SUBSCRIPTIONS_PER_SESSION {
621 return subscribe_error(
622 StatusCode::TOO_MANY_REQUESTS,
623 "TOO_MANY_SUBSCRIPTIONS",
624 format!(
625 "Session has reached the maximum of {} subscriptions",
626 MAX_SUBSCRIPTIONS_PER_SESSION
627 ),
628 );
629 }
630 match authorize_session_access(data, &request.session_secret, &request_auth) {
631 Ok(auth) => auth,
632 Err(resp) => return resp,
633 }
634 }
635 None => {
636 return subscribe_error(
637 StatusCode::NOT_FOUND,
638 "SESSION_NOT_FOUND",
639 "Session not found or expired",
640 );
641 }
642 };
643 drop(sessions);
644
645 let result = state
647 .reactor
648 .subscribe(
649 session_id,
650 request.id.clone(),
651 request.function,
652 request.args,
653 session_data,
654 )
655 .await;
656
657 match result {
658 Ok((subscription_id, data)) => {
659 let mut sessions = state.sessions.write().await;
661 match sessions.get_mut(&session_id) {
662 Some(session) => {
663 session.subscriptions.insert(request.id, subscription_id);
664 }
665 None => {
666 return subscribe_error(
668 StatusCode::NOT_FOUND,
669 "SESSION_NOT_FOUND",
670 "Session expired during subscription",
671 );
672 }
673 }
674
675 tracing::debug!(
676 %session_id,
677 %subscription_id,
678 "SSE subscription registered"
679 );
680
681 (
682 StatusCode::OK,
683 Json(SseSubscribeResponse {
684 success: true,
685 data: Some(data),
686 error: None,
687 }),
688 )
689 }
690 Err(e) => {
691 tracing::warn!(%session_id, error = %e, "SSE subscription failed");
692 match e {
693 forge_core::ForgeError::Unauthorized(msg) => {
694 subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
695 }
696 forge_core::ForgeError::Forbidden(msg) => {
697 subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
698 }
699 forge_core::ForgeError::InvalidArgument(msg)
700 | forge_core::ForgeError::Validation(msg) => {
701 subscribe_error(StatusCode::BAD_REQUEST, "INVALID_ARGUMENT", msg)
702 }
703 forge_core::ForgeError::NotFound(msg) => {
704 subscribe_error(StatusCode::NOT_FOUND, "NOT_FOUND", msg)
705 }
706 _ => subscribe_error(
707 StatusCode::INTERNAL_SERVER_ERROR,
708 "SUBSCRIPTION_FAILED",
709 "Subscription failed",
710 ),
711 }
712 }
713 }
714}
715
716pub async fn sse_unsubscribe_handler(
718 State(state): State<Arc<SseState>>,
719 Extension(request_auth): Extension<AuthContext>,
720 Json(request): Json<SseUnsubscribeRequest>,
721) -> impl IntoResponse {
722 let Some(session_id) = try_parse_session_id(&request.session_id) else {
723 return unsubscribe_error(
724 StatusCode::BAD_REQUEST,
725 "INVALID_SESSION",
726 "Invalid session ID format",
727 );
728 };
729
730 let subscription_id = {
732 let sessions = state.sessions.read().await;
733 match sessions.get(&session_id) {
734 Some(session) => {
735 if session.session_secret != request.session_secret
736 || !same_principal(&session.auth_context, &request_auth)
737 {
738 return unsubscribe_error(
739 StatusCode::FORBIDDEN,
740 "SESSION_PRINCIPAL_MISMATCH",
741 "Request principal does not match session principal",
742 );
743 }
744 session.subscriptions.get(&request.id).copied()
745 }
746 None => None,
747 }
748 };
749
750 let Some(subscription_id) = subscription_id else {
751 return unsubscribe_error(
752 StatusCode::NOT_FOUND,
753 "SUBSCRIPTION_NOT_FOUND",
754 "Subscription not found",
755 );
756 };
757
758 state.reactor.unsubscribe(subscription_id);
760
761 {
763 let mut sessions = state.sessions.write().await;
764 if let Some(session) = sessions.get_mut(&session_id) {
765 session.subscriptions.remove(&request.id);
766 }
767 }
768
769 tracing::debug!(
770 %session_id,
771 %subscription_id,
772 "SSE subscription removed"
773 );
774
775 (
776 StatusCode::OK,
777 Json(SseUnsubscribeResponse {
778 success: true,
779 error: None,
780 }),
781 )
782}
783
784pub async fn sse_job_subscribe_handler(
786 State(state): State<Arc<SseState>>,
787 Extension(request_auth): Extension<AuthContext>,
788 Json(request): Json<SseJobSubscribeRequest>,
789) -> impl IntoResponse {
790 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
791 return subscribe_error(
792 StatusCode::BAD_REQUEST,
793 "INVALID_ID",
794 format!(
795 "Subscription ID too long (max {} chars)",
796 MAX_CLIENT_SUB_ID_LEN
797 ),
798 );
799 }
800
801 let Some(session_id) = try_parse_session_id(&request.session_id) else {
802 return subscribe_error(
803 StatusCode::BAD_REQUEST,
804 "INVALID_SESSION",
805 "Invalid session ID format",
806 );
807 };
808
809 let session_auth = {
811 let sessions = state.sessions.read().await;
812 match sessions.get(&session_id) {
813 Some(session) => {
814 match authorize_session_access(session, &request.session_secret, &request_auth) {
815 Ok(auth) => auth,
816 Err(resp) => return resp,
817 }
818 }
819 None => {
820 return subscribe_error(
821 StatusCode::NOT_FOUND,
822 "SESSION_NOT_FOUND",
823 "Session not found or expired",
824 );
825 }
826 }
827 };
828
829 let job_uuid = match uuid::Uuid::parse_str(&request.job_id) {
831 Ok(uuid) => uuid,
832 Err(_) => {
833 return subscribe_error(
834 StatusCode::BAD_REQUEST,
835 "INVALID_JOB_ID",
836 "Invalid job ID format",
837 );
838 }
839 };
840
841 match state
843 .reactor
844 .subscribe_job(session_id, request.id.clone(), job_uuid, &session_auth)
845 .await
846 {
847 Ok(job_data) => {
848 let data = match serde_json::to_value(&job_data) {
849 Ok(v) => v,
850 Err(e) => {
851 tracing::error!("Failed to serialize job data: {}", e);
852 return subscribe_error(
853 StatusCode::INTERNAL_SERVER_ERROR,
854 "SERIALIZE_ERROR",
855 "Failed to serialize job data",
856 );
857 }
858 };
859 tracing::debug!(
860 %session_id,
861 job_id = %request.job_id,
862 client_sub_id = %request.id,
863 "SSE job subscription registered"
864 );
865 (
866 StatusCode::OK,
867 Json(SseSubscribeResponse {
868 success: true,
869 data: Some(data),
870 error: None,
871 }),
872 )
873 }
874 Err(e) => match e {
875 forge_core::ForgeError::Unauthorized(msg) => {
876 subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
877 }
878 forge_core::ForgeError::Forbidden(msg) => {
879 subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
880 }
881 forge_core::ForgeError::NotFound(msg) => {
882 subscribe_error(StatusCode::NOT_FOUND, "JOB_NOT_FOUND", msg)
883 }
884 _ => subscribe_error(
885 StatusCode::INTERNAL_SERVER_ERROR,
886 "SUBSCRIPTION_FAILED",
887 "Subscription failed",
888 ),
889 },
890 }
891}
892
893pub async fn sse_workflow_subscribe_handler(
895 State(state): State<Arc<SseState>>,
896 Extension(request_auth): Extension<AuthContext>,
897 Json(request): Json<SseWorkflowSubscribeRequest>,
898) -> impl IntoResponse {
899 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
900 return subscribe_error(
901 StatusCode::BAD_REQUEST,
902 "INVALID_ID",
903 format!(
904 "Subscription ID too long (max {} chars)",
905 MAX_CLIENT_SUB_ID_LEN
906 ),
907 );
908 }
909
910 let Some(session_id) = try_parse_session_id(&request.session_id) else {
911 return subscribe_error(
912 StatusCode::BAD_REQUEST,
913 "INVALID_SESSION",
914 "Invalid session ID format",
915 );
916 };
917
918 let session_auth = {
920 let sessions = state.sessions.read().await;
921 match sessions.get(&session_id) {
922 Some(session) => {
923 match authorize_session_access(session, &request.session_secret, &request_auth) {
924 Ok(auth) => auth,
925 Err(resp) => return resp,
926 }
927 }
928 None => {
929 return subscribe_error(
930 StatusCode::NOT_FOUND,
931 "SESSION_NOT_FOUND",
932 "Session not found or expired",
933 );
934 }
935 }
936 };
937
938 let workflow_uuid = match uuid::Uuid::parse_str(&request.workflow_id) {
940 Ok(uuid) => uuid,
941 Err(_) => {
942 return subscribe_error(
943 StatusCode::BAD_REQUEST,
944 "INVALID_WORKFLOW_ID",
945 "Invalid workflow ID format",
946 );
947 }
948 };
949
950 match state
952 .reactor
953 .subscribe_workflow(session_id, request.id.clone(), workflow_uuid, &session_auth)
954 .await
955 {
956 Ok(workflow_data) => {
957 let data = match serde_json::to_value(&workflow_data) {
958 Ok(v) => v,
959 Err(e) => {
960 tracing::error!("Failed to serialize workflow data: {}", e);
961 return subscribe_error(
962 StatusCode::INTERNAL_SERVER_ERROR,
963 "SERIALIZE_ERROR",
964 "Failed to serialize workflow data",
965 );
966 }
967 };
968 tracing::debug!(
969 %session_id,
970 workflow_id = %request.workflow_id,
971 client_sub_id = %request.id,
972 "SSE workflow subscription registered"
973 );
974 (
975 StatusCode::OK,
976 Json(SseSubscribeResponse {
977 success: true,
978 data: Some(data),
979 error: None,
980 }),
981 )
982 }
983 Err(e) => match e {
984 forge_core::ForgeError::Unauthorized(msg) => {
985 subscribe_error(StatusCode::UNAUTHORIZED, "UNAUTHORIZED", msg)
986 }
987 forge_core::ForgeError::Forbidden(msg) => {
988 subscribe_error(StatusCode::FORBIDDEN, "FORBIDDEN", msg)
989 }
990 forge_core::ForgeError::NotFound(msg) => {
991 subscribe_error(StatusCode::NOT_FOUND, "WORKFLOW_NOT_FOUND", msg)
992 }
993 _ => subscribe_error(
994 StatusCode::INTERNAL_SERVER_ERROR,
995 "SUBSCRIPTION_FAILED",
996 "Subscription failed",
997 ),
998 },
999 }
1000}
1001
1002#[cfg(test)]
1003#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
1004mod tests {
1005 use super::*;
1006 use std::collections::HashMap;
1007
1008 use uuid::Uuid;
1009
1010 #[test]
1011 fn test_sse_payload_serialization() {
1012 let payload = SsePayload::Update {
1013 target: "sub:123".to_string(),
1014 payload: serde_json::json!({"id": 1}),
1015 };
1016 let json = serde_json::to_string(&payload).unwrap();
1017 assert!(json.contains("\"type\":\"update\""));
1018 assert!(json.contains("\"target\":\"sub:123\""));
1019 }
1020
1021 #[test]
1022 fn test_sse_error_serialization() {
1023 let payload = SsePayload::Error {
1024 target: "sub:456".to_string(),
1025 code: "NOT_FOUND".to_string(),
1026 message: "Subscription not found".to_string(),
1027 };
1028 let json = serde_json::to_string(&payload).unwrap();
1029 assert!(json.contains("\"type\":\"error\""));
1030 assert!(json.contains("NOT_FOUND"));
1031 }
1032
1033 #[test]
1034 fn resolve_sse_auth_context_prefers_request_auth_when_query_token_absent() {
1035 let request_auth =
1036 AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new());
1037
1038 let resolved = resolve_sse_auth_context(&request_auth, None);
1039
1040 assert!(resolved.is_authenticated());
1041 assert_eq!(resolved.principal_id(), request_auth.principal_id());
1042 }
1043
1044 #[test]
1045 fn resolve_sse_auth_context_prefers_query_token_when_present() {
1046 let request_auth =
1047 AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new());
1048 let query_auth =
1049 AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new());
1050
1051 let resolved = resolve_sse_auth_context(&request_auth, Some(query_auth.clone()));
1052
1053 assert_eq!(resolved.principal_id(), query_auth.principal_id());
1054 }
1055}