1use axum::response::IntoResponse;
2use axum::response::sse::{Event, KeepAlive, Sse};
3use futures_core::Stream;
4use futures_util::StreamExt;
5use serde::Serialize;
6use std::{borrow::Cow, convert::Infallible, time::Duration};
7use tokio::sync::broadcast;
8use tokio_stream::wrappers::BroadcastStream;
9
10#[derive(Clone)]
14pub struct SseBroadcaster<T> {
15 tx: broadcast::Sender<T>,
16}
17
18impl<T: Clone + Send + 'static> SseBroadcaster<T> {
19 #[must_use]
21 pub fn new(capacity: usize) -> Self {
22 let (tx, _rx) = broadcast::channel(capacity);
23 Self { tx }
24 }
25
26 pub fn send(&self, value: T) {
29 if self.tx.send(value).is_err() {
30 tracing::trace!("SSE broadcast: no active receivers");
31 }
32 }
33
34 pub fn subscribe_stream(&self) -> impl Stream<Item = T> + use<T> {
36 BroadcastStream::new(self.tx.subscribe()).filter_map(|res| async move { res.ok() })
37 }
38
39 fn wrap_stream_as_sse<U>(stream: U) -> impl Stream<Item = Result<Event, Infallible>>
41 where
42 U: Stream<Item = T>,
43 T: Serialize,
44 {
45 stream.map(|msg| {
46 let ev = Event::default().json_data(&msg).unwrap_or_else(|_| {
47 Event::default().data("serialization_error")
49 });
50 Ok(ev)
51 })
52 }
53
54 fn wrap_stream_as_sse_named<U>(
56 stream: U,
57 event_name: Cow<'static, str>,
58 ) -> impl Stream<Item = Result<Event, Infallible>>
59 where
60 U: Stream<Item = T>,
61 T: Serialize,
62 {
63 stream.map(move |msg| {
64 let ev = Event::default()
65 .event(&event_name) .json_data(&msg)
67 .unwrap_or_else(|_| {
68 Event::default()
69 .event(&event_name)
70 .data("serialization_error")
71 });
72 Ok(ev)
73 })
74 }
75
76 pub fn sse_response(&self) -> Sse<impl Stream<Item = Result<Event, Infallible>> + use<T>>
83 where
84 T: Serialize,
85 {
86 let stream = Self::wrap_stream_as_sse(self.subscribe_stream());
87 Sse::new(stream).keep_alive(
88 KeepAlive::new()
89 .interval(Duration::from_secs(15))
90 .text("keepalive"),
91 )
92 }
93
94 pub fn sse_response_with_headers<I>(&self, headers: I) -> axum::response::Response
96 where
97 T: Serialize,
98 I: IntoIterator<Item = (axum::http::HeaderName, axum::http::HeaderValue)>,
99 {
100 let mut resp = self.sse_response().into_response();
101 let dst = resp.headers_mut();
102 for (name, value) in headers {
103 dst.insert(name, value);
104 }
105 resp
106 }
107
108 pub fn sse_response_named<N>(
114 &self,
115 event_name: N,
116 ) -> Sse<impl Stream<Item = Result<Event, Infallible>> + use<T, N>>
117 where
118 T: Serialize,
119 N: Into<Cow<'static, str>> + 'static,
120 {
121 let stream = Self::wrap_stream_as_sse_named(self.subscribe_stream(), event_name.into());
122 Sse::new(stream).keep_alive(
123 KeepAlive::new()
124 .interval(Duration::from_secs(15))
125 .text("keepalive"),
126 )
127 }
128
129 pub fn sse_response_named_with_headers<I>(
131 &self,
132 event_name: impl Into<Cow<'static, str>> + 'static,
133 headers: I,
134 ) -> axum::response::Response
135 where
136 T: Serialize,
137 I: IntoIterator<Item = (axum::http::HeaderName, axum::http::HeaderValue)>,
138 {
139 let mut resp = self.sse_response_named(event_name).into_response();
140 let dst = resp.headers_mut();
141 for (name, value) in headers {
142 dst.insert(name, value);
143 }
144 resp
145 }
146}
147
148#[cfg(test)]
149#[cfg_attr(coverage_nightly, coverage(off))]
150mod tests {
151 use super::*;
152 use futures_util::StreamExt;
153 use std::sync::{
154 Arc,
155 atomic::{AtomicUsize, Ordering},
156 };
157 use tokio::time::{Duration, timeout};
158
159 #[tokio::test]
160 async fn broadcaster_delivers_single_event() {
161 let b = SseBroadcaster::<u32>::new(16);
162 let mut sub = Box::pin(b.subscribe_stream());
163 b.send(42);
164 let v = timeout(Duration::from_millis(200), sub.next())
165 .await
166 .unwrap();
167 assert_eq!(v, Some(42));
168 }
169
170 #[tokio::test]
171 async fn broadcaster_handles_backpressure_with_bounded_channel() {
172 let capacity = 4;
174 let broadcaster = SseBroadcaster::<u32>::new(capacity);
175
176 let mut subscriber = Box::pin(broadcaster.subscribe_stream());
178
179 let num_events = capacity * 2;
181 for i in 0..num_events {
182 broadcaster.send(u32::try_from(i).unwrap());
183 }
184
185 let mut received = Vec::new();
188
189 for _ in 0..num_events {
191 match timeout(Duration::from_millis(10), subscriber.next()).await {
192 Ok(Some(event)) => received.push(event),
193 Ok(None) | Err(_) => break, }
195 }
196
197 assert!(!received.is_empty());
200 assert!(received.len() <= num_events);
201
202 for window in received.windows(2) {
204 assert!(window[0] < window[1], "Events should be in order");
205 }
206 }
207
208 #[tokio::test]
209 async fn broadcaster_handles_multiple_subscribers_with_backpressure() {
210 let capacity = 8;
211 let broadcaster = SseBroadcaster::<String>::new(capacity);
212
213 let mut fast_subscriber = Box::pin(broadcaster.subscribe_stream());
215 let mut slow_subscriber = Box::pin(broadcaster.subscribe_stream());
216
217 let events_sent = Arc::new(AtomicUsize::new(0));
218 let events_sent_clone = events_sent.clone();
219
220 let producer = tokio::spawn(async move {
222 for i in 0..50 {
223 broadcaster.send(format!("event_{i}"));
224 events_sent_clone.fetch_add(1, Ordering::SeqCst);
225 tokio::task::yield_now().await; }
227 });
228
229 let fast_events = Arc::new(AtomicUsize::new(0));
231 let fast_events_clone = fast_events.clone();
232 let fast_consumer = tokio::spawn(async move {
233 while let Ok(Some(_event)) =
234 timeout(Duration::from_millis(100), fast_subscriber.next()).await
235 {
236 fast_events_clone.fetch_add(1, Ordering::SeqCst);
237 }
238 });
239
240 let slow_events = Arc::new(AtomicUsize::new(0));
242 let slow_events_clone = slow_events.clone();
243 let slow_consumer = tokio::spawn(async move {
244 while let Ok(Some(_event)) =
245 timeout(Duration::from_millis(100), slow_subscriber.next()).await
246 {
247 slow_events_clone.fetch_add(1, Ordering::SeqCst);
248 tokio::time::sleep(Duration::from_millis(5)).await;
250 }
251 });
252
253 producer.await.unwrap();
255
256 tokio::time::sleep(Duration::from_millis(200)).await;
258
259 fast_consumer.abort();
261 slow_consumer.abort();
262
263 let total_sent = events_sent.load(Ordering::SeqCst);
264 let fast_received = fast_events.load(Ordering::SeqCst);
265 let slow_received = slow_events.load(Ordering::SeqCst);
266
267 assert_eq!(total_sent, 50);
268
269 assert!(fast_received > 0);
272 assert!(slow_received > 0);
273
274 println!(
277 "Sent: {total_sent}, Fast received: {fast_received}, Slow received: {slow_received}"
278 );
279 }
280
281 #[tokio::test]
282 #[allow(clippy::assertions_on_constants)]
283 async fn broadcaster_prevents_unbounded_memory_growth() {
284 let small_capacity = 2;
285 let broadcaster = SseBroadcaster::<Vec<u8>>::new(small_capacity);
286
287 let _subscriber = broadcaster.subscribe_stream();
289
290 for i in 0..100 {
292 let large_event = vec![u8::try_from(i).unwrap(); 1024]; broadcaster.send(large_event);
294 }
295
296 broadcaster.send(vec![255; 1024]);
302
303 assert!(true);
305 }
306
307 #[tokio::test]
308 async fn broadcaster_handles_subscriber_drop_gracefully() {
309 let broadcaster = SseBroadcaster::<u32>::new(16);
310
311 {
313 let _subscriber = broadcaster.subscribe_stream();
314 broadcaster.send(1);
315 } let mut new_subscriber = Box::pin(broadcaster.subscribe_stream());
319 broadcaster.send(2);
320
321 let received = timeout(Duration::from_millis(100), new_subscriber.next())
322 .await
323 .unwrap();
324 assert_eq!(received, Some(2));
325 }
326
327 #[tokio::test]
328 async fn broadcaster_send_is_non_blocking() {
329 let broadcaster = SseBroadcaster::<u32>::new(1); let start = std::time::Instant::now();
333 for i in 0..1000 {
334 broadcaster.send(i);
335 }
336 let elapsed = start.elapsed();
337
338 assert!(
340 elapsed < Duration::from_millis(100),
341 "Send operations took too long: {elapsed:?}"
342 );
343 }
344}