Skip to main content

opencode_sdk/
sse.rs

1//! SSE (Server-Sent Events) streaming support.
2//!
3//! This module provides SSE subscription with reconnection and backoff.
4
5use crate::error::Result;
6use crate::types::event::Event;
7use backon::{BackoffBuilder, ExponentialBuilder};
8use futures::StreamExt;
9use reqwest::Client as ReqClient;
10use reqwest_eventsource::{Event as EsEvent, EventSource};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::sync::RwLock as StdRwLock;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::time::Duration;
16use tokio::sync::{RwLock, broadcast, mpsc};
17use tokio_util::sync::CancellationToken;
18
19fn extract_session_id_from_raw_event(raw: &str) -> Option<String> {
20    let value: serde_json::Value = serde_json::from_str(raw).ok()?;
21    let event_type = value.get("type")?.as_str()?;
22    let properties = value.get("properties")?;
23
24    match event_type {
25        // JS SDK routes token stream events via properties.part.sessionID
26        "message.part.updated" => properties
27            .get("part")
28            .and_then(|p| p.get("sessionID").or_else(|| p.get("sessionId")))
29            .and_then(|v| v.as_str())
30            .map(ToOwned::to_owned),
31        "session.idle" | "session.error" => properties
32            .get("sessionID")
33            .or_else(|| properties.get("sessionId"))
34            .and_then(|v| v.as_str())
35            .map(ToOwned::to_owned),
36        _ => None,
37    }
38}
39
40fn should_forward_event(session_filter: Option<&str>, raw: &str, ev: &Event) -> bool {
41    match session_filter {
42        None => true,
43        Some(expected_session_id) => extract_session_id_for_routing(raw, ev)
44            .map(|actual_session_id| actual_session_id == expected_session_id)
45            .unwrap_or(false),
46    }
47}
48
49fn extract_session_id_for_routing(raw: &str, ev: &Event) -> Option<String> {
50    if matches!(
51        ev,
52        Event::MessagePartUpdated { .. } | Event::SessionIdle { .. } | Event::SessionError { .. }
53    ) {
54        return extract_session_id_from_raw_event(raw);
55    }
56
57    ev.session_id().map(ToOwned::to_owned)
58}
59
60/// Options for SSE subscription.
61#[derive(Clone, Copy, Debug)]
62pub struct SseOptions {
63    /// Channel capacity (default: 256).
64    pub capacity: usize,
65    /// Initial backoff interval (default: 250ms).
66    pub initial_interval: Duration,
67    /// Max backoff interval (default: 30s).
68    pub max_interval: Duration,
69}
70
71impl Default for SseOptions {
72    fn default() -> Self {
73        Self {
74            capacity: 256,
75            initial_interval: Duration::from_millis(250),
76            max_interval: Duration::from_secs(30),
77        }
78    }
79}
80
81/// Snapshot of SSE stream diagnostics counters.
82#[derive(Debug, Clone, PartialEq, Eq, Default)]
83pub struct SseStreamStats {
84    /// Number of SSE message frames received from the server.
85    pub events_in: u64,
86    /// Number of events successfully emitted to the subscription receiver.
87    pub events_out: u64,
88    /// Number of events dropped before delivery.
89    pub dropped: u64,
90    /// Number of JSON parse errors while decoding typed SSE events.
91    pub parse_errors: u64,
92    /// Number of reconnect attempts after stream interruption.
93    pub reconnects: u64,
94    /// Last observed `Last-Event-ID` value, if any.
95    pub last_event_id: Option<String>,
96}
97
98#[derive(Debug, Default)]
99struct SharedSseStreamStats {
100    events_in: AtomicU64,
101    events_out: AtomicU64,
102    dropped: AtomicU64,
103    parse_errors: AtomicU64,
104    reconnects: AtomicU64,
105    last_event_id: StdRwLock<Option<String>>,
106}
107
108impl SharedSseStreamStats {
109    fn snapshot(&self) -> SseStreamStats {
110        SseStreamStats {
111            events_in: self.events_in.load(Ordering::Relaxed),
112            events_out: self.events_out.load(Ordering::Relaxed),
113            dropped: self.dropped.load(Ordering::Relaxed),
114            parse_errors: self.parse_errors.load(Ordering::Relaxed),
115            reconnects: self.reconnects.load(Ordering::Relaxed),
116            last_event_id: self
117                .last_event_id
118                .read()
119                .ok()
120                .and_then(|value| value.clone()),
121        }
122    }
123
124    fn set_last_event_id(&self, id: Option<String>) {
125        if let Ok(mut guard) = self.last_event_id.write() {
126            *guard = id;
127        }
128    }
129}
130
131/// Handle to an active SSE subscription.
132///
133/// Dropping this handle will cancel the subscription.
134pub struct SseSubscription {
135    rx: mpsc::Receiver<Event>,
136    stats: Arc<SharedSseStreamStats>,
137    cancel: CancellationToken,
138    _task: tokio::task::JoinHandle<()>,
139}
140
141/// Raw SSE message frame as delivered by the `/event` endpoint.
142#[derive(Debug, Clone, PartialEq, Eq)]
143pub struct RawSseEvent {
144    /// Last-Event-ID value from the server, if any.
145    pub id: String,
146    /// SSE event name from the server frame.
147    pub event: String,
148    /// Raw JSON payload in the `data:` frame.
149    pub data: String,
150}
151
152/// Handle to an active raw SSE subscription.
153///
154/// Dropping this handle will cancel the subscription.
155pub struct RawSseSubscription {
156    rx: mpsc::Receiver<RawSseEvent>,
157    stats: Arc<SharedSseStreamStats>,
158    cancel: CancellationToken,
159    _task: tokio::task::JoinHandle<()>,
160}
161
162/// Options for [`SessionEventRouter`].
163#[derive(Clone, Copy, Debug)]
164pub struct SessionEventRouterOptions {
165    /// Upstream `/event` raw stream subscription options.
166    pub upstream: SseOptions,
167    /// Per-session fan-out channel capacity (default: 256).
168    pub session_capacity: usize,
169    /// Per-subscriber output channel capacity (default: 256).
170    pub subscriber_capacity: usize,
171}
172
173impl Default for SessionEventRouterOptions {
174    fn default() -> Self {
175        Self {
176            upstream: SseOptions::default(),
177            session_capacity: 256,
178            subscriber_capacity: 256,
179        }
180    }
181}
182
183#[derive(Debug)]
184struct SessionEventRouterInner {
185    per_session_channels: Arc<RwLock<HashMap<String, broadcast::Sender<Event>>>>,
186    session_capacity: usize,
187    subscriber_capacity: usize,
188    upstream_stats: Arc<SharedSseStreamStats>,
189    cancel: CancellationToken,
190    _task: tokio::task::JoinHandle<()>,
191}
192
193/// Multiplexes one upstream `/event` stream into per-session subscriptions.
194#[derive(Clone, Debug)]
195pub struct SessionEventRouter {
196    inner: Arc<SessionEventRouterInner>,
197}
198
199impl SessionEventRouter {
200    /// Subscribe to typed events for a single session ID.
201    pub async fn subscribe(&self, session_id: &str) -> SseSubscription {
202        let sender = {
203            let mut channels = self.inner.per_session_channels.write().await;
204            channels
205                .entry(session_id.to_string())
206                .or_insert_with(|| {
207                    let (tx, _rx) = broadcast::channel(self.inner.session_capacity);
208                    tx
209                })
210                .clone()
211        };
212
213        let mut session_rx = sender.subscribe();
214        let (tx, rx) = mpsc::channel(self.inner.subscriber_capacity);
215        let stats = Arc::new(SharedSseStreamStats::default());
216        let cancel = CancellationToken::new();
217        let cancel_clone = cancel.clone();
218        let stats_task = Arc::clone(&stats);
219
220        let task = tokio::spawn(async move {
221            loop {
222                tokio::select! {
223                    () = cancel_clone.cancelled() => {
224                        return;
225                    }
226                    recv = session_rx.recv() => {
227                        match recv {
228                            Ok(ev) => {
229                                stats_task.events_in.fetch_add(1, Ordering::Relaxed);
230                                if tx.send(ev).await.is_err() {
231                                    stats_task.dropped.fetch_add(1, Ordering::Relaxed);
232                                    return;
233                                }
234                                stats_task.events_out.fetch_add(1, Ordering::Relaxed);
235                            }
236                            Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => {
237                                stats_task.dropped.fetch_add(skipped, Ordering::Relaxed);
238                                tracing::warn!(
239                                    "SessionEventRouter subscription lagged by {} event(s)",
240                                    skipped
241                                );
242                            }
243                            Err(tokio::sync::broadcast::error::RecvError::Closed) => {
244                                return;
245                            }
246                        }
247                    }
248                }
249            }
250        });
251
252        SseSubscription {
253            rx,
254            stats,
255            cancel,
256            _task: task,
257        }
258    }
259
260    /// Get diagnostics for the upstream `/event` stream used by this router.
261    pub fn stats(&self) -> SseStreamStats {
262        self.inner.upstream_stats.snapshot()
263    }
264
265    /// Stop the router and all upstream activity.
266    pub fn close(&self) {
267        self.inner.cancel.cancel();
268    }
269}
270
271impl Drop for SessionEventRouter {
272    fn drop(&mut self) {
273        if Arc::strong_count(&self.inner) == 1 {
274            self.inner.cancel.cancel();
275        }
276    }
277}
278
279impl RawSseSubscription {
280    /// Receive the next raw SSE message.
281    ///
282    /// Returns `None` if the stream is closed.
283    pub async fn recv(&mut self) -> Option<RawSseEvent> {
284        self.rx.recv().await
285    }
286
287    /// Get a snapshot of stream diagnostics.
288    pub fn stats(&self) -> SseStreamStats {
289        self.stats.snapshot()
290    }
291
292    /// Close the subscription explicitly.
293    pub fn close(&self) {
294        self.cancel.cancel();
295    }
296}
297
298impl Drop for RawSseSubscription {
299    fn drop(&mut self) {
300        self.cancel.cancel();
301    }
302}
303
304impl SseSubscription {
305    /// Receive the next event.
306    ///
307    /// Returns `None` if the stream is closed.
308    pub async fn recv(&mut self) -> Option<Event> {
309        self.rx.recv().await
310    }
311
312    /// Get a snapshot of stream diagnostics.
313    pub fn stats(&self) -> SseStreamStats {
314        self.stats.snapshot()
315    }
316
317    /// Close the subscription explicitly.
318    pub fn close(&self) {
319        self.cancel.cancel();
320    }
321}
322
323impl Drop for SseSubscription {
324    fn drop(&mut self) {
325        self.cancel.cancel();
326    }
327}
328
329/// SSE subscriber for OpenCode events.
330#[derive(Clone)]
331pub struct SseSubscriber {
332    http: ReqClient,
333    base_url: String,
334    directory: Option<String>,
335    last_event_id: Arc<RwLock<Option<String>>>,
336}
337
338impl SseSubscriber {
339    // TODO(3): Accept optional ReqClient to allow connection pool sharing with HttpClient
340
341    /// Create a new SSE subscriber.
342    pub fn new(
343        base_url: String,
344        directory: Option<String>,
345        last_event_id: Arc<RwLock<Option<String>>>,
346    ) -> Self {
347        Self {
348            http: ReqClient::new(),
349            base_url,
350            directory,
351            last_event_id,
352        }
353    }
354
355    /// Subscribe to events, optionally filtered by session ID.
356    ///
357    /// OpenCode's `/event` endpoint streams all events for the configured directory.
358    /// If `session_id` is provided, events will be filtered client-side to only
359    /// include events for that session.
360    ///
361    /// # Errors
362    ///
363    /// Returns an error if the subscription cannot be created.
364    pub async fn subscribe_session(
365        &self,
366        session_id: &str,
367        opts: SseOptions,
368    ) -> Result<SseSubscription> {
369        let url = format!("{}/event", self.base_url);
370        self.subscribe_filtered(url, Some(session_id.to_string()), opts)
371            .await
372    }
373
374    /// Subscribe to all events for the configured directory.
375    ///
376    /// This uses the `/event` endpoint which streams all events for the
377    /// directory specified via the `x-opencode-directory` header.
378    ///
379    /// # Errors
380    ///
381    /// Returns an error if the subscription cannot be created.
382    pub async fn subscribe(&self, opts: SseOptions) -> Result<SseSubscription> {
383        self.subscribe_typed(opts).await
384    }
385
386    /// Subscribe to all events for the configured directory as typed [`Event`] values.
387    ///
388    /// This is equivalent to [`Self::subscribe`], but explicitly named to distinguish
389    /// it from [`Self::subscribe_raw`].
390    pub async fn subscribe_typed(&self, opts: SseOptions) -> Result<SseSubscription> {
391        let url = format!("{}/event", self.base_url);
392        self.subscribe_filtered(url, None, opts).await
393    }
394
395    /// Subscribe to global events (all directories).
396    ///
397    /// This uses the `/global/event` endpoint which streams events from all
398    /// OpenCode instances across all directories. Events are wrapped in a
399    /// `GlobalEventEnvelope` with directory context.
400    ///
401    /// # Errors
402    ///
403    /// Returns an error if the subscription cannot be created.
404    pub async fn subscribe_global(&self, opts: SseOptions) -> Result<SseSubscription> {
405        self.subscribe_typed_global(opts).await
406    }
407
408    /// Subscribe to global events as typed [`Event`] values (all directories).
409    pub async fn subscribe_typed_global(&self, opts: SseOptions) -> Result<SseSubscription> {
410        let url = format!("{}/global/event", self.base_url);
411        self.subscribe_filtered(url, None, opts).await
412    }
413
414    /// Subscribe to raw JSON SSE frames from the configured directory's `/event` stream.
415    ///
416    /// This is intended for debugging and parity verification.
417    pub async fn subscribe_raw(&self, opts: SseOptions) -> Result<RawSseSubscription> {
418        let url = format!("{}/event", self.base_url);
419        self.subscribe_raw_inner(url, opts).await
420    }
421
422    /// Create a session event router with one upstream `/event` subscription.
423    pub async fn session_event_router(
424        &self,
425        opts: SessionEventRouterOptions,
426    ) -> Result<SessionEventRouter> {
427        let mut upstream = self.subscribe_raw(opts.upstream).await?;
428        let upstream_stats = Arc::clone(&upstream.stats);
429        let endpoint = format!("{}/event", self.base_url);
430        let directory = self.directory.clone();
431        let channels = Arc::new(RwLock::new(
432            HashMap::<String, broadcast::Sender<Event>>::new(),
433        ));
434        let channels_task = channels.clone();
435        let cancel = CancellationToken::new();
436        let cancel_clone = cancel.clone();
437
438        let task = tokio::spawn(async move {
439            loop {
440                tokio::select! {
441                    () = cancel_clone.cancelled() => {
442                        upstream.close();
443                        return;
444                    }
445                    maybe_raw = upstream.recv() => {
446                        let Some(raw) = maybe_raw else {
447                            return;
448                        };
449
450                        let event = match serde_json::from_str::<Event>(&raw.data) {
451                            Ok(ev) => ev,
452                            Err(e) => {
453                                tracing::warn!(
454                                    "SessionEventRouter failed to parse raw event endpoint={} directory={:?} last_event_id={}: {} - Raw data: {}",
455                                    endpoint,
456                                    directory,
457                                    raw.id,
458                                    e,
459                                    raw.data
460                                );
461                                continue;
462                            }
463                        };
464
465                        let Some(session_id) = extract_session_id_for_routing(&raw.data, &event) else {
466                            continue;
467                        };
468
469                        let sender = channels_task.read().await.get(&session_id).cloned();
470                        if let Some(sender) = sender {
471                            if sender.receiver_count() == 0 {
472                                channels_task.write().await.remove(&session_id);
473                                continue;
474                            }
475
476                            let _ = sender.send(event);
477                        }
478                    }
479                }
480            }
481        });
482
483        Ok(SessionEventRouter {
484            inner: Arc::new(SessionEventRouterInner {
485                per_session_channels: channels,
486                session_capacity: opts.session_capacity,
487                subscriber_capacity: opts.subscriber_capacity,
488                upstream_stats,
489                cancel,
490                _task: task,
491            }),
492        })
493    }
494
495    async fn subscribe_filtered(
496        &self,
497        url: String,
498        session_filter: Option<String>,
499        opts: SseOptions,
500    ) -> Result<SseSubscription> {
501        let (tx, rx) = mpsc::channel(opts.capacity);
502        let stats = Arc::new(SharedSseStreamStats::default());
503        let cancel = CancellationToken::new();
504        let cancel_clone = cancel.clone();
505
506        let http = self.http.clone();
507        let dir = self.directory.clone();
508        let lei = self.last_event_id.clone();
509        let initial = opts.initial_interval;
510        let max = opts.max_interval;
511        let endpoint = url.clone();
512        let stats_task = Arc::clone(&stats);
513
514        stats.set_last_event_id(lei.read().await.clone());
515        let filter = session_filter;
516
517        let task = tokio::spawn(async move {
518            // Note: No max_times means the subscriber will retry indefinitely.
519            // This is intentional for long-lived SSE connections that should reconnect
520            // on any transient network failure.
521            let backoff_builder = ExponentialBuilder::default()
522                .with_min_delay(initial)
523                .with_max_delay(max)
524                .with_factor(2.0)
525                .with_jitter();
526
527            let mut backoff = backoff_builder.build();
528
529            loop {
530                if cancel_clone.is_cancelled() {
531                    break;
532                }
533
534                let mut req = http.get(&url);
535                if let Some(d) = &dir {
536                    req = req.header("x-opencode-directory", d);
537                }
538                if let Some(id) = lei.read().await.clone() {
539                    req = req.header("Last-Event-ID", id);
540                }
541
542                let es_result = EventSource::new(req);
543                let mut es = match es_result {
544                    Ok(es) => es,
545                    Err(e) => {
546                        tracing::warn!(
547                            "Failed to create EventSource endpoint={} directory={:?} session_filter={:?}: {:?}",
548                            endpoint,
549                            dir,
550                            filter,
551                            e
552                        );
553                        if let Some(delay) = backoff.next() {
554                            stats_task.reconnects.fetch_add(1, Ordering::Relaxed);
555                            tokio::select! {
556                                () = tokio::time::sleep(delay) => {}
557                                () = cancel_clone.cancelled() => { return; }
558                            }
559                        }
560                        continue;
561                    }
562                };
563
564                while let Some(event) = es.next().await {
565                    if cancel_clone.is_cancelled() {
566                        es.close();
567                        return;
568                    }
569
570                    match event {
571                        Ok(EsEvent::Open) => {
572                            // Reset backoff on successful connection
573                            backoff = backoff_builder.build();
574                            tracing::debug!(
575                                "SSE connection opened endpoint={} directory={:?} session_filter={:?}",
576                                endpoint,
577                                dir,
578                                filter
579                            );
580                        }
581                        Ok(EsEvent::Message(msg)) => {
582                            stats_task.events_in.fetch_add(1, Ordering::Relaxed);
583                            // Track last event ID
584                            if !msg.id.is_empty() {
585                                *lei.write().await = Some(msg.id.clone());
586                                stats_task.set_last_event_id(Some(msg.id.clone()));
587                            }
588
589                            // Parse event
590                            match serde_json::from_str::<Event>(&msg.data) {
591                                Ok(ev) => {
592                                    tracing::debug!(
593                                        "Parsed SSE event endpoint={} directory={:?} session_filter={:?}: {:?}",
594                                        endpoint,
595                                        dir,
596                                        filter,
597                                        ev
598                                    );
599                                    // Apply session filter if specified
600                                    let should_send =
601                                        should_forward_event(filter.as_deref(), &msg.data, &ev);
602
603                                    if should_send {
604                                        if tx.send(ev).await.is_err() {
605                                            stats_task.dropped.fetch_add(1, Ordering::Relaxed);
606                                            es.close();
607                                            return;
608                                        }
609                                        stats_task.events_out.fetch_add(1, Ordering::Relaxed);
610                                    } else {
611                                        stats_task.dropped.fetch_add(1, Ordering::Relaxed);
612                                    }
613                                }
614                                Err(e) => {
615                                    stats_task.parse_errors.fetch_add(1, Ordering::Relaxed);
616                                    stats_task.dropped.fetch_add(1, Ordering::Relaxed);
617                                    tracing::warn!(
618                                        "Failed to parse SSE event endpoint={} directory={:?} session_filter={:?}: {} - Raw data: {}",
619                                        endpoint,
620                                        dir,
621                                        filter,
622                                        e,
623                                        msg.data
624                                    );
625                                }
626                            }
627                        }
628                        Err(e) => {
629                            tracing::warn!(
630                                "SSE error endpoint={} directory={:?} session_filter={:?}: {:?}",
631                                endpoint,
632                                dir,
633                                filter,
634                                e
635                            );
636                            es.close();
637                            break; // Break inner loop to reconnect
638                        }
639                    }
640                }
641
642                // Apply backoff before reconnecting
643                if let Some(delay) = backoff.next() {
644                    stats_task.reconnects.fetch_add(1, Ordering::Relaxed);
645                    tracing::debug!(
646                        "SSE reconnecting endpoint={} directory={:?} session_filter={:?} after {:?}",
647                        endpoint,
648                        dir,
649                        filter,
650                        delay
651                    );
652                    tokio::select! {
653                        () = tokio::time::sleep(delay) => {}
654                        () = cancel_clone.cancelled() => { return; }
655                    }
656                }
657            }
658        });
659
660        Ok(SseSubscription {
661            rx,
662            stats,
663            cancel,
664            _task: task,
665        })
666    }
667
668    async fn subscribe_raw_inner(
669        &self,
670        url: String,
671        opts: SseOptions,
672    ) -> Result<RawSseSubscription> {
673        let (tx, rx) = mpsc::channel(opts.capacity);
674        let stats = Arc::new(SharedSseStreamStats::default());
675        let cancel = CancellationToken::new();
676        let cancel_clone = cancel.clone();
677
678        let http = self.http.clone();
679        let dir = self.directory.clone();
680        let lei = self.last_event_id.clone();
681        let initial = opts.initial_interval;
682        let max = opts.max_interval;
683        let endpoint = url.clone();
684        let stats_task = Arc::clone(&stats);
685
686        stats.set_last_event_id(lei.read().await.clone());
687
688        let task = tokio::spawn(async move {
689            let backoff_builder = ExponentialBuilder::default()
690                .with_min_delay(initial)
691                .with_max_delay(max)
692                .with_factor(2.0)
693                .with_jitter();
694
695            let mut backoff = backoff_builder.build();
696
697            loop {
698                if cancel_clone.is_cancelled() {
699                    break;
700                }
701
702                let mut req = http.get(&url);
703                if let Some(d) = &dir {
704                    req = req.header("x-opencode-directory", d);
705                }
706                if let Some(id) = lei.read().await.clone() {
707                    req = req.header("Last-Event-ID", id);
708                }
709
710                let es_result = EventSource::new(req);
711                let mut es = match es_result {
712                    Ok(es) => es,
713                    Err(e) => {
714                        tracing::warn!(
715                            "Failed to create raw EventSource endpoint={} directory={:?}: {:?}",
716                            endpoint,
717                            dir,
718                            e
719                        );
720                        if let Some(delay) = backoff.next() {
721                            stats_task.reconnects.fetch_add(1, Ordering::Relaxed);
722                            tokio::select! {
723                                () = tokio::time::sleep(delay) => {}
724                                () = cancel_clone.cancelled() => { return; }
725                            }
726                        }
727                        continue;
728                    }
729                };
730
731                while let Some(event) = es.next().await {
732                    if cancel_clone.is_cancelled() {
733                        es.close();
734                        return;
735                    }
736
737                    match event {
738                        Ok(EsEvent::Open) => {
739                            backoff = backoff_builder.build();
740                            tracing::debug!(
741                                "SSE raw connection opened endpoint={} directory={:?}",
742                                endpoint,
743                                dir
744                            );
745                        }
746                        Ok(EsEvent::Message(msg)) => {
747                            stats_task.events_in.fetch_add(1, Ordering::Relaxed);
748                            if !msg.id.is_empty() {
749                                *lei.write().await = Some(msg.id.clone());
750                                stats_task.set_last_event_id(Some(msg.id.clone()));
751                            }
752
753                            let raw = RawSseEvent {
754                                id: msg.id,
755                                event: msg.event,
756                                data: msg.data,
757                            };
758
759                            if tx.send(raw).await.is_err() {
760                                stats_task.dropped.fetch_add(1, Ordering::Relaxed);
761                                es.close();
762                                return;
763                            }
764                            stats_task.events_out.fetch_add(1, Ordering::Relaxed);
765                        }
766                        Err(e) => {
767                            tracing::warn!(
768                                "SSE raw error endpoint={} directory={:?}: {:?}",
769                                endpoint,
770                                dir,
771                                e
772                            );
773                            es.close();
774                            break;
775                        }
776                    }
777                }
778
779                if let Some(delay) = backoff.next() {
780                    stats_task.reconnects.fetch_add(1, Ordering::Relaxed);
781                    tracing::debug!(
782                        "SSE raw reconnecting endpoint={} directory={:?} after {:?}",
783                        endpoint,
784                        dir,
785                        delay
786                    );
787                    tokio::select! {
788                        () = tokio::time::sleep(delay) => {}
789                        () = cancel_clone.cancelled() => { return; }
790                    }
791                }
792            }
793        });
794
795        Ok(RawSseSubscription {
796            rx,
797            stats,
798            cancel,
799            _task: task,
800        })
801    }
802}
803
804#[cfg(test)]
805mod tests {
806    // TODO(2): Add tests for session filtering logic (lines 216-219), Last-Event-ID
807    // tracking/resume behavior (lines 208-210, 176-178), and backoff timing (with
808    // tokio time mocking).
809    use super::*;
810
811    #[test]
812    fn test_sse_options_defaults() {
813        let opts = SseOptions::default();
814        assert_eq!(opts.capacity, 256);
815        assert_eq!(opts.initial_interval, Duration::from_millis(250));
816        assert_eq!(opts.max_interval, Duration::from_secs(30));
817    }
818
819    #[tokio::test]
820    async fn test_subscription_cancel_on_close() {
821        let subscriber = SseSubscriber::new(
822            "http://localhost:9999".to_string(),
823            None,
824            Arc::new(RwLock::new(None)),
825        );
826
827        // This will fail to connect but we can test cancellation
828        let opts = SseOptions {
829            capacity: 1,
830            initial_interval: Duration::from_millis(10),
831            max_interval: Duration::from_millis(50),
832        };
833
834        let subscription = subscriber.subscribe_global(opts).await.unwrap();
835        assert_eq!(subscription.stats().events_in, 0);
836        subscription.close();
837        // Subscription should be cancelled
838        assert!(subscription.cancel.is_cancelled());
839    }
840
841    #[test]
842    fn test_extract_session_id_from_raw_event_accepts_session_id_variants() {
843        let message_part_with_pascal =
844            r#"{"type":"message.part.updated","properties":{"part":{"sessionID":"sess-a"}}}"#;
845        assert_eq!(
846            extract_session_id_from_raw_event(message_part_with_pascal),
847            Some("sess-a".to_string())
848        );
849
850        let message_part_with_camel =
851            r#"{"type":"message.part.updated","properties":{"part":{"sessionId":"sess-b"}}}"#;
852        assert_eq!(
853            extract_session_id_from_raw_event(message_part_with_camel),
854            Some("sess-b".to_string())
855        );
856
857        let session_idle_with_camel =
858            r#"{"type":"session.idle","properties":{"sessionId":"sess-c"}}"#;
859        assert_eq!(
860            extract_session_id_from_raw_event(session_idle_with_camel),
861            Some("sess-c".to_string())
862        );
863    }
864
865    #[test]
866    fn test_should_forward_event_drops_events_without_session_id_when_filtered() {
867        let unknown_json = r#"{"type":"server.connected","properties":{}}"#;
868        let event: Event = serde_json::from_str(unknown_json).unwrap();
869
870        assert!(should_forward_event(None, unknown_json, &event));
871        assert!(!should_forward_event(
872            Some("sess-123"),
873            unknown_json,
874            &event
875        ));
876    }
877
878    #[test]
879    fn test_should_forward_event_for_message_part_uses_raw_js_parity_fields() {
880        // sessionId is present only at the top-level properties object.
881        // For message.part.updated filtering we intentionally use raw parity
882        // with JS (`properties.part.sessionID|sessionId`) and drop when missing.
883        let json =
884            r#"{"type":"message.part.updated","properties":{"sessionId":"sess-top","delta":"hi"}}"#;
885        let event: Event = serde_json::from_str(json).unwrap();
886
887        assert!(!should_forward_event(Some("sess-top"), json, &event));
888    }
889
890    #[test]
891    fn test_extract_session_id_for_routing_prefers_raw_parity_fields() {
892        let json = r#"{"type":"message.part.updated","properties":{"part":{"type":"text","text":"","sessionID":"sess-nested"},"sessionId":"sess-top"}}"#;
893        let event: Event = serde_json::from_str(json).unwrap();
894
895        assert_eq!(
896            extract_session_id_for_routing(json, &event),
897            Some("sess-nested".to_string())
898        );
899    }
900
901    #[test]
902    fn test_extract_session_id_for_routing_falls_back_to_typed_fields() {
903        let json = r#"{"type":"message.updated","properties":{"info":{"id":"m1","sessionId":"sess-typed","role":"assistant","time":{"created":1}}}}"#;
904        let event: Event = serde_json::from_str(json).unwrap();
905
906        assert_eq!(
907            extract_session_id_for_routing(json, &event),
908            Some("sess-typed".to_string())
909        );
910    }
911
912    #[test]
913    fn test_concurrent_session_filtering_no_delta_cross_contamination() {
914        let raw_events = [
915            r#"{"type":"message.part.updated","properties":{"part":{"type":"text","text":"","sessionID":"sess-a"},"delta":"alpha"}}"#,
916            r#"{"type":"message.part.updated","properties":{"part":{"type":"text","text":"","sessionID":"sess-b"},"delta":"bravo"}}"#,
917            r#"{"type":"server.heartbeat","properties":{}}"#,
918            r#"{"type":"message.part.updated","properties":{"part":{"type":"text","text":"","sessionID":"sess-a"},"delta":"-2"}}"#,
919            r#"{"type":"message.part.updated","properties":{"part":{"type":"text","text":"","sessionID":"sess-b"},"delta":"-2"}}"#,
920        ];
921
922        let mut a = String::new();
923        let mut b = String::new();
924
925        for raw in raw_events {
926            let ev: Event = serde_json::from_str(raw).unwrap();
927
928            if should_forward_event(Some("sess-a"), raw, &ev)
929                && let Event::MessagePartUpdated { properties } = &ev
930                && let Some(delta) = &properties.delta
931            {
932                a.push_str(delta);
933            }
934
935            if should_forward_event(Some("sess-b"), raw, &ev)
936                && let Event::MessagePartUpdated { properties } = &ev
937                && let Some(delta) = &properties.delta
938            {
939                b.push_str(delta);
940            }
941        }
942
943        assert_eq!(a, "alpha-2");
944        assert_eq!(b, "bravo-2");
945    }
946
947    #[tokio::test]
948    async fn test_subscribe_raw_yields_payloads() {
949        use std::io::{Read, Write};
950        use std::net::TcpListener;
951
952        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
953        let addr = listener.local_addr().unwrap();
954
955        let server = std::thread::spawn(move || {
956            let (mut stream, _) = listener.accept().unwrap();
957
958            let mut buf = [0_u8; 1024];
959            let _ = stream.read(&mut buf);
960
961            let body = concat!(
962                "id: 1\n",
963                "event: message\n",
964                "data: {\"type\":\"server.connected\",\"properties\":{}}\n",
965                "\n",
966                "id: 2\n",
967                "event: message\n",
968                "data: {\"type\":\"server.heartbeat\",\"properties\":{}}\n",
969                "\n"
970            );
971
972            let response = format!(
973                "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncache-control: no-cache\r\nconnection: close\r\ncontent-length: {}\r\n\r\n{}",
974                body.len(),
975                body
976            );
977
978            let _ = stream.write_all(response.as_bytes());
979            let _ = stream.flush();
980        });
981
982        let subscriber = SseSubscriber::new(
983            format!("http://{}", addr),
984            None,
985            Arc::new(RwLock::new(None)),
986        );
987
988        let mut sub = subscriber
989            .subscribe_raw(SseOptions {
990                capacity: 8,
991                initial_interval: Duration::from_millis(10),
992                max_interval: Duration::from_millis(20),
993            })
994            .await
995            .unwrap();
996
997        let first = tokio::time::timeout(Duration::from_secs(2), sub.recv())
998            .await
999            .unwrap()
1000            .unwrap();
1001        assert_eq!(first.id, "1");
1002        assert_eq!(first.event, "message");
1003        assert!(first.data.contains("server.connected"));
1004
1005        let second = tokio::time::timeout(Duration::from_secs(2), sub.recv())
1006            .await
1007            .unwrap()
1008            .unwrap();
1009        assert_eq!(second.id, "2");
1010        assert!(second.data.contains("server.heartbeat"));
1011
1012        let stats = sub.stats();
1013        assert_eq!(stats.events_in, 2);
1014        assert_eq!(stats.events_out, 2);
1015        assert_eq!(stats.dropped, 0);
1016        assert_eq!(stats.parse_errors, 0);
1017        assert_eq!(stats.last_event_id.as_deref(), Some("2"));
1018
1019        sub.close();
1020        let _ = server.join();
1021    }
1022
1023    #[tokio::test]
1024    async fn test_subscribe_typed_tracks_parse_errors_and_drops() {
1025        use std::io::{Read, Write};
1026        use std::net::TcpListener;
1027
1028        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1029        let addr = listener.local_addr().unwrap();
1030
1031        let server = std::thread::spawn(move || {
1032            let (mut stream, _) = listener.accept().unwrap();
1033
1034            let mut buf = [0_u8; 1024];
1035            let _ = stream.read(&mut buf);
1036
1037            let body = concat!(
1038                "id: 1\n",
1039                "event: message\n",
1040                "data: {\"type\":\"server.connected\",\"properties\":{}}\n",
1041                "\n",
1042                "id: 2\n",
1043                "event: message\n",
1044                "data: not-json\n",
1045                "\n",
1046                "id: 3\n",
1047                "event: message\n",
1048                "data: {\"type\":\"server.heartbeat\",\"properties\":{}}\n",
1049                "\n"
1050            );
1051
1052            let response = format!(
1053                "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncache-control: no-cache\r\nconnection: close\r\ncontent-length: {}\r\n\r\n{}",
1054                body.len(),
1055                body
1056            );
1057
1058            let _ = stream.write_all(response.as_bytes());
1059            let _ = stream.flush();
1060        });
1061
1062        let subscriber = SseSubscriber::new(
1063            format!("http://{}", addr),
1064            None,
1065            Arc::new(RwLock::new(None)),
1066        );
1067
1068        let mut sub = subscriber
1069            .subscribe_typed(SseOptions {
1070                capacity: 8,
1071                initial_interval: Duration::from_millis(10),
1072                max_interval: Duration::from_millis(20),
1073            })
1074            .await
1075            .unwrap();
1076
1077        let first = tokio::time::timeout(Duration::from_secs(2), sub.recv())
1078            .await
1079            .unwrap()
1080            .unwrap();
1081        assert!(matches!(first, Event::ServerConnected { .. }));
1082
1083        let second = tokio::time::timeout(Duration::from_secs(2), sub.recv())
1084            .await
1085            .unwrap()
1086            .unwrap();
1087        assert!(matches!(second, Event::ServerHeartbeat { .. }));
1088
1089        tokio::time::sleep(Duration::from_millis(50)).await;
1090
1091        let stats = sub.stats();
1092        assert_eq!(stats.events_in, 3);
1093        assert_eq!(stats.events_out, 2);
1094        assert_eq!(stats.dropped, 1);
1095        assert_eq!(stats.parse_errors, 1);
1096        assert_eq!(stats.last_event_id.as_deref(), Some("3"));
1097
1098        sub.close();
1099        let _ = server.join();
1100    }
1101
1102    #[tokio::test]
1103    async fn test_session_event_router_exposes_upstream_stats() {
1104        use std::io::{Read, Write};
1105        use std::net::TcpListener;
1106
1107        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1108        let addr = listener.local_addr().unwrap();
1109
1110        let server = std::thread::spawn(move || {
1111            let (mut stream, _) = listener.accept().unwrap();
1112
1113            let mut buf = [0_u8; 1024];
1114            let _ = stream.read(&mut buf);
1115
1116            let body = concat!(
1117                "id: 9\n",
1118                "event: message\n",
1119                "data: {\"type\":\"message.removed\",\"properties\":{\"sessionId\":\"sess-a\",\"messageId\":\"msg-1\"}}\n",
1120                "\n"
1121            );
1122
1123            let response = format!(
1124                "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncache-control: no-cache\r\nconnection: close\r\ncontent-length: {}\r\n\r\n{}",
1125                body.len(),
1126                body
1127            );
1128
1129            let _ = stream.write_all(response.as_bytes());
1130            let _ = stream.flush();
1131        });
1132
1133        let subscriber = SseSubscriber::new(
1134            format!("http://{}", addr),
1135            None,
1136            Arc::new(RwLock::new(None)),
1137        );
1138
1139        let router = subscriber
1140            .session_event_router(SessionEventRouterOptions {
1141                upstream: SseOptions {
1142                    capacity: 8,
1143                    initial_interval: Duration::from_millis(10),
1144                    max_interval: Duration::from_millis(20),
1145                },
1146                session_capacity: 8,
1147                subscriber_capacity: 8,
1148            })
1149            .await
1150            .unwrap();
1151
1152        let mut session_sub = router.subscribe("sess-a").await;
1153        let event = tokio::time::timeout(Duration::from_secs(2), session_sub.recv())
1154            .await
1155            .unwrap()
1156            .unwrap();
1157        assert!(matches!(event, Event::MessageRemoved { .. }));
1158
1159        tokio::time::sleep(Duration::from_millis(50)).await;
1160
1161        let stats = router.stats();
1162        assert_eq!(stats.events_in, 1);
1163        assert_eq!(stats.events_out, 1);
1164        assert_eq!(stats.dropped, 0);
1165        assert_eq!(stats.parse_errors, 0);
1166        assert_eq!(stats.last_event_id.as_deref(), Some("9"));
1167
1168        session_sub.close();
1169        router.close();
1170        let _ = server.join();
1171    }
1172}