1use futures::stream::{self, StreamExt};
12use tokio::sync::{broadcast, mpsc};
13
14use crate::source::Source;
15
16pub struct BroadcastHub<T: Clone + Send + 'static> {
20 sender: broadcast::Sender<T>,
21}
22
23impl<T: Clone + Send + 'static> BroadcastHub<T> {
24 pub fn new(buffer_size: usize) -> Self {
25 assert!(buffer_size >= 1, "buffer_size must be >= 1");
26 let (sender, _rx) = broadcast::channel(buffer_size);
27 Self { sender }
28 }
29
30 pub fn attach(&self, source: Source<T>) {
33 let tx = self.sender.clone();
34 tokio::spawn(async move {
35 let mut s = source.into_boxed();
36 while let Some(item) = s.next().await {
37 let _ = tx.send(item); }
39 });
40 }
41
42 pub fn consumer(&self) -> Source<T> {
46 let rx = self.sender.subscribe();
47 let stream = stream::unfold(rx, |mut rx| async move {
48 loop {
49 match rx.recv().await {
50 Ok(item) => return Some((item, rx)),
51 Err(broadcast::error::RecvError::Lagged(_)) => continue,
52 Err(broadcast::error::RecvError::Closed) => return None,
53 }
54 }
55 });
56 Source { inner: stream.boxed() }
57 }
58
59 pub fn consumer_count(&self) -> usize {
61 self.sender.receiver_count()
62 }
63}
64
65pub struct MergeHub<T: Send + 'static> {
69 sender: mpsc::UnboundedSender<T>,
70 receiver: parking_lot::Mutex<Option<mpsc::UnboundedReceiver<T>>>,
72}
73
74impl<T: Send + 'static> Default for MergeHub<T> {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl<T: Send + 'static> MergeHub<T> {
81 pub fn new() -> Self {
82 let (tx, rx) = mpsc::unbounded_channel();
83 Self { sender: tx, receiver: parking_lot::Mutex::new(Some(rx)) }
84 }
85
86 pub fn attach(&self, source: Source<T>) {
88 let tx = self.sender.clone();
89 tokio::spawn(async move {
90 let mut s = source.into_boxed();
91 while let Some(item) = s.next().await {
92 if tx.send(item).is_err() {
93 return;
94 }
95 }
96 });
97 }
98
99 pub fn source(&self) -> Source<T> {
102 match self.receiver.lock().take() {
103 Some(rx) => Source::from_receiver(rx),
104 None => Source::empty(),
105 }
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use crate::sink::Sink;
113 use std::time::Duration;
114
115 #[tokio::test]
116 async fn broadcast_hub_fans_to_two_consumers() {
117 let hub = BroadcastHub::<i32>::new(16);
118 let c1 = hub.consumer();
119 let c2 = hub.consumer();
120
121 hub.attach(Source::from_iter(vec![1, 2, 3]));
123
124 drop(hub);
127
128 let (a, b) = tokio::join!(Sink::collect(c1), Sink::collect(c2));
130 assert_eq!(a, vec![1, 2, 3]);
131 assert_eq!(b, vec![1, 2, 3]);
132 }
133
134 #[tokio::test]
135 async fn broadcast_hub_late_consumer_misses_earlier_elements() {
136 let hub = BroadcastHub::<i32>::new(16);
137 let c_pre = hub.consumer();
140 hub.attach(Source::from_iter(vec![1, 2, 3]));
141 let pre = tokio::time::timeout(Duration::from_millis(200), async move {
145 let mut got = Vec::new();
146 let mut s = c_pre.into_boxed();
147 while got.len() < 3 {
148 match s.next().await {
149 Some(v) => got.push(v),
150 None => break,
151 }
152 }
153 got
154 })
155 .await
156 .unwrap_or_default();
157 assert_eq!(pre, vec![1, 2, 3]);
158
159 let c_late = hub.consumer();
162 let late =
163 tokio::time::timeout(Duration::from_millis(50), Sink::collect(c_late)).await.unwrap_or_default();
164 assert!(late.is_empty());
165 }
166
167 #[tokio::test]
168 async fn broadcast_hub_consumer_count_grows_with_subscribers() {
169 let hub = BroadcastHub::<i32>::new(4);
170 assert_eq!(hub.consumer_count(), 0);
171 let _c1 = hub.consumer();
172 let _c2 = hub.consumer();
173 assert_eq!(hub.consumer_count(), 2);
174 }
175
176 #[tokio::test]
177 async fn merge_hub_aggregates_multiple_producers() {
178 let hub = MergeHub::<i32>::new();
179 hub.attach(Source::from_iter(vec![1, 2, 3]));
180 hub.attach(Source::from_iter(vec![10, 20, 30]));
181 let merged = hub.source();
182 drop(hub);
185
186 let mut got = Sink::collect(merged).await;
187 got.sort();
188 assert_eq!(got, vec![1, 2, 3, 10, 20, 30]);
189 }
190
191 #[tokio::test]
192 async fn merge_hub_source_can_be_taken_only_once() {
193 let hub = MergeHub::<i32>::new();
194 hub.attach(Source::from_iter(vec![1]));
195 let _ = hub.source();
196 let s2 = hub.source();
197 let v = tokio::time::timeout(Duration::from_millis(50), Sink::collect(s2)).await.unwrap_or_default();
198 assert!(v.is_empty());
199 }
200}