Skip to main content

modkit/http/
sse.rs

1use axum::response::IntoResponse;
2use axum::response::sse::{Event, KeepAlive, Sse};
3use futures_core::Stream;
4use futures_util::StreamExt;
5use serde::Serialize;
6use std::{borrow::Cow, convert::Infallible, time::Duration};
7use tokio::sync::broadcast;
8use tokio_stream::wrappers::BroadcastStream;
9
10/// Small typed SSE broadcaster built on `tokio::sync::broadcast`.
11/// - T must be `Clone` so multiple subscribers can receive the same payload.
12/// - Bounded channel drops oldest events when subscribers lag (by design).
13#[derive(Clone)]
14pub struct SseBroadcaster<T> {
15    tx: broadcast::Sender<T>,
16}
17
18impl<T: Clone + Send + 'static> SseBroadcaster<T> {
19    /// Create a broadcaster with bounded buffer capacity.
20    #[must_use]
21    pub fn new(capacity: usize) -> Self {
22        let (tx, _rx) = broadcast::channel(capacity);
23        Self { tx }
24    }
25
26    /// Broadcast a single message to current subscribers.
27    /// Errors are ignored to keep the hot path cheap (e.g., no active subscribers).
28    pub fn send(&self, value: T) {
29        _ = self.tx.send(value);
30    }
31
32    /// Subscribe to a typed stream of messages; lag/drop errors are filtered out.
33    pub fn subscribe_stream(&self) -> impl Stream<Item = T> + use<T> {
34        BroadcastStream::new(self.tx.subscribe()).filter_map(|res| async move { res.ok() })
35    }
36
37    /// Convert a typed stream into an SSE stream with JSON payloads (no event name).
38    fn wrap_stream_as_sse<U>(stream: U) -> impl Stream<Item = Result<Event, Infallible>>
39    where
40        U: Stream<Item = T>,
41        T: Serialize,
42    {
43        stream.map(|msg| {
44            let ev = Event::default().json_data(&msg).unwrap_or_else(|_| {
45                // Fallback to a tiny text marker instead of breaking the stream.
46                Event::default().data("serialization_error")
47            });
48            Ok(ev)
49        })
50    }
51
52    /// Convert a typed stream into an SSE stream with JSON payloads and a constant `event:` name.
53    fn wrap_stream_as_sse_named<U>(
54        stream: U,
55        event_name: Cow<'static, str>,
56    ) -> impl Stream<Item = Result<Event, Infallible>>
57    where
58        U: Stream<Item = T>,
59        T: Serialize,
60    {
61        stream.map(move |msg| {
62            let ev = Event::default()
63                .event(&event_name) // <-- set event name
64                .json_data(&msg)
65                .unwrap_or_else(|_| {
66                    Event::default()
67                        .event(&event_name)
68                        .data("serialization_error")
69                });
70            Ok(ev)
71        })
72    }
73
74    // -------------------------
75    // Plain (unnamed) variants
76    // -------------------------
77
78    /// Plain SSE (no extra headers), unnamed events.
79    /// Includes periodic keepalive pings to avoid idle timeouts.
80    pub fn sse_response(&self) -> Sse<impl Stream<Item = Result<Event, Infallible>> + use<T>>
81    where
82        T: Serialize,
83    {
84        let stream = Self::wrap_stream_as_sse(self.subscribe_stream());
85        Sse::new(stream).keep_alive(
86            KeepAlive::new()
87                .interval(Duration::from_secs(15))
88                .text("keepalive"),
89        )
90    }
91
92    /// SSE with custom headers applied on top of the Sse response (unnamed events).
93    pub fn sse_response_with_headers<I>(&self, headers: I) -> axum::response::Response
94    where
95        T: Serialize,
96        I: IntoIterator<Item = (axum::http::HeaderName, axum::http::HeaderValue)>,
97    {
98        let mut resp = self.sse_response().into_response();
99        let dst = resp.headers_mut();
100        for (name, value) in headers {
101            dst.insert(name, value);
102        }
103        resp
104    }
105
106    // -------------------------
107    // Named-event variants
108    // -------------------------
109
110    /// Plain SSE with a constant `event:` name for all messages (no extra headers).
111    pub fn sse_response_named<N>(
112        &self,
113        event_name: N,
114    ) -> Sse<impl Stream<Item = Result<Event, Infallible>> + use<T, N>>
115    where
116        T: Serialize,
117        N: Into<Cow<'static, str>> + 'static,
118    {
119        let stream = Self::wrap_stream_as_sse_named(self.subscribe_stream(), event_name.into());
120        Sse::new(stream).keep_alive(
121            KeepAlive::new()
122                .interval(Duration::from_secs(15))
123                .text("keepalive"),
124        )
125    }
126
127    /// SSE with custom headers and a constant `event:` name for all messages.
128    pub fn sse_response_named_with_headers<I>(
129        &self,
130        event_name: impl Into<Cow<'static, str>> + 'static,
131        headers: I,
132    ) -> axum::response::Response
133    where
134        T: Serialize,
135        I: IntoIterator<Item = (axum::http::HeaderName, axum::http::HeaderValue)>,
136    {
137        let mut resp = self.sse_response_named(event_name).into_response();
138        let dst = resp.headers_mut();
139        for (name, value) in headers {
140            dst.insert(name, value);
141        }
142        resp
143    }
144}
145
146#[cfg(test)]
147#[cfg_attr(coverage_nightly, coverage(off))]
148mod tests {
149    use super::*;
150    use futures_util::StreamExt;
151    use std::sync::{
152        Arc,
153        atomic::{AtomicUsize, Ordering},
154    };
155    use tokio::time::{Duration, timeout};
156
157    #[tokio::test]
158    async fn broadcaster_delivers_single_event() {
159        let b = SseBroadcaster::<u32>::new(16);
160        let mut sub = Box::pin(b.subscribe_stream());
161        b.send(42);
162        let v = timeout(Duration::from_millis(200), sub.next())
163            .await
164            .unwrap();
165        assert_eq!(v, Some(42));
166    }
167
168    #[tokio::test]
169    async fn broadcaster_handles_backpressure_with_bounded_channel() {
170        // Test that bounded channel drops old events when capacity is exceeded
171        let capacity = 4;
172        let broadcaster = SseBroadcaster::<u32>::new(capacity);
173
174        // Create a slow consumer that doesn't read immediately
175        let mut subscriber = Box::pin(broadcaster.subscribe_stream());
176
177        // Send more events than capacity
178        let num_events = capacity * 2;
179        for i in 0..num_events {
180            broadcaster.send(u32::try_from(i).unwrap());
181        }
182
183        // The subscriber should only receive the most recent events
184        // due to the bounded channel dropping older ones
185        let mut received = Vec::new();
186
187        // Try to receive all events with a timeout
188        for _ in 0..num_events {
189            match timeout(Duration::from_millis(10), subscriber.next()).await {
190                Ok(Some(event)) => received.push(event),
191                Ok(None) | Err(_) => break, // None or timeout
192            }
193        }
194
195        // Should have received some events, but not necessarily all
196        // due to backpressure handling
197        assert!(!received.is_empty());
198        assert!(received.len() <= num_events);
199
200        // The events we did receive should be in order
201        for window in received.windows(2) {
202            assert!(window[0] < window[1], "Events should be in order");
203        }
204    }
205
206    #[tokio::test]
207    async fn broadcaster_handles_multiple_subscribers_with_backpressure() {
208        let capacity = 8;
209        let broadcaster = SseBroadcaster::<String>::new(capacity);
210
211        // Create multiple subscribers with different consumption rates
212        let mut fast_subscriber = Box::pin(broadcaster.subscribe_stream());
213        let mut slow_subscriber = Box::pin(broadcaster.subscribe_stream());
214
215        let events_sent = Arc::new(AtomicUsize::new(0));
216        let events_sent_clone = events_sent.clone();
217
218        // Producer task - sends events rapidly
219        let producer = tokio::spawn(async move {
220            for i in 0..50 {
221                broadcaster.send(format!("event_{i}"));
222                events_sent_clone.fetch_add(1, Ordering::SeqCst);
223                tokio::task::yield_now().await; // Allow other tasks to run
224            }
225        });
226
227        // Fast consumer task
228        let fast_events = Arc::new(AtomicUsize::new(0));
229        let fast_events_clone = fast_events.clone();
230        let fast_consumer = tokio::spawn(async move {
231            while let Ok(Some(_event)) =
232                timeout(Duration::from_millis(100), fast_subscriber.next()).await
233            {
234                fast_events_clone.fetch_add(1, Ordering::SeqCst);
235            }
236        });
237
238        // Slow consumer task
239        let slow_events = Arc::new(AtomicUsize::new(0));
240        let slow_events_clone = slow_events.clone();
241        let slow_consumer = tokio::spawn(async move {
242            while let Ok(Some(_event)) =
243                timeout(Duration::from_millis(100), slow_subscriber.next()).await
244            {
245                slow_events_clone.fetch_add(1, Ordering::SeqCst);
246                // Simulate slow processing
247                tokio::time::sleep(Duration::from_millis(5)).await;
248            }
249        });
250
251        // Wait for producer to finish
252        producer.await.unwrap();
253
254        // Give consumers time to process
255        tokio::time::sleep(Duration::from_millis(200)).await;
256
257        // Cancel consumers
258        fast_consumer.abort();
259        slow_consumer.abort();
260
261        let total_sent = events_sent.load(Ordering::SeqCst);
262        let fast_received = fast_events.load(Ordering::SeqCst);
263        let slow_received = slow_events.load(Ordering::SeqCst);
264
265        assert_eq!(total_sent, 50);
266
267        // Fast consumer should receive more events than slow consumer
268        // due to backpressure affecting the slow consumer more
269        assert!(fast_received > 0);
270        assert!(slow_received > 0);
271
272        // Due to bounded channel, neither consumer necessarily receives all events
273        // but the system should remain stable
274        println!(
275            "Sent: {total_sent}, Fast received: {fast_received}, Slow received: {slow_received}"
276        );
277    }
278
279    #[tokio::test]
280    #[allow(clippy::assertions_on_constants)]
281    async fn broadcaster_prevents_unbounded_memory_growth() {
282        let small_capacity = 2;
283        let broadcaster = SseBroadcaster::<Vec<u8>>::new(small_capacity);
284
285        // Create a subscriber but don't consume from it
286        let _subscriber = broadcaster.subscribe_stream();
287
288        // Send many large events
289        for i in 0..100 {
290            let large_event = vec![u8::try_from(i).unwrap(); 1024]; // 1KB per event
291            broadcaster.send(large_event);
292        }
293
294        // The broadcaster should not accumulate unbounded memory
295        // This test mainly ensures we don't panic or run out of memory
296        // The bounded channel should drop old events automatically
297
298        // Verify we can still send and the system is responsive
299        broadcaster.send(vec![255; 1024]);
300
301        // Test passes if we reach here without OOM or panic
302        assert!(true);
303    }
304
305    #[tokio::test]
306    async fn broadcaster_handles_subscriber_drop_gracefully() {
307        let broadcaster = SseBroadcaster::<u32>::new(16);
308
309        // Create and immediately drop a subscriber
310        {
311            let _subscriber = broadcaster.subscribe_stream();
312            broadcaster.send(1);
313        } // subscriber dropped here
314
315        // Broadcaster should continue working with new subscribers
316        let mut new_subscriber = Box::pin(broadcaster.subscribe_stream());
317        broadcaster.send(2);
318
319        let received = timeout(Duration::from_millis(100), new_subscriber.next())
320            .await
321            .unwrap();
322        assert_eq!(received, Some(2));
323    }
324
325    #[tokio::test]
326    async fn broadcaster_send_is_non_blocking() {
327        let broadcaster = SseBroadcaster::<u32>::new(1); // Very small capacity
328
329        // Send should not block even when no subscribers exist
330        let start = std::time::Instant::now();
331        for i in 0..1000 {
332            broadcaster.send(i);
333        }
334        let elapsed = start.elapsed();
335
336        // Should complete very quickly since send() doesn't block
337        assert!(
338            elapsed < Duration::from_millis(100),
339            "Send operations took too long: {elapsed:?}"
340        );
341    }
342}