Skip to main content

modkit/http/
sse.rs

1use axum::response::IntoResponse;
2use axum::response::sse::{Event, KeepAlive, Sse};
3use futures_core::Stream;
4use futures_util::StreamExt;
5use serde::Serialize;
6use std::{borrow::Cow, convert::Infallible, time::Duration};
7use tokio::sync::broadcast;
8use tokio_stream::wrappers::BroadcastStream;
9
10/// Small typed SSE broadcaster built on `tokio::sync::broadcast`.
11/// - T must be `Clone` so multiple subscribers can receive the same payload.
12/// - Bounded channel drops oldest events when subscribers lag (by design).
13#[derive(Clone)]
14pub struct SseBroadcaster<T> {
15    tx: broadcast::Sender<T>,
16}
17
18impl<T: Clone + Send + 'static> SseBroadcaster<T> {
19    /// Create a broadcaster with bounded buffer capacity.
20    #[must_use]
21    pub fn new(capacity: usize) -> Self {
22        let (tx, _rx) = broadcast::channel(capacity);
23        Self { tx }
24    }
25
26    /// Broadcast a single message to current subscribers.
27    /// Errors are ignored to keep the hot path cheap (e.g., no active subscribers).
28    pub fn send(&self, value: T) {
29        if self.tx.send(value).is_err() {
30            tracing::trace!("SSE broadcast: no active receivers");
31        }
32    }
33
34    /// Subscribe to a typed stream of messages; lag/drop errors are filtered out.
35    pub fn subscribe_stream(&self) -> impl Stream<Item = T> + use<T> {
36        BroadcastStream::new(self.tx.subscribe()).filter_map(|res| async move { res.ok() })
37    }
38
39    /// Convert a typed stream into an SSE stream with JSON payloads (no event name).
40    fn wrap_stream_as_sse<U>(stream: U) -> impl Stream<Item = Result<Event, Infallible>>
41    where
42        U: Stream<Item = T>,
43        T: Serialize,
44    {
45        stream.map(|msg| {
46            let ev = Event::default().json_data(&msg).unwrap_or_else(|_| {
47                // Fallback to a tiny text marker instead of breaking the stream.
48                Event::default().data("serialization_error")
49            });
50            Ok(ev)
51        })
52    }
53
54    /// Convert a typed stream into an SSE stream with JSON payloads and a constant `event:` name.
55    fn wrap_stream_as_sse_named<U>(
56        stream: U,
57        event_name: Cow<'static, str>,
58    ) -> impl Stream<Item = Result<Event, Infallible>>
59    where
60        U: Stream<Item = T>,
61        T: Serialize,
62    {
63        stream.map(move |msg| {
64            let ev = Event::default()
65                .event(&event_name) // <-- set event name
66                .json_data(&msg)
67                .unwrap_or_else(|_| {
68                    Event::default()
69                        .event(&event_name)
70                        .data("serialization_error")
71                });
72            Ok(ev)
73        })
74    }
75
76    // -------------------------
77    // Plain (unnamed) variants
78    // -------------------------
79
80    /// Plain SSE (no extra headers), unnamed events.
81    /// Includes periodic keepalive pings to avoid idle timeouts.
82    pub fn sse_response(&self) -> Sse<impl Stream<Item = Result<Event, Infallible>> + use<T>>
83    where
84        T: Serialize,
85    {
86        let stream = Self::wrap_stream_as_sse(self.subscribe_stream());
87        Sse::new(stream).keep_alive(
88            KeepAlive::new()
89                .interval(Duration::from_secs(15))
90                .text("keepalive"),
91        )
92    }
93
94    /// SSE with custom headers applied on top of the Sse response (unnamed events).
95    pub fn sse_response_with_headers<I>(&self, headers: I) -> axum::response::Response
96    where
97        T: Serialize,
98        I: IntoIterator<Item = (axum::http::HeaderName, axum::http::HeaderValue)>,
99    {
100        let mut resp = self.sse_response().into_response();
101        let dst = resp.headers_mut();
102        for (name, value) in headers {
103            dst.insert(name, value);
104        }
105        resp
106    }
107
108    // -------------------------
109    // Named-event variants
110    // -------------------------
111
112    /// Plain SSE with a constant `event:` name for all messages (no extra headers).
113    pub fn sse_response_named<N>(
114        &self,
115        event_name: N,
116    ) -> Sse<impl Stream<Item = Result<Event, Infallible>> + use<T, N>>
117    where
118        T: Serialize,
119        N: Into<Cow<'static, str>> + 'static,
120    {
121        let stream = Self::wrap_stream_as_sse_named(self.subscribe_stream(), event_name.into());
122        Sse::new(stream).keep_alive(
123            KeepAlive::new()
124                .interval(Duration::from_secs(15))
125                .text("keepalive"),
126        )
127    }
128
129    /// SSE with custom headers and a constant `event:` name for all messages.
130    pub fn sse_response_named_with_headers<I>(
131        &self,
132        event_name: impl Into<Cow<'static, str>> + 'static,
133        headers: I,
134    ) -> axum::response::Response
135    where
136        T: Serialize,
137        I: IntoIterator<Item = (axum::http::HeaderName, axum::http::HeaderValue)>,
138    {
139        let mut resp = self.sse_response_named(event_name).into_response();
140        let dst = resp.headers_mut();
141        for (name, value) in headers {
142            dst.insert(name, value);
143        }
144        resp
145    }
146}
147
148#[cfg(test)]
149#[cfg_attr(coverage_nightly, coverage(off))]
150mod tests {
151    use super::*;
152    use futures_util::StreamExt;
153    use std::sync::{
154        Arc,
155        atomic::{AtomicUsize, Ordering},
156    };
157    use tokio::time::{Duration, timeout};
158
159    #[tokio::test]
160    async fn broadcaster_delivers_single_event() {
161        let b = SseBroadcaster::<u32>::new(16);
162        let mut sub = Box::pin(b.subscribe_stream());
163        b.send(42);
164        let v = timeout(Duration::from_millis(200), sub.next())
165            .await
166            .unwrap();
167        assert_eq!(v, Some(42));
168    }
169
170    #[tokio::test]
171    async fn broadcaster_handles_backpressure_with_bounded_channel() {
172        // Test that bounded channel drops old events when capacity is exceeded
173        let capacity = 4;
174        let broadcaster = SseBroadcaster::<u32>::new(capacity);
175
176        // Create a slow consumer that doesn't read immediately
177        let mut subscriber = Box::pin(broadcaster.subscribe_stream());
178
179        // Send more events than capacity
180        let num_events = capacity * 2;
181        for i in 0..num_events {
182            broadcaster.send(u32::try_from(i).unwrap());
183        }
184
185        // The subscriber should only receive the most recent events
186        // due to the bounded channel dropping older ones
187        let mut received = Vec::new();
188
189        // Try to receive all events with a timeout
190        for _ in 0..num_events {
191            match timeout(Duration::from_millis(10), subscriber.next()).await {
192                Ok(Some(event)) => received.push(event),
193                Ok(None) | Err(_) => break, // None or timeout
194            }
195        }
196
197        // Should have received some events, but not necessarily all
198        // due to backpressure handling
199        assert!(!received.is_empty());
200        assert!(received.len() <= num_events);
201
202        // The events we did receive should be in order
203        for window in received.windows(2) {
204            assert!(window[0] < window[1], "Events should be in order");
205        }
206    }
207
208    #[tokio::test]
209    async fn broadcaster_handles_multiple_subscribers_with_backpressure() {
210        let capacity = 8;
211        let broadcaster = SseBroadcaster::<String>::new(capacity);
212
213        // Create multiple subscribers with different consumption rates
214        let mut fast_subscriber = Box::pin(broadcaster.subscribe_stream());
215        let mut slow_subscriber = Box::pin(broadcaster.subscribe_stream());
216
217        let events_sent = Arc::new(AtomicUsize::new(0));
218        let events_sent_clone = events_sent.clone();
219
220        // Producer task - sends events rapidly
221        let producer = tokio::spawn(async move {
222            for i in 0..50 {
223                broadcaster.send(format!("event_{i}"));
224                events_sent_clone.fetch_add(1, Ordering::SeqCst);
225                tokio::task::yield_now().await; // Allow other tasks to run
226            }
227        });
228
229        // Fast consumer task
230        let fast_events = Arc::new(AtomicUsize::new(0));
231        let fast_events_clone = fast_events.clone();
232        let fast_consumer = tokio::spawn(async move {
233            while let Ok(Some(_event)) =
234                timeout(Duration::from_millis(100), fast_subscriber.next()).await
235            {
236                fast_events_clone.fetch_add(1, Ordering::SeqCst);
237            }
238        });
239
240        // Slow consumer task
241        let slow_events = Arc::new(AtomicUsize::new(0));
242        let slow_events_clone = slow_events.clone();
243        let slow_consumer = tokio::spawn(async move {
244            while let Ok(Some(_event)) =
245                timeout(Duration::from_millis(100), slow_subscriber.next()).await
246            {
247                slow_events_clone.fetch_add(1, Ordering::SeqCst);
248                // Simulate slow processing
249                tokio::time::sleep(Duration::from_millis(5)).await;
250            }
251        });
252
253        // Wait for producer to finish
254        producer.await.unwrap();
255
256        // Give consumers time to process
257        tokio::time::sleep(Duration::from_millis(200)).await;
258
259        // Cancel consumers
260        fast_consumer.abort();
261        slow_consumer.abort();
262
263        let total_sent = events_sent.load(Ordering::SeqCst);
264        let fast_received = fast_events.load(Ordering::SeqCst);
265        let slow_received = slow_events.load(Ordering::SeqCst);
266
267        assert_eq!(total_sent, 50);
268
269        // Fast consumer should receive more events than slow consumer
270        // due to backpressure affecting the slow consumer more
271        assert!(fast_received > 0);
272        assert!(slow_received > 0);
273
274        // Due to bounded channel, neither consumer necessarily receives all events
275        // but the system should remain stable
276        println!(
277            "Sent: {total_sent}, Fast received: {fast_received}, Slow received: {slow_received}"
278        );
279    }
280
281    #[tokio::test]
282    #[allow(clippy::assertions_on_constants)]
283    async fn broadcaster_prevents_unbounded_memory_growth() {
284        let small_capacity = 2;
285        let broadcaster = SseBroadcaster::<Vec<u8>>::new(small_capacity);
286
287        // Create a subscriber but don't consume from it
288        let _subscriber = broadcaster.subscribe_stream();
289
290        // Send many large events
291        for i in 0..100 {
292            let large_event = vec![u8::try_from(i).unwrap(); 1024]; // 1KB per event
293            broadcaster.send(large_event);
294        }
295
296        // The broadcaster should not accumulate unbounded memory
297        // This test mainly ensures we don't panic or run out of memory
298        // The bounded channel should drop old events automatically
299
300        // Verify we can still send and the system is responsive
301        broadcaster.send(vec![255; 1024]);
302
303        // Test passes if we reach here without OOM or panic
304        assert!(true);
305    }
306
307    #[tokio::test]
308    async fn broadcaster_handles_subscriber_drop_gracefully() {
309        let broadcaster = SseBroadcaster::<u32>::new(16);
310
311        // Create and immediately drop a subscriber
312        {
313            let _subscriber = broadcaster.subscribe_stream();
314            broadcaster.send(1);
315        } // subscriber dropped here
316
317        // Broadcaster should continue working with new subscribers
318        let mut new_subscriber = Box::pin(broadcaster.subscribe_stream());
319        broadcaster.send(2);
320
321        let received = timeout(Duration::from_millis(100), new_subscriber.next())
322            .await
323            .unwrap();
324        assert_eq!(received, Some(2));
325    }
326
327    #[tokio::test]
328    async fn broadcaster_send_is_non_blocking() {
329        let broadcaster = SseBroadcaster::<u32>::new(1); // Very small capacity
330
331        // Send should not block even when no subscribers exist
332        let start = std::time::Instant::now();
333        for i in 0..1000 {
334            broadcaster.send(i);
335        }
336        let elapsed = start.elapsed();
337
338        // Should complete very quickly since send() doesn't block
339        assert!(
340            elapsed < Duration::from_millis(100),
341            "Send operations took too long: {elapsed:?}"
342        );
343    }
344}