1use std::collections::HashMap;
4use std::convert::Infallible;
5use std::sync::Arc;
6use std::time::Duration;
7
8use axum::Json;
9use axum::extract::{Query, State};
10use axum::http::StatusCode;
11use axum::response::IntoResponse;
12use axum::response::sse::{Event, KeepAlive, Sse};
13use serde::{Deserialize, Serialize};
14use tokio::sync::{RwLock, mpsc};
15use tokio_util::sync::CancellationToken;
16
17use forge_core::function::AuthContext;
18use forge_core::realtime::{SessionId, SubscriptionId};
19
20use super::auth::AuthMiddleware;
21use crate::realtime::Reactor;
22use crate::realtime::RealtimeMessage;
23
24const MAX_CLIENT_SUB_ID_LEN: usize = 255;
26
27fn try_parse_session_id(session_id: &str) -> Option<SessionId> {
28 uuid::Uuid::parse_str(session_id)
29 .ok()
30 .map(SessionId::from_uuid)
31}
32
33#[derive(Debug, Clone)]
35pub struct SseConfig {
36 pub max_sessions: usize,
38 pub channel_buffer_size: usize,
40 pub keepalive_interval_secs: u64,
42}
43
44impl Default for SseConfig {
45 fn default() -> Self {
46 Self {
47 max_sessions: 10_000,
48 channel_buffer_size: 256,
49 keepalive_interval_secs: 30,
50 }
51 }
52}
53
54#[derive(Debug, Deserialize)]
56pub struct SseQuery {
57 pub token: Option<String>,
59}
60
61struct SseSessionData {
62 auth_context: AuthContext,
63 subscriptions: HashMap<String, SubscriptionId>,
65}
66
67#[derive(Clone)]
69pub struct SseState {
70 reactor: Arc<Reactor>,
71 auth_middleware: Arc<AuthMiddleware>,
72 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
74 config: SseConfig,
75}
76
77impl SseState {
78 pub fn new(reactor: Arc<Reactor>, auth_middleware: Arc<AuthMiddleware>) -> Self {
80 Self::with_config(reactor, auth_middleware, SseConfig::default())
81 }
82
83 pub fn with_config(
85 reactor: Arc<Reactor>,
86 auth_middleware: Arc<AuthMiddleware>,
87 config: SseConfig,
88 ) -> Self {
89 Self {
90 reactor,
91 auth_middleware,
92 sessions: Arc::new(RwLock::new(HashMap::new())),
93 config,
94 }
95 }
96
97 pub async fn can_accept_session(&self) -> bool {
99 self.sessions.read().await.len() < self.config.max_sessions
100 }
101}
102
103struct SessionCleanupGuard {
106 session_id: SessionId,
107 reactor: Arc<Reactor>,
108 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
109 dropped: bool,
110}
111
112impl SessionCleanupGuard {
113 fn new(
114 session_id: SessionId,
115 reactor: Arc<Reactor>,
116 sessions: Arc<RwLock<HashMap<SessionId, SseSessionData>>>,
117 ) -> Self {
118 Self {
119 session_id,
120 reactor,
121 sessions,
122 dropped: false,
123 }
124 }
125
126 fn mark_closed(&mut self) {
128 self.dropped = true;
129 }
130}
131
132impl Drop for SessionCleanupGuard {
133 fn drop(&mut self) {
134 if self.dropped {
135 return;
136 }
137 let session_id = self.session_id;
138 let reactor = self.reactor.clone();
139 let sessions = self.sessions.clone();
140
141 if let Ok(handle) = tokio::runtime::Handle::try_current() {
144 handle.spawn(async move {
145 reactor.remove_session(session_id).await;
146 sessions.write().await.remove(&session_id);
147 tracing::debug!(%session_id, "SSE session cleaned up on disconnect");
148 });
149 } else {
150 tracing::warn!(%session_id, "Could not spawn cleanup task, runtime unavailable");
152 }
153 }
154}
155
156#[derive(Debug, Serialize)]
158#[serde(tag = "type", rename_all = "snake_case")]
159pub enum SsePayload {
160 Update {
162 target: String,
163 payload: serde_json::Value,
164 },
165 Error {
167 target: String,
168 code: String,
169 message: String,
170 },
171 Connected { session_id: String },
173}
174
175#[derive(Debug)]
177pub enum SseMessage {
178 Data {
179 target: String,
180 payload: serde_json::Value,
181 },
182 Error {
183 target: String,
184 code: String,
185 message: String,
186 },
187}
188
189#[derive(Debug, Deserialize)]
191pub struct SseSubscribeRequest {
192 pub session_id: String,
193 pub id: String,
194 pub function: String,
195 #[serde(default)]
196 pub args: serde_json::Value,
197}
198
199#[derive(Debug, Deserialize)]
201pub struct SseUnsubscribeRequest {
202 pub session_id: String,
203 pub id: String,
204}
205
206#[derive(Debug, Deserialize)]
208pub struct SseJobSubscribeRequest {
209 pub session_id: String,
210 pub id: String,
211 pub job_id: String,
212}
213
214#[derive(Debug, Deserialize)]
216pub struct SseWorkflowSubscribeRequest {
217 pub session_id: String,
218 pub id: String,
219 pub workflow_id: String,
220}
221
222#[derive(Debug, Serialize)]
224pub struct SseError {
225 pub code: String,
226 pub message: String,
227}
228
229#[derive(Debug, Serialize)]
231pub struct SseSubscribeResponse {
232 pub success: bool,
233 #[serde(skip_serializing_if = "Option::is_none")]
234 pub data: Option<serde_json::Value>,
235 #[serde(skip_serializing_if = "Option::is_none")]
236 pub error: Option<SseError>,
237}
238
239#[derive(Debug, Serialize)]
241pub struct SseUnsubscribeResponse {
242 pub success: bool,
243 #[serde(skip_serializing_if = "Option::is_none")]
244 pub error: Option<SseError>,
245}
246
247fn subscribe_error(
249 status: StatusCode,
250 code: impl Into<String>,
251 message: impl Into<String>,
252) -> (StatusCode, Json<SseSubscribeResponse>) {
253 (
254 status,
255 Json(SseSubscribeResponse {
256 success: false,
257 data: None,
258 error: Some(SseError {
259 code: code.into(),
260 message: message.into(),
261 }),
262 }),
263 )
264}
265
266fn unsubscribe_error(
268 status: StatusCode,
269 code: impl Into<String>,
270 message: impl Into<String>,
271) -> (StatusCode, Json<SseUnsubscribeResponse>) {
272 (
273 status,
274 Json(SseUnsubscribeResponse {
275 success: false,
276 error: Some(SseError {
277 code: code.into(),
278 message: message.into(),
279 }),
280 }),
281 )
282}
283
284pub async fn sse_handler(
286 State(state): State<Arc<SseState>>,
287 Query(query): Query<SseQuery>,
288) -> impl IntoResponse {
289 if !state.can_accept_session().await {
291 return (
292 StatusCode::SERVICE_UNAVAILABLE,
293 "Server at capacity".to_string(),
294 )
295 .into_response();
296 }
297
298 let session_id = SessionId::new();
299 let buffer_size = state.config.channel_buffer_size;
300 let keepalive_secs = state.config.keepalive_interval_secs;
301 let (tx, mut rx) = mpsc::channel::<SseMessage>(buffer_size);
302 let cancel_token = CancellationToken::new();
303
304 let auth_context = if let Some(token) = &query.token {
305 match state.auth_middleware.validate_token(token) {
306 Ok(claims) => super::auth::build_auth_context_from_claims(claims),
307 Err(e) => {
308 tracing::debug!(
309 "SSE token validation failed, continuing unauthenticated: {}",
310 e
311 );
312 forge_core::function::AuthContext::unauthenticated()
313 }
314 }
315 } else {
316 forge_core::function::AuthContext::unauthenticated()
317 };
318
319 let reactor = state.reactor.clone();
321 let cancel = cancel_token.clone();
322
323 let (rt_tx, mut rt_rx) = mpsc::channel(buffer_size);
325 reactor.register_session(session_id, rt_tx).await;
326
327 {
329 let mut sessions = state.sessions.write().await;
330 sessions.insert(
331 session_id,
332 SseSessionData {
333 auth_context: auth_context.clone(),
334 subscriptions: HashMap::new(),
335 },
336 );
337 }
338
339 let sessions = state.sessions.clone();
341
342 let cleanup_guard = SessionCleanupGuard::new(session_id, reactor.clone(), sessions.clone());
344
345 let bridge_cancel = cancel_token.clone();
347 tokio::spawn(async move {
348 loop {
349 tokio::select! {
350 msg = rt_rx.recv() => {
351 match msg {
352 Some(rt_msg) => {
353 if let Some(sse_msg) = convert_realtime_to_sse(rt_msg) {
354 if tx.send(sse_msg).await.is_err() {
355 break;
356 }
357 }
358 }
359 None => break,
360 }
361 }
362 _ = bridge_cancel.cancelled() => break,
363 }
364 }
365 });
366
367 let stream = async_stream::stream! {
369 let mut _guard = cleanup_guard;
371
372 let connected = SsePayload::Connected {
374 session_id: session_id.to_string(),
375 };
376 match serde_json::to_string(&connected) {
377 Ok(json) => {
378 yield Ok::<Event, Infallible>(Event::default().event("connected").data(json));
379 }
380 Err(e) => {
381 tracing::error!("Failed to serialize SSE connected payload: {}", e);
382 }
383 }
384
385 loop {
386 tokio::select! {
387 msg = rx.recv() => {
388 match msg {
389 Some(SseMessage::Data { target, payload }) => {
390 let event_data = SsePayload::Update { target, payload };
391 match serde_json::to_string(&event_data) {
392 Ok(json) => {
393 yield Ok::<Event, Infallible>(Event::default().event("update").data(json));
394 }
395 Err(e) => {
396 tracing::error!("Failed to serialize SSE update payload: {}", e);
397 }
398 }
399 }
400 Some(SseMessage::Error { target, code, message }) => {
401 let event_data = SsePayload::Error { target, code, message };
402 match serde_json::to_string(&event_data) {
403 Ok(json) => {
404 yield Ok::<Event, Infallible>(Event::default().event("error").data(json));
405 }
406 Err(e) => {
407 tracing::error!("Failed to serialize SSE error payload: {}", e);
408 }
409 }
410 }
411 None => break,
412 }
413 }
414 _ = cancel.cancelled() => break,
415 }
416 }
417
418 _guard.mark_closed();
420 reactor.remove_session(session_id).await;
421 sessions.write().await.remove(&session_id);
422 };
423
424 Sse::new(stream)
425 .keep_alive(
426 KeepAlive::new()
427 .interval(Duration::from_secs(keepalive_secs))
428 .text("ping"),
429 )
430 .into_response()
431}
432
433fn convert_realtime_to_sse(msg: RealtimeMessage) -> Option<SseMessage> {
435 match msg {
436 RealtimeMessage::Data {
437 subscription_id,
438 data,
439 } => Some(SseMessage::Data {
440 target: format!("sub:{}", subscription_id),
441 payload: data,
442 }),
443 RealtimeMessage::DeltaUpdate {
444 subscription_id,
445 delta,
446 } => match serde_json::to_value(&delta) {
447 Ok(payload) => Some(SseMessage::Data {
448 target: format!("sub:{}", subscription_id),
449 payload,
450 }),
451 Err(e) => {
452 tracing::error!("Failed to serialize delta update: {}", e);
453 Some(SseMessage::Error {
454 target: format!("sub:{}", subscription_id),
455 code: "SERIALIZATION_ERROR".to_string(),
456 message: "Failed to serialize update data".to_string(),
457 })
458 }
459 },
460 RealtimeMessage::JobUpdate { client_sub_id, job } => match serde_json::to_value(&job) {
461 Ok(payload) => Some(SseMessage::Data {
462 target: format!("job:{}", client_sub_id),
463 payload,
464 }),
465 Err(e) => {
466 tracing::error!("Failed to serialize job update: {}", e);
467 Some(SseMessage::Error {
468 target: format!("job:{}", client_sub_id),
469 code: "SERIALIZATION_ERROR".to_string(),
470 message: "Failed to serialize job update".to_string(),
471 })
472 }
473 },
474 RealtimeMessage::WorkflowUpdate {
475 client_sub_id,
476 workflow,
477 } => match serde_json::to_value(&workflow) {
478 Ok(payload) => Some(SseMessage::Data {
479 target: format!("wf:{}", client_sub_id),
480 payload,
481 }),
482 Err(e) => {
483 tracing::error!("Failed to serialize workflow update: {}", e);
484 Some(SseMessage::Error {
485 target: format!("wf:{}", client_sub_id),
486 code: "SERIALIZATION_ERROR".to_string(),
487 message: "Failed to serialize workflow update".to_string(),
488 })
489 }
490 },
491 RealtimeMessage::Error { code, message } => Some(SseMessage::Error {
492 target: String::new(),
493 code,
494 message,
495 }),
496 RealtimeMessage::ErrorWithId { id, code, message } => Some(SseMessage::Error {
497 target: id,
498 code,
499 message,
500 }),
501 RealtimeMessage::Subscribe { .. }
503 | RealtimeMessage::Unsubscribe { .. }
504 | RealtimeMessage::Ping
505 | RealtimeMessage::Pong
506 | RealtimeMessage::AuthSuccess
507 | RealtimeMessage::AuthFailed { .. } => None,
508 }
509}
510
511pub async fn sse_subscribe_handler(
513 State(state): State<Arc<SseState>>,
514 Json(request): Json<SseSubscribeRequest>,
515) -> impl IntoResponse {
516 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
518 return subscribe_error(
519 StatusCode::BAD_REQUEST,
520 "INVALID_ID",
521 format!(
522 "Subscription ID too long (max {} chars)",
523 MAX_CLIENT_SUB_ID_LEN
524 ),
525 );
526 }
527
528 let Some(session_id) = try_parse_session_id(&request.session_id) else {
529 return subscribe_error(
530 StatusCode::BAD_REQUEST,
531 "INVALID_SESSION",
532 "Invalid session ID format",
533 );
534 };
535
536 let sessions = state.sessions.read().await;
538 let session_data = match sessions.get(&session_id) {
539 Some(data) => data.auth_context.clone(),
540 None => {
541 return subscribe_error(
542 StatusCode::NOT_FOUND,
543 "SESSION_NOT_FOUND",
544 "Session not found or expired",
545 );
546 }
547 };
548 drop(sessions);
549
550 let result = state
552 .reactor
553 .subscribe(
554 session_id,
555 request.id.clone(),
556 request.function,
557 request.args,
558 session_data,
559 )
560 .await;
561
562 match result {
563 Ok((subscription_id, data)) => {
564 let mut sessions = state.sessions.write().await;
566 match sessions.get_mut(&session_id) {
567 Some(session) => {
568 session.subscriptions.insert(request.id, subscription_id);
569 }
570 None => {
571 return subscribe_error(
573 StatusCode::NOT_FOUND,
574 "SESSION_NOT_FOUND",
575 "Session expired during subscription",
576 );
577 }
578 }
579
580 tracing::debug!(
581 %session_id,
582 %subscription_id,
583 "SSE subscription registered"
584 );
585
586 (
587 StatusCode::OK,
588 Json(SseSubscribeResponse {
589 success: true,
590 data: Some(data),
591 error: None,
592 }),
593 )
594 }
595 Err(e) => {
596 tracing::warn!(%session_id, error = %e, "SSE subscription failed");
597 subscribe_error(
598 StatusCode::INTERNAL_SERVER_ERROR,
599 "SUBSCRIPTION_FAILED",
600 e.to_string(),
601 )
602 }
603 }
604}
605
606pub async fn sse_unsubscribe_handler(
608 State(state): State<Arc<SseState>>,
609 Json(request): Json<SseUnsubscribeRequest>,
610) -> impl IntoResponse {
611 let Some(session_id) = try_parse_session_id(&request.session_id) else {
612 return unsubscribe_error(
613 StatusCode::BAD_REQUEST,
614 "INVALID_SESSION",
615 "Invalid session ID format",
616 );
617 };
618
619 let subscription_id = {
621 let sessions = state.sessions.read().await;
622 sessions
623 .get(&session_id)
624 .and_then(|s| s.subscriptions.get(&request.id).copied())
625 };
626
627 let Some(subscription_id) = subscription_id else {
628 return unsubscribe_error(
629 StatusCode::NOT_FOUND,
630 "SUBSCRIPTION_NOT_FOUND",
631 "Subscription not found",
632 );
633 };
634
635 state.reactor.unsubscribe(subscription_id).await;
637
638 {
640 let mut sessions = state.sessions.write().await;
641 if let Some(session) = sessions.get_mut(&session_id) {
642 session.subscriptions.remove(&request.id);
643 }
644 }
645
646 tracing::debug!(
647 %session_id,
648 %subscription_id,
649 "SSE subscription removed"
650 );
651
652 (
653 StatusCode::OK,
654 Json(SseUnsubscribeResponse {
655 success: true,
656 error: None,
657 }),
658 )
659}
660
661pub async fn sse_job_subscribe_handler(
663 State(state): State<Arc<SseState>>,
664 Json(request): Json<SseJobSubscribeRequest>,
665) -> impl IntoResponse {
666 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
667 return subscribe_error(
668 StatusCode::BAD_REQUEST,
669 "INVALID_ID",
670 format!(
671 "Subscription ID too long (max {} chars)",
672 MAX_CLIENT_SUB_ID_LEN
673 ),
674 );
675 }
676
677 let Some(session_id) = try_parse_session_id(&request.session_id) else {
678 return subscribe_error(
679 StatusCode::BAD_REQUEST,
680 "INVALID_SESSION",
681 "Invalid session ID format",
682 );
683 };
684
685 {
687 let sessions = state.sessions.read().await;
688 if !sessions.contains_key(&session_id) {
689 return subscribe_error(
690 StatusCode::NOT_FOUND,
691 "SESSION_NOT_FOUND",
692 "Session not found or expired",
693 );
694 }
695 }
696
697 let job_uuid = match uuid::Uuid::parse_str(&request.job_id) {
699 Ok(uuid) => uuid,
700 Err(_) => {
701 return subscribe_error(
702 StatusCode::BAD_REQUEST,
703 "INVALID_JOB_ID",
704 "Invalid job ID format",
705 );
706 }
707 };
708
709 match state
711 .reactor
712 .subscribe_job(session_id, request.id.clone(), job_uuid)
713 .await
714 {
715 Ok(job_data) => {
716 let data = match serde_json::to_value(&job_data) {
717 Ok(v) => v,
718 Err(e) => {
719 tracing::error!("Failed to serialize job data: {}", e);
720 return subscribe_error(
721 StatusCode::INTERNAL_SERVER_ERROR,
722 "SERIALIZE_ERROR",
723 "Failed to serialize job data",
724 );
725 }
726 };
727 tracing::debug!(
728 %session_id,
729 job_id = %request.job_id,
730 client_sub_id = %request.id,
731 "SSE job subscription registered"
732 );
733 (
734 StatusCode::OK,
735 Json(SseSubscribeResponse {
736 success: true,
737 data: Some(data),
738 error: None,
739 }),
740 )
741 }
742 Err(e) => subscribe_error(StatusCode::NOT_FOUND, "JOB_NOT_FOUND", e.to_string()),
743 }
744}
745
746pub async fn sse_workflow_subscribe_handler(
748 State(state): State<Arc<SseState>>,
749 Json(request): Json<SseWorkflowSubscribeRequest>,
750) -> impl IntoResponse {
751 if request.id.len() > MAX_CLIENT_SUB_ID_LEN {
752 return subscribe_error(
753 StatusCode::BAD_REQUEST,
754 "INVALID_ID",
755 format!(
756 "Subscription ID too long (max {} chars)",
757 MAX_CLIENT_SUB_ID_LEN
758 ),
759 );
760 }
761
762 let Some(session_id) = try_parse_session_id(&request.session_id) else {
763 return subscribe_error(
764 StatusCode::BAD_REQUEST,
765 "INVALID_SESSION",
766 "Invalid session ID format",
767 );
768 };
769
770 {
772 let sessions = state.sessions.read().await;
773 if !sessions.contains_key(&session_id) {
774 return subscribe_error(
775 StatusCode::NOT_FOUND,
776 "SESSION_NOT_FOUND",
777 "Session not found or expired",
778 );
779 }
780 }
781
782 let workflow_uuid = match uuid::Uuid::parse_str(&request.workflow_id) {
784 Ok(uuid) => uuid,
785 Err(_) => {
786 return subscribe_error(
787 StatusCode::BAD_REQUEST,
788 "INVALID_WORKFLOW_ID",
789 "Invalid workflow ID format",
790 );
791 }
792 };
793
794 match state
796 .reactor
797 .subscribe_workflow(session_id, request.id.clone(), workflow_uuid)
798 .await
799 {
800 Ok(workflow_data) => {
801 let data = match serde_json::to_value(&workflow_data) {
802 Ok(v) => v,
803 Err(e) => {
804 tracing::error!("Failed to serialize workflow data: {}", e);
805 return subscribe_error(
806 StatusCode::INTERNAL_SERVER_ERROR,
807 "SERIALIZE_ERROR",
808 "Failed to serialize workflow data",
809 );
810 }
811 };
812 tracing::debug!(
813 %session_id,
814 workflow_id = %request.workflow_id,
815 client_sub_id = %request.id,
816 "SSE workflow subscription registered"
817 );
818 (
819 StatusCode::OK,
820 Json(SseSubscribeResponse {
821 success: true,
822 data: Some(data),
823 error: None,
824 }),
825 )
826 }
827 Err(e) => subscribe_error(StatusCode::NOT_FOUND, "WORKFLOW_NOT_FOUND", e.to_string()),
828 }
829}
830
831#[cfg(test)]
832mod tests {
833 use super::*;
834
835 #[test]
836 fn test_sse_payload_serialization() {
837 let payload = SsePayload::Update {
838 target: "sub:123".to_string(),
839 payload: serde_json::json!({"id": 1}),
840 };
841 let json = serde_json::to_string(&payload).unwrap();
842 assert!(json.contains("\"type\":\"update\""));
843 assert!(json.contains("\"target\":\"sub:123\""));
844 }
845
846 #[test]
847 fn test_sse_error_serialization() {
848 let payload = SsePayload::Error {
849 target: "sub:456".to_string(),
850 code: "NOT_FOUND".to_string(),
851 message: "Subscription not found".to_string(),
852 };
853 let json = serde_json::to_string(&payload).unwrap();
854 assert!(json.contains("\"type\":\"error\""));
855 assert!(json.contains("NOT_FOUND"));
856 }
857}