1use crate::error::Result;
6use crate::types::event::Event;
7use backon::{BackoffBuilder, ExponentialBuilder};
8use futures::StreamExt;
9use reqwest::Client as ReqClient;
10use reqwest_eventsource::{Event as EsEvent, EventSource};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::sync::RwLock as StdRwLock;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::time::Duration;
16use tokio::sync::{RwLock, broadcast, mpsc};
17use tokio_util::sync::CancellationToken;
18
19fn extract_session_id_from_raw_event(raw: &str) -> Option<String> {
20 let value: serde_json::Value = serde_json::from_str(raw).ok()?;
21 let event_type = value.get("type")?.as_str()?;
22 let properties = value.get("properties")?;
23
24 match event_type {
25 "message.part.updated" => properties
27 .get("part")
28 .and_then(|p| p.get("sessionID").or_else(|| p.get("sessionId")))
29 .and_then(|v| v.as_str())
30 .map(ToOwned::to_owned),
31 "session.idle" | "session.error" => properties
32 .get("sessionID")
33 .or_else(|| properties.get("sessionId"))
34 .and_then(|v| v.as_str())
35 .map(ToOwned::to_owned),
36 _ => None,
37 }
38}
39
40fn should_forward_event(session_filter: Option<&str>, raw: &str, ev: &Event) -> bool {
41 match session_filter {
42 None => true,
43 Some(expected_session_id) => extract_session_id_for_routing(raw, ev)
44 .map(|actual_session_id| actual_session_id == expected_session_id)
45 .unwrap_or(false),
46 }
47}
48
49fn extract_session_id_for_routing(raw: &str, ev: &Event) -> Option<String> {
50 if matches!(
51 ev,
52 Event::MessagePartUpdated { .. } | Event::SessionIdle { .. } | Event::SessionError { .. }
53 ) {
54 return extract_session_id_from_raw_event(raw);
55 }
56
57 ev.session_id().map(ToOwned::to_owned)
58}
59
60#[derive(Clone, Copy, Debug)]
62pub struct SseOptions {
63 pub capacity: usize,
65 pub initial_interval: Duration,
67 pub max_interval: Duration,
69}
70
71impl Default for SseOptions {
72 fn default() -> Self {
73 Self {
74 capacity: 256,
75 initial_interval: Duration::from_millis(250),
76 max_interval: Duration::from_secs(30),
77 }
78 }
79}
80
81#[derive(Debug, Clone, PartialEq, Eq, Default)]
83pub struct SseStreamStats {
84 pub events_in: u64,
86 pub events_out: u64,
88 pub dropped: u64,
90 pub parse_errors: u64,
92 pub reconnects: u64,
94 pub last_event_id: Option<String>,
96}
97
98#[derive(Debug, Default)]
99struct SharedSseStreamStats {
100 events_in: AtomicU64,
101 events_out: AtomicU64,
102 dropped: AtomicU64,
103 parse_errors: AtomicU64,
104 reconnects: AtomicU64,
105 last_event_id: StdRwLock<Option<String>>,
106}
107
108impl SharedSseStreamStats {
109 fn snapshot(&self) -> SseStreamStats {
110 SseStreamStats {
111 events_in: self.events_in.load(Ordering::Relaxed),
112 events_out: self.events_out.load(Ordering::Relaxed),
113 dropped: self.dropped.load(Ordering::Relaxed),
114 parse_errors: self.parse_errors.load(Ordering::Relaxed),
115 reconnects: self.reconnects.load(Ordering::Relaxed),
116 last_event_id: self
117 .last_event_id
118 .read()
119 .ok()
120 .and_then(|value| value.clone()),
121 }
122 }
123
124 fn set_last_event_id(&self, id: Option<String>) {
125 if let Ok(mut guard) = self.last_event_id.write() {
126 *guard = id;
127 }
128 }
129}
130
131pub struct SseSubscription {
135 rx: mpsc::Receiver<Event>,
136 stats: Arc<SharedSseStreamStats>,
137 cancel: CancellationToken,
138 _task: tokio::task::JoinHandle<()>,
139}
140
141#[derive(Debug, Clone, PartialEq, Eq)]
143pub struct RawSseEvent {
144 pub id: String,
146 pub event: String,
148 pub data: String,
150}
151
152pub struct RawSseSubscription {
156 rx: mpsc::Receiver<RawSseEvent>,
157 stats: Arc<SharedSseStreamStats>,
158 cancel: CancellationToken,
159 _task: tokio::task::JoinHandle<()>,
160}
161
162#[derive(Clone, Copy, Debug)]
164pub struct SessionEventRouterOptions {
165 pub upstream: SseOptions,
167 pub session_capacity: usize,
169 pub subscriber_capacity: usize,
171}
172
173impl Default for SessionEventRouterOptions {
174 fn default() -> Self {
175 Self {
176 upstream: SseOptions::default(),
177 session_capacity: 256,
178 subscriber_capacity: 256,
179 }
180 }
181}
182
183#[derive(Debug)]
184struct SessionEventRouterInner {
185 per_session_channels: Arc<RwLock<HashMap<String, broadcast::Sender<Event>>>>,
186 session_capacity: usize,
187 subscriber_capacity: usize,
188 upstream_stats: Arc<SharedSseStreamStats>,
189 cancel: CancellationToken,
190 _task: tokio::task::JoinHandle<()>,
191}
192
193#[derive(Clone, Debug)]
195pub struct SessionEventRouter {
196 inner: Arc<SessionEventRouterInner>,
197}
198
199impl SessionEventRouter {
200 pub async fn subscribe(&self, session_id: &str) -> SseSubscription {
202 let sender = {
203 let mut channels = self.inner.per_session_channels.write().await;
204 channels
205 .entry(session_id.to_string())
206 .or_insert_with(|| {
207 let (tx, _rx) = broadcast::channel(self.inner.session_capacity);
208 tx
209 })
210 .clone()
211 };
212
213 let mut session_rx = sender.subscribe();
214 let (tx, rx) = mpsc::channel(self.inner.subscriber_capacity);
215 let stats = Arc::new(SharedSseStreamStats::default());
216 let cancel = CancellationToken::new();
217 let cancel_clone = cancel.clone();
218 let stats_task = Arc::clone(&stats);
219
220 let task = tokio::spawn(async move {
221 loop {
222 tokio::select! {
223 () = cancel_clone.cancelled() => {
224 return;
225 }
226 recv = session_rx.recv() => {
227 match recv {
228 Ok(ev) => {
229 stats_task.events_in.fetch_add(1, Ordering::Relaxed);
230 if tx.send(ev).await.is_err() {
231 stats_task.dropped.fetch_add(1, Ordering::Relaxed);
232 return;
233 }
234 stats_task.events_out.fetch_add(1, Ordering::Relaxed);
235 }
236 Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => {
237 stats_task.dropped.fetch_add(skipped, Ordering::Relaxed);
238 tracing::warn!(
239 "SessionEventRouter subscription lagged by {} event(s)",
240 skipped
241 );
242 }
243 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
244 return;
245 }
246 }
247 }
248 }
249 }
250 });
251
252 SseSubscription {
253 rx,
254 stats,
255 cancel,
256 _task: task,
257 }
258 }
259
260 pub fn stats(&self) -> SseStreamStats {
262 self.inner.upstream_stats.snapshot()
263 }
264
265 pub fn close(&self) {
267 self.inner.cancel.cancel();
268 }
269}
270
271impl Drop for SessionEventRouter {
272 fn drop(&mut self) {
273 if Arc::strong_count(&self.inner) == 1 {
274 self.inner.cancel.cancel();
275 }
276 }
277}
278
279impl RawSseSubscription {
280 pub async fn recv(&mut self) -> Option<RawSseEvent> {
284 self.rx.recv().await
285 }
286
287 pub fn stats(&self) -> SseStreamStats {
289 self.stats.snapshot()
290 }
291
292 pub fn close(&self) {
294 self.cancel.cancel();
295 }
296}
297
298impl Drop for RawSseSubscription {
299 fn drop(&mut self) {
300 self.cancel.cancel();
301 }
302}
303
304impl SseSubscription {
305 pub async fn recv(&mut self) -> Option<Event> {
309 self.rx.recv().await
310 }
311
312 pub fn stats(&self) -> SseStreamStats {
314 self.stats.snapshot()
315 }
316
317 pub fn close(&self) {
319 self.cancel.cancel();
320 }
321}
322
323impl Drop for SseSubscription {
324 fn drop(&mut self) {
325 self.cancel.cancel();
326 }
327}
328
329#[derive(Clone)]
331pub struct SseSubscriber {
332 http: ReqClient,
333 base_url: String,
334 directory: Option<String>,
335 last_event_id: Arc<RwLock<Option<String>>>,
336}
337
338impl SseSubscriber {
339 pub fn new(
343 base_url: String,
344 directory: Option<String>,
345 last_event_id: Arc<RwLock<Option<String>>>,
346 ) -> Self {
347 Self {
348 http: ReqClient::new(),
349 base_url,
350 directory,
351 last_event_id,
352 }
353 }
354
355 pub async fn subscribe_session(
365 &self,
366 session_id: &str,
367 opts: SseOptions,
368 ) -> Result<SseSubscription> {
369 let url = format!("{}/event", self.base_url);
370 self.subscribe_filtered(url, Some(session_id.to_string()), opts)
371 .await
372 }
373
374 pub async fn subscribe(&self, opts: SseOptions) -> Result<SseSubscription> {
383 self.subscribe_typed(opts).await
384 }
385
386 pub async fn subscribe_typed(&self, opts: SseOptions) -> Result<SseSubscription> {
391 let url = format!("{}/event", self.base_url);
392 self.subscribe_filtered(url, None, opts).await
393 }
394
395 pub async fn subscribe_global(&self, opts: SseOptions) -> Result<SseSubscription> {
405 self.subscribe_typed_global(opts).await
406 }
407
408 pub async fn subscribe_typed_global(&self, opts: SseOptions) -> Result<SseSubscription> {
410 let url = format!("{}/global/event", self.base_url);
411 self.subscribe_filtered(url, None, opts).await
412 }
413
414 pub async fn subscribe_raw(&self, opts: SseOptions) -> Result<RawSseSubscription> {
418 let url = format!("{}/event", self.base_url);
419 self.subscribe_raw_inner(url, opts).await
420 }
421
422 pub async fn session_event_router(
424 &self,
425 opts: SessionEventRouterOptions,
426 ) -> Result<SessionEventRouter> {
427 let mut upstream = self.subscribe_raw(opts.upstream).await?;
428 let upstream_stats = Arc::clone(&upstream.stats);
429 let endpoint = format!("{}/event", self.base_url);
430 let directory = self.directory.clone();
431 let channels = Arc::new(RwLock::new(
432 HashMap::<String, broadcast::Sender<Event>>::new(),
433 ));
434 let channels_task = channels.clone();
435 let cancel = CancellationToken::new();
436 let cancel_clone = cancel.clone();
437
438 let task = tokio::spawn(async move {
439 loop {
440 tokio::select! {
441 () = cancel_clone.cancelled() => {
442 upstream.close();
443 return;
444 }
445 maybe_raw = upstream.recv() => {
446 let Some(raw) = maybe_raw else {
447 return;
448 };
449
450 let event = match serde_json::from_str::<Event>(&raw.data) {
451 Ok(ev) => ev,
452 Err(e) => {
453 tracing::warn!(
454 "SessionEventRouter failed to parse raw event endpoint={} directory={:?} last_event_id={}: {} - Raw data: {}",
455 endpoint,
456 directory,
457 raw.id,
458 e,
459 raw.data
460 );
461 continue;
462 }
463 };
464
465 let Some(session_id) = extract_session_id_for_routing(&raw.data, &event) else {
466 continue;
467 };
468
469 let sender = channels_task.read().await.get(&session_id).cloned();
470 if let Some(sender) = sender {
471 if sender.receiver_count() == 0 {
472 channels_task.write().await.remove(&session_id);
473 continue;
474 }
475
476 let _ = sender.send(event);
477 }
478 }
479 }
480 }
481 });
482
483 Ok(SessionEventRouter {
484 inner: Arc::new(SessionEventRouterInner {
485 per_session_channels: channels,
486 session_capacity: opts.session_capacity,
487 subscriber_capacity: opts.subscriber_capacity,
488 upstream_stats,
489 cancel,
490 _task: task,
491 }),
492 })
493 }
494
495 async fn subscribe_filtered(
496 &self,
497 url: String,
498 session_filter: Option<String>,
499 opts: SseOptions,
500 ) -> Result<SseSubscription> {
501 let (tx, rx) = mpsc::channel(opts.capacity);
502 let stats = Arc::new(SharedSseStreamStats::default());
503 let cancel = CancellationToken::new();
504 let cancel_clone = cancel.clone();
505
506 let http = self.http.clone();
507 let dir = self.directory.clone();
508 let lei = self.last_event_id.clone();
509 let initial = opts.initial_interval;
510 let max = opts.max_interval;
511 let endpoint = url.clone();
512 let stats_task = Arc::clone(&stats);
513
514 stats.set_last_event_id(lei.read().await.clone());
515 let filter = session_filter;
516
517 let task = tokio::spawn(async move {
518 let backoff_builder = ExponentialBuilder::default()
522 .with_min_delay(initial)
523 .with_max_delay(max)
524 .with_factor(2.0)
525 .with_jitter();
526
527 let mut backoff = backoff_builder.build();
528
529 loop {
530 if cancel_clone.is_cancelled() {
531 break;
532 }
533
534 let mut req = http.get(&url);
535 if let Some(d) = &dir {
536 req = req.header("x-opencode-directory", d);
537 }
538 if let Some(id) = lei.read().await.clone() {
539 req = req.header("Last-Event-ID", id);
540 }
541
542 let es_result = EventSource::new(req);
543 let mut es = match es_result {
544 Ok(es) => es,
545 Err(e) => {
546 tracing::warn!(
547 "Failed to create EventSource endpoint={} directory={:?} session_filter={:?}: {:?}",
548 endpoint,
549 dir,
550 filter,
551 e
552 );
553 if let Some(delay) = backoff.next() {
554 stats_task.reconnects.fetch_add(1, Ordering::Relaxed);
555 tokio::select! {
556 () = tokio::time::sleep(delay) => {}
557 () = cancel_clone.cancelled() => { return; }
558 }
559 }
560 continue;
561 }
562 };
563
564 while let Some(event) = es.next().await {
565 if cancel_clone.is_cancelled() {
566 es.close();
567 return;
568 }
569
570 match event {
571 Ok(EsEvent::Open) => {
572 backoff = backoff_builder.build();
574 tracing::debug!(
575 "SSE connection opened endpoint={} directory={:?} session_filter={:?}",
576 endpoint,
577 dir,
578 filter
579 );
580 }
581 Ok(EsEvent::Message(msg)) => {
582 stats_task.events_in.fetch_add(1, Ordering::Relaxed);
583 if !msg.id.is_empty() {
585 *lei.write().await = Some(msg.id.clone());
586 stats_task.set_last_event_id(Some(msg.id.clone()));
587 }
588
589 match serde_json::from_str::<Event>(&msg.data) {
591 Ok(ev) => {
592 tracing::debug!(
593 "Parsed SSE event endpoint={} directory={:?} session_filter={:?}: {:?}",
594 endpoint,
595 dir,
596 filter,
597 ev
598 );
599 let should_send =
601 should_forward_event(filter.as_deref(), &msg.data, &ev);
602
603 if should_send {
604 if tx.send(ev).await.is_err() {
605 stats_task.dropped.fetch_add(1, Ordering::Relaxed);
606 es.close();
607 return;
608 }
609 stats_task.events_out.fetch_add(1, Ordering::Relaxed);
610 } else {
611 stats_task.dropped.fetch_add(1, Ordering::Relaxed);
612 }
613 }
614 Err(e) => {
615 stats_task.parse_errors.fetch_add(1, Ordering::Relaxed);
616 stats_task.dropped.fetch_add(1, Ordering::Relaxed);
617 tracing::warn!(
618 "Failed to parse SSE event endpoint={} directory={:?} session_filter={:?}: {} - Raw data: {}",
619 endpoint,
620 dir,
621 filter,
622 e,
623 msg.data
624 );
625 }
626 }
627 }
628 Err(e) => {
629 tracing::warn!(
630 "SSE error endpoint={} directory={:?} session_filter={:?}: {:?}",
631 endpoint,
632 dir,
633 filter,
634 e
635 );
636 es.close();
637 break; }
639 }
640 }
641
642 if let Some(delay) = backoff.next() {
644 stats_task.reconnects.fetch_add(1, Ordering::Relaxed);
645 tracing::debug!(
646 "SSE reconnecting endpoint={} directory={:?} session_filter={:?} after {:?}",
647 endpoint,
648 dir,
649 filter,
650 delay
651 );
652 tokio::select! {
653 () = tokio::time::sleep(delay) => {}
654 () = cancel_clone.cancelled() => { return; }
655 }
656 }
657 }
658 });
659
660 Ok(SseSubscription {
661 rx,
662 stats,
663 cancel,
664 _task: task,
665 })
666 }
667
668 async fn subscribe_raw_inner(
669 &self,
670 url: String,
671 opts: SseOptions,
672 ) -> Result<RawSseSubscription> {
673 let (tx, rx) = mpsc::channel(opts.capacity);
674 let stats = Arc::new(SharedSseStreamStats::default());
675 let cancel = CancellationToken::new();
676 let cancel_clone = cancel.clone();
677
678 let http = self.http.clone();
679 let dir = self.directory.clone();
680 let lei = self.last_event_id.clone();
681 let initial = opts.initial_interval;
682 let max = opts.max_interval;
683 let endpoint = url.clone();
684 let stats_task = Arc::clone(&stats);
685
686 stats.set_last_event_id(lei.read().await.clone());
687
688 let task = tokio::spawn(async move {
689 let backoff_builder = ExponentialBuilder::default()
690 .with_min_delay(initial)
691 .with_max_delay(max)
692 .with_factor(2.0)
693 .with_jitter();
694
695 let mut backoff = backoff_builder.build();
696
697 loop {
698 if cancel_clone.is_cancelled() {
699 break;
700 }
701
702 let mut req = http.get(&url);
703 if let Some(d) = &dir {
704 req = req.header("x-opencode-directory", d);
705 }
706 if let Some(id) = lei.read().await.clone() {
707 req = req.header("Last-Event-ID", id);
708 }
709
710 let es_result = EventSource::new(req);
711 let mut es = match es_result {
712 Ok(es) => es,
713 Err(e) => {
714 tracing::warn!(
715 "Failed to create raw EventSource endpoint={} directory={:?}: {:?}",
716 endpoint,
717 dir,
718 e
719 );
720 if let Some(delay) = backoff.next() {
721 stats_task.reconnects.fetch_add(1, Ordering::Relaxed);
722 tokio::select! {
723 () = tokio::time::sleep(delay) => {}
724 () = cancel_clone.cancelled() => { return; }
725 }
726 }
727 continue;
728 }
729 };
730
731 while let Some(event) = es.next().await {
732 if cancel_clone.is_cancelled() {
733 es.close();
734 return;
735 }
736
737 match event {
738 Ok(EsEvent::Open) => {
739 backoff = backoff_builder.build();
740 tracing::debug!(
741 "SSE raw connection opened endpoint={} directory={:?}",
742 endpoint,
743 dir
744 );
745 }
746 Ok(EsEvent::Message(msg)) => {
747 stats_task.events_in.fetch_add(1, Ordering::Relaxed);
748 if !msg.id.is_empty() {
749 *lei.write().await = Some(msg.id.clone());
750 stats_task.set_last_event_id(Some(msg.id.clone()));
751 }
752
753 let raw = RawSseEvent {
754 id: msg.id,
755 event: msg.event,
756 data: msg.data,
757 };
758
759 if tx.send(raw).await.is_err() {
760 stats_task.dropped.fetch_add(1, Ordering::Relaxed);
761 es.close();
762 return;
763 }
764 stats_task.events_out.fetch_add(1, Ordering::Relaxed);
765 }
766 Err(e) => {
767 tracing::warn!(
768 "SSE raw error endpoint={} directory={:?}: {:?}",
769 endpoint,
770 dir,
771 e
772 );
773 es.close();
774 break;
775 }
776 }
777 }
778
779 if let Some(delay) = backoff.next() {
780 stats_task.reconnects.fetch_add(1, Ordering::Relaxed);
781 tracing::debug!(
782 "SSE raw reconnecting endpoint={} directory={:?} after {:?}",
783 endpoint,
784 dir,
785 delay
786 );
787 tokio::select! {
788 () = tokio::time::sleep(delay) => {}
789 () = cancel_clone.cancelled() => { return; }
790 }
791 }
792 }
793 });
794
795 Ok(RawSseSubscription {
796 rx,
797 stats,
798 cancel,
799 _task: task,
800 })
801 }
802}
803
804#[cfg(test)]
805mod tests {
806 use super::*;
810
811 #[test]
812 fn test_sse_options_defaults() {
813 let opts = SseOptions::default();
814 assert_eq!(opts.capacity, 256);
815 assert_eq!(opts.initial_interval, Duration::from_millis(250));
816 assert_eq!(opts.max_interval, Duration::from_secs(30));
817 }
818
819 #[tokio::test]
820 async fn test_subscription_cancel_on_close() {
821 let subscriber = SseSubscriber::new(
822 "http://localhost:9999".to_string(),
823 None,
824 Arc::new(RwLock::new(None)),
825 );
826
827 let opts = SseOptions {
829 capacity: 1,
830 initial_interval: Duration::from_millis(10),
831 max_interval: Duration::from_millis(50),
832 };
833
834 let subscription = subscriber.subscribe_global(opts).await.unwrap();
835 assert_eq!(subscription.stats().events_in, 0);
836 subscription.close();
837 assert!(subscription.cancel.is_cancelled());
839 }
840
841 #[test]
842 fn test_extract_session_id_from_raw_event_accepts_session_id_variants() {
843 let message_part_with_pascal =
844 r#"{"type":"message.part.updated","properties":{"part":{"sessionID":"sess-a"}}}"#;
845 assert_eq!(
846 extract_session_id_from_raw_event(message_part_with_pascal),
847 Some("sess-a".to_string())
848 );
849
850 let message_part_with_camel =
851 r#"{"type":"message.part.updated","properties":{"part":{"sessionId":"sess-b"}}}"#;
852 assert_eq!(
853 extract_session_id_from_raw_event(message_part_with_camel),
854 Some("sess-b".to_string())
855 );
856
857 let session_idle_with_camel =
858 r#"{"type":"session.idle","properties":{"sessionId":"sess-c"}}"#;
859 assert_eq!(
860 extract_session_id_from_raw_event(session_idle_with_camel),
861 Some("sess-c".to_string())
862 );
863 }
864
865 #[test]
866 fn test_should_forward_event_drops_events_without_session_id_when_filtered() {
867 let unknown_json = r#"{"type":"server.connected","properties":{}}"#;
868 let event: Event = serde_json::from_str(unknown_json).unwrap();
869
870 assert!(should_forward_event(None, unknown_json, &event));
871 assert!(!should_forward_event(
872 Some("sess-123"),
873 unknown_json,
874 &event
875 ));
876 }
877
878 #[test]
879 fn test_should_forward_event_for_message_part_uses_raw_js_parity_fields() {
880 let json =
884 r#"{"type":"message.part.updated","properties":{"sessionId":"sess-top","delta":"hi"}}"#;
885 let event: Event = serde_json::from_str(json).unwrap();
886
887 assert!(!should_forward_event(Some("sess-top"), json, &event));
888 }
889
890 #[test]
891 fn test_extract_session_id_for_routing_prefers_raw_parity_fields() {
892 let json = r#"{"type":"message.part.updated","properties":{"part":{"type":"text","text":"","sessionID":"sess-nested"},"sessionId":"sess-top"}}"#;
893 let event: Event = serde_json::from_str(json).unwrap();
894
895 assert_eq!(
896 extract_session_id_for_routing(json, &event),
897 Some("sess-nested".to_string())
898 );
899 }
900
901 #[test]
902 fn test_extract_session_id_for_routing_falls_back_to_typed_fields() {
903 let json = r#"{"type":"message.updated","properties":{"info":{"id":"m1","sessionId":"sess-typed","role":"assistant","time":{"created":1}}}}"#;
904 let event: Event = serde_json::from_str(json).unwrap();
905
906 assert_eq!(
907 extract_session_id_for_routing(json, &event),
908 Some("sess-typed".to_string())
909 );
910 }
911
912 #[test]
913 fn test_concurrent_session_filtering_no_delta_cross_contamination() {
914 let raw_events = [
915 r#"{"type":"message.part.updated","properties":{"part":{"type":"text","text":"","sessionID":"sess-a"},"delta":"alpha"}}"#,
916 r#"{"type":"message.part.updated","properties":{"part":{"type":"text","text":"","sessionID":"sess-b"},"delta":"bravo"}}"#,
917 r#"{"type":"server.heartbeat","properties":{}}"#,
918 r#"{"type":"message.part.updated","properties":{"part":{"type":"text","text":"","sessionID":"sess-a"},"delta":"-2"}}"#,
919 r#"{"type":"message.part.updated","properties":{"part":{"type":"text","text":"","sessionID":"sess-b"},"delta":"-2"}}"#,
920 ];
921
922 let mut a = String::new();
923 let mut b = String::new();
924
925 for raw in raw_events {
926 let ev: Event = serde_json::from_str(raw).unwrap();
927
928 if should_forward_event(Some("sess-a"), raw, &ev)
929 && let Event::MessagePartUpdated { properties } = &ev
930 && let Some(delta) = &properties.delta
931 {
932 a.push_str(delta);
933 }
934
935 if should_forward_event(Some("sess-b"), raw, &ev)
936 && let Event::MessagePartUpdated { properties } = &ev
937 && let Some(delta) = &properties.delta
938 {
939 b.push_str(delta);
940 }
941 }
942
943 assert_eq!(a, "alpha-2");
944 assert_eq!(b, "bravo-2");
945 }
946
947 #[tokio::test]
948 async fn test_subscribe_raw_yields_payloads() {
949 use std::io::{Read, Write};
950 use std::net::TcpListener;
951
952 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
953 let addr = listener.local_addr().unwrap();
954
955 let server = std::thread::spawn(move || {
956 let (mut stream, _) = listener.accept().unwrap();
957
958 let mut buf = [0_u8; 1024];
959 let _ = stream.read(&mut buf);
960
961 let body = concat!(
962 "id: 1\n",
963 "event: message\n",
964 "data: {\"type\":\"server.connected\",\"properties\":{}}\n",
965 "\n",
966 "id: 2\n",
967 "event: message\n",
968 "data: {\"type\":\"server.heartbeat\",\"properties\":{}}\n",
969 "\n"
970 );
971
972 let response = format!(
973 "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncache-control: no-cache\r\nconnection: close\r\ncontent-length: {}\r\n\r\n{}",
974 body.len(),
975 body
976 );
977
978 let _ = stream.write_all(response.as_bytes());
979 let _ = stream.flush();
980 });
981
982 let subscriber = SseSubscriber::new(
983 format!("http://{}", addr),
984 None,
985 Arc::new(RwLock::new(None)),
986 );
987
988 let mut sub = subscriber
989 .subscribe_raw(SseOptions {
990 capacity: 8,
991 initial_interval: Duration::from_millis(10),
992 max_interval: Duration::from_millis(20),
993 })
994 .await
995 .unwrap();
996
997 let first = tokio::time::timeout(Duration::from_secs(2), sub.recv())
998 .await
999 .unwrap()
1000 .unwrap();
1001 assert_eq!(first.id, "1");
1002 assert_eq!(first.event, "message");
1003 assert!(first.data.contains("server.connected"));
1004
1005 let second = tokio::time::timeout(Duration::from_secs(2), sub.recv())
1006 .await
1007 .unwrap()
1008 .unwrap();
1009 assert_eq!(second.id, "2");
1010 assert!(second.data.contains("server.heartbeat"));
1011
1012 let stats = sub.stats();
1013 assert_eq!(stats.events_in, 2);
1014 assert_eq!(stats.events_out, 2);
1015 assert_eq!(stats.dropped, 0);
1016 assert_eq!(stats.parse_errors, 0);
1017 assert_eq!(stats.last_event_id.as_deref(), Some("2"));
1018
1019 sub.close();
1020 let _ = server.join();
1021 }
1022
1023 #[tokio::test]
1024 async fn test_subscribe_typed_tracks_parse_errors_and_drops() {
1025 use std::io::{Read, Write};
1026 use std::net::TcpListener;
1027
1028 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1029 let addr = listener.local_addr().unwrap();
1030
1031 let server = std::thread::spawn(move || {
1032 let (mut stream, _) = listener.accept().unwrap();
1033
1034 let mut buf = [0_u8; 1024];
1035 let _ = stream.read(&mut buf);
1036
1037 let body = concat!(
1038 "id: 1\n",
1039 "event: message\n",
1040 "data: {\"type\":\"server.connected\",\"properties\":{}}\n",
1041 "\n",
1042 "id: 2\n",
1043 "event: message\n",
1044 "data: not-json\n",
1045 "\n",
1046 "id: 3\n",
1047 "event: message\n",
1048 "data: {\"type\":\"server.heartbeat\",\"properties\":{}}\n",
1049 "\n"
1050 );
1051
1052 let response = format!(
1053 "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncache-control: no-cache\r\nconnection: close\r\ncontent-length: {}\r\n\r\n{}",
1054 body.len(),
1055 body
1056 );
1057
1058 let _ = stream.write_all(response.as_bytes());
1059 let _ = stream.flush();
1060 });
1061
1062 let subscriber = SseSubscriber::new(
1063 format!("http://{}", addr),
1064 None,
1065 Arc::new(RwLock::new(None)),
1066 );
1067
1068 let mut sub = subscriber
1069 .subscribe_typed(SseOptions {
1070 capacity: 8,
1071 initial_interval: Duration::from_millis(10),
1072 max_interval: Duration::from_millis(20),
1073 })
1074 .await
1075 .unwrap();
1076
1077 let first = tokio::time::timeout(Duration::from_secs(2), sub.recv())
1078 .await
1079 .unwrap()
1080 .unwrap();
1081 assert!(matches!(first, Event::ServerConnected { .. }));
1082
1083 let second = tokio::time::timeout(Duration::from_secs(2), sub.recv())
1084 .await
1085 .unwrap()
1086 .unwrap();
1087 assert!(matches!(second, Event::ServerHeartbeat { .. }));
1088
1089 tokio::time::sleep(Duration::from_millis(50)).await;
1090
1091 let stats = sub.stats();
1092 assert_eq!(stats.events_in, 3);
1093 assert_eq!(stats.events_out, 2);
1094 assert_eq!(stats.dropped, 1);
1095 assert_eq!(stats.parse_errors, 1);
1096 assert_eq!(stats.last_event_id.as_deref(), Some("3"));
1097
1098 sub.close();
1099 let _ = server.join();
1100 }
1101
1102 #[tokio::test]
1103 async fn test_session_event_router_exposes_upstream_stats() {
1104 use std::io::{Read, Write};
1105 use std::net::TcpListener;
1106
1107 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1108 let addr = listener.local_addr().unwrap();
1109
1110 let server = std::thread::spawn(move || {
1111 let (mut stream, _) = listener.accept().unwrap();
1112
1113 let mut buf = [0_u8; 1024];
1114 let _ = stream.read(&mut buf);
1115
1116 let body = concat!(
1117 "id: 9\n",
1118 "event: message\n",
1119 "data: {\"type\":\"message.removed\",\"properties\":{\"sessionId\":\"sess-a\",\"messageId\":\"msg-1\"}}\n",
1120 "\n"
1121 );
1122
1123 let response = format!(
1124 "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncache-control: no-cache\r\nconnection: close\r\ncontent-length: {}\r\n\r\n{}",
1125 body.len(),
1126 body
1127 );
1128
1129 let _ = stream.write_all(response.as_bytes());
1130 let _ = stream.flush();
1131 });
1132
1133 let subscriber = SseSubscriber::new(
1134 format!("http://{}", addr),
1135 None,
1136 Arc::new(RwLock::new(None)),
1137 );
1138
1139 let router = subscriber
1140 .session_event_router(SessionEventRouterOptions {
1141 upstream: SseOptions {
1142 capacity: 8,
1143 initial_interval: Duration::from_millis(10),
1144 max_interval: Duration::from_millis(20),
1145 },
1146 session_capacity: 8,
1147 subscriber_capacity: 8,
1148 })
1149 .await
1150 .unwrap();
1151
1152 let mut session_sub = router.subscribe("sess-a").await;
1153 let event = tokio::time::timeout(Duration::from_secs(2), session_sub.recv())
1154 .await
1155 .unwrap()
1156 .unwrap();
1157 assert!(matches!(event, Event::MessageRemoved { .. }));
1158
1159 tokio::time::sleep(Duration::from_millis(50)).await;
1160
1161 let stats = router.stats();
1162 assert_eq!(stats.events_in, 1);
1163 assert_eq!(stats.events_out, 1);
1164 assert_eq!(stats.dropped, 0);
1165 assert_eq!(stats.parse_errors, 0);
1166 assert_eq!(stats.last_event_id.as_deref(), Some("9"));
1167
1168 session_sub.close();
1169 router.close();
1170 let _ = server.join();
1171 }
1172}