1use axum::response::IntoResponse;
2use axum::response::sse::{Event, KeepAlive, Sse};
3use futures_core::Stream;
4use futures_util::StreamExt;
5use serde::Serialize;
6use std::{borrow::Cow, convert::Infallible, time::Duration};
7use tokio::sync::broadcast;
8use tokio_stream::wrappers::BroadcastStream;
9
10#[derive(Clone)]
14pub struct SseBroadcaster<T> {
15 tx: broadcast::Sender<T>,
16}
17
18impl<T: Clone + Send + 'static> SseBroadcaster<T> {
19 #[must_use]
21 pub fn new(capacity: usize) -> Self {
22 let (tx, _rx) = broadcast::channel(capacity);
23 Self { tx }
24 }
25
26 pub fn send(&self, value: T) {
29 let _ = self.tx.send(value);
30 }
31
32 pub fn subscribe_stream(&self) -> impl Stream<Item = T> + use<T> {
34 BroadcastStream::new(self.tx.subscribe()).filter_map(|res| async move { res.ok() })
35 }
36
37 fn wrap_stream_as_sse<U>(stream: U) -> impl Stream<Item = Result<Event, Infallible>>
39 where
40 U: Stream<Item = T>,
41 T: Serialize,
42 {
43 stream.map(|msg| {
44 let ev = Event::default().json_data(&msg).unwrap_or_else(|_| {
45 Event::default().data("serialization_error")
47 });
48 Ok(ev)
49 })
50 }
51
52 fn wrap_stream_as_sse_named<U>(
54 stream: U,
55 event_name: Cow<'static, str>,
56 ) -> impl Stream<Item = Result<Event, Infallible>>
57 where
58 U: Stream<Item = T>,
59 T: Serialize,
60 {
61 stream.map(move |msg| {
62 let ev = Event::default()
63 .event(&event_name) .json_data(&msg)
65 .unwrap_or_else(|_| {
66 Event::default()
67 .event(&event_name)
68 .data("serialization_error")
69 });
70 Ok(ev)
71 })
72 }
73
74 pub fn sse_response(&self) -> Sse<impl Stream<Item = Result<Event, Infallible>> + use<T>>
81 where
82 T: Serialize,
83 {
84 let stream = Self::wrap_stream_as_sse(self.subscribe_stream());
85 Sse::new(stream).keep_alive(
86 KeepAlive::new()
87 .interval(Duration::from_secs(15))
88 .text("keepalive"),
89 )
90 }
91
92 pub fn sse_response_with_headers<I>(&self, headers: I) -> axum::response::Response
94 where
95 T: Serialize,
96 I: IntoIterator<Item = (axum::http::HeaderName, axum::http::HeaderValue)>,
97 {
98 let mut resp = self.sse_response().into_response();
99 let dst = resp.headers_mut();
100 for (name, value) in headers {
101 dst.insert(name, value);
102 }
103 resp
104 }
105
106 pub fn sse_response_named<N>(
112 &self,
113 event_name: N,
114 ) -> Sse<impl Stream<Item = Result<Event, Infallible>> + use<T, N>>
115 where
116 T: Serialize,
117 N: Into<Cow<'static, str>> + 'static,
118 {
119 let stream = Self::wrap_stream_as_sse_named(self.subscribe_stream(), event_name.into());
120 Sse::new(stream).keep_alive(
121 KeepAlive::new()
122 .interval(Duration::from_secs(15))
123 .text("keepalive"),
124 )
125 }
126
127 pub fn sse_response_named_with_headers<I>(
129 &self,
130 event_name: impl Into<Cow<'static, str>> + 'static,
131 headers: I,
132 ) -> axum::response::Response
133 where
134 T: Serialize,
135 I: IntoIterator<Item = (axum::http::HeaderName, axum::http::HeaderValue)>,
136 {
137 let mut resp = self.sse_response_named(event_name).into_response();
138 let dst = resp.headers_mut();
139 for (name, value) in headers {
140 dst.insert(name, value);
141 }
142 resp
143 }
144}
145
146#[cfg(test)]
147#[cfg_attr(coverage_nightly, coverage(off))]
148mod tests {
149 use super::*;
150 use futures_util::StreamExt;
151 use std::sync::{
152 Arc,
153 atomic::{AtomicUsize, Ordering},
154 };
155 use tokio::time::{Duration, timeout};
156
157 #[tokio::test]
158 async fn broadcaster_delivers_single_event() {
159 let b = SseBroadcaster::<u32>::new(16);
160 let mut sub = Box::pin(b.subscribe_stream());
161 b.send(42);
162 let v = timeout(Duration::from_millis(200), sub.next())
163 .await
164 .unwrap();
165 assert_eq!(v, Some(42));
166 }
167
168 #[tokio::test]
169 async fn broadcaster_handles_backpressure_with_bounded_channel() {
170 let capacity = 4;
172 let broadcaster = SseBroadcaster::<u32>::new(capacity);
173
174 let mut subscriber = Box::pin(broadcaster.subscribe_stream());
176
177 let num_events = capacity * 2;
179 for i in 0..num_events {
180 broadcaster.send(u32::try_from(i).unwrap());
181 }
182
183 let mut received = Vec::new();
186
187 for _ in 0..num_events {
189 match timeout(Duration::from_millis(10), subscriber.next()).await {
190 Ok(Some(event)) => received.push(event),
191 Ok(None) | Err(_) => break, }
193 }
194
195 assert!(!received.is_empty());
198 assert!(received.len() <= num_events);
199
200 for window in received.windows(2) {
202 assert!(window[0] < window[1], "Events should be in order");
203 }
204 }
205
206 #[tokio::test]
207 async fn broadcaster_handles_multiple_subscribers_with_backpressure() {
208 let capacity = 8;
209 let broadcaster = SseBroadcaster::<String>::new(capacity);
210
211 let mut fast_subscriber = Box::pin(broadcaster.subscribe_stream());
213 let mut slow_subscriber = Box::pin(broadcaster.subscribe_stream());
214
215 let events_sent = Arc::new(AtomicUsize::new(0));
216 let events_sent_clone = events_sent.clone();
217
218 let producer = tokio::spawn(async move {
220 for i in 0..50 {
221 broadcaster.send(format!("event_{i}"));
222 events_sent_clone.fetch_add(1, Ordering::SeqCst);
223 tokio::task::yield_now().await; }
225 });
226
227 let fast_events = Arc::new(AtomicUsize::new(0));
229 let fast_events_clone = fast_events.clone();
230 let fast_consumer = tokio::spawn(async move {
231 while let Ok(Some(_event)) =
232 timeout(Duration::from_millis(100), fast_subscriber.next()).await
233 {
234 fast_events_clone.fetch_add(1, Ordering::SeqCst);
235 }
236 });
237
238 let slow_events = Arc::new(AtomicUsize::new(0));
240 let slow_events_clone = slow_events.clone();
241 let slow_consumer = tokio::spawn(async move {
242 while let Ok(Some(_event)) =
243 timeout(Duration::from_millis(100), slow_subscriber.next()).await
244 {
245 slow_events_clone.fetch_add(1, Ordering::SeqCst);
246 tokio::time::sleep(Duration::from_millis(5)).await;
248 }
249 });
250
251 producer.await.unwrap();
253
254 tokio::time::sleep(Duration::from_millis(200)).await;
256
257 fast_consumer.abort();
259 slow_consumer.abort();
260
261 let total_sent = events_sent.load(Ordering::SeqCst);
262 let fast_received = fast_events.load(Ordering::SeqCst);
263 let slow_received = slow_events.load(Ordering::SeqCst);
264
265 assert_eq!(total_sent, 50);
266
267 assert!(fast_received > 0);
270 assert!(slow_received > 0);
271
272 println!(
275 "Sent: {total_sent}, Fast received: {fast_received}, Slow received: {slow_received}"
276 );
277 }
278
279 #[tokio::test]
280 #[allow(clippy::assertions_on_constants)]
281 async fn broadcaster_prevents_unbounded_memory_growth() {
282 let small_capacity = 2;
283 let broadcaster = SseBroadcaster::<Vec<u8>>::new(small_capacity);
284
285 let _subscriber = broadcaster.subscribe_stream();
287
288 for i in 0..100 {
290 let large_event = vec![u8::try_from(i).unwrap(); 1024]; broadcaster.send(large_event);
292 }
293
294 broadcaster.send(vec![255; 1024]);
300
301 assert!(true);
303 }
304
305 #[tokio::test]
306 async fn broadcaster_handles_subscriber_drop_gracefully() {
307 let broadcaster = SseBroadcaster::<u32>::new(16);
308
309 {
311 let _subscriber = broadcaster.subscribe_stream();
312 broadcaster.send(1);
313 } let mut new_subscriber = Box::pin(broadcaster.subscribe_stream());
317 broadcaster.send(2);
318
319 let received = timeout(Duration::from_millis(100), new_subscriber.next())
320 .await
321 .unwrap();
322 assert_eq!(received, Some(2));
323 }
324
325 #[tokio::test]
326 async fn broadcaster_send_is_non_blocking() {
327 let broadcaster = SseBroadcaster::<u32>::new(1); let start = std::time::Instant::now();
331 for i in 0..1000 {
332 broadcaster.send(i);
333 }
334 let elapsed = start.elapsed();
335
336 assert!(
338 elapsed < Duration::from_millis(100),
339 "Send operations took too long: {elapsed:?}"
340 );
341 }
342}