1use crate::error::Error;
2use axum::response::{IntoResponse, Response};
3use futures_util::{FutureExt, Stream, StreamExt};
4use std::collections::HashMap;
5use std::hash::Hash;
6use std::pin::Pin;
7use std::sync::{Arc, RwLock};
8use std::task::{Context, Poll};
9use tokio::sync::broadcast;
10
11use super::config::SseConfig;
12use super::event::Event;
13
14#[derive(Debug, Clone, Copy)]
16pub enum LagPolicy {
17 End,
19 Skip,
21}
22
23pub struct BroadcastStream<T> {
35 inner: Pin<Box<dyn Stream<Item = Result<T, Error>> + Send>>,
39 _cleanup: Option<Box<dyn FnOnce() + Send>>,
40}
41
42impl<T: Clone + Send + 'static> BroadcastStream<T> {
43 pub fn new(rx: broadcast::Receiver<T>) -> Self {
52 Self {
53 inner: Box::pin(unfold_default(rx)),
54 _cleanup: None,
55 }
56 }
57
58 pub(crate) fn with_cleanup(
60 rx: broadcast::Receiver<T>,
61 cleanup: impl FnOnce() + Send + 'static,
62 ) -> Self {
63 Self {
64 inner: Box::pin(unfold_default(rx)),
65 _cleanup: Some(Box::new(cleanup)),
66 }
67 }
68
69 pub fn on_lag(mut self, policy: LagPolicy) -> Self {
81 let original = std::mem::replace(&mut self.inner, Box::pin(futures_util::stream::empty()));
84 self.inner = Box::pin(apply_lag_policy(original, policy));
85 self
86 }
87}
88
89fn unfold_default<T: Clone + Send + 'static>(
91 rx: broadcast::Receiver<T>,
92) -> impl Stream<Item = Result<T, Error>> {
93 futures_util::stream::unfold(rx, |mut rx| async move {
94 match rx.recv().await {
95 Ok(item) => Some((Ok(item), rx)),
96 Err(broadcast::error::RecvError::Lagged(n)) => Some((Err(Error::lagged(n)), rx)),
97 Err(broadcast::error::RecvError::Closed) => None,
98 }
99 })
100}
101
102fn apply_lag_policy<T: Send + 'static>(
104 stream: Pin<Box<dyn Stream<Item = Result<T, Error>> + Send>>,
105 policy: LagPolicy,
106) -> impl Stream<Item = Result<T, Error>> + Send {
107 futures_util::stream::unfold(stream, move |mut stream| async move {
108 use futures_util::StreamExt;
109 loop {
110 match stream.next().await {
111 Some(Ok(item)) => return Some((Ok(item), stream)),
112 Some(Err(e)) if e.is_lagged() => match policy {
113 LagPolicy::End => return None,
114 LagPolicy::Skip => {
115 tracing::warn!("{e}");
116 continue;
117 }
118 },
119 Some(Err(e)) => return Some((Err(e), stream)),
120 None => return None,
121 }
122 }
123 })
124}
125
126impl<T> Drop for BroadcastStream<T> {
127 fn drop(&mut self) {
128 if let Some(cleanup) = self._cleanup.take() {
129 cleanup();
130 }
131 }
132}
133
134impl<T> Stream for BroadcastStream<T> {
135 type Item = Result<T, Error>;
136
137 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
138 self.inner.as_mut().poll_next(cx)
139 }
140}
141
142pub fn replay<T>(items: Vec<T>) -> impl Stream<Item = Result<T, Error>> + Send
151where
152 T: Send + 'static,
153{
154 futures_util::stream::iter(items.into_iter().map(Ok))
155}
156
157struct BroadcasterInner<K, T> {
158 channels: RwLock<HashMap<K, broadcast::Sender<T>>>,
159 buffer: usize,
160 config: SseConfig,
161}
162
163pub struct Broadcaster<K, T>
188where
189 K: Hash + Eq + Clone + Send + Sync + 'static,
190 T: Clone + Send + Sync + 'static,
191{
192 inner: Arc<BroadcasterInner<K, T>>,
193}
194
195impl<K, T> Clone for Broadcaster<K, T>
196where
197 K: Hash + Eq + Clone + Send + Sync + 'static,
198 T: Clone + Send + Sync + 'static,
199{
200 fn clone(&self) -> Self {
201 Self {
202 inner: Arc::clone(&self.inner),
203 }
204 }
205}
206
207impl<K, T> Broadcaster<K, T>
208where
209 K: Hash + Eq + Clone + Send + Sync + 'static,
210 T: Clone + Send + Sync + 'static,
211{
212 pub fn new(buffer: usize, config: SseConfig) -> Self {
219 Self {
220 inner: Arc::new(BroadcasterInner {
221 channels: RwLock::new(HashMap::new()),
222 buffer,
223 config,
224 }),
225 }
226 }
227
228 pub fn subscribe(&self, key: &K) -> BroadcastStream<T> {
234 let mut channels = self
235 .inner
236 .channels
237 .write()
238 .unwrap_or_else(|e| e.into_inner());
239
240 let sender = channels
241 .entry(key.clone())
242 .or_insert_with(|| broadcast::channel(self.inner.buffer).0);
243 let rx = sender.subscribe();
244
245 let inner_ref = Arc::clone(&self.inner);
246 let key_owned = key.clone();
247 let cleanup = move || {
248 let mut channels = inner_ref
249 .channels
250 .write()
251 .unwrap_or_else(|e| e.into_inner());
252 if let std::collections::hash_map::Entry::Occupied(entry) = channels.entry(key_owned)
253 && entry.get().receiver_count() == 0
254 {
255 entry.remove();
256 }
257 };
258
259 BroadcastStream::with_cleanup(rx, cleanup)
260 }
261
262 pub fn send(&self, key: &K, event: T) -> usize {
267 let channels = self
268 .inner
269 .channels
270 .read()
271 .unwrap_or_else(|e| e.into_inner());
272 if let Some(sender) = channels.get(key) {
273 match sender.send(event) {
274 Ok(count) => count,
275 Err(_) => {
276 drop(channels);
277 let mut channels = self
278 .inner
279 .channels
280 .write()
281 .unwrap_or_else(|e| e.into_inner());
282 if let std::collections::hash_map::Entry::Occupied(entry) =
283 channels.entry(key.clone())
284 && entry.get().receiver_count() == 0
285 {
286 entry.remove();
287 }
288 0
289 }
290 }
291 } else {
292 0
293 }
294 }
295
296 pub fn subscriber_count(&self, key: &K) -> usize {
298 let channels = self
299 .inner
300 .channels
301 .read()
302 .unwrap_or_else(|e| e.into_inner());
303 channels.get(key).map(|s| s.receiver_count()).unwrap_or(0)
304 }
305
306 pub fn remove(&self, key: &K) {
311 let mut channels = self
312 .inner
313 .channels
314 .write()
315 .unwrap_or_else(|e| e.into_inner());
316 channels.remove(key);
317 }
318
319 pub fn config(&self) -> &SseConfig {
321 &self.inner.config
322 }
323
324 pub fn channel<F, Fut>(&self, f: F) -> Response
334 where
335 F: FnOnce(super::Sender) -> Fut + Send + 'static,
336 Fut: std::future::Future<Output = Result<(), Error>> + Send,
337 {
338 const CHANNEL_BUFFER: usize = 32;
339 let (tx, rx) = tokio::sync::mpsc::channel(CHANNEL_BUFFER);
340 let sender = super::Sender { tx };
341
342 tokio::spawn(async move {
343 let result = std::panic::AssertUnwindSafe(f(sender)).catch_unwind().await;
344 match result {
345 Ok(Ok(())) => {}
346 Ok(Err(e)) => {
347 tracing::debug!(error = %e, "SSE channel closure ended with error")
348 }
349 Err(_) => tracing::error!("SSE channel closure panicked"),
350 }
351 });
352
353 let stream = futures_util::stream::unfold(rx, |mut rx| async move {
355 rx.recv().await.map(|event| (Ok(event), rx))
356 });
357
358 self.response(stream)
359 }
360
361 pub fn response<S>(&self, stream: S) -> Response
366 where
367 S: Stream<Item = Result<Event, Error>> + Send + 'static,
368 {
369 let mapped = stream.map(|result| {
370 result
371 .map(axum::response::sse::Event::from)
372 .map_err(axum::Error::new)
373 });
374
375 let keep_alive =
376 axum::response::sse::KeepAlive::new().interval(self.inner.config.keep_alive_interval());
377
378 let mut resp = axum::response::sse::Sse::new(mapped)
379 .keep_alive(keep_alive)
380 .into_response();
381
382 resp.headers_mut()
383 .insert("x-accel-buffering", http::HeaderValue::from_static("no"));
384
385 resp
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use futures_util::StreamExt;
393 use tokio::sync::broadcast;
394
395 #[tokio::test]
396 async fn stream_yields_sent_values() {
397 let (tx, rx) = broadcast::channel(16);
398 let mut stream = BroadcastStream::new(rx);
399 tx.send("hello".to_string()).unwrap();
400 tx.send("world".to_string()).unwrap();
401 drop(tx);
402
403 let items: Vec<String> = stream
404 .by_ref()
405 .filter_map(|r| async { r.ok() })
406 .collect()
407 .await;
408 assert_eq!(items, vec!["hello", "world"]);
409 }
410
411 #[tokio::test]
412 async fn stream_ends_when_sender_dropped() {
413 let (tx, rx) = broadcast::channel(16);
414 let mut stream = BroadcastStream::new(rx);
415 tx.send(1).unwrap();
416 drop(tx);
417
418 assert!(stream.next().await.unwrap().is_ok()); assert!(stream.next().await.is_none()); }
421
422 #[tokio::test]
423 async fn lag_policy_skip_continues_after_lag() {
424 let (tx, rx) = broadcast::channel(2);
425 let mut stream = BroadcastStream::new(rx).on_lag(LagPolicy::Skip);
426
427 tx.send(1).unwrap();
429 tx.send(2).unwrap();
430 tx.send(3).unwrap(); let item = stream.next().await.unwrap();
434 assert!(item.is_ok());
435 }
436
437 #[tokio::test]
438 async fn lag_policy_end_terminates_on_lag() {
439 let (tx, rx) = broadcast::channel(2);
440 let mut stream = BroadcastStream::new(rx).on_lag(LagPolicy::End);
441
442 tx.send(1).unwrap();
443 tx.send(2).unwrap();
444 tx.send(3).unwrap(); let item = stream.next().await;
447 assert!(item.is_none()); }
449
450 #[tokio::test]
451 async fn default_lag_policy_propagates_error() {
452 let (tx, rx) = broadcast::channel(2);
453 let mut stream = BroadcastStream::new(rx);
454
455 tx.send(1).unwrap();
456 tx.send(2).unwrap();
457 tx.send(3).unwrap(); let item = stream.next().await.unwrap();
460 assert!(item.is_err());
461 assert!(item.unwrap_err().is_lagged());
462 }
463
464 #[tokio::test]
465 async fn replay_yields_all_items() {
466 let items = vec!["a".to_string(), "b".to_string(), "c".to_string()];
467 let stream = replay(items);
468 let collected: Vec<String> = stream.filter_map(|r| async { r.ok() }).collect().await;
469 assert_eq!(collected, vec!["a", "b", "c"]);
470 }
471
472 #[tokio::test]
473 async fn replay_empty_vec() {
474 let stream = replay::<String>(vec![]);
475 let collected: Vec<String> = stream.filter_map(|r| async { r.ok() }).collect().await;
476 assert!(collected.is_empty());
477 }
478
479 #[tokio::test]
480 async fn cleanup_fires_on_drop() {
481 use std::sync::Arc;
482 use std::sync::atomic::{AtomicBool, Ordering};
483
484 let (tx, rx) = broadcast::channel::<i32>(16);
485 let cleaned = Arc::new(AtomicBool::new(false));
486 let cleaned_clone = cleaned.clone();
487
488 let stream = BroadcastStream::with_cleanup(rx, move || {
489 cleaned_clone.store(true, Ordering::SeqCst);
490 });
491
492 drop(stream);
493 assert!(cleaned.load(Ordering::SeqCst));
494 drop(tx);
495 }
496
497 #[tokio::test]
498 async fn broadcaster_subscribe_and_send() {
499 let bc: Broadcaster<String, String> = Broadcaster::new(16, SseConfig::default());
500 let key = "room1".to_string();
501
502 let mut stream = bc.subscribe(&key);
503 assert_eq!(bc.subscriber_count(&key), 1);
504
505 let count = bc.send(&key, "hello".into());
506 assert_eq!(count, 1);
507
508 let item = stream.next().await.unwrap().unwrap();
509 assert_eq!(item, "hello");
510 }
511
512 #[tokio::test]
513 async fn broadcaster_send_to_nonexistent_key_returns_zero() {
514 let bc: Broadcaster<String, String> = Broadcaster::new(16, SseConfig::default());
515 let count = bc.send(&"nobody".into(), "hello".into());
516 assert_eq!(count, 0);
517 }
518
519 #[tokio::test]
520 async fn broadcaster_multiple_subscribers() {
521 let bc: Broadcaster<String, i32> = Broadcaster::new(16, SseConfig::default());
522 let key = "k".to_string();
523
524 let mut s1 = bc.subscribe(&key);
525 let mut s2 = bc.subscribe(&key);
526 assert_eq!(bc.subscriber_count(&key), 2);
527
528 bc.send(&key, 42);
529 assert_eq!(s1.next().await.unwrap().unwrap(), 42);
530 assert_eq!(s2.next().await.unwrap().unwrap(), 42);
531 }
532
533 #[tokio::test]
534 async fn broadcaster_auto_cleanup_on_last_drop() {
535 let bc: Broadcaster<String, i32> = Broadcaster::new(16, SseConfig::default());
536 let key = "cleanup".to_string();
537
538 let s1 = bc.subscribe(&key);
539 let s2 = bc.subscribe(&key);
540 assert_eq!(bc.subscriber_count(&key), 2);
541
542 drop(s1);
543 assert_eq!(bc.subscriber_count(&key), 1);
545
546 drop(s2);
547 assert_eq!(bc.subscriber_count(&key), 0);
549 }
550
551 #[tokio::test]
552 async fn broadcaster_remove_disconnects_subscribers() {
553 let bc: Broadcaster<String, i32> = Broadcaster::new(16, SseConfig::default());
554 let key = "rm".to_string();
555
556 let mut stream = bc.subscribe(&key);
557 bc.remove(&key);
558
559 assert!(stream.next().await.is_none());
561 }
562
563 #[tokio::test]
564 async fn broadcaster_clone_shares_state() {
565 let bc1: Broadcaster<String, String> = Broadcaster::new(16, SseConfig::default());
566 let bc2 = bc1.clone();
567 let key = "shared".to_string();
568
569 let mut stream = bc1.subscribe(&key);
570 bc2.send(&key, "from_clone".into());
571
572 let item = stream.next().await.unwrap().unwrap();
573 assert_eq!(item, "from_clone");
574 }
575
576 #[tokio::test]
577 async fn broadcaster_channel_produces_events() {
578 let bc: Broadcaster<String, String> = Broadcaster::new(16, SseConfig::default());
579
580 let response = bc.channel(|tx| async move {
581 tx.send(super::Event::new("e1", "test").unwrap().data("hello"))
582 .await?;
583 tx.send(super::Event::new("e2", "test").unwrap().data("world"))
584 .await?;
585 Ok(())
586 });
587
588 assert_eq!(response.headers().get("x-accel-buffering").unwrap(), "no");
590 assert_eq!(
591 response.headers().get("content-type").unwrap(),
592 "text/event-stream"
593 );
594 }
595
596 #[test]
597 fn broadcaster_config_accessible() {
598 let config = SseConfig {
599 keep_alive_interval_secs: 30,
600 };
601 let bc: Broadcaster<String, String> = Broadcaster::new(64, config);
602 assert_eq!(bc.config().keep_alive_interval_secs, 30);
603 }
604
605 #[tokio::test]
606 async fn broadcaster_response_returns_valid_response() {
607 let bc: Broadcaster<String, String> = Broadcaster::new(16, SseConfig::default());
608 let stream = futures_util::stream::empty::<Result<super::Event, crate::error::Error>>();
609 let response = bc.response(stream);
610 assert_eq!(response.headers().get("x-accel-buffering").unwrap(), "no");
611 assert_eq!(
612 response.headers().get("content-type").unwrap(),
613 "text/event-stream"
614 );
615 }
616
617 #[tokio::test]
618 async fn channel_closure_error_produces_valid_response() {
619 let bc: Broadcaster<String, String> = Broadcaster::new(16, SseConfig::default());
620
621 let response =
622 bc.channel(|_tx| async move { Err(crate::error::Error::internal("deliberate error")) });
623
624 assert_eq!(
625 response.headers().get("content-type").unwrap(),
626 "text/event-stream"
627 );
628 assert_eq!(response.headers().get("x-accel-buffering").unwrap(), "no");
629 }
630
631 #[tokio::test]
632 async fn channel_closure_panic_produces_valid_response() {
633 let bc: Broadcaster<String, String> = Broadcaster::new(16, SseConfig::default());
634
635 let response = bc.channel(|_tx| async move {
636 panic!("deliberate panic");
637 });
638
639 assert_eq!(
640 response.headers().get("content-type").unwrap(),
641 "text/event-stream"
642 );
643 assert_eq!(response.headers().get("x-accel-buffering").unwrap(), "no");
644 }
645
646 #[tokio::test]
647 async fn concurrent_subscribe_and_send() {
648 let bc: Broadcaster<String, i32> = Broadcaster::new(256, SseConfig::default());
649 let key = "concurrent".to_string();
650
651 let mut set = tokio::task::JoinSet::new();
652
653 for task_num in 0..10 {
654 let bc = bc.clone();
655 let key = key.clone();
656 set.spawn(async move {
657 let mut stream = bc.subscribe(&key);
658 bc.send(&key, task_num);
659 stream.next().await.unwrap().unwrap()
660 });
661 }
662
663 let mut results = Vec::new();
664 while let Some(result) = set.join_next().await {
665 results.push(result.expect("Task panicked"));
666 }
667
668 assert_eq!(results.len(), 10);
669 }
670}