Skip to main content

modo/sse/
broadcaster.rs

1use crate::error::Error;
2use axum::response::{IntoResponse, Response};
3use futures_util::{FutureExt, Stream, StreamExt};
4use std::collections::HashMap;
5use std::hash::Hash;
6use std::pin::Pin;
7use std::sync::{Arc, RwLock};
8use std::task::{Context, Poll};
9use tokio::sync::broadcast;
10
11use super::config::SseConfig;
12use super::event::Event;
13
14/// Policy for handling lagged subscribers in a broadcast stream.
15#[derive(Debug, Clone, Copy)]
16pub enum LagPolicy {
17    /// End the stream on lag — client reconnects with `Last-Event-ID`.
18    End,
19    /// Skip lagged messages with a warning log, continue streaming.
20    Skip,
21}
22
23/// A stream of events from a broadcast channel.
24///
25/// Yields raw `T` values (not [`Event`](super::Event)). Convert downstream
26/// using [`SseStreamExt::cast_events()`](super::SseStreamExt::cast_events).
27///
28/// # Lag behavior
29///
30/// Configure with [`on_lag()`](Self::on_lag):
31/// - [`LagPolicy::End`] — stream terminates (safe for chat/notifications)
32/// - [`LagPolicy::Skip`] — skips missed messages (safe for dashboards)
33/// - Default (no call) — propagates lag as [`Error`]
34pub struct BroadcastStream<T> {
35    // IMPORTANT: `inner` must be declared before `_cleanup`. Rust drops fields
36    // in declaration order — the broadcast `Receiver` inside `inner` must drop
37    // first (decrementing `receiver_count`) before the cleanup closure checks it.
38    inner: Pin<Box<dyn Stream<Item = Result<T, Error>> + Send>>,
39    _cleanup: Option<Box<dyn FnOnce() + Send>>,
40}
41
42impl<T: Clone + Send + 'static> BroadcastStream<T> {
43    /// Wrap a raw [`broadcast::Receiver`] in a [`BroadcastStream`].
44    ///
45    /// Use this for ad-hoc broadcast channels not owned by a
46    /// [`Broadcaster`]. No auto-cleanup runs on drop — the channel is
47    /// managed by whoever owns the sender.
48    ///
49    /// Prefer [`Broadcaster::subscribe()`] for keyed channels, which
50    /// also cleans up empty channels when the last subscriber drops.
51    pub fn new(rx: broadcast::Receiver<T>) -> Self {
52        Self {
53            inner: Box::pin(unfold_default(rx)),
54            _cleanup: None,
55        }
56    }
57
58    /// Create a new broadcast stream with a cleanup closure.
59    pub(crate) fn with_cleanup(
60        rx: broadcast::Receiver<T>,
61        cleanup: impl FnOnce() + Send + 'static,
62    ) -> Self {
63        Self {
64            inner: Box::pin(unfold_default(rx)),
65            _cleanup: Some(Box::new(cleanup)),
66        }
67    }
68
69    /// Set the lag policy for this stream.
70    ///
71    /// - [`LagPolicy::End`] — end the stream on lag. Client reconnects with
72    ///   `Last-Event-ID` and replays from their store. Use for chat,
73    ///   notifications, anything where message loss is unacceptable.
74    /// - [`LagPolicy::Skip`] — skip lagged messages with a warning log and
75    ///   continue. Use for dashboards, metrics, anything where the next
76    ///   value supersedes the previous.
77    ///
78    /// Default (no call): propagate the lag error through the stream as
79    /// [`Error`] — caller handles it via standard stream combinators.
80    pub fn on_lag(mut self, policy: LagPolicy) -> Self {
81        // Reconstruct the inner stream with the new policy.
82        // We wrap the existing stream with policy handling.
83        let original = std::mem::replace(&mut self.inner, Box::pin(futures_util::stream::empty()));
84        self.inner = Box::pin(apply_lag_policy(original, policy));
85        self
86    }
87}
88
89/// Default unfold: propagate lag errors.
90fn unfold_default<T: Clone + Send + 'static>(
91    rx: broadcast::Receiver<T>,
92) -> impl Stream<Item = Result<T, Error>> {
93    futures_util::stream::unfold(rx, |mut rx| async move {
94        match rx.recv().await {
95            Ok(item) => Some((Ok(item), rx)),
96            Err(broadcast::error::RecvError::Lagged(n)) => Some((Err(Error::lagged(n)), rx)),
97            Err(broadcast::error::RecvError::Closed) => None,
98        }
99    })
100}
101
102/// Wrap a stream with lag policy handling.
103fn apply_lag_policy<T: Send + 'static>(
104    stream: Pin<Box<dyn Stream<Item = Result<T, Error>> + Send>>,
105    policy: LagPolicy,
106) -> impl Stream<Item = Result<T, Error>> + Send {
107    futures_util::stream::unfold(stream, move |mut stream| async move {
108        use futures_util::StreamExt;
109        loop {
110            match stream.next().await {
111                Some(Ok(item)) => return Some((Ok(item), stream)),
112                Some(Err(e)) if e.is_lagged() => match policy {
113                    LagPolicy::End => return None,
114                    LagPolicy::Skip => {
115                        tracing::warn!("{e}");
116                        continue;
117                    }
118                },
119                Some(Err(e)) => return Some((Err(e), stream)),
120                None => return None,
121            }
122        }
123    })
124}
125
126impl<T> Drop for BroadcastStream<T> {
127    fn drop(&mut self) {
128        if let Some(cleanup) = self._cleanup.take() {
129            cleanup();
130        }
131    }
132}
133
134impl<T> Stream for BroadcastStream<T> {
135    type Item = Result<T, Error>;
136
137    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
138        self.inner.as_mut().poll_next(cx)
139    }
140}
141
142/// Convert a `Vec<T>` into a `Stream<Item = Result<T, Error>>`.
143///
144/// Use this to replay missed events from a data store before chaining
145/// with a live broadcast stream on client reconnection.
146///
147/// The returned stream yields each item wrapped in `Ok`. Chain it with
148/// a live [`BroadcastStream`] using `.chain()` from
149/// [`futures_util::StreamExt`].
150pub fn replay<T>(items: Vec<T>) -> impl Stream<Item = Result<T, Error>> + Send
151where
152    T: Send + 'static,
153{
154    futures_util::stream::iter(items.into_iter().map(Ok))
155}
156
157struct BroadcasterInner<K, T> {
158    channels: RwLock<HashMap<K, broadcast::Sender<T>>>,
159    buffer: usize,
160    config: SseConfig,
161}
162
163/// Keyed broadcast channel registry for fan-out SSE delivery.
164///
165/// Each key maps to an independent broadcast channel. All subscribers of a key
166/// receive every message sent to that key. Register one broadcaster per domain
167/// concept (e.g., chat messages, notifications, metrics).
168///
169/// # Construction
170///
171/// ```
172/// use modo::sse::{Broadcaster, SseConfig};
173///
174/// # #[derive(Clone)]
175/// # struct ChatMessage;
176/// let chat: Broadcaster<String, ChatMessage> =
177///     Broadcaster::new(128, SseConfig::default());
178/// # let mut registry = modo::service::Registry::new();
179/// registry.add(chat);
180/// ```
181///
182/// # Channel lifecycle
183///
184/// - Channels are created lazily on first [`subscribe()`](Self::subscribe)
185/// - Channels are auto-cleaned when the last subscriber's stream is dropped
186/// - [`remove()`](Self::remove) forces immediate cleanup
187pub struct Broadcaster<K, T>
188where
189    K: Hash + Eq + Clone + Send + Sync + 'static,
190    T: Clone + Send + Sync + 'static,
191{
192    inner: Arc<BroadcasterInner<K, T>>,
193}
194
195impl<K, T> Clone for Broadcaster<K, T>
196where
197    K: Hash + Eq + Clone + Send + Sync + 'static,
198    T: Clone + Send + Sync + 'static,
199{
200    fn clone(&self) -> Self {
201        Self {
202            inner: Arc::clone(&self.inner),
203        }
204    }
205}
206
207impl<K, T> Broadcaster<K, T>
208where
209    K: Hash + Eq + Clone + Send + Sync + 'static,
210    T: Clone + Send + Sync + 'static,
211{
212    /// Create a new broadcaster.
213    ///
214    /// - `buffer` — per-channel buffer size. When a subscriber falls behind
215    ///   by this many messages, it lags. Typical values: 64–256 for chat,
216    ///   16–64 for dashboards.
217    /// - `config` — SSE configuration (keep-alive interval).
218    pub fn new(buffer: usize, config: SseConfig) -> Self {
219        Self {
220            inner: Arc::new(BroadcasterInner {
221                channels: RwLock::new(HashMap::new()),
222                buffer,
223                config,
224            }),
225        }
226    }
227
228    /// Subscribe to a keyed channel.
229    ///
230    /// Creates the channel lazily on first subscription. Returns a stream
231    /// of raw `T` values. The stream carries a cleanup closure that removes
232    /// the channel entry when the last subscriber drops.
233    pub fn subscribe(&self, key: &K) -> BroadcastStream<T> {
234        let mut channels = self
235            .inner
236            .channels
237            .write()
238            .unwrap_or_else(|e| e.into_inner());
239
240        let sender = channels
241            .entry(key.clone())
242            .or_insert_with(|| broadcast::channel(self.inner.buffer).0);
243        let rx = sender.subscribe();
244
245        let inner_ref = Arc::clone(&self.inner);
246        let key_owned = key.clone();
247        let cleanup = move || {
248            let mut channels = inner_ref
249                .channels
250                .write()
251                .unwrap_or_else(|e| e.into_inner());
252            if let std::collections::hash_map::Entry::Occupied(entry) = channels.entry(key_owned)
253                && entry.get().receiver_count() == 0
254            {
255                entry.remove();
256            }
257        };
258
259        BroadcastStream::with_cleanup(rx, cleanup)
260    }
261
262    /// Send an event to all subscribers of a key.
263    ///
264    /// Returns the number of receivers that got the message. Returns 0
265    /// if no subscribers exist for the key — does NOT create a channel.
266    pub fn send(&self, key: &K, event: T) -> usize {
267        let channels = self
268            .inner
269            .channels
270            .read()
271            .unwrap_or_else(|e| e.into_inner());
272        if let Some(sender) = channels.get(key) {
273            match sender.send(event) {
274                Ok(count) => count,
275                Err(_) => {
276                    drop(channels);
277                    let mut channels = self
278                        .inner
279                        .channels
280                        .write()
281                        .unwrap_or_else(|e| e.into_inner());
282                    if let std::collections::hash_map::Entry::Occupied(entry) =
283                        channels.entry(key.clone())
284                        && entry.get().receiver_count() == 0
285                    {
286                        entry.remove();
287                    }
288                    0
289                }
290            }
291        } else {
292            0
293        }
294    }
295
296    /// Number of active subscribers for a key. Returns 0 if no channel exists.
297    pub fn subscriber_count(&self, key: &K) -> usize {
298        let channels = self
299            .inner
300            .channels
301            .read()
302            .unwrap_or_else(|e| e.into_inner());
303        channels.get(key).map(|s| s.receiver_count()).unwrap_or(0)
304    }
305
306    /// Manually remove a channel and disconnect all its subscribers.
307    ///
308    /// Typically not needed — channels auto-clean on last subscriber drop.
309    /// Use for explicit teardown (e.g., deleting a chat room).
310    pub fn remove(&self, key: &K) {
311        let mut channels = self
312            .inner
313            .channels
314            .write()
315            .unwrap_or_else(|e| e.into_inner());
316        channels.remove(key);
317    }
318
319    /// Access the SSE config.
320    pub fn config(&self) -> &SseConfig {
321        &self.inner.config
322    }
323
324    /// Create an SSE response with an imperative sender.
325    ///
326    /// Spawns the closure as a tokio task. The closure receives a [`super::Sender`]
327    /// for pushing events. The task runs until:
328    /// - The closure returns `Ok(())` — stream ends cleanly
329    /// - The closure returns `Err(e)` — error is logged, stream ends
330    /// - A `tx.send()` call fails — client disconnected
331    ///
332    /// Panics in the closure are caught and logged.
333    pub fn channel<F, Fut>(&self, f: F) -> Response
334    where
335        F: FnOnce(super::Sender) -> Fut + Send + 'static,
336        Fut: std::future::Future<Output = Result<(), Error>> + Send,
337    {
338        const CHANNEL_BUFFER: usize = 32;
339        let (tx, rx) = tokio::sync::mpsc::channel(CHANNEL_BUFFER);
340        let sender = super::Sender { tx };
341
342        tokio::spawn(async move {
343            let result = std::panic::AssertUnwindSafe(f(sender)).catch_unwind().await;
344            match result {
345                Ok(Ok(())) => {}
346                Ok(Err(e)) => {
347                    tracing::debug!(error = %e, "SSE channel closure ended with error")
348                }
349                Err(_) => tracing::error!("SSE channel closure panicked"),
350            }
351        });
352
353        // Wrap the mpsc receiver as a stream of Events
354        let stream = futures_util::stream::unfold(rx, |mut rx| async move {
355            rx.recv().await.map(|event| (Ok(event), rx))
356        });
357
358        self.response(stream)
359    }
360
361    /// Wrap an event stream into an SSE HTTP response.
362    ///
363    /// Applies keep-alive comments at the configured interval and sets
364    /// the `X-Accel-Buffering: no` header for nginx compatibility.
365    pub fn response<S>(&self, stream: S) -> Response
366    where
367        S: Stream<Item = Result<Event, Error>> + Send + 'static,
368    {
369        let mapped = stream.map(|result| {
370            result
371                .map(axum::response::sse::Event::from)
372                .map_err(axum::Error::new)
373        });
374
375        let keep_alive =
376            axum::response::sse::KeepAlive::new().interval(self.inner.config.keep_alive_interval());
377
378        let mut resp = axum::response::sse::Sse::new(mapped)
379            .keep_alive(keep_alive)
380            .into_response();
381
382        resp.headers_mut()
383            .insert("x-accel-buffering", http::HeaderValue::from_static("no"));
384
385        resp
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392    use futures_util::StreamExt;
393    use tokio::sync::broadcast;
394
395    #[tokio::test]
396    async fn stream_yields_sent_values() {
397        let (tx, rx) = broadcast::channel(16);
398        let mut stream = BroadcastStream::new(rx);
399        tx.send("hello".to_string()).unwrap();
400        tx.send("world".to_string()).unwrap();
401        drop(tx);
402
403        let items: Vec<String> = stream
404            .by_ref()
405            .filter_map(|r| async { r.ok() })
406            .collect()
407            .await;
408        assert_eq!(items, vec!["hello", "world"]);
409    }
410
411    #[tokio::test]
412    async fn stream_ends_when_sender_dropped() {
413        let (tx, rx) = broadcast::channel(16);
414        let mut stream = BroadcastStream::new(rx);
415        tx.send(1).unwrap();
416        drop(tx);
417
418        assert!(stream.next().await.unwrap().is_ok()); // 1
419        assert!(stream.next().await.is_none()); // end
420    }
421
422    #[tokio::test]
423    async fn lag_policy_skip_continues_after_lag() {
424        let (tx, rx) = broadcast::channel(2);
425        let mut stream = BroadcastStream::new(rx).on_lag(LagPolicy::Skip);
426
427        // Fill buffer beyond capacity to cause lag
428        tx.send(1).unwrap();
429        tx.send(2).unwrap();
430        tx.send(3).unwrap(); // overwrites 1, receiver lags
431
432        // Should skip lagged messages and yield the latest
433        let item = stream.next().await.unwrap();
434        assert!(item.is_ok());
435    }
436
437    #[tokio::test]
438    async fn lag_policy_end_terminates_on_lag() {
439        let (tx, rx) = broadcast::channel(2);
440        let mut stream = BroadcastStream::new(rx).on_lag(LagPolicy::End);
441
442        tx.send(1).unwrap();
443        tx.send(2).unwrap();
444        tx.send(3).unwrap(); // causes lag
445
446        let item = stream.next().await;
447        assert!(item.is_none()); // stream ended
448    }
449
450    #[tokio::test]
451    async fn default_lag_policy_propagates_error() {
452        let (tx, rx) = broadcast::channel(2);
453        let mut stream = BroadcastStream::new(rx);
454
455        tx.send(1).unwrap();
456        tx.send(2).unwrap();
457        tx.send(3).unwrap(); // causes lag
458
459        let item = stream.next().await.unwrap();
460        assert!(item.is_err());
461        assert!(item.unwrap_err().is_lagged());
462    }
463
464    #[tokio::test]
465    async fn replay_yields_all_items() {
466        let items = vec!["a".to_string(), "b".to_string(), "c".to_string()];
467        let stream = replay(items);
468        let collected: Vec<String> = stream.filter_map(|r| async { r.ok() }).collect().await;
469        assert_eq!(collected, vec!["a", "b", "c"]);
470    }
471
472    #[tokio::test]
473    async fn replay_empty_vec() {
474        let stream = replay::<String>(vec![]);
475        let collected: Vec<String> = stream.filter_map(|r| async { r.ok() }).collect().await;
476        assert!(collected.is_empty());
477    }
478
479    #[tokio::test]
480    async fn cleanup_fires_on_drop() {
481        use std::sync::Arc;
482        use std::sync::atomic::{AtomicBool, Ordering};
483
484        let (tx, rx) = broadcast::channel::<i32>(16);
485        let cleaned = Arc::new(AtomicBool::new(false));
486        let cleaned_clone = cleaned.clone();
487
488        let stream = BroadcastStream::with_cleanup(rx, move || {
489            cleaned_clone.store(true, Ordering::SeqCst);
490        });
491
492        drop(stream);
493        assert!(cleaned.load(Ordering::SeqCst));
494        drop(tx);
495    }
496
497    #[tokio::test]
498    async fn broadcaster_subscribe_and_send() {
499        let bc: Broadcaster<String, String> = Broadcaster::new(16, SseConfig::default());
500        let key = "room1".to_string();
501
502        let mut stream = bc.subscribe(&key);
503        assert_eq!(bc.subscriber_count(&key), 1);
504
505        let count = bc.send(&key, "hello".into());
506        assert_eq!(count, 1);
507
508        let item = stream.next().await.unwrap().unwrap();
509        assert_eq!(item, "hello");
510    }
511
512    #[tokio::test]
513    async fn broadcaster_send_to_nonexistent_key_returns_zero() {
514        let bc: Broadcaster<String, String> = Broadcaster::new(16, SseConfig::default());
515        let count = bc.send(&"nobody".into(), "hello".into());
516        assert_eq!(count, 0);
517    }
518
519    #[tokio::test]
520    async fn broadcaster_multiple_subscribers() {
521        let bc: Broadcaster<String, i32> = Broadcaster::new(16, SseConfig::default());
522        let key = "k".to_string();
523
524        let mut s1 = bc.subscribe(&key);
525        let mut s2 = bc.subscribe(&key);
526        assert_eq!(bc.subscriber_count(&key), 2);
527
528        bc.send(&key, 42);
529        assert_eq!(s1.next().await.unwrap().unwrap(), 42);
530        assert_eq!(s2.next().await.unwrap().unwrap(), 42);
531    }
532
533    #[tokio::test]
534    async fn broadcaster_auto_cleanup_on_last_drop() {
535        let bc: Broadcaster<String, i32> = Broadcaster::new(16, SseConfig::default());
536        let key = "cleanup".to_string();
537
538        let s1 = bc.subscribe(&key);
539        let s2 = bc.subscribe(&key);
540        assert_eq!(bc.subscriber_count(&key), 2);
541
542        drop(s1);
543        // Channel still exists (s2 is alive)
544        assert_eq!(bc.subscriber_count(&key), 1);
545
546        drop(s2);
547        // Channel should be cleaned up
548        assert_eq!(bc.subscriber_count(&key), 0);
549    }
550
551    #[tokio::test]
552    async fn broadcaster_remove_disconnects_subscribers() {
553        let bc: Broadcaster<String, i32> = Broadcaster::new(16, SseConfig::default());
554        let key = "rm".to_string();
555
556        let mut stream = bc.subscribe(&key);
557        bc.remove(&key);
558
559        // Stream should end because sender was dropped
560        assert!(stream.next().await.is_none());
561    }
562
563    #[tokio::test]
564    async fn broadcaster_clone_shares_state() {
565        let bc1: Broadcaster<String, String> = Broadcaster::new(16, SseConfig::default());
566        let bc2 = bc1.clone();
567        let key = "shared".to_string();
568
569        let mut stream = bc1.subscribe(&key);
570        bc2.send(&key, "from_clone".into());
571
572        let item = stream.next().await.unwrap().unwrap();
573        assert_eq!(item, "from_clone");
574    }
575
576    #[tokio::test]
577    async fn broadcaster_channel_produces_events() {
578        let bc: Broadcaster<String, String> = Broadcaster::new(16, SseConfig::default());
579
580        let response = bc.channel(|tx| async move {
581            tx.send(super::Event::new("e1", "test").unwrap().data("hello"))
582                .await?;
583            tx.send(super::Event::new("e2", "test").unwrap().data("world"))
584                .await?;
585            Ok(())
586        });
587
588        // Response should have SSE headers
589        assert_eq!(response.headers().get("x-accel-buffering").unwrap(), "no");
590        assert_eq!(
591            response.headers().get("content-type").unwrap(),
592            "text/event-stream"
593        );
594    }
595
596    #[test]
597    fn broadcaster_config_accessible() {
598        let config = SseConfig {
599            keep_alive_interval_secs: 30,
600        };
601        let bc: Broadcaster<String, String> = Broadcaster::new(64, config);
602        assert_eq!(bc.config().keep_alive_interval_secs, 30);
603    }
604
605    #[tokio::test]
606    async fn broadcaster_response_returns_valid_response() {
607        let bc: Broadcaster<String, String> = Broadcaster::new(16, SseConfig::default());
608        let stream = futures_util::stream::empty::<Result<super::Event, crate::error::Error>>();
609        let response = bc.response(stream);
610        assert_eq!(response.headers().get("x-accel-buffering").unwrap(), "no");
611        assert_eq!(
612            response.headers().get("content-type").unwrap(),
613            "text/event-stream"
614        );
615    }
616
617    #[tokio::test]
618    async fn channel_closure_error_produces_valid_response() {
619        let bc: Broadcaster<String, String> = Broadcaster::new(16, SseConfig::default());
620
621        let response =
622            bc.channel(|_tx| async move { Err(crate::error::Error::internal("deliberate error")) });
623
624        assert_eq!(
625            response.headers().get("content-type").unwrap(),
626            "text/event-stream"
627        );
628        assert_eq!(response.headers().get("x-accel-buffering").unwrap(), "no");
629    }
630
631    #[tokio::test]
632    async fn channel_closure_panic_produces_valid_response() {
633        let bc: Broadcaster<String, String> = Broadcaster::new(16, SseConfig::default());
634
635        let response = bc.channel(|_tx| async move {
636            panic!("deliberate panic");
637        });
638
639        assert_eq!(
640            response.headers().get("content-type").unwrap(),
641            "text/event-stream"
642        );
643        assert_eq!(response.headers().get("x-accel-buffering").unwrap(), "no");
644    }
645
646    #[tokio::test]
647    async fn concurrent_subscribe_and_send() {
648        let bc: Broadcaster<String, i32> = Broadcaster::new(256, SseConfig::default());
649        let key = "concurrent".to_string();
650
651        let mut set = tokio::task::JoinSet::new();
652
653        for task_num in 0..10 {
654            let bc = bc.clone();
655            let key = key.clone();
656            set.spawn(async move {
657                let mut stream = bc.subscribe(&key);
658                bc.send(&key, task_num);
659                stream.next().await.unwrap().unwrap()
660            });
661        }
662
663        let mut results = Vec::new();
664        while let Some(result) = set.join_next().await {
665            results.push(result.expect("Task panicked"));
666        }
667
668        assert_eq!(results.len(), 10);
669    }
670}